1from __future__ import print_function
2
3#!/usr/bin/env python
4""" A simple recurrent neural network that detects parity for arbitrary sequences. """
5
6__author__ = 'Tom Schaul (tom@idsia.ch)'
7
8from datasets import ParityDataSet #@UnresolvedImport
9from pybrain.supervised.trainers.backprop import BackpropTrainer
10from pybrain.structure import RecurrentNetwork, LinearLayer, TanhLayer, BiasUnit, FullConnection
11
12
13def buildParityNet():
14    net = RecurrentNetwork()
15    net.addInputModule(LinearLayer(1, name = 'i'))
16    net.addModule(TanhLayer(2, name = 'h'))
17    net.addModule(BiasUnit('bias'))
18    net.addOutputModule(TanhLayer(1, name = 'o'))
19    net.addConnection(FullConnection(net['i'], net['h']))
20    net.addConnection(FullConnection(net['bias'], net['h']))
21    net.addConnection(FullConnection(net['bias'], net['o']))
22    net.addConnection(FullConnection(net['h'], net['o']))
23    net.addRecurrentConnection(FullConnection(net['o'], net['h']))
24    net.sortModules()
25
26    p = net.params
27    p[:] = [-0.5, -1.5, 1, 1, -1, 1, 1, -1, 1]
28    p *= 10.
29
30    return net
31
32def evalRnnOnSeqDataset(net, DS, verbose = False, silent = False):
33    """ evaluate the network on all the sequences of a dataset. """
34    r = 0.
35    samples = 0.
36    for seq in DS:
37        net.reset()
38        for i, t in seq:
39            res = net.activate(i)
40            if verbose:
41                print(t, res)
42            r += sum((t-res)**2)
43            samples += 1
44        if verbose:
45            print('-'*20)
46    r /= samples
47    if not silent:
48        print('MSE:', r)
49    return r
50
51if __name__ == "__main__":
52    N = buildParityNet()
53    DS = ParityDataSet()
54    evalRnnOnSeqDataset(N, DS, verbose = True)
55    print('(preset weights)')
56    N.randomize()
57    evalRnnOnSeqDataset(N, DS)
58    print('(random weights)')
59
60
61    # Backprop improves the network performance, and sometimes even finds the global optimum.
62    N.reset()
63    bp = BackpropTrainer(N, DS, verbose = True)
64    bp.trainEpochs(5000)
65    evalRnnOnSeqDataset(N, DS)
66    print('(backprop-trained weights)')
67