README.md
1<!--- Licensed to the Apache Software Foundation (ASF) under one -->
2<!--- or more contributor license agreements. See the NOTICE file -->
3<!--- distributed with this work for additional information -->
4<!--- regarding copyright ownership. The ASF licenses this file -->
5<!--- to you under the Apache License, Version 2.0 (the -->
6<!--- "License"); you may not use this file except in compliance -->
7<!--- with the License. You may obtain a copy of the License at -->
8
9<!--- http://www.apache.org/licenses/LICENSE-2.0 -->
10
11<!--- Unless required by applicable law or agreed to in writing, -->
12<!--- software distributed under the License is distributed on an -->
13<!--- "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -->
14<!--- KIND, either express or implied. See the License for the -->
15<!--- specific language governing permissions and limitations -->
16<!--- under the License. -->
17
18# Conditional Generative Adversarial Network with MXNet R package
19
20This tutorial shows how to build and train a Conditional Generative Adversarial Network (CGAN) on MNIST images.
21
22## How GAN works
23A Generative Adversarial Model simultaneously trains two models: a generator that learns to output fake samples from an unknown distribution and a discriminator that learns to distinguish fake from real samples.
24
25The CGAN is a conditional variation of the GAN where the generator is instructed to generate a real sample having specific characteristics rather than a generic sample from full distribution. Such condition could be the label associated with an image like in this tutorial or a more detailed tag as shown in the example below:
26
27![Image credit: (Scott Reed)[https://github.com/reedscot/icml2016]](https://raw.githubusercontent.com/dmlc/web-data/master/mxnet/example/gan/CGAN_mnist_R/dcgan_network.jpg)
28
29## Initial setup
30
31The following packages are needed to run the tutorial:
32
33```
34require("imager")
35require("dplyr")
36require("readr")
37require("mxnet")
38```
39
40The full demo is comprised of the two following scripts:
41
42```CGAN_mnist_setup.R```: prepare data and define the model structure
43```CGAN_train.R```: execute the training
44
45## Data preperation
46
47The MNIST dataset is available [here](https://www.kaggle.com/c/digit-recognizer/data)). Once train.csv is downloaded into the data/ folder, we can import into R.
48
49```train <- read_csv('data/train.csv')
50train <- data.matrix(train)
51
52train_data <- train[,-1]
53train_data <- t(train_data/255*2-1)
54train_label <- as.integer(train[,1])
55
56dim(train_data) <- c(28, 28, 1, ncol(train_data))
57```
58Custom iterators are defined in ```iterators.R``` and imported by ```CGAN_mnist_setup.R```
59
60## Generator
61The generator is a network that creates novel samples (MNIST images) from 2 inputs:
62
63- Noise vector
64- Labels defining the object condition (which digit to produce)
65
66The noise vector provides the building blocks to the Generator model, which will learns how to structure that noise into a sample. The mx.symbol.Deconvolution operator is used to upsample the initial input from a 1x1 shape up to a 28x28 image.
67
68The information on the label for which to generate a fake sample is provided by a one-hot encoding of the label indices that is appended to the random noise. For MNIST, the 0-9 indices are therefore converted into a binary vector of length 10. More complex applications would require embeddings rather than simple one-hot to encode the condition.
69
70## Discriminator
71The discriminator attempts to distinguish between fake samples produced by the generator and real ones sampled from MNIST training data.
72
73In a conditional GAN, the labels associated with the samples are also provided to the Discriminator. In this demo, this information is again provided as a hot-hot encoding of the label that is broadcast to match the image dimensions (10 -> 28x28x10).
74
75## Training logic
76The training process of the discriminator is most obvious: the loss is simple a binary TRUE/FALSE response and that loss is propagated back into the CNN network. It can therefore be understood as a simple binary classification problem.
77
78```### Train loop on fake
79mx.exec.update.arg.arrays(exec_D, arg.arrays =
80 list(data=D_data_fake, digit=D_digit_fake, label=mx.nd.array(rep(0, batch_size))),
81 match.name=TRUE)
82mx.exec.forward(exec_D, is.train=T)
83mx.exec.backward(exec_D)
84update_args_D<- updater_D(weight = exec_D$ref.arg.arrays, grad = exec_D$ref.grad.arrays)
85mx.exec.update.arg.arrays(exec_D, update_args_D, skip.null=TRUE)
86
87### Train loop on real
88mx.exec.update.arg.arrays(exec_D, arg.arrays =
89 list(data=D_data_real, digit=D_digit_real, label=mx.nd.array(rep(1, batch_size))),
90 match.name=TRUE)
91mx.exec.forward(exec_D, is.train=T)
92mx.exec.backward(exec_D)
93update_args_D<- updater_D(weight = exec_D$ref.arg.arrays, grad = exec_D$ref.grad.arrays)
94mx.exec.update.arg.arrays(exec_D, update_args_D, skip.null=TRUE)
95```
96
97The generator loss comes from the backpropagation of the discriminator loss into its generated output. By faking the generator labels to be real samples into the discriminator, the discriminator back-propagated loss provides the generator with the information on how to best adapt its parameters to trick the discriminator into believing the fake samples are real.
98
99This requires to backpropagate the gradients up to the input data of the discriminator (whereas this input gradient is typically ignored in vanilla feedforward network).
100
101```### Update Generator weights - use a seperate executor for writing data gradients
102exec_D_back <- mxnet:::mx.symbol.bind(symbol = D_sym,
103 arg.arrays = exec_D$arg.arrays,
104 aux.arrays = exec_D$aux.arrays, grad.reqs = rep("write", length(exec_D$arg.arrays)),
105 ctx = devices)
106
107mx.exec.update.arg.arrays(exec_D_back, arg.arrays =
108 list(data=D_data_fake, digit=D_digit_fake, label=mx.nd.array(rep(1, batch_size))),
109 match.name=TRUE)
110mx.exec.forward(exec_D_back, is.train=T)
111mx.exec.backward(exec_D_back)
112D_grads <- exec_D_back$ref.grad.arrays$data
113mx.exec.backward(exec_G, out_grads=D_grads)
114
115update_args_G <- updater_G(weight = exec_G$ref.arg.arrays, grad = exec_G$ref.grad.arrays)
116mx.exec.update.arg.arrays(exec_G, update_args_G, skip.null=TRUE)
117```
118
119The above training steps are executed in the ```CGAN_train.R``` script.
120
121## Monitor the training
122
123During training, the [imager](http://dahtah.github.io/imager/) package facilitates the visual quality assessment of the fake samples.
124
125```if (iteration==1 | iteration %% 100==0){
126 par(mfrow=c(3,3), mar=c(0.1,0.1,0.1,0.1))
127 for (i in 1:9) {
128 img <- as.array(exec_G$ref.outputs$G_sym_output)[,,,i]
129 plot(as.cimg(img), axes=F)
130 }
131}
132```
133Below are samples obtained at different stage of the training.
134
135Starting from noise:
136
137![](https://raw.githubusercontent.com/dmlc/web-data/master/mxnet/example/gan/CGAN_mnist_R/CGAN_1.png)
138
139Slowly getting it - iteration 200:
140
141![](https://raw.githubusercontent.com/dmlc/web-data/master/mxnet/example/gan/CGAN_mnist_R/CGAN_200.png)
142
143Generate specified digit images on demand - iteration 2400:
144
145![](https://raw.githubusercontent.com/dmlc/web-data/master/mxnet/example/gan/CGAN_mnist_R/CGAN_2400.png)
146
147## Inference
148
149Once the model is trained, synthetic images of the desired digit can be produced by feeding the generator with fixed labels rather than the randomly generated ones used during the training.
150
151Here we will generate fake ```9```:
152
153```digit <- mx.nd.array(rep(9, times=batch_size))
154data <- mx.nd.one.hot(indices = digit, depth = 10)
155data <- mx.nd.reshape(data = data, shape = c(1,1,-1, batch_size))
156
157exec_G <- mx.simple.bind(symbol = G_sym, data=data_shape_G, ctx = devices, grad.req = "null")
158mx.exec.update.arg.arrays(exec_G, G_arg_params, match.name=TRUE)
159mx.exec.update.arg.arrays(exec_G, list(data=data), match.name=TRUE)
160mx.exec.update.aux.arrays(exec_G, G_aux_params, match.name=TRUE)
161
162mx.exec.forward(exec_G, is.train=F)
163```
164![](https://raw.githubusercontent.com/dmlc/web-data/master/mxnet/example/gan/CGAN_mnist_R/CGAN_infer_9.png)
165
166Further details of the CGAN methodology can be found in the paper [Generative Adversarial Text to Image Synthesis](https://arxiv.org/abs/1605.05396).
167
168
169