1# Copyright 2017 The TensorFlow Agents Authors. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14"""Proximal Policy Optimization algorithm. 15 16Based on John Schulman's implementation in Python and Theano: 17https://github.com/joschu/modular_rl/blob/master/modular_rl/ppo.py 18""" 19 20from __future__ import absolute_import 21from __future__ import division 22from __future__ import print_function 23 24import collections 25 26import tf.compat.v1 as tf 27 28from . import memory 29from . import normalize 30from . import utility 31 32_NetworkOutput = collections.namedtuple('NetworkOutput', 'policy, mean, logstd, value, state') 33 34 35class PPOAlgorithm(object): 36 """A vectorized implementation of the PPO algorithm by John Schulman.""" 37 38 def __init__(self, batch_env, step, is_training, should_log, config): 39 """Create an instance of the PPO algorithm. 40 41 Args: 42 batch_env: In-graph batch environment. 43 step: Integer tensor holding the current training step. 44 is_training: Boolean tensor for whether the algorithm should train. 45 should_log: Boolean tensor for whether summaries should be returned. 46 config: Object containing the agent configuration as attributes. 47 """ 48 self._batch_env = batch_env 49 self._step = step 50 self._is_training = is_training 51 self._should_log = should_log 52 self._config = config 53 self._observ_filter = normalize.StreamingNormalize(self._batch_env.observ[0], 54 center=True, 55 scale=True, 56 clip=5, 57 name='normalize_observ') 58 self._reward_filter = normalize.StreamingNormalize(self._batch_env.reward[0], 59 center=False, 60 scale=True, 61 clip=10, 62 name='normalize_reward') 63 # Memory stores tuple of observ, action, mean, logstd, reward. 64 template = (self._batch_env.observ[0], self._batch_env.action[0], self._batch_env.action[0], 65 self._batch_env.action[0], self._batch_env.reward[0]) 66 self._memory = memory.EpisodeMemory(template, config.update_every, config.max_length, 'memory') 67 self._memory_index = tf.Variable(0, False) 68 use_gpu = self._config.use_gpu and utility.available_gpus() 69 with tf.device('/gpu:0' if use_gpu else '/cpu:0'): 70 # Create network variables for later calls to reuse. 71 self._network(tf.zeros_like(self._batch_env.observ)[:, None], 72 tf.ones(len(self._batch_env)), 73 reuse=None) 74 cell = self._config.network(self._batch_env.action.shape[1].value) 75 with tf.variable_scope('ppo_temporary'): 76 self._episodes = memory.EpisodeMemory(template, len(batch_env), config.max_length, 77 'episodes') 78 self._last_state = utility.create_nested_vars(cell.zero_state(len(batch_env), tf.float32)) 79 self._last_action = tf.Variable(tf.zeros_like(self._batch_env.action), 80 False, 81 name='last_action') 82 self._last_mean = tf.Variable(tf.zeros_like(self._batch_env.action), 83 False, 84 name='last_mean') 85 self._last_logstd = tf.Variable(tf.zeros_like(self._batch_env.action), 86 False, 87 name='last_logstd') 88 self._penalty = tf.Variable(self._config.kl_init_penalty, False, dtype=tf.float32) 89 self._policy_optimizer = self._config.policy_optimizer(self._config.policy_lr, 90 name='policy_optimizer') 91 self._value_optimizer = self._config.value_optimizer(self._config.value_lr, 92 name='value_optimizer') 93 94 def begin_episode(self, agent_indices): 95 """Reset the recurrent states and stored episode. 96 97 Args: 98 agent_indices: 1D tensor of batch indices for agents starting an episode. 99 100 Returns: 101 Summary tensor. 102 """ 103 with tf.name_scope('begin_episode/'): 104 reset_state = utility.reinit_nested_vars(self._last_state, agent_indices) 105 reset_buffer = self._episodes.clear(agent_indices) 106 with tf.control_dependencies([reset_state, reset_buffer]): 107 return tf.constant('') 108 109 def perform(self, observ): 110 """Compute batch of actions and a summary for a batch of observation. 111 112 Args: 113 observ: Tensor of a batch of observations for all agents. 114 115 Returns: 116 Tuple of action batch tensor and summary tensor. 117 """ 118 with tf.name_scope('perform/'): 119 observ = self._observ_filter.transform(observ) 120 network = self._network(observ[:, None], tf.ones(observ.shape[0]), self._last_state) 121 action = tf.cond(self._is_training, network.policy.sample, lambda: network.mean) 122 logprob = network.policy.log_prob(action)[:, 0] 123 # pylint: disable=g-long-lambda 124 summary = tf.cond( 125 self._should_log, lambda: tf.summary.merge([ 126 tf.summary.histogram('mean', network.mean[:, 0]), 127 tf.summary.histogram('std', tf.exp(network.logstd[:, 0])), 128 tf.summary.histogram('action', action[:, 0]), 129 tf.summary.histogram('logprob', logprob) 130 ]), str) 131 # Remember current policy to append to memory in the experience callback. 132 with tf.control_dependencies([ 133 utility.assign_nested_vars(self._last_state, network.state), 134 self._last_action.assign(action[:, 0]), 135 self._last_mean.assign(network.mean[:, 0]), 136 self._last_logstd.assign(network.logstd[:, 0]) 137 ]): 138 return tf.check_numerics(action[:, 0], 'action'), tf.identity(summary) 139 140 def experience(self, observ, action, reward, unused_done, unused_nextob): 141 """Process the transition tuple of the current step. 142 143 When training, add the current transition tuple to the memory and update 144 the streaming statistics for observations and rewards. A summary string is 145 returned if requested at this step. 146 147 Args: 148 observ: Batch tensor of observations. 149 action: Batch tensor of actions. 150 reward: Batch tensor of rewards. 151 unused_done: Batch tensor of done flags. 152 unused_nextob: Batch tensor of successor observations. 153 154 Returns: 155 Summary tensor. 156 """ 157 with tf.name_scope('experience/'): 158 return tf.cond(self._is_training, lambda: self._define_experience(observ, action, reward), 159 str) 160 161 def _define_experience(self, observ, action, reward): 162 """Implement the branch of experience() entered during training.""" 163 update_filters = tf.summary.merge( 164 [self._observ_filter.update(observ), 165 self._reward_filter.update(reward)]) 166 with tf.control_dependencies([update_filters]): 167 if self._config.train_on_agent_action: 168 # NOTE: Doesn't seem to change much. 169 action = self._last_action 170 batch = observ, action, self._last_mean, self._last_logstd, reward 171 append = self._episodes.append(batch, tf.range(len(self._batch_env))) 172 with tf.control_dependencies([append]): 173 norm_observ = self._observ_filter.transform(observ) 174 norm_reward = tf.reduce_mean(self._reward_filter.transform(reward)) 175 # pylint: disable=g-long-lambda 176 summary = tf.cond( 177 self._should_log, lambda: tf.summary.merge([ 178 update_filters, 179 self._observ_filter.summary(), 180 self._reward_filter.summary(), 181 tf.summary.scalar('memory_size', self._memory_index), 182 tf.summary.histogram('normalized_observ', norm_observ), 183 tf.summary.histogram('action', self._last_action), 184 tf.summary.scalar('normalized_reward', norm_reward) 185 ]), str) 186 return summary 187 188 def end_episode(self, agent_indices): 189 """Add episodes to the memory and perform update steps if memory is full. 190 191 During training, add the collected episodes of the batch indices that 192 finished their episode to the memory. If the memory is full, train on it, 193 and then clear the memory. A summary string is returned if requested at 194 this step. 195 196 Args: 197 agent_indices: 1D tensor of batch indices for agents starting an episode. 198 199 Returns: 200 Summary tensor. 201 """ 202 with tf.name_scope('end_episode/'): 203 return tf.cond(self._is_training, lambda: self._define_end_episode(agent_indices), str) 204 205 def _define_end_episode(self, agent_indices): 206 """Implement the branch of end_episode() entered during training.""" 207 episodes, length = self._episodes.data(agent_indices) 208 space_left = self._config.update_every - self._memory_index 209 use_episodes = tf.range(tf.minimum(tf.shape(agent_indices)[0], space_left)) 210 episodes = [tf.gather(elem, use_episodes) for elem in episodes] 211 append = self._memory.replace(episodes, tf.gather(length, use_episodes), 212 use_episodes + self._memory_index) 213 with tf.control_dependencies([append]): 214 inc_index = self._memory_index.assign_add(tf.shape(use_episodes)[0]) 215 with tf.control_dependencies([inc_index]): 216 memory_full = self._memory_index >= self._config.update_every 217 return tf.cond(memory_full, self._training, str) 218 219 def _training(self): 220 """Perform multiple training iterations of both policy and value baseline. 221 222 Training on the episodes collected in the memory. Reset the memory 223 afterwards. Always returns a summary string. 224 225 Returns: 226 Summary tensor. 227 """ 228 with tf.name_scope('training'): 229 assert_full = tf.assert_equal(self._memory_index, self._config.update_every) 230 with tf.control_dependencies([assert_full]): 231 data = self._memory.data() 232 (observ, action, old_mean, old_logstd, reward), length = data 233 with tf.control_dependencies([tf.assert_greater(length, 0)]): 234 length = tf.identity(length) 235 observ = self._observ_filter.transform(observ) 236 reward = self._reward_filter.transform(reward) 237 policy_summary = self._update_policy(observ, action, old_mean, old_logstd, reward, length) 238 with tf.control_dependencies([policy_summary]): 239 value_summary = self._update_value(observ, reward, length) 240 with tf.control_dependencies([value_summary]): 241 penalty_summary = self._adjust_penalty(observ, old_mean, old_logstd, length) 242 with tf.control_dependencies([penalty_summary]): 243 clear_memory = tf.group(self._memory.clear(), self._memory_index.assign(0)) 244 with tf.control_dependencies([clear_memory]): 245 weight_summary = utility.variable_summaries(tf.trainable_variables(), 246 self._config.weight_summaries) 247 return tf.summary.merge([policy_summary, value_summary, penalty_summary, weight_summary]) 248 249 def _update_value(self, observ, reward, length): 250 """Perform multiple update steps of the value baseline. 251 252 We need to decide for the summary of one iteration, and thus choose the one 253 after half of the iterations. 254 255 Args: 256 observ: Sequences of observations. 257 reward: Sequences of reward. 258 length: Batch of sequence lengths. 259 260 Returns: 261 Summary tensor. 262 """ 263 with tf.name_scope('update_value'): 264 loss, summary = tf.scan(lambda _1, _2: self._update_value_step(observ, reward, length), 265 tf.range(self._config.update_epochs_value), [0., ''], 266 parallel_iterations=1) 267 print_loss = tf.Print(0, [tf.reduce_mean(loss)], 'value loss: ') 268 with tf.control_dependencies([loss, print_loss]): 269 return summary[self._config.update_epochs_value // 2] 270 271 def _update_value_step(self, observ, reward, length): 272 """Compute the current value loss and perform a gradient update step. 273 274 Args: 275 observ: Sequences of observations. 276 reward: Sequences of reward. 277 length: Batch of sequence lengths. 278 279 Returns: 280 Tuple of loss tensor and summary tensor. 281 """ 282 loss, summary = self._value_loss(observ, reward, length) 283 gradients, variables = (zip(*self._value_optimizer.compute_gradients(loss))) 284 optimize = self._value_optimizer.apply_gradients(zip(gradients, variables)) 285 summary = tf.summary.merge([ 286 summary, 287 tf.summary.scalar('gradient_norm', tf.global_norm(gradients)), 288 utility.gradient_summaries(zip(gradients, variables), dict(value=r'.*')) 289 ]) 290 with tf.control_dependencies([optimize]): 291 return [tf.identity(loss), tf.identity(summary)] 292 293 def _value_loss(self, observ, reward, length): 294 """Compute the loss function for the value baseline. 295 296 The value loss is the difference between empirical and approximated returns 297 over the collected episodes. Returns the loss tensor and a summary strin. 298 299 Args: 300 observ: Sequences of observations. 301 reward: Sequences of reward. 302 length: Batch of sequence lengths. 303 304 Returns: 305 Tuple of loss tensor and summary tensor. 306 """ 307 with tf.name_scope('value_loss'): 308 value = self._network(observ, length).value 309 return_ = utility.discounted_return(reward, length, self._config.discount) 310 advantage = return_ - value 311 value_loss = 0.5 * self._mask(advantage**2, length) 312 summary = tf.summary.merge([ 313 tf.summary.histogram('value_loss', value_loss), 314 tf.summary.scalar('avg_value_loss', tf.reduce_mean(value_loss)) 315 ]) 316 value_loss = tf.reduce_mean(value_loss) 317 return tf.check_numerics(value_loss, 'value_loss'), summary 318 319 def _update_policy(self, observ, action, old_mean, old_logstd, reward, length): 320 """Perform multiple update steps of the policy. 321 322 The advantage is computed once at the beginning and shared across 323 iterations. We need to decide for the summary of one iteration, and thus 324 choose the one after half of the iterations. 325 326 Args: 327 observ: Sequences of observations. 328 action: Sequences of actions. 329 old_mean: Sequences of action means of the behavioral policy. 330 old_logstd: Sequences of action log stddevs of the behavioral policy. 331 reward: Sequences of rewards. 332 length: Batch of sequence lengths. 333 334 Returns: 335 Summary tensor. 336 """ 337 with tf.name_scope('update_policy'): 338 return_ = utility.discounted_return(reward, length, self._config.discount) 339 value = self._network(observ, length).value 340 if self._config.gae_lambda: 341 advantage = utility.lambda_return(reward, value, length, self._config.discount, 342 self._config.gae_lambda) 343 else: 344 advantage = return_ - value 345 mean, variance = tf.nn.moments(advantage, axes=[0, 1], keep_dims=True) 346 advantage = (advantage - mean) / (tf.sqrt(variance) + 1e-8) 347 advantage = tf.Print( 348 advantage, [tf.reduce_mean(return_), tf.reduce_mean(value)], 'return and value: ') 349 advantage = tf.Print(advantage, [tf.reduce_mean(advantage)], 'normalized advantage: ') 350 # pylint: disable=g-long-lambda 351 loss, summary = tf.scan(lambda _1, _2: self._update_policy_step( 352 observ, action, old_mean, old_logstd, advantage, length), 353 tf.range(self._config.update_epochs_policy), [0., ''], 354 parallel_iterations=1) 355 print_loss = tf.Print(0, [tf.reduce_mean(loss)], 'policy loss: ') 356 with tf.control_dependencies([loss, print_loss]): 357 return summary[self._config.update_epochs_policy // 2] 358 359 def _update_policy_step(self, observ, action, old_mean, old_logstd, advantage, length): 360 """Compute the current policy loss and perform a gradient update step. 361 362 Args: 363 observ: Sequences of observations. 364 action: Sequences of actions. 365 old_mean: Sequences of action means of the behavioral policy. 366 old_logstd: Sequences of action log stddevs of the behavioral policy. 367 advantage: Sequences of advantages. 368 length: Batch of sequence lengths. 369 370 Returns: 371 Tuple of loss tensor and summary tensor. 372 """ 373 network = self._network(observ, length) 374 loss, summary = self._policy_loss(network.mean, network.logstd, old_mean, old_logstd, action, 375 advantage, length) 376 gradients, variables = (zip(*self._policy_optimizer.compute_gradients(loss))) 377 optimize = self._policy_optimizer.apply_gradients(zip(gradients, variables)) 378 summary = tf.summary.merge([ 379 summary, 380 tf.summary.scalar('gradient_norm', tf.global_norm(gradients)), 381 utility.gradient_summaries(zip(gradients, variables), dict(policy=r'.*')) 382 ]) 383 with tf.control_dependencies([optimize]): 384 return [tf.identity(loss), tf.identity(summary)] 385 386 def _policy_loss(self, mean, logstd, old_mean, old_logstd, action, advantage, length): 387 """Compute the policy loss composed of multiple components. 388 389 1. The policy gradient loss is importance sampled from the data-collecting 390 policy at the beginning of training. 391 2. The second term is a KL penalty between the policy at the beginning of 392 training and the current policy. 393 3. Additionally, if this KL already changed more than twice the target 394 amount, we activate a strong penalty discouraging further divergence. 395 396 Args: 397 mean: Sequences of action means of the current policy. 398 logstd: Sequences of action log stddevs of the current policy. 399 old_mean: Sequences of action means of the behavioral policy. 400 old_logstd: Sequences of action log stddevs of the behavioral policy. 401 action: Sequences of actions. 402 advantage: Sequences of advantages. 403 length: Batch of sequence lengths. 404 405 Returns: 406 Tuple of loss tensor and summary tensor. 407 """ 408 with tf.name_scope('policy_loss'): 409 entropy = utility.diag_normal_entropy(mean, logstd) 410 kl = tf.reduce_mean( 411 self._mask(utility.diag_normal_kl(old_mean, old_logstd, mean, logstd), length), 1) 412 policy_gradient = tf.exp( 413 utility.diag_normal_logpdf(mean, logstd, action) - 414 utility.diag_normal_logpdf(old_mean, old_logstd, action)) 415 surrogate_loss = -tf.reduce_mean( 416 self._mask(policy_gradient * tf.stop_gradient(advantage), length), 1) 417 kl_penalty = self._penalty * kl 418 cutoff_threshold = self._config.kl_target * self._config.kl_cutoff_factor 419 cutoff_count = tf.reduce_sum(tf.cast(kl > cutoff_threshold, tf.int32)) 420 with tf.control_dependencies( 421 [tf.cond(cutoff_count > 0, lambda: tf.Print(0, [cutoff_count], 'kl cutoff! '), int)]): 422 kl_cutoff = (self._config.kl_cutoff_coef * tf.cast(kl > cutoff_threshold, tf.float32) * 423 (kl - cutoff_threshold)**2) 424 policy_loss = surrogate_loss + kl_penalty + kl_cutoff 425 summary = tf.summary.merge([ 426 tf.summary.histogram('entropy', entropy), 427 tf.summary.histogram('kl', kl), 428 tf.summary.histogram('surrogate_loss', surrogate_loss), 429 tf.summary.histogram('kl_penalty', kl_penalty), 430 tf.summary.histogram('kl_cutoff', kl_cutoff), 431 tf.summary.histogram('kl_penalty_combined', kl_penalty + kl_cutoff), 432 tf.summary.histogram('policy_loss', policy_loss), 433 tf.summary.scalar('avg_surr_loss', tf.reduce_mean(surrogate_loss)), 434 tf.summary.scalar('avg_kl_penalty', tf.reduce_mean(kl_penalty)), 435 tf.summary.scalar('avg_policy_loss', tf.reduce_mean(policy_loss)) 436 ]) 437 policy_loss = tf.reduce_mean(policy_loss, 0) 438 return tf.check_numerics(policy_loss, 'policy_loss'), summary 439 440 def _adjust_penalty(self, observ, old_mean, old_logstd, length): 441 """Adjust the KL policy between the behavioral and current policy. 442 443 Compute how much the policy actually changed during the multiple 444 update steps. Adjust the penalty strength for the next training phase if we 445 overshot or undershot the target divergence too much. 446 447 Args: 448 observ: Sequences of observations. 449 old_mean: Sequences of action means of the behavioral policy. 450 old_logstd: Sequences of action log stddevs of the behavioral policy. 451 length: Batch of sequence lengths. 452 453 Returns: 454 Summary tensor. 455 """ 456 with tf.name_scope('adjust_penalty'): 457 network = self._network(observ, length) 458 assert_change = tf.assert_equal(tf.reduce_all(tf.equal(network.mean, old_mean)), 459 False, 460 message='policy should change') 461 print_penalty = tf.Print(0, [self._penalty], 'current penalty: ') 462 with tf.control_dependencies([assert_change, print_penalty]): 463 kl_change = tf.reduce_mean( 464 self._mask(utility.diag_normal_kl(old_mean, old_logstd, network.mean, network.logstd), 465 length)) 466 kl_change = tf.Print(kl_change, [kl_change], 'kl change: ') 467 maybe_increase = tf.cond( 468 kl_change > 1.3 * self._config.kl_target, 469 # pylint: disable=g-long-lambda 470 lambda: tf.Print(self._penalty.assign(self._penalty * 1.5), [0], 'increase penalty '), 471 float) 472 maybe_decrease = tf.cond( 473 kl_change < 0.7 * self._config.kl_target, 474 # pylint: disable=g-long-lambda 475 lambda: tf.Print(self._penalty.assign(self._penalty / 1.5), [0], 'decrease penalty '), 476 float) 477 with tf.control_dependencies([maybe_increase, maybe_decrease]): 478 return tf.summary.merge([ 479 tf.summary.scalar('kl_change', kl_change), 480 tf.summary.scalar('penalty', self._penalty) 481 ]) 482 483 def _mask(self, tensor, length): 484 """Set padding elements of a batch of sequences to zero. 485 486 Useful to then safely sum along the time dimension. 487 488 Args: 489 tensor: Tensor of sequences. 490 length: Batch of sequence lengths. 491 492 Returns: 493 Masked sequences. 494 """ 495 with tf.name_scope('mask'): 496 range_ = tf.range(tensor.shape[1].value) 497 mask = tf.cast(range_[None, :] < length[:, None], tf.float32) 498 masked = tensor * mask 499 return tf.check_numerics(masked, 'masked') 500 501 def _network(self, observ, length=None, state=None, reuse=True): 502 """Compute the network output for a batched sequence of observations. 503 504 Optionally, the initial state can be specified. The weights should be 505 reused for all calls, except for the first one. Output is a named tuple 506 containing the policy as a TensorFlow distribution, the policy mean and log 507 standard deviation, the approximated state value, and the new recurrent 508 state. 509 510 Args: 511 observ: Sequences of observations. 512 length: Batch of sequence lengths. 513 state: Batch of initial recurrent states. 514 reuse: Python boolean whether to reuse previous variables. 515 516 Returns: 517 NetworkOutput tuple. 518 """ 519 with tf.variable_scope('network', reuse=reuse): 520 observ = tf.convert_to_tensor(observ) 521 use_gpu = self._config.use_gpu and utility.available_gpus() 522 with tf.device('/gpu:0' if use_gpu else '/cpu:0'): 523 observ = tf.check_numerics(observ, 'observ') 524 cell = self._config.network(self._batch_env.action.shape[1].value) 525 (mean, logstd, value), state = tf.nn.dynamic_rnn(cell, 526 observ, 527 length, 528 state, 529 tf.float32, 530 swap_memory=True) 531 mean = tf.check_numerics(mean, 'mean') 532 logstd = tf.check_numerics(logstd, 'logstd') 533 value = tf.check_numerics(value, 'value') 534 policy = tf.contrib.distributions.MultivariateNormalDiag(mean, tf.exp(logstd)) 535 return _NetworkOutput(policy, mean, logstd, value, state) 536