1# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
2#
3# This source code is licensed under the MIT license found in the
4# LICENSE file in the root directory of this source tree.
5
6import warnings
7import operator
8import copy as _copy
9import typing as tp
10import gym
11import numpy as np
12from nevergrad.common.tools import pytorch_import_fix
13from nevergrad.parametrization import parameter as p
14from ..base import ExperimentFunction
15from . import base
16from . import envs
17
18pytorch_import_fix()
19
20# pylint: disable=wrong-import-position,wrong-import-order
21import torch as torch  # noqa
22import torch.nn.functional as F  # noqa
23from torch import nn  # noqa
24from torch.utils.data import WeightedRandomSampler  # noqa
25
26
27class RandomAgent(base.Agent):
28    """Agent that plays randomly."""
29
30    def __init__(self, env: gym.Env) -> None:
31        self.env = env
32        assert isinstance(env.action_space, gym.spaces.Discrete)
33        self.num_outputs = env.action_space.n
34
35    def act(
36        self,
37        observation: tp.Any,
38        reward: tp.Any,
39        done: bool,
40        info: tp.Optional[tp.Dict[tp.Any, tp.Any]] = None,
41    ) -> tp.Any:
42        return np.random.randint(self.num_outputs)
43
44    def copy(self) -> "RandomAgent":
45        return self.__class__(self.env)
46
47
48class Agent007(base.Agent):
49    """Agents that plays slighlty better than random on the 007 game."""
50
51    def __init__(self, env: gym.Env) -> None:
52        self.env = env
53        assert isinstance(env, envs.DoubleOSeven) or (
54            isinstance(env, base.SingleAgentEnv) and isinstance(env.env, envs.DoubleOSeven)
55        )
56
57    def act(
58        self,
59        observation: tp.Any,
60        reward: tp.Any,
61        done: bool,
62        info: tp.Optional[tp.Dict[tp.Any, tp.Any]] = None,
63    ) -> tp.Any:
64        my_amm, my_prot, their_amm, their_prot = observation  # pylint: disable=unused-variable
65        if their_prot == 4 and my_amm:
66            action = "fire"
67        elif their_amm == 0:
68            action = np.random.choice(["fire", "reload"])
69        else:
70            action = np.random.choice(["fire", "protect", "reload"])
71        return envs.JamesBond.actions.index(action)
72
73    def copy(self) -> "Agent007":
74        return self.__class__(self.env)
75
76
77class TorchAgent(base.Agent):
78    """Agents than plays through a torch neural network"""
79
80    def __init__(
81        self, module: nn.Module, deterministic: bool = True, instrumentation_std: float = 0.1
82    ) -> None:
83        super().__init__()
84        self.deterministic = deterministic
85        self.module = module
86        kwargs = {
87            name: p.Array(shape=value.shape)
88            .set_mutation(sigma=instrumentation_std)
89            .set_bounds(-10, 10, method="arctan")
90            for name, value in module.state_dict().items()  # type: ignore
91        }  # bounded to avoid overflows
92        self.instrumentation = p.Instrumentation(**kwargs)
93
94    @classmethod
95    def from_module_maker(
96        cls,
97        env: gym.Env,
98        module_maker: tp.Callable[[tp.Tuple[int, ...], int], nn.Module],
99        deterministic: bool = True,
100    ) -> "TorchAgent":
101        assert isinstance(env.action_space, gym.spaces.Discrete)
102        assert isinstance(env.observation_space, gym.spaces.Box)
103        module = module_maker(env.observation_space.shape, env.action_space.n)
104        return cls(module, deterministic=deterministic)
105
106    def act(
107        self,
108        observation: tp.Any,
109        reward: tp.Any,
110        done: bool,
111        info: tp.Optional[tp.Dict[tp.Any, tp.Any]] = None,
112    ) -> tp.Any:
113        obs = torch.from_numpy(observation.astype(np.float32))
114        forward = self.module.forward(obs)  # type: ignore
115        probas = F.softmax(forward, dim=0)
116        if self.deterministic:
117            return probas.max(0)[1].view(1, 1).item()
118        else:
119            return next(iter(WeightedRandomSampler(probas, 1)))
120
121    def copy(self) -> "TorchAgent":
122        return TorchAgent(_copy.deepcopy(self.module), self.deterministic)
123
124    def load_state_dict(self, state_dict: tp.Dict[str, np.ndarray]) -> None:
125        # pylint: disable=not-callable
126        self.module.load_state_dict({x: torch.tensor(y.astype(np.float32)) for x, y in state_dict.items()})
127
128
129class TorchAgentFunction(ExperimentFunction):
130    """Instrumented function which plays the agent using an environment runner"""
131
132    _num_test_evaluations = 1000
133
134    def __init__(
135        self,
136        agent: TorchAgent,
137        env_runner: base.EnvironmentRunner,
138        reward_postprocessing: tp.Callable[[float], float] = operator.neg,
139    ) -> None:
140        assert isinstance(env_runner.env, gym.Env)
141        self.agent = agent.copy()
142        self.runner = env_runner.copy()
143        self.reward_postprocessing = reward_postprocessing
144        super().__init__(self.compute, self.agent.instrumentation.copy().set_name(""))
145        self.parametrization.function.deterministic = False
146        self.add_descriptors(
147            num_repetitions=self.runner.num_repetitions, archi=self.agent.module.__class__.__name__
148        )
149
150    def compute(self, **kwargs: np.ndarray) -> float:
151        self.agent.load_state_dict(kwargs)
152        try:  # safeguard against nans
153            with torch.no_grad():
154                reward = self.runner.run(self.agent)
155
156        except RuntimeError as e:
157            warnings.warn(f"Returning 0 after error: {e}")
158            reward = 0.0
159        assert isinstance(reward, (int, float))
160        return self.reward_postprocessing(reward)
161
162    def evaluation_function(self, *recommendations: p.Parameter) -> float:
163        """Implements the call of the function.
164        Under the hood, __call__ delegates to oracle_call + add some noise if noise_level > 0.
165        """
166        assert len(recommendations) == 1, "Should not be a pareto set for a singleobjective function"
167        num_tests = max(1, int(self._num_test_evaluations / self.runner.num_repetitions))
168        return sum(self.compute(**recommendations[0].kwargs) for _ in range(num_tests)) / num_tests
169
170
171class Perceptron(nn.Module):
172    def __init__(self, input_shape: tp.Tuple[int, ...], output_size: int) -> None:
173        super().__init__()  # type: ignore
174        assert len(input_shape) == 1
175        self.head = nn.Linear(input_shape[0], output_size)  # type: ignore
176
177    def forward(self, *args: tp.Any) -> tp.Any:
178        assert len(args) == 1
179        return self.head(args[0])
180
181
182class DenseNet(nn.Module):
183    def __init__(self, input_shape: tp.Tuple[int, ...], output_size: int) -> None:
184        super().__init__()  # type: ignore
185        assert len(input_shape) == 1
186        self.lin1 = nn.Linear(input_shape[0], 16)  # type: ignore
187        self.lin2 = nn.Linear(16, 16)  # type: ignore
188        self.lin3 = nn.Linear(16, 16)  # type: ignore
189        self.head = nn.Linear(16, output_size)  # type: ignore
190
191    def forward(self, *args: tp.Any) -> tp.Any:
192        assert len(args) == 1
193        x = F.relu(self.lin1(args[0]))
194        x = F.relu(self.lin2(x))
195        x = F.relu(self.lin3(x))
196        return self.head(x)
197