1# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2# 3# This source code is licensed under the MIT license found in the 4# LICENSE file in the root directory of this source tree. 5 6import pytest 7import numpy as np 8import nevergrad as ng 9from nevergrad.common.tools import flatten 10from .optimizerlib import registry 11from .externalbo import _hp_parametrization_to_dict, _hp_dict_to_parametrization 12 13 14@pytest.mark.parametrize( # type: ignore 15 "parametrization,has_transform", 16 [ 17 (ng.p.Choice(list(range(10))), True), 18 (ng.p.Scalar(lower=0, upper=1), True), 19 (ng.p.Scalar(lower=0, upper=10).set_integer_casting(), True), 20 (ng.p.Log(lower=1e-3, upper=1e3), True), 21 (ng.p.Array(init=np.zeros(10)), True), 22 (ng.p.Instrumentation(ng.p.Scalar(lower=0, upper=1), a=ng.p.Choice(list(range(10)))), False), 23 ( 24 ng.p.Instrumentation( 25 a=ng.p.Choice([ng.p.Scalar(lower=0, upper=1), ng.p.Scalar(lower=100, upper=1000)]) 26 ), 27 True, 28 ), 29 ( 30 ng.p.Instrumentation( 31 a=ng.p.Choice( 32 [ 33 ng.p.Choice(list(range(10))), 34 ng.p.Scalar(lower=0, upper=1), 35 ] 36 ) 37 ), 38 False, 39 ), 40 ( 41 ng.p.Instrumentation( 42 a=ng.p.Choice( 43 [ 44 ng.p.Instrumentation( 45 b=ng.p.Choice(list(range(10))), c=ng.p.Log(lower=1e-3, upper=1e3) 46 ), 47 ng.p.Instrumentation( 48 d=ng.p.Scalar(lower=0, upper=1), e=ng.p.Log(lower=1e-3, upper=1e3) 49 ), 50 ] 51 ) 52 ), 53 False, 54 ), 55 ], 56) 57def test_hyperopt(parametrization, has_transform) -> None: 58 optim1 = registry["HyperOpt"](parametrization=parametrization, budget=5) 59 optim2 = registry["HyperOpt"](parametrization=parametrization.copy(), budget=5) 60 for it in range(4): 61 cand = optim1.ask() 62 optim1.tell(cand, 0) # Tell asked 63 del cand._meta["trial_id"] 64 optim2.tell(cand, 0) # Tell not asked 65 assert flatten(optim1.trials._dynamic_trials[it]["misc"]["vals"]) == pytest.approx( # type: ignore 66 flatten(optim2.trials._dynamic_trials[it]["misc"]["vals"]) # type: ignore 67 ) 68 69 assert optim1.trials.new_trial_ids(1) == optim2.trials.new_trial_ids(1) # type: ignore 70 assert optim1.trials.new_trial_ids(1)[0] == (it + 2) # type: ignore 71 assert (optim1._transform is not None) == has_transform # type: ignore 72 73 # Test parallelization 74 opt = registry["HyperOpt"](parametrization=parametrization, budget=30, num_workers=5) 75 for k in range(40): 76 cand = opt.ask() 77 if not k: 78 opt.tell(cand, 1) 79 80 81@pytest.mark.parametrize( # type: ignore 82 "parametrization,values", 83 [ 84 ( 85 ng.p.Instrumentation( 86 a=ng.p.Choice([ng.p.Choice(list(range(10))), ng.p.Scalar(lower=0, upper=1)]) 87 ), 88 [ 89 (((), {"a": 0.5}), {"a": [1], "a__1": [0.5]}, {"args": {}, "kwargs": {"a": 0.5}}), 90 (((), {"a": 1}), {"a": [0], "a__0": [1]}, {"args": {}, "kwargs": {"a": 1}}), 91 ], 92 ), 93 ( 94 ng.p.Instrumentation(ng.p.Scalar(lower=0, upper=1), a=ng.p.Choice(list(range(10)))), 95 [ 96 (((0.5,), {"a": 3}), {"0": [0.5], "a": [3]}, {"args": {"0": 0.5}, "kwargs": {"a": 3}}), 97 (((0.99,), {"a": 0}), {"0": [0.99], "a": [0]}, {"args": {"0": 0.99}, "kwargs": {"a": 0}}), 98 ], 99 ), 100 ( 101 ng.p.Instrumentation( 102 a=ng.p.Choice( 103 [ 104 ng.p.Instrumentation( 105 b=ng.p.Choice(list(range(10))), c=ng.p.Log(lower=1e-3, upper=1e3) 106 ), 107 ng.p.Instrumentation( 108 d=ng.p.Scalar(lower=0, upper=1), e=ng.p.Log(lower=1e-3, upper=1e3) 109 ), 110 ] 111 ) 112 ), 113 [ 114 ( 115 ((), {"a": ((), {"d": 0.5, "e": 1.0})}), 116 {"a": [1], "d": [0.5], "e": [1.0]}, 117 {"args": {}, "kwargs": {"a": {"args": {}, "kwargs": {"d": 0.5, "e": 1.0}}}}, 118 ), 119 ( 120 ((), {"a": ((), {"b": 0, "c": 0.014})}), 121 {"a": [0], "b": [0], "c": [0.014]}, 122 {"args": {}, "kwargs": {"a": {"args": {}, "kwargs": {"b": 0, "c": 0.014}}}}, 123 ), 124 ], 125 ), 126 ], 127) 128def test_hyperopt_helpers(parametrization, values): 129 for val, dict_val, hyperopt_val in values: 130 parametrization.value = val 131 assert flatten(_hp_parametrization_to_dict(parametrization)) == pytest.approx(flatten(dict_val)) 132 assert flatten(_hp_dict_to_parametrization(hyperopt_val)) == pytest.approx( 133 flatten(parametrization.value) 134 ) 135