1# -*- coding: utf-8 -*-
2
3__author__ = 'Justin S Bayer, bayer.justin@googlemail.com'
4__version__ = '$Id$'
5
6
7import copy
8
9from pybrain.datasets import SupervisedDataSet, UnsupervisedDataSet
10from pybrain.structure import BiasUnit, FeedForwardNetwork, FullConnection
11from pybrain.structure.networks.rbm import Rbm
12from pybrain.structure.modules.neuronlayer import NeuronLayer
13from pybrain.supervised.trainers import Trainer
14from pybrain.unsupervised.trainers.rbm import (RbmBernoulliTrainer,
15                                               RbmGaussTrainer)
16
17
18class DeepBeliefTrainer(Trainer):
19    """Trainer for deep networks.
20
21    Trains the network by greedily training layer after layer with the
22    RbmGibbsTrainer.
23
24    The network that is being trained is assumed to be a chain of layers that
25    are connected with full connections and feature a bias each.
26
27    The behaviour of the trainer is undefined for other cases.
28    """
29
30    trainers = {
31        'bernoulli': RbmBernoulliTrainer,
32        'gauss': RbmGaussTrainer,
33    }
34
35    def __init__(self, net, dataset, epochs=50,
36                 cfg=None, distribution='bernoulli'):
37        if isinstance(dataset, SupervisedDataSet):
38            self.datasetfield = 'input'
39        elif isinstance(dataset, UnsupervisedDataSet):
40            self.datasetfield = 'sample'
41        else:
42            raise ValueError("Wrong dataset class.")
43        self.net = net
44        self.net.sortModules()
45        self.dataset = dataset
46        self.epochs = epochs
47        self.cfg = cfg
48        self.trainerKlass = self.trainers[distribution]
49
50    def trainRbm(self, rbm, dataset):
51        trainer = self.trainerKlass(rbm, dataset, self.cfg)
52        for _ in range(self.epochs):
53            trainer.train()
54        return rbm
55
56    def iterRbms(self):
57        """Yield every two layers as an rbm."""
58        layers = [i for i in self.net.modulesSorted
59                  if isinstance(i, NeuronLayer) and not isinstance(i, BiasUnit)]
60        # There will be a single bias.
61        bias = [i for i in self.net.modulesSorted if isinstance(i, BiasUnit)][0]
62        layercons = (self.net.connections[i][0] for i in layers)
63        # The biascons will not be sorted; we have to sort them to zip nicely
64        # with the corresponding layers.
65        biascons = self.net.connections[bias]
66        biascons.sort(key=lambda c: layers.index(c.outmod))
67        modules = list(zip(layers, layers[1:], layercons, biascons))
68        for visible, hidden, layercon, biascon in modules:
69            rbm = Rbm.fromModules(visible, hidden, bias,
70                                  layercon, biascon)
71            yield rbm
72
73    def train(self):
74        # We will build up a network piecewise in order to create a new dataset
75        # for each layer.
76        dataset = self.dataset
77        piecenet = FeedForwardNetwork()
78        piecenet.addInputModule(copy.deepcopy(self.net.inmodules[0]))
79        # Add a bias
80        bias = BiasUnit()
81        piecenet.addModule(bias)
82        # Add the first visible layer
83        firstRbm = next(self.iterRbms())
84        visible = copy.deepcopy(firstRbm.visible)
85        piecenet.addModule(visible)
86        # For saving the rbms and their inverses
87        self.invRbms = []
88        self.rbms = []
89        for rbm in self.iterRbms():
90            self.net.sortModules()
91            # Train the first layer with an rbm trainer for `epoch` epochs.
92            trainer = self.trainerKlass(rbm, dataset, self.cfg)
93            for _ in range(self.epochs):
94                trainer.train()
95            self.invRbms.append(trainer.invRbm)
96            self.rbms.append(rbm)
97            # Add the connections and the hidden layer of the rbm to the net.
98            hidden = copy.deepcopy(rbm.hidden)
99            biascon = FullConnection(bias, hidden)
100            biascon.params[:] = rbm.biasWeights
101            con = FullConnection(visible, hidden)
102            con.params[:] = rbm.weights
103
104            piecenet.addConnection(biascon)
105            piecenet.addConnection(con)
106            piecenet.addModule(hidden)
107            # Overwrite old outputs
108            piecenet.outmodules = [hidden]
109            piecenet.outdim = rbm.hiddenDim
110            piecenet.sortModules()
111
112            dataset = UnsupervisedDataSet(rbm.hiddenDim)
113            for sample, in self.dataset:
114                new_sample = piecenet.activate(sample)
115                dataset.addSample(new_sample)
116            visible = hidden
117