1import gym
2import pygame
3import matplotlib
4import argparse
5from gym import logger
6
7try:
8    matplotlib.use("TkAgg")
9    import matplotlib.pyplot as plt
10except ImportError as e:
11    logger.warn("failed to set matplotlib backend, plotting will not work: %s" % str(e))
12    plt = None
13
14from collections import deque
15from pygame.locals import VIDEORESIZE
16
17
18def display_arr(screen, arr, video_size, transpose):
19    arr_min, arr_max = arr.min(), arr.max()
20    arr = 255.0 * (arr - arr_min) / (arr_max - arr_min)
21    pyg_img = pygame.surfarray.make_surface(arr.swapaxes(0, 1) if transpose else arr)
22    pyg_img = pygame.transform.scale(pyg_img, video_size)
23    screen.blit(pyg_img, (0, 0))
24
25
26def play(env, transpose=True, fps=30, zoom=None, callback=None, keys_to_action=None):
27    """Allows one to play the game using keyboard.
28
29    To simply play the game use:
30
31        play(gym.make("Pong-v4"))
32
33    Above code works also if env is wrapped, so it's particularly useful in
34    verifying that the frame-level preprocessing does not render the game
35    unplayable.
36
37    If you wish to plot real time statistics as you play, you can use
38    gym.utils.play.PlayPlot. Here's a sample code for plotting the reward
39    for last 5 second of gameplay.
40
41        def callback(obs_t, obs_tp1, action, rew, done, info):
42            return [rew,]
43        plotter = PlayPlot(callback, 30 * 5, ["reward"])
44
45        env = gym.make("Pong-v4")
46        play(env, callback=plotter.callback)
47
48
49    Arguments
50    ---------
51    env: gym.Env
52        Environment to use for playing.
53    transpose: bool
54        If True the output of observation is transposed.
55        Defaults to true.
56    fps: int
57        Maximum number of steps of the environment to execute every second.
58        Defaults to 30.
59    zoom: float
60        Make screen edge this many times bigger
61    callback: lambda or None
62        Callback if a callback is provided it will be executed after
63        every step. It takes the following input:
64            obs_t: observation before performing action
65            obs_tp1: observation after performing action
66            action: action that was executed
67            rew: reward that was received
68            done: whether the environment is done or not
69            info: debug info
70    keys_to_action: dict: tuple(int) -> int or None
71        Mapping from keys pressed to action performed.
72        For example if pressed 'w' and space at the same time is supposed
73        to trigger action number 2 then key_to_action dict would look like this:
74
75            {
76                # ...
77                sorted(ord('w'), ord(' ')) -> 2
78                # ...
79            }
80        If None, default key_to_action mapping for that env is used, if provided.
81    """
82    env.reset()
83    rendered = env.render(mode="rgb_array")
84
85    if keys_to_action is None:
86        if hasattr(env, "get_keys_to_action"):
87            keys_to_action = env.get_keys_to_action()
88        elif hasattr(env.unwrapped, "get_keys_to_action"):
89            keys_to_action = env.unwrapped.get_keys_to_action()
90        else:
91            assert False, (
92                env.spec.id
93                + " does not have explicit key to action mapping, "
94                + "please specify one manually"
95            )
96    relevant_keys = set(sum(map(list, keys_to_action.keys()), []))
97
98    video_size = [rendered.shape[1], rendered.shape[0]]
99    if zoom is not None:
100        video_size = int(video_size[0] * zoom), int(video_size[1] * zoom)
101
102    pressed_keys = []
103    running = True
104    env_done = True
105
106    screen = pygame.display.set_mode(video_size)
107    clock = pygame.time.Clock()
108
109    while running:
110        if env_done:
111            env_done = False
112            obs = env.reset()
113        else:
114            action = keys_to_action.get(tuple(sorted(pressed_keys)), 0)
115            prev_obs = obs
116            obs, rew, env_done, info = env.step(action)
117            if callback is not None:
118                callback(prev_obs, obs, action, rew, env_done, info)
119        if obs is not None:
120            rendered = env.render(mode="rgb_array")
121            display_arr(screen, rendered, transpose=transpose, video_size=video_size)
122
123        # process pygame events
124        for event in pygame.event.get():
125            # test events, set key states
126            if event.type == pygame.KEYDOWN:
127                if event.key in relevant_keys:
128                    pressed_keys.append(event.key)
129                elif event.key == 27:
130                    running = False
131            elif event.type == pygame.KEYUP:
132                if event.key in relevant_keys:
133                    pressed_keys.remove(event.key)
134            elif event.type == pygame.QUIT:
135                running = False
136            elif event.type == VIDEORESIZE:
137                video_size = event.size
138                screen = pygame.display.set_mode(video_size)
139                print(video_size)
140
141        pygame.display.flip()
142        clock.tick(fps)
143    pygame.quit()
144
145
146class PlayPlot(object):
147    def __init__(self, callback, horizon_timesteps, plot_names):
148        self.data_callback = callback
149        self.horizon_timesteps = horizon_timesteps
150        self.plot_names = plot_names
151
152        assert plt is not None, "matplotlib backend failed, plotting will not work"
153
154        num_plots = len(self.plot_names)
155        self.fig, self.ax = plt.subplots(num_plots)
156        if num_plots == 1:
157            self.ax = [self.ax]
158        for axis, name in zip(self.ax, plot_names):
159            axis.set_title(name)
160        self.t = 0
161        self.cur_plot = [None for _ in range(num_plots)]
162        self.data = [deque(maxlen=horizon_timesteps) for _ in range(num_plots)]
163
164    def callback(self, obs_t, obs_tp1, action, rew, done, info):
165        points = self.data_callback(obs_t, obs_tp1, action, rew, done, info)
166        for point, data_series in zip(points, self.data):
167            data_series.append(point)
168        self.t += 1
169
170        xmin, xmax = max(0, self.t - self.horizon_timesteps), self.t
171
172        for i, plot in enumerate(self.cur_plot):
173            if plot is not None:
174                plot.remove()
175            self.cur_plot[i] = self.ax[i].scatter(
176                range(xmin, xmax), list(self.data[i]), c="blue"
177            )
178            self.ax[i].set_xlim(xmin, xmax)
179        plt.pause(0.000001)
180
181
182def main():
183    parser = argparse.ArgumentParser()
184    parser.add_argument(
185        "--env",
186        type=str,
187        default="MontezumaRevengeNoFrameskip-v4",
188        help="Define Environment",
189    )
190    args = parser.parse_args()
191    env = gym.make(args.env)
192    play(env, zoom=4, fps=60)
193
194
195if __name__ == "__main__":
196    main()
197