1# Licensed to the Apache Software Foundation (ASF) under one 2# or more contributor license agreements. See the NOTICE file 3# distributed with this work for additional information 4# regarding copyright ownership. The ASF licenses this file 5# to you under the Apache License, Version 2.0 (the 6# "License"); you may not use this file except in compliance 7# with the License. You may obtain a copy of the License at 8# 9# http://www.apache.org/licenses/LICENSE-2.0 10# 11# Unless required by applicable law or agreed to in writing, 12# software distributed under the License is distributed on an 13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14# KIND, either express or implied. See the License for the 15# specific language governing permissions and limitations 16# under the License. 17 18import random as pyrnd 19import argparse 20import numpy as np 21import mxnet as mx 22from matplotlib import pyplot as plt 23import binary_rbm 24 25 26### Command line arguments 27 28parser = argparse.ArgumentParser(description='Restricted Boltzmann machine learning MNIST') 29parser.add_argument('--num-hidden', type=int, default=500, help='number of hidden units') 30parser.add_argument('--k', type=int, default=30, help='number of Gibbs sampling steps used in the PCD algorithm') 31parser.add_argument('--batch-size', type=int, default=80, help='batch size') 32parser.add_argument('--num-epoch', type=int, default=130, help='number of epochs') 33parser.add_argument('--learning-rate', type=float, default=0.1, help='learning rate for stochastic gradient descent') # The optimizer rescales this with `1 / batch_size` 34parser.add_argument('--momentum', type=float, default=0.3, help='momentum for the stochastic gradient descent') 35parser.add_argument('--ais-batch-size', type=int, default=100, help='batch size for AIS to estimate the log-likelihood') 36parser.add_argument('--ais-num-batch', type=int, default=10, help='number of batches for AIS to estimate the log-likelihood') 37parser.add_argument('--ais-intermediate-steps', type=int, default=10, help='number of intermediate distributions for AIS to estimate the log-likelihood') 38parser.add_argument('--ais-burn-in-steps', type=int, default=10, help='number of burn in steps for each intermediate distributions of AIS to estimate the log-likelihood') 39parser.add_argument('--cuda', action='store_true', dest='cuda', help='train on GPU with CUDA') 40parser.add_argument('--no-cuda', action='store_false', dest='cuda', help='train on CPU') 41parser.add_argument('--device-id', type=int, default=0, help='GPU device id') 42parser.set_defaults(cuda=True) 43 44args = parser.parse_args() 45print(args) 46 47### Global environment 48 49mx.random.seed(pyrnd.getrandbits(32)) 50ctx = mx.gpu(args.device_id) if args.cuda else mx.cpu() 51 52### Prepare data 53 54mnist = mx.test_utils.get_mnist() # Each pixel has a value in [0, 1]. 55mnist_train_data = mnist['train_data'] 56mnist_test_data = mnist['test_data'] 57img_height = mnist_train_data.shape[2] 58img_width = mnist_train_data.shape[3] 59num_visible = img_width * img_height 60 61# The iterators generate arrays with shape (batch_size, num_channel = 1, height = 28, width = 28) 62train_iter = mx.io.NDArrayIter( 63 data={'data': mnist_train_data}, 64 batch_size=args.batch_size, 65 shuffle=True) 66test_iter = mx.io.NDArrayIter( 67 data={'data': mnist_test_data}, 68 batch_size=args.batch_size, 69 shuffle=True) 70 71 72### Define symbols 73 74data = mx.sym.Variable('data') # (batch_size, num_channel = 1, height, width) 75flattened_data = mx.sym.flatten(data=data) # (batch_size, num_channel * height * width) 76visible_layer_bias = mx.sym.Variable('visible_layer_bias', init=mx.init.Normal(sigma=.01)) 77hidden_layer_bias = mx.sym.Variable('hidden_layer_bias', init=mx.init.Normal(sigma=.01)) 78interaction_weight = mx.sym.Variable('interaction_weight', init=mx.init.Normal(sigma=.01)) 79aux_hidden_layer_sample = mx.sym.Variable('aux_hidden_layer_sample', init=mx.init.Normal(sigma=.01)) 80aux_hidden_layer_prob_1 = mx.sym.Variable('aux_hidden_layer_prob_1', init=mx.init.Constant(0)) 81 82 83### Train 84 85rbm = mx.sym.Custom( 86 flattened_data, 87 visible_layer_bias, 88 hidden_layer_bias, 89 interaction_weight, 90 aux_hidden_layer_sample, 91 aux_hidden_layer_prob_1, 92 num_hidden=args.num_hidden, 93 k=args.k, 94 for_training=True, 95 op_type='BinaryRBM', 96 name='rbm') 97model = mx.mod.Module(symbol=rbm, context=ctx, data_names=['data'], label_names=None) 98model.bind(data_shapes=train_iter.provide_data) 99model.init_params() 100model.init_optimizer(optimizer='sgd', optimizer_params={'learning_rate': args.learning_rate, 'momentum': args.momentum}) 101 102for epoch in range(args.num_epoch): 103 # Update parameters 104 train_iter.reset() 105 for batch in train_iter: 106 model.forward(batch) 107 model.backward() 108 model.update() 109 mx.nd.waitall() 110 111 # Monitor the performace of the model 112 params = model.get_params()[0] 113 param_visible_layer_bias = params['visible_layer_bias'].as_in_context(ctx) 114 param_hidden_layer_bias = params['hidden_layer_bias'].as_in_context(ctx) 115 param_interaction_weight = params['interaction_weight'].as_in_context(ctx) 116 test_iter.reset() 117 test_log_likelihood, _ = binary_rbm.estimate_log_likelihood( 118 param_visible_layer_bias, param_hidden_layer_bias, param_interaction_weight, 119 args.ais_batch_size, args.ais_num_batch, args.ais_intermediate_steps, args.ais_burn_in_steps, test_iter, ctx) 120 train_iter.reset() 121 train_log_likelihood, _ = binary_rbm.estimate_log_likelihood( 122 param_visible_layer_bias, param_hidden_layer_bias, param_interaction_weight, 123 args.ais_batch_size, args.ais_num_batch, args.ais_intermediate_steps, args.ais_burn_in_steps, train_iter, ctx) 124 print("Epoch %d completed with test log-likelihood %f and train log-likelihood %f" % (epoch, test_log_likelihood, train_log_likelihood)) 125 126### Show some samples. 127 128# Each sample is obtained by 3000 steps of Gibbs sampling starting from a real sample. 129# Starting from the real data is just for convenience of implmentation. 130# There must be no correlation between the initial states and the resulting samples. 131# You can start from random states and run the Gibbs chain for sufficiently long time. 132 133print("Preparing showcase") 134 135showcase_gibbs_sampling_steps = 3000 136showcase_num_samples_w = 15 137showcase_num_samples_h = 15 138showcase_num_samples = showcase_num_samples_w * showcase_num_samples_h 139showcase_img_shape = (showcase_num_samples_h * img_height, 2 * showcase_num_samples_w * img_width) 140showcase_img_column_shape = (showcase_num_samples_h * img_height, img_width) 141 142params = model.get_params()[0] # We don't need aux states here 143showcase_rbm = mx.sym.Custom( 144 flattened_data, 145 visible_layer_bias, 146 hidden_layer_bias, 147 interaction_weight, 148 num_hidden=args.num_hidden, 149 k=showcase_gibbs_sampling_steps, 150 for_training=False, 151 op_type='BinaryRBM', 152 name='showcase_rbm') 153showcase_iter = mx.io.NDArrayIter( 154 data={'data': mnist['train_data']}, 155 batch_size=showcase_num_samples_h, 156 shuffle=True) 157showcase_model = mx.mod.Module(symbol=showcase_rbm, context=ctx, data_names=['data'], label_names=None) 158showcase_model.bind(data_shapes=showcase_iter.provide_data, for_training=False) 159showcase_model.set_params(params, aux_params=None) 160showcase_img = np.zeros(showcase_img_shape) 161for sample_batch, i, data_batch in showcase_model.iter_predict(eval_data=showcase_iter, num_batch=showcase_num_samples_w): 162 # Each pixel is the probability that the unit is 1. 163 showcase_img[:, i * img_width : (i + 1) * img_width] = data_batch.data[0].reshape(showcase_img_column_shape).asnumpy() 164 showcase_img[:, (showcase_num_samples_w + i) * img_width : (showcase_num_samples_w + i + 1) * img_width 165 ] = sample_batch[0].reshape(showcase_img_column_shape).asnumpy() 166s = plt.imshow(showcase_img, cmap='gray') 167plt.axis('off') 168plt.axvline(showcase_num_samples_w * img_width, color='y') 169plt.show(s) 170 171print("Done") 172