1# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
2#
3# This source code is licensed under the MIT license found in the
4# LICENSE file in the root directory of this source tree.
5
6import pytest
7import numpy as np
8import nevergrad as ng
9from nevergrad.common.tools import flatten
10from .optimizerlib import registry
11from .externalbo import _hp_parametrization_to_dict, _hp_dict_to_parametrization
12
13
14@pytest.mark.parametrize(  # type: ignore
15    "parametrization,has_transform",
16    [
17        (ng.p.Choice(list(range(10))), True),
18        (ng.p.Scalar(lower=0, upper=1), True),
19        (ng.p.Scalar(lower=0, upper=10).set_integer_casting(), True),
20        (ng.p.Log(lower=1e-3, upper=1e3), True),
21        (ng.p.Array(init=np.zeros(10)), True),
22        (ng.p.Instrumentation(ng.p.Scalar(lower=0, upper=1), a=ng.p.Choice(list(range(10)))), False),
23        (
24            ng.p.Instrumentation(
25                a=ng.p.Choice([ng.p.Scalar(lower=0, upper=1), ng.p.Scalar(lower=100, upper=1000)])
26            ),
27            True,
28        ),
29        (
30            ng.p.Instrumentation(
31                a=ng.p.Choice(
32                    [
33                        ng.p.Choice(list(range(10))),
34                        ng.p.Scalar(lower=0, upper=1),
35                    ]
36                )
37            ),
38            False,
39        ),
40        (
41            ng.p.Instrumentation(
42                a=ng.p.Choice(
43                    [
44                        ng.p.Instrumentation(
45                            b=ng.p.Choice(list(range(10))), c=ng.p.Log(lower=1e-3, upper=1e3)
46                        ),
47                        ng.p.Instrumentation(
48                            d=ng.p.Scalar(lower=0, upper=1), e=ng.p.Log(lower=1e-3, upper=1e3)
49                        ),
50                    ]
51                )
52            ),
53            False,
54        ),
55    ],
56)
57def test_hyperopt(parametrization, has_transform) -> None:
58    optim1 = registry["HyperOpt"](parametrization=parametrization, budget=5)
59    optim2 = registry["HyperOpt"](parametrization=parametrization.copy(), budget=5)
60    for it in range(4):
61        cand = optim1.ask()
62        optim1.tell(cand, 0)  # Tell asked
63        del cand._meta["trial_id"]
64        optim2.tell(cand, 0)  # Tell not asked
65        assert flatten(optim1.trials._dynamic_trials[it]["misc"]["vals"]) == pytest.approx(  # type: ignore
66            flatten(optim2.trials._dynamic_trials[it]["misc"]["vals"])  # type: ignore
67        )
68
69    assert optim1.trials.new_trial_ids(1) == optim2.trials.new_trial_ids(1)  # type: ignore
70    assert optim1.trials.new_trial_ids(1)[0] == (it + 2)  # type: ignore
71    assert (optim1._transform is not None) == has_transform  # type: ignore
72
73    # Test parallelization
74    opt = registry["HyperOpt"](parametrization=parametrization, budget=30, num_workers=5)
75    for k in range(40):
76        cand = opt.ask()
77        if not k:
78            opt.tell(cand, 1)
79
80
81@pytest.mark.parametrize(  # type: ignore
82    "parametrization,values",
83    [
84        (
85            ng.p.Instrumentation(
86                a=ng.p.Choice([ng.p.Choice(list(range(10))), ng.p.Scalar(lower=0, upper=1)])
87            ),
88            [
89                (((), {"a": 0.5}), {"a": [1], "a__1": [0.5]}, {"args": {}, "kwargs": {"a": 0.5}}),
90                (((), {"a": 1}), {"a": [0], "a__0": [1]}, {"args": {}, "kwargs": {"a": 1}}),
91            ],
92        ),
93        (
94            ng.p.Instrumentation(ng.p.Scalar(lower=0, upper=1), a=ng.p.Choice(list(range(10)))),
95            [
96                (((0.5,), {"a": 3}), {"0": [0.5], "a": [3]}, {"args": {"0": 0.5}, "kwargs": {"a": 3}}),
97                (((0.99,), {"a": 0}), {"0": [0.99], "a": [0]}, {"args": {"0": 0.99}, "kwargs": {"a": 0}}),
98            ],
99        ),
100        (
101            ng.p.Instrumentation(
102                a=ng.p.Choice(
103                    [
104                        ng.p.Instrumentation(
105                            b=ng.p.Choice(list(range(10))), c=ng.p.Log(lower=1e-3, upper=1e3)
106                        ),
107                        ng.p.Instrumentation(
108                            d=ng.p.Scalar(lower=0, upper=1), e=ng.p.Log(lower=1e-3, upper=1e3)
109                        ),
110                    ]
111                )
112            ),
113            [
114                (
115                    ((), {"a": ((), {"d": 0.5, "e": 1.0})}),
116                    {"a": [1], "d": [0.5], "e": [1.0]},
117                    {"args": {}, "kwargs": {"a": {"args": {}, "kwargs": {"d": 0.5, "e": 1.0}}}},
118                ),
119                (
120                    ((), {"a": ((), {"b": 0, "c": 0.014})}),
121                    {"a": [0], "b": [0], "c": [0.014]},
122                    {"args": {}, "kwargs": {"a": {"args": {}, "kwargs": {"b": 0, "c": 0.014}}}},
123                ),
124            ],
125        ),
126    ],
127)
128def test_hyperopt_helpers(parametrization, values):
129    for val, dict_val, hyperopt_val in values:
130        parametrization.value = val
131        assert flatten(_hp_parametrization_to_dict(parametrization)) == pytest.approx(flatten(dict_val))
132        assert flatten(_hp_dict_to_parametrization(hyperopt_val)) == pytest.approx(
133            flatten(parametrization.value)
134        )
135