1#
2#  Copyright (C) 2000-2008  greg Landrum
3#
4""" Contains the class _NetNode_ which is used to represent nodes in neural nets
5
6**Network Architecture:**
7
8  A tacit assumption in all of this stuff is that we're dealing with
9  feedforward networks.
10
11  The network itself is stored as a list of _NetNode_ objects.  The list
12  is ordered in the sense that nodes in earlier/later layers than a
13  given node are guaranteed to come before/after that node in the list.
14  This way we can easily generate the values of each node by moving
15  sequentially through the list, we're guaranteed that every input for a
16  node has already been filled in.
17
18  Each node stores a list (_inputNodes_) of indices of its inputs in the
19  main node list.
20
21"""
22import numpy
23from . import ActFuncs
24
25
26# FIX: this class has not been updated to new-style classes
27# (RD Issue380) because that would break all of our legacy pickled
28# data. Until a solution is found for this breakage, an update is
29# impossible.
30class NetNode:
31  """ a node in a neural network
32
33  """
34
35  def Eval(self, valVect):
36    """Given a set of inputs (valVect), returns the output of this node
37
38     **Arguments**
39
40      - valVect: a list of inputs
41
42     **Returns**
43
44        the result of running the values in valVect through this node
45
46    """
47    if self.inputNodes and len(self.inputNodes) != 0:
48      # grab our list of weighted inputs
49      inputs = numpy.take(valVect, self.inputNodes)
50      # weight them
51      inputs = self.weights * inputs
52      # run that through the activation function
53      val = self.actFunc(sum(inputs))
54    else:
55      val = 1
56    # put our value in the list and return it (just in case)
57    valVect[self.nodeIndex] = val
58    return val
59
60  def SetInputs(self, inputNodes):
61    """ Sets the input list
62
63      **Arguments**
64
65        - inputNodes: a list of _NetNode_s which are to be used as inputs
66
67      **Note**
68
69        If this _NetNode_ already has weights set and _inputNodes_ is a different length,
70        this will bomb out with an assertion.
71
72    """
73    if self.weights is not None:
74      assert len(self.weights) == len(inputNodes), \
75             'lengths of weights and nodes do not match'
76    self.inputNodes = inputNodes[:]
77
78  def GetInputs(self):
79    """ returns the input list
80
81    """
82    return self.inputNodes
83
84  def SetWeights(self, weights):
85    """ Sets the weight list
86
87      **Arguments**
88
89        - weights: a list of values which are to be used as weights
90
91      **Note**
92
93        If this _NetNode_ already has _inputNodes_  and _weights_ is a different length,
94        this will bomb out with an assertion.
95
96    """
97    if self.inputNodes:
98      assert len(weights) == len(self.inputNodes),\
99             'lengths of weights and nodes do not match'
100    self.weights = numpy.array(weights)
101
102  def GetWeights(self):
103    """ returns the weight list
104
105    """
106    return self.weights
107
108  def __init__(self, nodeIndex, nodeList, inputNodes=None, weights=None, actFunc=ActFuncs.Sigmoid,
109               actFuncParms=()):
110    """ Constructor
111
112      **Arguments**
113
114        - nodeIndex: the integer index of this node in _nodeList_
115
116        - nodeList: the list of other _NetNodes_ already in the network
117
118        - inputNodes: a list of this node's inputs
119
120        - weights: a list of this node's weights
121
122        - actFunc: the activation function to be used here.  Must support the API
123            of _ActFuncs.ActFunc_.
124
125        - actFuncParms: a tuple of extra arguments to be passed to the activation function
126            constructor.
127
128      **Note**
129        There should be only one copy of _inputNodes_, every _NetNode_ just has a pointer
130        to it so that changes made at one node propagate automatically to the others.
131
132    """
133    if inputNodes and weights:
134      assert (len(weights) == len(inputNodes))
135    if weights:
136      self.weights = numpy.array(weights)
137    else:
138      self.weights = None
139    if inputNodes:
140      self.inputNodes = inputNodes[:]
141    else:
142      self.inputNodes = None
143
144    self.nodeIndex = nodeIndex
145    # there's only one of these, everybody has a pointer to it.
146    self.nodeList = nodeList
147
148    self.actFunc = actFunc(*actFuncParms)
149