1# Copyright 2017 The TensorFlow Agents Authors.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#      http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14"""Proximal Policy Optimization algorithm.
15
16Based on John Schulman's implementation in Python and Theano:
17https://github.com/joschu/modular_rl/blob/master/modular_rl/ppo.py
18"""
19
20from __future__ import absolute_import
21from __future__ import division
22from __future__ import print_function
23
24import functools
25
26try:
27  import tensorflow.compat.v1 as tf
28except Exception:
29  import tensorflow as tf
30
31from . import memory
32from . import normalize
33from . import utility
34
35
36class PPOAlgorithm(object):
37  """A vectorized implementation of the PPO algorithm by John Schulman."""
38
39  def __init__(self, batch_env, step, is_training, should_log, config):
40    """Create an instance of the PPO algorithm.
41
42    Args:
43      batch_env: In-graph batch environment.
44      step: Integer tensor holding the current training step.
45      is_training: Boolean tensor for whether the algorithm should train.
46      should_log: Boolean tensor for whether summaries should be returned.
47      config: Object containing the agent configuration as attributes.
48    """
49    self._batch_env = batch_env
50    self._step = step
51    self._is_training = is_training
52    self._should_log = should_log
53    self._config = config
54    self._observ_filter = normalize.StreamingNormalize(self._batch_env.observ[0],
55                                                       center=True,
56                                                       scale=True,
57                                                       clip=5,
58                                                       name='normalize_observ')
59    self._reward_filter = normalize.StreamingNormalize(self._batch_env.reward[0],
60                                                       center=False,
61                                                       scale=True,
62                                                       clip=10,
63                                                       name='normalize_reward')
64    # Memory stores tuple of observ, action, mean, logstd, reward.
65    template = (self._batch_env.observ[0], self._batch_env.action[0], self._batch_env.action[0],
66                self._batch_env.action[0], self._batch_env.reward[0])
67    self._memory = memory.EpisodeMemory(template, config.update_every, config.max_length, 'memory')
68    self._memory_index = tf.Variable(0, False)
69    use_gpu = self._config.use_gpu and utility.available_gpus()
70    with tf.device('/gpu:0' if use_gpu else '/cpu:0'):
71      # Create network variables for later calls to reuse.
72      action_size = self._batch_env.action.shape[1].value
73      self._network = tf.make_template('network',
74                                       functools.partial(config.network, config, action_size))
75      output = self._network(
76          tf.zeros_like(self._batch_env.observ)[:, None], tf.ones(len(self._batch_env)))
77      with tf.variable_scope('ppo_temporary'):
78        self._episodes = memory.EpisodeMemory(template, len(batch_env), config.max_length,
79                                              'episodes')
80        if output.state is None:
81          self._last_state = None
82        else:
83          # Ensure the batch dimension is set.
84          tf.contrib.framework.nest.map_structure(
85              lambda x: x.set_shape([len(batch_env)] + x.shape.as_list()[1:]), output.state)
86          # pylint: disable=undefined-variable
87          self._last_state = tf.contrib.framework.nest.map_structure(
88              lambda x: tf.Variable(lambda: tf.zeros_like(x), False), output.state)
89        self._last_action = tf.Variable(tf.zeros_like(self._batch_env.action),
90                                        False,
91                                        name='last_action')
92        self._last_mean = tf.Variable(tf.zeros_like(self._batch_env.action),
93                                      False,
94                                      name='last_mean')
95        self._last_logstd = tf.Variable(tf.zeros_like(self._batch_env.action),
96                                        False,
97                                        name='last_logstd')
98    self._penalty = tf.Variable(self._config.kl_init_penalty, False, dtype=tf.float32)
99    self._optimizer = self._config.optimizer(self._config.learning_rate)
100
101  def begin_episode(self, agent_indices):
102    """Reset the recurrent states and stored episode.
103
104    Args:
105      agent_indices: Tensor containing current batch indices.
106
107    Returns:
108      Summary tensor.
109    """
110    with tf.name_scope('begin_episode/'):
111      if self._last_state is None:
112        reset_state = tf.no_op()
113      else:
114        reset_state = utility.reinit_nested_vars(self._last_state, agent_indices)
115      reset_buffer = self._episodes.clear(agent_indices)
116      with tf.control_dependencies([reset_state, reset_buffer]):
117        return tf.constant('')
118
119  def perform(self, agent_indices, observ):
120    """Compute batch of actions and a summary for a batch of observation.
121
122    Args:
123      agent_indices: Tensor containing current batch indices.
124      observ: Tensor of a batch of observations for all agents.
125
126    Returns:
127      Tuple of action batch tensor and summary tensor.
128    """
129    with tf.name_scope('perform/'):
130      observ = self._observ_filter.transform(observ)
131      if self._last_state is None:
132        state = None
133      else:
134        state = tf.contrib.framework.nest.map_structure(lambda x: tf.gather(x, agent_indices),
135                                                        self._last_state)
136      output = self._network(observ[:, None], tf.ones(observ.shape[0]), state)
137      action = tf.cond(self._is_training, output.policy.sample, lambda: output.mean)
138      logprob = output.policy.log_prob(action)[:, 0]
139      # pylint: disable=g-long-lambda
140      summary = tf.cond(
141          self._should_log, lambda: tf.summary.merge([
142              tf.summary.histogram('mean', output.mean[:, 0]),
143              tf.summary.histogram('std', tf.exp(output.logstd[:, 0])),
144              tf.summary.histogram('action', action[:, 0]),
145              tf.summary.histogram('logprob', logprob)
146          ]), str)
147      # Remember current policy to append to memory in the experience callback.
148      if self._last_state is None:
149        assign_state = tf.no_op()
150      else:
151        assign_state = utility.assign_nested_vars(self._last_state, output.state, agent_indices)
152      with tf.control_dependencies([
153          assign_state,
154          tf.scatter_update(self._last_action, agent_indices, action[:, 0]),
155          tf.scatter_update(self._last_mean, agent_indices, output.mean[:, 0]),
156          tf.scatter_update(self._last_logstd, agent_indices, output.logstd[:, 0])
157      ]):
158        return tf.check_numerics(action[:, 0], 'action'), tf.identity(summary)
159
160  def experience(self, agent_indices, observ, action, reward, unused_done, unused_nextob):
161    """Process the transition tuple of the current step.
162
163    When training, add the current transition tuple to the memory and update
164    the streaming statistics for observations and rewards. A summary string is
165    returned if requested at this step.
166
167    Args:
168      agent_indices: Tensor containing current batch indices.
169      observ: Batch tensor of observations.
170      action: Batch tensor of actions.
171      reward: Batch tensor of rewards.
172      unused_done: Batch tensor of done flags.
173      unused_nextob: Batch tensor of successor observations.
174
175    Returns:
176      Summary tensor.
177    """
178    with tf.name_scope('experience/'):
179      return tf.cond(
180          self._is_training,
181          # pylint: disable=g-long-lambda
182          lambda: self._define_experience(agent_indices, observ, action, reward),
183          str)
184
185  def _define_experience(self, agent_indices, observ, action, reward):
186    """Implement the branch of experience() entered during training."""
187    update_filters = tf.summary.merge(
188        [self._observ_filter.update(observ),
189         self._reward_filter.update(reward)])
190    with tf.control_dependencies([update_filters]):
191      if self._config.train_on_agent_action:
192        # NOTE: Doesn't seem to change much.
193        action = self._last_action
194      batch = (observ, action, tf.gather(self._last_mean,
195                                         agent_indices), tf.gather(self._last_logstd,
196                                                                   agent_indices), reward)
197      append = self._episodes.append(batch, agent_indices)
198    with tf.control_dependencies([append]):
199      norm_observ = self._observ_filter.transform(observ)
200      norm_reward = tf.reduce_mean(self._reward_filter.transform(reward))
201      # pylint: disable=g-long-lambda
202      summary = tf.cond(
203          self._should_log, lambda: tf.summary.merge([
204              update_filters,
205              self._observ_filter.summary(),
206              self._reward_filter.summary(),
207              tf.summary.scalar('memory_size', self._memory_index),
208              tf.summary.histogram('normalized_observ', norm_observ),
209              tf.summary.histogram('action', self._last_action),
210              tf.summary.scalar('normalized_reward', norm_reward)
211          ]), str)
212      return summary
213
214  def end_episode(self, agent_indices):
215    """Add episodes to the memory and perform update steps if memory is full.
216
217    During training, add the collected episodes of the batch indices that
218    finished their episode to the memory. If the memory is full, train on it,
219    and then clear the memory. A summary string is returned if requested at
220    this step.
221
222    Args:
223      agent_indices: Tensor containing current batch indices.
224
225    Returns:
226       Summary tensor.
227    """
228    with tf.name_scope('end_episode/'):
229      return tf.cond(self._is_training, lambda: self._define_end_episode(agent_indices), str)
230
231  def _define_end_episode(self, agent_indices):
232    """Implement the branch of end_episode() entered during training."""
233    episodes, length = self._episodes.data(agent_indices)
234    space_left = self._config.update_every - self._memory_index
235    use_episodes = tf.range(tf.minimum(tf.shape(agent_indices)[0], space_left))
236    episodes = [tf.gather(elem, use_episodes) for elem in episodes]
237    append = self._memory.replace(episodes, tf.gather(length, use_episodes),
238                                  use_episodes + self._memory_index)
239    with tf.control_dependencies([append]):
240      inc_index = self._memory_index.assign_add(tf.shape(use_episodes)[0])
241    with tf.control_dependencies([inc_index]):
242      memory_full = self._memory_index >= self._config.update_every
243      return tf.cond(memory_full, self._training, str)
244
245  def _training(self):
246    """Perform multiple training iterations of both policy and value baseline.
247
248    Training on the episodes collected in the memory. Reset the memory
249    afterwards. Always returns a summary string.
250
251    Returns:
252      Summary tensor.
253    """
254    with tf.name_scope('training'):
255      assert_full = tf.assert_equal(self._memory_index, self._config.update_every)
256      with tf.control_dependencies([assert_full]):
257        data = self._memory.data()
258      (observ, action, old_mean, old_logstd, reward), length = data
259      with tf.control_dependencies([tf.assert_greater(length, 0)]):
260        length = tf.identity(length)
261      observ = self._observ_filter.transform(observ)
262      reward = self._reward_filter.transform(reward)
263      update_summary = self._perform_update_steps(observ, action, old_mean, old_logstd, reward,
264                                                  length)
265      with tf.control_dependencies([update_summary]):
266        penalty_summary = self._adjust_penalty(observ, old_mean, old_logstd, length)
267      with tf.control_dependencies([penalty_summary]):
268        clear_memory = tf.group(self._memory.clear(), self._memory_index.assign(0))
269      with tf.control_dependencies([clear_memory]):
270        weight_summary = utility.variable_summaries(tf.trainable_variables(),
271                                                    self._config.weight_summaries)
272        return tf.summary.merge([update_summary, penalty_summary, weight_summary])
273
274  def _perform_update_steps(self, observ, action, old_mean, old_logstd, reward, length):
275    """Perform multiple update steps of value function and policy.
276
277    The advantage is computed once at the beginning and shared across
278    iterations. We need to decide for the summary of one iteration, and thus
279    choose the one after half of the iterations.
280
281    Args:
282      observ: Sequences of observations.
283      action: Sequences of actions.
284      old_mean: Sequences of action means of the behavioral policy.
285      old_logstd: Sequences of action log stddevs of the behavioral policy.
286      reward: Sequences of rewards.
287      length: Batch of sequence lengths.
288
289    Returns:
290      Summary tensor.
291    """
292    return_ = utility.discounted_return(reward, length, self._config.discount)
293    value = self._network(observ, length).value
294    if self._config.gae_lambda:
295      advantage = utility.lambda_return(reward, value, length, self._config.discount,
296                                        self._config.gae_lambda)
297    else:
298      advantage = return_ - value
299    mean, variance = tf.nn.moments(advantage, axes=[0, 1], keep_dims=True)
300    advantage = (advantage - mean) / (tf.sqrt(variance) + 1e-8)
301    advantage = tf.Print(advantage,
302                         [tf.reduce_mean(return_), tf.reduce_mean(value)], 'return and value: ')
303    advantage = tf.Print(advantage, [tf.reduce_mean(advantage)], 'normalized advantage: ')
304    # pylint: disable=g-long-lambda
305    value_loss, policy_loss, summary = tf.scan(lambda _1, _2: self._update_step(
306        observ, action, old_mean, old_logstd, reward, advantage, length),
307                                               tf.range(self._config.update_epochs), [0., 0., ''],
308                                               parallel_iterations=1)
309    print_losses = tf.group(tf.Print(0, [tf.reduce_mean(value_loss)], 'value loss: '),
310                            tf.Print(0, [tf.reduce_mean(policy_loss)], 'policy loss: '))
311    with tf.control_dependencies([value_loss, policy_loss, print_losses]):
312      return summary[self._config.update_epochs // 2]
313
314  def _update_step(self, observ, action, old_mean, old_logstd, reward, advantage, length):
315    """Compute the current combined loss and perform a gradient update step.
316
317    Args:
318      observ: Sequences of observations.
319      action: Sequences of actions.
320      old_mean: Sequences of action means of the behavioral policy.
321      old_logstd: Sequences of action log stddevs of the behavioral policy.
322      reward: Sequences of reward.
323      advantage: Sequences of advantages.
324      length: Batch of sequence lengths.
325
326    Returns:
327      Tuple of value loss, policy loss, and summary tensor.
328    """
329    value_loss, value_summary = self._value_loss(observ, reward, length)
330    network = self._network(observ, length)
331    policy_loss, policy_summary = self._policy_loss(network.mean, network.logstd, old_mean,
332                                                    old_logstd, action, advantage, length)
333    value_gradients, value_variables = (zip(*self._optimizer.compute_gradients(value_loss)))
334    policy_gradients, policy_variables = (zip(*self._optimizer.compute_gradients(policy_loss)))
335    all_gradients = value_gradients + policy_gradients
336    all_variables = value_variables + policy_variables
337    optimize = self._optimizer.apply_gradients(zip(all_gradients, all_variables))
338    summary = tf.summary.merge([
339        value_summary, policy_summary,
340        tf.summary.scalar('value_gradient_norm', tf.global_norm(value_gradients)),
341        tf.summary.scalar('policy_gradient_norm', tf.global_norm(policy_gradients)),
342        utility.gradient_summaries(zip(value_gradients, value_variables), dict(value=r'.*')),
343        utility.gradient_summaries(zip(policy_gradients, policy_variables), dict(policy=r'.*'))
344    ])
345    with tf.control_dependencies([optimize]):
346      return [tf.identity(x) for x in (value_loss, policy_loss, summary)]
347
348  def _value_loss(self, observ, reward, length):
349    """Compute the loss function for the value baseline.
350
351    The value loss is the difference between empirical and approximated returns
352    over the collected episodes. Returns the loss tensor and a summary strin.
353
354    Args:
355      observ: Sequences of observations.
356      reward: Sequences of reward.
357      length: Batch of sequence lengths.
358
359    Returns:
360      Tuple of loss tensor and summary tensor.
361    """
362    with tf.name_scope('value_loss'):
363      value = self._network(observ, length).value
364      return_ = utility.discounted_return(reward, length, self._config.discount)
365      advantage = return_ - value
366      value_loss = 0.5 * self._mask(advantage**2, length)
367      summary = tf.summary.merge([
368          tf.summary.histogram('value_loss', value_loss),
369          tf.summary.scalar('avg_value_loss', tf.reduce_mean(value_loss))
370      ])
371      value_loss = tf.reduce_mean(value_loss)
372      return tf.check_numerics(value_loss, 'value_loss'), summary
373
374  def _policy_loss(self, mean, logstd, old_mean, old_logstd, action, advantage, length):
375    """Compute the policy loss composed of multiple components.
376
377    1. The policy gradient loss is importance sampled from the data-collecting
378       policy at the beginning of training.
379    2. The second term is a KL penalty between the policy at the beginning of
380       training and the current policy.
381    3. Additionally, if this KL already changed more than twice the target
382       amount, we activate a strong penalty discouraging further divergence.
383
384    Args:
385      mean: Sequences of action means of the current policy.
386      logstd: Sequences of action log stddevs of the current policy.
387      old_mean: Sequences of action means of the behavioral policy.
388      old_logstd: Sequences of action log stddevs of the behavioral policy.
389      action: Sequences of actions.
390      advantage: Sequences of advantages.
391      length: Batch of sequence lengths.
392
393    Returns:
394      Tuple of loss tensor and summary tensor.
395    """
396    with tf.name_scope('policy_loss'):
397      entropy = utility.diag_normal_entropy(mean, logstd)
398      kl = tf.reduce_mean(
399          self._mask(utility.diag_normal_kl(old_mean, old_logstd, mean, logstd), length), 1)
400      policy_gradient = tf.exp(
401          utility.diag_normal_logpdf(mean, logstd, action) -
402          utility.diag_normal_logpdf(old_mean, old_logstd, action))
403      surrogate_loss = -tf.reduce_mean(
404          self._mask(policy_gradient * tf.stop_gradient(advantage), length), 1)
405      kl_penalty = self._penalty * kl
406      cutoff_threshold = self._config.kl_target * self._config.kl_cutoff_factor
407      cutoff_count = tf.reduce_sum(tf.cast(kl > cutoff_threshold, tf.int32))
408      with tf.control_dependencies(
409          [tf.cond(cutoff_count > 0, lambda: tf.Print(0, [cutoff_count], 'kl cutoff! '), int)]):
410        kl_cutoff = (self._config.kl_cutoff_coef * tf.cast(kl > cutoff_threshold, tf.float32) *
411                     (kl - cutoff_threshold)**2)
412      policy_loss = surrogate_loss + kl_penalty + kl_cutoff
413      summary = tf.summary.merge([
414          tf.summary.histogram('entropy', entropy),
415          tf.summary.histogram('kl', kl),
416          tf.summary.histogram('surrogate_loss', surrogate_loss),
417          tf.summary.histogram('kl_penalty', kl_penalty),
418          tf.summary.histogram('kl_cutoff', kl_cutoff),
419          tf.summary.histogram('kl_penalty_combined', kl_penalty + kl_cutoff),
420          tf.summary.histogram('policy_loss', policy_loss),
421          tf.summary.scalar('avg_surr_loss', tf.reduce_mean(surrogate_loss)),
422          tf.summary.scalar('avg_kl_penalty', tf.reduce_mean(kl_penalty)),
423          tf.summary.scalar('avg_policy_loss', tf.reduce_mean(policy_loss))
424      ])
425      policy_loss = tf.reduce_mean(policy_loss, 0)
426      return tf.check_numerics(policy_loss, 'policy_loss'), summary
427
428  def _adjust_penalty(self, observ, old_mean, old_logstd, length):
429    """Adjust the KL policy between the behavioral and current policy.
430
431    Compute how much the policy actually changed during the multiple
432    update steps. Adjust the penalty strength for the next training phase if we
433    overshot or undershot the target divergence too much.
434
435    Args:
436      observ: Sequences of observations.
437      old_mean: Sequences of action means of the behavioral policy.
438      old_logstd: Sequences of action log stddevs of the behavioral policy.
439      length: Batch of sequence lengths.
440
441    Returns:
442      Summary tensor.
443    """
444    with tf.name_scope('adjust_penalty'):
445      network = self._network(observ, length)
446      assert_change = tf.assert_equal(tf.reduce_all(tf.equal(network.mean, old_mean)),
447                                      False,
448                                      message='policy should change')
449      print_penalty = tf.Print(0, [self._penalty], 'current penalty: ')
450      with tf.control_dependencies([assert_change, print_penalty]):
451        kl_change = tf.reduce_mean(
452            self._mask(utility.diag_normal_kl(old_mean, old_logstd, network.mean, network.logstd),
453                       length))
454        kl_change = tf.Print(kl_change, [kl_change], 'kl change: ')
455        maybe_increase = tf.cond(
456            kl_change > 1.3 * self._config.kl_target,
457            # pylint: disable=g-long-lambda
458            lambda: tf.Print(self._penalty.assign(self._penalty * 1.5), [0], 'increase penalty '),
459            float)
460        maybe_decrease = tf.cond(
461            kl_change < 0.7 * self._config.kl_target,
462            # pylint: disable=g-long-lambda
463            lambda: tf.Print(self._penalty.assign(self._penalty / 1.5), [0], 'decrease penalty '),
464            float)
465      with tf.control_dependencies([maybe_increase, maybe_decrease]):
466        return tf.summary.merge([
467            tf.summary.scalar('kl_change', kl_change),
468            tf.summary.scalar('penalty', self._penalty)
469        ])
470
471  def _mask(self, tensor, length):
472    """Set padding elements of a batch of sequences to zero.
473
474    Useful to then safely sum along the time dimension.
475
476    Args:
477      tensor: Tensor of sequences.
478      length: Batch of sequence lengths.
479
480    Returns:
481      Masked sequences.
482    """
483    with tf.name_scope('mask'):
484      range_ = tf.range(tensor.shape[1].value)
485      mask = tf.cast(range_[None, :] < length[:, None], tf.float32)
486      masked = tensor * mask
487      return tf.check_numerics(masked, 'masked')
488