1import numpy as np
2from gym.spaces import Box
3from gym import ObservationWrapper
4
5
6class GrayScaleObservation(ObservationWrapper):
7    r"""Convert the image observation from RGB to gray scale."""
8
9    def __init__(self, env, keep_dim=False):
10        super(GrayScaleObservation, self).__init__(env)
11        self.keep_dim = keep_dim
12
13        assert (
14            len(env.observation_space.shape) == 3
15            and env.observation_space.shape[-1] == 3
16        )
17
18        obs_shape = self.observation_space.shape[:2]
19        if self.keep_dim:
20            self.observation_space = Box(
21                low=0, high=255, shape=(obs_shape[0], obs_shape[1], 1), dtype=np.uint8
22            )
23        else:
24            self.observation_space = Box(
25                low=0, high=255, shape=obs_shape, dtype=np.uint8
26            )
27
28    def observation(self, observation):
29        import cv2
30
31        observation = cv2.cvtColor(observation, cv2.COLOR_RGB2GRAY)
32        if self.keep_dim:
33            observation = np.expand_dims(observation, -1)
34        return observation
35