1__author__ = 'Daan Wierstra and Tom Schaul'
2
3from itertools import chain
4
5from scipy import zeros
6
7from pybrain.structure.networks.feedforward import FeedForwardNetwork
8from pybrain.structure.networks.recurrent import RecurrentNetwork
9from pybrain.structure.modules.neuronlayer import NeuronLayer
10from pybrain.structure.connections import FullConnection
11
12# CHECKME: allow modules that do not inherit from NeuronLayer? and treat them as single neurons?
13
14
15class NeuronDecomposableNetwork(object):
16    """ A Network, that allows accessing parameters decomposed by their
17    corresponding individual neuron. """
18
19    # ESP style treatment:
20    espStyleDecomposition = True
21
22    def addModule(self, m):
23        assert isinstance(m, NeuronLayer)
24        super(NeuronDecomposableNetwork, self).addModule(m)
25
26    def sortModules(self):
27        super(NeuronDecomposableNetwork, self).sortModules()
28        self._constructParameterInfo()
29
30        # contains a list of lists of indices
31        self.decompositionIndices = {}
32        for neuron in self._neuronIterator():
33            self.decompositionIndices[neuron] = []
34        for w in range(self.paramdim):
35            inneuron, outneuron = self.paramInfo[w]
36            if self.espStyleDecomposition and outneuron[0] in self.outmodules:
37                self.decompositionIndices[inneuron].append(w)
38            else:
39                self.decompositionIndices[outneuron].append(w)
40
41    def _neuronIterator(self):
42        for m in self.modules:
43            for n in range(m.dim):
44                yield (m, n)
45
46    def _constructParameterInfo(self):
47        """ construct a dictionnary with information about each parameter:
48        The key is the index in self.params, and the value is a tuple containing
49        (inneuron, outneuron), where a neuron is a tuple of it's module and an index.
50        """
51        self.paramInfo = {}
52        index = 0
53        for x in self._containerIterator():
54            if isinstance(x, FullConnection):
55                for w in range(x.paramdim):
56                    inbuf, outbuf = x.whichBuffers(w)
57                    self.paramInfo[index + w] = ((x.inmod, x.inmod.whichNeuron(outputIndex=inbuf)),
58                                               (x.outmod, x.outmod.whichNeuron(inputIndex=outbuf)))
59            elif isinstance(x, NeuronLayer):
60                for n in range(x.paramdim):
61                    self.paramInfo[index + n] = ((x, n), (x, n))
62            else:
63                raise
64            index += x.paramdim
65
66    def getDecomposition(self):
67        """ return a list of arrays, each corresponding to one neuron's relevant parameters """
68        res = []
69        for neuron in self._neuronIterator():
70            nIndices = self.decompositionIndices[neuron]
71            if len(nIndices) > 0:
72                tmp = zeros(len(nIndices))
73                for i, ni in enumerate(nIndices):
74                    tmp[i] = self.params[ni]
75                res.append(tmp)
76        return res
77
78    def setDecomposition(self, decomposedParams):
79        """ set parameters by neuron decomposition,
80        each corresponding to one neuron's relevant parameters """
81        nindex = 0
82        for neuron in self._neuronIterator():
83            nIndices = self.decompositionIndices[neuron]
84            if len(nIndices) > 0:
85                for i, ni in enumerate(nIndices):
86                    self.params[ni] = decomposedParams[nindex][i]
87                nindex += 1
88
89    @staticmethod
90    def convertNormalNetwork(n):
91        """ convert a normal network into a decomposable one """
92        if isinstance(n, RecurrentNetwork):
93            res = RecurrentDecomposableNetwork()
94            for c in n.recurrentConns:
95                res.addRecurrentConnection(c)
96        else:
97            res = FeedForwardDecomposableNetwork()
98        for m in n.inmodules:
99            res.addInputModule(m)
100        for m in n.outmodules:
101            res.addOutputModule(m)
102        for m in n.modules:
103            res.addModule(m)
104        for c in chain(*list(n.connections.values())):
105            res.addConnection(c)
106        res.name = n.name
107        res.sortModules()
108        return res
109
110
111class FeedForwardDecomposableNetwork(NeuronDecomposableNetwork, FeedForwardNetwork):
112    pass
113
114
115class RecurrentDecomposableNetwork(NeuronDecomposableNetwork, RecurrentNetwork):
116    pass
117