1from __future__ import print_function
2
3__author__ = 'Tom Schaul, tom@idsia.ch and Daan Wiertra, daan@idsia.ch'
4
5from scipy import zeros, array, mean, randn, exp, dot, argmax
6
7from pybrain.datasets import ReinforcementDataSet, ImportanceDataSet, SequentialDataSet
8from pybrain.supervised import BackpropTrainer
9from pybrain.utilities import drawIndex
10from pybrain.rl.learners.directsearch.directsearch import DirectSearchLearner
11
12
13# TODO: greedy runs: start once in every possible starting state!
14# TODO: supervised: train-set, test-set, early stopping -> actual convergence!
15
16
17class RWR(DirectSearchLearner):
18    """ Reward-weighted regression.
19
20    The algorithm is currently limited to discrete-action episodic tasks, subclasses of POMDPTasks.
21    """
22
23    # parameters
24    batchSize = 20
25
26    # feedback settings
27    verbose = True
28    greedyRuns = 20
29    supervisedPlotting = False
30
31    # settings for the supervised training
32    learningRate = 0.005
33    momentum = 0.9
34    maxEpochs = 20
35    validationProportion = 0.33
36    continueEpochs = 2
37
38    # parameters for the variation that uses a value function
39    # TODO: split into 2 classes.
40    valueLearningRate = None
41    valueMomentum = None
42    #valueTrainEpochs = 5
43    resetAllWeights = False
44    netweights = 0.01
45
46    def __init__(self, net, task, valueNetwork=None, **args):
47        self.net = net
48        self.task = task
49        self.setArgs(**args)
50        if self.valueLearningRate == None:
51            self.valueLearningRate = self.learningRate
52        if self.valueMomentum == None:
53            self.valueMomentum = self.momentum
54        if self.supervisedPlotting:
55            from pylab import ion
56            ion()
57
58        # adaptive temperature:
59        self.tau = 1.
60
61        # prepare the datasets to be used
62        self.weightedDs = ImportanceDataSet(self.task.outdim, self.task.indim)
63        self.rawDs = ReinforcementDataSet(self.task.outdim, self.task.indim)
64        self.valueDs = SequentialDataSet(self.task.outdim, 1)
65
66        # prepare the supervised trainers
67        self.bp = BackpropTrainer(self.net, self.weightedDs, self.learningRate,
68                                  self.momentum, verbose=False,
69                                  batchlearning=True)
70
71        # CHECKME: outsource
72        self.vnet = valueNetwork
73        if valueNetwork != None:
74            self.vbp = BackpropTrainer(self.vnet, self.valueDs, self.valueLearningRate,
75                                       self.valueMomentum, verbose=self.verbose)
76
77        # keep information:
78        self.totalSteps = 0
79        self.totalEpisodes = 0
80
81    def shapingFunction(self, R):
82        return exp(self.tau * R)
83
84    def updateTau(self, R, U):
85        self.tau = sum(U) / dot((R - self.task.minReward), U)
86
87    def reset(self):
88        self.weightedDs.clear()
89        self.valueDs.clear()
90        self.rawDs.clear()
91        self.bp.momentumvector *= 0.0
92        if self.vnet != None:
93            self.vbp.momentumvector *= 0.0
94            if self.resetAllWeights:
95                self.vnet.params[:] = randn(len(self.vnet.params)) * self.netweights
96
97    def greedyEpisode(self):
98        """ run one episode with greedy decisions, return the list of rewards recieved."""
99        rewards = []
100        self.task.reset()
101        self.net.reset()
102        while not self.task.isFinished():
103            obs = self.task.getObservation()
104            act = self.net.activate(obs)
105            chosen = argmax(act)
106            self.task.performAction(chosen)
107            reward = self.task.getReward()
108            rewards.append(reward)
109        return rewards
110
111    def learn(self, batches):
112        self.greedyAvg = []
113        self.rewardAvg = []
114        self.lengthAvg = []
115        self.initr0Avg = []
116        for b in range(batches):
117            if self.verbose:
118                print()
119                print(('Batch', b + 1))
120            self.reset()
121            self.learnOneBatch()
122            self.totalEpisodes += self.batchSize
123
124            # greedy measure (avg over some greedy runs)
125            rws = 0.
126            for dummy in range(self.greedyRuns):
127                tmp = self.greedyEpisode()
128                rws += (sum(tmp) / float(len(tmp)))
129            self.greedyAvg.append(rws / self.greedyRuns)
130            if self.verbose:
131                print(('::', round(rws / self.greedyRuns, 5), '::'))
132
133    def learnOneBatch(self):
134        # collect a batch of runs as experience
135        r0s = []
136        lens = []
137        avgReward = 0.
138        for dummy in range(self.batchSize):
139            self.rawDs.newSequence()
140            self.valueDs.newSequence()
141            self.task.reset()
142            self.net.reset()
143            acts, obss, rewards = [], [], []
144            while not self.task.isFinished():
145                obs = self.task.getObservation()
146                act = self.net.activate(obs)
147                chosen = drawIndex(act)
148                self.task.performAction(chosen)
149                reward = self.task.getReward()
150                obss.append(obs)
151                y = zeros(len(act))
152                y[chosen] = 1
153                acts.append(y)
154                rewards.append(reward)
155            avgReward += sum(rewards) / float(len(rewards))
156
157            # compute the returns from the list of rewards
158            current = 0
159            returns = []
160            for r in reversed(rewards):
161                current *= self.task.discount
162                current += r
163                returns.append(current)
164            returns.reverse()
165            for i in range(len(obss)):
166                self.rawDs.addSample(obss[i], acts[i], returns[i])
167                self.valueDs.addSample(obss[i], returns[i])
168            r0s.append(returns[0])
169            lens.append(len(returns))
170
171        r0s = array(r0s)
172        self.totalSteps += sum(lens)
173        avgLen = sum(lens) / float(self.batchSize)
174        avgR0 = mean(r0s)
175        avgReward /= self.batchSize
176        if self.verbose:
177            print(('***', round(avgLen, 3), '***', '(avg init exp. return:', round(avgR0, 5), ')',))
178            print(('avg reward', round(avgReward, 5), '(tau:', round(self.tau, 3), ')'))
179            print(lens)
180        # storage:
181        self.rewardAvg.append(avgReward)
182        self.lengthAvg.append(avgLen)
183        self.initr0Avg.append(avgR0)
184
185
186#        if self.vnet == None:
187#            # case 1: no value estimator:
188
189        # prepare the dataset for training the acting network
190        shaped = self.shapingFunction(r0s)
191        self.updateTau(r0s, shaped)
192        shaped /= max(shaped)
193        for i, seq in enumerate(self.rawDs):
194            self.weightedDs.newSequence()
195            for sample in seq:
196                obs, act, dummy = sample
197                self.weightedDs.addSample(obs, act, shaped[i])
198
199#        else:
200#            # case 2: value estimator:
201#
202#
203#            # train the value estimating network
204#            if self.verbose: print('Old value error:  ', self.vbp.testOnData())
205#            self.vbp.trainEpochs(self.valueTrainEpochs)
206#            if self.verbose: print('New value error:  ', self.vbp.testOnData())
207#
208#            # produce the values and analyze
209#            rminusvs = []
210#            sizes = []
211#            for i, seq in enumerate(self.valueDs):
212#                self.vnet.reset()
213#                seq = list(seq)
214#                for sample in seq:
215#                    obs, ret = sample
216#                    val = self.vnet.activate(obs)
217#                    rminusvs.append(ret-val)
218#                    sizes.append(len(seq))
219#
220#            rminusvs = array(rminusvs)
221#            shapedRminusv = self.shapingFunction(rminusvs)
222#            # CHECKME: here?
223#            self.updateTau(rminusvs, shapedRminusv)
224#            shapedRminusv /= array(sizes)
225#            shapedRminusv /= max(shapedRminusv)
226#
227#            # prepare the dataset for training the acting network
228#            rvindex = 0
229#            for i, seq in enumerate(self.rawDs):
230#                self.weightedDs.newSequence()
231#                self.vnet.reset()
232#                for sample in seq:
233#                    obs, act, ret = sample
234#                    self.weightedDs.addSample(obs, act, shapedRminusv[rvindex])
235#                    rvindex += 1
236
237        # train the acting network
238        tmp1, tmp2 = self.bp.trainUntilConvergence(maxEpochs=self.maxEpochs,
239                                                   validationProportion=self.validationProportion,
240                                                   continueEpochs=self.continueEpochs,
241                                                   verbose=self.verbose)
242        if self.supervisedPlotting:
243            from pylab import plot, legend, figure, clf, draw
244            figure(1)
245            clf()
246            plot(tmp1, label='train')
247            plot(tmp2, label='valid')
248            legend()
249            draw()
250
251        return avgLen, avgR0
252
253
254