1from gym import ObservationWrapper 2 3 4class TransformObservation(ObservationWrapper): 5 r"""Transform the observation via an arbitrary function. 6 7 Example:: 8 9 >>> import gym 10 >>> env = gym.make('CartPole-v1') 11 >>> env = TransformObservation(env, lambda obs: obs + 0.1*np.random.randn(*obs.shape)) 12 >>> env.reset() 13 array([-0.08319338, 0.04635121, -0.07394746, 0.20877492]) 14 15 Args: 16 env (Env): environment 17 f (callable): a function that transforms the observation 18 19 """ 20 21 def __init__(self, env, f): 22 super(TransformObservation, self).__init__(env) 23 assert callable(f) 24 self.f = f 25 26 def observation(self, observation): 27 return self.f(observation) 28