1# Licensed to the Apache Software Foundation (ASF) under one 2# or more contributor license agreements. See the NOTICE file 3# distributed with this work for additional information 4# regarding copyright ownership. The ASF licenses this file 5# to you under the Apache License, Version 2.0 (the 6# "License"); you may not use this file except in compliance 7# with the License. You may obtain a copy of the License at 8# 9# http://www.apache.org/licenses/LICENSE-2.0 10# 11# Unless required by applicable law or agreed to in writing, 12# software distributed under the License is distributed on an 13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14# KIND, either express or implied. See the License for the 15# specific language governing permissions and limitations 16# under the License. 17 18from utils import define_qfunc 19import mxnet as mx 20 21 22class QFunc(object): 23 """ 24 Base class for Q-Value Function. 25 """ 26 27 def __init__(self, env_spec): 28 29 self.env_spec = env_spec 30 31 def get_qvals(self, obs, act): 32 33 raise NotImplementedError 34 35 36class ContinuousMLPQ(QFunc): 37 """ 38 Continuous Multi-Layer Perceptron Q-Value Network 39 for determnistic policy training. 40 """ 41 42 def __init__( 43 self, 44 env_spec): 45 46 super(ContinuousMLPQ, self).__init__(env_spec) 47 48 self.obs = mx.symbol.Variable("obs") 49 self.act = mx.symbol.Variable("act") 50 self.qval = define_qfunc(self.obs, self.act) 51 self.yval = mx.symbol.Variable("yval") 52 53 def get_output_symbol(self): 54 55 return self.qval 56 57 def get_loss_symbols(self): 58 59 return {"qval": self.qval, 60 "yval": self.yval} 61 62 def define_loss(self, loss_exp): 63 64 self.loss = mx.symbol.MakeLoss(loss_exp, name="qfunc_loss") 65 self.loss = mx.symbol.Group([self.loss, mx.symbol.BlockGrad(self.qval)]) 66 67 def define_exe(self, ctx, init, updater, input_shapes=None, args=None, 68 grad_req=None): 69 70 # define an executor, initializer and updater for batch version loss 71 self.exe = self.loss.simple_bind(ctx=ctx, **input_shapes) 72 self.arg_arrays = self.exe.arg_arrays 73 self.grad_arrays = self.exe.grad_arrays 74 self.arg_dict = self.exe.arg_dict 75 76 for name, arr in self.arg_dict.items(): 77 if name not in input_shapes: 78 init(name, arr) 79 80 self.updater = updater 81 82 def update_params(self, obs, act, yval): 83 84 self.arg_dict["obs"][:] = obs 85 self.arg_dict["act"][:] = act 86 self.arg_dict["yval"][:] = yval 87 88 self.exe.forward(is_train=True) 89 self.exe.backward() 90 91 for i, pair in enumerate(zip(self.arg_arrays, self.grad_arrays)): 92 weight, grad = pair 93 self.updater(i, grad, weight) 94 95 def get_qvals(self, obs, act): 96 97 self.exe.arg_dict["obs"][:] = obs 98 self.exe.arg_dict["act"][:] = act 99 self.exe.forward(is_train=False) 100 101 return self.exe.outputs[1].asnumpy() 102 103 104