1import numpy as np
2import copy
3import os
4import time
5import sys
6from abc import abstractmethod
7import abc
8if sys.version_info >= (3, 4):
9  ABC = abc.ABC
10else:
11  ABC = abc.ABCMeta('ABC', (), {})
12
13from enum import Enum
14
15from pybullet_envs.deep_mimic.learning.path import *
16from pybullet_envs.deep_mimic.learning.exp_params import ExpParams
17from pybullet_envs.deep_mimic.learning.normalizer import Normalizer
18from pybullet_envs.deep_mimic.learning.replay_buffer import ReplayBuffer
19from pybullet_utils.logger import Logger
20import pybullet_utils.mpi_util as MPIUtil
21import pybullet_utils.math_util as MathUtil
22
23
24class RLAgent(ABC):
25
26  class Mode(Enum):
27    TRAIN = 0
28    TEST = 1
29    TRAIN_END = 2
30
31  NAME = "None"
32
33  UPDATE_PERIOD_KEY = "UpdatePeriod"
34  ITERS_PER_UPDATE = "ItersPerUpdate"
35  DISCOUNT_KEY = "Discount"
36  MINI_BATCH_SIZE_KEY = "MiniBatchSize"
37  REPLAY_BUFFER_SIZE_KEY = "ReplayBufferSize"
38  INIT_SAMPLES_KEY = "InitSamples"
39  NORMALIZER_SAMPLES_KEY = "NormalizerSamples"
40
41  OUTPUT_ITERS_KEY = "OutputIters"
42  INT_OUTPUT_ITERS_KEY = "IntOutputIters"
43  TEST_EPISODES_KEY = "TestEpisodes"
44
45  EXP_ANNEAL_SAMPLES_KEY = "ExpAnnealSamples"
46  EXP_PARAM_BEG_KEY = "ExpParamsBeg"
47  EXP_PARAM_END_KEY = "ExpParamsEnd"
48
49  def __init__(self, world, id, json_data):
50    self.world = world
51    self.id = id
52    self.logger = Logger()
53    self._mode = self.Mode.TRAIN
54
55    assert self._check_action_space(), \
56        Logger.print2("Invalid action space, got {:s}".format(str(self.get_action_space())))
57
58    self._enable_training = True
59    self.path = Path()
60    self.iter = int(0)
61    self.start_time = time.time()
62    self._update_counter = 0
63
64    self.update_period = 1.0  # simulated time (seconds) before each training update
65    self.iters_per_update = int(1)
66    self.discount = 0.95
67    self.mini_batch_size = int(32)
68    self.replay_buffer_size = int(50000)
69    self.init_samples = int(1000)
70    self.normalizer_samples = np.inf
71    self._local_mini_batch_size = self.mini_batch_size  # batch size for each work for multiprocessing
72    self._need_normalizer_update = True
73    self._total_sample_count = 0
74
75    self._output_dir = ""
76    self._int_output_dir = ""
77    self.output_iters = 100
78    self.int_output_iters = 100
79
80    self.train_return = 0.0
81    self.test_episodes = int(0)
82    self.test_episode_count = int(0)
83    self.test_return = 0.0
84    self.avg_test_return = 0.0
85
86    self.exp_anneal_samples = 320000
87    self.exp_params_beg = ExpParams()
88    self.exp_params_end = ExpParams()
89    self.exp_params_curr = ExpParams()
90
91    self._load_params(json_data)
92    self._build_replay_buffer(self.replay_buffer_size)
93    self._build_normalizers()
94    self._build_bounds()
95    self.reset()
96
97    return
98
99  def __str__(self):
100    action_space_str = str(self.get_action_space())
101    info_str = ""
102    info_str += '"ID": {:d},\n "Type": "{:s}",\n "ActionSpace": "{:s}",\n "StateDim": {:d},\n "GoalDim": {:d},\n "ActionDim": {:d}'.format(
103        self.id, self.NAME, action_space_str[action_space_str.rfind('.') + 1:],
104        self.get_state_size(), self.get_goal_size(), self.get_action_size())
105    return "{\n" + info_str + "\n}"
106
107  def get_output_dir(self):
108    return self._output_dir
109
110  def set_output_dir(self, out_dir):
111    self._output_dir = out_dir
112    if (self._output_dir != ""):
113      self.logger.configure_output_file(out_dir + "/agent" + str(self.id) + "_log.txt")
114    return
115
116  output_dir = property(get_output_dir, set_output_dir)
117
118  def get_int_output_dir(self):
119    return self._int_output_dir
120
121  def set_int_output_dir(self, out_dir):
122    self._int_output_dir = out_dir
123    return
124
125  int_output_dir = property(get_int_output_dir, set_int_output_dir)
126
127  def reset(self):
128    self.path.clear()
129    return
130
131  def update(self, timestep):
132    if self.need_new_action():
133      #print("update_new_action!!!")
134      self._update_new_action()
135
136    if (self._mode == self.Mode.TRAIN and self.enable_training):
137      self._update_counter += timestep
138
139      while self._update_counter >= self.update_period:
140        self._train()
141        self._update_exp_params()
142        self.world.env.set_sample_count(self._total_sample_count)
143        self._update_counter -= self.update_period
144
145    return
146
147  def end_episode(self):
148    if (self.path.pathlength() > 0):
149      self._end_path()
150
151      if (self._mode == self.Mode.TRAIN or self._mode == self.Mode.TRAIN_END):
152        if (self.enable_training and self.path.pathlength() > 0):
153          self._store_path(self.path)
154      elif (self._mode == self.Mode.TEST):
155        self._update_test_return(self.path)
156      else:
157        assert False, Logger.print2("Unsupported RL agent mode" + str(self._mode))
158
159      self._update_mode()
160    return
161
162  def has_goal(self):
163    return self.get_goal_size() > 0
164
165  def predict_val(self):
166    return 0
167
168  def get_enable_training(self):
169    return self._enable_training
170
171  def set_enable_training(self, enable):
172    print("set_enable_training=", enable)
173    self._enable_training = enable
174    if (self._enable_training):
175      self.reset()
176    return
177
178  enable_training = property(get_enable_training, set_enable_training)
179
180  def enable_testing(self):
181    return self.test_episodes > 0
182
183  def get_name(self):
184    return self.NAME
185
186  @abstractmethod
187  def save_model(self, out_path):
188    pass
189
190  @abstractmethod
191  def load_model(self, in_path):
192    pass
193
194  @abstractmethod
195  def _decide_action(self, s, g):
196    pass
197
198  @abstractmethod
199  def _get_output_path(self):
200    pass
201
202  @abstractmethod
203  def _get_int_output_path(self):
204    pass
205
206  @abstractmethod
207  def _train_step(self):
208    pass
209
210  @abstractmethod
211  def _check_action_space(self):
212    pass
213
214  def get_action_space(self):
215    return self.world.env.get_action_space(self.id)
216
217  def get_state_size(self):
218    return self.world.env.get_state_size(self.id)
219
220  def get_goal_size(self):
221    return self.world.env.get_goal_size(self.id)
222
223  def get_action_size(self):
224    return self.world.env.get_action_size(self.id)
225
226  def get_num_actions(self):
227    return self.world.env.get_num_actions(self.id)
228
229  def need_new_action(self):
230    return self.world.env.need_new_action(self.id)
231
232  def _build_normalizers(self):
233    self.s_norm = Normalizer(self.get_state_size(),
234                             self.world.env.build_state_norm_groups(self.id))
235    self.s_norm.set_mean_std(-self.world.env.build_state_offset(self.id),
236                             1 / self.world.env.build_state_scale(self.id))
237
238    self.g_norm = Normalizer(self.get_goal_size(), self.world.env.build_goal_norm_groups(self.id))
239    self.g_norm.set_mean_std(-self.world.env.build_goal_offset(self.id),
240                             1 / self.world.env.build_goal_scale(self.id))
241
242    self.a_norm = Normalizer(self.world.env.get_action_size())
243    self.a_norm.set_mean_std(-self.world.env.build_action_offset(self.id),
244                             1 / self.world.env.build_action_scale(self.id))
245    return
246
247  def _build_bounds(self):
248    self.a_bound_min = self.world.env.build_action_bound_min(self.id)
249    self.a_bound_max = self.world.env.build_action_bound_max(self.id)
250    return
251
252  def _load_params(self, json_data):
253    if (self.UPDATE_PERIOD_KEY in json_data):
254      self.update_period = int(json_data[self.UPDATE_PERIOD_KEY])
255
256    if (self.ITERS_PER_UPDATE in json_data):
257      self.iters_per_update = int(json_data[self.ITERS_PER_UPDATE])
258
259    if (self.DISCOUNT_KEY in json_data):
260      self.discount = json_data[self.DISCOUNT_KEY]
261
262    if (self.MINI_BATCH_SIZE_KEY in json_data):
263      self.mini_batch_size = int(json_data[self.MINI_BATCH_SIZE_KEY])
264
265    if (self.REPLAY_BUFFER_SIZE_KEY in json_data):
266      self.replay_buffer_size = int(json_data[self.REPLAY_BUFFER_SIZE_KEY])
267
268    if (self.INIT_SAMPLES_KEY in json_data):
269      self.init_samples = int(json_data[self.INIT_SAMPLES_KEY])
270
271    if (self.NORMALIZER_SAMPLES_KEY in json_data):
272      self.normalizer_samples = int(json_data[self.NORMALIZER_SAMPLES_KEY])
273
274    if (self.OUTPUT_ITERS_KEY in json_data):
275      self.output_iters = json_data[self.OUTPUT_ITERS_KEY]
276
277    if (self.INT_OUTPUT_ITERS_KEY in json_data):
278      self.int_output_iters = json_data[self.INT_OUTPUT_ITERS_KEY]
279
280    if (self.TEST_EPISODES_KEY in json_data):
281      self.test_episodes = int(json_data[self.TEST_EPISODES_KEY])
282
283    if (self.EXP_ANNEAL_SAMPLES_KEY in json_data):
284      self.exp_anneal_samples = json_data[self.EXP_ANNEAL_SAMPLES_KEY]
285
286    if (self.EXP_PARAM_BEG_KEY in json_data):
287      self.exp_params_beg.load(json_data[self.EXP_PARAM_BEG_KEY])
288
289    if (self.EXP_PARAM_END_KEY in json_data):
290      self.exp_params_end.load(json_data[self.EXP_PARAM_END_KEY])
291
292    num_procs = MPIUtil.get_num_procs()
293    self._local_mini_batch_size = int(np.ceil(self.mini_batch_size / num_procs))
294    self._local_mini_batch_size = np.maximum(self._local_mini_batch_size, 1)
295    self.mini_batch_size = self._local_mini_batch_size * num_procs
296
297    assert (self.exp_params_beg.noise == self.exp_params_end.noise)  # noise std should not change
298    self.exp_params_curr = copy.deepcopy(self.exp_params_beg)
299    self.exp_params_end.noise = self.exp_params_beg.noise
300
301    self._need_normalizer_update = self.normalizer_samples > 0
302
303    return
304
305  def _record_state(self):
306    s = self.world.env.record_state(self.id)
307    return s
308
309  def _record_goal(self):
310    g = self.world.env.record_goal(self.id)
311    return g
312
313  def _record_reward(self):
314    r = self.world.env.calc_reward(self.id)
315    return r
316
317  def _apply_action(self, a):
318    self.world.env.set_action(self.id, a)
319    return
320
321  def _record_flags(self):
322    return int(0)
323
324  def _is_first_step(self):
325    return len(self.path.states) == 0
326
327  def _end_path(self):
328    s = self._record_state()
329    g = self._record_goal()
330    r = self._record_reward()
331
332    self.path.rewards.append(r)
333    self.path.states.append(s)
334    self.path.goals.append(g)
335    self.path.terminate = self.world.env.check_terminate(self.id)
336
337    return
338
339  def _update_new_action(self):
340    #print("_update_new_action!")
341    s = self._record_state()
342    #np.savetxt("pb_record_state_s.csv", s, delimiter=",")
343    g = self._record_goal()
344
345    if not (self._is_first_step()):
346      r = self._record_reward()
347      self.path.rewards.append(r)
348
349    a, logp = self._decide_action(s=s, g=g)
350    assert len(np.shape(a)) == 1
351    assert len(np.shape(logp)) <= 1
352
353    flags = self._record_flags()
354    self._apply_action(a)
355
356    self.path.states.append(s)
357    self.path.goals.append(g)
358    self.path.actions.append(a)
359    self.path.logps.append(logp)
360    self.path.flags.append(flags)
361
362    if self._enable_draw():
363      self._log_val(s, g)
364
365    return
366
367  def _update_exp_params(self):
368    lerp = float(self._total_sample_count) / self.exp_anneal_samples
369    lerp = np.clip(lerp, 0.0, 1.0)
370    self.exp_params_curr = self.exp_params_beg.lerp(self.exp_params_end, lerp)
371    return
372
373  def _update_test_return(self, path):
374    path_reward = path.calc_return()
375    self.test_return += path_reward
376    self.test_episode_count += 1
377    return
378
379  def _update_mode(self):
380    if (self._mode == self.Mode.TRAIN):
381      self._update_mode_train()
382    elif (self._mode == self.Mode.TRAIN_END):
383      self._update_mode_train_end()
384    elif (self._mode == self.Mode.TEST):
385      self._update_mode_test()
386    else:
387      assert False, Logger.print2("Unsupported RL agent mode" + str(self._mode))
388    return
389
390  def _update_mode_train(self):
391    return
392
393  def _update_mode_train_end(self):
394    self._init_mode_test()
395    return
396
397  def _update_mode_test(self):
398    if (self.test_episode_count * MPIUtil.get_num_procs() >= self.test_episodes):
399      global_return = MPIUtil.reduce_sum(self.test_return)
400      global_count = MPIUtil.reduce_sum(self.test_episode_count)
401      avg_return = global_return / global_count
402      self.avg_test_return = avg_return
403
404      if self.enable_training:
405        self._init_mode_train()
406    return
407
408  def _init_mode_train(self):
409    self._mode = self.Mode.TRAIN
410    self.world.env.set_mode(self._mode)
411    return
412
413  def _init_mode_train_end(self):
414    self._mode = self.Mode.TRAIN_END
415    return
416
417  def _init_mode_test(self):
418    self._mode = self.Mode.TEST
419    self.test_return = 0.0
420    self.test_episode_count = 0
421    self.world.env.set_mode(self._mode)
422    return
423
424  def _enable_output(self):
425    return MPIUtil.is_root_proc() and self.output_dir != ""
426
427  def _enable_int_output(self):
428    return MPIUtil.is_root_proc() and self.int_output_dir != ""
429
430  def _calc_val_bounds(self, discount):
431    r_min = self.world.env.get_reward_min(self.id)
432    r_max = self.world.env.get_reward_max(self.id)
433    assert (r_min <= r_max)
434
435    val_min = r_min / (1.0 - discount)
436    val_max = r_max / (1.0 - discount)
437    return val_min, val_max
438
439  def _calc_val_offset_scale(self, discount):
440    val_min, val_max = self._calc_val_bounds(discount)
441    val_offset = 0
442    val_scale = 1
443
444    if (np.isfinite(val_min) and np.isfinite(val_max)):
445      val_offset = -0.5 * (val_max + val_min)
446      val_scale = 2 / (val_max - val_min)
447
448    return val_offset, val_scale
449
450  def _calc_term_vals(self, discount):
451    r_fail = self.world.env.get_reward_fail(self.id)
452    r_succ = self.world.env.get_reward_succ(self.id)
453
454    r_min = self.world.env.get_reward_min(self.id)
455    r_max = self.world.env.get_reward_max(self.id)
456    assert (r_fail <= r_max and r_fail >= r_min)
457    assert (r_succ <= r_max and r_succ >= r_min)
458    assert (not np.isinf(r_fail))
459    assert (not np.isinf(r_succ))
460
461    if (discount == 0):
462      val_fail = 0
463      val_succ = 0
464    else:
465      val_fail = r_fail / (1.0 - discount)
466      val_succ = r_succ / (1.0 - discount)
467
468    return val_fail, val_succ
469
470  def _update_iter(self, iter):
471    if (self._enable_output() and self.iter % self.output_iters == 0):
472      output_path = self._get_output_path()
473      output_dir = os.path.dirname(output_path)
474      if not os.path.exists(output_dir):
475        os.makedirs(output_dir)
476      self.save_model(output_path)
477
478    if (self._enable_int_output() and self.iter % self.int_output_iters == 0):
479      int_output_path = self._get_int_output_path()
480      int_output_dir = os.path.dirname(int_output_path)
481      if not os.path.exists(int_output_dir):
482        os.makedirs(int_output_dir)
483      self.save_model(int_output_path)
484
485    self.iter = iter
486    return
487
488  def _enable_draw(self):
489    return self.world.env.enable_draw
490
491  def _log_val(self, s, g):
492    pass
493
494  def _build_replay_buffer(self, buffer_size):
495    num_procs = MPIUtil.get_num_procs()
496    buffer_size = int(buffer_size / num_procs)
497    self.replay_buffer = ReplayBuffer(buffer_size=buffer_size)
498    self.replay_buffer_initialized = False
499    return
500
501  def _store_path(self, path):
502    path_id = self.replay_buffer.store(path)
503    valid_path = path_id != MathUtil.INVALID_IDX
504
505    if valid_path:
506      self.train_return = path.calc_return()
507
508      if self._need_normalizer_update:
509        self._record_normalizers(path)
510
511    return path_id
512
513  def _record_normalizers(self, path):
514    states = np.array(path.states)
515    self.s_norm.record(states)
516
517    if self.has_goal():
518      goals = np.array(path.goals)
519      self.g_norm.record(goals)
520
521    return
522
523  def _update_normalizers(self):
524    self.s_norm.update()
525
526    if self.has_goal():
527      self.g_norm.update()
528    return
529
530  def _train(self):
531    samples = self.replay_buffer.total_count
532    self._total_sample_count = int(MPIUtil.reduce_sum(samples))
533    end_training = False
534
535    if (self.replay_buffer_initialized):
536      if (self._valid_train_step()):
537        prev_iter = self.iter
538        iters = self._get_iters_per_update()
539        avg_train_return = MPIUtil.reduce_avg(self.train_return)
540
541        for i in range(iters):
542          curr_iter = self.iter
543          wall_time = time.time() - self.start_time
544          wall_time /= 60 * 60  # store time in hours
545
546          has_goal = self.has_goal()
547          s_mean = np.mean(self.s_norm.mean)
548          s_std = np.mean(self.s_norm.std)
549          g_mean = np.mean(self.g_norm.mean) if has_goal else 0
550          g_std = np.mean(self.g_norm.std) if has_goal else 0
551
552          self.logger.log_tabular("Iteration", self.iter)
553          self.logger.log_tabular("Wall_Time", wall_time)
554          self.logger.log_tabular("Samples", self._total_sample_count)
555          self.logger.log_tabular("Train_Return", avg_train_return)
556          self.logger.log_tabular("Test_Return", self.avg_test_return)
557          self.logger.log_tabular("State_Mean", s_mean)
558          self.logger.log_tabular("State_Std", s_std)
559          self.logger.log_tabular("Goal_Mean", g_mean)
560          self.logger.log_tabular("Goal_Std", g_std)
561          self._log_exp_params()
562
563          self._update_iter(self.iter + 1)
564          self._train_step()
565
566          Logger.print2("Agent " + str(self.id))
567          self.logger.print_tabular()
568          Logger.print2("")
569
570          if (self._enable_output() and curr_iter % self.int_output_iters == 0):
571            self.logger.dump_tabular()
572
573        if (prev_iter // self.int_output_iters != self.iter // self.int_output_iters):
574          end_training = self.enable_testing()
575
576    else:
577
578      Logger.print2("Agent " + str(self.id))
579      Logger.print2("Samples: " + str(self._total_sample_count))
580      Logger.print2("")
581
582      if (self._total_sample_count >= self.init_samples):
583        self.replay_buffer_initialized = True
584        end_training = self.enable_testing()
585
586    if self._need_normalizer_update:
587      self._update_normalizers()
588      self._need_normalizer_update = self.normalizer_samples > self._total_sample_count
589
590    if end_training:
591      self._init_mode_train_end()
592
593    return
594
595  def _get_iters_per_update(self):
596    return MPIUtil.get_num_procs() * self.iters_per_update
597
598  def _valid_train_step(self):
599    return True
600
601  def _log_exp_params(self):
602    self.logger.log_tabular("Exp_Rate", self.exp_params_curr.rate)
603    self.logger.log_tabular("Exp_Noise", self.exp_params_curr.noise)
604    self.logger.log_tabular("Exp_Temp", self.exp_params_curr.temp)
605    return
606