1# Licensed to the Apache Software Foundation (ASF) under one 2# or more contributor license agreements. See the NOTICE file 3# distributed with this work for additional information 4# regarding copyright ownership. The ASF licenses this file 5# to you under the Apache License, Version 2.0 (the 6# "License"); you may not use this file except in compliance 7# with the License. You may obtain a copy of the License at 8# 9# http://www.apache.org/licenses/LICENSE-2.0 10# 11# Unless required by applicable law or agreed to in writing, 12# software distributed under the License is distributed on an 13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14# KIND, either express or implied. See the License for the 15# specific language governing permissions and limitations 16# under the License. 17 18''' 19Created on Jun 15, 2017 20 21@author: shujon 22''' 23 24from __future__ import print_function 25import logging 26from datetime import datetime 27import os 28import argparse 29import errno 30import mxnet as mx 31import numpy as np 32import cv2 33from scipy.io import savemat 34#from layer import GaussianSampleLayer 35 36###################################################################### 37#An adversarial variational autoencoder implementation in mxnet 38# following the implementation at https://github.com/JeremyCCHsu/tf-vaegan 39# of paper `Larsen, Anders Boesen Lindbo, et al. "Autoencoding beyond pixels using a 40# learned similarity metric." arXiv preprint arXiv:1512.09300 (2015).` 41###################################################################### 42 43@mx.init.register 44class MyConstant(mx.init.Initializer): 45 '''constant operator in mxnet, no used in the code 46 ''' 47 def __init__(self, value): 48 super(MyConstant, self).__init__(value=value) 49 self.value = value 50 51 def _init_weight(self, _, arr): 52 arr[:] = mx.nd.array(self.value) 53 54def encoder(nef, z_dim, batch_size, no_bias=True, fix_gamma=True, eps=1e-5 + 1e-12): 55 '''The encoder is a CNN which takes 32x32 image as input 56 generates the 100 dimensional shape embedding as a sample from normal distribution 57 using predicted meand and variance 58 ''' 59 BatchNorm = mx.sym.BatchNorm 60 61 data = mx.sym.Variable('data') 62 63 e1 = mx.sym.Convolution(data, name='enc1', kernel=(5,5), stride=(2,2), pad=(2,2), num_filter=nef, no_bias=no_bias) 64 ebn1 = BatchNorm(e1, name='encbn1', fix_gamma=fix_gamma, eps=eps) 65 eact1 = mx.sym.LeakyReLU(ebn1, name='encact1', act_type='leaky', slope=0.2) 66 67 e2 = mx.sym.Convolution(eact1, name='enc2', kernel=(5,5), stride=(2,2), pad=(2,2), num_filter=nef*2, no_bias=no_bias) 68 ebn2 = BatchNorm(e2, name='encbn2', fix_gamma=fix_gamma, eps=eps) 69 eact2 = mx.sym.LeakyReLU(ebn2, name='encact2', act_type='leaky', slope=0.2) 70 71 e3 = mx.sym.Convolution(eact2, name='enc3', kernel=(5,5), stride=(2,2), pad=(2,2), num_filter=nef*4, no_bias=no_bias) 72 ebn3 = BatchNorm(e3, name='encbn3', fix_gamma=fix_gamma, eps=eps) 73 eact3 = mx.sym.LeakyReLU(ebn3, name='encact3', act_type='leaky', slope=0.2) 74 75 e4 = mx.sym.Convolution(eact3, name='enc4', kernel=(5,5), stride=(2,2), pad=(2,2), num_filter=nef*8, no_bias=no_bias) 76 ebn4 = BatchNorm(e4, name='encbn4', fix_gamma=fix_gamma, eps=eps) 77 eact4 = mx.sym.LeakyReLU(ebn4, name='encact4', act_type='leaky', slope=0.2) 78 79 eact4 = mx.sym.Flatten(eact4) 80 81 z_mu = mx.sym.FullyConnected(eact4, num_hidden=z_dim, name="enc_mu") 82 z_lv = mx.sym.FullyConnected(eact4, num_hidden=z_dim, name="enc_lv") 83 84 z = z_mu + mx.symbol.broadcast_mul(mx.symbol.exp(0.5*z_lv),mx.symbol.random_normal(loc=0, scale=1,shape=(batch_size,z_dim))) 85 86 return z_mu, z_lv, z 87 88def generator(ngf, nc, no_bias=True, fix_gamma=True, eps=1e-5 + 1e-12, z_dim=100, activation='sigmoid'): 89 '''The genrator is a CNN which takes 100 dimensional embedding as input 90 and reconstructs the input image given to the encoder 91 ''' 92 BatchNorm = mx.sym.BatchNorm 93 rand = mx.sym.Variable('rand') 94 95 rand = mx.sym.Reshape(rand, shape=(-1, z_dim, 1, 1)) 96 97 g1 = mx.sym.Deconvolution(rand, name='gen1', kernel=(5,5), stride=(2,2),target_shape=(2,2), num_filter=ngf*8, no_bias=no_bias) 98 gbn1 = BatchNorm(g1, name='genbn1', fix_gamma=fix_gamma, eps=eps) 99 gact1 = mx.sym.Activation(gbn1, name="genact1", act_type="relu") 100 101 g2 = mx.sym.Deconvolution(gact1, name='gen2', kernel=(5,5), stride=(2,2),target_shape=(4,4), num_filter=ngf*4, no_bias=no_bias) 102 gbn2 = BatchNorm(g2, name='genbn2', fix_gamma=fix_gamma, eps=eps) 103 gact2 = mx.sym.Activation(gbn2, name='genact2', act_type='relu') 104 105 g3 = mx.sym.Deconvolution(gact2, name='gen3', kernel=(5,5), stride=(2,2), target_shape=(8,8), num_filter=ngf*2, no_bias=no_bias) 106 gbn3 = BatchNorm(g3, name='genbn3', fix_gamma=fix_gamma, eps=eps) 107 gact3 = mx.sym.Activation(gbn3, name='genact3', act_type='relu') 108 109 g4 = mx.sym.Deconvolution(gact3, name='gen4', kernel=(5,5), stride=(2,2), target_shape=(16,16), num_filter=ngf, no_bias=no_bias) 110 gbn4 = BatchNorm(g4, name='genbn4', fix_gamma=fix_gamma, eps=eps) 111 gact4 = mx.sym.Activation(gbn4, name='genact4', act_type='relu') 112 113 g5 = mx.sym.Deconvolution(gact4, name='gen5', kernel=(5,5), stride=(2,2), target_shape=(32,32), num_filter=nc, no_bias=no_bias) 114 gout = mx.sym.Activation(g5, name='genact5', act_type=activation) 115 116 return gout 117 118def discriminator1(ndf, no_bias=True, fix_gamma=True, eps=1e-5 + 1e-12): 119 '''First part of the discriminator which takes a 32x32 image as input 120 and output a convolutional feature map, this is required to calculate 121 the layer loss''' 122 BatchNorm = mx.sym.BatchNorm 123 124 data = mx.sym.Variable('data') 125 126 d1 = mx.sym.Convolution(data, name='d1', kernel=(5,5), stride=(2,2), pad=(2,2), num_filter=ndf, no_bias=no_bias) 127 dact1 = mx.sym.LeakyReLU(d1, name='dact1', act_type='leaky', slope=0.2) 128 129 d2 = mx.sym.Convolution(dact1, name='d2', kernel=(5,5), stride=(2,2), pad=(2,2), num_filter=ndf*2, no_bias=no_bias) 130 dbn2 = BatchNorm(d2, name='dbn2', fix_gamma=fix_gamma, eps=eps) 131 dact2 = mx.sym.LeakyReLU(dbn2, name='dact2', act_type='leaky', slope=0.2) 132 133 d3 = mx.sym.Convolution(dact2, name='d3', kernel=(5,5), stride=(2,2), pad=(2,2), num_filter=ndf*4, no_bias=no_bias) 134 dbn3 = BatchNorm(d3, name='dbn3', fix_gamma=fix_gamma, eps=eps) 135 dact3 = mx.sym.LeakyReLU(dbn3, name='dact3', act_type='leaky', slope=0.2) 136 137 return dact3 138 139def discriminator2(ndf, no_bias=True, fix_gamma=True, eps=1e-5 + 1e-12): 140 '''Second part of the discriminator which takes a 256x8x8 feature map as input 141 and generates the loss based on whether the input image was a real one or fake one''' 142 143 BatchNorm = mx.sym.BatchNorm 144 145 data = mx.sym.Variable('data') 146 147 label = mx.sym.Variable('label') 148 149 d4 = mx.sym.Convolution(data, name='d4', kernel=(5,5), stride=(2,2), pad=(2,2), num_filter=ndf*8, no_bias=no_bias) 150 dbn4 = BatchNorm(d4, name='dbn4', fix_gamma=fix_gamma, eps=eps) 151 dact4 = mx.sym.LeakyReLU(dbn4, name='dact4', act_type='leaky', slope=0.2) 152 153 h = mx.sym.Flatten(dact4) 154 155 d5 = mx.sym.FullyConnected(h, num_hidden=1, name="d5") 156 157 dloss = mx.sym.LogisticRegressionOutput(data=d5, label=label, name='dloss') 158 159 return dloss 160 161def GaussianLogDensity(x, mu, log_var, name='GaussianLogDensity', EPSILON = 1e-6): 162 '''GaussianLogDensity loss calculation for layer wise loss 163 ''' 164 c = mx.sym.ones_like(log_var)*2.0 * 3.1416 165 c = mx.symbol.log(c) 166 var = mx.sym.exp(log_var) 167 x_mu2 = mx.symbol.square(x - mu) # [Issue] not sure the dim works or not? 168 x_mu2_over_var = mx.symbol.broadcast_div(x_mu2, var + EPSILON) 169 log_prob = -0.5 * (c + log_var + x_mu2_over_var) 170 log_prob = mx.symbol.sum(log_prob, axis=1, name=name) # keep_dims=True, 171 return log_prob 172 173def DiscriminatorLayerLoss(): 174 '''Calculate the discriminator layer loss 175 ''' 176 177 data = mx.sym.Variable('data') 178 179 label = mx.sym.Variable('label') 180 181 data = mx.sym.Flatten(data) 182 label = mx.sym.Flatten(label) 183 184 label = mx.sym.BlockGrad(label) 185 186 zeros = mx.sym.zeros_like(data) 187 188 output = -GaussianLogDensity(label, data, zeros) 189 190 dloss = mx.symbol.MakeLoss(mx.symbol.mean(output),name='lloss') 191 192 return dloss 193 194def KLDivergenceLoss(): 195 '''KLDivergenceLoss loss 196 ''' 197 198 data = mx.sym.Variable('data') 199 mu1, lv1 = mx.sym.split(data, num_outputs=2, axis=0) 200 mu2 = mx.sym.zeros_like(mu1) 201 lv2 = mx.sym.zeros_like(lv1) 202 203 v1 = mx.sym.exp(lv1) 204 v2 = mx.sym.exp(lv2) 205 mu_diff_sq = mx.sym.square(mu1 - mu2) 206 dimwise_kld = .5 * ( 207 (lv2 - lv1) + mx.symbol.broadcast_div(v1, v2) + mx.symbol.broadcast_div(mu_diff_sq, v2) - 1.) 208 KL = mx.symbol.sum(dimwise_kld, axis=1) 209 210 KLloss = mx.symbol.MakeLoss(mx.symbol.mean(KL),name='KLloss') 211 return KLloss 212 213def get_data(path, activation): 214 '''Get the dataset 215 ''' 216 data = [] 217 image_names = [] 218 for filename in os.listdir(path): 219 img = cv2.imread(os.path.join(path,filename), cv2.IMREAD_GRAYSCALE) 220 image_names.append(filename) 221 if img is not None: 222 data.append(img) 223 224 data = np.asarray(data) 225 226 if activation == 'sigmoid': 227 data = data.astype(np.float32)/(255.0) 228 elif activation == 'tanh': 229 data = data.astype(np.float32)/(255.0/2) - 1.0 230 231 data = data.reshape((data.shape[0], 1, data.shape[1], data.shape[2])) 232 233 np.random.seed(1234) 234 p = np.random.permutation(data.shape[0]) 235 X = data[p] 236 237 return X, image_names 238 239class RandIter(mx.io.DataIter): 240 '''Create a random iterator for generator 241 ''' 242 def __init__(self, batch_size, ndim): 243 self.batch_size = batch_size 244 self.ndim = ndim 245 self.provide_data = [('rand', (batch_size, ndim, 1, 1))] 246 self.provide_label = [] 247 248 def iter_next(self): 249 return True 250 251 def getdata(self): 252 return [mx.random.normal(0, 1.0, shape=(self.batch_size, self.ndim, 1, 1))] 253 254def fill_buf(buf, i, img, shape): 255 '''fill the ith grid of the buffer matrix with the values from the img 256 buf : buffer matrix 257 i : serial of the image in the 2D grid 258 img : image data 259 shape : ( height width depth ) of image''' 260 261 # grid height is a multiple of individual image height 262 m = buf.shape[0]/shape[0] 263 264 sx = (i%m)*shape[1] 265 sy = (i//m)*shape[0] 266 sx = int(sx) 267 sy = int(sy) 268 buf[sy:sy+shape[0], sx:sx+shape[1], :] = img 269 270def visual(title, X, activation): 271 '''create a grid of images and save it as a final image 272 title : grid image name 273 X : array of images 274 ''' 275 assert len(X.shape) == 4 276 277 X = X.transpose((0, 2, 3, 1)) 278 if activation == 'sigmoid': 279 X = np.clip((X)*(255.0), 0, 255).astype(np.uint8) 280 elif activation == 'tanh': 281 X = np.clip((X+1.0)*(255.0/2.0), 0, 255).astype(np.uint8) 282 n = np.ceil(np.sqrt(X.shape[0])) 283 buff = np.zeros((int(n*X.shape[1]), int(n*X.shape[2]), int(X.shape[3])), dtype=np.uint8) 284 for i, img in enumerate(X): 285 fill_buf(buff, i, img, X.shape[1:3]) 286 cv2.imwrite('%s.jpg' % (title), buff) 287 288def train(dataset, nef, ndf, ngf, nc, batch_size, Z, lr, beta1, epsilon, ctx, check_point, g_dl_weight, output_path, checkpoint_path, data_path, activation,num_epoch, save_after_every, visualize_after_every, show_after_every): 289 '''adversarial training of the VAE 290 ''' 291 292 #encoder 293 z_mu, z_lv, z = encoder(nef, Z, batch_size) 294 symE = mx.sym.Group([z_mu, z_lv, z]) 295 296 #generator 297 symG = generator(ngf, nc, no_bias=True, fix_gamma=True, eps=1e-5 + 1e-12, z_dim = Z, activation=activation ) 298 299 #discriminator 300 h = discriminator1(ndf) 301 dloss = discriminator2(ndf) 302 symD1 = h 303 symD2 = dloss 304 305 306 # ==============data============== 307 X_train, _ = get_data(data_path, activation) 308 train_iter = mx.io.NDArrayIter(X_train, batch_size=batch_size, shuffle=True) 309 rand_iter = RandIter(batch_size, Z) 310 label = mx.nd.zeros((batch_size,), ctx=ctx) 311 312 # =============module E============= 313 modE = mx.mod.Module(symbol=symE, data_names=('data',), label_names=None, context=ctx) 314 modE.bind(data_shapes=train_iter.provide_data) 315 modE.init_params(initializer=mx.init.Normal(0.02)) 316 modE.init_optimizer( 317 optimizer='adam', 318 optimizer_params={ 319 'learning_rate': lr, 320 'wd': 1e-6, 321 'beta1': beta1, 322 'epsilon': epsilon, 323 'rescale_grad': (1.0/batch_size) 324 }) 325 mods = [modE] 326 327 # =============module G============= 328 modG = mx.mod.Module(symbol=symG, data_names=('rand',), label_names=None, context=ctx) 329 modG.bind(data_shapes=rand_iter.provide_data, inputs_need_grad=True) 330 modG.init_params(initializer=mx.init.Normal(0.02)) 331 modG.init_optimizer( 332 optimizer='adam', 333 optimizer_params={ 334 'learning_rate': lr, 335 'wd': 1e-6, 336 'beta1': beta1, 337 'epsilon': epsilon, 338 }) 339 mods.append(modG) 340 341 # =============module D============= 342 modD1 = mx.mod.Module(symD1, label_names=[], context=ctx) 343 modD2 = mx.mod.Module(symD2, label_names=('label',), context=ctx) 344 modD = mx.mod.SequentialModule() 345 modD.add(modD1).add(modD2, take_labels=True, auto_wiring=True) 346 modD.bind(data_shapes=train_iter.provide_data, 347 label_shapes=[('label', (batch_size,))], 348 inputs_need_grad=True) 349 modD.init_params(initializer=mx.init.Normal(0.02)) 350 modD.init_optimizer( 351 optimizer='adam', 352 optimizer_params={ 353 'learning_rate': lr, 354 'wd': 1e-3, 355 'beta1': beta1, 356 'epsilon': epsilon, 357 'rescale_grad': (1.0/batch_size) 358 }) 359 mods.append(modD) 360 361 362 # =============module DL============= 363 symDL = DiscriminatorLayerLoss() 364 modDL = mx.mod.Module(symbol=symDL, data_names=('data',), label_names=('label',), context=ctx) 365 modDL.bind(data_shapes=[('data', (batch_size,nef * 4,4,4))], ################################################################################################################################ fix 512 here 366 label_shapes=[('label', (batch_size,nef * 4,4,4))], 367 inputs_need_grad=True) 368 modDL.init_params(initializer=mx.init.Normal(0.02)) 369 modDL.init_optimizer( 370 optimizer='adam', 371 optimizer_params={ 372 'learning_rate': lr, 373 'wd': 0., 374 'beta1': beta1, 375 'epsilon': epsilon, 376 'rescale_grad': (1.0/batch_size) 377 }) 378 379 # =============module KL============= 380 symKL = KLDivergenceLoss() 381 modKL = mx.mod.Module(symbol=symKL, data_names=('data',), label_names=None, context=ctx) 382 modKL.bind(data_shapes=[('data', (batch_size*2,Z))], 383 inputs_need_grad=True) 384 modKL.init_params(initializer=mx.init.Normal(0.02)) 385 modKL.init_optimizer( 386 optimizer='adam', 387 optimizer_params={ 388 'learning_rate': lr, 389 'wd': 0., 390 'beta1': beta1, 391 'epsilon': epsilon, 392 'rescale_grad': (1.0/batch_size) 393 }) 394 mods.append(modKL) 395 396 def norm_stat(d): 397 return mx.nd.norm(d)/np.sqrt(d.size) 398 mon = mx.mon.Monitor(10, norm_stat, pattern=".*output|d1_backward_data", sort=True) 399 mon = None 400 if mon is not None: 401 for mod in mods: 402 pass 403 404 def facc(label, pred): 405 '''calculating prediction accuracy 406 ''' 407 pred = pred.ravel() 408 label = label.ravel() 409 return ((pred > 0.5) == label).mean() 410 411 def fentropy(label, pred): 412 '''calculating binary cross-entropy loss 413 ''' 414 pred = pred.ravel() 415 label = label.ravel() 416 return -(label*np.log(pred+1e-12) + (1.-label)*np.log(1.-pred+1e-12)).mean() 417 418 def kldivergence(label, pred): 419 '''calculating KL divergence loss 420 ''' 421 mean, log_var = np.split(pred, 2, axis=0) 422 var = np.exp(log_var) 423 KLLoss = -0.5 * np.sum(1 + log_var - np.power(mean, 2) - var) 424 KLLoss = KLLoss / nElements 425 return KLLoss 426 427 mG = mx.metric.CustomMetric(fentropy) 428 mD = mx.metric.CustomMetric(fentropy) 429 mE = mx.metric.CustomMetric(kldivergence) 430 mACC = mx.metric.CustomMetric(facc) 431 432 print('Training...') 433 stamp = datetime.now().strftime('%Y_%m_%d-%H_%M') 434 435 # =============train=============== 436 for epoch in range(num_epoch): 437 train_iter.reset() 438 for t, batch in enumerate(train_iter): 439 440 rbatch = rand_iter.next() 441 442 if mon is not None: 443 mon.tic() 444 445 modG.forward(rbatch, is_train=True) 446 outG = modG.get_outputs() 447 448 # update discriminator on fake 449 label[:] = 0 450 modD.forward(mx.io.DataBatch(outG, [label]), is_train=True) 451 modD.backward() 452 gradD11 = [[grad.copyto(grad.context) for grad in grads] for grads in modD1._exec_group.grad_arrays] 453 gradD12 = [[grad.copyto(grad.context) for grad in grads] for grads in modD2._exec_group.grad_arrays] 454 455 modD.update_metric(mD, [label]) 456 modD.update_metric(mACC, [label]) 457 458 459 #update discriminator on decoded 460 modE.forward(batch, is_train=True) 461 mu, lv, z = modE.get_outputs() 462 z = z.reshape((batch_size, Z, 1, 1)) 463 sample = mx.io.DataBatch([z], label=None, provide_data = [('rand', (batch_size, Z, 1, 1))]) 464 modG.forward(sample, is_train=True) 465 xz = modG.get_outputs() 466 label[:] = 0 467 modD.forward(mx.io.DataBatch(xz, [label]), is_train=True) 468 modD.backward() 469 470 #modD.update() 471 gradD21 = [[grad.copyto(grad.context) for grad in grads] for grads in modD1._exec_group.grad_arrays] 472 gradD22 = [[grad.copyto(grad.context) for grad in grads] for grads in modD2._exec_group.grad_arrays] 473 modD.update_metric(mD, [label]) 474 modD.update_metric(mACC, [label]) 475 476 # update discriminator on real 477 label[:] = 1 478 batch.label = [label] 479 modD.forward(batch, is_train=True) 480 lx = [out.copyto(out.context) for out in modD1.get_outputs()] 481 modD.backward() 482 for gradsr, gradsf, gradsd in zip(modD1._exec_group.grad_arrays, gradD11, gradD21): 483 for gradr, gradf, gradd in zip(gradsr, gradsf, gradsd): 484 gradr += 0.5 * (gradf + gradd) 485 for gradsr, gradsf, gradsd in zip(modD2._exec_group.grad_arrays, gradD12, gradD22): 486 for gradr, gradf, gradd in zip(gradsr, gradsf, gradsd): 487 gradr += 0.5 * (gradf + gradd) 488 489 modD.update() 490 modD.update_metric(mD, [label]) 491 modD.update_metric(mACC, [label]) 492 493 modG.forward(rbatch, is_train=True) 494 outG = modG.get_outputs() 495 label[:] = 1 496 modD.forward(mx.io.DataBatch(outG, [label]), is_train=True) 497 modD.backward() 498 diffD = modD1.get_input_grads() 499 modG.backward(diffD) 500 gradG1 = [[grad.copyto(grad.context) for grad in grads] for grads in modG._exec_group.grad_arrays] 501 mG.update([label], modD.get_outputs()) 502 503 modG.forward(sample, is_train=True) 504 xz = modG.get_outputs() 505 label[:] = 1 506 modD.forward(mx.io.DataBatch(xz, [label]), is_train=True) 507 modD.backward() 508 diffD = modD1.get_input_grads() 509 modG.backward(diffD) 510 gradG2 = [[grad.copyto(grad.context) for grad in grads] for grads in modG._exec_group.grad_arrays] 511 mG.update([label], modD.get_outputs()) 512 513 modG.forward(sample, is_train=True) 514 xz = modG.get_outputs() 515 modD1.forward(mx.io.DataBatch(xz, []), is_train=True) 516 outD1 = modD1.get_outputs() 517 modDL.forward(mx.io.DataBatch(outD1, lx), is_train=True) 518 modDL.backward() 519 dlGrad = modDL.get_input_grads() 520 modD1.backward(dlGrad) 521 diffD = modD1.get_input_grads() 522 modG.backward(diffD) 523 524 for grads, gradsG1, gradsG2 in zip(modG._exec_group.grad_arrays, gradG1, gradG2): 525 for grad, gradg1, gradg2 in zip(grads, gradsG1, gradsG2): 526 grad = g_dl_weight * grad + 0.5 * (gradg1 + gradg2) 527 528 modG.update() 529 mG.update([label], modD.get_outputs()) 530 531 modG.forward(rbatch, is_train=True) 532 outG = modG.get_outputs() 533 label[:] = 1 534 modD.forward(mx.io.DataBatch(outG, [label]), is_train=True) 535 modD.backward() 536 diffD = modD1.get_input_grads() 537 modG.backward(diffD) 538 gradG1 = [[grad.copyto(grad.context) for grad in grads] for grads in modG._exec_group.grad_arrays] 539 mG.update([label], modD.get_outputs()) 540 541 modG.forward(sample, is_train=True) 542 xz = modG.get_outputs() 543 label[:] = 1 544 modD.forward(mx.io.DataBatch(xz, [label]), is_train=True) 545 modD.backward() 546 diffD = modD1.get_input_grads() 547 modG.backward(diffD) 548 gradG2 = [[grad.copyto(grad.context) for grad in grads] for grads in modG._exec_group.grad_arrays] 549 mG.update([label], modD.get_outputs()) 550 551 modG.forward(sample, is_train=True) 552 xz = modG.get_outputs() 553 modD1.forward(mx.io.DataBatch(xz, []), is_train=True) 554 outD1 = modD1.get_outputs() 555 modDL.forward(mx.io.DataBatch(outD1, lx), is_train=True) 556 modDL.backward() 557 dlGrad = modDL.get_input_grads() 558 modD1.backward(dlGrad) 559 diffD = modD1.get_input_grads() 560 modG.backward(diffD) 561 562 for grads, gradsG1, gradsG2 in zip(modG._exec_group.grad_arrays, gradG1, gradG2): 563 for grad, gradg1, gradg2 in zip(grads, gradsG1, gradsG2): 564 grad = g_dl_weight * grad + 0.5 * (gradg1 + gradg2) 565 566 modG.update() 567 mG.update([label], modD.get_outputs()) 568 569 modG.forward(sample, is_train=True) 570 xz = modG.get_outputs() 571 572 #update generator 573 modD1.forward(mx.io.DataBatch(xz, []), is_train=True) 574 outD1 = modD1.get_outputs() 575 modDL.forward(mx.io.DataBatch(outD1, lx), is_train=True) 576 DLloss = modDL.get_outputs() 577 modDL.backward() 578 dlGrad = modDL.get_input_grads() 579 modD1.backward(dlGrad) 580 diffD = modD1.get_input_grads() 581 modG.backward(diffD) 582 #update encoder 583 nElements = batch_size 584 modKL.forward(mx.io.DataBatch([mx.ndarray.concat(mu,lv, dim=0)]), is_train=True) 585 KLloss = modKL.get_outputs() 586 modKL.backward() 587 gradKLLoss = modKL.get_input_grads() 588 diffG = modG.get_input_grads() 589 diffG = diffG[0].reshape((batch_size, Z)) 590 modE.backward(mx.ndarray.split(gradKLLoss[0], num_outputs=2, axis=0) + [diffG]) 591 modE.update() 592 pred = mx.ndarray.concat(mu,lv, dim=0) 593 mE.update([pred], [pred]) 594 if mon is not None: 595 mon.toc_print() 596 597 t += 1 598 if t % show_after_every == 0: 599 print('epoch:', epoch, 'iter:', t, 'metric:', mACC.get(), mG.get(), mD.get(), mE.get(), KLloss[0].asnumpy(), DLloss[0].asnumpy()) 600 mACC.reset() 601 mG.reset() 602 mD.reset() 603 mE.reset() 604 605 if epoch % visualize_after_every == 0: 606 visual(output_path +'gout'+str(epoch), outG[0].asnumpy(), activation) 607 visual(output_path + 'data'+str(epoch), batch.data[0].asnumpy(), activation) 608 609 if check_point and epoch % save_after_every == 0: 610 print('Saving...') 611 modG.save_params(checkpoint_path + '/%s_G-%04d.params'%(dataset, epoch)) 612 modD.save_params(checkpoint_path + '/%s_D-%04d.params'%(dataset, epoch)) 613 modE.save_params(checkpoint_path + '/%s_E-%04d.params'%(dataset, epoch)) 614 615def test(nef, ngf, nc, batch_size, Z, ctx, pretrained_encoder_path, pretrained_generator_path, output_path, data_path, activation, save_embedding, embedding_path = ''): 616 '''Test the VAE with a pretrained encoder and generator. 617 Keep the batch size 1''' 618 #encoder 619 z_mu, z_lv, z = encoder(nef, Z, batch_size) 620 symE = mx.sym.Group([z_mu, z_lv, z]) 621 622 #generator 623 symG = generator(ngf, nc, no_bias=True, fix_gamma=True, eps=1e-5 + 1e-12, z_dim = Z, activation=activation ) 624 625 # ==============data============== 626 X_test, image_names = get_data(data_path, activation) 627 test_iter = mx.io.NDArrayIter(X_test, batch_size=batch_size, shuffle=False) 628 629 # =============module E============= 630 modE = mx.mod.Module(symbol=symE, data_names=('data',), label_names=None, context=ctx) 631 modE.bind(data_shapes=test_iter.provide_data) 632 modE.load_params(pretrained_encoder_path) 633 634 # =============module G============= 635 modG = mx.mod.Module(symbol=symG, data_names=('rand',), label_names=None, context=ctx) 636 modG.bind(data_shapes=[('rand', (1, Z, 1, 1))]) 637 modG.load_params(pretrained_generator_path) 638 639 print('Testing...') 640 641 # =============test=============== 642 test_iter.reset() 643 for t, batch in enumerate(test_iter): 644 645 #update discriminator on decoded 646 modE.forward(batch, is_train=False) 647 mu, lv, z = modE.get_outputs() 648 mu = mu.reshape((batch_size, Z, 1, 1)) 649 sample = mx.io.DataBatch([mu], label=None, provide_data = [('rand', (batch_size, Z, 1, 1))]) 650 modG.forward(sample, is_train=False) 651 outG = modG.get_outputs() 652 653 visual(output_path + '/' + 'gout'+str(t), outG[0].asnumpy(), activation) 654 visual(output_path + '/' + 'data'+str(t), batch.data[0].asnumpy(), activation) 655 image_name = image_names[t].split('.')[0] 656 657 if save_embedding: 658 savemat(embedding_path+'/'+image_name+'.mat', {'embedding':mu.asnumpy()}) 659 660def create_and_validate_dir(data_dir): 661 '''Creates/Validates dir 662 ''' 663 if data_dir != "": 664 if not os.path.exists(data_dir): 665 try: 666 logging.info('create directory %s', data_dir) 667 os.makedirs(data_dir) 668 except OSError as exc: 669 if exc.errno != errno.EEXIST: 670 raise OSError('failed to create ' + data_dir) 671 672 673def parse_args(): 674 '''Parse args 675 ''' 676 parser = argparse.ArgumentParser(description='Train and Test an Adversarial Variatiional Encoder') 677 678 parser.add_argument('--train', help='train the network', action='store_true') 679 parser.add_argument('--test', help='test the network', action='store_true') 680 parser.add_argument('--save_embedding', help='saves the shape embedding of each input image', action='store_true') 681 parser.add_argument('--dataset', help='dataset name', default='caltech', type=str) 682 parser.add_argument('--activation', help='activation i.e. sigmoid or tanh', default='sigmoid', type=str) 683 parser.add_argument('--training_data_path', help='training data path', default='datasets/caltech101/data/images32x32', type=str) 684 parser.add_argument('--testing_data_path', help='testing data path', default='datasets/caltech101/test_data', type=str) 685 parser.add_argument('--pretrained_encoder_path', help='pretrained encoder model path', default='checkpoints32x32_sigmoid/caltech_E-0045.params', type=str) 686 parser.add_argument('--pretrained_generator_path', help='pretrained generator model path', default='checkpoints32x32_sigmoid/caltech_G-0045.params', type=str) 687 parser.add_argument('--output_path', help='output path for the generated images', default='outputs32x32_sigmoid', type=str) 688 parser.add_argument('--embedding_path', help='output path for the generated embeddings', default='outputs32x32_sigmoid', type=str) 689 parser.add_argument('--checkpoint_path', help='checkpoint saving path ', default='checkpoints32x32_sigmoid', type=str) 690 parser.add_argument('--nef', help='encoder filter count in the first layer', default=64, type=int) 691 parser.add_argument('--ndf', help='discriminator filter count in the first layer', default=64, type=int) 692 parser.add_argument('--ngf', help='generator filter count in the second last layer', default=64, type=int) 693 parser.add_argument('--nc', help='generator filter count in the last layer i.e. 1 for grayscale image, 3 for RGB image', default=1, type=int) 694 parser.add_argument('--batch_size', help='batch size, keep it 1 during testing', default=64, type=int) 695 parser.add_argument('--Z', help='embedding size', default=100, type=int) 696 parser.add_argument('--lr', help='learning rate', default=0.0002, type=float) 697 parser.add_argument('--beta1', help='beta1 for adam optimizer', default=0.5, type=float) 698 parser.add_argument('--epsilon', help='epsilon for adam optimizer', default=1e-5, type=float) 699 parser.add_argument('--g_dl_weight', help='discriminator layer loss weight', default=1e-1, type=float) 700 parser.add_argument('--gpu', help='gpu index', default=0, type=int) 701 parser.add_argument('--use_cpu', help='use cpu', action='store_true') 702 parser.add_argument('--num_epoch', help='number of maximum epochs ', default=45, type=int) 703 parser.add_argument('--save_after_every', help='save checkpoint after every this number of epochs ', default=5, type=int) 704 parser.add_argument('--visualize_after_every', help='save output images after every this number of epochs', default=5, type=int) 705 parser.add_argument('--show_after_every', help='show metrics after this number of iterations', default=10, type=int) 706 707 args = parser.parse_args() 708 return args 709 710def main(): 711 args = parse_args() 712 713 if args.test and not os.path.exists(args.testing_data_path): 714 if not os.path.exists(args.testing_data_path): 715 raise OSError("Provided Testing Path: {} does not exist".format(args.testing_data_path)) 716 if not os.path.exists(args.checkpoint_path): 717 raise OSError("Provided Checkpoint Path: {} does not exist".format(args.checkpoint_path)) 718 719 create_and_validate_dir(args.checkpoint_path) 720 create_and_validate_dir(args.output_path) 721 722 # gpu context 723 if args.use_cpu: 724 ctx = mx.cpu() 725 else: 726 ctx = mx.gpu(args.gpu) 727 728 # checkpoint saving flags 729 check_point = True 730 731 if args.train: 732 train(args.dataset, args.nef, args.ndf, args.ngf, args.nc, args.batch_size, args.Z, args.lr, args.beta1, args.epsilon, ctx, check_point, args.g_dl_weight, args.output_path, args.checkpoint_path, args.training_data_path, args.activation, args.num_epoch, args.save_after_every, args.visualize_after_every, args.show_after_every) 733 734 if args.test: 735 test(args.nef, args.ngf, args.nc, 1, args.Z, ctx, args.pretrained_encoder_path, args.pretrained_generator_path, args.output_path, args.testing_data_path, args.activation, args.save_embedding, args.embedding_path) 736 737if __name__ == '__main__': 738 logging.basicConfig(level=logging.DEBUG) 739 main() 740