1"""
2
3    >>> from pybrain.tools.shortcuts import buildNetwork
4    >>> from test_recurrent_network import buildRecurrentNetwork
5    >>> from test_peephole_lstm import buildMinimalLSTMNetwork
6    >>> from test_peephole_mdlstm import buildMinimalMDLSTMNetwork
7    >>> from test_nested_network import buildNestedNetwork
8    >>> from test_simple_lstm_network import buildSimpleLSTMNetwork
9    >>> from test_simple_mdlstm import buildSimpleMDLSTMNetwork
10    >>> from test_swiping_network import buildSwipingNetwork
11    >>> from test_shared_connections import buildSharedCrossedNetwork
12    >>> from test_sliced_connections import buildSlicedNetwork
13    >>> from test_borderswipingnetwork import buildSimpleBorderSwipingNet
14
15Test a number of network architectures, and compare if they produce the same output,
16whether the Python implementation is used, or CTYPES.
17
18Use the network construction scripts in other test files to build a number of networks,
19and then test the equivalence of each.
20
21Simple net
22    >>> testEquivalence(buildNetwork(2,2))
23    True
24
25A lot of layers
26    >>> net = buildNetwork(2,3,4,3,2,3,4,3,2)
27    >>> testEquivalence(net)
28    True
29
30Nonstandard components
31    >>> from pybrain.structure import TanhLayer
32    >>> net = buildNetwork(2,3,2, bias = True, outclass = TanhLayer)
33    >>> testEquivalence(net)
34    True
35
36Shared connections
37    >>> net = buildSharedCrossedNetwork()
38    >>> testEquivalence(net)
39    True
40
41Sliced connections
42    >>> net = buildSlicedNetwork()
43    >>> testEquivalence(net)
44    True
45
46Nested networks (not supposed to work yet!)
47    >>> net = buildNestedNetwork()
48    >>> testEquivalence(net)
49    Network cannot be converted.
50
51Recurrent networks
52    >>> net = buildRecurrentNetwork()
53    >>> net.name = '22'
54    >>> net.params[:] = [1,1,0.5]
55    >>> testEquivalence(net)
56    True
57
58Swiping networks
59    >>> net = buildSwipingNetwork()
60    >>> testEquivalence(net)
61    True
62
63Border-swiping networks
64    >>> net = buildSimpleBorderSwipingNet()
65    >>> testEquivalence(net)
66    True
67
68Lstm
69    >>> net = buildSimpleLSTMNetwork()
70    >>> testEquivalence(net)
71    True
72
73Mdlstm
74    >>> net = buildSimpleMDLSTMNetwork()
75    >>> testEquivalence(net)
76    True
77
78Lstm with peepholes
79    >>> net = buildMinimalLSTMNetwork(True)
80    >>> testEquivalence(net)
81    True
82
83Mdlstm with peepholes
84    >>> net = buildMinimalMDLSTMNetwork(True)
85    >>> testEquivalence(net)
86    True
87
88
89TODO:
90- heavily nested
91- exotic module use
92
93"""
94
95__author__ = 'Tom Schaul, tom@idsia.ch'
96_dependencies = ['arac']
97
98from pybrain.tests.helpers import buildAppropriateDataset, epsilonCheck
99from pybrain.tests import runModuleTestSuite
100
101def testEquivalence(net):
102    cnet = net.convertToFastNetwork()
103    if cnet == None:
104        return None
105    ds = buildAppropriateDataset(net)
106    if net.sequential:
107        for seq in ds:
108            net.reset()
109            cnet.reset()
110            for input, _ in seq:
111                res = net.activate(input)
112                cres = cnet.activate(input)
113                if net.name == '22':
114                    h = net['hidden0']
115                    ch = cnet['hidden0']
116                    print(('ni', input, net.inputbuffer.T))
117                    print(('ci', input, cnet.inputbuffer.T))
118                    print(('hni', h.inputbuffer.T[0]))
119                    print(('hci', ch.inputbuffer.T[0]))
120                    print(('hnout', h.outputbuffer.T[0]))
121                    print(('hcout', ch.outputbuffer.T[0]))
122                    print()
123
124    else:
125        for input, _ in ds:
126            res = net.activate(input)
127            cres = cnet.activate(input)
128    if epsilonCheck(sum(res - cres), 0.001):
129        return True
130    else:
131        print(('in-net', net.inputbuffer.T))
132        print(('in-arac', cnet.inputbuffer.T))
133        print(('out-net', net.outputbuffer.T))
134        print(('out-arac', cnet.outputbuffer.T))
135        return (res, cres)
136
137
138if __name__ == "__main__":
139    runModuleTestSuite(__import__('__main__'))
140
141