1# Copyright 2017 The TensorFlow Agents Authors. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14"""Batch of environments inside the TensorFlow graph.""" 15 16from __future__ import absolute_import 17from __future__ import division 18from __future__ import print_function 19import pdb 20import gym 21import tf.compat.v1 as tf 22 23 24class InGraphBatchEnv(object): 25 """Batch of environments inside the TensorFlow graph. 26 27 The batch of environments will be stepped and reset inside of the graph using 28 a tf.py_func(). The current batch of observations, actions, rewards, and done 29 flags are held in according variables. 30 """ 31 32 def __init__(self, batch_env): 33 """Batch of environments inside the TensorFlow graph. 34 35 Args: 36 batch_env: Batch environment. 37 """ 38 self._batch_env = batch_env 39 observ_shape = self._parse_shape(self._batch_env.observation_space) 40 observ_dtype = self._parse_dtype(self._batch_env.observation_space) 41 action_shape = self._parse_shape(self._batch_env.action_space) 42 action_dtype = self._parse_dtype(self._batch_env.action_space) 43 with tf.variable_scope('env_temporary'): 44 self._observ = tf.Variable(tf.zeros((len(self._batch_env),) + observ_shape, observ_dtype), 45 name='observ', 46 trainable=False) 47 self._action = tf.Variable(tf.zeros((len(self._batch_env),) + action_shape, action_dtype), 48 name='action', 49 trainable=False) 50 self._reward = tf.Variable(tf.zeros((len(self._batch_env),), tf.float32), 51 name='reward', 52 trainable=False) 53 self._done = tf.Variable(tf.cast(tf.ones((len(self._batch_env),)), tf.bool), 54 name='done', 55 trainable=False) 56 57 def __getattr__(self, name): 58 """Forward unimplemented attributes to one of the original environments. 59 60 Args: 61 name: Attribute that was accessed. 62 63 Returns: 64 Value behind the attribute name in one of the original environments. 65 """ 66 return getattr(self._batch_env, name) 67 68 def __len__(self): 69 """Number of combined environments.""" 70 return len(self._batch_env) 71 72 def __getitem__(self, index): 73 """Access an underlying environment by index.""" 74 return self._batch_env[index] 75 76 def simulate(self, action): 77 """Step the batch of environments. 78 79 The results of the step can be accessed from the variables defined below. 80 81 Args: 82 action: Tensor holding the batch of actions to apply. 83 84 Returns: 85 Operation. 86 """ 87 with tf.name_scope('environment/simulate'): 88 if action.dtype in (tf.float16, tf.float32, tf.float64): 89 action = tf.check_numerics(action, 'action') 90 observ_dtype = self._parse_dtype(self._batch_env.observation_space) 91 observ, reward, done = tf.py_func(lambda a: self._batch_env.step(a)[:3], [action], 92 [observ_dtype, tf.float32, tf.bool], 93 name='step') 94 observ = tf.check_numerics(observ, 'observ') 95 reward = tf.check_numerics(reward, 'reward') 96 return tf.group(self._observ.assign(observ), self._action.assign(action), 97 self._reward.assign(reward), self._done.assign(done)) 98 99 def reset(self, indices=None): 100 """Reset the batch of environments. 101 102 Args: 103 indices: The batch indices of the environments to reset; defaults to all. 104 105 Returns: 106 Batch tensor of the new observations. 107 """ 108 if indices is None: 109 indices = tf.range(len(self._batch_env)) 110 observ_dtype = self._parse_dtype(self._batch_env.observation_space) 111 observ = tf.py_func(self._batch_env.reset, [indices], observ_dtype, name='reset') 112 observ = tf.check_numerics(observ, 'observ') 113 reward = tf.zeros_like(indices, tf.float32) 114 done = tf.zeros_like(indices, tf.bool) 115 with tf.control_dependencies([ 116 tf.scatter_update(self._observ, indices, observ), 117 tf.scatter_update(self._reward, indices, reward), 118 tf.scatter_update(self._done, indices, done) 119 ]): 120 return tf.identity(observ) 121 122 @property 123 def observ(self): 124 """Access the variable holding the current observation.""" 125 return self._observ 126 127 @property 128 def action(self): 129 """Access the variable holding the last recieved action.""" 130 return self._action 131 132 @property 133 def reward(self): 134 """Access the variable holding the current reward.""" 135 return self._reward 136 137 @property 138 def done(self): 139 """Access the variable indicating whether the episode is done.""" 140 return self._done 141 142 def close(self): 143 """Send close messages to the external process and join them.""" 144 self._batch_env.close() 145 146 def _parse_shape(self, space): 147 """Get a tensor shape from a OpenAI Gym space. 148 149 Args: 150 space: Gym space. 151 152 Returns: 153 Shape tuple. 154 """ 155 if isinstance(space, gym.spaces.Discrete): 156 return () 157 if isinstance(space, gym.spaces.Box): 158 return space.shape 159 raise NotImplementedError() 160 161 def _parse_dtype(self, space): 162 """Get a tensor dtype from a OpenAI Gym space. 163 164 Args: 165 space: Gym space. 166 167 Returns: 168 TensorFlow data type. 169 """ 170 if isinstance(space, gym.spaces.Discrete): 171 return tf.int32 172 if isinstance(space, gym.spaces.Box): 173 return tf.float32 174 raise NotImplementedError() 175