1import collections 2from collections.abc import MutableMapping 3import copy 4import numpy as np 5 6from gym import spaces 7from gym import ObservationWrapper 8 9 10STATE_KEY = "state" 11 12 13class PixelObservationWrapper(ObservationWrapper): 14 """Augment observations by pixel values.""" 15 16 def __init__( 17 self, env, pixels_only=True, render_kwargs=None, pixel_keys=("pixels",) 18 ): 19 """Initializes a new pixel Wrapper. 20 21 Args: 22 env: The environment to wrap. 23 pixels_only: If `True` (default), the original observation returned 24 by the wrapped environment will be discarded, and a dictionary 25 observation will only include pixels. If `False`, the 26 observation dictionary will contain both the original 27 observations and the pixel observations. 28 render_kwargs: Optional `dict` containing keyword arguments passed 29 to the `self.render` method. 30 pixel_keys: Optional custom string specifying the pixel 31 observation's key in the `OrderedDict` of observations. 32 Defaults to 'pixels'. 33 34 Raises: 35 ValueError: If `env`'s observation spec is not compatible with the 36 wrapper. Supported formats are a single array, or a dict of 37 arrays. 38 ValueError: If `env`'s observation already contains any of the 39 specified `pixel_keys`. 40 """ 41 42 super(PixelObservationWrapper, self).__init__(env) 43 44 if render_kwargs is None: 45 render_kwargs = {} 46 47 for key in pixel_keys: 48 render_kwargs.setdefault(key, {}) 49 50 render_mode = render_kwargs[key].pop("mode", "rgb_array") 51 assert render_mode == "rgb_array", render_mode 52 render_kwargs[key]["mode"] = "rgb_array" 53 54 wrapped_observation_space = env.observation_space 55 56 if isinstance(wrapped_observation_space, spaces.Box): 57 self._observation_is_dict = False 58 invalid_keys = set([STATE_KEY]) 59 elif isinstance(wrapped_observation_space, (spaces.Dict, MutableMapping)): 60 self._observation_is_dict = True 61 invalid_keys = set(wrapped_observation_space.spaces.keys()) 62 else: 63 raise ValueError("Unsupported observation space structure.") 64 65 if not pixels_only: 66 # Make sure that now keys in the `pixel_keys` overlap with 67 # `observation_keys` 68 overlapping_keys = set(pixel_keys) & set(invalid_keys) 69 if overlapping_keys: 70 raise ValueError( 71 "Duplicate or reserved pixel keys {!r}.".format(overlapping_keys) 72 ) 73 74 if pixels_only: 75 self.observation_space = spaces.Dict() 76 elif self._observation_is_dict: 77 self.observation_space = copy.deepcopy(wrapped_observation_space) 78 else: 79 self.observation_space = spaces.Dict() 80 self.observation_space.spaces[STATE_KEY] = wrapped_observation_space 81 82 # Extend observation space with pixels. 83 84 pixels_spaces = {} 85 for pixel_key in pixel_keys: 86 pixels = self.env.render(**render_kwargs[pixel_key]) 87 88 if np.issubdtype(pixels.dtype, np.integer): 89 low, high = (0, 255) 90 elif np.issubdtype(pixels.dtype, np.float): 91 low, high = (-float("inf"), float("inf")) 92 else: 93 raise TypeError(pixels.dtype) 94 95 pixels_space = spaces.Box( 96 shape=pixels.shape, low=low, high=high, dtype=pixels.dtype 97 ) 98 pixels_spaces[pixel_key] = pixels_space 99 100 self.observation_space.spaces.update(pixels_spaces) 101 102 self._env = env 103 self._pixels_only = pixels_only 104 self._render_kwargs = render_kwargs 105 self._pixel_keys = pixel_keys 106 107 def observation(self, observation): 108 pixel_observation = self._add_pixel_observation(observation) 109 return pixel_observation 110 111 def _add_pixel_observation(self, wrapped_observation): 112 if self._pixels_only: 113 observation = collections.OrderedDict() 114 elif self._observation_is_dict: 115 observation = type(wrapped_observation)(wrapped_observation) 116 else: 117 observation = collections.OrderedDict() 118 observation[STATE_KEY] = wrapped_observation 119 120 pixel_observations = { 121 pixel_key: self.env.render(**self._render_kwargs[pixel_key]) 122 for pixel_key in self._pixel_keys 123 } 124 125 observation.update(pixel_observations) 126 127 return observation 128