1# Copyright 2017 The TensorFlow Agents Authors. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14"""Proximal Policy Optimization algorithm. 15 16Based on John Schulman's implementation in Python and Theano: 17https://github.com/joschu/modular_rl/blob/master/modular_rl/ppo.py 18""" 19 20from __future__ import absolute_import 21from __future__ import division 22from __future__ import print_function 23 24import functools 25 26try: 27 import tensorflow.compat.v1 as tf 28except Exception: 29 import tensorflow as tf 30 31from . import memory 32from . import normalize 33from . import utility 34 35 36class PPOAlgorithm(object): 37 """A vectorized implementation of the PPO algorithm by John Schulman.""" 38 39 def __init__(self, batch_env, step, is_training, should_log, config): 40 """Create an instance of the PPO algorithm. 41 42 Args: 43 batch_env: In-graph batch environment. 44 step: Integer tensor holding the current training step. 45 is_training: Boolean tensor for whether the algorithm should train. 46 should_log: Boolean tensor for whether summaries should be returned. 47 config: Object containing the agent configuration as attributes. 48 """ 49 self._batch_env = batch_env 50 self._step = step 51 self._is_training = is_training 52 self._should_log = should_log 53 self._config = config 54 self._observ_filter = normalize.StreamingNormalize(self._batch_env.observ[0], 55 center=True, 56 scale=True, 57 clip=5, 58 name='normalize_observ') 59 self._reward_filter = normalize.StreamingNormalize(self._batch_env.reward[0], 60 center=False, 61 scale=True, 62 clip=10, 63 name='normalize_reward') 64 # Memory stores tuple of observ, action, mean, logstd, reward. 65 template = (self._batch_env.observ[0], self._batch_env.action[0], self._batch_env.action[0], 66 self._batch_env.action[0], self._batch_env.reward[0]) 67 self._memory = memory.EpisodeMemory(template, config.update_every, config.max_length, 'memory') 68 self._memory_index = tf.Variable(0, False) 69 use_gpu = self._config.use_gpu and utility.available_gpus() 70 with tf.device('/gpu:0' if use_gpu else '/cpu:0'): 71 # Create network variables for later calls to reuse. 72 action_size = self._batch_env.action.shape[1].value 73 self._network = tf.make_template('network', 74 functools.partial(config.network, config, action_size)) 75 output = self._network( 76 tf.zeros_like(self._batch_env.observ)[:, None], tf.ones(len(self._batch_env))) 77 with tf.variable_scope('ppo_temporary'): 78 self._episodes = memory.EpisodeMemory(template, len(batch_env), config.max_length, 79 'episodes') 80 if output.state is None: 81 self._last_state = None 82 else: 83 # Ensure the batch dimension is set. 84 tf.contrib.framework.nest.map_structure( 85 lambda x: x.set_shape([len(batch_env)] + x.shape.as_list()[1:]), output.state) 86 # pylint: disable=undefined-variable 87 self._last_state = tf.contrib.framework.nest.map_structure( 88 lambda x: tf.Variable(lambda: tf.zeros_like(x), False), output.state) 89 self._last_action = tf.Variable(tf.zeros_like(self._batch_env.action), 90 False, 91 name='last_action') 92 self._last_mean = tf.Variable(tf.zeros_like(self._batch_env.action), 93 False, 94 name='last_mean') 95 self._last_logstd = tf.Variable(tf.zeros_like(self._batch_env.action), 96 False, 97 name='last_logstd') 98 self._penalty = tf.Variable(self._config.kl_init_penalty, False, dtype=tf.float32) 99 self._optimizer = self._config.optimizer(self._config.learning_rate) 100 101 def begin_episode(self, agent_indices): 102 """Reset the recurrent states and stored episode. 103 104 Args: 105 agent_indices: Tensor containing current batch indices. 106 107 Returns: 108 Summary tensor. 109 """ 110 with tf.name_scope('begin_episode/'): 111 if self._last_state is None: 112 reset_state = tf.no_op() 113 else: 114 reset_state = utility.reinit_nested_vars(self._last_state, agent_indices) 115 reset_buffer = self._episodes.clear(agent_indices) 116 with tf.control_dependencies([reset_state, reset_buffer]): 117 return tf.constant('') 118 119 def perform(self, agent_indices, observ): 120 """Compute batch of actions and a summary for a batch of observation. 121 122 Args: 123 agent_indices: Tensor containing current batch indices. 124 observ: Tensor of a batch of observations for all agents. 125 126 Returns: 127 Tuple of action batch tensor and summary tensor. 128 """ 129 with tf.name_scope('perform/'): 130 observ = self._observ_filter.transform(observ) 131 if self._last_state is None: 132 state = None 133 else: 134 state = tf.contrib.framework.nest.map_structure(lambda x: tf.gather(x, agent_indices), 135 self._last_state) 136 output = self._network(observ[:, None], tf.ones(observ.shape[0]), state) 137 action = tf.cond(self._is_training, output.policy.sample, lambda: output.mean) 138 logprob = output.policy.log_prob(action)[:, 0] 139 # pylint: disable=g-long-lambda 140 summary = tf.cond( 141 self._should_log, lambda: tf.summary.merge([ 142 tf.summary.histogram('mean', output.mean[:, 0]), 143 tf.summary.histogram('std', tf.exp(output.logstd[:, 0])), 144 tf.summary.histogram('action', action[:, 0]), 145 tf.summary.histogram('logprob', logprob) 146 ]), str) 147 # Remember current policy to append to memory in the experience callback. 148 if self._last_state is None: 149 assign_state = tf.no_op() 150 else: 151 assign_state = utility.assign_nested_vars(self._last_state, output.state, agent_indices) 152 with tf.control_dependencies([ 153 assign_state, 154 tf.scatter_update(self._last_action, agent_indices, action[:, 0]), 155 tf.scatter_update(self._last_mean, agent_indices, output.mean[:, 0]), 156 tf.scatter_update(self._last_logstd, agent_indices, output.logstd[:, 0]) 157 ]): 158 return tf.check_numerics(action[:, 0], 'action'), tf.identity(summary) 159 160 def experience(self, agent_indices, observ, action, reward, unused_done, unused_nextob): 161 """Process the transition tuple of the current step. 162 163 When training, add the current transition tuple to the memory and update 164 the streaming statistics for observations and rewards. A summary string is 165 returned if requested at this step. 166 167 Args: 168 agent_indices: Tensor containing current batch indices. 169 observ: Batch tensor of observations. 170 action: Batch tensor of actions. 171 reward: Batch tensor of rewards. 172 unused_done: Batch tensor of done flags. 173 unused_nextob: Batch tensor of successor observations. 174 175 Returns: 176 Summary tensor. 177 """ 178 with tf.name_scope('experience/'): 179 return tf.cond( 180 self._is_training, 181 # pylint: disable=g-long-lambda 182 lambda: self._define_experience(agent_indices, observ, action, reward), 183 str) 184 185 def _define_experience(self, agent_indices, observ, action, reward): 186 """Implement the branch of experience() entered during training.""" 187 update_filters = tf.summary.merge( 188 [self._observ_filter.update(observ), 189 self._reward_filter.update(reward)]) 190 with tf.control_dependencies([update_filters]): 191 if self._config.train_on_agent_action: 192 # NOTE: Doesn't seem to change much. 193 action = self._last_action 194 batch = (observ, action, tf.gather(self._last_mean, 195 agent_indices), tf.gather(self._last_logstd, 196 agent_indices), reward) 197 append = self._episodes.append(batch, agent_indices) 198 with tf.control_dependencies([append]): 199 norm_observ = self._observ_filter.transform(observ) 200 norm_reward = tf.reduce_mean(self._reward_filter.transform(reward)) 201 # pylint: disable=g-long-lambda 202 summary = tf.cond( 203 self._should_log, lambda: tf.summary.merge([ 204 update_filters, 205 self._observ_filter.summary(), 206 self._reward_filter.summary(), 207 tf.summary.scalar('memory_size', self._memory_index), 208 tf.summary.histogram('normalized_observ', norm_observ), 209 tf.summary.histogram('action', self._last_action), 210 tf.summary.scalar('normalized_reward', norm_reward) 211 ]), str) 212 return summary 213 214 def end_episode(self, agent_indices): 215 """Add episodes to the memory and perform update steps if memory is full. 216 217 During training, add the collected episodes of the batch indices that 218 finished their episode to the memory. If the memory is full, train on it, 219 and then clear the memory. A summary string is returned if requested at 220 this step. 221 222 Args: 223 agent_indices: Tensor containing current batch indices. 224 225 Returns: 226 Summary tensor. 227 """ 228 with tf.name_scope('end_episode/'): 229 return tf.cond(self._is_training, lambda: self._define_end_episode(agent_indices), str) 230 231 def _define_end_episode(self, agent_indices): 232 """Implement the branch of end_episode() entered during training.""" 233 episodes, length = self._episodes.data(agent_indices) 234 space_left = self._config.update_every - self._memory_index 235 use_episodes = tf.range(tf.minimum(tf.shape(agent_indices)[0], space_left)) 236 episodes = [tf.gather(elem, use_episodes) for elem in episodes] 237 append = self._memory.replace(episodes, tf.gather(length, use_episodes), 238 use_episodes + self._memory_index) 239 with tf.control_dependencies([append]): 240 inc_index = self._memory_index.assign_add(tf.shape(use_episodes)[0]) 241 with tf.control_dependencies([inc_index]): 242 memory_full = self._memory_index >= self._config.update_every 243 return tf.cond(memory_full, self._training, str) 244 245 def _training(self): 246 """Perform multiple training iterations of both policy and value baseline. 247 248 Training on the episodes collected in the memory. Reset the memory 249 afterwards. Always returns a summary string. 250 251 Returns: 252 Summary tensor. 253 """ 254 with tf.name_scope('training'): 255 assert_full = tf.assert_equal(self._memory_index, self._config.update_every) 256 with tf.control_dependencies([assert_full]): 257 data = self._memory.data() 258 (observ, action, old_mean, old_logstd, reward), length = data 259 with tf.control_dependencies([tf.assert_greater(length, 0)]): 260 length = tf.identity(length) 261 observ = self._observ_filter.transform(observ) 262 reward = self._reward_filter.transform(reward) 263 update_summary = self._perform_update_steps(observ, action, old_mean, old_logstd, reward, 264 length) 265 with tf.control_dependencies([update_summary]): 266 penalty_summary = self._adjust_penalty(observ, old_mean, old_logstd, length) 267 with tf.control_dependencies([penalty_summary]): 268 clear_memory = tf.group(self._memory.clear(), self._memory_index.assign(0)) 269 with tf.control_dependencies([clear_memory]): 270 weight_summary = utility.variable_summaries(tf.trainable_variables(), 271 self._config.weight_summaries) 272 return tf.summary.merge([update_summary, penalty_summary, weight_summary]) 273 274 def _perform_update_steps(self, observ, action, old_mean, old_logstd, reward, length): 275 """Perform multiple update steps of value function and policy. 276 277 The advantage is computed once at the beginning and shared across 278 iterations. We need to decide for the summary of one iteration, and thus 279 choose the one after half of the iterations. 280 281 Args: 282 observ: Sequences of observations. 283 action: Sequences of actions. 284 old_mean: Sequences of action means of the behavioral policy. 285 old_logstd: Sequences of action log stddevs of the behavioral policy. 286 reward: Sequences of rewards. 287 length: Batch of sequence lengths. 288 289 Returns: 290 Summary tensor. 291 """ 292 return_ = utility.discounted_return(reward, length, self._config.discount) 293 value = self._network(observ, length).value 294 if self._config.gae_lambda: 295 advantage = utility.lambda_return(reward, value, length, self._config.discount, 296 self._config.gae_lambda) 297 else: 298 advantage = return_ - value 299 mean, variance = tf.nn.moments(advantage, axes=[0, 1], keep_dims=True) 300 advantage = (advantage - mean) / (tf.sqrt(variance) + 1e-8) 301 advantage = tf.Print(advantage, 302 [tf.reduce_mean(return_), tf.reduce_mean(value)], 'return and value: ') 303 advantage = tf.Print(advantage, [tf.reduce_mean(advantage)], 'normalized advantage: ') 304 # pylint: disable=g-long-lambda 305 value_loss, policy_loss, summary = tf.scan(lambda _1, _2: self._update_step( 306 observ, action, old_mean, old_logstd, reward, advantage, length), 307 tf.range(self._config.update_epochs), [0., 0., ''], 308 parallel_iterations=1) 309 print_losses = tf.group(tf.Print(0, [tf.reduce_mean(value_loss)], 'value loss: '), 310 tf.Print(0, [tf.reduce_mean(policy_loss)], 'policy loss: ')) 311 with tf.control_dependencies([value_loss, policy_loss, print_losses]): 312 return summary[self._config.update_epochs // 2] 313 314 def _update_step(self, observ, action, old_mean, old_logstd, reward, advantage, length): 315 """Compute the current combined loss and perform a gradient update step. 316 317 Args: 318 observ: Sequences of observations. 319 action: Sequences of actions. 320 old_mean: Sequences of action means of the behavioral policy. 321 old_logstd: Sequences of action log stddevs of the behavioral policy. 322 reward: Sequences of reward. 323 advantage: Sequences of advantages. 324 length: Batch of sequence lengths. 325 326 Returns: 327 Tuple of value loss, policy loss, and summary tensor. 328 """ 329 value_loss, value_summary = self._value_loss(observ, reward, length) 330 network = self._network(observ, length) 331 policy_loss, policy_summary = self._policy_loss(network.mean, network.logstd, old_mean, 332 old_logstd, action, advantage, length) 333 value_gradients, value_variables = (zip(*self._optimizer.compute_gradients(value_loss))) 334 policy_gradients, policy_variables = (zip(*self._optimizer.compute_gradients(policy_loss))) 335 all_gradients = value_gradients + policy_gradients 336 all_variables = value_variables + policy_variables 337 optimize = self._optimizer.apply_gradients(zip(all_gradients, all_variables)) 338 summary = tf.summary.merge([ 339 value_summary, policy_summary, 340 tf.summary.scalar('value_gradient_norm', tf.global_norm(value_gradients)), 341 tf.summary.scalar('policy_gradient_norm', tf.global_norm(policy_gradients)), 342 utility.gradient_summaries(zip(value_gradients, value_variables), dict(value=r'.*')), 343 utility.gradient_summaries(zip(policy_gradients, policy_variables), dict(policy=r'.*')) 344 ]) 345 with tf.control_dependencies([optimize]): 346 return [tf.identity(x) for x in (value_loss, policy_loss, summary)] 347 348 def _value_loss(self, observ, reward, length): 349 """Compute the loss function for the value baseline. 350 351 The value loss is the difference between empirical and approximated returns 352 over the collected episodes. Returns the loss tensor and a summary strin. 353 354 Args: 355 observ: Sequences of observations. 356 reward: Sequences of reward. 357 length: Batch of sequence lengths. 358 359 Returns: 360 Tuple of loss tensor and summary tensor. 361 """ 362 with tf.name_scope('value_loss'): 363 value = self._network(observ, length).value 364 return_ = utility.discounted_return(reward, length, self._config.discount) 365 advantage = return_ - value 366 value_loss = 0.5 * self._mask(advantage**2, length) 367 summary = tf.summary.merge([ 368 tf.summary.histogram('value_loss', value_loss), 369 tf.summary.scalar('avg_value_loss', tf.reduce_mean(value_loss)) 370 ]) 371 value_loss = tf.reduce_mean(value_loss) 372 return tf.check_numerics(value_loss, 'value_loss'), summary 373 374 def _policy_loss(self, mean, logstd, old_mean, old_logstd, action, advantage, length): 375 """Compute the policy loss composed of multiple components. 376 377 1. The policy gradient loss is importance sampled from the data-collecting 378 policy at the beginning of training. 379 2. The second term is a KL penalty between the policy at the beginning of 380 training and the current policy. 381 3. Additionally, if this KL already changed more than twice the target 382 amount, we activate a strong penalty discouraging further divergence. 383 384 Args: 385 mean: Sequences of action means of the current policy. 386 logstd: Sequences of action log stddevs of the current policy. 387 old_mean: Sequences of action means of the behavioral policy. 388 old_logstd: Sequences of action log stddevs of the behavioral policy. 389 action: Sequences of actions. 390 advantage: Sequences of advantages. 391 length: Batch of sequence lengths. 392 393 Returns: 394 Tuple of loss tensor and summary tensor. 395 """ 396 with tf.name_scope('policy_loss'): 397 entropy = utility.diag_normal_entropy(mean, logstd) 398 kl = tf.reduce_mean( 399 self._mask(utility.diag_normal_kl(old_mean, old_logstd, mean, logstd), length), 1) 400 policy_gradient = tf.exp( 401 utility.diag_normal_logpdf(mean, logstd, action) - 402 utility.diag_normal_logpdf(old_mean, old_logstd, action)) 403 surrogate_loss = -tf.reduce_mean( 404 self._mask(policy_gradient * tf.stop_gradient(advantage), length), 1) 405 kl_penalty = self._penalty * kl 406 cutoff_threshold = self._config.kl_target * self._config.kl_cutoff_factor 407 cutoff_count = tf.reduce_sum(tf.cast(kl > cutoff_threshold, tf.int32)) 408 with tf.control_dependencies( 409 [tf.cond(cutoff_count > 0, lambda: tf.Print(0, [cutoff_count], 'kl cutoff! '), int)]): 410 kl_cutoff = (self._config.kl_cutoff_coef * tf.cast(kl > cutoff_threshold, tf.float32) * 411 (kl - cutoff_threshold)**2) 412 policy_loss = surrogate_loss + kl_penalty + kl_cutoff 413 summary = tf.summary.merge([ 414 tf.summary.histogram('entropy', entropy), 415 tf.summary.histogram('kl', kl), 416 tf.summary.histogram('surrogate_loss', surrogate_loss), 417 tf.summary.histogram('kl_penalty', kl_penalty), 418 tf.summary.histogram('kl_cutoff', kl_cutoff), 419 tf.summary.histogram('kl_penalty_combined', kl_penalty + kl_cutoff), 420 tf.summary.histogram('policy_loss', policy_loss), 421 tf.summary.scalar('avg_surr_loss', tf.reduce_mean(surrogate_loss)), 422 tf.summary.scalar('avg_kl_penalty', tf.reduce_mean(kl_penalty)), 423 tf.summary.scalar('avg_policy_loss', tf.reduce_mean(policy_loss)) 424 ]) 425 policy_loss = tf.reduce_mean(policy_loss, 0) 426 return tf.check_numerics(policy_loss, 'policy_loss'), summary 427 428 def _adjust_penalty(self, observ, old_mean, old_logstd, length): 429 """Adjust the KL policy between the behavioral and current policy. 430 431 Compute how much the policy actually changed during the multiple 432 update steps. Adjust the penalty strength for the next training phase if we 433 overshot or undershot the target divergence too much. 434 435 Args: 436 observ: Sequences of observations. 437 old_mean: Sequences of action means of the behavioral policy. 438 old_logstd: Sequences of action log stddevs of the behavioral policy. 439 length: Batch of sequence lengths. 440 441 Returns: 442 Summary tensor. 443 """ 444 with tf.name_scope('adjust_penalty'): 445 network = self._network(observ, length) 446 assert_change = tf.assert_equal(tf.reduce_all(tf.equal(network.mean, old_mean)), 447 False, 448 message='policy should change') 449 print_penalty = tf.Print(0, [self._penalty], 'current penalty: ') 450 with tf.control_dependencies([assert_change, print_penalty]): 451 kl_change = tf.reduce_mean( 452 self._mask(utility.diag_normal_kl(old_mean, old_logstd, network.mean, network.logstd), 453 length)) 454 kl_change = tf.Print(kl_change, [kl_change], 'kl change: ') 455 maybe_increase = tf.cond( 456 kl_change > 1.3 * self._config.kl_target, 457 # pylint: disable=g-long-lambda 458 lambda: tf.Print(self._penalty.assign(self._penalty * 1.5), [0], 'increase penalty '), 459 float) 460 maybe_decrease = tf.cond( 461 kl_change < 0.7 * self._config.kl_target, 462 # pylint: disable=g-long-lambda 463 lambda: tf.Print(self._penalty.assign(self._penalty / 1.5), [0], 'decrease penalty '), 464 float) 465 with tf.control_dependencies([maybe_increase, maybe_decrease]): 466 return tf.summary.merge([ 467 tf.summary.scalar('kl_change', kl_change), 468 tf.summary.scalar('penalty', self._penalty) 469 ]) 470 471 def _mask(self, tensor, length): 472 """Set padding elements of a batch of sequences to zero. 473 474 Useful to then safely sum along the time dimension. 475 476 Args: 477 tensor: Tensor of sequences. 478 length: Batch of sequence lengths. 479 480 Returns: 481 Masked sequences. 482 """ 483 with tf.name_scope('mask'): 484 range_ = tf.range(tensor.shape[1].value) 485 mask = tf.cast(range_[None, :] < length[:, None], tf.float32) 486 masked = tensor * mask 487 return tf.check_numerics(masked, 'masked') 488