1import numpy as np 2from gym.spaces import Box 3from gym import ObservationWrapper 4 5 6class GrayScaleObservation(ObservationWrapper): 7 r"""Convert the image observation from RGB to gray scale.""" 8 9 def __init__(self, env, keep_dim=False): 10 super(GrayScaleObservation, self).__init__(env) 11 self.keep_dim = keep_dim 12 13 assert ( 14 len(env.observation_space.shape) == 3 15 and env.observation_space.shape[-1] == 3 16 ) 17 18 obs_shape = self.observation_space.shape[:2] 19 if self.keep_dim: 20 self.observation_space = Box( 21 low=0, high=255, shape=(obs_shape[0], obs_shape[1], 1), dtype=np.uint8 22 ) 23 else: 24 self.observation_space = Box( 25 low=0, high=255, shape=obs_shape, dtype=np.uint8 26 ) 27 28 def observation(self, observation): 29 import cv2 30 31 observation = cv2.cvtColor(observation, cv2.COLOR_RGB2GRAY) 32 if self.keep_dim: 33 observation = np.expand_dims(observation, -1) 34 return observation 35