1import gym 2import pygame 3import matplotlib 4import argparse 5from gym import logger 6 7try: 8 matplotlib.use("TkAgg") 9 import matplotlib.pyplot as plt 10except ImportError as e: 11 logger.warn("failed to set matplotlib backend, plotting will not work: %s" % str(e)) 12 plt = None 13 14from collections import deque 15from pygame.locals import VIDEORESIZE 16 17 18def display_arr(screen, arr, video_size, transpose): 19 arr_min, arr_max = arr.min(), arr.max() 20 arr = 255.0 * (arr - arr_min) / (arr_max - arr_min) 21 pyg_img = pygame.surfarray.make_surface(arr.swapaxes(0, 1) if transpose else arr) 22 pyg_img = pygame.transform.scale(pyg_img, video_size) 23 screen.blit(pyg_img, (0, 0)) 24 25 26def play(env, transpose=True, fps=30, zoom=None, callback=None, keys_to_action=None): 27 """Allows one to play the game using keyboard. 28 29 To simply play the game use: 30 31 play(gym.make("Pong-v4")) 32 33 Above code works also if env is wrapped, so it's particularly useful in 34 verifying that the frame-level preprocessing does not render the game 35 unplayable. 36 37 If you wish to plot real time statistics as you play, you can use 38 gym.utils.play.PlayPlot. Here's a sample code for plotting the reward 39 for last 5 second of gameplay. 40 41 def callback(obs_t, obs_tp1, action, rew, done, info): 42 return [rew,] 43 plotter = PlayPlot(callback, 30 * 5, ["reward"]) 44 45 env = gym.make("Pong-v4") 46 play(env, callback=plotter.callback) 47 48 49 Arguments 50 --------- 51 env: gym.Env 52 Environment to use for playing. 53 transpose: bool 54 If True the output of observation is transposed. 55 Defaults to true. 56 fps: int 57 Maximum number of steps of the environment to execute every second. 58 Defaults to 30. 59 zoom: float 60 Make screen edge this many times bigger 61 callback: lambda or None 62 Callback if a callback is provided it will be executed after 63 every step. It takes the following input: 64 obs_t: observation before performing action 65 obs_tp1: observation after performing action 66 action: action that was executed 67 rew: reward that was received 68 done: whether the environment is done or not 69 info: debug info 70 keys_to_action: dict: tuple(int) -> int or None 71 Mapping from keys pressed to action performed. 72 For example if pressed 'w' and space at the same time is supposed 73 to trigger action number 2 then key_to_action dict would look like this: 74 75 { 76 # ... 77 sorted(ord('w'), ord(' ')) -> 2 78 # ... 79 } 80 If None, default key_to_action mapping for that env is used, if provided. 81 """ 82 env.reset() 83 rendered = env.render(mode="rgb_array") 84 85 if keys_to_action is None: 86 if hasattr(env, "get_keys_to_action"): 87 keys_to_action = env.get_keys_to_action() 88 elif hasattr(env.unwrapped, "get_keys_to_action"): 89 keys_to_action = env.unwrapped.get_keys_to_action() 90 else: 91 assert False, ( 92 env.spec.id 93 + " does not have explicit key to action mapping, " 94 + "please specify one manually" 95 ) 96 relevant_keys = set(sum(map(list, keys_to_action.keys()), [])) 97 98 video_size = [rendered.shape[1], rendered.shape[0]] 99 if zoom is not None: 100 video_size = int(video_size[0] * zoom), int(video_size[1] * zoom) 101 102 pressed_keys = [] 103 running = True 104 env_done = True 105 106 screen = pygame.display.set_mode(video_size) 107 clock = pygame.time.Clock() 108 109 while running: 110 if env_done: 111 env_done = False 112 obs = env.reset() 113 else: 114 action = keys_to_action.get(tuple(sorted(pressed_keys)), 0) 115 prev_obs = obs 116 obs, rew, env_done, info = env.step(action) 117 if callback is not None: 118 callback(prev_obs, obs, action, rew, env_done, info) 119 if obs is not None: 120 rendered = env.render(mode="rgb_array") 121 display_arr(screen, rendered, transpose=transpose, video_size=video_size) 122 123 # process pygame events 124 for event in pygame.event.get(): 125 # test events, set key states 126 if event.type == pygame.KEYDOWN: 127 if event.key in relevant_keys: 128 pressed_keys.append(event.key) 129 elif event.key == 27: 130 running = False 131 elif event.type == pygame.KEYUP: 132 if event.key in relevant_keys: 133 pressed_keys.remove(event.key) 134 elif event.type == pygame.QUIT: 135 running = False 136 elif event.type == VIDEORESIZE: 137 video_size = event.size 138 screen = pygame.display.set_mode(video_size) 139 print(video_size) 140 141 pygame.display.flip() 142 clock.tick(fps) 143 pygame.quit() 144 145 146class PlayPlot(object): 147 def __init__(self, callback, horizon_timesteps, plot_names): 148 self.data_callback = callback 149 self.horizon_timesteps = horizon_timesteps 150 self.plot_names = plot_names 151 152 assert plt is not None, "matplotlib backend failed, plotting will not work" 153 154 num_plots = len(self.plot_names) 155 self.fig, self.ax = plt.subplots(num_plots) 156 if num_plots == 1: 157 self.ax = [self.ax] 158 for axis, name in zip(self.ax, plot_names): 159 axis.set_title(name) 160 self.t = 0 161 self.cur_plot = [None for _ in range(num_plots)] 162 self.data = [deque(maxlen=horizon_timesteps) for _ in range(num_plots)] 163 164 def callback(self, obs_t, obs_tp1, action, rew, done, info): 165 points = self.data_callback(obs_t, obs_tp1, action, rew, done, info) 166 for point, data_series in zip(points, self.data): 167 data_series.append(point) 168 self.t += 1 169 170 xmin, xmax = max(0, self.t - self.horizon_timesteps), self.t 171 172 for i, plot in enumerate(self.cur_plot): 173 if plot is not None: 174 plot.remove() 175 self.cur_plot[i] = self.ax[i].scatter( 176 range(xmin, xmax), list(self.data[i]), c="blue" 177 ) 178 self.ax[i].set_xlim(xmin, xmax) 179 plt.pause(0.000001) 180 181 182def main(): 183 parser = argparse.ArgumentParser() 184 parser.add_argument( 185 "--env", 186 type=str, 187 default="MontezumaRevengeNoFrameskip-v4", 188 help="Define Environment", 189 ) 190 args = parser.parse_args() 191 env = gym.make(args.env) 192 play(env, zoom=4, fps=60) 193 194 195if __name__ == "__main__": 196 main() 197