1# -*- coding: utf-8 -*- 2 3__author__ = 'Justin S Bayer, bayer.justin@googlemail.com' 4__version__ = '$Id$' 5 6 7from pybrain.structure import (LinearLayer, SigmoidLayer, FullConnection, 8 BiasUnit, FeedForwardNetwork) 9 10 11class Rbm(object): 12 """Class that holds a network and offers some shortcuts.""" 13 14 @property 15 def params(self): 16 return self.con.params 17 pass 18 19 @property 20 def biasParams(self): 21 return self.biascon.params 22 23 @property 24 def visibleDim(self): 25 return self.net.indim 26 27 @property 28 def hiddenDim(self): 29 return self.net.outdim 30 31 def __init__(self, net): 32 self.net = net 33 self.net.sortModules() 34 self.bias = [i for i in self.net.modules if isinstance(i, BiasUnit)][0] 35 self.biascon = self.net.connections[self.bias][0] 36 self.visible = net['visible'] 37 self.hidden = net['hidden'] 38 self.con = self.net.connections[self.visible][0] 39 40 @classmethod 41 def fromDims(cls, visibledim, hiddendim, params=None, biasParams=None): 42 """Return a restricted Boltzmann machine of the given dimensions with the 43 given distributions.""" 44 net = FeedForwardNetwork() 45 bias = BiasUnit('bias') 46 visible = LinearLayer(visibledim, 'visible') 47 hidden = SigmoidLayer(hiddendim, 'hidden') 48 con1 = FullConnection(visible, hidden) 49 con2 = FullConnection(bias, hidden) 50 if params is not None: 51 con1.params[:] = params 52 if biasParams is not None: 53 con2.params[:] = biasParams 54 55 net.addInputModule(visible) 56 net.addModule(bias) 57 net.addOutputModule(hidden) 58 net.addConnection(con1) 59 net.addConnection(con2) 60 net.sortModules() 61 return cls(net) 62 63 @classmethod 64 def fromModules(cls, visible, hidden, bias, con, biascon): 65 net = FeedForwardNetwork() 66 net.addInputModule(visible) 67 net.addModule(bias) 68 net.addOutputModule(hidden) 69 net.addConnection(con) 70 net.addConnection(biascon) 71 net.sortModules() 72 return cls(net) 73 74 def invert(self): 75 """Return the inverse rbm.""" 76 # TODO: check if shape is correct 77 return self.__class__.fromDims(self.hiddenDim, self.visibleDim, 78 params=self.params) 79 80 def activate(self, inpt): 81 return self.net.activate(inpt) 82