1# Licensed to the Apache Software Foundation (ASF) under one
2# or more contributor license agreements.  See the NOTICE file
3# distributed with this work for additional information
4# regarding copyright ownership.  The ASF licenses this file
5# to you under the Apache License, Version 2.0 (the
6# "License"); you may not use this file except in compliance
7# with the License.  You may obtain a copy of the License at
8#
9#   http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing,
12# software distributed under the License is distributed on an
13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14# KIND, either express or implied.  See the License for the
15# specific language governing permissions and limitations
16# under the License.
17
18'''
19Created on Jun 15, 2017
20
21@author: shujon
22'''
23
24from __future__ import print_function
25import logging
26from datetime import datetime
27import os
28import argparse
29import errno
30import mxnet as mx
31import numpy as np
32import cv2
33from scipy.io import savemat
34#from layer import GaussianSampleLayer
35
36######################################################################
37#An adversarial variational autoencoder implementation in mxnet
38# following the implementation at https://github.com/JeremyCCHsu/tf-vaegan
39# of paper `Larsen, Anders Boesen Lindbo, et al. "Autoencoding beyond pixels using a
40# learned similarity metric." arXiv preprint arXiv:1512.09300 (2015).`
41######################################################################
42
43@mx.init.register
44class MyConstant(mx.init.Initializer):
45    '''constant operator in mxnet, no used in the code
46    '''
47    def __init__(self, value):
48        super(MyConstant, self).__init__(value=value)
49        self.value = value
50
51    def _init_weight(self, _, arr):
52        arr[:] = mx.nd.array(self.value)
53
54def encoder(nef, z_dim, batch_size, no_bias=True, fix_gamma=True, eps=1e-5 + 1e-12):
55    '''The encoder is a CNN which takes 32x32 image as input
56    generates the 100 dimensional shape embedding as a sample from normal distribution
57    using predicted meand and variance
58    '''
59    BatchNorm = mx.sym.BatchNorm
60
61    data = mx.sym.Variable('data')
62
63    e1 = mx.sym.Convolution(data, name='enc1', kernel=(5,5), stride=(2,2), pad=(2,2), num_filter=nef, no_bias=no_bias)
64    ebn1 = BatchNorm(e1, name='encbn1', fix_gamma=fix_gamma, eps=eps)
65    eact1 = mx.sym.LeakyReLU(ebn1, name='encact1', act_type='leaky', slope=0.2)
66
67    e2 = mx.sym.Convolution(eact1, name='enc2', kernel=(5,5), stride=(2,2), pad=(2,2), num_filter=nef*2, no_bias=no_bias)
68    ebn2 = BatchNorm(e2, name='encbn2', fix_gamma=fix_gamma, eps=eps)
69    eact2 = mx.sym.LeakyReLU(ebn2, name='encact2', act_type='leaky', slope=0.2)
70
71    e3 = mx.sym.Convolution(eact2, name='enc3', kernel=(5,5), stride=(2,2), pad=(2,2), num_filter=nef*4, no_bias=no_bias)
72    ebn3 = BatchNorm(e3, name='encbn3', fix_gamma=fix_gamma, eps=eps)
73    eact3 = mx.sym.LeakyReLU(ebn3, name='encact3', act_type='leaky', slope=0.2)
74
75    e4 = mx.sym.Convolution(eact3, name='enc4', kernel=(5,5), stride=(2,2), pad=(2,2), num_filter=nef*8, no_bias=no_bias)
76    ebn4 = BatchNorm(e4, name='encbn4', fix_gamma=fix_gamma, eps=eps)
77    eact4 = mx.sym.LeakyReLU(ebn4, name='encact4', act_type='leaky', slope=0.2)
78
79    eact4 = mx.sym.Flatten(eact4)
80
81    z_mu = mx.sym.FullyConnected(eact4, num_hidden=z_dim, name="enc_mu")
82    z_lv = mx.sym.FullyConnected(eact4, num_hidden=z_dim, name="enc_lv")
83
84    z = z_mu + mx.symbol.broadcast_mul(mx.symbol.exp(0.5*z_lv),mx.symbol.random_normal(loc=0, scale=1,shape=(batch_size,z_dim)))
85
86    return z_mu, z_lv, z
87
88def generator(ngf, nc, no_bias=True, fix_gamma=True, eps=1e-5 + 1e-12, z_dim=100, activation='sigmoid'):
89    '''The genrator is a CNN which takes 100 dimensional embedding as input
90    and reconstructs the input image given to the encoder
91    '''
92    BatchNorm = mx.sym.BatchNorm
93    rand = mx.sym.Variable('rand')
94
95    rand = mx.sym.Reshape(rand, shape=(-1, z_dim, 1, 1))
96
97    g1 = mx.sym.Deconvolution(rand, name='gen1', kernel=(5,5), stride=(2,2),target_shape=(2,2), num_filter=ngf*8, no_bias=no_bias)
98    gbn1 = BatchNorm(g1, name='genbn1', fix_gamma=fix_gamma, eps=eps)
99    gact1 = mx.sym.Activation(gbn1, name="genact1", act_type="relu")
100
101    g2 = mx.sym.Deconvolution(gact1, name='gen2', kernel=(5,5), stride=(2,2),target_shape=(4,4), num_filter=ngf*4, no_bias=no_bias)
102    gbn2 = BatchNorm(g2, name='genbn2', fix_gamma=fix_gamma, eps=eps)
103    gact2 = mx.sym.Activation(gbn2, name='genact2', act_type='relu')
104
105    g3 = mx.sym.Deconvolution(gact2, name='gen3', kernel=(5,5), stride=(2,2), target_shape=(8,8), num_filter=ngf*2, no_bias=no_bias)
106    gbn3 = BatchNorm(g3, name='genbn3', fix_gamma=fix_gamma, eps=eps)
107    gact3 = mx.sym.Activation(gbn3, name='genact3', act_type='relu')
108
109    g4 = mx.sym.Deconvolution(gact3, name='gen4', kernel=(5,5), stride=(2,2), target_shape=(16,16), num_filter=ngf, no_bias=no_bias)
110    gbn4 = BatchNorm(g4, name='genbn4', fix_gamma=fix_gamma, eps=eps)
111    gact4 = mx.sym.Activation(gbn4, name='genact4', act_type='relu')
112
113    g5 = mx.sym.Deconvolution(gact4, name='gen5', kernel=(5,5), stride=(2,2), target_shape=(32,32), num_filter=nc, no_bias=no_bias)
114    gout = mx.sym.Activation(g5, name='genact5', act_type=activation)
115
116    return gout
117
118def discriminator1(ndf, no_bias=True, fix_gamma=True, eps=1e-5 + 1e-12):
119    '''First part of the discriminator which takes a 32x32 image as input
120    and output a convolutional feature map, this is required to calculate
121    the layer loss'''
122    BatchNorm = mx.sym.BatchNorm
123
124    data = mx.sym.Variable('data')
125
126    d1 = mx.sym.Convolution(data, name='d1', kernel=(5,5), stride=(2,2), pad=(2,2), num_filter=ndf, no_bias=no_bias)
127    dact1 = mx.sym.LeakyReLU(d1, name='dact1', act_type='leaky', slope=0.2)
128
129    d2 = mx.sym.Convolution(dact1, name='d2', kernel=(5,5), stride=(2,2), pad=(2,2), num_filter=ndf*2, no_bias=no_bias)
130    dbn2 = BatchNorm(d2, name='dbn2', fix_gamma=fix_gamma, eps=eps)
131    dact2 = mx.sym.LeakyReLU(dbn2, name='dact2', act_type='leaky', slope=0.2)
132
133    d3 = mx.sym.Convolution(dact2, name='d3', kernel=(5,5), stride=(2,2), pad=(2,2), num_filter=ndf*4, no_bias=no_bias)
134    dbn3 = BatchNorm(d3, name='dbn3', fix_gamma=fix_gamma, eps=eps)
135    dact3 = mx.sym.LeakyReLU(dbn3, name='dact3', act_type='leaky', slope=0.2)
136
137    return dact3
138
139def discriminator2(ndf, no_bias=True, fix_gamma=True, eps=1e-5 + 1e-12):
140    '''Second part of the discriminator which takes a 256x8x8 feature map as input
141    and generates the loss based on whether the input image was a real one or fake one'''
142
143    BatchNorm = mx.sym.BatchNorm
144
145    data = mx.sym.Variable('data')
146
147    label = mx.sym.Variable('label')
148
149    d4 = mx.sym.Convolution(data, name='d4', kernel=(5,5), stride=(2,2), pad=(2,2), num_filter=ndf*8, no_bias=no_bias)
150    dbn4 = BatchNorm(d4, name='dbn4', fix_gamma=fix_gamma, eps=eps)
151    dact4 = mx.sym.LeakyReLU(dbn4, name='dact4', act_type='leaky', slope=0.2)
152
153    h = mx.sym.Flatten(dact4)
154
155    d5 = mx.sym.FullyConnected(h, num_hidden=1, name="d5")
156
157    dloss = mx.sym.LogisticRegressionOutput(data=d5, label=label, name='dloss')
158
159    return dloss
160
161def GaussianLogDensity(x, mu, log_var, name='GaussianLogDensity', EPSILON = 1e-6):
162    '''GaussianLogDensity loss calculation for layer wise loss
163    '''
164    c = mx.sym.ones_like(log_var)*2.0 * 3.1416
165    c = mx.symbol.log(c)
166    var = mx.sym.exp(log_var)
167    x_mu2 = mx.symbol.square(x - mu)   # [Issue] not sure the dim works or not?
168    x_mu2_over_var = mx.symbol.broadcast_div(x_mu2, var + EPSILON)
169    log_prob = -0.5 * (c + log_var + x_mu2_over_var)
170    log_prob = mx.symbol.sum(log_prob, axis=1, name=name)   # keep_dims=True,
171    return log_prob
172
173def DiscriminatorLayerLoss():
174    '''Calculate the discriminator layer loss
175    '''
176
177    data = mx.sym.Variable('data')
178
179    label = mx.sym.Variable('label')
180
181    data = mx.sym.Flatten(data)
182    label = mx.sym.Flatten(label)
183
184    label = mx.sym.BlockGrad(label)
185
186    zeros = mx.sym.zeros_like(data)
187
188    output = -GaussianLogDensity(label, data, zeros)
189
190    dloss = mx.symbol.MakeLoss(mx.symbol.mean(output),name='lloss')
191
192    return dloss
193
194def KLDivergenceLoss():
195    '''KLDivergenceLoss loss
196    '''
197
198    data = mx.sym.Variable('data')
199    mu1, lv1 = mx.sym.split(data,  num_outputs=2, axis=0)
200    mu2 = mx.sym.zeros_like(mu1)
201    lv2 = mx.sym.zeros_like(lv1)
202
203    v1 = mx.sym.exp(lv1)
204    v2 = mx.sym.exp(lv2)
205    mu_diff_sq = mx.sym.square(mu1 - mu2)
206    dimwise_kld = .5 * (
207    (lv2 - lv1) + mx.symbol.broadcast_div(v1, v2) + mx.symbol.broadcast_div(mu_diff_sq, v2) - 1.)
208    KL = mx.symbol.sum(dimwise_kld, axis=1)
209
210    KLloss = mx.symbol.MakeLoss(mx.symbol.mean(KL),name='KLloss')
211    return KLloss
212
213def get_data(path, activation):
214    '''Get the dataset
215    '''
216    data = []
217    image_names = []
218    for filename in os.listdir(path):
219        img = cv2.imread(os.path.join(path,filename), cv2.IMREAD_GRAYSCALE)
220        image_names.append(filename)
221        if img is not None:
222            data.append(img)
223
224    data = np.asarray(data)
225
226    if activation == 'sigmoid':
227        data = data.astype(np.float32)/(255.0)
228    elif activation == 'tanh':
229        data = data.astype(np.float32)/(255.0/2) - 1.0
230
231    data = data.reshape((data.shape[0], 1, data.shape[1], data.shape[2]))
232
233    np.random.seed(1234)
234    p = np.random.permutation(data.shape[0])
235    X = data[p]
236
237    return X, image_names
238
239class RandIter(mx.io.DataIter):
240    '''Create a random iterator for generator
241    '''
242    def __init__(self, batch_size, ndim):
243        self.batch_size = batch_size
244        self.ndim = ndim
245        self.provide_data = [('rand', (batch_size, ndim, 1, 1))]
246        self.provide_label = []
247
248    def iter_next(self):
249        return True
250
251    def getdata(self):
252        return [mx.random.normal(0, 1.0, shape=(self.batch_size, self.ndim, 1, 1))]
253
254def fill_buf(buf, i, img, shape):
255    '''fill the ith grid of the buffer matrix with the values from the img
256    buf : buffer matrix
257    i : serial of the image in the 2D grid
258    img : image data
259    shape : ( height width depth ) of image'''
260
261    # grid height is a multiple of individual image height
262    m = buf.shape[0]/shape[0]
263
264    sx = (i%m)*shape[1]
265    sy = (i//m)*shape[0]
266    sx = int(sx)
267    sy = int(sy)
268    buf[sy:sy+shape[0], sx:sx+shape[1], :] = img
269
270def visual(title, X, activation):
271    '''create a grid of images and save it as a final image
272    title : grid image name
273    X : array of images
274    '''
275    assert len(X.shape) == 4
276
277    X = X.transpose((0, 2, 3, 1))
278    if activation == 'sigmoid':
279        X = np.clip((X)*(255.0), 0, 255).astype(np.uint8)
280    elif activation == 'tanh':
281        X = np.clip((X+1.0)*(255.0/2.0), 0, 255).astype(np.uint8)
282    n = np.ceil(np.sqrt(X.shape[0]))
283    buff = np.zeros((int(n*X.shape[1]), int(n*X.shape[2]), int(X.shape[3])), dtype=np.uint8)
284    for i, img in enumerate(X):
285        fill_buf(buff, i, img, X.shape[1:3])
286    cv2.imwrite('%s.jpg' % (title), buff)
287
288def train(dataset, nef, ndf, ngf, nc, batch_size, Z, lr, beta1, epsilon, ctx, check_point, g_dl_weight, output_path, checkpoint_path, data_path, activation,num_epoch, save_after_every, visualize_after_every, show_after_every):
289    '''adversarial training of the VAE
290    '''
291
292    #encoder
293    z_mu, z_lv, z = encoder(nef, Z, batch_size)
294    symE = mx.sym.Group([z_mu, z_lv, z])
295
296    #generator
297    symG = generator(ngf, nc, no_bias=True, fix_gamma=True, eps=1e-5 + 1e-12, z_dim = Z, activation=activation )
298
299    #discriminator
300    h  = discriminator1(ndf)
301    dloss  = discriminator2(ndf)
302    symD1 = h
303    symD2 = dloss
304
305
306    # ==============data==============
307    X_train, _ = get_data(data_path, activation)
308    train_iter = mx.io.NDArrayIter(X_train, batch_size=batch_size, shuffle=True)
309    rand_iter = RandIter(batch_size, Z)
310    label = mx.nd.zeros((batch_size,), ctx=ctx)
311
312    # =============module E=============
313    modE = mx.mod.Module(symbol=symE, data_names=('data',), label_names=None, context=ctx)
314    modE.bind(data_shapes=train_iter.provide_data)
315    modE.init_params(initializer=mx.init.Normal(0.02))
316    modE.init_optimizer(
317        optimizer='adam',
318        optimizer_params={
319            'learning_rate': lr,
320            'wd': 1e-6,
321            'beta1': beta1,
322            'epsilon': epsilon,
323            'rescale_grad': (1.0/batch_size)
324        })
325    mods = [modE]
326
327    # =============module G=============
328    modG = mx.mod.Module(symbol=symG, data_names=('rand',), label_names=None, context=ctx)
329    modG.bind(data_shapes=rand_iter.provide_data, inputs_need_grad=True)
330    modG.init_params(initializer=mx.init.Normal(0.02))
331    modG.init_optimizer(
332        optimizer='adam',
333        optimizer_params={
334            'learning_rate': lr,
335            'wd': 1e-6,
336            'beta1': beta1,
337            'epsilon': epsilon,
338        })
339    mods.append(modG)
340
341    # =============module D=============
342    modD1 = mx.mod.Module(symD1, label_names=[], context=ctx)
343    modD2 = mx.mod.Module(symD2, label_names=('label',), context=ctx)
344    modD = mx.mod.SequentialModule()
345    modD.add(modD1).add(modD2, take_labels=True, auto_wiring=True)
346    modD.bind(data_shapes=train_iter.provide_data,
347              label_shapes=[('label', (batch_size,))],
348              inputs_need_grad=True)
349    modD.init_params(initializer=mx.init.Normal(0.02))
350    modD.init_optimizer(
351        optimizer='adam',
352        optimizer_params={
353            'learning_rate': lr,
354            'wd': 1e-3,
355            'beta1': beta1,
356            'epsilon': epsilon,
357            'rescale_grad': (1.0/batch_size)
358        })
359    mods.append(modD)
360
361
362    # =============module DL=============
363    symDL = DiscriminatorLayerLoss()
364    modDL = mx.mod.Module(symbol=symDL, data_names=('data',), label_names=('label',), context=ctx)
365    modDL.bind(data_shapes=[('data', (batch_size,nef * 4,4,4))], ################################################################################################################################ fix 512 here
366              label_shapes=[('label', (batch_size,nef * 4,4,4))],
367              inputs_need_grad=True)
368    modDL.init_params(initializer=mx.init.Normal(0.02))
369    modDL.init_optimizer(
370        optimizer='adam',
371        optimizer_params={
372            'learning_rate': lr,
373            'wd': 0.,
374            'beta1': beta1,
375            'epsilon': epsilon,
376            'rescale_grad': (1.0/batch_size)
377        })
378
379    # =============module KL=============
380    symKL = KLDivergenceLoss()
381    modKL = mx.mod.Module(symbol=symKL, data_names=('data',), label_names=None, context=ctx)
382    modKL.bind(data_shapes=[('data', (batch_size*2,Z))],
383               inputs_need_grad=True)
384    modKL.init_params(initializer=mx.init.Normal(0.02))
385    modKL.init_optimizer(
386        optimizer='adam',
387        optimizer_params={
388            'learning_rate': lr,
389            'wd': 0.,
390            'beta1': beta1,
391            'epsilon': epsilon,
392            'rescale_grad': (1.0/batch_size)
393        })
394    mods.append(modKL)
395
396    def norm_stat(d):
397        return mx.nd.norm(d)/np.sqrt(d.size)
398    mon = mx.mon.Monitor(10, norm_stat, pattern=".*output|d1_backward_data", sort=True)
399    mon = None
400    if mon is not None:
401        for mod in mods:
402            pass
403
404    def facc(label, pred):
405        '''calculating prediction accuracy
406        '''
407        pred = pred.ravel()
408        label = label.ravel()
409        return ((pred > 0.5) == label).mean()
410
411    def fentropy(label, pred):
412        '''calculating binary cross-entropy loss
413        '''
414        pred = pred.ravel()
415        label = label.ravel()
416        return -(label*np.log(pred+1e-12) + (1.-label)*np.log(1.-pred+1e-12)).mean()
417
418    def kldivergence(label, pred):
419        '''calculating KL divergence loss
420        '''
421        mean, log_var = np.split(pred, 2, axis=0)
422        var = np.exp(log_var)
423        KLLoss = -0.5 * np.sum(1 + log_var - np.power(mean, 2) - var)
424        KLLoss = KLLoss / nElements
425        return KLLoss
426
427    mG = mx.metric.CustomMetric(fentropy)
428    mD = mx.metric.CustomMetric(fentropy)
429    mE = mx.metric.CustomMetric(kldivergence)
430    mACC = mx.metric.CustomMetric(facc)
431
432    print('Training...')
433    stamp =  datetime.now().strftime('%Y_%m_%d-%H_%M')
434
435    # =============train===============
436    for epoch in range(num_epoch):
437        train_iter.reset()
438        for t, batch in enumerate(train_iter):
439
440            rbatch = rand_iter.next()
441
442            if mon is not None:
443                mon.tic()
444
445            modG.forward(rbatch, is_train=True)
446            outG = modG.get_outputs()
447
448            # update discriminator on fake
449            label[:] = 0
450            modD.forward(mx.io.DataBatch(outG, [label]), is_train=True)
451            modD.backward()
452            gradD11 = [[grad.copyto(grad.context) for grad in grads] for grads in modD1._exec_group.grad_arrays]
453            gradD12 = [[grad.copyto(grad.context) for grad in grads] for grads in modD2._exec_group.grad_arrays]
454
455            modD.update_metric(mD, [label])
456            modD.update_metric(mACC, [label])
457
458
459            #update discriminator on decoded
460            modE.forward(batch, is_train=True)
461            mu, lv, z = modE.get_outputs()
462            z = z.reshape((batch_size, Z, 1, 1))
463            sample = mx.io.DataBatch([z], label=None, provide_data = [('rand', (batch_size, Z, 1, 1))])
464            modG.forward(sample, is_train=True)
465            xz = modG.get_outputs()
466            label[:] = 0
467            modD.forward(mx.io.DataBatch(xz, [label]), is_train=True)
468            modD.backward()
469
470            #modD.update()
471            gradD21 = [[grad.copyto(grad.context) for grad in grads] for grads in modD1._exec_group.grad_arrays]
472            gradD22 = [[grad.copyto(grad.context) for grad in grads] for grads in modD2._exec_group.grad_arrays]
473            modD.update_metric(mD, [label])
474            modD.update_metric(mACC, [label])
475
476            # update discriminator on real
477            label[:] = 1
478            batch.label = [label]
479            modD.forward(batch, is_train=True)
480            lx = [out.copyto(out.context) for out in modD1.get_outputs()]
481            modD.backward()
482            for gradsr, gradsf, gradsd in zip(modD1._exec_group.grad_arrays, gradD11, gradD21):
483                for gradr, gradf, gradd in zip(gradsr, gradsf, gradsd):
484                    gradr += 0.5 * (gradf + gradd)
485            for gradsr, gradsf, gradsd in zip(modD2._exec_group.grad_arrays, gradD12, gradD22):
486                for gradr, gradf, gradd in zip(gradsr, gradsf, gradsd):
487                    gradr += 0.5 * (gradf + gradd)
488
489            modD.update()
490            modD.update_metric(mD, [label])
491            modD.update_metric(mACC, [label])
492
493            modG.forward(rbatch, is_train=True)
494            outG = modG.get_outputs()
495            label[:] = 1
496            modD.forward(mx.io.DataBatch(outG, [label]), is_train=True)
497            modD.backward()
498            diffD = modD1.get_input_grads()
499            modG.backward(diffD)
500            gradG1 = [[grad.copyto(grad.context) for grad in grads] for grads in modG._exec_group.grad_arrays]
501            mG.update([label], modD.get_outputs())
502
503            modG.forward(sample, is_train=True)
504            xz = modG.get_outputs()
505            label[:] = 1
506            modD.forward(mx.io.DataBatch(xz, [label]), is_train=True)
507            modD.backward()
508            diffD = modD1.get_input_grads()
509            modG.backward(diffD)
510            gradG2 = [[grad.copyto(grad.context) for grad in grads] for grads in modG._exec_group.grad_arrays]
511            mG.update([label], modD.get_outputs())
512
513            modG.forward(sample, is_train=True)
514            xz = modG.get_outputs()
515            modD1.forward(mx.io.DataBatch(xz, []), is_train=True)
516            outD1 = modD1.get_outputs()
517            modDL.forward(mx.io.DataBatch(outD1, lx), is_train=True)
518            modDL.backward()
519            dlGrad = modDL.get_input_grads()
520            modD1.backward(dlGrad)
521            diffD = modD1.get_input_grads()
522            modG.backward(diffD)
523
524            for grads, gradsG1, gradsG2 in zip(modG._exec_group.grad_arrays, gradG1, gradG2):
525                for grad, gradg1, gradg2 in zip(grads, gradsG1, gradsG2):
526                    grad = g_dl_weight * grad + 0.5 * (gradg1 + gradg2)
527
528            modG.update()
529            mG.update([label], modD.get_outputs())
530
531            modG.forward(rbatch, is_train=True)
532            outG = modG.get_outputs()
533            label[:] = 1
534            modD.forward(mx.io.DataBatch(outG, [label]), is_train=True)
535            modD.backward()
536            diffD = modD1.get_input_grads()
537            modG.backward(diffD)
538            gradG1 = [[grad.copyto(grad.context) for grad in grads] for grads in modG._exec_group.grad_arrays]
539            mG.update([label], modD.get_outputs())
540
541            modG.forward(sample, is_train=True)
542            xz = modG.get_outputs()
543            label[:] = 1
544            modD.forward(mx.io.DataBatch(xz, [label]), is_train=True)
545            modD.backward()
546            diffD = modD1.get_input_grads()
547            modG.backward(diffD)
548            gradG2 = [[grad.copyto(grad.context) for grad in grads] for grads in modG._exec_group.grad_arrays]
549            mG.update([label], modD.get_outputs())
550
551            modG.forward(sample, is_train=True)
552            xz = modG.get_outputs()
553            modD1.forward(mx.io.DataBatch(xz, []), is_train=True)
554            outD1 = modD1.get_outputs()
555            modDL.forward(mx.io.DataBatch(outD1, lx), is_train=True)
556            modDL.backward()
557            dlGrad = modDL.get_input_grads()
558            modD1.backward(dlGrad)
559            diffD = modD1.get_input_grads()
560            modG.backward(diffD)
561
562            for grads, gradsG1, gradsG2 in zip(modG._exec_group.grad_arrays, gradG1, gradG2):
563                for grad, gradg1, gradg2 in zip(grads, gradsG1, gradsG2):
564                    grad = g_dl_weight * grad + 0.5 * (gradg1 + gradg2)
565
566            modG.update()
567            mG.update([label], modD.get_outputs())
568
569            modG.forward(sample, is_train=True)
570            xz = modG.get_outputs()
571
572            #update generator
573            modD1.forward(mx.io.DataBatch(xz, []), is_train=True)
574            outD1 = modD1.get_outputs()
575            modDL.forward(mx.io.DataBatch(outD1, lx), is_train=True)
576            DLloss = modDL.get_outputs()
577            modDL.backward()
578            dlGrad = modDL.get_input_grads()
579            modD1.backward(dlGrad)
580            diffD = modD1.get_input_grads()
581            modG.backward(diffD)
582            #update encoder
583            nElements = batch_size
584            modKL.forward(mx.io.DataBatch([mx.ndarray.concat(mu,lv, dim=0)]), is_train=True)
585            KLloss = modKL.get_outputs()
586            modKL.backward()
587            gradKLLoss = modKL.get_input_grads()
588            diffG = modG.get_input_grads()
589            diffG = diffG[0].reshape((batch_size, Z))
590            modE.backward(mx.ndarray.split(gradKLLoss[0], num_outputs=2, axis=0) + [diffG])
591            modE.update()
592            pred = mx.ndarray.concat(mu,lv, dim=0)
593            mE.update([pred], [pred])
594            if mon is not None:
595                mon.toc_print()
596
597            t += 1
598            if t % show_after_every == 0:
599                print('epoch:', epoch, 'iter:', t, 'metric:', mACC.get(), mG.get(), mD.get(), mE.get(), KLloss[0].asnumpy(), DLloss[0].asnumpy())
600                mACC.reset()
601                mG.reset()
602                mD.reset()
603                mE.reset()
604
605            if epoch % visualize_after_every == 0:
606                visual(output_path +'gout'+str(epoch), outG[0].asnumpy(), activation)
607                visual(output_path + 'data'+str(epoch), batch.data[0].asnumpy(), activation)
608
609        if check_point and epoch % save_after_every == 0:
610            print('Saving...')
611            modG.save_params(checkpoint_path + '/%s_G-%04d.params'%(dataset, epoch))
612            modD.save_params(checkpoint_path + '/%s_D-%04d.params'%(dataset, epoch))
613            modE.save_params(checkpoint_path + '/%s_E-%04d.params'%(dataset, epoch))
614
615def test(nef, ngf, nc, batch_size, Z, ctx, pretrained_encoder_path, pretrained_generator_path, output_path, data_path, activation, save_embedding, embedding_path = ''):
616    '''Test the VAE with a pretrained encoder and generator.
617    Keep the batch size 1'''
618    #encoder
619    z_mu, z_lv, z = encoder(nef, Z, batch_size)
620    symE = mx.sym.Group([z_mu, z_lv, z])
621
622    #generator
623    symG = generator(ngf, nc, no_bias=True, fix_gamma=True, eps=1e-5 + 1e-12, z_dim = Z, activation=activation )
624
625    # ==============data==============
626    X_test, image_names = get_data(data_path, activation)
627    test_iter = mx.io.NDArrayIter(X_test, batch_size=batch_size, shuffle=False)
628
629    # =============module E=============
630    modE = mx.mod.Module(symbol=symE, data_names=('data',), label_names=None, context=ctx)
631    modE.bind(data_shapes=test_iter.provide_data)
632    modE.load_params(pretrained_encoder_path)
633
634    # =============module G=============
635    modG = mx.mod.Module(symbol=symG, data_names=('rand',), label_names=None, context=ctx)
636    modG.bind(data_shapes=[('rand', (1, Z, 1, 1))])
637    modG.load_params(pretrained_generator_path)
638
639    print('Testing...')
640
641    # =============test===============
642    test_iter.reset()
643    for t, batch in enumerate(test_iter):
644
645        #update discriminator on decoded
646        modE.forward(batch, is_train=False)
647        mu, lv, z = modE.get_outputs()
648        mu = mu.reshape((batch_size, Z, 1, 1))
649        sample = mx.io.DataBatch([mu], label=None, provide_data = [('rand', (batch_size, Z, 1, 1))])
650        modG.forward(sample, is_train=False)
651        outG = modG.get_outputs()
652
653        visual(output_path + '/' + 'gout'+str(t), outG[0].asnumpy(), activation)
654        visual(output_path +  '/' + 'data'+str(t), batch.data[0].asnumpy(), activation)
655        image_name = image_names[t].split('.')[0]
656
657        if save_embedding:
658            savemat(embedding_path+'/'+image_name+'.mat', {'embedding':mu.asnumpy()})
659
660def create_and_validate_dir(data_dir):
661    '''Creates/Validates dir
662    '''
663    if data_dir != "":
664        if not os.path.exists(data_dir):
665            try:
666                logging.info('create directory %s', data_dir)
667                os.makedirs(data_dir)
668            except OSError as exc:
669                if exc.errno != errno.EEXIST:
670                    raise OSError('failed to create ' + data_dir)
671
672
673def parse_args():
674    '''Parse args
675    '''
676    parser = argparse.ArgumentParser(description='Train and Test an Adversarial Variatiional Encoder')
677
678    parser.add_argument('--train', help='train the network', action='store_true')
679    parser.add_argument('--test', help='test the network', action='store_true')
680    parser.add_argument('--save_embedding', help='saves the shape embedding of each input image', action='store_true')
681    parser.add_argument('--dataset', help='dataset name', default='caltech', type=str)
682    parser.add_argument('--activation', help='activation i.e. sigmoid or tanh', default='sigmoid', type=str)
683    parser.add_argument('--training_data_path', help='training data path', default='datasets/caltech101/data/images32x32', type=str)
684    parser.add_argument('--testing_data_path', help='testing data path', default='datasets/caltech101/test_data', type=str)
685    parser.add_argument('--pretrained_encoder_path', help='pretrained encoder model path', default='checkpoints32x32_sigmoid/caltech_E-0045.params', type=str)
686    parser.add_argument('--pretrained_generator_path', help='pretrained generator model path', default='checkpoints32x32_sigmoid/caltech_G-0045.params', type=str)
687    parser.add_argument('--output_path', help='output path for the generated images', default='outputs32x32_sigmoid', type=str)
688    parser.add_argument('--embedding_path', help='output path for the generated embeddings', default='outputs32x32_sigmoid', type=str)
689    parser.add_argument('--checkpoint_path', help='checkpoint saving path ', default='checkpoints32x32_sigmoid', type=str)
690    parser.add_argument('--nef', help='encoder filter count in the first layer', default=64, type=int)
691    parser.add_argument('--ndf', help='discriminator filter count in the first layer', default=64, type=int)
692    parser.add_argument('--ngf', help='generator filter count in the second last layer', default=64, type=int)
693    parser.add_argument('--nc', help='generator filter count in the last layer i.e. 1 for grayscale image, 3 for RGB image', default=1, type=int)
694    parser.add_argument('--batch_size', help='batch size, keep it 1 during testing', default=64, type=int)
695    parser.add_argument('--Z', help='embedding size', default=100, type=int)
696    parser.add_argument('--lr', help='learning rate', default=0.0002, type=float)
697    parser.add_argument('--beta1', help='beta1 for adam optimizer', default=0.5, type=float)
698    parser.add_argument('--epsilon', help='epsilon for adam optimizer', default=1e-5, type=float)
699    parser.add_argument('--g_dl_weight', help='discriminator layer loss weight', default=1e-1, type=float)
700    parser.add_argument('--gpu', help='gpu index', default=0, type=int)
701    parser.add_argument('--use_cpu', help='use cpu', action='store_true')
702    parser.add_argument('--num_epoch', help='number of maximum epochs ', default=45, type=int)
703    parser.add_argument('--save_after_every', help='save checkpoint after every this number of epochs ', default=5, type=int)
704    parser.add_argument('--visualize_after_every', help='save output images after every this number of epochs', default=5, type=int)
705    parser.add_argument('--show_after_every', help='show metrics after this number of iterations', default=10, type=int)
706
707    args = parser.parse_args()
708    return args
709
710def main():
711    args = parse_args()
712
713    if args.test and not os.path.exists(args.testing_data_path):
714        if not os.path.exists(args.testing_data_path):
715            raise OSError("Provided Testing Path: {} does not exist".format(args.testing_data_path))
716        if not os.path.exists(args.checkpoint_path):
717            raise OSError("Provided Checkpoint Path: {} does not exist".format(args.checkpoint_path))
718
719    create_and_validate_dir(args.checkpoint_path)
720    create_and_validate_dir(args.output_path)
721
722    # gpu context
723    if args.use_cpu:
724        ctx = mx.cpu()
725    else:
726        ctx = mx.gpu(args.gpu)
727
728    # checkpoint saving flags
729    check_point = True
730
731    if args.train:
732        train(args.dataset, args.nef, args.ndf, args.ngf, args.nc, args.batch_size, args.Z, args.lr, args.beta1, args.epsilon, ctx, check_point, args.g_dl_weight, args.output_path, args.checkpoint_path, args.training_data_path, args.activation, args.num_epoch, args.save_after_every, args.visualize_after_every, args.show_after_every)
733
734    if args.test:
735        test(args.nef, args.ngf, args.nc, 1, args.Z, ctx, args.pretrained_encoder_path, args.pretrained_generator_path, args.output_path, args.testing_data_path, args.activation, args.save_embedding, args.embedding_path)
736
737if __name__ == '__main__':
738    logging.basicConfig(level=logging.DEBUG)
739    main()
740