1# Copyright 2017 The TensorFlow Agents Authors.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#      http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14"""Execute operations in a loop and coordinate logging and checkpoints."""
15
16from __future__ import absolute_import
17from __future__ import division
18from __future__ import print_function
19
20import collections
21import os
22
23import tf.compat.v1 as tf
24
25from pybullet_envs.minitaur.agents.tools import streaming_mean
26
27_Phase = collections.namedtuple(
28    'Phase', 'name, writer, op, batch, steps, feed, report_every, log_every,'
29    'checkpoint_every')
30
31
32class Loop(object):
33  """Execute operations in a loop and coordinate logging and checkpoints.
34
35  Supports multiple phases, that define their own operations to run, and
36  intervals for reporting scores, logging summaries, and storing checkpoints.
37  All class state is stored in-graph to properly recover from checkpoints.
38  """
39
40  def __init__(self, logdir, step=None, log=None, report=None, reset=None):
41    """Execute operations in a loop and coordinate logging and checkpoints.
42
43    The step, log, report, and report arguments will get created if not
44    provided. Reset is used to indicate switching to a new phase, so that the
45    model can start a new computation in case its computation is split over
46    multiple training steps.
47
48    Args:
49      logdir: Will contain checkpoints and summaries for each phase.
50      step: Variable of the global step (optional).
51      log: Tensor indicating to the model to compute summary tensors.
52      report: Tensor indicating to the loop to report the current mean score.
53      reset: Tensor indicating to the model to start a new computation.
54    """
55    self._logdir = logdir
56    self._step = (tf.Variable(0, False, name='global_step') if step is None else step)
57    self._log = tf.placeholder(tf.bool) if log is None else log
58    self._report = tf.placeholder(tf.bool) if report is None else report
59    self._reset = tf.placeholder(tf.bool) if reset is None else reset
60    self._phases = []
61
62  def add_phase(self,
63                name,
64                done,
65                score,
66                summary,
67                steps,
68                report_every=None,
69                log_every=None,
70                checkpoint_every=None,
71                feed=None):
72    """Add a phase to the loop protocol.
73
74    If the model breaks long computation into multiple steps, the done tensor
75    indicates whether the current score should be added to the mean counter.
76    For example, in reinforcement learning we only have a valid score at the
77    end of the episode.
78
79    Score and done tensors can either be scalars or vectors, to support
80    single and batched computations.
81
82    Args:
83      name: Name for the phase, used for the summary writer.
84      done: Tensor indicating whether current score can be used.
85      score: Tensor holding the current, possibly intermediate, score.
86      summary: Tensor holding summary string to write if not an empty string.
87      steps: Duration of the phase in steps.
88      report_every: Yield mean score every this number of steps.
89      log_every: Request summaries via `log` tensor every this number of steps.
90      checkpoint_every: Write checkpoint every this number of steps.
91      feed: Additional feed dictionary for the session run call.
92
93    Raises:
94      ValueError: Unknown rank for done or score tensors.
95    """
96    done = tf.convert_to_tensor(done, tf.bool)
97    score = tf.convert_to_tensor(score, tf.float32)
98    summary = tf.convert_to_tensor(summary, tf.string)
99    feed = feed or {}
100    if done.shape.ndims is None or score.shape.ndims is None:
101      raise ValueError("Rank of 'done' and 'score' tensors must be known.")
102    writer = self._logdir and tf.summary.FileWriter(
103        os.path.join(self._logdir, name), tf.get_default_graph(), flush_secs=60)
104    op = self._define_step(done, score, summary)
105    batch = 1 if score.shape.ndims == 0 else score.shape[0].value
106    self._phases.append(
107        _Phase(name, writer, op, batch, int(steps), feed, report_every, log_every,
108               checkpoint_every))
109
110  def run(self, sess, saver, max_step=None):
111    """Run the loop schedule for a specified number of steps.
112
113    Call the operation of the current phase until the global step reaches the
114    specified maximum step. Phases are repeated over and over in the order they
115    were added.
116
117    Args:
118      sess: Session to use to run the phase operation.
119      saver: Saver used for checkpointing.
120      max_step: Run the operations until the step reaches this limit.
121
122    Yields:
123      Reported mean scores.
124    """
125    global_step = sess.run(self._step)
126    steps_made = 1
127    while True:
128      if max_step and global_step >= max_step:
129        break
130      phase, epoch, steps_in = self._find_current_phase(global_step)
131      phase_step = epoch * phase.steps + steps_in
132      if steps_in % phase.steps < steps_made:
133        message = '\n' + ('-' * 50) + '\n'
134        message += 'Phase {} (phase step {}, global step {}).'
135        tf.logging.info(message.format(phase.name, phase_step, global_step))
136      # Populate book keeping tensors.
137      phase.feed[self._reset] = (steps_in < steps_made)
138      phase.feed[self._log] = (phase.writer and
139                               self._is_every_steps(phase_step, phase.batch, phase.log_every))
140      phase.feed[self._report] = (self._is_every_steps(phase_step, phase.batch,
141                                                       phase.report_every))
142      summary, mean_score, global_step, steps_made = sess.run(phase.op, phase.feed)
143      if self._is_every_steps(phase_step, phase.batch, phase.checkpoint_every):
144        self._store_checkpoint(sess, saver, global_step)
145      if self._is_every_steps(phase_step, phase.batch, phase.report_every):
146        yield mean_score
147      if summary and phase.writer:
148        # We want smaller phases to catch up at the beginnig of each epoch so
149        # that their graphs are aligned.
150        longest_phase = max(phase.steps for phase in self._phases)
151        summary_step = epoch * longest_phase + steps_in
152        phase.writer.add_summary(summary, summary_step)
153
154  def _is_every_steps(self, phase_step, batch, every):
155    """Determine whether a periodic event should happen at this step.
156
157    Args:
158      phase_step: The incrementing step.
159      batch: The number of steps progressed at once.
160      every: The interval of the periode.
161
162    Returns:
163      Boolean of whether the event should happen.
164    """
165    if not every:
166      return False
167    covered_steps = range(phase_step, phase_step + batch)
168    return any((step + 1) % every == 0 for step in covered_steps)
169
170  def _find_current_phase(self, global_step):
171    """Determine the current phase based on the global step.
172
173    This ensures continuing the correct phase after restoring checkoints.
174
175    Args:
176      global_step: The global number of steps performed across all phases.
177
178    Returns:
179      Tuple of phase object, epoch number, and phase steps within the epoch.
180    """
181    epoch_size = sum(phase.steps for phase in self._phases)
182    epoch = int(global_step // epoch_size)
183    steps_in = global_step % epoch_size
184    for phase in self._phases:
185      if steps_in < phase.steps:
186        return phase, epoch, steps_in
187      steps_in -= phase.steps
188
189  def _define_step(self, done, score, summary):
190    """Combine operations of a phase.
191
192    Keeps track of the mean score and when to report it.
193
194    Args:
195      done: Tensor indicating whether current score can be used.
196      score: Tensor holding the current, possibly intermediate, score.
197      summary: Tensor holding summary string to write if not an empty string.
198
199    Returns:
200      Tuple of summary tensor, mean score, and new global step. The mean score
201      is zero for non reporting steps.
202    """
203    if done.shape.ndims == 0:
204      done = done[None]
205    if score.shape.ndims == 0:
206      score = score[None]
207    score_mean = streaming_mean.StreamingMean((), tf.float32)
208    with tf.control_dependencies([done, score, summary]):
209      done_score = tf.gather(score, tf.where(done)[:, 0])
210      submit_score = tf.cond(tf.reduce_any(done), lambda: score_mean.submit(done_score), tf.no_op)
211    with tf.control_dependencies([submit_score]):
212      mean_score = tf.cond(self._report, score_mean.clear, float)
213      steps_made = tf.shape(score)[0]
214      next_step = self._step.assign_add(steps_made)
215    with tf.control_dependencies([mean_score, next_step]):
216      return tf.identity(summary), mean_score, next_step, steps_made
217
218  def _store_checkpoint(self, sess, saver, global_step):
219    """Store a checkpoint if a log directory was provided to the constructor.
220
221    The directory will be created if needed.
222
223    Args:
224      sess: Session containing variables to store.
225      saver: Saver used for checkpointing.
226      global_step: Step number of the checkpoint name.
227    """
228    if not self._logdir or not saver:
229      return
230    tf.gfile.MakeDirs(self._logdir)
231    filename = os.path.join(self._logdir, 'model.ckpt')
232    saver.save(sess, filename, global_step)
233