1__author__ = 'Daan Wierstra and Tom Schaul' 2 3from itertools import chain 4 5from scipy import zeros 6 7from pybrain.structure.networks.feedforward import FeedForwardNetwork 8from pybrain.structure.networks.recurrent import RecurrentNetwork 9from pybrain.structure.modules.neuronlayer import NeuronLayer 10from pybrain.structure.connections import FullConnection 11 12# CHECKME: allow modules that do not inherit from NeuronLayer? and treat them as single neurons? 13 14 15class NeuronDecomposableNetwork(object): 16 """ A Network, that allows accessing parameters decomposed by their 17 corresponding individual neuron. """ 18 19 # ESP style treatment: 20 espStyleDecomposition = True 21 22 def addModule(self, m): 23 assert isinstance(m, NeuronLayer) 24 super(NeuronDecomposableNetwork, self).addModule(m) 25 26 def sortModules(self): 27 super(NeuronDecomposableNetwork, self).sortModules() 28 self._constructParameterInfo() 29 30 # contains a list of lists of indices 31 self.decompositionIndices = {} 32 for neuron in self._neuronIterator(): 33 self.decompositionIndices[neuron] = [] 34 for w in range(self.paramdim): 35 inneuron, outneuron = self.paramInfo[w] 36 if self.espStyleDecomposition and outneuron[0] in self.outmodules: 37 self.decompositionIndices[inneuron].append(w) 38 else: 39 self.decompositionIndices[outneuron].append(w) 40 41 def _neuronIterator(self): 42 for m in self.modules: 43 for n in range(m.dim): 44 yield (m, n) 45 46 def _constructParameterInfo(self): 47 """ construct a dictionnary with information about each parameter: 48 The key is the index in self.params, and the value is a tuple containing 49 (inneuron, outneuron), where a neuron is a tuple of it's module and an index. 50 """ 51 self.paramInfo = {} 52 index = 0 53 for x in self._containerIterator(): 54 if isinstance(x, FullConnection): 55 for w in range(x.paramdim): 56 inbuf, outbuf = x.whichBuffers(w) 57 self.paramInfo[index + w] = ((x.inmod, x.inmod.whichNeuron(outputIndex=inbuf)), 58 (x.outmod, x.outmod.whichNeuron(inputIndex=outbuf))) 59 elif isinstance(x, NeuronLayer): 60 for n in range(x.paramdim): 61 self.paramInfo[index + n] = ((x, n), (x, n)) 62 else: 63 raise 64 index += x.paramdim 65 66 def getDecomposition(self): 67 """ return a list of arrays, each corresponding to one neuron's relevant parameters """ 68 res = [] 69 for neuron in self._neuronIterator(): 70 nIndices = self.decompositionIndices[neuron] 71 if len(nIndices) > 0: 72 tmp = zeros(len(nIndices)) 73 for i, ni in enumerate(nIndices): 74 tmp[i] = self.params[ni] 75 res.append(tmp) 76 return res 77 78 def setDecomposition(self, decomposedParams): 79 """ set parameters by neuron decomposition, 80 each corresponding to one neuron's relevant parameters """ 81 nindex = 0 82 for neuron in self._neuronIterator(): 83 nIndices = self.decompositionIndices[neuron] 84 if len(nIndices) > 0: 85 for i, ni in enumerate(nIndices): 86 self.params[ni] = decomposedParams[nindex][i] 87 nindex += 1 88 89 @staticmethod 90 def convertNormalNetwork(n): 91 """ convert a normal network into a decomposable one """ 92 if isinstance(n, RecurrentNetwork): 93 res = RecurrentDecomposableNetwork() 94 for c in n.recurrentConns: 95 res.addRecurrentConnection(c) 96 else: 97 res = FeedForwardDecomposableNetwork() 98 for m in n.inmodules: 99 res.addInputModule(m) 100 for m in n.outmodules: 101 res.addOutputModule(m) 102 for m in n.modules: 103 res.addModule(m) 104 for c in chain(*list(n.connections.values())): 105 res.addConnection(c) 106 res.name = n.name 107 res.sortModules() 108 return res 109 110 111class FeedForwardDecomposableNetwork(NeuronDecomposableNetwork, FeedForwardNetwork): 112 pass 113 114 115class RecurrentDecomposableNetwork(NeuronDecomposableNetwork, RecurrentNetwork): 116 pass 117