1# Licensed to the Apache Software Foundation (ASF) under one
2# or more contributor license agreements.  See the NOTICE file
3# distributed with this work for additional information
4# regarding copyright ownership.  The ASF licenses this file
5# to you under the Apache License, Version 2.0 (the
6# "License"); you may not use this file except in compliance
7# with the License.  You may obtain a copy of the License at
8#
9#   http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing,
12# software distributed under the License is distributed on an
13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14# KIND, either express or implied.  See the License for the
15# specific language governing permissions and limitations
16# under the License.
17
18import random as pyrnd
19import argparse
20import numpy as np
21import mxnet as mx
22from matplotlib import pyplot as plt
23import binary_rbm
24
25
26### Command line arguments
27
28parser = argparse.ArgumentParser(description='Restricted Boltzmann machine learning MNIST')
29parser.add_argument('--num-hidden', type=int, default=500, help='number of hidden units')
30parser.add_argument('--k', type=int, default=30, help='number of Gibbs sampling steps used in the PCD algorithm')
31parser.add_argument('--batch-size', type=int, default=80, help='batch size')
32parser.add_argument('--num-epoch', type=int, default=130, help='number of epochs')
33parser.add_argument('--learning-rate', type=float, default=0.1, help='learning rate for stochastic gradient descent') # The optimizer rescales this with `1 / batch_size`
34parser.add_argument('--momentum', type=float, default=0.3, help='momentum for the stochastic gradient descent')
35parser.add_argument('--ais-batch-size', type=int, default=100, help='batch size for AIS to estimate the log-likelihood')
36parser.add_argument('--ais-num-batch', type=int, default=10, help='number of batches for AIS to estimate the log-likelihood')
37parser.add_argument('--ais-intermediate-steps', type=int, default=10, help='number of intermediate distributions for AIS to estimate the log-likelihood')
38parser.add_argument('--ais-burn-in-steps', type=int, default=10, help='number of burn in steps for each intermediate distributions of AIS to estimate the log-likelihood')
39parser.add_argument('--cuda', action='store_true', dest='cuda', help='train on GPU with CUDA')
40parser.add_argument('--no-cuda', action='store_false', dest='cuda', help='train on CPU')
41parser.add_argument('--device-id', type=int, default=0, help='GPU device id')
42parser.set_defaults(cuda=True)
43
44args = parser.parse_args()
45print(args)
46
47### Global environment
48
49mx.random.seed(pyrnd.getrandbits(32))
50ctx = mx.gpu(args.device_id) if args.cuda else mx.cpu()
51
52### Prepare data
53
54mnist = mx.test_utils.get_mnist() # Each pixel has a value in [0, 1].
55mnist_train_data = mnist['train_data']
56mnist_test_data = mnist['test_data']
57img_height = mnist_train_data.shape[2]
58img_width = mnist_train_data.shape[3]
59num_visible = img_width * img_height
60
61# The iterators generate arrays with shape (batch_size, num_channel = 1, height = 28, width = 28)
62train_iter = mx.io.NDArrayIter(
63    data={'data': mnist_train_data},
64    batch_size=args.batch_size,
65    shuffle=True)
66test_iter = mx.io.NDArrayIter(
67    data={'data': mnist_test_data},
68    batch_size=args.batch_size,
69    shuffle=True)
70
71
72### Define symbols
73
74data = mx.sym.Variable('data') # (batch_size, num_channel = 1, height, width)
75flattened_data = mx.sym.flatten(data=data) # (batch_size, num_channel * height * width)
76visible_layer_bias = mx.sym.Variable('visible_layer_bias', init=mx.init.Normal(sigma=.01))
77hidden_layer_bias = mx.sym.Variable('hidden_layer_bias', init=mx.init.Normal(sigma=.01))
78interaction_weight = mx.sym.Variable('interaction_weight', init=mx.init.Normal(sigma=.01))
79aux_hidden_layer_sample = mx.sym.Variable('aux_hidden_layer_sample', init=mx.init.Normal(sigma=.01))
80aux_hidden_layer_prob_1 = mx.sym.Variable('aux_hidden_layer_prob_1', init=mx.init.Constant(0))
81
82
83### Train
84
85rbm = mx.sym.Custom(
86    flattened_data,
87    visible_layer_bias,
88    hidden_layer_bias,
89    interaction_weight,
90    aux_hidden_layer_sample,
91    aux_hidden_layer_prob_1,
92    num_hidden=args.num_hidden,
93    k=args.k,
94    for_training=True,
95    op_type='BinaryRBM',
96    name='rbm')
97model = mx.mod.Module(symbol=rbm, context=ctx, data_names=['data'], label_names=None)
98model.bind(data_shapes=train_iter.provide_data)
99model.init_params()
100model.init_optimizer(optimizer='sgd', optimizer_params={'learning_rate': args.learning_rate, 'momentum': args.momentum})
101
102for epoch in range(args.num_epoch):
103    # Update parameters
104    train_iter.reset()
105    for batch in train_iter:
106        model.forward(batch)
107        model.backward()
108        model.update()
109    mx.nd.waitall()
110
111    # Monitor the performace of the model
112    params = model.get_params()[0]
113    param_visible_layer_bias = params['visible_layer_bias'].as_in_context(ctx)
114    param_hidden_layer_bias = params['hidden_layer_bias'].as_in_context(ctx)
115    param_interaction_weight = params['interaction_weight'].as_in_context(ctx)
116    test_iter.reset()
117    test_log_likelihood, _ = binary_rbm.estimate_log_likelihood(
118        param_visible_layer_bias, param_hidden_layer_bias, param_interaction_weight,
119        args.ais_batch_size, args.ais_num_batch, args.ais_intermediate_steps, args.ais_burn_in_steps, test_iter, ctx)
120    train_iter.reset()
121    train_log_likelihood, _ = binary_rbm.estimate_log_likelihood(
122        param_visible_layer_bias, param_hidden_layer_bias, param_interaction_weight,
123        args.ais_batch_size, args.ais_num_batch, args.ais_intermediate_steps, args.ais_burn_in_steps, train_iter, ctx)
124    print("Epoch %d completed with test log-likelihood %f and train log-likelihood %f" % (epoch, test_log_likelihood, train_log_likelihood))
125
126### Show some samples.
127
128# Each sample is obtained by 3000 steps of Gibbs sampling starting from a real sample.
129# Starting from the real data is just for convenience of implmentation.
130# There must be no correlation between the initial states and the resulting samples.
131# You can start from random states and run the Gibbs chain for sufficiently long time.
132
133print("Preparing showcase")
134
135showcase_gibbs_sampling_steps = 3000
136showcase_num_samples_w = 15
137showcase_num_samples_h = 15
138showcase_num_samples = showcase_num_samples_w * showcase_num_samples_h
139showcase_img_shape = (showcase_num_samples_h * img_height, 2 * showcase_num_samples_w * img_width)
140showcase_img_column_shape = (showcase_num_samples_h * img_height, img_width)
141
142params = model.get_params()[0] # We don't need aux states here
143showcase_rbm = mx.sym.Custom(
144    flattened_data,
145    visible_layer_bias,
146    hidden_layer_bias,
147    interaction_weight,
148    num_hidden=args.num_hidden,
149    k=showcase_gibbs_sampling_steps,
150    for_training=False,
151    op_type='BinaryRBM',
152    name='showcase_rbm')
153showcase_iter = mx.io.NDArrayIter(
154    data={'data': mnist['train_data']},
155    batch_size=showcase_num_samples_h,
156    shuffle=True)
157showcase_model = mx.mod.Module(symbol=showcase_rbm, context=ctx, data_names=['data'], label_names=None)
158showcase_model.bind(data_shapes=showcase_iter.provide_data, for_training=False)
159showcase_model.set_params(params, aux_params=None)
160showcase_img = np.zeros(showcase_img_shape)
161for sample_batch, i, data_batch in showcase_model.iter_predict(eval_data=showcase_iter, num_batch=showcase_num_samples_w):
162    # Each pixel is the probability that the unit is 1.
163    showcase_img[:, i * img_width : (i + 1) * img_width] = data_batch.data[0].reshape(showcase_img_column_shape).asnumpy()
164    showcase_img[:, (showcase_num_samples_w + i) * img_width : (showcase_num_samples_w + i + 1) * img_width
165                ] = sample_batch[0].reshape(showcase_img_column_shape).asnumpy()
166s = plt.imshow(showcase_img, cmap='gray')
167plt.axis('off')
168plt.axvline(showcase_num_samples_w * img_width, color='y')
169plt.show(s)
170
171print("Done")
172