1# Licensed to the Apache Software Foundation (ASF) under one
2# or more contributor license agreements.  See the NOTICE file
3# distributed with this work for additional information
4# regarding copyright ownership.  The ASF licenses this file
5# to you under the Apache License, Version 2.0 (the
6# "License"); you may not use this file except in compliance
7# with the License.  You may obtain a copy of the License at
8#
9#   http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing,
12# software distributed under the License is distributed on an
13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14# KIND, either express or implied.  See the License for the
15# specific language governing permissions and limitations
16# under the License.
17
18from utils import define_qfunc
19import mxnet as mx
20
21
22class QFunc(object):
23    """
24    Base class for Q-Value Function.
25    """
26
27    def __init__(self, env_spec):
28
29        self.env_spec = env_spec
30
31    def get_qvals(self, obs, act):
32
33        raise NotImplementedError
34
35
36class ContinuousMLPQ(QFunc):
37    """
38    Continuous Multi-Layer Perceptron Q-Value Network
39    for determnistic policy training.
40    """
41
42    def __init__(
43        self,
44        env_spec):
45
46        super(ContinuousMLPQ, self).__init__(env_spec)
47
48        self.obs = mx.symbol.Variable("obs")
49        self.act = mx.symbol.Variable("act")
50        self.qval = define_qfunc(self.obs, self.act)
51        self.yval = mx.symbol.Variable("yval")
52
53    def get_output_symbol(self):
54
55        return self.qval
56
57    def get_loss_symbols(self):
58
59        return {"qval": self.qval,
60                "yval": self.yval}
61
62    def define_loss(self, loss_exp):
63
64        self.loss = mx.symbol.MakeLoss(loss_exp, name="qfunc_loss")
65        self.loss = mx.symbol.Group([self.loss, mx.symbol.BlockGrad(self.qval)])
66
67    def define_exe(self, ctx, init, updater, input_shapes=None, args=None,
68                    grad_req=None):
69
70        # define an executor, initializer and updater for batch version loss
71        self.exe = self.loss.simple_bind(ctx=ctx, **input_shapes)
72        self.arg_arrays = self.exe.arg_arrays
73        self.grad_arrays = self.exe.grad_arrays
74        self.arg_dict = self.exe.arg_dict
75
76        for name, arr in self.arg_dict.items():
77            if name not in input_shapes:
78                init(name, arr)
79
80        self.updater = updater
81
82    def update_params(self, obs, act, yval):
83
84        self.arg_dict["obs"][:] = obs
85        self.arg_dict["act"][:] = act
86        self.arg_dict["yval"][:] = yval
87
88        self.exe.forward(is_train=True)
89        self.exe.backward()
90
91        for i, pair in enumerate(zip(self.arg_arrays, self.grad_arrays)):
92            weight, grad = pair
93            self.updater(i, grad, weight)
94
95    def get_qvals(self, obs, act):
96
97        self.exe.arg_dict["obs"][:] = obs
98        self.exe.arg_dict["act"][:] = act
99        self.exe.forward(is_train=False)
100
101        return self.exe.outputs[1].asnumpy()
102
103
104