1from gym import ObservationWrapper
2
3
4class TransformObservation(ObservationWrapper):
5    r"""Transform the observation via an arbitrary function.
6
7    Example::
8
9        >>> import gym
10        >>> env = gym.make('CartPole-v1')
11        >>> env = TransformObservation(env, lambda obs: obs + 0.1*np.random.randn(*obs.shape))
12        >>> env.reset()
13        array([-0.08319338,  0.04635121, -0.07394746,  0.20877492])
14
15    Args:
16        env (Env): environment
17        f (callable): a function that transforms the observation
18
19    """
20
21    def __init__(self, env, f):
22        super(TransformObservation, self).__init__(env)
23        assert callable(f)
24        self.f = f
25
26    def observation(self, observation):
27        return self.f(observation)
28