1{
2 "cells": [
3  {
4   "cell_type": "code",
5   "execution_count": 1,
6   "metadata": {},
7   "outputs": [],
8   "source": [
9    "import mxnet as mx\n",
10    "import numpy as np\n",
11    "import os\n",
12    "import logging\n",
13    "import matplotlib.pyplot as plt\n",
14    "import matplotlib.cm as cm"
15   ]
16  },
17  {
18   "cell_type": "markdown",
19   "metadata": {},
20   "source": [
21    "# Building a Variational Autoencoder in MXNet\n",
22    "\n",
23    "#### Xiaoyu Lu,  July 5th, 2017\n",
24    "\n",
25    "This tutorial guides you through the process of building a variational encoder in MXNet. In this notebook we'll focus on an example using the MNIST handwritten digit recognition dataset. Refer to [Auto-Encoding Variational Bayes](https://arxiv.org/abs/1312.6114/) for more details on the model description.\n",
26    "\n"
27   ]
28  },
29  {
30   "cell_type": "markdown",
31   "metadata": {},
32   "source": [
33    "## Prerequisites\n",
34    "\n",
35    "To complete this tutorial, we need following python packages:\n",
36    "\n",
37    "- numpy, matplotlib "
38   ]
39  },
40  {
41   "cell_type": "markdown",
42   "metadata": {},
43   "source": [
44    "## 1. Loading the Data\n",
45    "\n",
46    "We first load the MNIST dataset, which contains 60000 training and 10000 test examples. The following code imports required modules and loads the data. These images are stored in a 4-D matrix with shape (`batch_size, num_channels, width, height`). For the MNIST dataset, there is only one color channel, and both width and height are 28, so we reshape each image as a 28x28 array. See below for a visualization:\n"
47   ]
48  },
49  {
50   "cell_type": "code",
51   "execution_count": 2,
52   "metadata": {},
53   "outputs": [
54    {
55     "name": "stdout",
56     "output_type": "stream",
57     "text": [
58      "60000 784\n"
59     ]
60    }
61   ],
62   "source": [
63    "mnist = mx.test_utils.get_mnist()\n",
64    "image = np.reshape(mnist['train_data'],(60000,28*28))\n",
65    "label = image\n",
66    "image_test = np.reshape(mnist['test_data'],(10000,28*28))\n",
67    "label_test = image_test\n",
68    "[N,features] = np.shape(image)          #number of examples and features\n",
69    "print(N,features)"
70   ]
71  },
72  {
73   "cell_type": "code",
74   "execution_count": 3,
75   "metadata": {},
76   "outputs": [
77    {
78     "data": {
79      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAsMAAACWCAYAAAA7UIUvAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvhp/UCwAAFI5JREFUeJzt3X+wVfO/x/H3u9+JfikVnb5FaYTEnLmqe5mLDDKJwfGVaYzfVJRBP+5FgzHjd6YoMkyFcbvUJDToxnW7GDql6YpSFwnpVKRfqDM+94+2O+f9Waf94+xfa+3P8zHTnF6rtfd6n/a73afde62lzjkBAAAAQtSs3AUAAAAA5cJiGAAAAMFiMQwAAIBgsRgGAABAsFgMAwAAIFgshgEAABAsFsMAAAAIFothAAAABCuvxbCqnq+q61V1o6pOLlRRAAAAQCloU+9Ap6rNReQrETlXRL4XkRUicqVz7otDPaZLly6ud+/eTToe4uXbb7+V7du3azGemz6pLCtXrtzunOtajOemVyoH7ynIFu8pyEYu7ykt8jjOP4jIRufc1yIiqvpvIjJSRA65GO7du7fU1tbmcUjERXV1ddGemz6pLKq6qVjPTa9UDt5TkC3eU5CNXN5T8hmTOEZENjfI36e2Gap6o6rWqmrttm3b8jgcKhl9gmzRK8gGfYJs0Sso+gl0zrnZzrlq51x1165F+V8NVAD6BNmiV5AN+gTZoleQz2L4BxGpapB7prYBAAAAiZDPYniFiPRT1T6q2kpE/i4iiwtTFgAAAFB8TT6BzjlXr6rjROQdEWkuIi8459YWrDIAAACgyPK5moQ455aIyJIC1QIAAACUFHegAwAAQLBYDAMAACBYLIYBAAAQLBbDAAAACBaLYQAAAASLxTAAAACCxWIYAAAAwWIxDAAAgGCxGAYAAECw8roDHYDiue2220yeMWNGZJ8OHTqYvG7dOpO7d+9e+MIAAKggfDIMAACAYLEYBgAAQLBYDAMAACBYzAwXwdatW00eNGiQyUOHDjV5wYIFRa8J8bdq1SqTn376aZObNYv+23X37t0m19XVmczMMJBcf/zxh8nz5883+Zprrkn7+OXLl5s8ePBgk/ft2xd5TJs2bUw+cOCAyf77UOvWrdPWACQBnwwDAAAgWCyGAQAAECwWwwAAAAgWM8NFMHv2bJP9GeJFixaVshzEVH19vcnz5s3L+TmOP/54k3v16pVXTQDKY+/evZFtU6dONfnJJ580WVXTPuftt99u8mmnnWbyc889F3nMZZddZvL7779vcu/evU32z4Hxz1O48cYbTe7UqdOhCwbKhE+GAQAAECwWwwAAAAgWi2EAAAAEi5nhEnDOlbsExNAnn3xi8owZM3J+jsmTJ5vcsWPHvGoCUB6XXnppZNvSpUtNzjQj7Fu5cqXJtbW1GZ/vtddeS/ucO3bsyOk53333XZOXLFkSeU6uVZzehg0bTO7fv39knxEjRpg8a9Ysk48++ujCF+bZuXOnyX6v+OdTZTJu3LjItqqqqtwLywKfDAMAACBYLIYBAAAQLBbDAAAACBYzw0Wwdu1ak/0ZqksuuaSU5SAm/vzzT5OnTJmS0+NXrFgR2XbyySfnVRMyu+WWW0z+4osvTL7qqqtMHjhwoMmDBw8uTmFItLlz55r8wQcflKmS4vr6669N/uOPPyL7MDOc3vz5801ubNb7jTfeMNk/J8WfGT7qqKNMHjVqVNoaPvvss8g2v2d//PFHk7dt22ayf/5Uphn48847L7KNmWEAAACgwFgMAwAAIFgshgEAABAsZoaLwJ/F8TV2P3hUPn/m6sMPP0y7/xFHHGFyY9eJbNmyZf6FwfCvs/rss8+m3X/58uUm+3NwLVpE32b9uWL/epp9+vRJ+xz79u0zediwYWlrRPn512CdOnWqyQcOHChlOQXjnwMzcuRIky+//HKT27RpU/Sakqa+vt7k1atXm/zEE0/k/Jz+vK6fff41rXOd7y2GIUOGlOxYfDIMAACAYLEYBgAAQLAyLoZV9QVVrVPVzxts66yqS1V1Q+prp+KWCQAAABReNjPDc0TkKRGZ12DbZBFZ5px7SFUnp/KkwpeXDPv37zd548aNZaoEcfbNN9/ktP8pp5xicvfu3QtZDg7Bn+dt27atyb/99lvax/uzdo3Ngq5cudLka665xuRmzeznFP68nn+Mww47zOQuXbpEjvnRRx+ZTD+V1po1a0z2/97wX1OR6LXJ/b7IpF27dia3b9/e5MauLeu/7/j8GeBWrVrlVBOidu3aZfLpp5+edv+amprItuHDh5vsn6OyaNEikzdt2pRLiSXhfw+NnW9RLBn/ZDnn/ktEfvY2jxSRv64YPldELi5wXQAAAEDRNXVmuJtzbkvq5z+JSLcC1QMAAACUTN4n0LmD/7cT/f+dFFW9UVVrVbU206U9EC76BNmiV5AN+gTZolfQ1MXwVlXtISKS+lp3qB2dc7Odc9XOuequXbs28XCodPQJskWvIBv0CbJFr6Cp08mLReRqEXko9fX1glWUQNu3bzf57bffNrmxEyMQnrfeeiun/R955JEiVYJ0jj/+eJP9P98zZ8402T/5xffMM89EtmX69Mk/cSqTPXv2pM0i0e9r4cKFJnPjjtKqq7OfITV2UwP/hDn/RMkrrrjC5NGjR5t87LHHmlxVVZVznSi8vXv3mnz33Xfn9Pjrr78+su2cc84x2e+FBx980OQtW7ZIvt59912T77zzTpMznWzsn5zsv1eW8qZS2Vxa7RUR+VhE+qvq96p6nRxcBJ+rqhtEZFgqAwAAAImS8ZNh59yVh/ilcw6xHQAAAEgE7kAHAACAYJXuisYV7OOPPza5sdkvhOfXX381edmyZTk9vrq6upDloIn8ubY77rgjp8dPmhS9H1GmmeBXXnnF5Ew38vEvqL9hw4bIPv4ccaZ5PhSWPyfaFL169TJ5xowZJvu9inj6/fffTZ4/f77J/nlGN998s8n+fHA2/N7w58kzWbduXWTbmDFjcnqODh06mPzhhx+afMwxx+T0fIXEJ8MAAAAIFothAAAABIvFMAAAAILFzHABnHrqqSb714Js1aqVyW3atCl6TSg//9qRP/zwQ9r9u3WzdzVn9rwy+O8H2bjhhhty2v/ee+81ee7cuZF9xo0bZ/LkyZNNHjFiRE7HRG78vyeOPPJIk3fs2JHxOdavX2/yrFmzTB45cqTJxx13XC4lokTat29v8oIFC0wePny4yY8//njRa/K9+eabJtfU1ET2yfR3lD+n7M8IDxgwoInVFR6fDAMAACBYLIYBAAAQLBbDAAAACBYzwwXw2Wefmbxv3z6T+/TpY3JTZggRf/X19Sb7830+f0b4pZdeMrlZM/6tiuy0a9fO5FGjRkX28WeGJ0yYUNSaYHXv3t3kfv36mZzNzLBv4sSJJk+dOtXkRx991GT/erUoj5YtW5o8ZMgQk9977z2TS3H96Hfeecfkiy66yORszmHxr8E+fvx4k8t5HeFM+NsWAAAAwWIxDAAAgGCxGAYAAECwmBkuAP/aef59xS+//PJSloMymT59usnLli1Lu/+JJ55o8tlnn13wmhCmZ555JuM+/gwrSuu1114zubq6OrKPf23yTOcR+OerjB071uS77rrL5BUrVkSeo2/fvia3aMEyodhat25t8uDBg4t+zO+++87kiy++OO3+Rx11VGTbLbfcYrI/w56keyrwyTAAAACCxWIYAAAAwWIxDAAAgGAxDFQA/uydfz2+QYMGlbIclEk2c5oNnXDCCUWqBKHZuXOnydOmTYvs48/vVVVVFbUmpOf/vTFv3rzIPqNHjzZ569ateR3Tnyk+6aSTIvv417g988wz8zom4sG/jrA/77t//36Tjz76aJM//fTTyHP6+yQZnwwDAAAgWCyGAQAAECwWwwAAAAgWM8MF4M/n+dcZXr16tckjRowoek2IP/oAheLPq2/bti2yj99vnMsQL41dZ9y/hr1/DeBiuOKKK0xesGCByUOHDi16Dcjfrl27TPZnwb/99tu0j9+7d6/J/nkJIswMAwAAABWBxTAAAACCxWIYAAAAwWIxDAAAgGBxAl0B+DfZ8DMqz6ZNmyLbfvzxx7SPad68ucmnn356QWtCOH7++WeTp0+fbnLPnj0jj3n55ZeLWhMKz78xyubNm01euHChyffff7/JO3bsyPmY/smXY8eONfmjjz4yuW3btjkfA8W3Zs0akx977DGTM61TZs2aZfKAAQMKU1hM8ckwAAAAgsViGAAAAMFiMQwAAIBgMTNcAP6Fp3/66SeTL7roolKWgxLwX2MRkd9++y3tY8aMGWNy+/btC1oTKtf+/ftNfvjhh032+/HCCy+MPMfhhx9e+MJQVP55Bj169DC5Y8eOJvs3WiiE1q1bp60J8VBbW2tyrjd1qqmpMfmCCy7Iu6Yk4ZNhAAAABCvjYlhVq1T1fVX9QlXXqur41PbOqrpUVTekvnYqfrkAAABA4WTzyXC9iNzhnBsgIoNFZKyqDhCRySKyzDnXT0SWpTIAAACQGBlnhp1zW0RkS+rnu1X1SxE5RkRGisg/p3abKyL/KSKTilJlzK1atcpk//p9nTrxoXmlefbZZ3N+zFNPPWXyTTfdZPIJJ5yQV02oXFu2bDH5kUceMdm/1usDDzxQ9JqQ3u7du02eOHGiye3atTO5W7dukedwzpn86quvmuz/3ZPJn3/+aXKzZpk/DzvssMNyOgaK7/XXX49su/baa03OND/euXNnk6dNm2Zyhw4dmlhdMuU0M6yqvUXkVBH5RES6pRbKIiI/iUj0TzIAAAAQY1kvhlX1cBFZICITnHPmnxzu4D9f3SEed6Oq1qpqrX9nG+Av9AmyRa8gG/QJskWvIKvFsKq2lIML4Zedc3/d/3GrqvZI/XoPEalr7LHOudnOuWrnXHXXrl0LUTMqEH2CbNEryAZ9gmzRK8g4M6wHB2CfF5EvnXNPNPilxSJytYg8lPoaHWIJRN++fU32r/nZq1evUpaDmPLn//w+YWYYh7Jz506T/fMSHnvsMZMHDRpU9JqQ3ubNm01+7rnncn4O/z3Df939nIk/I9zY4/33IX8+vVWrVjkdE/nbtGmTyffcc09kn19++cXkTL0xbNgwk0M/tymbm278o4iMFpH/UdXVqW3/IgcXwf+uqteJyCYRqTnE4wEAAIBYyuZqEv8tIof6J8Y5hS0HAAAAKB3uQAcAAIBgZTMmgQz8a4D6szr79u0zmes2Jl+PHj1yfow/a3fWWWcVqhxUmPr6epOnTp1qcsuWLU0eOnRo0WtCbvzruF533XUmP//886Usp1Ht27ePbFuyZInJVVVVpSoHKXV19noEp512msn+fLBIdL7cd8YZZ5j84osvmtyiRdjLQT4ZBgAAQLBYDAMAACBYLIYBAAAQrLCHRApkz549JvvzpMwIV55bb701sm3mzJkm+/d2nz59elFrQuVYu3atyYsXLzb5wgsvNJnrCsdP9+7dTX7wwQdNfvPNN03eunVr0WsaP368yaNGjYrsw4xw6flriP79+5u8a5e56W9W15e+7777TJ4wYYLJoc8I+/hkGAAAAMFiMQwAAIBgsRgGAABAsBgaKQD/PuH+rBgqT2OvcWPXfgSa4oEHHjDZn+/z508Rfx07djT5q6++MnnMmDGRx7z00ks5HaO6utrkadOmmTxkyJCcng/FsXfvXpMvu+wyk/0Z4Wz4M8KTJk0y2b/OPSw+GQYAAECwWAwDAAAgWCyGAQAAECwWwwAAAAgWJ9AVwP3331/uEgBUsDPOOMPkgQMHlqkSNJV/EqSf586dG3lMY9uQfO3atTO5rq4up8fX1NREtk2ZMsVkbqqRGz4ZBgAAQLBYDAMAACBYLIYBAAAQLIZKAAAAymTx4sUmjxs3zuS2bduaPGfOnMhzMCOcHz4ZBgAAQLBYDAMAACBYLIYBAAAQLIZMACBm/JnBhQsXlqkSAMXWs2dPkxctWlSmSsLFJ8MAAAAIFothAAAABIvFMAAAAIKlzrnSHUx1m4hsEpEuIrK9ZAduGmpM72/Oua7FeOKE9YlIMuqkV8qPGtMrRZ+I8DoUSqX3Cq9B4ZSrzqz7pKSL4f8/qGqtc6665AfOATWWX1K+vyTUmYQa85GE748a4yEJ3yM1ll8Svr8k1CiSjDoZkwAAAECwWAwDAAAgWOVaDM8u03FzQY3ll5TvLwl1JqHGfCTh+6PGeEjC90iN5ZeE7y8JNYokoM6yzAwDAAAAccCYBAAAAIJV0sWwqp6vqutVdaOqTi7lsdNR1RdUtU5VP2+wrbOqLlXVDamvncpcY5Wqvq+qX6jqWlUdH8c6CyWOvUKfxE8c+0SEXokjeqXJ9QXVJyLx7JW490mqnsT2SskWw6raXESeFpELRGSAiFypqgNKdfwM5ojI+d62ySKyzDnXT0SWpXI51YvIHc65ASIyWETGpn7/4lZn3mLcK3OEPomNGPeJCL0SK/RKXoLpE5FY98ociXefiCS5V5xzJfkhIkNE5J0GeYqITCnV8bOor7eIfN4grxeRHqmf9xCR9eWu0av3dRE5N+51Vlqv0Cfx+RHnPqFX4vWDXqFPKqFXktQnSeuVUo5JHCMimxvk71Pb4qqbc25L6uc/iUi3chbTkKr2FpFTReQTiXGdeUhSr8T2958+iZ3Yvgb0SuzE8jUIoE9EktUrsX0NktYrnECXBXfwnzOxuOyGqh4uIgtEZIJzblfDX4tTnSGK0+8/fRJvcXoN6JV4i8trQJ/EW5xegyT2SikXwz+ISFWD3DO1La62qmoPEZHU17oy1yOq2lIONtjLzrmFqc2xq7MAktQrsfv9p09iK3avAb0SW7F6DQLqE5Fk9UrsXoOk9kopF8MrRKSfqvZR1VYi8ncRWVzC4+dqsYhcnfr51XJw9qVsVFVF5HkR+dI590SDX4pVnQWSpF6J1e8/fRLbPhGJ2WtAr9Ar2QisT0SS1Suxeg0S3SslHqYeLiJficj/isi/lntgukFdr4jIFhE5IAfng64TkSPl4FmPG0TkP0Skc5lr/Cc5+F8La0RkderH8LjVWcm9Qp/E70cc+4ReiecPeoU+SXKvxL1Pkt4r3IEOAAAAweIEOgAAAASLxTAAAACCxWIYAAAAwWIxDAAAgGCxGAYAAECwWAwDAAAgWCyGAQAAECwWwwAAAAjW/wEgPmufEARJLAAAAABJRU5ErkJggg==\n",
80      "text/plain": [
81       "<Figure size 864x216 with 5 Axes>"
82      ]
83     },
84     "metadata": {},
85     "output_type": "display_data"
86    }
87   ],
88   "source": [
89    "nsamples = 5\n",
90    "idx = np.random.choice(len(mnist['train_data']), nsamples)\n",
91    "_, axarr = plt.subplots(1, nsamples, sharex='col', sharey='row',figsize=(12,3))\n",
92    "\n",
93    "for i,j in enumerate(idx):\n",
94    "    axarr[i].imshow(np.reshape(image[j,:],(28,28)), interpolation='nearest', cmap=cm.Greys)\n",
95    "\n",
96    "plt.show()"
97   ]
98  },
99  {
100   "cell_type": "markdown",
101   "metadata": {},
102   "source": [
103    "We can optionally save the parameters in the directory variable 'model_prefix'. We first create data iterators for MXNet, with each batch of data containing 100 images."
104   ]
105  },
106  {
107   "cell_type": "code",
108   "execution_count": 4,
109   "metadata": {},
110   "outputs": [],
111   "source": [
112    "model_prefix = None\n",
113    "\n",
114    "batch_size = 100\n",
115    "latent_dim = 5\n",
116    "nd_iter = mx.io.NDArrayIter(data={'data':image},label={'loss_label':label},\n",
117    "                            batch_size = batch_size)\n",
118    "nd_iter_test = mx.io.NDArrayIter(data={'data':image_test},label={'loss_label':label_test},\n",
119    "                            batch_size = batch_size)"
120   ]
121  },
122  {
123   "cell_type": "markdown",
124   "metadata": {},
125   "source": [
126    "## 2.  Building the Network Architecture\n",
127    "\n",
128    "### 2.1 Gaussian MLP as encoder\n",
129    "Next we constuct the neural network, as in the [paper](https://arxiv.org/abs/1312.6114/), we use *Multilayer Perceptron (MLP)* for both the encoder and decoder. For encoder, a Gaussian MLP is used as follows:\n",
130    "\n",
131    "\\begin{align}\n",
132    "\\log q_{\\phi}(z|x) &= \\log \\mathcal{N}(z:\\mu,\\sigma^2I) \\\\\n",
133    "\\textit{ where } \\mu &= W_2h+b_2, \\log \\sigma^2 = W_3h+b_3\\\\\n",
134    "h &= \\tanh(W_1x+b_1)\n",
135    "\\end{align}\n",
136    "\n",
137    "where $\\{W_1,W_2,W_3,b_1,b_2,b_3\\}$ are the weights and biases of the MLP.\n",
138    "Note below that `encoder_mu`(`mu`) and `encoder_logvar`(`logvar`) are symbols. So, we can use `get_internals()` to get the values of them, after which we can sample the latent variable $z$.\n",
139    "\n",
140    "\n",
141    "\n"
142   ]
143  },
144  {
145   "cell_type": "code",
146   "execution_count": 5,
147   "metadata": {},
148   "outputs": [],
149   "source": [
150    "## define data and loss labels as symbols \n",
151    "data = mx.sym.var('data')\n",
152    "loss_label = mx.sym.var('loss_label')\n",
153    "\n",
154    "## define fully connected and activation layers for the encoder, where we used tanh activation function.\n",
155    "encoder_h  = mx.sym.FullyConnected(data=data, name=\"encoder_h\",num_hidden=400)\n",
156    "act_h = mx.sym.Activation(data=encoder_h, act_type=\"tanh\",name=\"activation_h\")\n",
157    "\n",
158    "## define mu and log variance which are the fully connected layers of the previous activation layer\n",
159    "mu  = mx.sym.FullyConnected(data=act_h, name=\"mu\",num_hidden = latent_dim)\n",
160    "logvar  = mx.sym.FullyConnected(data=act_h, name=\"logvar\",num_hidden = latent_dim)\n",
161    "\n",
162    "## sample the latent variables z according to Normal(mu,var)\n",
163    "z = mu + mx.symbol.broadcast_mul(mx.symbol.exp(0.5 * logvar), \n",
164    "                                 mx.symbol.random_normal(loc=0, scale=1, shape=(batch_size, latent_dim)),\n",
165    "                                 name=\"z\")"
166   ]
167  },
168  {
169   "cell_type": "markdown",
170   "metadata": {},
171   "source": [
172    "### 2.2 Bernoulli MLP as decoder\n",
173    "\n",
174    "In this case let $p_\\theta(x|z)$ be a multivariate Bernoulli whose probabilities are computed from $z$ with a feed forward neural network with a single hidden layer:\n",
175    "\n",
176    "\\begin{align}\n",
177    "\\log p(x|z) &= \\sum_{i=1}^D x_i\\log y_i + (1-x_i)\\log (1-y_i) \\\\\n",
178    "\\textit{ where }  y &= f_\\sigma(W_5\\tanh (W_4z+b_4)+b_5)\n",
179    "\\end{align}\n",
180    "\n",
181    "where $f_\\sigma(\\dot)$ is the elementwise sigmoid activation function, $\\{W_4,W_5,b_4,b_5\\}$ are the weights and biases of the decoder MLP. A Bernouilli likelihood is suitable for this type of data but you can easily extend it to other likelihood types by parsing into the argument `likelihood` in the `VAE` class, see section 4 for details."
182   ]
183  },
184  {
185   "cell_type": "code",
186   "execution_count": 6,
187   "metadata": {},
188   "outputs": [],
189   "source": [
190    "# define fully connected and tanh activation layers for the decoder\n",
191    "decoder_z = mx.sym.FullyConnected(data=z, name=\"decoder_z\",num_hidden=400)\n",
192    "act_z = mx.sym.Activation(data=decoder_z, act_type=\"tanh\",name=\"activation_z\")\n",
193    "\n",
194    "# define the output layer with sigmoid activation function, where the dimension is equal to the input dimension\n",
195    "decoder_x = mx.sym.FullyConnected(data=act_z, name=\"decoder_x\",num_hidden=features)\n",
196    "y = mx.sym.Activation(data=decoder_x, act_type=\"sigmoid\",name='activation_x')"
197   ]
198  },
199  {
200   "cell_type": "markdown",
201   "metadata": {},
202   "source": [
203    "### 2.3 Joint Loss Function for the Encoder and the Decoder\n",
204    "\n",
205    "The variational lower bound also called evidence lower bound (ELBO) can be estimated as:\n",
206    "\n",
207    "\\begin{align}\n",
208    "\\mathcal{L}(\\theta,\\phi;x_{(i)}) \\approx \\frac{1}{2}\\left(1+\\log ((\\sigma_j^{(i)})^2)-(\\mu_j^{(i)})^2-(\\sigma_j^{(i)})^2\\right) + \\log p_\\theta(x^{(i)}|z^{(i)})\n",
209    "\\end{align}\n",
210    "\n",
211    "where the first term is the KL divergence of the approximate posterior from the prior, and the second term is an expected negative reconstruction error. We would like to maximize this lower bound, so we can define the loss to be $-\\mathcal{L}$(minus ELBO) for MXNet to minimize."
212   ]
213  },
214  {
215   "cell_type": "code",
216   "execution_count": 7,
217   "metadata": {},
218   "outputs": [],
219   "source": [
220    "# define the objective loss function that needs to be minimized\n",
221    "KL = 0.5*mx.symbol.sum(1+logvar-pow( mu,2)-mx.symbol.exp(logvar),axis=1)\n",
222    "loss = -mx.symbol.sum(mx.symbol.broadcast_mul(loss_label,mx.symbol.log(y)) \n",
223    "                      + mx.symbol.broadcast_mul(1-loss_label,mx.symbol.log(1-y)),axis=1)-KL\n",
224    "output = mx.symbol.MakeLoss(sum(loss),name='loss')"
225   ]
226  },
227  {
228   "cell_type": "markdown",
229   "metadata": {},
230   "source": [
231    "## 3. Training the model\n",
232    "\n",
233    "Now, we can define the model and train it. First we will initilize the weights and the biases to be Gaussian(0,0.01), and then use stochastic gradient descent for optimization. To warm start the training, one may also initilize with pre-trainined parameters `arg_params` using `init=mx.initializer.Load(arg_params)`. \n",
234    "\n",
235    "To save intermediate results, we can optionally use `epoch_end_callback = mx.callback.do_checkpoint(model_prefix, 1)` which saves the parameters to the path given by model_prefix, and with period every $1$ epoch. To assess the performance, we output $-\\mathcal{L}$(minus ELBO) after each epoch, with the command `eval_metric = 'Loss'` which is defined above. We will also plot the training loss for mini batches by accessing the log and saving it to a list, and then parsing it to the argument `batch_end_callback`."
236   ]
237  },
238  {
239   "cell_type": "code",
240   "execution_count": 8,
241   "metadata": {},
242   "outputs": [],
243   "source": [
244    "# set up the log\n",
245    "nd_iter.reset()\n",
246    "logging.getLogger().setLevel(logging.DEBUG)  \n",
247    "\n",
248    "# define function to trave back training loss\n",
249    "def log_to_list(period, lst):\n",
250    "    def _callback(param):\n",
251    "        \"\"\"The checkpoint function.\"\"\"\n",
252    "        if param.nbatch % period == 0:\n",
253    "            name, value = param.eval_metric.get()\n",
254    "            lst.append(value)\n",
255    "    return _callback\n",
256    "\n",
257    "# define the model\n",
258    "model = mx.mod.Module(\n",
259    "    symbol = output ,\n",
260    "    data_names=['data'],\n",
261    "    label_names = ['loss_label'])"
262   ]
263  },
264  {
265   "cell_type": "code",
266   "execution_count": 9,
267   "metadata": {},
268   "outputs": [
269    {
270     "name": "stderr",
271     "output_type": "stream",
272     "text": [
273      "INFO:root:Epoch[0] Train-loss=373.547317\n",
274      "INFO:root:Epoch[0] Time cost=5.020\n",
275      "INFO:root:Epoch[1] Train-loss=212.232684\n",
276      "INFO:root:Epoch[1] Time cost=4.651\n",
277      "INFO:root:Epoch[2] Train-loss=207.448528\n",
278      "INFO:root:Epoch[2] Time cost=4.665\n",
279      "INFO:root:Epoch[3] Train-loss=205.369479\n",
280      "INFO:root:Epoch[3] Time cost=4.758\n",
281      "INFO:root:Epoch[4] Train-loss=203.651983\n",
282      "INFO:root:Epoch[4] Time cost=4.672\n",
283      "INFO:root:Epoch[5] Train-loss=202.061007\n",
284      "INFO:root:Epoch[5] Time cost=5.087\n",
285      "INFO:root:Epoch[6] Train-loss=199.348143\n",
286      "INFO:root:Epoch[6] Time cost=5.056\n",
287      "INFO:root:Epoch[7] Train-loss=196.266242\n",
288      "INFO:root:Epoch[7] Time cost=4.813\n",
289      "INFO:root:Epoch[8] Train-loss=194.694945\n",
290      "INFO:root:Epoch[8] Time cost=4.776\n",
291      "INFO:root:Epoch[9] Train-loss=193.699284\n",
292      "INFO:root:Epoch[9] Time cost=4.756\n",
293      "INFO:root:Epoch[10] Train-loss=193.036517\n",
294      "INFO:root:Epoch[10] Time cost=4.757\n",
295      "INFO:root:Epoch[11] Train-loss=192.555736\n",
296      "INFO:root:Epoch[11] Time cost=4.678\n",
297      "INFO:root:Epoch[12] Train-loss=192.020813\n",
298      "INFO:root:Epoch[12] Time cost=4.630\n",
299      "INFO:root:Epoch[13] Train-loss=191.648876\n",
300      "INFO:root:Epoch[13] Time cost=5.158\n",
301      "INFO:root:Epoch[14] Train-loss=191.057798\n",
302      "INFO:root:Epoch[14] Time cost=4.781\n",
303      "INFO:root:Epoch[15] Train-loss=190.315835\n",
304      "INFO:root:Epoch[15] Time cost=5.117\n",
305      "INFO:root:Epoch[16] Train-loss=189.311271\n",
306      "INFO:root:Epoch[16] Time cost=4.707\n",
307      "INFO:root:Epoch[17] Train-loss=187.285967\n",
308      "INFO:root:Epoch[17] Time cost=4.745\n",
309      "INFO:root:Epoch[18] Train-loss=185.271324\n",
310      "INFO:root:Epoch[18] Time cost=4.692\n",
311      "INFO:root:Epoch[19] Train-loss=183.510888\n",
312      "INFO:root:Epoch[19] Time cost=4.762\n",
313      "INFO:root:Epoch[20] Train-loss=181.756008\n",
314      "INFO:root:Epoch[20] Time cost=4.838\n",
315      "INFO:root:Epoch[21] Train-loss=180.546818\n",
316      "INFO:root:Epoch[21] Time cost=4.764\n",
317      "INFO:root:Epoch[22] Train-loss=179.479776\n",
318      "INFO:root:Epoch[22] Time cost=4.791\n",
319      "INFO:root:Epoch[23] Train-loss=178.352077\n",
320      "INFO:root:Epoch[23] Time cost=4.981\n",
321      "INFO:root:Epoch[24] Train-loss=177.385084\n",
322      "INFO:root:Epoch[24] Time cost=5.292\n",
323      "INFO:root:Epoch[25] Train-loss=175.920123\n",
324      "INFO:root:Epoch[25] Time cost=5.097\n",
325      "INFO:root:Epoch[26] Train-loss=174.377171\n",
326      "INFO:root:Epoch[26] Time cost=4.907\n",
327      "INFO:root:Epoch[27] Train-loss=172.590589\n",
328      "INFO:root:Epoch[27] Time cost=4.484\n",
329      "INFO:root:Epoch[28] Train-loss=170.933683\n",
330      "INFO:root:Epoch[28] Time cost=4.348\n",
331      "INFO:root:Epoch[29] Train-loss=169.866807\n",
332      "INFO:root:Epoch[29] Time cost=4.647\n",
333      "INFO:root:Epoch[30] Train-loss=169.182084\n",
334      "INFO:root:Epoch[30] Time cost=5.034\n",
335      "INFO:root:Epoch[31] Train-loss=168.121719\n",
336      "INFO:root:Epoch[31] Time cost=5.615\n",
337      "INFO:root:Epoch[32] Train-loss=167.389992\n",
338      "INFO:root:Epoch[32] Time cost=4.733\n",
339      "INFO:root:Epoch[33] Train-loss=166.189067\n",
340      "INFO:root:Epoch[33] Time cost=5.041\n",
341      "INFO:root:Epoch[34] Train-loss=163.783392\n",
342      "INFO:root:Epoch[34] Time cost=5.168\n",
343      "INFO:root:Epoch[35] Train-loss=162.167959\n",
344      "INFO:root:Epoch[35] Time cost=5.019\n",
345      "INFO:root:Epoch[36] Train-loss=161.192039\n",
346      "INFO:root:Epoch[36] Time cost=5.064\n",
347      "INFO:root:Epoch[37] Train-loss=160.307114\n",
348      "INFO:root:Epoch[37] Time cost=5.180\n",
349      "INFO:root:Epoch[38] Train-loss=159.591957\n",
350      "INFO:root:Epoch[38] Time cost=5.440\n",
351      "INFO:root:Epoch[39] Train-loss=159.109593\n",
352      "INFO:root:Epoch[39] Time cost=5.119\n",
353      "INFO:root:Epoch[40] Train-loss=158.463844\n",
354      "INFO:root:Epoch[40] Time cost=5.299\n",
355      "INFO:root:Epoch[41] Train-loss=158.037287\n",
356      "INFO:root:Epoch[41] Time cost=4.856\n",
357      "INFO:root:Epoch[42] Train-loss=157.598576\n",
358      "INFO:root:Epoch[42] Time cost=5.227\n",
359      "INFO:root:Epoch[43] Train-loss=157.097344\n",
360      "INFO:root:Epoch[43] Time cost=5.237\n",
361      "INFO:root:Epoch[44] Train-loss=156.594472\n",
362      "INFO:root:Epoch[44] Time cost=4.783\n",
363      "INFO:root:Epoch[45] Train-loss=156.177069\n",
364      "INFO:root:Epoch[45] Time cost=4.834\n",
365      "INFO:root:Epoch[46] Train-loss=155.825302\n",
366      "INFO:root:Epoch[46] Time cost=4.902\n",
367      "INFO:root:Epoch[47] Train-loss=155.318117\n",
368      "INFO:root:Epoch[47] Time cost=4.966\n",
369      "INFO:root:Epoch[48] Train-loss=154.890766\n",
370      "INFO:root:Epoch[48] Time cost=5.012\n",
371      "INFO:root:Epoch[49] Train-loss=154.504158\n",
372      "INFO:root:Epoch[49] Time cost=4.844\n",
373      "INFO:root:Epoch[50] Train-loss=154.035214\n",
374      "INFO:root:Epoch[50] Time cost=4.736\n",
375      "INFO:root:Epoch[51] Train-loss=153.692903\n",
376      "INFO:root:Epoch[51] Time cost=5.057\n",
377      "INFO:root:Epoch[52] Train-loss=153.257554\n",
378      "INFO:root:Epoch[52] Time cost=5.044\n",
379      "INFO:root:Epoch[53] Train-loss=152.849715\n",
380      "INFO:root:Epoch[53] Time cost=4.783\n",
381      "INFO:root:Epoch[54] Train-loss=152.483047\n",
382      "INFO:root:Epoch[54] Time cost=4.842\n",
383      "INFO:root:Epoch[55] Train-loss=152.091617\n",
384      "INFO:root:Epoch[55] Time cost=5.044\n",
385      "INFO:root:Epoch[56] Train-loss=151.715490\n",
386      "INFO:root:Epoch[56] Time cost=5.029\n",
387      "INFO:root:Epoch[57] Train-loss=151.362293\n",
388      "INFO:root:Epoch[57] Time cost=4.873\n",
389      "INFO:root:Epoch[58] Train-loss=151.003241\n",
390      "INFO:root:Epoch[58] Time cost=4.729\n",
391      "INFO:root:Epoch[59] Train-loss=150.619678\n",
392      "INFO:root:Epoch[59] Time cost=5.068\n",
393      "INFO:root:Epoch[60] Train-loss=150.296043\n",
394      "INFO:root:Epoch[60] Time cost=4.458\n",
395      "INFO:root:Epoch[61] Train-loss=149.964152\n",
396      "INFO:root:Epoch[61] Time cost=4.828\n",
397      "INFO:root:Epoch[62] Train-loss=149.694102\n",
398      "INFO:root:Epoch[62] Time cost=5.012\n",
399      "INFO:root:Epoch[63] Train-loss=149.290113\n",
400      "INFO:root:Epoch[63] Time cost=5.193\n",
401      "INFO:root:Epoch[64] Train-loss=148.934186\n",
402      "INFO:root:Epoch[64] Time cost=4.999\n",
403      "INFO:root:Epoch[65] Train-loss=148.657502\n",
404      "INFO:root:Epoch[65] Time cost=4.810\n",
405      "INFO:root:Epoch[66] Train-loss=148.331948\n",
406      "INFO:root:Epoch[66] Time cost=5.201\n",
407      "INFO:root:Epoch[67] Train-loss=148.018539\n",
408      "INFO:root:Epoch[67] Time cost=4.833\n",
409      "INFO:root:Epoch[68] Train-loss=147.746825\n",
410      "INFO:root:Epoch[68] Time cost=5.187\n",
411      "INFO:root:Epoch[69] Train-loss=147.406399\n",
412      "INFO:root:Epoch[69] Time cost=5.355\n",
413      "INFO:root:Epoch[70] Train-loss=147.181831\n",
414      "INFO:root:Epoch[70] Time cost=4.989\n",
415      "INFO:root:Epoch[71] Train-loss=146.860770\n",
416      "INFO:root:Epoch[71] Time cost=4.934\n",
417      "INFO:root:Epoch[72] Train-loss=146.604369\n",
418      "INFO:root:Epoch[72] Time cost=5.283\n",
419      "INFO:root:Epoch[73] Train-loss=146.351628\n",
420      "INFO:root:Epoch[73] Time cost=5.062\n",
421      "INFO:root:Epoch[74] Train-loss=146.102506\n",
422      "INFO:root:Epoch[74] Time cost=4.540\n",
423      "INFO:root:Epoch[75] Train-loss=145.828805\n",
424      "INFO:root:Epoch[75] Time cost=4.875\n",
425      "INFO:root:Epoch[76] Train-loss=145.571626\n",
426      "INFO:root:Epoch[76] Time cost=4.856\n",
427      "INFO:root:Epoch[77] Train-loss=145.365383\n",
428      "INFO:root:Epoch[77] Time cost=5.003\n",
429      "INFO:root:Epoch[78] Train-loss=145.101047\n",
430      "INFO:root:Epoch[78] Time cost=4.718\n",
431      "INFO:root:Epoch[79] Train-loss=144.810765\n",
432      "INFO:root:Epoch[79] Time cost=5.127\n",
433      "INFO:root:Epoch[80] Train-loss=144.619876\n",
434      "INFO:root:Epoch[80] Time cost=4.737\n",
435      "INFO:root:Epoch[81] Train-loss=144.399066\n",
436      "INFO:root:Epoch[81] Time cost=4.742\n",
437      "INFO:root:Epoch[82] Train-loss=144.220090\n",
438      "INFO:root:Epoch[82] Time cost=4.810\n",
439      "INFO:root:Epoch[83] Train-loss=143.904279\n",
440      "INFO:root:Epoch[83] Time cost=5.176\n",
441      "INFO:root:Epoch[84] Train-loss=143.734935\n",
442      "INFO:root:Epoch[84] Time cost=4.921\n",
443      "INFO:root:Epoch[85] Train-loss=143.499403\n",
444      "INFO:root:Epoch[85] Time cost=4.692\n",
445      "INFO:root:Epoch[86] Train-loss=143.304287\n",
446      "INFO:root:Epoch[86] Time cost=4.778\n",
447      "INFO:root:Epoch[87] Train-loss=143.096145\n",
448      "INFO:root:Epoch[87] Time cost=4.962\n",
449      "INFO:root:Epoch[88] Train-loss=142.877920\n",
450      "INFO:root:Epoch[88] Time cost=4.815\n",
451      "INFO:root:Epoch[89] Train-loss=142.677429\n",
452      "INFO:root:Epoch[89] Time cost=5.127\n",
453      "INFO:root:Epoch[90] Train-loss=142.499622\n",
454      "INFO:root:Epoch[90] Time cost=5.463\n",
455      "INFO:root:Epoch[91] Train-loss=142.300291\n",
456      "INFO:root:Epoch[91] Time cost=4.639\n",
457      "INFO:root:Epoch[92] Train-loss=142.111362\n",
458      "INFO:root:Epoch[92] Time cost=5.064\n",
459      "INFO:root:Epoch[93] Train-loss=141.912848\n",
460      "INFO:root:Epoch[93] Time cost=4.894\n",
461      "INFO:root:Epoch[94] Train-loss=141.723130\n",
462      "INFO:root:Epoch[94] Time cost=4.635\n",
463      "INFO:root:Epoch[95] Train-loss=141.516580\n",
464      "INFO:root:Epoch[95] Time cost=5.063\n",
465      "INFO:root:Epoch[96] Train-loss=141.362380\n",
466      "INFO:root:Epoch[96] Time cost=4.785\n",
467      "INFO:root:Epoch[97] Train-loss=141.178878\n",
468      "INFO:root:Epoch[97] Time cost=4.699\n",
469      "INFO:root:Epoch[98] Train-loss=141.004168\n",
470      "INFO:root:Epoch[98] Time cost=4.959\n",
471      "INFO:root:Epoch[99] Train-loss=140.865592\n",
472      "INFO:root:Epoch[99] Time cost=5.155\n"
473     ]
474    }
475   ],
476   "source": [
477    "# training the model, save training loss as a list.\n",
478    "training_loss=list()\n",
479    "\n",
480    "# initilize the parameters for training using Normal.\n",
481    "init = mx.init.Normal(0.01)\n",
482    "model.fit(nd_iter,  # train data\n",
483    "          initializer=init,\n",
484    "          # if eval_data is supplied, test loss will also be reported\n",
485    "          # eval_data = nd_iter_test,\n",
486    "          optimizer='sgd',  # use SGD to train\n",
487    "          optimizer_params={'learning_rate':1e-3,'wd':1e-2},  \n",
488    "          # save parameters for each epoch if model_prefix is supplied\n",
489    "          epoch_end_callback = None if model_prefix==None else mx.callback.do_checkpoint(model_prefix, 1),\n",
490    "          batch_end_callback = log_to_list(N/batch_size,training_loss), \n",
491    "          num_epoch=100,\n",
492    "          eval_metric = 'Loss')"
493   ]
494  },
495  {
496   "cell_type": "code",
497   "execution_count": 10,
498   "metadata": {},
499   "outputs": [
500    {
501     "name": "stderr",
502     "output_type": "stream",
503     "text": [
504      "DEBUG:matplotlib.font_manager:findfont: Matching :family=sans-serif:style=normal:variant=normal:weight=normal:stretch=normal:size=12.0 to DejaVu Sans ('/usr/local/lib/python3.5/dist-packages/matplotlib/mpl-data/fonts/ttf/DejaVuSans.ttf') with score of 0.050000\n"
505     ]
506    },
507    {
508     "data": {
509      "image/png": "\n",
510      "text/plain": [
511       "<Figure size 432x288 with 1 Axes>"
512      ]
513     },
514     "metadata": {},
515     "output_type": "display_data"
516    }
517   ],
518   "source": [
519    "ELBO = [-training_loss[i] for i in range(len(training_loss))]\n",
520    "plt.plot(ELBO)\n",
521    "plt.ylabel('ELBO');plt.xlabel('epoch');plt.title(\"training curve for mini batches\")\n",
522    "plt.show()"
523   ]
524  },
525  {
526   "cell_type": "markdown",
527   "metadata": {},
528   "source": [
529    "As expected, the ELBO is monotonically increasing over epoch, and we reproduced the results given in the paper [Auto-Encoding Variational Bayes](https://arxiv.org/abs/1312.6114/). Now we can extract/load the parameters and then feed the network forward to calculate $y$ which is the reconstructed image, and we can also calculate the ELBO for the test set. "
530   ]
531  },
532  {
533   "cell_type": "code",
534   "execution_count": 11,
535   "metadata": {},
536   "outputs": [],
537   "source": [
538    "arg_params = model.get_params()[0]\n",
539    "nd_iter_test.reset()\n",
540    "test_batch = nd_iter_test.next()\n",
541    "\n",
542    "# if saved the parameters, can load them using `load_checkpoint` method at e.g. 100th epoch\n",
543    "# sym, arg_params, aux_params = mx.model.load_checkpoint(model_prefix, 100)\n",
544    "# assert sym.tojson() == output.tojson()\n",
545    "\n",
546    "e = y.bind(mx.cpu(), {'data': test_batch.data[0],\n",
547    "                     'encoder_h_weight': arg_params['encoder_h_weight'],\n",
548    "                     'encoder_h_bias': arg_params['encoder_h_bias'],\n",
549    "                     'mu_weight': arg_params['mu_weight'],\n",
550    "                     'mu_bias': arg_params['mu_bias'],\n",
551    "                     'logvar_weight':arg_params['logvar_weight'],\n",
552    "                     'logvar_bias':arg_params['logvar_bias'],\n",
553    "                     'decoder_z_weight':arg_params['decoder_z_weight'],\n",
554    "                     'decoder_z_bias':arg_params['decoder_z_bias'],\n",
555    "                     'decoder_x_weight':arg_params['decoder_x_weight'],\n",
556    "                     'decoder_x_bias':arg_params['decoder_x_bias'],                \n",
557    "                     'loss_label':label})\n",
558    "\n",
559    "x_fit = e.forward()\n",
560    "x_construction = x_fit[0].asnumpy()"
561   ]
562  },
563  {
564   "cell_type": "code",
565   "execution_count": 12,
566   "metadata": {
567    "scrolled": true
568   },
569   "outputs": [
570    {
571     "data": {
572      "image/png": "\n",
573      "text/plain": [
574       "<Figure size 864x216 with 4 Axes>"
575      ]
576     },
577     "metadata": {},
578     "output_type": "display_data"
579    }
580   ],
581   "source": [
582    "# learning images on the test set\n",
583    "f, ((ax1, ax2, ax3, ax4)) = plt.subplots(1,4,  sharex='col', sharey='row',figsize=(12,3))\n",
584    "ax1.imshow(np.reshape(image_test[0,:],(28,28)), interpolation='nearest', cmap=cm.Greys)\n",
585    "ax1.set_title('True image')\n",
586    "ax2.imshow(np.reshape(x_construction[0,:],(28,28)), interpolation='nearest', cmap=cm.Greys)\n",
587    "ax2.set_title('Learned image')\n",
588    "ax3.imshow(np.reshape(image_test[99,:],(28,28)), interpolation='nearest', cmap=cm.Greys)\n",
589    "ax3.set_title('True image')\n",
590    "ax4.imshow(np.reshape(x_construction[99,:],(28,28)), interpolation='nearest', cmap=cm.Greys)\n",
591    "ax4.set_title('Learned image')\n",
592    "plt.show()"
593   ]
594  },
595  {
596   "cell_type": "code",
597   "execution_count": 13,
598   "metadata": {},
599   "outputs": [
600    {
601     "data": {
602      "text/plain": [
603       "[('loss', 140.17346005859375)]"
604      ]
605     },
606     "execution_count": 13,
607     "metadata": {},
608     "output_type": "execute_result"
609    }
610   ],
611   "source": [
612    "# calculate the ELBO which is minus the loss for test set\n",
613    "metric = mx.metric.Loss()\n",
614    "model.score(nd_iter_test, metric)"
615   ]
616  },
617  {
618   "cell_type": "markdown",
619   "metadata": {},
620   "source": [
621    "## 4. All together: MXNet-based class VAE"
622   ]
623  },
624  {
625   "cell_type": "code",
626   "execution_count": 14,
627   "metadata": {},
628   "outputs": [],
629   "source": [
630    "from VAE import VAE"
631   ]
632  },
633  {
634   "cell_type": "markdown",
635   "metadata": {},
636   "source": [
637    "One can directly call the class `VAE` to do the training:\n",
638    "\n",
639    "```VAE(n_latent=5,num_hidden_ecoder=400,num_hidden_decoder=400,x_train=None,x_valid=None,\n",
640    "batch_size=100,learning_rate=0.001,weight_decay=0.01,num_epoch=100,optimizer='sgd',model_prefix=None,\n",
641    "initializer = mx.init.Normal(0.01),likelihood=Bernoulli)```\n",
642    "\n",
643    "The outputs are the learned model and training loss."
644   ]
645  },
646  {
647   "cell_type": "code",
648   "execution_count": 15,
649   "metadata": {},
650   "outputs": [
651    {
652     "name": "stderr",
653     "output_type": "stream",
654     "text": [
655      "INFO:root:Epoch[0] Train-loss=383.478870\n",
656      "INFO:root:Epoch[0] Time cost=5.075\n",
657      "INFO:root:Epoch[1] Train-loss=211.923867\n",
658      "INFO:root:Epoch[1] Time cost=4.741\n",
659      "INFO:root:Epoch[2] Train-loss=206.789445\n",
660      "INFO:root:Epoch[2] Time cost=4.601\n",
661      "INFO:root:Epoch[3] Train-loss=204.428186\n",
662      "INFO:root:Epoch[3] Time cost=4.865\n",
663      "INFO:root:Epoch[4] Train-loss=202.417322\n",
664      "INFO:root:Epoch[4] Time cost=4.606\n",
665      "INFO:root:Epoch[5] Train-loss=200.635136\n",
666      "INFO:root:Epoch[5] Time cost=4.711\n",
667      "INFO:root:Epoch[6] Train-loss=199.009614\n",
668      "INFO:root:Epoch[6] Time cost=5.159\n",
669      "INFO:root:Epoch[7] Train-loss=197.565788\n",
670      "INFO:root:Epoch[7] Time cost=4.588\n",
671      "INFO:root:Epoch[8] Train-loss=196.524507\n",
672      "INFO:root:Epoch[8] Time cost=4.905\n",
673      "INFO:root:Epoch[9] Train-loss=195.725745\n",
674      "INFO:root:Epoch[9] Time cost=4.426\n",
675      "INFO:root:Epoch[10] Train-loss=194.902025\n",
676      "INFO:root:Epoch[10] Time cost=4.685\n",
677      "INFO:root:Epoch[11] Train-loss=194.026873\n",
678      "INFO:root:Epoch[11] Time cost=4.622\n",
679      "INFO:root:Epoch[12] Train-loss=193.350646\n",
680      "INFO:root:Epoch[12] Time cost=4.712\n",
681      "INFO:root:Epoch[13] Train-loss=192.737502\n",
682      "INFO:root:Epoch[13] Time cost=4.618\n",
683      "INFO:root:Epoch[14] Train-loss=192.338165\n",
684      "INFO:root:Epoch[14] Time cost=4.763\n",
685      "INFO:root:Epoch[15] Train-loss=191.888625\n",
686      "INFO:root:Epoch[15] Time cost=5.168\n",
687      "INFO:root:Epoch[16] Train-loss=191.170650\n",
688      "INFO:root:Epoch[16] Time cost=4.809\n",
689      "INFO:root:Epoch[17] Train-loss=190.307264\n",
690      "INFO:root:Epoch[17] Time cost=4.622\n",
691      "INFO:root:Epoch[18] Train-loss=188.988063\n",
692      "INFO:root:Epoch[18] Time cost=4.543\n",
693      "INFO:root:Epoch[19] Train-loss=187.616311\n",
694      "INFO:root:Epoch[19] Time cost=5.154\n",
695      "INFO:root:Epoch[20] Train-loss=186.352783\n",
696      "INFO:root:Epoch[20] Time cost=4.661\n",
697      "INFO:root:Epoch[21] Train-loss=185.428020\n",
698      "INFO:root:Epoch[21] Time cost=5.193\n",
699      "INFO:root:Epoch[22] Train-loss=184.543097\n",
700      "INFO:root:Epoch[22] Time cost=4.519\n",
701      "INFO:root:Epoch[23] Train-loss=184.029907\n",
702      "INFO:root:Epoch[23] Time cost=4.732\n",
703      "INFO:root:Epoch[24] Train-loss=183.643270\n",
704      "INFO:root:Epoch[24] Time cost=5.011\n",
705      "INFO:root:Epoch[25] Train-loss=183.246912\n",
706      "INFO:root:Epoch[25] Time cost=4.706\n",
707      "INFO:root:Epoch[26] Train-loss=183.065233\n",
708      "INFO:root:Epoch[26] Time cost=4.673\n",
709      "INFO:root:Epoch[27] Train-loss=182.680542\n",
710      "INFO:root:Epoch[27] Time cost=4.628\n",
711      "INFO:root:Epoch[28] Train-loss=182.428677\n",
712      "INFO:root:Epoch[28] Time cost=4.772\n",
713      "INFO:root:Epoch[29] Train-loss=182.219946\n",
714      "INFO:root:Epoch[29] Time cost=4.571\n",
715      "INFO:root:Epoch[30] Train-loss=182.070927\n",
716      "INFO:root:Epoch[30] Time cost=4.603\n",
717      "INFO:root:Epoch[31] Train-loss=181.837968\n",
718      "INFO:root:Epoch[31] Time cost=4.559\n",
719      "INFO:root:Epoch[32] Train-loss=181.624303\n",
720      "INFO:root:Epoch[32] Time cost=5.069\n",
721      "INFO:root:Epoch[33] Train-loss=181.534547\n",
722      "INFO:root:Epoch[33] Time cost=4.654\n",
723      "INFO:root:Epoch[34] Train-loss=181.239556\n",
724      "INFO:root:Epoch[34] Time cost=4.776\n",
725      "INFO:root:Epoch[35] Train-loss=181.098188\n",
726      "INFO:root:Epoch[35] Time cost=4.571\n",
727      "INFO:root:Epoch[36] Train-loss=180.820560\n",
728      "INFO:root:Epoch[36] Time cost=4.815\n",
729      "INFO:root:Epoch[37] Train-loss=180.828095\n",
730      "INFO:root:Epoch[37] Time cost=4.455\n",
731      "INFO:root:Epoch[38] Train-loss=180.495569\n",
732      "INFO:root:Epoch[38] Time cost=5.096\n",
733      "INFO:root:Epoch[39] Train-loss=180.389106\n",
734      "INFO:root:Epoch[39] Time cost=4.797\n",
735      "INFO:root:Epoch[40] Train-loss=180.200965\n",
736      "INFO:root:Epoch[40] Time cost=5.054\n",
737      "INFO:root:Epoch[41] Train-loss=179.851014\n",
738      "INFO:root:Epoch[41] Time cost=4.642\n",
739      "INFO:root:Epoch[42] Train-loss=179.719933\n",
740      "INFO:root:Epoch[42] Time cost=4.603\n",
741      "INFO:root:Epoch[43] Train-loss=179.431740\n",
742      "INFO:root:Epoch[43] Time cost=4.341\n",
743      "INFO:root:Epoch[44] Train-loss=179.235384\n",
744      "INFO:root:Epoch[44] Time cost=4.638\n",
745      "INFO:root:Epoch[45] Train-loss=179.108771\n",
746      "INFO:root:Epoch[45] Time cost=4.754\n",
747      "INFO:root:Epoch[46] Train-loss=178.714163\n",
748      "INFO:root:Epoch[46] Time cost=4.457\n",
749      "INFO:root:Epoch[47] Train-loss=178.508338\n",
750      "INFO:root:Epoch[47] Time cost=4.960\n",
751      "INFO:root:Epoch[48] Train-loss=178.288002\n",
752      "INFO:root:Epoch[48] Time cost=4.562\n",
753      "INFO:root:Epoch[49] Train-loss=178.083288\n",
754      "INFO:root:Epoch[49] Time cost=4.619\n",
755      "INFO:root:Epoch[50] Train-loss=177.791330\n",
756      "INFO:root:Epoch[50] Time cost=4.580\n",
757      "INFO:root:Epoch[51] Train-loss=177.570741\n",
758      "INFO:root:Epoch[51] Time cost=4.704\n",
759      "INFO:root:Epoch[52] Train-loss=177.287114\n",
760      "INFO:root:Epoch[52] Time cost=5.172\n",
761      "INFO:root:Epoch[53] Train-loss=177.122645\n",
762      "INFO:root:Epoch[53] Time cost=4.678\n",
763      "INFO:root:Epoch[54] Train-loss=176.816022\n",
764      "INFO:root:Epoch[54] Time cost=4.819\n",
765      "INFO:root:Epoch[55] Train-loss=176.670484\n",
766      "INFO:root:Epoch[55] Time cost=4.568\n",
767      "INFO:root:Epoch[56] Train-loss=176.459671\n",
768      "INFO:root:Epoch[56] Time cost=4.450\n",
769      "INFO:root:Epoch[57] Train-loss=176.174175\n",
770      "INFO:root:Epoch[57] Time cost=4.579\n",
771      "INFO:root:Epoch[58] Train-loss=175.935856\n",
772      "INFO:root:Epoch[58] Time cost=4.552\n",
773      "INFO:root:Epoch[59] Train-loss=175.739928\n",
774      "INFO:root:Epoch[59] Time cost=4.385\n",
775      "INFO:root:Epoch[60] Train-loss=175.579695\n",
776      "INFO:root:Epoch[60] Time cost=4.496\n",
777      "INFO:root:Epoch[61] Train-loss=175.403871\n",
778      "INFO:root:Epoch[61] Time cost=5.088\n",
779      "INFO:root:Epoch[62] Train-loss=175.157114\n",
780      "INFO:root:Epoch[62] Time cost=4.628\n",
781      "INFO:root:Epoch[63] Train-loss=174.953950\n",
782      "INFO:root:Epoch[63] Time cost=4.826\n",
783      "INFO:root:Epoch[64] Train-loss=174.743393\n",
784      "INFO:root:Epoch[64] Time cost=4.832\n",
785      "INFO:root:Epoch[65] Train-loss=174.554056\n",
786      "INFO:root:Epoch[65] Time cost=4.375\n",
787      "INFO:root:Epoch[66] Train-loss=174.366719\n",
788      "INFO:root:Epoch[66] Time cost=4.583\n",
789      "INFO:root:Epoch[67] Train-loss=174.160622\n",
790      "INFO:root:Epoch[67] Time cost=4.586\n",
791      "INFO:root:Epoch[68] Train-loss=173.981699\n",
792      "INFO:root:Epoch[68] Time cost=5.149\n",
793      "INFO:root:Epoch[69] Train-loss=173.751617\n",
794      "INFO:root:Epoch[69] Time cost=4.495\n",
795      "INFO:root:Epoch[70] Train-loss=173.548732\n",
796      "INFO:root:Epoch[70] Time cost=4.588\n",
797      "INFO:root:Epoch[71] Train-loss=173.380950\n",
798      "INFO:root:Epoch[71] Time cost=5.042\n",
799      "INFO:root:Epoch[72] Train-loss=173.158519\n",
800      "INFO:root:Epoch[72] Time cost=4.817\n",
801      "INFO:root:Epoch[73] Train-loss=172.970726\n",
802      "INFO:root:Epoch[73] Time cost=4.791\n",
803      "INFO:root:Epoch[74] Train-loss=172.782357\n",
804      "INFO:root:Epoch[74] Time cost=4.377\n",
805      "INFO:root:Epoch[75] Train-loss=172.581992\n",
806      "INFO:root:Epoch[75] Time cost=4.518\n",
807      "INFO:root:Epoch[76] Train-loss=172.385020\n",
808      "INFO:root:Epoch[76] Time cost=4.863\n",
809      "INFO:root:Epoch[77] Train-loss=172.198309\n",
810      "INFO:root:Epoch[77] Time cost=5.104\n",
811      "INFO:root:Epoch[78] Train-loss=172.022333\n",
812      "INFO:root:Epoch[78] Time cost=4.571\n",
813      "INFO:root:Epoch[79] Train-loss=171.816585\n",
814      "INFO:root:Epoch[79] Time cost=4.557\n",
815      "INFO:root:Epoch[80] Train-loss=171.643714\n",
816      "INFO:root:Epoch[80] Time cost=4.567\n",
817      "INFO:root:Epoch[81] Train-loss=171.460581\n",
818      "INFO:root:Epoch[81] Time cost=4.735\n",
819      "INFO:root:Epoch[82] Train-loss=171.284854\n",
820      "INFO:root:Epoch[82] Time cost=5.012\n",
821      "INFO:root:Epoch[83] Train-loss=171.113129\n",
822      "INFO:root:Epoch[83] Time cost=4.877\n",
823      "INFO:root:Epoch[84] Train-loss=170.947790\n",
824      "INFO:root:Epoch[84] Time cost=4.487\n",
825      "INFO:root:Epoch[85] Train-loss=170.766223\n",
826      "INFO:root:Epoch[85] Time cost=4.723\n",
827      "INFO:root:Epoch[86] Train-loss=170.602559\n",
828      "INFO:root:Epoch[86] Time cost=4.803\n",
829      "INFO:root:Epoch[87] Train-loss=170.448713\n",
830      "INFO:root:Epoch[87] Time cost=4.636\n",
831      "INFO:root:Epoch[88] Train-loss=170.273053\n",
832      "INFO:root:Epoch[88] Time cost=4.562\n",
833      "INFO:root:Epoch[89] Train-loss=170.099485\n",
834      "INFO:root:Epoch[89] Time cost=4.567\n",
835      "INFO:root:Epoch[90] Train-loss=169.934289\n",
836      "INFO:root:Epoch[90] Time cost=4.905\n",
837      "INFO:root:Epoch[91] Train-loss=169.768920\n",
838      "INFO:root:Epoch[91] Time cost=4.636\n",
839      "INFO:root:Epoch[92] Train-loss=169.620803\n",
840      "INFO:root:Epoch[92] Time cost=4.429\n",
841      "INFO:root:Epoch[93] Train-loss=169.448189\n",
842      "INFO:root:Epoch[93] Time cost=4.985\n",
843      "INFO:root:Epoch[94] Train-loss=169.295794\n",
844      "INFO:root:Epoch[94] Time cost=4.649\n",
845      "INFO:root:Epoch[95] Train-loss=169.143627\n",
846      "INFO:root:Epoch[95] Time cost=4.602\n",
847      "INFO:root:Epoch[96] Train-loss=168.989410\n",
848      "INFO:root:Epoch[96] Time cost=4.904\n",
849      "INFO:root:Epoch[97] Train-loss=168.841089\n",
850      "INFO:root:Epoch[97] Time cost=4.602\n",
851      "INFO:root:Epoch[98] Train-loss=168.694906\n",
852      "INFO:root:Epoch[98] Time cost=4.589\n",
853      "INFO:root:Epoch[99] Train-loss=168.527604\n",
854      "INFO:root:Epoch[99] Time cost=4.560\n",
855      "INFO:root:Epoch[100] Train-loss=168.385596\n",
856      "INFO:root:Epoch[100] Time cost=4.835\n",
857      "INFO:root:Epoch[101] Train-loss=168.246526\n",
858      "INFO:root:Epoch[101] Time cost=4.558\n",
859      "INFO:root:Epoch[102] Train-loss=168.093663\n",
860      "INFO:root:Epoch[102] Time cost=4.609\n",
861      "INFO:root:Epoch[103] Train-loss=167.938807\n",
862      "INFO:root:Epoch[103] Time cost=4.599\n",
863      "INFO:root:Epoch[104] Train-loss=167.814916\n",
864      "INFO:root:Epoch[104] Time cost=4.394\n",
865      "INFO:root:Epoch[105] Train-loss=167.676473\n"
866     ]
867    },
868    {
869     "name": "stderr",
870     "output_type": "stream",
871     "text": [
872      "INFO:root:Epoch[105] Time cost=4.724\n",
873      "INFO:root:Epoch[106] Train-loss=167.560241\n",
874      "INFO:root:Epoch[106] Time cost=4.316\n",
875      "INFO:root:Epoch[107] Train-loss=167.424132\n",
876      "INFO:root:Epoch[107] Time cost=4.646\n",
877      "INFO:root:Epoch[108] Train-loss=167.284482\n",
878      "INFO:root:Epoch[108] Time cost=4.472\n",
879      "INFO:root:Epoch[109] Train-loss=167.184511\n",
880      "INFO:root:Epoch[109] Time cost=4.768\n",
881      "INFO:root:Epoch[110] Train-loss=167.037793\n",
882      "INFO:root:Epoch[110] Time cost=4.717\n",
883      "INFO:root:Epoch[111] Train-loss=166.916652\n",
884      "INFO:root:Epoch[111] Time cost=4.803\n",
885      "INFO:root:Epoch[112] Train-loss=166.796803\n",
886      "INFO:root:Epoch[112] Time cost=4.617\n",
887      "INFO:root:Epoch[113] Train-loss=166.655028\n",
888      "INFO:root:Epoch[113] Time cost=4.420\n",
889      "INFO:root:Epoch[114] Train-loss=166.561129\n",
890      "INFO:root:Epoch[114] Time cost=4.333\n",
891      "INFO:root:Epoch[115] Train-loss=166.434593\n",
892      "INFO:root:Epoch[115] Time cost=4.526\n",
893      "INFO:root:Epoch[116] Train-loss=166.322805\n",
894      "INFO:root:Epoch[116] Time cost=4.310\n",
895      "INFO:root:Epoch[117] Train-loss=166.195452\n",
896      "INFO:root:Epoch[117] Time cost=4.458\n",
897      "INFO:root:Epoch[118] Train-loss=166.073792\n",
898      "INFO:root:Epoch[118] Time cost=4.333\n",
899      "INFO:root:Epoch[119] Train-loss=165.967437\n",
900      "INFO:root:Epoch[119] Time cost=4.459\n",
901      "INFO:root:Epoch[120] Train-loss=165.876094\n",
902      "INFO:root:Epoch[120] Time cost=5.070\n",
903      "INFO:root:Epoch[121] Train-loss=165.748064\n",
904      "INFO:root:Epoch[121] Time cost=4.782\n",
905      "INFO:root:Epoch[122] Train-loss=165.656283\n",
906      "INFO:root:Epoch[122] Time cost=4.640\n",
907      "INFO:root:Epoch[123] Train-loss=165.540462\n",
908      "INFO:root:Epoch[123] Time cost=4.522\n",
909      "INFO:root:Epoch[124] Train-loss=165.448734\n",
910      "INFO:root:Epoch[124] Time cost=4.858\n",
911      "INFO:root:Epoch[125] Train-loss=165.347751\n",
912      "INFO:root:Epoch[125] Time cost=4.842\n",
913      "INFO:root:Epoch[126] Train-loss=165.230048\n",
914      "INFO:root:Epoch[126] Time cost=4.495\n",
915      "INFO:root:Epoch[127] Train-loss=165.147932\n",
916      "INFO:root:Epoch[127] Time cost=4.766\n",
917      "INFO:root:Epoch[128] Train-loss=165.036021\n",
918      "INFO:root:Epoch[128] Time cost=4.526\n",
919      "INFO:root:Epoch[129] Train-loss=164.977613\n",
920      "INFO:root:Epoch[129] Time cost=5.091\n",
921      "INFO:root:Epoch[130] Train-loss=164.881467\n",
922      "INFO:root:Epoch[130] Time cost=5.223\n",
923      "INFO:root:Epoch[131] Train-loss=164.785627\n",
924      "INFO:root:Epoch[131] Time cost=4.165\n",
925      "INFO:root:Epoch[132] Train-loss=164.707629\n",
926      "INFO:root:Epoch[132] Time cost=4.527\n",
927      "INFO:root:Epoch[133] Train-loss=164.598039\n",
928      "INFO:root:Epoch[133] Time cost=4.167\n",
929      "INFO:root:Epoch[134] Train-loss=164.502932\n",
930      "INFO:root:Epoch[134] Time cost=4.354\n",
931      "INFO:root:Epoch[135] Train-loss=164.422286\n",
932      "INFO:root:Epoch[135] Time cost=4.387\n",
933      "INFO:root:Epoch[136] Train-loss=164.344749\n",
934      "INFO:root:Epoch[136] Time cost=4.662\n",
935      "INFO:root:Epoch[137] Train-loss=164.264898\n",
936      "INFO:root:Epoch[137] Time cost=4.671\n",
937      "INFO:root:Epoch[138] Train-loss=164.178707\n",
938      "INFO:root:Epoch[138] Time cost=4.776\n",
939      "INFO:root:Epoch[139] Train-loss=164.109071\n",
940      "INFO:root:Epoch[139] Time cost=4.787\n",
941      "INFO:root:Epoch[140] Train-loss=163.993291\n",
942      "INFO:root:Epoch[140] Time cost=4.726\n",
943      "INFO:root:Epoch[141] Train-loss=163.956234\n",
944      "INFO:root:Epoch[141] Time cost=4.337\n",
945      "INFO:root:Epoch[142] Train-loss=163.845638\n",
946      "INFO:root:Epoch[142] Time cost=4.787\n",
947      "INFO:root:Epoch[143] Train-loss=163.790882\n",
948      "INFO:root:Epoch[143] Time cost=5.563\n",
949      "INFO:root:Epoch[144] Train-loss=163.723495\n",
950      "INFO:root:Epoch[144] Time cost=4.529\n",
951      "INFO:root:Epoch[145] Train-loss=163.634262\n",
952      "INFO:root:Epoch[145] Time cost=5.028\n",
953      "INFO:root:Epoch[146] Train-loss=163.552854\n",
954      "INFO:root:Epoch[146] Time cost=4.933\n",
955      "INFO:root:Epoch[147] Train-loss=163.501429\n",
956      "INFO:root:Epoch[147] Time cost=4.912\n",
957      "INFO:root:Epoch[148] Train-loss=163.444245\n",
958      "INFO:root:Epoch[148] Time cost=5.034\n",
959      "INFO:root:Epoch[149] Train-loss=163.348476\n",
960      "INFO:root:Epoch[149] Time cost=4.600\n",
961      "INFO:root:Epoch[150] Train-loss=163.256955\n",
962      "INFO:root:Epoch[150] Time cost=4.704\n",
963      "INFO:root:Epoch[151] Train-loss=163.216139\n",
964      "INFO:root:Epoch[151] Time cost=4.670\n",
965      "INFO:root:Epoch[152] Train-loss=163.144691\n",
966      "INFO:root:Epoch[152] Time cost=4.678\n",
967      "INFO:root:Epoch[153] Train-loss=163.050236\n",
968      "INFO:root:Epoch[153] Time cost=4.595\n",
969      "INFO:root:Epoch[154] Train-loss=162.991225\n",
970      "INFO:root:Epoch[154] Time cost=5.307\n",
971      "INFO:root:Epoch[155] Train-loss=162.907200\n",
972      "INFO:root:Epoch[155] Time cost=4.684\n",
973      "INFO:root:Epoch[156] Train-loss=162.838075\n",
974      "INFO:root:Epoch[156] Time cost=4.686\n",
975      "INFO:root:Epoch[157] Train-loss=162.759286\n",
976      "INFO:root:Epoch[157] Time cost=4.750\n",
977      "INFO:root:Epoch[158] Train-loss=162.725998\n",
978      "INFO:root:Epoch[158] Time cost=4.637\n",
979      "INFO:root:Epoch[159] Train-loss=162.635852\n",
980      "INFO:root:Epoch[159] Time cost=4.498\n",
981      "INFO:root:Epoch[160] Train-loss=162.563777\n",
982      "INFO:root:Epoch[160] Time cost=5.048\n",
983      "INFO:root:Epoch[161] Train-loss=162.527387\n",
984      "INFO:root:Epoch[161] Time cost=5.040\n",
985      "INFO:root:Epoch[162] Train-loss=162.395881\n",
986      "INFO:root:Epoch[162] Time cost=4.764\n",
987      "INFO:root:Epoch[163] Train-loss=162.353654\n",
988      "INFO:root:Epoch[163] Time cost=4.561\n",
989      "INFO:root:Epoch[164] Train-loss=162.285584\n",
990      "INFO:root:Epoch[164] Time cost=5.051\n",
991      "INFO:root:Epoch[165] Train-loss=162.204332\n",
992      "INFO:root:Epoch[165] Time cost=4.455\n",
993      "INFO:root:Epoch[166] Train-loss=162.147100\n",
994      "INFO:root:Epoch[166] Time cost=5.021\n",
995      "INFO:root:Epoch[167] Train-loss=162.051296\n",
996      "INFO:root:Epoch[167] Time cost=4.551\n",
997      "INFO:root:Epoch[168] Train-loss=161.978708\n",
998      "INFO:root:Epoch[168] Time cost=4.744\n",
999      "INFO:root:Epoch[169] Train-loss=161.927990\n",
1000      "INFO:root:Epoch[169] Time cost=4.821\n",
1001      "INFO:root:Epoch[170] Train-loss=161.883088\n",
1002      "INFO:root:Epoch[170] Time cost=4.365\n",
1003      "INFO:root:Epoch[171] Train-loss=161.785367\n",
1004      "INFO:root:Epoch[171] Time cost=4.448\n",
1005      "INFO:root:Epoch[172] Train-loss=161.716386\n",
1006      "INFO:root:Epoch[172] Time cost=4.622\n",
1007      "INFO:root:Epoch[173] Train-loss=161.656391\n",
1008      "INFO:root:Epoch[173] Time cost=4.500\n",
1009      "INFO:root:Epoch[174] Train-loss=161.598127\n",
1010      "INFO:root:Epoch[174] Time cost=4.677\n",
1011      "INFO:root:Epoch[175] Train-loss=161.518613\n",
1012      "INFO:root:Epoch[175] Time cost=4.958\n",
1013      "INFO:root:Epoch[176] Train-loss=161.418783\n",
1014      "INFO:root:Epoch[176] Time cost=4.607\n",
1015      "INFO:root:Epoch[177] Train-loss=161.407767\n",
1016      "INFO:root:Epoch[177] Time cost=4.427\n",
1017      "INFO:root:Epoch[178] Train-loss=161.319552\n",
1018      "INFO:root:Epoch[178] Time cost=4.930\n",
1019      "INFO:root:Epoch[179] Train-loss=161.234087\n",
1020      "INFO:root:Epoch[179] Time cost=4.240\n",
1021      "INFO:root:Epoch[180] Train-loss=161.187404\n",
1022      "INFO:root:Epoch[180] Time cost=4.484\n",
1023      "INFO:root:Epoch[181] Train-loss=161.123118\n",
1024      "INFO:root:Epoch[181] Time cost=4.937\n",
1025      "INFO:root:Epoch[182] Train-loss=160.999420\n",
1026      "INFO:root:Epoch[182] Time cost=4.489\n",
1027      "INFO:root:Epoch[183] Train-loss=160.955369\n",
1028      "INFO:root:Epoch[183] Time cost=4.894\n",
1029      "INFO:root:Epoch[184] Train-loss=160.908542\n",
1030      "INFO:root:Epoch[184] Time cost=4.269\n",
1031      "INFO:root:Epoch[185] Train-loss=160.846908\n",
1032      "INFO:root:Epoch[185] Time cost=4.998\n",
1033      "INFO:root:Epoch[186] Train-loss=160.765964\n",
1034      "INFO:root:Epoch[186] Time cost=4.467\n",
1035      "INFO:root:Epoch[187] Train-loss=160.687773\n",
1036      "INFO:root:Epoch[187] Time cost=4.609\n",
1037      "INFO:root:Epoch[188] Train-loss=160.652674\n",
1038      "INFO:root:Epoch[188] Time cost=5.327\n",
1039      "INFO:root:Epoch[189] Train-loss=160.551175\n",
1040      "INFO:root:Epoch[189] Time cost=4.267\n",
1041      "INFO:root:Epoch[190] Train-loss=160.477424\n",
1042      "INFO:root:Epoch[190] Time cost=4.798\n",
1043      "INFO:root:Epoch[191] Train-loss=160.501221\n",
1044      "INFO:root:Epoch[191] Time cost=4.695\n",
1045      "INFO:root:Epoch[192] Train-loss=160.370335\n",
1046      "INFO:root:Epoch[192] Time cost=4.640\n",
1047      "INFO:root:Epoch[193] Train-loss=160.279749\n",
1048      "INFO:root:Epoch[193] Time cost=4.653\n",
1049      "INFO:root:Epoch[194] Train-loss=160.242415\n",
1050      "INFO:root:Epoch[194] Time cost=5.044\n",
1051      "INFO:root:Epoch[195] Train-loss=160.197063\n",
1052      "INFO:root:Epoch[195] Time cost=4.684\n",
1053      "INFO:root:Epoch[196] Train-loss=160.132983\n",
1054      "INFO:root:Epoch[196] Time cost=4.460\n",
1055      "INFO:root:Epoch[197] Train-loss=160.083149\n",
1056      "INFO:root:Epoch[197] Time cost=4.713\n",
1057      "INFO:root:Epoch[198] Train-loss=160.025012\n",
1058      "INFO:root:Epoch[198] Time cost=4.779\n",
1059      "INFO:root:Epoch[199] Train-loss=159.945513\n",
1060      "INFO:root:Epoch[199] Time cost=4.659\n"
1061     ]
1062    }
1063   ],
1064   "source": [
1065    "# can initilize weights and biases with the learned parameters as follows: \n",
1066    "# init = mx.initializer.Load(params)\n",
1067    "\n",
1068    "# call the VAE, output model contains the learned model and training loss\n",
1069    "out = VAE(n_latent=2, x_train=image, x_valid=None, num_epoch=200) "
1070   ]
1071  },
1072  {
1073   "cell_type": "code",
1074   "execution_count": 16,
1075   "metadata": {},
1076   "outputs": [],
1077   "source": [
1078    "# encode test images to obtain mu and logvar which are used for sampling\n",
1079    "[mu,logvar] = VAE.encoder(out,image_test)\n",
1080    "# sample in the latent space\n",
1081    "z = VAE.sampler(mu,logvar)\n",
1082    "# decode from the latent space to obtain reconstructed images\n",
1083    "x_construction = VAE.decoder(out,z)\n"
1084   ]
1085  },
1086  {
1087   "cell_type": "code",
1088   "execution_count": 17,
1089   "metadata": {},
1090   "outputs": [
1091    {
1092     "data": {
1093      "image/png": "\n",
1094      "text/plain": [
1095       "<Figure size 864x216 with 4 Axes>"
1096      ]
1097     },
1098     "metadata": {},
1099     "output_type": "display_data"
1100    }
1101   ],
1102   "source": [
1103    "f, ((ax1, ax2, ax3, ax4)) = plt.subplots(1,4,  sharex='col', sharey='row',figsize=(12,3))\n",
1104    "ax1.imshow(np.reshape(image_test[0,:],(28,28)), interpolation='nearest', cmap=cm.Greys)\n",
1105    "ax1.set_title('True image')\n",
1106    "ax2.imshow(np.reshape(x_construction[0,:],(28,28)), interpolation='nearest', cmap=cm.Greys)\n",
1107    "ax2.set_title('Learned image')\n",
1108    "ax3.imshow(np.reshape(image_test[146,:],(28,28)), interpolation='nearest', cmap=cm.Greys)\n",
1109    "ax3.set_title('True image')\n",
1110    "ax4.imshow(np.reshape(x_construction[146,:],(28,28)), interpolation='nearest', cmap=cm.Greys)\n",
1111    "ax4.set_title('Learned image')\n",
1112    "plt.show()"
1113   ]
1114  },
1115  {
1116   "cell_type": "code",
1117   "execution_count": 18,
1118   "metadata": {},
1119   "outputs": [
1120    {
1121     "name": "stderr",
1122     "output_type": "stream",
1123     "text": [
1124      "DEBUG:matplotlib.font_manager:findfont: Matching :family=sans-serif:style=normal:variant=normal:weight=normal:stretch=normal:size=15.0 to DejaVu Sans ('/usr/local/lib/python3.5/dist-packages/matplotlib/mpl-data/fonts/ttf/DejaVuSans.ttf') with score of 0.050000\n"
1125     ]
1126    },
1127    {
1128     "data": {
1129      "image/png": "\n",
1130      "text/plain": [
1131       "<Figure size 432x288 with 1 Axes>"
1132      ]
1133     },
1134     "metadata": {},
1135     "output_type": "display_data"
1136    },
1137    {
1138     "data": {
1139      "image/png": "\n",
1140      "text/plain": [
1141       "<Figure size 864x180 with 6 Axes>"
1142      ]
1143     },
1144     "metadata": {},
1145     "output_type": "display_data"
1146    }
1147   ],
1148   "source": [
1149    "z1 = z[:,0]\n",
1150    "z2 = z[:,1]\n",
1151    "\n",
1152    "fig = plt.figure()\n",
1153    "ax = fig.add_subplot(111)\n",
1154    "ax.plot(z1,z2,'ko')\n",
1155    "plt.title(\"latent space\")\n",
1156    "\n",
1157    "#np.where((z1>3) & (z2<2) & (z2>0))\n",
1158    "#select the points from the latent space\n",
1159    "a_vec = [2,5,7,789,25,9993]\n",
1160    "for i in range(len(a_vec)):\n",
1161    "    ax.plot(z1[a_vec[i]],z2[a_vec[i]],'ro')  \n",
1162    "    ax.annotate('z%d' %i, xy=(z1[a_vec[i]],z2[a_vec[i]]), \n",
1163    "                xytext=(z1[a_vec[i]],z2[a_vec[i]]),color = 'r',fontsize=15)\n",
1164    "\n",
1165    "\n",
1166    "f, ((ax0, ax1, ax2, ax3, ax4,ax5)) = plt.subplots(1,6,  sharex='col', sharey='row',figsize=(12,2.5))\n",
1167    "for i in range(len(a_vec)):\n",
1168    "    eval('ax%d' %(i)).imshow(np.reshape(x_construction[a_vec[i],:],(28,28)), interpolation='nearest', cmap=cm.Greys)\n",
1169    "    eval('ax%d' %(i)).set_title('z%d'%i)\n",
1170    "\n",
1171    "plt.show()"
1172   ]
1173  },
1174  {
1175   "cell_type": "markdown",
1176   "metadata": {},
1177   "source": [
1178    "Above is a plot of points in the 2D latent space and their corresponding decoded images, it can be seen that points that are close in the latent space get mapped to the same digit from the decoder, and we can see how it evolves from left to right."
1179   ]
1180  }
1181 ],
1182 "metadata": {
1183  "anaconda-cloud": {},
1184  "kernelspec": {
1185   "display_name": "Python 3",
1186   "language": "python",
1187   "name": "python3"
1188  },
1189  "language_info": {
1190   "codemirror_mode": {
1191    "name": "ipython",
1192    "version": 3
1193   },
1194   "file_extension": ".py",
1195   "mimetype": "text/x-python",
1196   "name": "python",
1197   "nbconvert_exporter": "python",
1198   "pygments_lexer": "ipython3",
1199   "version": "3.5.2"
1200  }
1201 },
1202 "nbformat": 4,
1203 "nbformat_minor": 2
1204}
1205