1from __future__ import print_function 2 3#!/usr/bin/env python 4""" A simple recurrent neural network that detects parity for arbitrary sequences. """ 5 6__author__ = 'Tom Schaul (tom@idsia.ch)' 7 8from datasets import ParityDataSet #@UnresolvedImport 9from pybrain.supervised.trainers.backprop import BackpropTrainer 10from pybrain.structure import RecurrentNetwork, LinearLayer, TanhLayer, BiasUnit, FullConnection 11 12 13def buildParityNet(): 14 net = RecurrentNetwork() 15 net.addInputModule(LinearLayer(1, name = 'i')) 16 net.addModule(TanhLayer(2, name = 'h')) 17 net.addModule(BiasUnit('bias')) 18 net.addOutputModule(TanhLayer(1, name = 'o')) 19 net.addConnection(FullConnection(net['i'], net['h'])) 20 net.addConnection(FullConnection(net['bias'], net['h'])) 21 net.addConnection(FullConnection(net['bias'], net['o'])) 22 net.addConnection(FullConnection(net['h'], net['o'])) 23 net.addRecurrentConnection(FullConnection(net['o'], net['h'])) 24 net.sortModules() 25 26 p = net.params 27 p[:] = [-0.5, -1.5, 1, 1, -1, 1, 1, -1, 1] 28 p *= 10. 29 30 return net 31 32def evalRnnOnSeqDataset(net, DS, verbose = False, silent = False): 33 """ evaluate the network on all the sequences of a dataset. """ 34 r = 0. 35 samples = 0. 36 for seq in DS: 37 net.reset() 38 for i, t in seq: 39 res = net.activate(i) 40 if verbose: 41 print(t, res) 42 r += sum((t-res)**2) 43 samples += 1 44 if verbose: 45 print('-'*20) 46 r /= samples 47 if not silent: 48 print('MSE:', r) 49 return r 50 51if __name__ == "__main__": 52 N = buildParityNet() 53 DS = ParityDataSet() 54 evalRnnOnSeqDataset(N, DS, verbose = True) 55 print('(preset weights)') 56 N.randomize() 57 evalRnnOnSeqDataset(N, DS) 58 print('(random weights)') 59 60 61 # Backprop improves the network performance, and sometimes even finds the global optimum. 62 N.reset() 63 bp = BackpropTrainer(N, DS, verbose = True) 64 bp.trainEpochs(5000) 65 evalRnnOnSeqDataset(N, DS) 66 print('(backprop-trained weights)') 67