1'''
2Parallel implementation of the Augmented Random Search method.
3Horia Mania --- hmania@berkeley.edu
4Aurelia Guy
5Benjamin Recht
6'''
7
8import parser
9import time
10import os
11import numpy as np
12import gym
13
14import arspb.logz as logz
15import ray
16import arspb.utils as utils
17import arspb.optimizers as optimizers
18from arspb.policies import *
19import socket
20from arspb.shared_noise import *
21
22##############################
23#create an envs_v2 laikago env
24
25import os
26
27import gin
28from pybullet_envs.minitaur.envs_v2 import env_loader
29import pybullet_data as pd
30
31
32
33def create_laikago_env():
34  CONFIG_DIR = pd.getDataPath()+"/configs_v2/"
35  CONFIG_FILES = [
36      os.path.join(CONFIG_DIR, "base/laikago_with_imu.gin"),
37      os.path.join(CONFIG_DIR, "tasks/fwd_task_no_termination.gin"),
38      os.path.join(CONFIG_DIR, "wrappers/pmtg_wrapper.gin"),
39      os.path.join(CONFIG_DIR, "scenes/simple_scene.gin")
40  ]
41
42  #gin.bind_parameter("scene_base.SceneBase.data_root", pd.getDataPath()+"/")
43  for gin_file in CONFIG_FILES:
44    gin.parse_config_file(gin_file)
45  gin.bind_parameter("SimulationParameters.enable_rendering", False)
46  gin.bind_parameter("terminal_conditions.maxstep_terminal_condition.max_step",
47                     10000)
48  env = env_loader.load()
49  return env
50
51
52##############################
53
54@ray.remote
55class Worker(object):
56    """
57    Object class for parallel rollout generation.
58    """
59
60    def __init__(self, env_seed,
61                 env_name='',
62                 policy_params = None,
63                 deltas=None,
64                 rollout_length=1000,
65                 delta_std=0.01):
66
67        # initialize OpenAI environment for each worker
68        try:
69          import pybullet_envs
70        except:
71          pass
72        try:
73          import tds_environments
74        except:
75          pass
76
77        self.env = create_laikago_env()
78        self.env.seed(env_seed)
79
80        # each worker gets access to the shared noise table
81        # with independent random streams for sampling
82        # from the shared noise table.
83        self.deltas = SharedNoiseTable(deltas, env_seed + 7)
84        self.policy_params = policy_params
85        if policy_params['type'] == 'linear':
86            print("LinearPolicy2")
87            self.policy = LinearPolicy2(policy_params)
88        elif policy_params['type'] == 'nn':
89            print("FullyConnectedNeuralNetworkPolicy")
90            self.policy = FullyConnectedNeuralNetworkPolicy(policy_params)
91        else:
92            raise NotImplementedError
93
94        self.delta_std = delta_std
95        self.rollout_length = rollout_length
96
97
98    def get_weights_plus_stats(self):
99        """
100        Get current policy weights and current statistics of past states.
101        """
102        #assert self.policy_params['type'] == 'linear'
103        return self.policy.get_weights_plus_stats()
104
105
106    def rollout(self, shift = 0., rollout_length = None):
107        """
108        Performs one rollout of maximum length rollout_length.
109        At each time-step it substracts shift from the reward.
110        """
111
112        if rollout_length is None:
113            rollout_length = self.rollout_length
114
115        total_reward = 0.
116        steps = 0
117
118        ob = self.env.reset()
119        for i in range(rollout_length):
120            action = self.policy.act(ob)
121            ob, reward, done, _ = self.env.step(action)
122            steps += 1
123            total_reward += (reward - shift)
124            if done:
125                break
126
127        return total_reward, steps
128
129    def do_rollouts(self, w_policy, num_rollouts = 1, shift = 1, evaluate = False):
130        """
131        Generate multiple rollouts with a policy parametrized by w_policy.
132        """
133
134        rollout_rewards, deltas_idx = [], []
135        steps = 0
136
137        for i in range(num_rollouts):
138
139            if evaluate:
140                self.policy.update_weights(w_policy)
141                deltas_idx.append(-1)
142
143                # set to false so that evaluation rollouts are not used for updating state statistics
144                self.policy.update_filter = False
145
146                # for evaluation we do not shift the rewards (shift = 0) and we use the
147                # default rollout length (1000 for the MuJoCo locomotion tasks)
148                reward, r_steps = self.rollout(shift = 0., rollout_length = self.rollout_length)
149                rollout_rewards.append(reward)
150
151            else:
152                idx, delta = self.deltas.get_delta(w_policy.size)
153
154                delta = (self.delta_std * delta).reshape(w_policy.shape)
155                deltas_idx.append(idx)
156
157                # set to true so that state statistics are updated
158                self.policy.update_filter = True
159
160                # compute reward and number of timesteps used for positive perturbation rollout
161                self.policy.update_weights(w_policy + delta)
162                pos_reward, pos_steps  = self.rollout(shift = shift)
163
164                # compute reward and number of timesteps used for negative pertubation rollout
165                self.policy.update_weights(w_policy - delta)
166                neg_reward, neg_steps = self.rollout(shift = shift)
167                steps += pos_steps + neg_steps
168
169                rollout_rewards.append([pos_reward, neg_reward])
170
171        return {'deltas_idx': deltas_idx, 'rollout_rewards': rollout_rewards, "steps" : steps}
172
173    def stats_increment(self):
174        self.policy.observation_filter.stats_increment()
175        return
176
177    def get_weights(self):
178        return self.policy.get_weights()
179
180    def get_filter(self):
181        return self.policy.observation_filter
182
183    def sync_filter(self, other):
184        self.policy.observation_filter.sync(other)
185        return
186
187
188class ARSLearner(object):
189    """
190    Object class implementing the ARS algorithm.
191    """
192
193    def __init__(self, env_name='HalfCheetah-v1',
194                 policy_params=None,
195                 num_workers=32,
196                 num_deltas=320,
197                 deltas_used=320,
198                 delta_std=0.01,
199                 logdir=None,
200                 rollout_length=4000,
201                 step_size=0.01,
202                 shift='constant zero',
203                 params=None,
204                 seed=123):
205
206        logz.configure_output_dir(logdir)
207        logz.save_params(params)
208        try:
209          import pybullet_envs
210        except:
211          pass
212        try:
213          import tds_environments
214        except:
215          pass
216
217        env = create_laikago_env()
218
219        self.timesteps = 0
220        self.action_size = env.action_space.shape[0]
221        self.ob_size = env.observation_space.shape[0]
222        self.num_deltas = num_deltas
223        self.deltas_used = deltas_used
224        self.rollout_length = rollout_length
225        self.step_size = step_size
226        self.delta_std = delta_std
227        self.logdir = logdir
228        self.shift = shift
229        self.params = params
230        self.max_past_avg_reward = float('-inf')
231        self.num_episodes_used = float('inf')
232
233
234        # create shared table for storing noise
235        print("Creating deltas table.")
236        deltas_id = create_shared_noise.remote()
237        self.deltas = SharedNoiseTable(ray.get(deltas_id), seed = seed + 3)
238        print('Created deltas table.')
239
240        # initialize workers with different random seeds
241        print('Initializing workers.')
242        self.num_workers = num_workers
243        self.workers = [Worker.remote(seed + 7 * i,
244                                      env_name=env_name,
245                                      policy_params=policy_params,
246                                      deltas=deltas_id,
247                                      rollout_length=rollout_length,
248                                      delta_std=delta_std) for i in range(num_workers)]
249
250
251        # initialize policy
252        if policy_params['type'] == 'linear':
253            print("LinearPolicy2")
254            self.policy = LinearPolicy2(policy_params)
255            self.w_policy = self.policy.get_weights()
256        elif policy_params['type'] == 'nn':
257            print("FullyConnectedNeuralNetworkPolicy")
258            self.policy = FullyConnectedNeuralNetworkPolicy(policy_params)
259            self.w_policy = self.policy.get_weights()
260        else:
261            raise NotImplementedError
262
263        # initialize optimization algorithm
264        self.optimizer = optimizers.SGD(self.w_policy, self.step_size)
265        print("Initialization of ARS complete.")
266
267    def aggregate_rollouts(self, num_rollouts = None, evaluate = False):
268        """
269        Aggregate update step from rollouts generated in parallel.
270        """
271
272        if num_rollouts is None:
273            num_deltas = self.num_deltas
274        else:
275            num_deltas = num_rollouts
276
277        # put policy weights in the object store
278        policy_id = ray.put(self.w_policy)
279
280        t1 = time.time()
281        num_rollouts = int(num_deltas / self.num_workers)
282
283        # parallel generation of rollouts
284        rollout_ids_one = [worker.do_rollouts.remote(policy_id,
285                                                 num_rollouts = num_rollouts,
286                                                 shift = self.shift,
287                                                 evaluate=evaluate) for worker in self.workers]
288
289        rollout_ids_two = [worker.do_rollouts.remote(policy_id,
290                                                 num_rollouts = 1,
291                                                 shift = self.shift,
292                                                 evaluate=evaluate) for worker in self.workers[:(num_deltas % self.num_workers)]]
293
294        # gather results
295        results_one = ray.get(rollout_ids_one)
296        results_two = ray.get(rollout_ids_two)
297
298        rollout_rewards, deltas_idx = [], []
299
300        for result in results_one:
301            if not evaluate:
302                self.timesteps += result["steps"]
303            deltas_idx += result['deltas_idx']
304            rollout_rewards += result['rollout_rewards']
305
306        for result in results_two:
307            if not evaluate:
308                self.timesteps += result["steps"]
309            deltas_idx += result['deltas_idx']
310            rollout_rewards += result['rollout_rewards']
311
312        deltas_idx = np.array(deltas_idx)
313        rollout_rewards = np.array(rollout_rewards, dtype = np.float64)
314
315        print('Maximum reward of collected rollouts:', rollout_rewards.max())
316        t2 = time.time()
317
318        print('Time to generate rollouts:', t2 - t1)
319
320        if evaluate:
321            return rollout_rewards
322
323        # select top performing directions if deltas_used < num_deltas
324        max_rewards = np.max(rollout_rewards, axis = 1)
325        if self.deltas_used > self.num_deltas:
326            self.deltas_used = self.num_deltas
327
328        idx = np.arange(max_rewards.size)[max_rewards >= np.percentile(max_rewards, 100*(1 - (self.deltas_used / self.num_deltas)))]
329        deltas_idx = deltas_idx[idx]
330        rollout_rewards = rollout_rewards[idx,:]
331
332        # normalize rewards by their standard deviation
333        np_std = np.std(rollout_rewards)
334        if np_std>1e-6:
335          rollout_rewards /= np_std
336
337        t1 = time.time()
338        # aggregate rollouts to form g_hat, the gradient used to compute SGD step
339        g_hat, count = utils.batched_weighted_sum(rollout_rewards[:,0] - rollout_rewards[:,1],
340                                                  (self.deltas.get(idx, self.w_policy.size)
341                                                   for idx in deltas_idx),
342                                                  batch_size = 500)
343        g_hat /= deltas_idx.size
344        t2 = time.time()
345        print('time to aggregate rollouts', t2 - t1)
346        return g_hat
347
348
349    def train_step(self):
350        """
351        Perform one update step of the policy weights.
352        """
353
354        g_hat = self.aggregate_rollouts()
355        print("Euclidean norm of update step:", np.linalg.norm(g_hat))
356        self.w_policy -= self.optimizer._compute_step(g_hat).reshape(self.w_policy.shape)
357        return
358
359    def train(self, num_iter):
360
361        start = time.time()
362        best_mean_rewards = -1e30
363
364        for i in range(num_iter):
365
366            t1 = time.time()
367            self.train_step()
368            t2 = time.time()
369            print('total time of one step', t2 - t1)
370            print('iter ', i,' done')
371
372            # record statistics every 10 iterations
373            if ((i + 1) % 10 == 0):
374
375                rewards = self.aggregate_rollouts(num_rollouts = 100, evaluate = True)
376                w = ray.get(self.workers[0].get_weights_plus_stats.remote())
377                np.savez(self.logdir + "/lin_policy_plus_latest", w)
378
379                mean_rewards = np.mean(rewards)
380                if (mean_rewards > best_mean_rewards):
381                  best_mean_rewards = mean_rewards
382                  np.savez(self.logdir + "/lin_policy_plus_best_"+str(i+1), w)
383
384
385                print(sorted(self.params.items()))
386                logz.log_tabular("Time", time.time() - start)
387                logz.log_tabular("Iteration", i + 1)
388                logz.log_tabular("AverageReward", np.mean(rewards))
389                logz.log_tabular("StdRewards", np.std(rewards))
390                logz.log_tabular("MaxRewardRollout", np.max(rewards))
391                logz.log_tabular("MinRewardRollout", np.min(rewards))
392                logz.log_tabular("timesteps", self.timesteps)
393                logz.dump_tabular()
394
395            t1 = time.time()
396            # get statistics from all workers
397            for j in range(self.num_workers):
398                self.policy.observation_filter.update(ray.get(self.workers[j].get_filter.remote()))
399            self.policy.observation_filter.stats_increment()
400
401            # make sure master filter buffer is clear
402            self.policy.observation_filter.clear_buffer()
403            # sync all workers
404            filter_id = ray.put(self.policy.observation_filter)
405            setting_filters_ids = [worker.sync_filter.remote(filter_id) for worker in self.workers]
406            # waiting for sync of all workers
407            ray.get(setting_filters_ids)
408
409            increment_filters_ids = [worker.stats_increment.remote() for worker in self.workers]
410            # waiting for increment of all workers
411            ray.get(increment_filters_ids)
412            t2 = time.time()
413            print('Time to sync statistics:', t2 - t1)
414
415        return
416
417def run_ars(params):
418    dir_path = params['dir_path']
419
420    if not(os.path.exists(dir_path)):
421        os.makedirs(dir_path)
422    logdir = dir_path
423    if not(os.path.exists(logdir)):
424        os.makedirs(logdir)
425
426    try:
427      import pybullet_envs
428    except:
429      pass
430    try:
431      import tds_environments
432    except:
433      pass
434    env = create_laikago_env()
435    ob_dim = env.observation_space.shape[0]
436    ac_dim = env.action_space.shape[0]
437    ac_lb = env.action_space.low
438    ac_ub = env.action_space.high
439
440    # set policy parameters. Possible filters: 'MeanStdFilter' for v2, 'NoFilter' for v1.
441    if params["policy_type"]=="nn":
442      policy_sizes_string = params['policy_network_size_list'].split(',')
443      print("policy_sizes_string=",policy_sizes_string)
444      policy_sizes_list = [int(item) for item in policy_sizes_string]
445      print("policy_sizes_list=",policy_sizes_list)
446      activation = params['activation']
447      policy_params={'type': params["policy_type"],
448                     'ob_filter':params['filter'],
449                     'policy_network_size' : policy_sizes_list,
450                     'ob_dim':ob_dim,
451                     'ac_dim':ac_dim,
452                     'activation' : activation,
453                     'action_lower_bound' : ac_lb,
454                     'action_upper_bound' : ac_ub,
455      }
456    else:
457      del params['policy_network_size_list']
458      del params['activation']
459      policy_params={'type': params["policy_type"],
460                     'ob_filter':params['filter'],
461                     'ob_dim':ob_dim,
462                     'ac_dim':ac_dim,
463                     'action_lower_bound' : ac_lb,
464                     'action_upper_bound' : ac_ub,
465      }
466
467
468    ARS = ARSLearner(env_name=params['env_name'],
469                     policy_params=policy_params,
470                     num_workers=params['n_workers'],
471                     num_deltas=params['n_directions'],
472                     deltas_used=params['deltas_used'],
473                     step_size=params['step_size'],
474                     delta_std=params['delta_std'],
475                     logdir=logdir,
476                     rollout_length=params['rollout_length'],
477                     shift=params['shift'],
478                     params=params,
479                     seed = params['seed'])
480
481    ARS.train(params['n_iter'])
482
483    return
484
485
486if __name__ == '__main__':
487    import argparse
488    parser = argparse.ArgumentParser()
489    parser.add_argument('--env_name', type=str, default='InvertedPendulumSwingupBulletEnv-v0')
490    parser.add_argument('--n_iter', '-n', type=int, default=1000)
491    parser.add_argument('--n_directions', '-nd', type=int, default=16)
492    parser.add_argument('--deltas_used', '-du', type=int, default=16)
493    parser.add_argument('--step_size', '-s', type=float, default=0.03)
494    parser.add_argument('--delta_std', '-std', type=float, default=.03)
495    parser.add_argument('--n_workers', '-e', type=int, default=18)
496    parser.add_argument('--rollout_length', '-r', type=int, default=2000)
497
498    # for Swimmer-v1 and HalfCheetah-v1 use shift = 0
499    # for Hopper-v1, Walker2d-v1, and Ant-v1 use shift = 1
500    # for Humanoid-v1 used shift = 5
501    parser.add_argument('--shift', type=float, default=0)
502    parser.add_argument('--seed', type=int, default=37)
503    parser.add_argument('--policy_type', type=str, help="Policy type, linear or nn (neural network)", default= 'linear')
504    parser.add_argument('--dir_path', type=str, default='data')
505
506    # for ARS V1 use filter = 'NoFilter'
507    parser.add_argument('--filter', type=str, default='MeanStdFilter')
508    parser.add_argument('--activation', type=str, help="Neural network policy activation function, tanh or clip", default="tanh")
509    parser.add_argument('--policy_network_size', action='store', dest='policy_network_size_list',type=str, nargs='*', default='64,64')
510
511
512
513
514
515    local_ip = socket.gethostbyname(socket.gethostname())
516    ray.init(address= local_ip + ':6379')
517
518    args = parser.parse_args()
519    params = vars(args)
520    run_ars(params)
521
522