1# Licensed to the Apache Software Foundation (ASF) under one 2# or more contributor license agreements. See the NOTICE file 3# distributed with this work for additional information 4# regarding copyright ownership. The ASF licenses this file 5# to you under the Apache License, Version 2.0 (the 6# "License"); you may not use this file except in compliance 7# with the License. You may obtain a copy of the License at 8# 9# http://www.apache.org/licenses/LICENSE-2.0 10# 11# Unless required by applicable law or agreed to in writing, 12# software distributed under the License is distributed on an 13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14# KIND, either express or implied. See the License for the 15# specific language governing permissions and limitations 16# under the License. 17 18import mxnet as mx 19import mxnet.ndarray as nd 20import numpy 21import cv2 22from scipy.stats import entropy 23from utils import * 24 25class DQNOutput(mx.operator.CustomOp): 26 def __init__(self): 27 super(DQNOutput, self).__init__() 28 29 def forward(self, is_train, req, in_data, out_data, aux): 30 self.assign(out_data[0], req[0], in_data[0]) 31 32 def backward(self, req, out_grad, in_data, out_data, in_grad, aux): 33 # TODO Backward using NDArray will cause some troubles see `https://github.com/apache/mxnet/issues/1720' 34 x = out_data[0].asnumpy() 35 action = in_data[1].asnumpy().astype(numpy.int) 36 reward = in_data[2].asnumpy() 37 dx = in_grad[0] 38 ret = numpy.zeros(shape=dx.shape, dtype=numpy.float32) 39 ret[numpy.arange(action.shape[0]), action] \ 40 = numpy.clip(x[numpy.arange(action.shape[0]), action] - reward, -1, 1) 41 self.assign(dx, req[0], ret) 42 43 44@mx.operator.register("DQNOutput") 45class DQNOutputProp(mx.operator.CustomOpProp): 46 def __init__(self): 47 super(DQNOutputProp, self).__init__(need_top_grad=False) 48 49 def list_arguments(self): 50 return ['data', 'action', 'reward'] 51 52 def list_outputs(self): 53 return ['output'] 54 55 def infer_shape(self, in_shape): 56 data_shape = in_shape[0] 57 action_shape = (in_shape[0][0],) 58 reward_shape = (in_shape[0][0],) 59 output_shape = in_shape[0] 60 return [data_shape, action_shape, reward_shape], [output_shape], [] 61 62 def create_operator(self, ctx, shapes, dtypes): 63 return DQNOutput() 64 65 66class DQNOutputNpyOp(mx.operator.NumpyOp): 67 def __init__(self): 68 super(DQNOutputNpyOp, self).__init__(need_top_grad=False) 69 70 def list_arguments(self): 71 return ['data', 'action', 'reward'] 72 73 def list_outputs(self): 74 return ['output'] 75 76 def infer_shape(self, in_shape): 77 data_shape = in_shape[0] 78 action_shape = (in_shape[0][0],) 79 reward_shape = (in_shape[0][0],) 80 output_shape = in_shape[0] 81 return [data_shape, action_shape, reward_shape], [output_shape] 82 83 def forward(self, in_data, out_data): 84 x = in_data[0] 85 y = out_data[0] 86 y[:] = x 87 88 def backward(self, out_grad, in_data, out_data, in_grad): 89 x = out_data[0] 90 action = in_data[1].astype(numpy.int) 91 reward = in_data[2] 92 dx = in_grad[0] 93 dx[:] = 0 94 dx[numpy.arange(action.shape[0]), action] \ 95 = numpy.clip(x[numpy.arange(action.shape[0]), action] - reward, -1, 1) 96 97 98def dqn_sym_nips(action_num, data=None, name='dqn'): 99 """Structure of the Deep Q Network in the NIPS 2013 workshop paper: 100 Playing Atari with Deep Reinforcement Learning (https://www.cs.toronto.edu/~vmnih/docs/dqn.pdf) 101 102 Parameters 103 ---------- 104 action_num : int 105 data : mxnet.sym.Symbol, optional 106 name : str, optional 107 """ 108 if data is None: 109 net = mx.symbol.Variable('data') 110 else: 111 net = data 112 net = mx.symbol.Convolution(data=net, name='conv1', kernel=(8, 8), stride=(4, 4), num_filter=16) 113 net = mx.symbol.Activation(data=net, name='relu1', act_type="relu") 114 net = mx.symbol.Convolution(data=net, name='conv2', kernel=(4, 4), stride=(2, 2), num_filter=32) 115 net = mx.symbol.Activation(data=net, name='relu2', act_type="relu") 116 net = mx.symbol.Flatten(data=net) 117 net = mx.symbol.FullyConnected(data=net, name='fc3', num_hidden=256) 118 net = mx.symbol.Activation(data=net, name='relu3', act_type="relu") 119 net = mx.symbol.FullyConnected(data=net, name='fc4', num_hidden=action_num) 120 net = mx.symbol.Custom(data=net, name=name, op_type='DQNOutput') 121 return net 122 123 124def dqn_sym_nature(action_num, data=None, name='dqn'): 125 """Structure of the Deep Q Network in the Nature 2015 paper: 126 Human-level control through deep reinforcement learning 127 (http://www.nature.com/nature/journal/v518/n7540/full/nature14236.html) 128 129 Parameters 130 ---------- 131 action_num : int 132 data : mxnet.sym.Symbol, optional 133 name : str, optional 134 """ 135 if data is None: 136 net = mx.symbol.Variable('data') 137 else: 138 net = data 139 net = mx.symbol.Variable('data') 140 net = mx.symbol.Convolution(data=net, name='conv1', kernel=(8, 8), stride=(4, 4), num_filter=32) 141 net = mx.symbol.Activation(data=net, name='relu1', act_type="relu") 142 net = mx.symbol.Convolution(data=net, name='conv2', kernel=(4, 4), stride=(2, 2), num_filter=64) 143 net = mx.symbol.Activation(data=net, name='relu2', act_type="relu") 144 net = mx.symbol.Convolution(data=net, name='conv3', kernel=(3, 3), stride=(1, 1), num_filter=64) 145 net = mx.symbol.Activation(data=net, name='relu3', act_type="relu") 146 net = mx.symbol.Flatten(data=net) 147 net = mx.symbol.FullyConnected(data=net, name='fc4', num_hidden=512) 148 net = mx.symbol.Activation(data=net, name='relu4', act_type="relu") 149 net = mx.symbol.FullyConnected(data=net, name='fc5', num_hidden=action_num) 150 net = mx.symbol.Custom(data=net, name=name, op_type='DQNOutput') 151 return net 152