-
-
Save awjuliani/c9ecd8b37d33d6855cd4ed9aa16ce89f to your computer and use it in GitHub Desktop.
| { | |
| "cells": [ | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "# InfoGAN Tutorial" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "This tutorials walks through an implementation of InfoGAN as described in [InfoGAN: Interpretable Representation Learning by Information Maximizing Generative Adversarial Nets](https://arxiv.org/abs/1606.03657).\n", | |
| "\n", | |
| "To learn more about InfoGAN, see this [Medium post](https://medium.com/p/dd710852db46) on them. To lean more about GANs generally, see [this one](https://medium.com/@awjuliani/generative-adversarial-networks-explained-with-a-classic-spongebob-squarepants-episode-54deab2fce39#.692jyamki)." | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": { | |
| "collapsed": true | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "#Import the libraries we will need.\n", | |
| "import tensorflow as tf\n", | |
| "import numpy as np\n", | |
| "import input_data\n", | |
| "import matplotlib.pyplot as plt\n", | |
| "import tensorflow.contrib.slim as slim\n", | |
| "import os\n", | |
| "import scipy.misc\n", | |
| "import scipy" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "Load the MNIST dataset." | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": { | |
| "collapsed": false, | |
| "scrolled": true | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "mnist = input_data.read_data_sets(\"MNIST_data/\", one_hot=False)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "### Helper Functions" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": { | |
| "collapsed": true | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "#This function performns a leaky relu activation, which is needed for the discriminator network.\n", | |
| "def lrelu(x, leak=0.2, name=\"lrelu\"):\n", | |
| " with tf.variable_scope(name):\n", | |
| " f1 = 0.5 * (1 + leak)\n", | |
| " f2 = 0.5 * (1 - leak)\n", | |
| " return f1 * x + f2 * abs(x)\n", | |
| " \n", | |
| "#The below functions are taken from carpdem20's implementation https://github.com/carpedm20/DCGAN-tensorflow\n", | |
| "#They allow for saving sample images from the generator to follow progress\n", | |
| "def save_images(images, size, image_path):\n", | |
| " return imsave(inverse_transform(images), size, image_path)\n", | |
| "\n", | |
| "def imsave(images, size, path):\n", | |
| " return scipy.misc.imsave(path, merge(images, size))\n", | |
| "\n", | |
| "def inverse_transform(images):\n", | |
| " return (images+1.)/2.\n", | |
| "\n", | |
| "def merge(images, size):\n", | |
| " h, w = images.shape[1], images.shape[2]\n", | |
| " img = np.zeros((h * size[0], w * size[1]))\n", | |
| "\n", | |
| " for idx, image in enumerate(images):\n", | |
| " i = idx % size[1]\n", | |
| " j = idx / size[1]\n", | |
| " img[j*h:j*h+h, i*w:i*w+w] = image\n", | |
| "\n", | |
| " return img" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "## Defining the Adversarial Networks" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "### Generator Network\n", | |
| "\n", | |
| "The generator takes a vector of random numbers and transforms it into a 32x32 image. Each layer in the network involves a strided transpose convolution, batch normalization, and rectified nonlinearity. Tensorflow's slim library allows us to easily define each of these layers." | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": { | |
| "collapsed": true | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "def generator(z):\n", | |
| " \n", | |
| " zP = slim.fully_connected(z,4*4*256,normalizer_fn=slim.batch_norm,\\\n", | |
| " activation_fn=tf.nn.relu,scope='g_project',weights_initializer=initializer)\n", | |
| " zCon = tf.reshape(zP,[-1,4,4,256])\n", | |
| " \n", | |
| " gen1 = slim.convolution2d(\\\n", | |
| " zCon,num_outputs=128,kernel_size=[3,3],\\\n", | |
| " padding=\"SAME\",normalizer_fn=slim.batch_norm,\\\n", | |
| " activation_fn=tf.nn.relu,scope='g_conv1', weights_initializer=initializer)\n", | |
| " gen1 = tf.depth_to_space(gen1,2)\n", | |
| " \n", | |
| " gen2 = slim.convolution2d(\\\n", | |
| " gen1,num_outputs=64,kernel_size=[3,3],\\\n", | |
| " padding=\"SAME\",normalizer_fn=slim.batch_norm,\\\n", | |
| " activation_fn=tf.nn.relu,scope='g_conv2', weights_initializer=initializer)\n", | |
| " gen2 = tf.depth_to_space(gen2,2)\n", | |
| " \n", | |
| " gen3 = slim.convolution2d(\\\n", | |
| " gen2,num_outputs=32,kernel_size=[3,3],\\\n", | |
| " padding=\"SAME\",normalizer_fn=slim.batch_norm,\\\n", | |
| " activation_fn=tf.nn.relu,scope='g_conv3', weights_initializer=initializer)\n", | |
| " gen3 = tf.depth_to_space(gen3,2)\n", | |
| " \n", | |
| " g_out = slim.convolution2d(\\\n", | |
| " gen3,num_outputs=1,kernel_size=[32,32],padding=\"SAME\",\\\n", | |
| " biases_initializer=None,activation_fn=tf.nn.tanh,\\\n", | |
| " scope='g_out', weights_initializer=initializer)\n", | |
| " \n", | |
| " return g_out" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "### Discriminator Network\n", | |
| "The discriminator network takes as input a 32x32 image and transforms it into a single valued probability of being generated from real-world data. Again we use tf.slim to define the convolutional layers, batch normalization, and weight initialization." | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": { | |
| "collapsed": false | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "def discriminator(bottom, cat_list,conts, reuse=False):\n", | |
| " \n", | |
| " dis1 = slim.convolution2d(bottom,32,[3,3],padding=\"SAME\",\\\n", | |
| " biases_initializer=None,activation_fn=lrelu,\\\n", | |
| " reuse=reuse,scope='d_conv1',weights_initializer=initializer)\n", | |
| " dis1 = tf.space_to_depth(dis1,2)\n", | |
| " \n", | |
| " dis2 = slim.convolution2d(dis1,64,[3,3],padding=\"SAME\",\\\n", | |
| " normalizer_fn=slim.batch_norm,activation_fn=lrelu,\\\n", | |
| " reuse=reuse,scope='d_conv2', weights_initializer=initializer)\n", | |
| " dis2 = tf.space_to_depth(dis2,2)\n", | |
| " \n", | |
| " dis3 = slim.convolution2d(dis2,128,[3,3],padding=\"SAME\",\\\n", | |
| " normalizer_fn=slim.batch_norm,activation_fn=lrelu,\\\n", | |
| " reuse=reuse,scope='d_conv3',weights_initializer=initializer)\n", | |
| " dis3 = tf.space_to_depth(dis3,2)\n", | |
| " \n", | |
| " dis4 = slim.fully_connected(slim.flatten(dis3),1024,activation_fn=lrelu,\\\n", | |
| " reuse=reuse,scope='d_fc1', weights_initializer=initializer)\n", | |
| " \n", | |
| " d_out = slim.fully_connected(dis4,1,activation_fn=tf.nn.sigmoid,\\\n", | |
| " reuse=reuse,scope='d_out', weights_initializer=initializer)\n", | |
| " \n", | |
| " q_a = slim.fully_connected(dis4,128,normalizer_fn=slim.batch_norm,\\\n", | |
| " reuse=reuse,scope='q_fc1', weights_initializer=initializer)\n", | |
| " \n", | |
| " \n", | |
| " ## Here we define the unique layers used for the q-network. The number of outputs depends on the number of \n", | |
| " ## latent variables we choose to define.\n", | |
| " q_cat_outs = []\n", | |
| " for idx,var in enumerate(cat_list):\n", | |
| " q_outA = slim.fully_connected(q_a,var,activation_fn=tf.nn.softmax,\\\n", | |
| " reuse=reuse,scope='q_out_cat_'+str(idx), weights_initializer=initializer)\n", | |
| " q_cat_outs.append(q_outA)\n", | |
| " \n", | |
| " q_cont_outs = None\n", | |
| " if conts > 0:\n", | |
| " q_cont_outs = slim.fully_connected(q_a,conts,activation_fn=tf.nn.tanh,\\\n", | |
| " reuse=reuse,scope='q_out_cont_'+str(conts), weights_initializer=initializer)\n", | |
| " \n", | |
| " return d_out,q_cat_outs,q_cont_outs" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "### Connecting them together" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": { | |
| "collapsed": false | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "tf.reset_default_graph()\n", | |
| "\n", | |
| "z_size = 64 #Size of initial z vector used for generator.\n", | |
| "\n", | |
| "# Define latent variables.\n", | |
| "categorical_list = [10] # Each entry in this list defines a categorical variable of a specific size.\n", | |
| "number_continuous = 2 # The number of continous variables.\n", | |
| "\n", | |
| "#This initializaer is used to initialize all the weights of the network.\n", | |
| "initializer = tf.truncated_normal_initializer(stddev=0.02)\n", | |
| "\n", | |
| "#These placeholders are used for input into the generator and discriminator, respectively.\n", | |
| "z_in = tf.placeholder(shape=[None,z_size],dtype=tf.float32) #Random vector\n", | |
| "real_in = tf.placeholder(shape=[None,32,32,1],dtype=tf.float32) #Real images\n", | |
| "\n", | |
| "#These placeholders load the latent variables.\n", | |
| "latent_cat_in = tf.placeholder(shape=[None,len(categorical_list)],dtype=tf.int32)\n", | |
| "latent_cat_list = tf.split(1,len(categorical_list),latent_cat_in)\n", | |
| "latent_cont_in = tf.placeholder(shape=[None,number_continuous],dtype=tf.float32)\n", | |
| "\n", | |
| "oh_list = []\n", | |
| "for idx,var in enumerate(categorical_list):\n", | |
| " latent_oh = tf.one_hot(tf.reshape(latent_cat_list[idx],[-1]),var)\n", | |
| " oh_list.append(latent_oh)\n", | |
| "\n", | |
| "#Concatenate all c and z variables.\n", | |
| "z_lats = oh_list[:]\n", | |
| "z_lats.append(z_in)\n", | |
| "z_lats.append(latent_cont_in)\n", | |
| "z_lat = tf.concat(1,z_lats)\n", | |
| "\n", | |
| "\n", | |
| "Gz = generator(z_lat) #Generates images from random z vectors\n", | |
| "Dx,_,_ = discriminator(real_in,categorical_list,number_continuous) #Produces probabilities for real images\n", | |
| "Dg,QgCat,QgCont = discriminator(Gz,categorical_list,number_continuous,reuse=True) #Produces probabilities for generator images\n", | |
| "\n", | |
| "#These functions together define the optimization objective of the GAN.\n", | |
| "d_loss = -tf.reduce_mean(tf.log(Dx) + tf.log(1.-Dg)) #This optimizes the discriminator.\n", | |
| "g_loss = -tf.reduce_mean(tf.log((Dg/(1-Dg)))) #KL Divergence optimizer\n", | |
| "\n", | |
| "#Combine losses for each of the categorical variables.\n", | |
| "cat_losses = []\n", | |
| "for idx,latent_var in enumerate(oh_list):\n", | |
| " cat_loss = -tf.reduce_sum(latent_var*tf.log(QgCat[idx]),reduction_indices=1)\n", | |
| " cat_losses.append(cat_loss)\n", | |
| " \n", | |
| "#Combine losses for each of the continous variables.\n", | |
| "if number_continuous > 0:\n", | |
| " q_cont_loss = tf.reduce_sum(0.5 * tf.square(latent_cont_in - QgCont),reduction_indices=1)\n", | |
| "else:\n", | |
| " q_cont_loss = tf.constant(0.0)\n", | |
| "\n", | |
| "q_cont_loss = tf.reduce_mean(q_cont_loss)\n", | |
| "q_cat_loss = tf.reduce_mean(cat_losses)\n", | |
| "q_loss = tf.add(q_cat_loss,q_cont_loss)\n", | |
| "tvars = tf.trainable_variables()\n", | |
| "\n", | |
| "#The below code is responsible for applying gradient descent to update the GAN.\n", | |
| "trainerD = tf.train.AdamOptimizer(learning_rate=0.0002,beta1=0.5)\n", | |
| "trainerG = tf.train.AdamOptimizer(learning_rate=0.002,beta1=0.5)\n", | |
| "trainerQ = tf.train.AdamOptimizer(learning_rate=0.0002,beta1=0.5)\n", | |
| "d_grads = trainerD.compute_gradients(d_loss,tvars[9:-2-((number_continuous>0)*2)-(len(categorical_list)*2)]) #Only update the weights for the discriminator network.\n", | |
| "g_grads = trainerG.compute_gradients(g_loss, tvars[0:9]) #Only update the weights for the generator network.\n", | |
| "q_grads = trainerG.compute_gradients(q_loss, tvars) \n", | |
| "\n", | |
| "update_D = trainerD.apply_gradients(d_grads)\n", | |
| "update_G = trainerG.apply_gradients(g_grads)\n", | |
| "update_Q = trainerQ.apply_gradients(q_grads)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": { | |
| "collapsed": true | |
| }, | |
| "source": [ | |
| "## Training the network\n", | |
| "Now that we have fully defined our network, it is time to train it!" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": { | |
| "collapsed": false, | |
| "scrolled": true | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "batch_size = 64 #Size of image batch to apply at each iteration.\n", | |
| "iterations = 500000 #Total number of iterations to use.\n", | |
| "sample_directory = './figsTut' #Directory to save sample images from generator in.\n", | |
| "model_directory = './models' #Directory to save trained model to.\n", | |
| "\n", | |
| "init = tf.initialize_all_variables()\n", | |
| "saver = tf.train.Saver()\n", | |
| "with tf.Session() as sess: \n", | |
| " sess.run(init)\n", | |
| " for i in range(iterations):\n", | |
| " zs = np.random.uniform(-1.0,1.0,size=[batch_size,z_size]).astype(np.float32) #Generate a random z batch\n", | |
| " lcat = np.random.randint(0,10,[batch_size,len(categorical_list)]) #Generate random c batch\n", | |
| " lcont = np.random.uniform(-1,1,[batch_size,number_continuous]) #\n", | |
| " \n", | |
| " xs,_ = mnist.train.next_batch(batch_size) #Draw a sample batch from MNIST dataset.\n", | |
| " xs = (np.reshape(xs,[batch_size,28,28,1]) - 0.5) * 2.0 #Transform it to be between -1 and 1\n", | |
| " xs = np.lib.pad(xs, ((0,0),(2,2),(2,2),(0,0)),'constant', constant_values=(-1, -1)) #Pad the images so the are 32x32\n", | |
| " \n", | |
| " _,dLoss = sess.run([update_D,d_loss],feed_dict={z_in:zs,real_in:xs,latent_cat_in:lcat,latent_cont_in:lcont}) #Update the discriminator\n", | |
| " _,gLoss = sess.run([update_G,g_loss],feed_dict={z_in:zs,latent_cat_in:lcat,latent_cont_in:lcont}) #Update the generator, twice for good measure.\n", | |
| " _,qLoss,qK,qC = sess.run([update_Q,q_loss,q_cont_loss,q_cat_loss],feed_dict={z_in:zs,latent_cat_in:lcat,latent_cont_in:lcont}) #Update to optimize mutual information.\n", | |
| " if i % 100 == 0:\n", | |
| " print \"Gen Loss: \" + str(gLoss) + \" Disc Loss: \" + str(dLoss) + \" Q Losses: \" + str([qK,qC])\n", | |
| " z_sample = np.random.uniform(-1.0,1.0,size=[100,z_size]).astype(np.float32) #Generate another z batch\n", | |
| " lcat_sample = np.reshape(np.array([e for e in range(10) for _ in range(10)]),[100,1])\n", | |
| " a = a = np.reshape(np.array([[(e/4.5 - 1.)] for e in range(10) for _ in range(10)]),[10,10]).T\n", | |
| " b = np.reshape(a,[100,1])\n", | |
| " c = np.zeros_like(b)\n", | |
| " lcont_sample = np.hstack([b,c])\n", | |
| " samples = sess.run(Gz,feed_dict={z_in:z_sample,latent_cat_in:lcat_sample,latent_cont_in:lcont_sample}) #Use new z to get sample images from generator.\n", | |
| " if not os.path.exists(sample_directory):\n", | |
| " os.makedirs(sample_directory)\n", | |
| " #Save sample generator images for viewing training progress.\n", | |
| " save_images(np.reshape(samples[0:100],[100,32,32]),[10,10],sample_directory+'/fig'+str(i)+'.png')\n", | |
| " if i % 1000 == 0 and i != 0:\n", | |
| " if not os.path.exists(model_directory):\n", | |
| " os.makedirs(model_directory)\n", | |
| " saver.save(sess,model_directory+'/model-'+str(i)+'.cptk')\n", | |
| " print \"Saved Model\"" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "## Using a trained network\n", | |
| "Once we have a trained model saved, we may want to use it to generate new images, and explore the representation it has learned." | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": { | |
| "collapsed": true | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "sample_directory = './figsTut' #Directory to save sample images from generator in.\n", | |
| "model_directory = './models' #Directory to load trained model from.\n", | |
| "\n", | |
| "init = tf.initialize_all_variables()\n", | |
| "saver = tf.train.Saver()\n", | |
| "with tf.Session() as sess: \n", | |
| " sess.run(init)\n", | |
| " #Reload the model.\n", | |
| " print 'Loading Model...'\n", | |
| " ckpt = tf.train.get_checkpoint_state(path)\n", | |
| " saver.restore(sess,ckpt.model_checkpoint_path)\n", | |
| " \n", | |
| " z_sample = np.random.uniform(-1.0,1.0,size=[100,z_size]).astype(np.float32) #Generate another z batch\n", | |
| " lcat_sample = np.reshape(np.array([e for e in range(10) for _ in range(10)]),[100,1])\n", | |
| " a = a = np.reshape(np.array([[(e/4.5 - 1.)] for e in range(10) for _ in range(10)]),[10,10]).T\n", | |
| " b = np.reshape(a,[100,1])\n", | |
| " c = np.zeros_like(b)\n", | |
| " lcont_sample = np.hstack([b,c])\n", | |
| " samples = sess.run(Gz,feed_dict={z_in:z_sample,latent_cat_in:lcat_sample,latent_cont_in:lcont_sample}) #Use new z to get sample images from generator.\n", | |
| " if not os.path.exists(sample_directory):\n", | |
| " os.makedirs(sample_directory)\n", | |
| " #Save sample generator images for viewing training progress.\n", | |
| " save_images(np.reshape(samples[0:100],[100,32,32]),[10,10],sample_directory+'/fig_test+'.png')" | |
| ] | |
| } | |
| ], | |
| "metadata": { | |
| "kernelspec": { | |
| "display_name": "Python 2", | |
| "language": "python", | |
| "name": "python2" | |
| }, | |
| "language_info": { | |
| "codemirror_mode": { | |
| "name": "ipython", | |
| "version": 2 | |
| }, | |
| "file_extension": ".py", | |
| "mimetype": "text/x-python", | |
| "name": "python", | |
| "nbconvert_exporter": "python", | |
| "pygments_lexer": "ipython2", | |
| "version": "2.7.12" | |
| } | |
| }, | |
| "nbformat": 4, | |
| "nbformat_minor": 0 | |
| } |
Love the minimalism of this example.
One question -- it seems here that G is only being optimized to fool D, but not at all wrt maximizing mutual information via Q. Is this the case, or is G getting gradients from somewhere I'm missing?
For reference, in the infoGAN paper it suggests (right under eq (5)) "LI can be maximized w.r.t. Q directly and w.r.t. G via the reparametrization trick" and lines 82 and 97 of https://github.com/openai/InfoGAN/blob/master/infogan/algos/infogan_trainer.py are adding Q losses to the generator loss, which presumably propagates the Q errors through G as well.
Also it looks like the continuous Q loss is a reconstruction error, not conditional entropy? Where are the log-likelihoods as in https://github.com/openai/InfoGAN/blob/master/infogan/algos/infogan_trainer.py#L87
Hi Arthur,
I'm interested whether this:
q_grads = trainerG.compute_gradients(q_loss, tvars)
is supposed to be:
q_grads = trainerQ.compute_gradients(q_loss, tvars)
(If not then) could you please explain the mechanism to implement the mutual information factor into GAN in your implementation a little bit?