1# -*- coding: utf-8 -*-
2
3__author__ = 'Justin S Bayer, bayer.justin@googlemail.com'
4__version__ = '$Id$'
5
6
7from pybrain.structure import (LinearLayer, SigmoidLayer, FullConnection,
8                               BiasUnit, FeedForwardNetwork)
9
10
11class Rbm(object):
12    """Class that holds a network and offers some shortcuts."""
13
14    @property
15    def params(self):
16        return self.con.params
17        pass
18
19    @property
20    def biasParams(self):
21        return self.biascon.params
22
23    @property
24    def visibleDim(self):
25        return self.net.indim
26
27    @property
28    def hiddenDim(self):
29        return self.net.outdim
30
31    def __init__(self, net):
32        self.net = net
33        self.net.sortModules()
34        self.bias = [i for i in self.net.modules if isinstance(i, BiasUnit)][0]
35        self.biascon = self.net.connections[self.bias][0]
36        self.visible = net['visible']
37        self.hidden = net['hidden']
38        self.con = self.net.connections[self.visible][0]
39
40    @classmethod
41    def fromDims(cls, visibledim, hiddendim, params=None, biasParams=None):
42        """Return a restricted Boltzmann machine of the given dimensions with the
43        given distributions."""
44        net = FeedForwardNetwork()
45        bias = BiasUnit('bias')
46        visible = LinearLayer(visibledim, 'visible')
47        hidden = SigmoidLayer(hiddendim, 'hidden')
48        con1 = FullConnection(visible, hidden)
49        con2 = FullConnection(bias, hidden)
50        if params is not None:
51            con1.params[:] = params
52        if biasParams is not None:
53            con2.params[:] = biasParams
54
55        net.addInputModule(visible)
56        net.addModule(bias)
57        net.addOutputModule(hidden)
58        net.addConnection(con1)
59        net.addConnection(con2)
60        net.sortModules()
61        return cls(net)
62
63    @classmethod
64    def fromModules(cls, visible, hidden, bias, con, biascon):
65        net = FeedForwardNetwork()
66        net.addInputModule(visible)
67        net.addModule(bias)
68        net.addOutputModule(hidden)
69        net.addConnection(con)
70        net.addConnection(biascon)
71        net.sortModules()
72        return cls(net)
73
74    def invert(self):
75        """Return the inverse rbm."""
76        # TODO: check if shape is correct
77        return self.__class__.fromDims(self.hiddenDim, self.visibleDim,
78                                       params=self.params)
79
80    def activate(self, inpt):
81        return self.net.activate(inpt)
82