1# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2# 3# This source code is licensed under the MIT license found in the 4# LICENSE file in the root directory of this source tree. 5 6import typing as tp 7import pytest 8import numpy as np 9from ..common import testing 10from . import transforms 11 12 13@testing.parametrized( 14 affine=(transforms.Affine(3, 4), "Af(3,4)"), 15 reverted=(transforms.Affine(3, 4).reverted(), "Rv(Af(3,4))"), 16 exponentiate=(transforms.Exponentiate(3, 4), "Ex(3,4)"), 17 tanh=(transforms.TanhBound(3.0, 4.5), "Th(3,4.5)"), 18 arctan=(transforms.ArctanBound(3, 4), "At(3,4)"), 19 cumdensity=(transforms.CumulativeDensity(), "Cd(0,1)"), 20 cumdensity2=(transforms.CumulativeDensity(1, 3), "Cd(1,3)"), 21 clipping=(transforms.Clipping(None, 1e12), "Cl(None,1000000000000)"), 22 bouncing=(transforms.Clipping(-12000, 12000, bounce=True), "Cl(-12000,12000,b)"), 23 fourrier=(transforms.Fourrier(), "F(0)"), 24) 25def test_back_and_forth(transform: transforms.Transform, string: str) -> None: 26 x = np.random.normal(0, 1, size=12) 27 y = transform.forward(x) 28 x2 = transform.backward(y) 29 np.testing.assert_array_almost_equal(x2, x) 30 np.testing.assert_equal(transform.name, string) 31 32 33@testing.parametrized( 34 affine=(transforms.Affine(3, 4), [0, 1, 2], [4, 7, 10]), 35 reverted=(transforms.Affine(3, 4).reverted(), [4, 7, 10], [0, 1, 2]), 36 exponentiate=(transforms.Exponentiate(10, -1.0), [0, 1, 2], [1, 0.1, 0.01]), 37 tanh=(transforms.TanhBound(3, 5), [-100000, 100000, 0], [3, 5, 4]), 38 arctan=(transforms.ArctanBound(3, 5), [-100000, 100000, 0], [3, 5, 4]), 39 bouncing=(transforms.Clipping(0, 10, bounce=True), [-1, 22, 3], [1, 0, 3]), 40 cumdensity=(transforms.CumulativeDensity(), [-10, 0, 10], [0, 0.5, 1]), 41 cumdensity_bounds=(transforms.CumulativeDensity(2, 4), [-10, 0, 10], [2, 3, 4]), 42) 43def test_vals(transform: transforms.Transform, x: tp.List[float], expected: tp.List[float]) -> None: 44 y = transform.forward(np.array(x)) 45 np.testing.assert_almost_equal(y, expected, decimal=5) 46 47 48@testing.parametrized( 49 tanh=(transforms.TanhBound(0, 5), [2, 4], None), 50 tanh_err=(transforms.TanhBound(0, 5), [2, 4, 6], ValueError), 51 clipping=(transforms.Clipping(0), [2, 4, 6], None), 52 clipping_err=(transforms.Clipping(0), [-2, 4, 6], ValueError), 53 arctan=(transforms.ArctanBound(0, 5), [2, 4, 5], None), 54 arctan_err=(transforms.ArctanBound(0, 5), [-1, 4, 5], ValueError), 55 cumdensity=(transforms.CumulativeDensity(), [0, 0.5], None), 56 cumdensity_err=(transforms.CumulativeDensity(), [-0.1, 0.5], ValueError), 57) 58def test_out_of_bound( 59 transform: transforms.Transform, x: tp.List[float], expected: tp.Optional[tp.Type[Exception]] 60) -> None: 61 if expected is None: 62 transform.backward(np.array(x)) 63 else: 64 with pytest.raises(expected): 65 transform.backward(np.array(x)) 66 67 68@testing.parametrized( 69 tanh=(transforms.TanhBound, [1.0, 100.0]), 70 arctan=(transforms.ArctanBound, [0.9968, 99.65]), 71 clipping=(transforms.Clipping, [1, 90]), 72) 73def test_multibounds(transform_cls: tp.Type[transforms.BoundTransform], expected: tp.List[float]) -> None: 74 transform = transform_cls([0, 0], [1, 100]) 75 output = transform.forward(np.array([100, 90])) 76 np.testing.assert_almost_equal(output, expected, decimal=2) 77 # shapes 78 with pytest.raises(ValueError): 79 transform.forward(np.array([-3, 5, 4])) 80 with pytest.raises(ValueError): 81 transform.backward(np.array([-3, 5, 4])) 82 # bound error 83 with pytest.raises(ValueError): 84 transform_cls([0, 0], [0, 100]) 85 # two Nones 86 with pytest.raises(ValueError): 87 transform_cls(None, None) 88 89 90@testing.parametrized( 91 both_sides=(transforms.Clipping(0, 1), [0, 1.0]), 92 one_side=(transforms.Clipping(a_max=1), [-3, 1.0]), 93) 94def test_clipping(transform: transforms.Transform, expected: tp.List[float]) -> None: 95 y = transform.forward(np.array([-3, 5])) 96 np.testing.assert_array_equal(y, expected) 97