1# Licensed to the Apache Software Foundation (ASF) under one
2# or more contributor license agreements.  See the NOTICE file
3# distributed with this work for additional information
4# regarding copyright ownership.  The ASF licenses this file
5# to you under the Apache License, Version 2.0 (the
6# "License"); you may not use this file except in compliance
7# with the License.  You may obtain a copy of the License at
8#
9#   http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing,
12# software distributed under the License is distributed on an
13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14# KIND, either express or implied.  See the License for the
15# specific language governing permissions and limitations
16# under the License.
17
18import mxnet as mx
19import mxnet.ndarray as nd
20import numpy
21import cv2
22from scipy.stats import entropy
23from utils import *
24
25class DQNOutput(mx.operator.CustomOp):
26    def __init__(self):
27        super(DQNOutput, self).__init__()
28
29    def forward(self, is_train, req, in_data, out_data, aux):
30        self.assign(out_data[0], req[0], in_data[0])
31
32    def backward(self, req, out_grad, in_data, out_data, in_grad, aux):
33        # TODO Backward using NDArray will cause some troubles see `https://github.com/apache/mxnet/issues/1720'
34        x = out_data[0].asnumpy()
35        action = in_data[1].asnumpy().astype(numpy.int)
36        reward = in_data[2].asnumpy()
37        dx = in_grad[0]
38        ret = numpy.zeros(shape=dx.shape, dtype=numpy.float32)
39        ret[numpy.arange(action.shape[0]), action] \
40            = numpy.clip(x[numpy.arange(action.shape[0]), action] - reward, -1, 1)
41        self.assign(dx, req[0], ret)
42
43
44@mx.operator.register("DQNOutput")
45class DQNOutputProp(mx.operator.CustomOpProp):
46    def __init__(self):
47        super(DQNOutputProp, self).__init__(need_top_grad=False)
48
49    def list_arguments(self):
50        return ['data', 'action', 'reward']
51
52    def list_outputs(self):
53        return ['output']
54
55    def infer_shape(self, in_shape):
56        data_shape = in_shape[0]
57        action_shape = (in_shape[0][0],)
58        reward_shape = (in_shape[0][0],)
59        output_shape = in_shape[0]
60        return [data_shape, action_shape, reward_shape], [output_shape], []
61
62    def create_operator(self, ctx, shapes, dtypes):
63        return DQNOutput()
64
65
66class DQNOutputNpyOp(mx.operator.NumpyOp):
67    def __init__(self):
68        super(DQNOutputNpyOp, self).__init__(need_top_grad=False)
69
70    def list_arguments(self):
71        return ['data', 'action', 'reward']
72
73    def list_outputs(self):
74        return ['output']
75
76    def infer_shape(self, in_shape):
77        data_shape = in_shape[0]
78        action_shape = (in_shape[0][0],)
79        reward_shape = (in_shape[0][0],)
80        output_shape = in_shape[0]
81        return [data_shape, action_shape, reward_shape], [output_shape]
82
83    def forward(self, in_data, out_data):
84        x = in_data[0]
85        y = out_data[0]
86        y[:] = x
87
88    def backward(self, out_grad, in_data, out_data, in_grad):
89        x = out_data[0]
90        action = in_data[1].astype(numpy.int)
91        reward = in_data[2]
92        dx = in_grad[0]
93        dx[:] = 0
94        dx[numpy.arange(action.shape[0]), action] \
95            = numpy.clip(x[numpy.arange(action.shape[0]), action] - reward, -1, 1)
96
97
98def dqn_sym_nips(action_num, data=None, name='dqn'):
99    """Structure of the Deep Q Network in the NIPS 2013 workshop paper:
100    Playing Atari with Deep Reinforcement Learning (https://www.cs.toronto.edu/~vmnih/docs/dqn.pdf)
101
102    Parameters
103    ----------
104    action_num : int
105    data : mxnet.sym.Symbol, optional
106    name : str, optional
107    """
108    if data is None:
109        net = mx.symbol.Variable('data')
110    else:
111        net = data
112    net = mx.symbol.Convolution(data=net, name='conv1', kernel=(8, 8), stride=(4, 4), num_filter=16)
113    net = mx.symbol.Activation(data=net, name='relu1', act_type="relu")
114    net = mx.symbol.Convolution(data=net, name='conv2', kernel=(4, 4), stride=(2, 2), num_filter=32)
115    net = mx.symbol.Activation(data=net, name='relu2', act_type="relu")
116    net = mx.symbol.Flatten(data=net)
117    net = mx.symbol.FullyConnected(data=net, name='fc3', num_hidden=256)
118    net = mx.symbol.Activation(data=net, name='relu3', act_type="relu")
119    net = mx.symbol.FullyConnected(data=net, name='fc4', num_hidden=action_num)
120    net = mx.symbol.Custom(data=net, name=name, op_type='DQNOutput')
121    return net
122
123
124def dqn_sym_nature(action_num, data=None, name='dqn'):
125    """Structure of the Deep Q Network in the Nature 2015 paper:
126    Human-level control through deep reinforcement learning
127    (http://www.nature.com/nature/journal/v518/n7540/full/nature14236.html)
128
129    Parameters
130    ----------
131    action_num : int
132    data : mxnet.sym.Symbol, optional
133    name : str, optional
134    """
135    if data is None:
136        net = mx.symbol.Variable('data')
137    else:
138        net = data
139    net = mx.symbol.Variable('data')
140    net = mx.symbol.Convolution(data=net, name='conv1', kernel=(8, 8), stride=(4, 4), num_filter=32)
141    net = mx.symbol.Activation(data=net, name='relu1', act_type="relu")
142    net = mx.symbol.Convolution(data=net, name='conv2', kernel=(4, 4), stride=(2, 2), num_filter=64)
143    net = mx.symbol.Activation(data=net, name='relu2', act_type="relu")
144    net = mx.symbol.Convolution(data=net, name='conv3', kernel=(3, 3), stride=(1, 1), num_filter=64)
145    net = mx.symbol.Activation(data=net, name='relu3', act_type="relu")
146    net = mx.symbol.Flatten(data=net)
147    net = mx.symbol.FullyConnected(data=net, name='fc4', num_hidden=512)
148    net = mx.symbol.Activation(data=net, name='relu4', act_type="relu")
149    net = mx.symbol.FullyConnected(data=net, name='fc5', num_hidden=action_num)
150    net = mx.symbol.Custom(data=net, name=name, op_type='DQNOutput')
151    return net
152