1import collections
2from collections.abc import MutableMapping
3import copy
4import numpy as np
5
6from gym import spaces
7from gym import ObservationWrapper
8
9
10STATE_KEY = "state"
11
12
13class PixelObservationWrapper(ObservationWrapper):
14    """Augment observations by pixel values."""
15
16    def __init__(
17        self, env, pixels_only=True, render_kwargs=None, pixel_keys=("pixels",)
18    ):
19        """Initializes a new pixel Wrapper.
20
21        Args:
22            env: The environment to wrap.
23            pixels_only: If `True` (default), the original observation returned
24                by the wrapped environment will be discarded, and a dictionary
25                observation will only include pixels. If `False`, the
26                observation dictionary will contain both the original
27                observations and the pixel observations.
28            render_kwargs: Optional `dict` containing keyword arguments passed
29                to the `self.render` method.
30            pixel_keys: Optional custom string specifying the pixel
31                observation's key in the `OrderedDict` of observations.
32                Defaults to 'pixels'.
33
34        Raises:
35            ValueError: If `env`'s observation spec is not compatible with the
36                wrapper. Supported formats are a single array, or a dict of
37                arrays.
38            ValueError: If `env`'s observation already contains any of the
39                specified `pixel_keys`.
40        """
41
42        super(PixelObservationWrapper, self).__init__(env)
43
44        if render_kwargs is None:
45            render_kwargs = {}
46
47        for key in pixel_keys:
48            render_kwargs.setdefault(key, {})
49
50            render_mode = render_kwargs[key].pop("mode", "rgb_array")
51            assert render_mode == "rgb_array", render_mode
52            render_kwargs[key]["mode"] = "rgb_array"
53
54        wrapped_observation_space = env.observation_space
55
56        if isinstance(wrapped_observation_space, spaces.Box):
57            self._observation_is_dict = False
58            invalid_keys = set([STATE_KEY])
59        elif isinstance(wrapped_observation_space, (spaces.Dict, MutableMapping)):
60            self._observation_is_dict = True
61            invalid_keys = set(wrapped_observation_space.spaces.keys())
62        else:
63            raise ValueError("Unsupported observation space structure.")
64
65        if not pixels_only:
66            # Make sure that now keys in the `pixel_keys` overlap with
67            # `observation_keys`
68            overlapping_keys = set(pixel_keys) & set(invalid_keys)
69            if overlapping_keys:
70                raise ValueError(
71                    "Duplicate or reserved pixel keys {!r}.".format(overlapping_keys)
72                )
73
74        if pixels_only:
75            self.observation_space = spaces.Dict()
76        elif self._observation_is_dict:
77            self.observation_space = copy.deepcopy(wrapped_observation_space)
78        else:
79            self.observation_space = spaces.Dict()
80            self.observation_space.spaces[STATE_KEY] = wrapped_observation_space
81
82        # Extend observation space with pixels.
83
84        pixels_spaces = {}
85        for pixel_key in pixel_keys:
86            pixels = self.env.render(**render_kwargs[pixel_key])
87
88            if np.issubdtype(pixels.dtype, np.integer):
89                low, high = (0, 255)
90            elif np.issubdtype(pixels.dtype, np.float):
91                low, high = (-float("inf"), float("inf"))
92            else:
93                raise TypeError(pixels.dtype)
94
95            pixels_space = spaces.Box(
96                shape=pixels.shape, low=low, high=high, dtype=pixels.dtype
97            )
98            pixels_spaces[pixel_key] = pixels_space
99
100        self.observation_space.spaces.update(pixels_spaces)
101
102        self._env = env
103        self._pixels_only = pixels_only
104        self._render_kwargs = render_kwargs
105        self._pixel_keys = pixel_keys
106
107    def observation(self, observation):
108        pixel_observation = self._add_pixel_observation(observation)
109        return pixel_observation
110
111    def _add_pixel_observation(self, wrapped_observation):
112        if self._pixels_only:
113            observation = collections.OrderedDict()
114        elif self._observation_is_dict:
115            observation = type(wrapped_observation)(wrapped_observation)
116        else:
117            observation = collections.OrderedDict()
118            observation[STATE_KEY] = wrapped_observation
119
120        pixel_observations = {
121            pixel_key: self.env.render(**self._render_kwargs[pixel_key])
122            for pixel_key in self._pixel_keys
123        }
124
125        observation.update(pixel_observations)
126
127        return observation
128