1import gym
2from gym.spaces import Tuple
3from gym.vector.utils.spaces import batch_space
4
5__all__ = ["VectorEnv"]
6
7
8class VectorEnv(gym.Env):
9    r"""Base class for vectorized environments.
10
11    Each observation returned from vectorized environment is a batch of observations
12    for each sub-environment. And :meth:`step` is also expected to receive a batch of
13    actions for each sub-environment.
14
15    .. note::
16
17        All sub-environments should share the identical observation and action spaces.
18        In other words, a vector of multiple different environments is not supported.
19
20    Parameters
21    ----------
22    num_envs : int
23        Number of environments in the vectorized environment.
24
25    observation_space : `gym.spaces.Space` instance
26        Observation space of a single environment.
27
28    action_space : `gym.spaces.Space` instance
29        Action space of a single environment.
30    """
31
32    def __init__(self, num_envs, observation_space, action_space):
33        super(VectorEnv, self).__init__()
34        self.num_envs = num_envs
35        self.is_vector_env = True
36        self.observation_space = batch_space(observation_space, n=num_envs)
37        self.action_space = Tuple((action_space,) * num_envs)
38
39        self.closed = False
40        self.viewer = None
41
42        # The observation and action spaces of a single environment are
43        # kept in separate properties
44        self.single_observation_space = observation_space
45        self.single_action_space = action_space
46
47    def reset_async(self):
48        pass
49
50    def reset_wait(self, **kwargs):
51        raise NotImplementedError()
52
53    def reset(self):
54        r"""Reset all sub-environments and return a batch of initial observations.
55
56        Returns
57        -------
58        observations : sample from `observation_space`
59            A batch of observations from the vectorized environment.
60        """
61        self.reset_async()
62        return self.reset_wait()
63
64    def step_async(self, actions):
65        pass
66
67    def step_wait(self, **kwargs):
68        raise NotImplementedError()
69
70    def step(self, actions):
71        r"""Take an action for each sub-environments.
72
73        Parameters
74        ----------
75        actions : iterable of samples from `action_space`
76            List of actions.
77
78        Returns
79        -------
80        observations : sample from `observation_space`
81            A batch of observations from the vectorized environment.
82
83        rewards : `np.ndarray` instance (dtype `np.float_`)
84            A vector of rewards from the vectorized environment.
85
86        dones : `np.ndarray` instance (dtype `np.bool_`)
87            A vector whose entries indicate whether the episode has ended.
88
89        infos : list of dict
90            A list of auxiliary diagnostic information dicts from sub-environments.
91        """
92
93        self.step_async(actions)
94        return self.step_wait()
95
96    def close_extras(self, **kwargs):
97        r"""Clean up the extra resources e.g. beyond what's in this base class."""
98        raise NotImplementedError()
99
100    def close(self, **kwargs):
101        r"""Close all sub-environments and release resources.
102
103        It also closes all the existing image viewers, then calls :meth:`close_extras` and set
104        :attr:`closed` as ``True``.
105
106        .. warning::
107
108            This function itself does not close the environments, it should be handled
109            in :meth:`close_extras`. This is generic for both synchronous and asynchronous
110            vectorized environments.
111
112        .. note::
113
114            This will be automatically called when garbage collected or program exited.
115
116        """
117        if self.closed:
118            return
119        if self.viewer is not None:
120            self.viewer.close()
121        self.close_extras(**kwargs)
122        self.closed = True
123
124    def seed(self, seeds=None):
125        """
126        Parameters
127        ----------
128        seeds : list of int, or int, optional
129            Random seed for each individual environment. If `seeds` is a list of
130            length `num_envs`, then the items of the list are chosen as random
131            seeds. If `seeds` is an int, then each environment uses the random
132            seed `seeds + n`, where `n` is the index of the environment (between
133            `0` and `num_envs - 1`).
134        """
135        pass
136
137    def __del__(self):
138        if not getattr(self, "closed", True):
139            self.close(terminate=True)
140
141    def __repr__(self):
142        if self.spec is None:
143            return "{}({})".format(self.__class__.__name__, self.num_envs)
144        else:
145            return "{}({}, {})".format(
146                self.__class__.__name__, self.spec.id, self.num_envs
147            )
148
149
150class VectorEnvWrapper(VectorEnv):
151    r"""Wraps the vectorized environment to allow a modular transformation.
152
153    This class is the base class for all wrappers for vectorized environments. The subclass
154    could override some methods to change the behavior of the original vectorized environment
155    without touching the original code.
156
157    .. note::
158
159        Don't forget to call ``super().__init__(env)`` if the subclass overrides :meth:`__init__`.
160
161    """
162
163    def __init__(self, env):
164        assert isinstance(env, VectorEnv)
165        self.env = env
166
167    # explicitly forward the methods defined in VectorEnv
168    # to self.env (instead of the base class)
169    def reset_async(self):
170        return self.env.reset_async()
171
172    def reset_wait(self):
173        return self.env.reset_wait()
174
175    def step_async(self, actions):
176        return self.env.step_async(actions)
177
178    def step_wait(self):
179        return self.env.step_wait()
180
181    def close(self, **kwargs):
182        return self.env.close(**kwargs)
183
184    def close_extras(self, **kwargs):
185        return self.env.close_extras(**kwargs)
186
187    def seed(self, seeds=None):
188        return self.env.seed(seeds)
189
190    # implicitly forward all other methods and attributes to self.env
191    def __getattr__(self, name):
192        if name.startswith("_"):
193            raise AttributeError(
194                "attempted to get missing private attribute '{}'".format(name)
195            )
196        return getattr(self.env, name)
197
198    @property
199    def unwrapped(self):
200        return self.env.unwrapped
201
202    def __repr__(self):
203        return "<{}, {}>".format(self.__class__.__name__, self.env)
204