1# Licensed to the Apache Software Foundation (ASF) under one 2# or more contributor license agreements. See the NOTICE file 3# distributed with this work for additional information 4# regarding copyright ownership. The ASF licenses this file 5# to you under the Apache License, Version 2.0 (the 6# "License"); you may not use this file except in compliance 7# with the License. You may obtain a copy of the License at 8# 9# http://www.apache.org/licenses/LICENSE-2.0 10# 11# Unless required by applicable law or agreed to in writing, 12# software distributed under the License is distributed on an 13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14# KIND, either express or implied. See the License for the 15# specific language governing permissions and limitations 16# under the License. 17 18from __future__ import print_function 19import mxnet as mx 20import numpy as np 21import rl_data 22import sym 23import argparse 24import logging 25import os 26import gym 27from datetime import datetime 28import time 29import sys 30try: 31 from importlib import reload 32except ImportError: 33 pass 34 35parser = argparse.ArgumentParser(description='Traing A3C with OpenAI Gym') 36parser.add_argument('--test', action='store_true', help='run testing', default=False) 37parser.add_argument('--log-file', type=str, help='the name of log file') 38parser.add_argument('--log-dir', type=str, default="./log", help='directory of the log file') 39parser.add_argument('--model-prefix', type=str, help='the prefix of the model to load') 40parser.add_argument('--save-model-prefix', type=str, help='the prefix of the model to save') 41parser.add_argument('--load-epoch', type=int, help="load the model on an epoch using the model-prefix") 42 43parser.add_argument('--kv-store', type=str, default='device', help='the kvstore type') 44parser.add_argument('--gpus', type=str, help='the gpus will be used, e.g "0,1,2,3"') 45 46parser.add_argument('--num-epochs', type=int, default=120, help='the number of training epochs') 47parser.add_argument('--num-examples', type=int, default=1000000, help='the number of training examples') 48parser.add_argument('--batch-size', type=int, default=32) 49parser.add_argument('--input-length', type=int, default=4) 50 51parser.add_argument('--lr', type=float, default=0.0001) 52parser.add_argument('--wd', type=float, default=0) 53parser.add_argument('--t-max', type=int, default=4) 54parser.add_argument('--gamma', type=float, default=0.99) 55parser.add_argument('--beta', type=float, default=0.08) 56 57args = parser.parse_args() 58 59def log_config(log_dir=None, log_file=None, prefix=None, rank=0): 60 reload(logging) 61 head = '%(asctime)-15s Node[' + str(rank) + '] %(message)s' 62 if log_dir: 63 logging.basicConfig(level=logging.DEBUG, format=head) 64 if not os.path.exists(log_dir): 65 os.makedirs(log_dir) 66 if not log_file: 67 log_file = (prefix if prefix else '') + datetime.now().strftime('_%Y_%m_%d-%H_%M.log') 68 log_file = log_file.replace('/', '-') 69 else: 70 log_file = log_file 71 log_file_full_name = os.path.join(log_dir, log_file) 72 handler = logging.FileHandler(log_file_full_name, mode='w') 73 formatter = logging.Formatter(head) 74 handler.setFormatter(formatter) 75 logging.getLogger().addHandler(handler) 76 logging.info('start with arguments %s', args) 77 else: 78 logging.basicConfig(level=logging.DEBUG, format=head) 79 logging.info('start with arguments %s', args) 80 81def train(): 82 # kvstore 83 kv = mx.kvstore.create(args.kv_store) 84 85 model_prefix = args.model_prefix 86 if model_prefix is not None: 87 model_prefix += "-%d" % (kv.rank) 88 save_model_prefix = args.save_model_prefix 89 if save_model_prefix is None: 90 save_model_prefix = model_prefix 91 92 log_config(args.log_dir, args.log_file, save_model_prefix, kv.rank) 93 94 devs = mx.cpu() if args.gpus is None else [ 95 mx.gpu(int(i)) for i in args.gpus.split(',')] 96 97 epoch_size = args.num_examples / args.batch_size 98 99 if args.kv_store == 'dist_sync': 100 epoch_size /= kv.num_workers 101 102 # disable kvstore for single device 103 if 'local' in kv.type and ( 104 args.gpus is None or len(args.gpus.split(',')) is 1): 105 kv = None 106 107 # module 108 dataiter = rl_data.GymDataIter('Breakout-v0', args.batch_size, args.input_length, web_viz=True) 109 net = sym.get_symbol_atari(dataiter.act_dim) 110 module = mx.mod.Module(net, data_names=[d[0] for d in dataiter.provide_data], label_names=('policy_label', 'value_label'), context=devs) 111 module.bind(data_shapes=dataiter.provide_data, 112 label_shapes=[('policy_label', (args.batch_size,)), ('value_label', (args.batch_size, 1))], 113 grad_req='add') 114 115 # load model 116 117 if args.load_epoch is not None: 118 assert model_prefix is not None 119 _, arg_params, aux_params = mx.model.load_checkpoint(model_prefix, args.load_epoch) 120 else: 121 arg_params = aux_params = None 122 123 # save model 124 checkpoint = None if save_model_prefix is None else mx.callback.do_checkpoint(save_model_prefix) 125 126 init = mx.init.Mixed(['fc_value_weight|fc_policy_weight', '.*'], 127 [mx.init.Uniform(0.001), mx.init.Xavier(rnd_type='gaussian', factor_type="in", magnitude=2)]) 128 module.init_params(initializer=init, 129 arg_params=arg_params, aux_params=aux_params) 130 131 # optimizer 132 module.init_optimizer(kvstore=kv, optimizer='adam', 133 optimizer_params={'learning_rate': args.lr, 'wd': args.wd, 'epsilon': 1e-3}) 134 135 # logging 136 np.set_printoptions(precision=3, suppress=True) 137 138 T = 0 139 dataiter.reset() 140 score = np.zeros((args.batch_size, 1)) 141 final_score = np.zeros((args.batch_size, 1)) 142 for epoch in range(args.num_epochs): 143 if save_model_prefix: 144 module.save_params('%s-%04d.params'%(save_model_prefix, epoch)) 145 146 147 for _ in range(int(epoch_size/args.t_max)): 148 tic = time.time() 149 # clear gradients 150 for exe in module._exec_group.grad_arrays: 151 for g in exe: 152 g[:] = 0 153 154 S, A, V, r, D = [], [], [], [], [] 155 for t in range(args.t_max + 1): 156 data = dataiter.data() 157 module.forward(mx.io.DataBatch(data=data, label=None), is_train=False) 158 act, _, val = module.get_outputs() 159 V.append(val.asnumpy()) 160 if t < args.t_max: 161 act = act.asnumpy() 162 act = [np.random.choice(dataiter.act_dim, p=act[i]) for i in range(act.shape[0])] 163 reward, done = dataiter.act(act) 164 S.append(data) 165 A.append(act) 166 r.append(reward.reshape((-1, 1))) 167 D.append(done.reshape((-1, 1))) 168 169 err = 0 170 R = V[args.t_max] 171 for i in reversed(range(args.t_max)): 172 R = r[i] + args.gamma * (1 - D[i]) * R 173 adv = np.tile(R - V[i], (1, dataiter.act_dim)) 174 175 batch = mx.io.DataBatch(data=S[i], label=[mx.nd.array(A[i]), mx.nd.array(R)]) 176 module.forward(batch, is_train=True) 177 178 pi = module.get_outputs()[1] 179 h = -args.beta*(mx.nd.log(pi+1e-7)*pi) 180 out_acts = np.amax(pi.asnumpy(), 1) 181 out_acts=np.reshape(out_acts,(-1,1)) 182 out_acts_tile=np.tile(-np.log(out_acts + 1e-7),(1, dataiter.act_dim)) 183 module.backward([mx.nd.array(out_acts_tile*adv), h]) 184 185 print('pi', pi[0].asnumpy()) 186 print('h', h[0].asnumpy()) 187 err += (adv**2).mean() 188 score += r[i] 189 final_score *= (1-D[i]) 190 final_score += score * D[i] 191 score *= 1-D[i] 192 T += D[i].sum() 193 194 module.update() 195 logging.info('fps: %f err: %f score: %f final: %f T: %f'%(args.batch_size/(time.time()-tic), err/args.t_max, score.mean(), final_score.mean(), T)) 196 print(score.squeeze()) 197 print(final_score.squeeze()) 198 199def test(): 200 log_config() 201 202 devs = mx.cpu() if args.gpus is None else [ 203 mx.gpu(int(i)) for i in args.gpus.split(',')] 204 205 # module 206 dataiter = rl_data.GymDataIter('scenes', args.batch_size, args.input_length, web_viz=True) 207 print(dataiter.provide_data) 208 net = sym.get_symbol_thor(dataiter.act_dim) 209 module = mx.mod.Module(net, data_names=[d[0] for d in dataiter.provide_data], label_names=('policy_label', 'value_label'), context=devs) 210 module.bind(data_shapes=dataiter.provide_data, 211 label_shapes=[('policy_label', (args.batch_size,)), ('value_label', (args.batch_size, 1))], 212 for_training=False) 213 214 # load model 215 assert args.load_epoch is not None 216 assert args.model_prefix is not None 217 module.load_params('%s-%04d.params'%(args.model_prefix, args.load_epoch)) 218 219 N = args.num_epochs * args.num_examples / args.batch_size 220 221 R = 0 222 T = 1e-20 223 score = np.zeros((args.batch_size,)) 224 for t in range(N): 225 dataiter.clear_history() 226 data = dataiter.next() 227 module.forward(data, is_train=False) 228 act = module.get_outputs()[0].asnumpy() 229 act = [np.random.choice(dataiter.act_dim, p=act[i]) for i in range(act.shape[0])] 230 dataiter.act(act) 231 time.sleep(0.05) 232 _, reward, _, done = dataiter.history[0] 233 T += done.sum() 234 score += reward 235 R += (done*score).sum() 236 score *= (1-done) 237 238 if t % 100 == 0: 239 logging.info('n %d score: %f T: %f'%(t, R/T, T)) 240 241 242if __name__ == '__main__': 243 if args.test: 244 test() 245 else: 246 train() 247 248 249