1# Copyright 2017 The TensorFlow Agents Authors.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#      http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14"""Proximal Policy Optimization algorithm.
15
16Based on John Schulman's implementation in Python and Theano:
17https://github.com/joschu/modular_rl/blob/master/modular_rl/ppo.py
18"""
19
20from __future__ import absolute_import
21from __future__ import division
22from __future__ import print_function
23
24import collections
25
26import tf.compat.v1 as tf
27
28from . import memory
29from . import normalize
30from . import utility
31
32_NetworkOutput = collections.namedtuple('NetworkOutput', 'policy, mean, logstd, value, state')
33
34
35class PPOAlgorithm(object):
36  """A vectorized implementation of the PPO algorithm by John Schulman."""
37
38  def __init__(self, batch_env, step, is_training, should_log, config):
39    """Create an instance of the PPO algorithm.
40
41    Args:
42      batch_env: In-graph batch environment.
43      step: Integer tensor holding the current training step.
44      is_training: Boolean tensor for whether the algorithm should train.
45      should_log: Boolean tensor for whether summaries should be returned.
46      config: Object containing the agent configuration as attributes.
47    """
48    self._batch_env = batch_env
49    self._step = step
50    self._is_training = is_training
51    self._should_log = should_log
52    self._config = config
53    self._observ_filter = normalize.StreamingNormalize(self._batch_env.observ[0],
54                                                       center=True,
55                                                       scale=True,
56                                                       clip=5,
57                                                       name='normalize_observ')
58    self._reward_filter = normalize.StreamingNormalize(self._batch_env.reward[0],
59                                                       center=False,
60                                                       scale=True,
61                                                       clip=10,
62                                                       name='normalize_reward')
63    # Memory stores tuple of observ, action, mean, logstd, reward.
64    template = (self._batch_env.observ[0], self._batch_env.action[0], self._batch_env.action[0],
65                self._batch_env.action[0], self._batch_env.reward[0])
66    self._memory = memory.EpisodeMemory(template, config.update_every, config.max_length, 'memory')
67    self._memory_index = tf.Variable(0, False)
68    use_gpu = self._config.use_gpu and utility.available_gpus()
69    with tf.device('/gpu:0' if use_gpu else '/cpu:0'):
70      # Create network variables for later calls to reuse.
71      self._network(tf.zeros_like(self._batch_env.observ)[:, None],
72                    tf.ones(len(self._batch_env)),
73                    reuse=None)
74      cell = self._config.network(self._batch_env.action.shape[1].value)
75      with tf.variable_scope('ppo_temporary'):
76        self._episodes = memory.EpisodeMemory(template, len(batch_env), config.max_length,
77                                              'episodes')
78        self._last_state = utility.create_nested_vars(cell.zero_state(len(batch_env), tf.float32))
79        self._last_action = tf.Variable(tf.zeros_like(self._batch_env.action),
80                                        False,
81                                        name='last_action')
82        self._last_mean = tf.Variable(tf.zeros_like(self._batch_env.action),
83                                      False,
84                                      name='last_mean')
85        self._last_logstd = tf.Variable(tf.zeros_like(self._batch_env.action),
86                                        False,
87                                        name='last_logstd')
88    self._penalty = tf.Variable(self._config.kl_init_penalty, False, dtype=tf.float32)
89    self._policy_optimizer = self._config.policy_optimizer(self._config.policy_lr,
90                                                           name='policy_optimizer')
91    self._value_optimizer = self._config.value_optimizer(self._config.value_lr,
92                                                         name='value_optimizer')
93
94  def begin_episode(self, agent_indices):
95    """Reset the recurrent states and stored episode.
96
97    Args:
98      agent_indices: 1D tensor of batch indices for agents starting an episode.
99
100    Returns:
101      Summary tensor.
102    """
103    with tf.name_scope('begin_episode/'):
104      reset_state = utility.reinit_nested_vars(self._last_state, agent_indices)
105      reset_buffer = self._episodes.clear(agent_indices)
106      with tf.control_dependencies([reset_state, reset_buffer]):
107        return tf.constant('')
108
109  def perform(self, observ):
110    """Compute batch of actions and a summary for a batch of observation.
111
112    Args:
113      observ: Tensor of a batch of observations for all agents.
114
115    Returns:
116      Tuple of action batch tensor and summary tensor.
117    """
118    with tf.name_scope('perform/'):
119      observ = self._observ_filter.transform(observ)
120      network = self._network(observ[:, None], tf.ones(observ.shape[0]), self._last_state)
121      action = tf.cond(self._is_training, network.policy.sample, lambda: network.mean)
122      logprob = network.policy.log_prob(action)[:, 0]
123      # pylint: disable=g-long-lambda
124      summary = tf.cond(
125          self._should_log, lambda: tf.summary.merge([
126              tf.summary.histogram('mean', network.mean[:, 0]),
127              tf.summary.histogram('std', tf.exp(network.logstd[:, 0])),
128              tf.summary.histogram('action', action[:, 0]),
129              tf.summary.histogram('logprob', logprob)
130          ]), str)
131      # Remember current policy to append to memory in the experience callback.
132      with tf.control_dependencies([
133          utility.assign_nested_vars(self._last_state, network.state),
134          self._last_action.assign(action[:, 0]),
135          self._last_mean.assign(network.mean[:, 0]),
136          self._last_logstd.assign(network.logstd[:, 0])
137      ]):
138        return tf.check_numerics(action[:, 0], 'action'), tf.identity(summary)
139
140  def experience(self, observ, action, reward, unused_done, unused_nextob):
141    """Process the transition tuple of the current step.
142
143    When training, add the current transition tuple to the memory and update
144    the streaming statistics for observations and rewards. A summary string is
145    returned if requested at this step.
146
147    Args:
148      observ: Batch tensor of observations.
149      action: Batch tensor of actions.
150      reward: Batch tensor of rewards.
151      unused_done: Batch tensor of done flags.
152      unused_nextob: Batch tensor of successor observations.
153
154    Returns:
155      Summary tensor.
156    """
157    with tf.name_scope('experience/'):
158      return tf.cond(self._is_training, lambda: self._define_experience(observ, action, reward),
159                     str)
160
161  def _define_experience(self, observ, action, reward):
162    """Implement the branch of experience() entered during training."""
163    update_filters = tf.summary.merge(
164        [self._observ_filter.update(observ),
165         self._reward_filter.update(reward)])
166    with tf.control_dependencies([update_filters]):
167      if self._config.train_on_agent_action:
168        # NOTE: Doesn't seem to change much.
169        action = self._last_action
170      batch = observ, action, self._last_mean, self._last_logstd, reward
171      append = self._episodes.append(batch, tf.range(len(self._batch_env)))
172    with tf.control_dependencies([append]):
173      norm_observ = self._observ_filter.transform(observ)
174      norm_reward = tf.reduce_mean(self._reward_filter.transform(reward))
175      # pylint: disable=g-long-lambda
176      summary = tf.cond(
177          self._should_log, lambda: tf.summary.merge([
178              update_filters,
179              self._observ_filter.summary(),
180              self._reward_filter.summary(),
181              tf.summary.scalar('memory_size', self._memory_index),
182              tf.summary.histogram('normalized_observ', norm_observ),
183              tf.summary.histogram('action', self._last_action),
184              tf.summary.scalar('normalized_reward', norm_reward)
185          ]), str)
186      return summary
187
188  def end_episode(self, agent_indices):
189    """Add episodes to the memory and perform update steps if memory is full.
190
191    During training, add the collected episodes of the batch indices that
192    finished their episode to the memory. If the memory is full, train on it,
193    and then clear the memory. A summary string is returned if requested at
194    this step.
195
196    Args:
197      agent_indices: 1D tensor of batch indices for agents starting an episode.
198
199    Returns:
200       Summary tensor.
201    """
202    with tf.name_scope('end_episode/'):
203      return tf.cond(self._is_training, lambda: self._define_end_episode(agent_indices), str)
204
205  def _define_end_episode(self, agent_indices):
206    """Implement the branch of end_episode() entered during training."""
207    episodes, length = self._episodes.data(agent_indices)
208    space_left = self._config.update_every - self._memory_index
209    use_episodes = tf.range(tf.minimum(tf.shape(agent_indices)[0], space_left))
210    episodes = [tf.gather(elem, use_episodes) for elem in episodes]
211    append = self._memory.replace(episodes, tf.gather(length, use_episodes),
212                                  use_episodes + self._memory_index)
213    with tf.control_dependencies([append]):
214      inc_index = self._memory_index.assign_add(tf.shape(use_episodes)[0])
215    with tf.control_dependencies([inc_index]):
216      memory_full = self._memory_index >= self._config.update_every
217      return tf.cond(memory_full, self._training, str)
218
219  def _training(self):
220    """Perform multiple training iterations of both policy and value baseline.
221
222    Training on the episodes collected in the memory. Reset the memory
223    afterwards. Always returns a summary string.
224
225    Returns:
226      Summary tensor.
227    """
228    with tf.name_scope('training'):
229      assert_full = tf.assert_equal(self._memory_index, self._config.update_every)
230      with tf.control_dependencies([assert_full]):
231        data = self._memory.data()
232      (observ, action, old_mean, old_logstd, reward), length = data
233      with tf.control_dependencies([tf.assert_greater(length, 0)]):
234        length = tf.identity(length)
235      observ = self._observ_filter.transform(observ)
236      reward = self._reward_filter.transform(reward)
237      policy_summary = self._update_policy(observ, action, old_mean, old_logstd, reward, length)
238      with tf.control_dependencies([policy_summary]):
239        value_summary = self._update_value(observ, reward, length)
240      with tf.control_dependencies([value_summary]):
241        penalty_summary = self._adjust_penalty(observ, old_mean, old_logstd, length)
242      with tf.control_dependencies([penalty_summary]):
243        clear_memory = tf.group(self._memory.clear(), self._memory_index.assign(0))
244      with tf.control_dependencies([clear_memory]):
245        weight_summary = utility.variable_summaries(tf.trainable_variables(),
246                                                    self._config.weight_summaries)
247        return tf.summary.merge([policy_summary, value_summary, penalty_summary, weight_summary])
248
249  def _update_value(self, observ, reward, length):
250    """Perform multiple update steps of the value baseline.
251
252    We need to decide for the summary of one iteration, and thus choose the one
253    after half of the iterations.
254
255    Args:
256      observ: Sequences of observations.
257      reward: Sequences of reward.
258      length: Batch of sequence lengths.
259
260    Returns:
261      Summary tensor.
262    """
263    with tf.name_scope('update_value'):
264      loss, summary = tf.scan(lambda _1, _2: self._update_value_step(observ, reward, length),
265                              tf.range(self._config.update_epochs_value), [0., ''],
266                              parallel_iterations=1)
267      print_loss = tf.Print(0, [tf.reduce_mean(loss)], 'value loss: ')
268      with tf.control_dependencies([loss, print_loss]):
269        return summary[self._config.update_epochs_value // 2]
270
271  def _update_value_step(self, observ, reward, length):
272    """Compute the current value loss and perform a gradient update step.
273
274    Args:
275      observ: Sequences of observations.
276      reward: Sequences of reward.
277      length: Batch of sequence lengths.
278
279    Returns:
280      Tuple of loss tensor and summary tensor.
281    """
282    loss, summary = self._value_loss(observ, reward, length)
283    gradients, variables = (zip(*self._value_optimizer.compute_gradients(loss)))
284    optimize = self._value_optimizer.apply_gradients(zip(gradients, variables))
285    summary = tf.summary.merge([
286        summary,
287        tf.summary.scalar('gradient_norm', tf.global_norm(gradients)),
288        utility.gradient_summaries(zip(gradients, variables), dict(value=r'.*'))
289    ])
290    with tf.control_dependencies([optimize]):
291      return [tf.identity(loss), tf.identity(summary)]
292
293  def _value_loss(self, observ, reward, length):
294    """Compute the loss function for the value baseline.
295
296    The value loss is the difference between empirical and approximated returns
297    over the collected episodes. Returns the loss tensor and a summary strin.
298
299    Args:
300      observ: Sequences of observations.
301      reward: Sequences of reward.
302      length: Batch of sequence lengths.
303
304    Returns:
305      Tuple of loss tensor and summary tensor.
306    """
307    with tf.name_scope('value_loss'):
308      value = self._network(observ, length).value
309      return_ = utility.discounted_return(reward, length, self._config.discount)
310      advantage = return_ - value
311      value_loss = 0.5 * self._mask(advantage**2, length)
312      summary = tf.summary.merge([
313          tf.summary.histogram('value_loss', value_loss),
314          tf.summary.scalar('avg_value_loss', tf.reduce_mean(value_loss))
315      ])
316      value_loss = tf.reduce_mean(value_loss)
317      return tf.check_numerics(value_loss, 'value_loss'), summary
318
319  def _update_policy(self, observ, action, old_mean, old_logstd, reward, length):
320    """Perform multiple update steps of the policy.
321
322    The advantage is computed once at the beginning and shared across
323    iterations. We need to decide for the summary of one iteration, and thus
324    choose the one after half of the iterations.
325
326    Args:
327      observ: Sequences of observations.
328      action: Sequences of actions.
329      old_mean: Sequences of action means of the behavioral policy.
330      old_logstd: Sequences of action log stddevs of the behavioral policy.
331      reward: Sequences of rewards.
332      length: Batch of sequence lengths.
333
334    Returns:
335      Summary tensor.
336    """
337    with tf.name_scope('update_policy'):
338      return_ = utility.discounted_return(reward, length, self._config.discount)
339      value = self._network(observ, length).value
340      if self._config.gae_lambda:
341        advantage = utility.lambda_return(reward, value, length, self._config.discount,
342                                          self._config.gae_lambda)
343      else:
344        advantage = return_ - value
345      mean, variance = tf.nn.moments(advantage, axes=[0, 1], keep_dims=True)
346      advantage = (advantage - mean) / (tf.sqrt(variance) + 1e-8)
347      advantage = tf.Print(
348          advantage, [tf.reduce_mean(return_), tf.reduce_mean(value)], 'return and value: ')
349      advantage = tf.Print(advantage, [tf.reduce_mean(advantage)], 'normalized advantage: ')
350      # pylint: disable=g-long-lambda
351      loss, summary = tf.scan(lambda _1, _2: self._update_policy_step(
352          observ, action, old_mean, old_logstd, advantage, length),
353                              tf.range(self._config.update_epochs_policy), [0., ''],
354                              parallel_iterations=1)
355      print_loss = tf.Print(0, [tf.reduce_mean(loss)], 'policy loss: ')
356      with tf.control_dependencies([loss, print_loss]):
357        return summary[self._config.update_epochs_policy // 2]
358
359  def _update_policy_step(self, observ, action, old_mean, old_logstd, advantage, length):
360    """Compute the current policy loss and perform a gradient update step.
361
362    Args:
363      observ: Sequences of observations.
364      action: Sequences of actions.
365      old_mean: Sequences of action means of the behavioral policy.
366      old_logstd: Sequences of action log stddevs of the behavioral policy.
367      advantage: Sequences of advantages.
368      length: Batch of sequence lengths.
369
370    Returns:
371      Tuple of loss tensor and summary tensor.
372    """
373    network = self._network(observ, length)
374    loss, summary = self._policy_loss(network.mean, network.logstd, old_mean, old_logstd, action,
375                                      advantage, length)
376    gradients, variables = (zip(*self._policy_optimizer.compute_gradients(loss)))
377    optimize = self._policy_optimizer.apply_gradients(zip(gradients, variables))
378    summary = tf.summary.merge([
379        summary,
380        tf.summary.scalar('gradient_norm', tf.global_norm(gradients)),
381        utility.gradient_summaries(zip(gradients, variables), dict(policy=r'.*'))
382    ])
383    with tf.control_dependencies([optimize]):
384      return [tf.identity(loss), tf.identity(summary)]
385
386  def _policy_loss(self, mean, logstd, old_mean, old_logstd, action, advantage, length):
387    """Compute the policy loss composed of multiple components.
388
389    1. The policy gradient loss is importance sampled from the data-collecting
390       policy at the beginning of training.
391    2. The second term is a KL penalty between the policy at the beginning of
392       training and the current policy.
393    3. Additionally, if this KL already changed more than twice the target
394       amount, we activate a strong penalty discouraging further divergence.
395
396    Args:
397      mean: Sequences of action means of the current policy.
398      logstd: Sequences of action log stddevs of the current policy.
399      old_mean: Sequences of action means of the behavioral policy.
400      old_logstd: Sequences of action log stddevs of the behavioral policy.
401      action: Sequences of actions.
402      advantage: Sequences of advantages.
403      length: Batch of sequence lengths.
404
405    Returns:
406      Tuple of loss tensor and summary tensor.
407    """
408    with tf.name_scope('policy_loss'):
409      entropy = utility.diag_normal_entropy(mean, logstd)
410      kl = tf.reduce_mean(
411          self._mask(utility.diag_normal_kl(old_mean, old_logstd, mean, logstd), length), 1)
412      policy_gradient = tf.exp(
413          utility.diag_normal_logpdf(mean, logstd, action) -
414          utility.diag_normal_logpdf(old_mean, old_logstd, action))
415      surrogate_loss = -tf.reduce_mean(
416          self._mask(policy_gradient * tf.stop_gradient(advantage), length), 1)
417      kl_penalty = self._penalty * kl
418      cutoff_threshold = self._config.kl_target * self._config.kl_cutoff_factor
419      cutoff_count = tf.reduce_sum(tf.cast(kl > cutoff_threshold, tf.int32))
420      with tf.control_dependencies(
421          [tf.cond(cutoff_count > 0, lambda: tf.Print(0, [cutoff_count], 'kl cutoff! '), int)]):
422        kl_cutoff = (self._config.kl_cutoff_coef * tf.cast(kl > cutoff_threshold, tf.float32) *
423                     (kl - cutoff_threshold)**2)
424      policy_loss = surrogate_loss + kl_penalty + kl_cutoff
425      summary = tf.summary.merge([
426          tf.summary.histogram('entropy', entropy),
427          tf.summary.histogram('kl', kl),
428          tf.summary.histogram('surrogate_loss', surrogate_loss),
429          tf.summary.histogram('kl_penalty', kl_penalty),
430          tf.summary.histogram('kl_cutoff', kl_cutoff),
431          tf.summary.histogram('kl_penalty_combined', kl_penalty + kl_cutoff),
432          tf.summary.histogram('policy_loss', policy_loss),
433          tf.summary.scalar('avg_surr_loss', tf.reduce_mean(surrogate_loss)),
434          tf.summary.scalar('avg_kl_penalty', tf.reduce_mean(kl_penalty)),
435          tf.summary.scalar('avg_policy_loss', tf.reduce_mean(policy_loss))
436      ])
437      policy_loss = tf.reduce_mean(policy_loss, 0)
438      return tf.check_numerics(policy_loss, 'policy_loss'), summary
439
440  def _adjust_penalty(self, observ, old_mean, old_logstd, length):
441    """Adjust the KL policy between the behavioral and current policy.
442
443    Compute how much the policy actually changed during the multiple
444    update steps. Adjust the penalty strength for the next training phase if we
445    overshot or undershot the target divergence too much.
446
447    Args:
448      observ: Sequences of observations.
449      old_mean: Sequences of action means of the behavioral policy.
450      old_logstd: Sequences of action log stddevs of the behavioral policy.
451      length: Batch of sequence lengths.
452
453    Returns:
454      Summary tensor.
455    """
456    with tf.name_scope('adjust_penalty'):
457      network = self._network(observ, length)
458      assert_change = tf.assert_equal(tf.reduce_all(tf.equal(network.mean, old_mean)),
459                                      False,
460                                      message='policy should change')
461      print_penalty = tf.Print(0, [self._penalty], 'current penalty: ')
462      with tf.control_dependencies([assert_change, print_penalty]):
463        kl_change = tf.reduce_mean(
464            self._mask(utility.diag_normal_kl(old_mean, old_logstd, network.mean, network.logstd),
465                       length))
466        kl_change = tf.Print(kl_change, [kl_change], 'kl change: ')
467        maybe_increase = tf.cond(
468            kl_change > 1.3 * self._config.kl_target,
469            # pylint: disable=g-long-lambda
470            lambda: tf.Print(self._penalty.assign(self._penalty * 1.5), [0], 'increase penalty '),
471            float)
472        maybe_decrease = tf.cond(
473            kl_change < 0.7 * self._config.kl_target,
474            # pylint: disable=g-long-lambda
475            lambda: tf.Print(self._penalty.assign(self._penalty / 1.5), [0], 'decrease penalty '),
476            float)
477      with tf.control_dependencies([maybe_increase, maybe_decrease]):
478        return tf.summary.merge([
479            tf.summary.scalar('kl_change', kl_change),
480            tf.summary.scalar('penalty', self._penalty)
481        ])
482
483  def _mask(self, tensor, length):
484    """Set padding elements of a batch of sequences to zero.
485
486    Useful to then safely sum along the time dimension.
487
488    Args:
489      tensor: Tensor of sequences.
490      length: Batch of sequence lengths.
491
492    Returns:
493      Masked sequences.
494    """
495    with tf.name_scope('mask'):
496      range_ = tf.range(tensor.shape[1].value)
497      mask = tf.cast(range_[None, :] < length[:, None], tf.float32)
498      masked = tensor * mask
499      return tf.check_numerics(masked, 'masked')
500
501  def _network(self, observ, length=None, state=None, reuse=True):
502    """Compute the network output for a batched sequence of observations.
503
504    Optionally, the initial state can be specified. The weights should be
505    reused for all calls, except for the first one. Output is a named tuple
506    containing the policy as a TensorFlow distribution, the policy mean and log
507    standard deviation, the approximated state value, and the new recurrent
508    state.
509
510    Args:
511      observ: Sequences of observations.
512      length: Batch of sequence lengths.
513      state: Batch of initial recurrent states.
514      reuse: Python boolean whether to reuse previous variables.
515
516    Returns:
517      NetworkOutput tuple.
518    """
519    with tf.variable_scope('network', reuse=reuse):
520      observ = tf.convert_to_tensor(observ)
521      use_gpu = self._config.use_gpu and utility.available_gpus()
522      with tf.device('/gpu:0' if use_gpu else '/cpu:0'):
523        observ = tf.check_numerics(observ, 'observ')
524        cell = self._config.network(self._batch_env.action.shape[1].value)
525        (mean, logstd, value), state = tf.nn.dynamic_rnn(cell,
526                                                         observ,
527                                                         length,
528                                                         state,
529                                                         tf.float32,
530                                                         swap_memory=True)
531      mean = tf.check_numerics(mean, 'mean')
532      logstd = tf.check_numerics(logstd, 'logstd')
533      value = tf.check_numerics(value, 'value')
534      policy = tf.contrib.distributions.MultivariateNormalDiag(mean, tf.exp(logstd))
535      return _NetworkOutput(policy, mean, logstd, value, state)
536