1# Licensed to the Apache Software Foundation (ASF) under one
2# or more contributor license agreements.  See the NOTICE file
3# distributed with this work for additional information
4# regarding copyright ownership.  The ASF licenses this file
5# to you under the Apache License, Version 2.0 (the
6# "License"); you may not use this file except in compliance
7# with the License.  You may obtain a copy of the License at
8#
9#   http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing,
12# software distributed under the License is distributed on an
13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14# KIND, either express or implied.  See the License for the
15# specific language governing permissions and limitations
16# under the License.
17
18from itertools import chain
19import numpy as np
20import scipy.signal
21import mxnet as mx
22
23
24class Agent(object):
25    def __init__(self, input_size, act_space, config):
26        super(Agent, self).__init__()
27        self.input_size = input_size
28        self.num_envs = config.num_envs
29        self.ctx = config.ctx
30        self.act_space = act_space
31        self.config = config
32
33        # Shared network.
34        net = mx.sym.Variable('data')
35        net = mx.sym.FullyConnected(
36            data=net, name='fc1', num_hidden=config.hidden_size, no_bias=True)
37        net = mx.sym.Activation(data=net, name='relu1', act_type="relu")
38
39        # Policy network.
40        policy_fc = mx.sym.FullyConnected(
41            data=net, name='policy_fc', num_hidden=act_space, no_bias=True)
42        policy = mx.sym.SoftmaxActivation(data=policy_fc, name='policy')
43        policy = mx.sym.clip(data=policy, a_min=1e-5, a_max=1 - 1e-5)
44        log_policy = mx.sym.log(data=policy, name='log_policy')
45        out_policy = mx.sym.BlockGrad(data=policy, name='out_policy')
46
47        # Negative entropy.
48        neg_entropy = policy * log_policy
49        neg_entropy = mx.sym.MakeLoss(
50            data=neg_entropy, grad_scale=config.entropy_wt, name='neg_entropy')
51
52        # Value network.
53        value = mx.sym.FullyConnected(data=net, name='value', num_hidden=1)
54
55        self.sym = mx.sym.Group([log_policy, value, neg_entropy, out_policy])
56        self.model = mx.mod.Module(self.sym, data_names=('data',),
57                                   label_names=None)
58
59        self.paralell_num = config.num_envs * config.t_max
60        self.model.bind(
61            data_shapes=[('data', (self.paralell_num, input_size))],
62            label_shapes=None,
63            grad_req="write")
64
65        self.model.init_params(config.init_func)
66
67        optimizer_params = {'learning_rate': config.learning_rate,
68                            'rescale_grad': 1.0}
69        if config.grad_clip:
70            optimizer_params['clip_gradient'] = config.clip_magnitude
71
72        self.model.init_optimizer(
73            kvstore='local', optimizer=config.update_rule,
74            optimizer_params=optimizer_params)
75
76    def act(self, ps):
77        us = np.random.uniform(size=ps.shape[0])[:, np.newaxis]
78        as_ = (np.cumsum(ps, axis=1) > us).argmax(axis=1)
79        return as_
80
81    def train_step(self, env_xs, env_as, env_rs, env_vs):
82        # NOTE(reed): Reshape to set the data shape.
83        self.model.reshape([('data', (len(env_xs), self.input_size))])
84
85        xs = mx.nd.array(env_xs, ctx=self.ctx)
86        as_ = np.array(list(chain.from_iterable(env_as)))
87
88        # Compute discounted rewards and advantages.
89        advs = []
90        gamma, lambda_ = self.config.gamma, self.config.lambda_
91        for i in range(len(env_vs)):
92            # Compute advantages using Generalized Advantage Estimation;
93            # see eqn. (16) of [Schulman 2016].
94            delta_t = (env_rs[i] + gamma*np.array(env_vs[i][1:]) -
95                       np.array(env_vs[i][:-1]))
96            advs.extend(self._discount(delta_t, gamma * lambda_))
97
98        # Negative generalized advantage estimations.
99        neg_advs_v = -np.asarray(advs)
100
101        # NOTE(reed): Only keeping the grads for selected actions.
102        neg_advs_np = np.zeros((len(advs), self.act_space), dtype=np.float32)
103        neg_advs_np[np.arange(neg_advs_np.shape[0]), as_] = neg_advs_v
104        neg_advs = mx.nd.array(neg_advs_np, ctx=self.ctx)
105
106        # NOTE(reed): The grads of values is actually negative advantages.
107        v_grads = mx.nd.array(self.config.vf_wt * neg_advs_v[:, np.newaxis],
108                              ctx=self.ctx)
109
110        data_batch = mx.io.DataBatch(data=[xs], label=None)
111        self._forward_backward(data_batch=data_batch,
112                               out_grads=[neg_advs, v_grads])
113
114        self._update_params()
115
116    def _discount(self, x, gamma):
117        return scipy.signal.lfilter([1], [1, -gamma], x[::-1], axis=0)[::-1]
118
119    def _forward_backward(self, data_batch, out_grads=None):
120        self.model.forward(data_batch, is_train=True)
121        self.model.backward(out_grads=out_grads)
122
123    def _update_params(self):
124        self.model.update()
125        self.model._sync_params_from_devices()
126