1import numpy as np 2import copy 3import os 4import time 5import sys 6from abc import abstractmethod 7import abc 8if sys.version_info >= (3, 4): 9 ABC = abc.ABC 10else: 11 ABC = abc.ABCMeta('ABC', (), {}) 12 13from enum import Enum 14 15from pybullet_envs.deep_mimic.learning.path import * 16from pybullet_envs.deep_mimic.learning.exp_params import ExpParams 17from pybullet_envs.deep_mimic.learning.normalizer import Normalizer 18from pybullet_envs.deep_mimic.learning.replay_buffer import ReplayBuffer 19from pybullet_utils.logger import Logger 20import pybullet_utils.mpi_util as MPIUtil 21import pybullet_utils.math_util as MathUtil 22 23 24class RLAgent(ABC): 25 26 class Mode(Enum): 27 TRAIN = 0 28 TEST = 1 29 TRAIN_END = 2 30 31 NAME = "None" 32 33 UPDATE_PERIOD_KEY = "UpdatePeriod" 34 ITERS_PER_UPDATE = "ItersPerUpdate" 35 DISCOUNT_KEY = "Discount" 36 MINI_BATCH_SIZE_KEY = "MiniBatchSize" 37 REPLAY_BUFFER_SIZE_KEY = "ReplayBufferSize" 38 INIT_SAMPLES_KEY = "InitSamples" 39 NORMALIZER_SAMPLES_KEY = "NormalizerSamples" 40 41 OUTPUT_ITERS_KEY = "OutputIters" 42 INT_OUTPUT_ITERS_KEY = "IntOutputIters" 43 TEST_EPISODES_KEY = "TestEpisodes" 44 45 EXP_ANNEAL_SAMPLES_KEY = "ExpAnnealSamples" 46 EXP_PARAM_BEG_KEY = "ExpParamsBeg" 47 EXP_PARAM_END_KEY = "ExpParamsEnd" 48 49 def __init__(self, world, id, json_data): 50 self.world = world 51 self.id = id 52 self.logger = Logger() 53 self._mode = self.Mode.TRAIN 54 55 assert self._check_action_space(), \ 56 Logger.print2("Invalid action space, got {:s}".format(str(self.get_action_space()))) 57 58 self._enable_training = True 59 self.path = Path() 60 self.iter = int(0) 61 self.start_time = time.time() 62 self._update_counter = 0 63 64 self.update_period = 1.0 # simulated time (seconds) before each training update 65 self.iters_per_update = int(1) 66 self.discount = 0.95 67 self.mini_batch_size = int(32) 68 self.replay_buffer_size = int(50000) 69 self.init_samples = int(1000) 70 self.normalizer_samples = np.inf 71 self._local_mini_batch_size = self.mini_batch_size # batch size for each work for multiprocessing 72 self._need_normalizer_update = True 73 self._total_sample_count = 0 74 75 self._output_dir = "" 76 self._int_output_dir = "" 77 self.output_iters = 100 78 self.int_output_iters = 100 79 80 self.train_return = 0.0 81 self.test_episodes = int(0) 82 self.test_episode_count = int(0) 83 self.test_return = 0.0 84 self.avg_test_return = 0.0 85 86 self.exp_anneal_samples = 320000 87 self.exp_params_beg = ExpParams() 88 self.exp_params_end = ExpParams() 89 self.exp_params_curr = ExpParams() 90 91 self._load_params(json_data) 92 self._build_replay_buffer(self.replay_buffer_size) 93 self._build_normalizers() 94 self._build_bounds() 95 self.reset() 96 97 return 98 99 def __str__(self): 100 action_space_str = str(self.get_action_space()) 101 info_str = "" 102 info_str += '"ID": {:d},\n "Type": "{:s}",\n "ActionSpace": "{:s}",\n "StateDim": {:d},\n "GoalDim": {:d},\n "ActionDim": {:d}'.format( 103 self.id, self.NAME, action_space_str[action_space_str.rfind('.') + 1:], 104 self.get_state_size(), self.get_goal_size(), self.get_action_size()) 105 return "{\n" + info_str + "\n}" 106 107 def get_output_dir(self): 108 return self._output_dir 109 110 def set_output_dir(self, out_dir): 111 self._output_dir = out_dir 112 if (self._output_dir != ""): 113 self.logger.configure_output_file(out_dir + "/agent" + str(self.id) + "_log.txt") 114 return 115 116 output_dir = property(get_output_dir, set_output_dir) 117 118 def get_int_output_dir(self): 119 return self._int_output_dir 120 121 def set_int_output_dir(self, out_dir): 122 self._int_output_dir = out_dir 123 return 124 125 int_output_dir = property(get_int_output_dir, set_int_output_dir) 126 127 def reset(self): 128 self.path.clear() 129 return 130 131 def update(self, timestep): 132 if self.need_new_action(): 133 #print("update_new_action!!!") 134 self._update_new_action() 135 136 if (self._mode == self.Mode.TRAIN and self.enable_training): 137 self._update_counter += timestep 138 139 while self._update_counter >= self.update_period: 140 self._train() 141 self._update_exp_params() 142 self.world.env.set_sample_count(self._total_sample_count) 143 self._update_counter -= self.update_period 144 145 return 146 147 def end_episode(self): 148 if (self.path.pathlength() > 0): 149 self._end_path() 150 151 if (self._mode == self.Mode.TRAIN or self._mode == self.Mode.TRAIN_END): 152 if (self.enable_training and self.path.pathlength() > 0): 153 self._store_path(self.path) 154 elif (self._mode == self.Mode.TEST): 155 self._update_test_return(self.path) 156 else: 157 assert False, Logger.print2("Unsupported RL agent mode" + str(self._mode)) 158 159 self._update_mode() 160 return 161 162 def has_goal(self): 163 return self.get_goal_size() > 0 164 165 def predict_val(self): 166 return 0 167 168 def get_enable_training(self): 169 return self._enable_training 170 171 def set_enable_training(self, enable): 172 print("set_enable_training=", enable) 173 self._enable_training = enable 174 if (self._enable_training): 175 self.reset() 176 return 177 178 enable_training = property(get_enable_training, set_enable_training) 179 180 def enable_testing(self): 181 return self.test_episodes > 0 182 183 def get_name(self): 184 return self.NAME 185 186 @abstractmethod 187 def save_model(self, out_path): 188 pass 189 190 @abstractmethod 191 def load_model(self, in_path): 192 pass 193 194 @abstractmethod 195 def _decide_action(self, s, g): 196 pass 197 198 @abstractmethod 199 def _get_output_path(self): 200 pass 201 202 @abstractmethod 203 def _get_int_output_path(self): 204 pass 205 206 @abstractmethod 207 def _train_step(self): 208 pass 209 210 @abstractmethod 211 def _check_action_space(self): 212 pass 213 214 def get_action_space(self): 215 return self.world.env.get_action_space(self.id) 216 217 def get_state_size(self): 218 return self.world.env.get_state_size(self.id) 219 220 def get_goal_size(self): 221 return self.world.env.get_goal_size(self.id) 222 223 def get_action_size(self): 224 return self.world.env.get_action_size(self.id) 225 226 def get_num_actions(self): 227 return self.world.env.get_num_actions(self.id) 228 229 def need_new_action(self): 230 return self.world.env.need_new_action(self.id) 231 232 def _build_normalizers(self): 233 self.s_norm = Normalizer(self.get_state_size(), 234 self.world.env.build_state_norm_groups(self.id)) 235 self.s_norm.set_mean_std(-self.world.env.build_state_offset(self.id), 236 1 / self.world.env.build_state_scale(self.id)) 237 238 self.g_norm = Normalizer(self.get_goal_size(), self.world.env.build_goal_norm_groups(self.id)) 239 self.g_norm.set_mean_std(-self.world.env.build_goal_offset(self.id), 240 1 / self.world.env.build_goal_scale(self.id)) 241 242 self.a_norm = Normalizer(self.world.env.get_action_size()) 243 self.a_norm.set_mean_std(-self.world.env.build_action_offset(self.id), 244 1 / self.world.env.build_action_scale(self.id)) 245 return 246 247 def _build_bounds(self): 248 self.a_bound_min = self.world.env.build_action_bound_min(self.id) 249 self.a_bound_max = self.world.env.build_action_bound_max(self.id) 250 return 251 252 def _load_params(self, json_data): 253 if (self.UPDATE_PERIOD_KEY in json_data): 254 self.update_period = int(json_data[self.UPDATE_PERIOD_KEY]) 255 256 if (self.ITERS_PER_UPDATE in json_data): 257 self.iters_per_update = int(json_data[self.ITERS_PER_UPDATE]) 258 259 if (self.DISCOUNT_KEY in json_data): 260 self.discount = json_data[self.DISCOUNT_KEY] 261 262 if (self.MINI_BATCH_SIZE_KEY in json_data): 263 self.mini_batch_size = int(json_data[self.MINI_BATCH_SIZE_KEY]) 264 265 if (self.REPLAY_BUFFER_SIZE_KEY in json_data): 266 self.replay_buffer_size = int(json_data[self.REPLAY_BUFFER_SIZE_KEY]) 267 268 if (self.INIT_SAMPLES_KEY in json_data): 269 self.init_samples = int(json_data[self.INIT_SAMPLES_KEY]) 270 271 if (self.NORMALIZER_SAMPLES_KEY in json_data): 272 self.normalizer_samples = int(json_data[self.NORMALIZER_SAMPLES_KEY]) 273 274 if (self.OUTPUT_ITERS_KEY in json_data): 275 self.output_iters = json_data[self.OUTPUT_ITERS_KEY] 276 277 if (self.INT_OUTPUT_ITERS_KEY in json_data): 278 self.int_output_iters = json_data[self.INT_OUTPUT_ITERS_KEY] 279 280 if (self.TEST_EPISODES_KEY in json_data): 281 self.test_episodes = int(json_data[self.TEST_EPISODES_KEY]) 282 283 if (self.EXP_ANNEAL_SAMPLES_KEY in json_data): 284 self.exp_anneal_samples = json_data[self.EXP_ANNEAL_SAMPLES_KEY] 285 286 if (self.EXP_PARAM_BEG_KEY in json_data): 287 self.exp_params_beg.load(json_data[self.EXP_PARAM_BEG_KEY]) 288 289 if (self.EXP_PARAM_END_KEY in json_data): 290 self.exp_params_end.load(json_data[self.EXP_PARAM_END_KEY]) 291 292 num_procs = MPIUtil.get_num_procs() 293 self._local_mini_batch_size = int(np.ceil(self.mini_batch_size / num_procs)) 294 self._local_mini_batch_size = np.maximum(self._local_mini_batch_size, 1) 295 self.mini_batch_size = self._local_mini_batch_size * num_procs 296 297 assert (self.exp_params_beg.noise == self.exp_params_end.noise) # noise std should not change 298 self.exp_params_curr = copy.deepcopy(self.exp_params_beg) 299 self.exp_params_end.noise = self.exp_params_beg.noise 300 301 self._need_normalizer_update = self.normalizer_samples > 0 302 303 return 304 305 def _record_state(self): 306 s = self.world.env.record_state(self.id) 307 return s 308 309 def _record_goal(self): 310 g = self.world.env.record_goal(self.id) 311 return g 312 313 def _record_reward(self): 314 r = self.world.env.calc_reward(self.id) 315 return r 316 317 def _apply_action(self, a): 318 self.world.env.set_action(self.id, a) 319 return 320 321 def _record_flags(self): 322 return int(0) 323 324 def _is_first_step(self): 325 return len(self.path.states) == 0 326 327 def _end_path(self): 328 s = self._record_state() 329 g = self._record_goal() 330 r = self._record_reward() 331 332 self.path.rewards.append(r) 333 self.path.states.append(s) 334 self.path.goals.append(g) 335 self.path.terminate = self.world.env.check_terminate(self.id) 336 337 return 338 339 def _update_new_action(self): 340 #print("_update_new_action!") 341 s = self._record_state() 342 #np.savetxt("pb_record_state_s.csv", s, delimiter=",") 343 g = self._record_goal() 344 345 if not (self._is_first_step()): 346 r = self._record_reward() 347 self.path.rewards.append(r) 348 349 a, logp = self._decide_action(s=s, g=g) 350 assert len(np.shape(a)) == 1 351 assert len(np.shape(logp)) <= 1 352 353 flags = self._record_flags() 354 self._apply_action(a) 355 356 self.path.states.append(s) 357 self.path.goals.append(g) 358 self.path.actions.append(a) 359 self.path.logps.append(logp) 360 self.path.flags.append(flags) 361 362 if self._enable_draw(): 363 self._log_val(s, g) 364 365 return 366 367 def _update_exp_params(self): 368 lerp = float(self._total_sample_count) / self.exp_anneal_samples 369 lerp = np.clip(lerp, 0.0, 1.0) 370 self.exp_params_curr = self.exp_params_beg.lerp(self.exp_params_end, lerp) 371 return 372 373 def _update_test_return(self, path): 374 path_reward = path.calc_return() 375 self.test_return += path_reward 376 self.test_episode_count += 1 377 return 378 379 def _update_mode(self): 380 if (self._mode == self.Mode.TRAIN): 381 self._update_mode_train() 382 elif (self._mode == self.Mode.TRAIN_END): 383 self._update_mode_train_end() 384 elif (self._mode == self.Mode.TEST): 385 self._update_mode_test() 386 else: 387 assert False, Logger.print2("Unsupported RL agent mode" + str(self._mode)) 388 return 389 390 def _update_mode_train(self): 391 return 392 393 def _update_mode_train_end(self): 394 self._init_mode_test() 395 return 396 397 def _update_mode_test(self): 398 if (self.test_episode_count * MPIUtil.get_num_procs() >= self.test_episodes): 399 global_return = MPIUtil.reduce_sum(self.test_return) 400 global_count = MPIUtil.reduce_sum(self.test_episode_count) 401 avg_return = global_return / global_count 402 self.avg_test_return = avg_return 403 404 if self.enable_training: 405 self._init_mode_train() 406 return 407 408 def _init_mode_train(self): 409 self._mode = self.Mode.TRAIN 410 self.world.env.set_mode(self._mode) 411 return 412 413 def _init_mode_train_end(self): 414 self._mode = self.Mode.TRAIN_END 415 return 416 417 def _init_mode_test(self): 418 self._mode = self.Mode.TEST 419 self.test_return = 0.0 420 self.test_episode_count = 0 421 self.world.env.set_mode(self._mode) 422 return 423 424 def _enable_output(self): 425 return MPIUtil.is_root_proc() and self.output_dir != "" 426 427 def _enable_int_output(self): 428 return MPIUtil.is_root_proc() and self.int_output_dir != "" 429 430 def _calc_val_bounds(self, discount): 431 r_min = self.world.env.get_reward_min(self.id) 432 r_max = self.world.env.get_reward_max(self.id) 433 assert (r_min <= r_max) 434 435 val_min = r_min / (1.0 - discount) 436 val_max = r_max / (1.0 - discount) 437 return val_min, val_max 438 439 def _calc_val_offset_scale(self, discount): 440 val_min, val_max = self._calc_val_bounds(discount) 441 val_offset = 0 442 val_scale = 1 443 444 if (np.isfinite(val_min) and np.isfinite(val_max)): 445 val_offset = -0.5 * (val_max + val_min) 446 val_scale = 2 / (val_max - val_min) 447 448 return val_offset, val_scale 449 450 def _calc_term_vals(self, discount): 451 r_fail = self.world.env.get_reward_fail(self.id) 452 r_succ = self.world.env.get_reward_succ(self.id) 453 454 r_min = self.world.env.get_reward_min(self.id) 455 r_max = self.world.env.get_reward_max(self.id) 456 assert (r_fail <= r_max and r_fail >= r_min) 457 assert (r_succ <= r_max and r_succ >= r_min) 458 assert (not np.isinf(r_fail)) 459 assert (not np.isinf(r_succ)) 460 461 if (discount == 0): 462 val_fail = 0 463 val_succ = 0 464 else: 465 val_fail = r_fail / (1.0 - discount) 466 val_succ = r_succ / (1.0 - discount) 467 468 return val_fail, val_succ 469 470 def _update_iter(self, iter): 471 if (self._enable_output() and self.iter % self.output_iters == 0): 472 output_path = self._get_output_path() 473 output_dir = os.path.dirname(output_path) 474 if not os.path.exists(output_dir): 475 os.makedirs(output_dir) 476 self.save_model(output_path) 477 478 if (self._enable_int_output() and self.iter % self.int_output_iters == 0): 479 int_output_path = self._get_int_output_path() 480 int_output_dir = os.path.dirname(int_output_path) 481 if not os.path.exists(int_output_dir): 482 os.makedirs(int_output_dir) 483 self.save_model(int_output_path) 484 485 self.iter = iter 486 return 487 488 def _enable_draw(self): 489 return self.world.env.enable_draw 490 491 def _log_val(self, s, g): 492 pass 493 494 def _build_replay_buffer(self, buffer_size): 495 num_procs = MPIUtil.get_num_procs() 496 buffer_size = int(buffer_size / num_procs) 497 self.replay_buffer = ReplayBuffer(buffer_size=buffer_size) 498 self.replay_buffer_initialized = False 499 return 500 501 def _store_path(self, path): 502 path_id = self.replay_buffer.store(path) 503 valid_path = path_id != MathUtil.INVALID_IDX 504 505 if valid_path: 506 self.train_return = path.calc_return() 507 508 if self._need_normalizer_update: 509 self._record_normalizers(path) 510 511 return path_id 512 513 def _record_normalizers(self, path): 514 states = np.array(path.states) 515 self.s_norm.record(states) 516 517 if self.has_goal(): 518 goals = np.array(path.goals) 519 self.g_norm.record(goals) 520 521 return 522 523 def _update_normalizers(self): 524 self.s_norm.update() 525 526 if self.has_goal(): 527 self.g_norm.update() 528 return 529 530 def _train(self): 531 samples = self.replay_buffer.total_count 532 self._total_sample_count = int(MPIUtil.reduce_sum(samples)) 533 end_training = False 534 535 if (self.replay_buffer_initialized): 536 if (self._valid_train_step()): 537 prev_iter = self.iter 538 iters = self._get_iters_per_update() 539 avg_train_return = MPIUtil.reduce_avg(self.train_return) 540 541 for i in range(iters): 542 curr_iter = self.iter 543 wall_time = time.time() - self.start_time 544 wall_time /= 60 * 60 # store time in hours 545 546 has_goal = self.has_goal() 547 s_mean = np.mean(self.s_norm.mean) 548 s_std = np.mean(self.s_norm.std) 549 g_mean = np.mean(self.g_norm.mean) if has_goal else 0 550 g_std = np.mean(self.g_norm.std) if has_goal else 0 551 552 self.logger.log_tabular("Iteration", self.iter) 553 self.logger.log_tabular("Wall_Time", wall_time) 554 self.logger.log_tabular("Samples", self._total_sample_count) 555 self.logger.log_tabular("Train_Return", avg_train_return) 556 self.logger.log_tabular("Test_Return", self.avg_test_return) 557 self.logger.log_tabular("State_Mean", s_mean) 558 self.logger.log_tabular("State_Std", s_std) 559 self.logger.log_tabular("Goal_Mean", g_mean) 560 self.logger.log_tabular("Goal_Std", g_std) 561 self._log_exp_params() 562 563 self._update_iter(self.iter + 1) 564 self._train_step() 565 566 Logger.print2("Agent " + str(self.id)) 567 self.logger.print_tabular() 568 Logger.print2("") 569 570 if (self._enable_output() and curr_iter % self.int_output_iters == 0): 571 self.logger.dump_tabular() 572 573 if (prev_iter // self.int_output_iters != self.iter // self.int_output_iters): 574 end_training = self.enable_testing() 575 576 else: 577 578 Logger.print2("Agent " + str(self.id)) 579 Logger.print2("Samples: " + str(self._total_sample_count)) 580 Logger.print2("") 581 582 if (self._total_sample_count >= self.init_samples): 583 self.replay_buffer_initialized = True 584 end_training = self.enable_testing() 585 586 if self._need_normalizer_update: 587 self._update_normalizers() 588 self._need_normalizer_update = self.normalizer_samples > self._total_sample_count 589 590 if end_training: 591 self._init_mode_train_end() 592 593 return 594 595 def _get_iters_per_update(self): 596 return MPIUtil.get_num_procs() * self.iters_per_update 597 598 def _valid_train_step(self): 599 return True 600 601 def _log_exp_params(self): 602 self.logger.log_tabular("Exp_Rate", self.exp_params_curr.rate) 603 self.logger.log_tabular("Exp_Noise", self.exp_params_curr.noise) 604 self.logger.log_tabular("Exp_Temp", self.exp_params_curr.temp) 605 return 606