1# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
2#
3# This source code is licensed under the MIT license found in the
4# LICENSE file in the root directory of this source tree.
5
6import time
7import numpy as np
8import nevergrad as ng
9import nevergrad.common.typing as tp
10from nevergrad.common import testing
11from . import recaster
12from . import optimizerlib
13
14
15def test_message() -> None:
16    message = recaster.Message(1, 2, blublu=3)
17    np.testing.assert_equal(message.done, False)
18    np.testing.assert_equal(message.args, [1, 2])
19    np.testing.assert_equal(message.kwargs, {"blublu": 3})
20    message.result = 3
21    np.testing.assert_equal(message.done, True)
22    np.testing.assert_equal(message.result, 3)
23
24
25def fake_caller(func: tp.Callable[[int], int]) -> int:
26    output = 0
27    for k in range(10):
28        output += func(k)
29    return output
30
31
32@testing.parametrized(
33    finished=(10, 30),
34    unfinished=(2, None),  # should not hang at deletion!
35)
36def test_messaging_thread(num_iter: int, output: tp.Optional[int]) -> None:
37    thread = recaster.MessagingThread(fake_caller)
38    num_answers = 0
39    while num_answers < num_iter:
40        if thread.messages and not thread.messages[0].done:
41            thread.messages[0].result = 3
42            num_answers += 1
43        time.sleep(0.001)
44    with testing.skip_error_on_systems(AssertionError, systems=("Windows",)):  # TODO fix
45        np.testing.assert_equal(thread.output, output)
46
47
48def test_automatic_thread_deletion() -> None:
49    thread = recaster.MessagingThread(fake_caller)
50    assert thread.is_alive()
51
52
53def fake_cost_function(x: tp.ArrayLike) -> float:
54    return float(np.sum(np.array(x) ** 2))
55
56
57class FakeOptimizer(recaster.SequentialRecastOptimizer):
58    def get_optimization_function(self) -> tp.Callable[[tp.Callable[..., tp.Any]], tp.ArrayLike]:
59        # create a new instance to avoid deadlock
60        return self.__class__(self.parametrization, self.budget, self.num_workers)._optim_function
61
62    def _optim_function(self, func: tp.Callable[..., tp.Any]) -> tp.ArrayLike:
63        suboptim = optimizerlib.OnePlusOne(parametrization=2, budget=self.budget)
64        recom = suboptim.minimize(func)
65        return recom.get_standardized_data(reference=self.parametrization)
66
67
68def test_recast_optimizer() -> None:
69    optimizer = FakeOptimizer(parametrization=2, budget=100)
70    optimizer.minimize(fake_cost_function)
71    assert optimizer._messaging_thread is not None
72    np.testing.assert_equal(optimizer._messaging_thread._thread.call_count, 100)
73
74
75def test_recast_optimizer_with_error() -> None:
76    optimizer = FakeOptimizer(parametrization=2, budget=100)
77    np.testing.assert_raises(TypeError, optimizer.minimize)  # did hang in some versions
78
79
80def test_recast_optimizer_and_stop() -> None:
81    optimizer = FakeOptimizer(parametrization=2, budget=100)
82    optimizer.ask()
83    # thread is not finished... but should not hang!
84
85
86def test_provide_recommendation() -> None:
87    opt = optimizerlib.SQP(parametrization=2, budget=100)
88    assert isinstance(
89        opt.provide_recommendation(), ng.p.Parameter
90    ), "Recommendation should be available from start"
91    # the recommended solution should be the better one among the told points
92    x1 = opt.ask()
93    opt.tell(x1, 10)
94    x2 = opt.ask()
95    opt.tell(x2, 5)
96    recommendation = opt.provide_recommendation()
97    np.testing.assert_array_almost_equal(recommendation.value, x2.value)
98