1# Licensed to the Apache Software Foundation (ASF) under one 2# or more contributor license agreements. See the NOTICE file 3# distributed with this work for additional information 4# regarding copyright ownership. The ASF licenses this file 5# to you under the Apache License, Version 2.0 (the 6# "License"); you may not use this file except in compliance 7# with the License. You may obtain a copy of the License at 8# 9# http://www.apache.org/licenses/LICENSE-2.0 10# 11# Unless required by applicable law or agreed to in writing, 12# software distributed under the License is distributed on an 13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14# KIND, either express or implied. See the License for the 15# specific language governing permissions and limitations 16# under the License. 17 18from itertools import chain 19import numpy as np 20import scipy.signal 21import mxnet as mx 22 23 24class Agent(object): 25 def __init__(self, input_size, act_space, config): 26 super(Agent, self).__init__() 27 self.input_size = input_size 28 self.num_envs = config.num_envs 29 self.ctx = config.ctx 30 self.act_space = act_space 31 self.config = config 32 33 # Shared network. 34 net = mx.sym.Variable('data') 35 net = mx.sym.FullyConnected( 36 data=net, name='fc1', num_hidden=config.hidden_size, no_bias=True) 37 net = mx.sym.Activation(data=net, name='relu1', act_type="relu") 38 39 # Policy network. 40 policy_fc = mx.sym.FullyConnected( 41 data=net, name='policy_fc', num_hidden=act_space, no_bias=True) 42 policy = mx.sym.SoftmaxActivation(data=policy_fc, name='policy') 43 policy = mx.sym.clip(data=policy, a_min=1e-5, a_max=1 - 1e-5) 44 log_policy = mx.sym.log(data=policy, name='log_policy') 45 out_policy = mx.sym.BlockGrad(data=policy, name='out_policy') 46 47 # Negative entropy. 48 neg_entropy = policy * log_policy 49 neg_entropy = mx.sym.MakeLoss( 50 data=neg_entropy, grad_scale=config.entropy_wt, name='neg_entropy') 51 52 # Value network. 53 value = mx.sym.FullyConnected(data=net, name='value', num_hidden=1) 54 55 self.sym = mx.sym.Group([log_policy, value, neg_entropy, out_policy]) 56 self.model = mx.mod.Module(self.sym, data_names=('data',), 57 label_names=None) 58 59 self.paralell_num = config.num_envs * config.t_max 60 self.model.bind( 61 data_shapes=[('data', (self.paralell_num, input_size))], 62 label_shapes=None, 63 grad_req="write") 64 65 self.model.init_params(config.init_func) 66 67 optimizer_params = {'learning_rate': config.learning_rate, 68 'rescale_grad': 1.0} 69 if config.grad_clip: 70 optimizer_params['clip_gradient'] = config.clip_magnitude 71 72 self.model.init_optimizer( 73 kvstore='local', optimizer=config.update_rule, 74 optimizer_params=optimizer_params) 75 76 def act(self, ps): 77 us = np.random.uniform(size=ps.shape[0])[:, np.newaxis] 78 as_ = (np.cumsum(ps, axis=1) > us).argmax(axis=1) 79 return as_ 80 81 def train_step(self, env_xs, env_as, env_rs, env_vs): 82 # NOTE(reed): Reshape to set the data shape. 83 self.model.reshape([('data', (len(env_xs), self.input_size))]) 84 85 xs = mx.nd.array(env_xs, ctx=self.ctx) 86 as_ = np.array(list(chain.from_iterable(env_as))) 87 88 # Compute discounted rewards and advantages. 89 advs = [] 90 gamma, lambda_ = self.config.gamma, self.config.lambda_ 91 for i in range(len(env_vs)): 92 # Compute advantages using Generalized Advantage Estimation; 93 # see eqn. (16) of [Schulman 2016]. 94 delta_t = (env_rs[i] + gamma*np.array(env_vs[i][1:]) - 95 np.array(env_vs[i][:-1])) 96 advs.extend(self._discount(delta_t, gamma * lambda_)) 97 98 # Negative generalized advantage estimations. 99 neg_advs_v = -np.asarray(advs) 100 101 # NOTE(reed): Only keeping the grads for selected actions. 102 neg_advs_np = np.zeros((len(advs), self.act_space), dtype=np.float32) 103 neg_advs_np[np.arange(neg_advs_np.shape[0]), as_] = neg_advs_v 104 neg_advs = mx.nd.array(neg_advs_np, ctx=self.ctx) 105 106 # NOTE(reed): The grads of values is actually negative advantages. 107 v_grads = mx.nd.array(self.config.vf_wt * neg_advs_v[:, np.newaxis], 108 ctx=self.ctx) 109 110 data_batch = mx.io.DataBatch(data=[xs], label=None) 111 self._forward_backward(data_batch=data_batch, 112 out_grads=[neg_advs, v_grads]) 113 114 self._update_params() 115 116 def _discount(self, x, gamma): 117 return scipy.signal.lfilter([1], [1, -gamma], x[::-1], axis=0)[::-1] 118 119 def _forward_backward(self, data_batch, out_grads=None): 120 self.model.forward(data_batch, is_train=True) 121 self.model.backward(out_grads=out_grads) 122 123 def _update_params(self): 124 self.model.update() 125 self.model._sync_params_from_devices() 126