1# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2# 3# This source code is licensed under the MIT license found in the 4# LICENSE file in the root directory of this source tree. 5 6import warnings 7import operator 8import copy as _copy 9import typing as tp 10import gym 11import numpy as np 12from nevergrad.common.tools import pytorch_import_fix 13from nevergrad.parametrization import parameter as p 14from ..base import ExperimentFunction 15from . import base 16from . import envs 17 18pytorch_import_fix() 19 20# pylint: disable=wrong-import-position,wrong-import-order 21import torch as torch # noqa 22import torch.nn.functional as F # noqa 23from torch import nn # noqa 24from torch.utils.data import WeightedRandomSampler # noqa 25 26 27class RandomAgent(base.Agent): 28 """Agent that plays randomly.""" 29 30 def __init__(self, env: gym.Env) -> None: 31 self.env = env 32 assert isinstance(env.action_space, gym.spaces.Discrete) 33 self.num_outputs = env.action_space.n 34 35 def act( 36 self, 37 observation: tp.Any, 38 reward: tp.Any, 39 done: bool, 40 info: tp.Optional[tp.Dict[tp.Any, tp.Any]] = None, 41 ) -> tp.Any: 42 return np.random.randint(self.num_outputs) 43 44 def copy(self) -> "RandomAgent": 45 return self.__class__(self.env) 46 47 48class Agent007(base.Agent): 49 """Agents that plays slighlty better than random on the 007 game.""" 50 51 def __init__(self, env: gym.Env) -> None: 52 self.env = env 53 assert isinstance(env, envs.DoubleOSeven) or ( 54 isinstance(env, base.SingleAgentEnv) and isinstance(env.env, envs.DoubleOSeven) 55 ) 56 57 def act( 58 self, 59 observation: tp.Any, 60 reward: tp.Any, 61 done: bool, 62 info: tp.Optional[tp.Dict[tp.Any, tp.Any]] = None, 63 ) -> tp.Any: 64 my_amm, my_prot, their_amm, their_prot = observation # pylint: disable=unused-variable 65 if their_prot == 4 and my_amm: 66 action = "fire" 67 elif their_amm == 0: 68 action = np.random.choice(["fire", "reload"]) 69 else: 70 action = np.random.choice(["fire", "protect", "reload"]) 71 return envs.JamesBond.actions.index(action) 72 73 def copy(self) -> "Agent007": 74 return self.__class__(self.env) 75 76 77class TorchAgent(base.Agent): 78 """Agents than plays through a torch neural network""" 79 80 def __init__( 81 self, module: nn.Module, deterministic: bool = True, instrumentation_std: float = 0.1 82 ) -> None: 83 super().__init__() 84 self.deterministic = deterministic 85 self.module = module 86 kwargs = { 87 name: p.Array(shape=value.shape) 88 .set_mutation(sigma=instrumentation_std) 89 .set_bounds(-10, 10, method="arctan") 90 for name, value in module.state_dict().items() # type: ignore 91 } # bounded to avoid overflows 92 self.instrumentation = p.Instrumentation(**kwargs) 93 94 @classmethod 95 def from_module_maker( 96 cls, 97 env: gym.Env, 98 module_maker: tp.Callable[[tp.Tuple[int, ...], int], nn.Module], 99 deterministic: bool = True, 100 ) -> "TorchAgent": 101 assert isinstance(env.action_space, gym.spaces.Discrete) 102 assert isinstance(env.observation_space, gym.spaces.Box) 103 module = module_maker(env.observation_space.shape, env.action_space.n) 104 return cls(module, deterministic=deterministic) 105 106 def act( 107 self, 108 observation: tp.Any, 109 reward: tp.Any, 110 done: bool, 111 info: tp.Optional[tp.Dict[tp.Any, tp.Any]] = None, 112 ) -> tp.Any: 113 obs = torch.from_numpy(observation.astype(np.float32)) 114 forward = self.module.forward(obs) # type: ignore 115 probas = F.softmax(forward, dim=0) 116 if self.deterministic: 117 return probas.max(0)[1].view(1, 1).item() 118 else: 119 return next(iter(WeightedRandomSampler(probas, 1))) 120 121 def copy(self) -> "TorchAgent": 122 return TorchAgent(_copy.deepcopy(self.module), self.deterministic) 123 124 def load_state_dict(self, state_dict: tp.Dict[str, np.ndarray]) -> None: 125 # pylint: disable=not-callable 126 self.module.load_state_dict({x: torch.tensor(y.astype(np.float32)) for x, y in state_dict.items()}) 127 128 129class TorchAgentFunction(ExperimentFunction): 130 """Instrumented function which plays the agent using an environment runner""" 131 132 _num_test_evaluations = 1000 133 134 def __init__( 135 self, 136 agent: TorchAgent, 137 env_runner: base.EnvironmentRunner, 138 reward_postprocessing: tp.Callable[[float], float] = operator.neg, 139 ) -> None: 140 assert isinstance(env_runner.env, gym.Env) 141 self.agent = agent.copy() 142 self.runner = env_runner.copy() 143 self.reward_postprocessing = reward_postprocessing 144 super().__init__(self.compute, self.agent.instrumentation.copy().set_name("")) 145 self.parametrization.function.deterministic = False 146 self.add_descriptors( 147 num_repetitions=self.runner.num_repetitions, archi=self.agent.module.__class__.__name__ 148 ) 149 150 def compute(self, **kwargs: np.ndarray) -> float: 151 self.agent.load_state_dict(kwargs) 152 try: # safeguard against nans 153 with torch.no_grad(): 154 reward = self.runner.run(self.agent) 155 156 except RuntimeError as e: 157 warnings.warn(f"Returning 0 after error: {e}") 158 reward = 0.0 159 assert isinstance(reward, (int, float)) 160 return self.reward_postprocessing(reward) 161 162 def evaluation_function(self, *recommendations: p.Parameter) -> float: 163 """Implements the call of the function. 164 Under the hood, __call__ delegates to oracle_call + add some noise if noise_level > 0. 165 """ 166 assert len(recommendations) == 1, "Should not be a pareto set for a singleobjective function" 167 num_tests = max(1, int(self._num_test_evaluations / self.runner.num_repetitions)) 168 return sum(self.compute(**recommendations[0].kwargs) for _ in range(num_tests)) / num_tests 169 170 171class Perceptron(nn.Module): 172 def __init__(self, input_shape: tp.Tuple[int, ...], output_size: int) -> None: 173 super().__init__() # type: ignore 174 assert len(input_shape) == 1 175 self.head = nn.Linear(input_shape[0], output_size) # type: ignore 176 177 def forward(self, *args: tp.Any) -> tp.Any: 178 assert len(args) == 1 179 return self.head(args[0]) 180 181 182class DenseNet(nn.Module): 183 def __init__(self, input_shape: tp.Tuple[int, ...], output_size: int) -> None: 184 super().__init__() # type: ignore 185 assert len(input_shape) == 1 186 self.lin1 = nn.Linear(input_shape[0], 16) # type: ignore 187 self.lin2 = nn.Linear(16, 16) # type: ignore 188 self.lin3 = nn.Linear(16, 16) # type: ignore 189 self.head = nn.Linear(16, output_size) # type: ignore 190 191 def forward(self, *args: tp.Any) -> tp.Any: 192 assert len(args) == 1 193 x = F.relu(self.lin1(args[0])) 194 x = F.relu(self.lin2(x)) 195 x = F.relu(self.lin3(x)) 196 return self.head(x) 197