1# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
2#
3# This source code is licensed under the MIT license found in the
4# LICENSE file in the root directory of this source tree.
5
6import typing as tp
7import pytest
8import numpy as np
9from ..common import testing
10from . import transforms
11
12
13@testing.parametrized(
14    affine=(transforms.Affine(3, 4), "Af(3,4)"),
15    reverted=(transforms.Affine(3, 4).reverted(), "Rv(Af(3,4))"),
16    exponentiate=(transforms.Exponentiate(3, 4), "Ex(3,4)"),
17    tanh=(transforms.TanhBound(3.0, 4.5), "Th(3,4.5)"),
18    arctan=(transforms.ArctanBound(3, 4), "At(3,4)"),
19    cumdensity=(transforms.CumulativeDensity(), "Cd(0,1)"),
20    cumdensity2=(transforms.CumulativeDensity(1, 3), "Cd(1,3)"),
21    clipping=(transforms.Clipping(None, 1e12), "Cl(None,1000000000000)"),
22    bouncing=(transforms.Clipping(-12000, 12000, bounce=True), "Cl(-12000,12000,b)"),
23    fourrier=(transforms.Fourrier(), "F(0)"),
24)
25def test_back_and_forth(transform: transforms.Transform, string: str) -> None:
26    x = np.random.normal(0, 1, size=12)
27    y = transform.forward(x)
28    x2 = transform.backward(y)
29    np.testing.assert_array_almost_equal(x2, x)
30    np.testing.assert_equal(transform.name, string)
31
32
33@testing.parametrized(
34    affine=(transforms.Affine(3, 4), [0, 1, 2], [4, 7, 10]),
35    reverted=(transforms.Affine(3, 4).reverted(), [4, 7, 10], [0, 1, 2]),
36    exponentiate=(transforms.Exponentiate(10, -1.0), [0, 1, 2], [1, 0.1, 0.01]),
37    tanh=(transforms.TanhBound(3, 5), [-100000, 100000, 0], [3, 5, 4]),
38    arctan=(transforms.ArctanBound(3, 5), [-100000, 100000, 0], [3, 5, 4]),
39    bouncing=(transforms.Clipping(0, 10, bounce=True), [-1, 22, 3], [1, 0, 3]),
40    cumdensity=(transforms.CumulativeDensity(), [-10, 0, 10], [0, 0.5, 1]),
41    cumdensity_bounds=(transforms.CumulativeDensity(2, 4), [-10, 0, 10], [2, 3, 4]),
42)
43def test_vals(transform: transforms.Transform, x: tp.List[float], expected: tp.List[float]) -> None:
44    y = transform.forward(np.array(x))
45    np.testing.assert_almost_equal(y, expected, decimal=5)
46
47
48@testing.parametrized(
49    tanh=(transforms.TanhBound(0, 5), [2, 4], None),
50    tanh_err=(transforms.TanhBound(0, 5), [2, 4, 6], ValueError),
51    clipping=(transforms.Clipping(0), [2, 4, 6], None),
52    clipping_err=(transforms.Clipping(0), [-2, 4, 6], ValueError),
53    arctan=(transforms.ArctanBound(0, 5), [2, 4, 5], None),
54    arctan_err=(transforms.ArctanBound(0, 5), [-1, 4, 5], ValueError),
55    cumdensity=(transforms.CumulativeDensity(), [0, 0.5], None),
56    cumdensity_err=(transforms.CumulativeDensity(), [-0.1, 0.5], ValueError),
57)
58def test_out_of_bound(
59    transform: transforms.Transform, x: tp.List[float], expected: tp.Optional[tp.Type[Exception]]
60) -> None:
61    if expected is None:
62        transform.backward(np.array(x))
63    else:
64        with pytest.raises(expected):
65            transform.backward(np.array(x))
66
67
68@testing.parametrized(
69    tanh=(transforms.TanhBound, [1.0, 100.0]),
70    arctan=(transforms.ArctanBound, [0.9968, 99.65]),
71    clipping=(transforms.Clipping, [1, 90]),
72)
73def test_multibounds(transform_cls: tp.Type[transforms.BoundTransform], expected: tp.List[float]) -> None:
74    transform = transform_cls([0, 0], [1, 100])
75    output = transform.forward(np.array([100, 90]))
76    np.testing.assert_almost_equal(output, expected, decimal=2)
77    # shapes
78    with pytest.raises(ValueError):
79        transform.forward(np.array([-3, 5, 4]))
80    with pytest.raises(ValueError):
81        transform.backward(np.array([-3, 5, 4]))
82    # bound error
83    with pytest.raises(ValueError):
84        transform_cls([0, 0], [0, 100])
85    # two Nones
86    with pytest.raises(ValueError):
87        transform_cls(None, None)
88
89
90@testing.parametrized(
91    both_sides=(transforms.Clipping(0, 1), [0, 1.0]),
92    one_side=(transforms.Clipping(a_max=1), [-3, 1.0]),
93)
94def test_clipping(transform: transforms.Transform, expected: tp.List[float]) -> None:
95    y = transform.forward(np.array([-3, 5]))
96    np.testing.assert_array_equal(y, expected)
97