1__author__ = ('Christian Osendorfer, osendorf@in.tum.de;'
2              'Justin S Bayer, bayerj@in.tum.de'
3              'SUN Yi, yi@idsia.ch')
4
5from scipy import random, outer, zeros, ones
6
7from pybrain.datasets import SupervisedDataSet, UnsupervisedDataSet
8from pybrain.supervised.trainers import Trainer
9from pybrain.utilities import abstractMethod
10
11
12class RbmGibbsTrainerConfig:
13    def __init__(self):
14        self.batchSize = 10		# how many samples in a batch
15
16        # training rate
17        self.rWeights = 0.1
18        self.rHidBias = 0.1
19        self.rVisBias = 0.1
20
21        # Several configurations, I have no idea why they are here...
22        self.weightCost = 0.0002
23
24        self.iniMm = 0.5		# initial momentum
25        self.finMm = 0.9		# final momentum
26        self.mmSwitchIter = 5	# at which iteration we switch the momentum
27        self.maxIter = 9		# how many iterations
28
29        self.visibleDistribution = 'bernoulli'
30
31
32class RbmGibbsTrainer(Trainer):
33    """Class for training rbms with contrastive divergence."""
34
35    def __init__(self, rbm, dataset, cfg=None):
36        self.rbm = rbm
37        self.invRbm = rbm.invert()
38        self.dataset = dataset
39        self.cfg = RbmGibbsTrainerConfig() if cfg is None else cfg
40
41        if isinstance(self.dataset, SupervisedDataSet):
42            self.datasetField = 'input'
43        elif isinstance(self.dataset, UnsupervisedDataSet):
44            self.datasetField = 'sample'
45
46    def train(self):
47        self.trainOnDataset(self.dataset)
48
49    def trainOnDataset(self, dataset):
50        """This function trains the RBM using the same algorithm and
51        implementation presented in:
52        http://www.cs.toronto.edu/~hinton/MatlabForSciencePaper.html"""
53        cfg = self.cfg
54        for rows in dataset.randomBatches(self.datasetField, cfg.batchSize):
55            olduw, olduhb, olduvb = \
56                zeros((self.rbm.visibleDim, self.rbm.hiddenDim)), \
57                zeros(self.rbm.hiddenDim), zeros(self.rbm.visibleDim)
58
59            for t in range(cfg.maxIter):
60                #print("*** Iteration %2d **************************************" % t)
61
62                params = self.rbm.params
63                params = params.reshape((self.rbm.visibleDim, self.rbm.hiddenDim))
64                biasParams = self.rbm.biasParams
65
66                mm = cfg.iniMm if t < cfg.mmSwitchIter else cfg.finMm
67
68                w, hb, vb = self.calcUpdateByRows(rows)
69
70                #print("Delta: ")
71                #print("Weight: ",)
72                #print(w)
73                #print("Visible bias: ",)
74                #print(vb)
75                #print("Hidden bias: ",)
76                #print(hb)
77                #print("")
78
79                olduw = uw = olduw * mm + \
80                	cfg.rWeights * (w - cfg.weightCost * params)
81                olduhb = uhb = olduhb * mm + cfg.rHidBias * hb
82                olduvb = uvb = olduvb * mm + cfg.rVisBias * vb
83
84                #print("Delta after momentum: ")
85                #print("Weight: ",)
86                #print(uw)
87                #print("Visible bias: ",)
88                #print(uvb)
89                #print("Hidden bias: ",)
90                #print(uhb)
91                #print("")
92
93                # update the parameters of the original rbm
94                params += uw
95                biasParams += uhb
96
97                # Create a new inverted rbm with correct parameters
98                invBiasParams = self.invRbm.biasParams
99                invBiasParams += uvb
100                self.invRbm = self.rbm.invert()
101                self.invRbm.biasParams[:] = invBiasParams
102
103                #print("Updated ")
104                #print("Weight: ",)
105                #print(self.rbm.connections[self.rbm['visible']][0].params.reshape( \)
106                #    (self.rbm.indim, self.rbm.outdim))
107                #print("Visible bias: ",)
108                #print(self.invRbm.connections[self.invRbm['bias']][0].params)
109                #print("Hidden bias: ",)
110                #print(self.rbm.connections[self.rbm['bias']][0].params)
111                #print("")
112
113    def calcUpdateByRow(self, row):
114        """This function trains the RBM using only one data row.
115        Return a 3-tuple consiting of updates for (weightmatrix,
116        hidden bias weights, visible bias weights)."""
117
118        # a) positive phase
119        poshp = self.rbm.activate(row)	# compute the posterior probability
120        pos = outer(row, poshp)       	# fraction from the positive phase
121        poshb = poshp
122        posvb = row
123
124        # b) the sampling & reconstruction
125        sampled = self.sampler(poshp)
126        recon = self.invRbm.activate(sampled)	# the re-construction of data
127
128        # c) negative phase
129        neghp = self.rbm.activate(recon)
130        neg = outer(recon, neghp)
131        neghb = neghp
132        negvb = recon
133
134        # compute the raw delta
135        # !!! note that this delta is only the 'theoretical' delta
136        return self.updater(pos, neg, poshb, neghb, posvb, negvb)
137
138    def sampler(self, probabilities):
139        abstractMethod()
140
141    def updater(self, pos, neg, poshb, neghb, posvb, negvb):
142        abstractMethod()
143
144    def calcUpdateByRows(self, rows):
145        """Return a 3-tuple constisting of update for (weightmatrix,
146        hidden bias weights, visible bias weights)."""
147
148        delta_w, delta_hb, delta_vb = \
149            zeros((self.rbm.visibleDim, self.rbm.hiddenDim)), \
150            zeros(self.rbm.hiddenDim), zeros(self.rbm.visibleDim)
151
152        for row in rows:
153            dw, dhb, dvb = self.calcUpdateByRow(row)
154            delta_w += dw
155            delta_hb += dhb
156            delta_vb += dvb
157
158        delta_w /= len(rows)
159        delta_hb /= len(rows)
160        delta_vb /= len(rows)
161
162        # !!! note that this delta is only the 'theoretical' delta
163        return delta_w, delta_hb, delta_vb
164
165
166class RbmBernoulliTrainer(RbmGibbsTrainer):
167
168    def sampler(self, probabilities):
169        result = probabilities > random.rand(self.rbm.hiddenDim)
170        return result.astype('int32')
171
172    def updater(self, pos, neg, poshb, neghb, posvb, negvb):
173        return pos - neg, poshb - neghb, posvb - negvb
174
175
176class RbmGaussTrainer(RbmGibbsTrainer):
177
178    def __init__(self, rbm, dataset, cfg=None):
179        super(RbmGaussTrainer, self).__init__(rbm, dataset, cfg)
180        #samples = self.dataset[self.datasetField]
181        # self.visibleVariances = samples.var(axis=0)
182        self.visibleVariances = ones(rbm.net.outdim)
183
184    def sampler(self, probabilities):
185        return random.normal(probabilities, self.visibleVariances)
186
187    def updater(self, pos, neg, poshb, neghb, posvb, negvb):
188        pos = pos / self.visibleVariances
189        return pos - neg, poshb - neghb, posvb - negvb
190
191
192
193