1# Copyright 2017 The TensorFlow Agents Authors.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#      http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14"""Batch of environments inside the TensorFlow graph."""
15
16from __future__ import absolute_import
17from __future__ import division
18from __future__ import print_function
19import pdb
20import gym
21import tf.compat.v1 as tf
22
23
24class InGraphBatchEnv(object):
25  """Batch of environments inside the TensorFlow graph.
26
27  The batch of environments will be stepped and reset inside of the graph using
28  a tf.py_func(). The current batch of observations, actions, rewards, and done
29  flags are held in according variables.
30  """
31
32  def __init__(self, batch_env):
33    """Batch of environments inside the TensorFlow graph.
34
35    Args:
36      batch_env: Batch environment.
37    """
38    self._batch_env = batch_env
39    observ_shape = self._parse_shape(self._batch_env.observation_space)
40    observ_dtype = self._parse_dtype(self._batch_env.observation_space)
41    action_shape = self._parse_shape(self._batch_env.action_space)
42    action_dtype = self._parse_dtype(self._batch_env.action_space)
43    with tf.variable_scope('env_temporary'):
44      self._observ = tf.Variable(tf.zeros((len(self._batch_env),) + observ_shape, observ_dtype),
45                                 name='observ',
46                                 trainable=False)
47      self._action = tf.Variable(tf.zeros((len(self._batch_env),) + action_shape, action_dtype),
48                                 name='action',
49                                 trainable=False)
50      self._reward = tf.Variable(tf.zeros((len(self._batch_env),), tf.float32),
51                                 name='reward',
52                                 trainable=False)
53      self._done = tf.Variable(tf.cast(tf.ones((len(self._batch_env),)), tf.bool),
54                               name='done',
55                               trainable=False)
56
57  def __getattr__(self, name):
58    """Forward unimplemented attributes to one of the original environments.
59
60    Args:
61      name: Attribute that was accessed.
62
63    Returns:
64      Value behind the attribute name in one of the original environments.
65    """
66    return getattr(self._batch_env, name)
67
68  def __len__(self):
69    """Number of combined environments."""
70    return len(self._batch_env)
71
72  def __getitem__(self, index):
73    """Access an underlying environment by index."""
74    return self._batch_env[index]
75
76  def simulate(self, action):
77    """Step the batch of environments.
78
79    The results of the step can be accessed from the variables defined below.
80
81    Args:
82      action: Tensor holding the batch of actions to apply.
83
84    Returns:
85      Operation.
86    """
87    with tf.name_scope('environment/simulate'):
88      if action.dtype in (tf.float16, tf.float32, tf.float64):
89        action = tf.check_numerics(action, 'action')
90      observ_dtype = self._parse_dtype(self._batch_env.observation_space)
91      observ, reward, done = tf.py_func(lambda a: self._batch_env.step(a)[:3], [action],
92                                        [observ_dtype, tf.float32, tf.bool],
93                                        name='step')
94      observ = tf.check_numerics(observ, 'observ')
95      reward = tf.check_numerics(reward, 'reward')
96      return tf.group(self._observ.assign(observ), self._action.assign(action),
97                      self._reward.assign(reward), self._done.assign(done))
98
99  def reset(self, indices=None):
100    """Reset the batch of environments.
101
102    Args:
103      indices: The batch indices of the environments to reset; defaults to all.
104
105    Returns:
106      Batch tensor of the new observations.
107    """
108    if indices is None:
109      indices = tf.range(len(self._batch_env))
110    observ_dtype = self._parse_dtype(self._batch_env.observation_space)
111    observ = tf.py_func(self._batch_env.reset, [indices], observ_dtype, name='reset')
112    observ = tf.check_numerics(observ, 'observ')
113    reward = tf.zeros_like(indices, tf.float32)
114    done = tf.zeros_like(indices, tf.bool)
115    with tf.control_dependencies([
116        tf.scatter_update(self._observ, indices, observ),
117        tf.scatter_update(self._reward, indices, reward),
118        tf.scatter_update(self._done, indices, done)
119    ]):
120      return tf.identity(observ)
121
122  @property
123  def observ(self):
124    """Access the variable holding the current observation."""
125    return self._observ
126
127  @property
128  def action(self):
129    """Access the variable holding the last recieved action."""
130    return self._action
131
132  @property
133  def reward(self):
134    """Access the variable holding the current reward."""
135    return self._reward
136
137  @property
138  def done(self):
139    """Access the variable indicating whether the episode is done."""
140    return self._done
141
142  def close(self):
143    """Send close messages to the external process and join them."""
144    self._batch_env.close()
145
146  def _parse_shape(self, space):
147    """Get a tensor shape from a OpenAI Gym space.
148
149    Args:
150      space: Gym space.
151
152    Returns:
153      Shape tuple.
154    """
155    if isinstance(space, gym.spaces.Discrete):
156      return ()
157    if isinstance(space, gym.spaces.Box):
158      return space.shape
159    raise NotImplementedError()
160
161  def _parse_dtype(self, space):
162    """Get a tensor dtype from a OpenAI Gym space.
163
164    Args:
165      space: Gym space.
166
167    Returns:
168      TensorFlow data type.
169    """
170    if isinstance(space, gym.spaces.Discrete):
171      return tf.int32
172    if isinstance(space, gym.spaces.Box):
173      return tf.float32
174    raise NotImplementedError()
175