1# Licensed to the Apache Software Foundation (ASF) under one
2# or more contributor license agreements.  See the NOTICE file
3# distributed with this work for additional information
4# regarding copyright ownership.  The ASF licenses this file
5# to you under the Apache License, Version 2.0 (the
6# "License"); you may not use this file except in compliance
7# with the License.  You may obtain a copy of the License at
8#
9#   http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing,
12# software distributed under the License is distributed on an
13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14# KIND, either express or implied.  See the License for the
15# specific language governing permissions and limitations
16# under the License.
17
18from __future__ import print_function
19import mxnet as mx
20import numpy as np
21import rl_data
22import sym
23import argparse
24import logging
25import os
26import gym
27from datetime import datetime
28import time
29import sys
30try:
31    from importlib import reload
32except ImportError:
33    pass
34
35parser = argparse.ArgumentParser(description='Traing A3C with OpenAI Gym')
36parser.add_argument('--test', action='store_true', help='run testing', default=False)
37parser.add_argument('--log-file', type=str, help='the name of log file')
38parser.add_argument('--log-dir', type=str, default="./log", help='directory of the log file')
39parser.add_argument('--model-prefix', type=str, help='the prefix of the model to load')
40parser.add_argument('--save-model-prefix', type=str, help='the prefix of the model to save')
41parser.add_argument('--load-epoch', type=int, help="load the model on an epoch using the model-prefix")
42
43parser.add_argument('--kv-store', type=str, default='device', help='the kvstore type')
44parser.add_argument('--gpus', type=str, help='the gpus will be used, e.g "0,1,2,3"')
45
46parser.add_argument('--num-epochs', type=int, default=120, help='the number of training epochs')
47parser.add_argument('--num-examples', type=int, default=1000000, help='the number of training examples')
48parser.add_argument('--batch-size', type=int, default=32)
49parser.add_argument('--input-length', type=int, default=4)
50
51parser.add_argument('--lr', type=float, default=0.0001)
52parser.add_argument('--wd', type=float, default=0)
53parser.add_argument('--t-max', type=int, default=4)
54parser.add_argument('--gamma', type=float, default=0.99)
55parser.add_argument('--beta', type=float, default=0.08)
56
57args = parser.parse_args()
58
59def log_config(log_dir=None, log_file=None, prefix=None, rank=0):
60    reload(logging)
61    head = '%(asctime)-15s Node[' + str(rank) + '] %(message)s'
62    if log_dir:
63        logging.basicConfig(level=logging.DEBUG, format=head)
64        if not os.path.exists(log_dir):
65            os.makedirs(log_dir)
66        if not log_file:
67            log_file = (prefix if prefix else '') + datetime.now().strftime('_%Y_%m_%d-%H_%M.log')
68            log_file = log_file.replace('/', '-')
69        else:
70            log_file = log_file
71        log_file_full_name = os.path.join(log_dir, log_file)
72        handler = logging.FileHandler(log_file_full_name, mode='w')
73        formatter = logging.Formatter(head)
74        handler.setFormatter(formatter)
75        logging.getLogger().addHandler(handler)
76        logging.info('start with arguments %s', args)
77    else:
78        logging.basicConfig(level=logging.DEBUG, format=head)
79        logging.info('start with arguments %s', args)
80
81def train():
82    # kvstore
83    kv = mx.kvstore.create(args.kv_store)
84
85    model_prefix = args.model_prefix
86    if model_prefix is not None:
87        model_prefix += "-%d" % (kv.rank)
88    save_model_prefix = args.save_model_prefix
89    if save_model_prefix is None:
90        save_model_prefix = model_prefix
91
92    log_config(args.log_dir, args.log_file, save_model_prefix, kv.rank)
93
94    devs = mx.cpu() if args.gpus is None else [
95        mx.gpu(int(i)) for i in args.gpus.split(',')]
96
97    epoch_size = args.num_examples / args.batch_size
98
99    if args.kv_store == 'dist_sync':
100        epoch_size /= kv.num_workers
101
102    # disable kvstore for single device
103    if 'local' in kv.type and (
104            args.gpus is None or len(args.gpus.split(',')) is 1):
105        kv = None
106
107    # module
108    dataiter = rl_data.GymDataIter('Breakout-v0', args.batch_size, args.input_length, web_viz=True)
109    net = sym.get_symbol_atari(dataiter.act_dim)
110    module = mx.mod.Module(net, data_names=[d[0] for d in dataiter.provide_data], label_names=('policy_label', 'value_label'), context=devs)
111    module.bind(data_shapes=dataiter.provide_data,
112                label_shapes=[('policy_label', (args.batch_size,)), ('value_label', (args.batch_size, 1))],
113                grad_req='add')
114
115    # load model
116
117    if args.load_epoch is not None:
118        assert model_prefix is not None
119        _, arg_params, aux_params = mx.model.load_checkpoint(model_prefix, args.load_epoch)
120    else:
121        arg_params = aux_params = None
122
123    # save model
124    checkpoint = None if save_model_prefix is None else mx.callback.do_checkpoint(save_model_prefix)
125
126    init = mx.init.Mixed(['fc_value_weight|fc_policy_weight', '.*'],
127                         [mx.init.Uniform(0.001), mx.init.Xavier(rnd_type='gaussian', factor_type="in", magnitude=2)])
128    module.init_params(initializer=init,
129                       arg_params=arg_params, aux_params=aux_params)
130
131    # optimizer
132    module.init_optimizer(kvstore=kv, optimizer='adam',
133                          optimizer_params={'learning_rate': args.lr, 'wd': args.wd, 'epsilon': 1e-3})
134
135    # logging
136    np.set_printoptions(precision=3, suppress=True)
137
138    T = 0
139    dataiter.reset()
140    score = np.zeros((args.batch_size, 1))
141    final_score = np.zeros((args.batch_size, 1))
142    for epoch in range(args.num_epochs):
143        if save_model_prefix:
144            module.save_params('%s-%04d.params'%(save_model_prefix, epoch))
145
146
147        for _ in range(int(epoch_size/args.t_max)):
148            tic = time.time()
149            # clear gradients
150            for exe in module._exec_group.grad_arrays:
151                for g in exe:
152                    g[:] = 0
153
154            S, A, V, r, D = [], [], [], [], []
155            for t in range(args.t_max + 1):
156                data = dataiter.data()
157                module.forward(mx.io.DataBatch(data=data, label=None), is_train=False)
158                act, _, val = module.get_outputs()
159                V.append(val.asnumpy())
160                if t < args.t_max:
161                    act = act.asnumpy()
162                    act = [np.random.choice(dataiter.act_dim, p=act[i]) for i in range(act.shape[0])]
163                    reward, done = dataiter.act(act)
164                    S.append(data)
165                    A.append(act)
166                    r.append(reward.reshape((-1, 1)))
167                    D.append(done.reshape((-1, 1)))
168
169            err = 0
170            R = V[args.t_max]
171            for i in reversed(range(args.t_max)):
172                R = r[i] + args.gamma * (1 - D[i]) * R
173                adv = np.tile(R - V[i], (1, dataiter.act_dim))
174
175                batch = mx.io.DataBatch(data=S[i], label=[mx.nd.array(A[i]), mx.nd.array(R)])
176                module.forward(batch, is_train=True)
177
178                pi = module.get_outputs()[1]
179                h = -args.beta*(mx.nd.log(pi+1e-7)*pi)
180                out_acts = np.amax(pi.asnumpy(), 1)
181                out_acts=np.reshape(out_acts,(-1,1))
182                out_acts_tile=np.tile(-np.log(out_acts + 1e-7),(1, dataiter.act_dim))
183                module.backward([mx.nd.array(out_acts_tile*adv), h])
184
185                print('pi', pi[0].asnumpy())
186                print('h', h[0].asnumpy())
187                err += (adv**2).mean()
188                score += r[i]
189                final_score *= (1-D[i])
190                final_score += score * D[i]
191                score *= 1-D[i]
192                T += D[i].sum()
193
194            module.update()
195            logging.info('fps: %f err: %f score: %f final: %f T: %f'%(args.batch_size/(time.time()-tic), err/args.t_max, score.mean(), final_score.mean(), T))
196            print(score.squeeze())
197            print(final_score.squeeze())
198
199def test():
200    log_config()
201
202    devs = mx.cpu() if args.gpus is None else [
203        mx.gpu(int(i)) for i in args.gpus.split(',')]
204
205    # module
206    dataiter = rl_data.GymDataIter('scenes', args.batch_size, args.input_length, web_viz=True)
207    print(dataiter.provide_data)
208    net = sym.get_symbol_thor(dataiter.act_dim)
209    module = mx.mod.Module(net, data_names=[d[0] for d in dataiter.provide_data], label_names=('policy_label', 'value_label'), context=devs)
210    module.bind(data_shapes=dataiter.provide_data,
211                label_shapes=[('policy_label', (args.batch_size,)), ('value_label', (args.batch_size, 1))],
212                for_training=False)
213
214    # load model
215    assert args.load_epoch is not None
216    assert args.model_prefix is not None
217    module.load_params('%s-%04d.params'%(args.model_prefix, args.load_epoch))
218
219    N = args.num_epochs * args.num_examples / args.batch_size
220
221    R = 0
222    T = 1e-20
223    score = np.zeros((args.batch_size,))
224    for t in range(N):
225        dataiter.clear_history()
226        data = dataiter.next()
227        module.forward(data, is_train=False)
228        act = module.get_outputs()[0].asnumpy()
229        act = [np.random.choice(dataiter.act_dim, p=act[i]) for i in range(act.shape[0])]
230        dataiter.act(act)
231        time.sleep(0.05)
232        _, reward, _, done = dataiter.history[0]
233        T += done.sum()
234        score += reward
235        R += (done*score).sum()
236        score *= (1-done)
237
238        if t % 100 == 0:
239            logging.info('n %d score: %f T: %f'%(t, R/T, T))
240
241
242if __name__ == '__main__':
243    if args.test:
244        test()
245    else:
246        train()
247
248
249