1# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
2#
3# This source code is licensed under the MIT license found in the
4# LICENSE file in the root directory of this source tree.
5
6import uuid
7import itertools
8import numpy as np
9from scipy import stats
10import nevergrad.common.typing as tp
11from . import utils
12
13
14class Transform:
15    """Base class for transforms implementing a forward and a backward (inverse)
16    method.
17    This provide a default representation, and a short representation should be implemented
18    for each transform.
19    """
20
21    def __init__(self) -> None:
22        self.name = uuid.uuid4().hex  # a name for easy identification. This random uuid should be overriden
23
24    def forward(self, x: np.ndarray) -> np.ndarray:
25        raise NotImplementedError
26
27    def backward(self, y: np.ndarray) -> np.ndarray:
28        raise NotImplementedError
29
30    def reverted(self) -> "Transform":
31        return Reverted(self)
32
33    def __repr__(self) -> str:
34        args = ", ".join(f"{x}={y}" for x, y in sorted(self.__dict__.items()) if not x.startswith("_"))
35        return f"{self.__class__.__name__}({args})"
36
37
38class Reverted(Transform):
39    """Inverse of a transform.
40
41    Parameters
42    ----------
43    transform: Transform
44    """
45
46    def __init__(self, transform: Transform) -> None:
47        super().__init__()
48        self.transform = transform
49        self.name = f"Rv({self.transform.name})"
50
51    def forward(self, x: np.ndarray) -> np.ndarray:
52        return self.transform.backward(x)
53
54    def backward(self, y: np.ndarray) -> np.ndarray:
55        return self.transform.forward(y)
56
57
58class Affine(Transform):
59    """Affine transform a * x + b
60
61    Parameters
62    ----------
63    a: float
64    b: float
65    """
66
67    def __init__(self, a: float, b: float) -> None:
68        super().__init__()
69        if not a:
70            raise ValueError('"a" parameter should be non-zero to prevent information loss.')
71        self.a = a
72        self.b = b
73        self.name = f"Af({self.a},{self.b})"
74
75    def forward(self, x: np.ndarray) -> np.ndarray:
76        return self.a * x + self.b  # type: ignore
77
78    def backward(self, y: np.ndarray) -> np.ndarray:
79        return (y - self.b) / self.a  # type: ignore
80
81
82class Exponentiate(Transform):
83    """Exponentiation transform base ** (coeff * x)
84    This can for instance be used for to get a logarithmicly distruted values 10**(-[1, 2, 3]).
85
86    Parameters
87    ----------
88    base: float
89    coeff: float
90    """
91
92    def __init__(self, base: float = 10.0, coeff: float = 1.0) -> None:
93        super().__init__()
94        self.base = base
95        self.coeff = coeff
96        self.name = f"Ex({self.base},{self.coeff})"
97
98    def forward(self, x: np.ndarray) -> np.ndarray:
99        return self.base ** (float(self.coeff) * x)  # type: ignore
100
101    def backward(self, y: np.ndarray) -> np.ndarray:
102        return np.log(y) / (float(self.coeff) * np.log(self.base))  # type: ignore
103
104
105BoundType = tp.Optional[tp.Union[tp.ArrayLike, float]]
106
107
108def _f(x: BoundType) -> BoundType:
109    """Format for prints:
110    array with one scalars are converted to floats
111    """
112    if isinstance(x, (np.ndarray, list, tuple)):
113        x = np.array(x, copy=False)
114        if x.shape == (1,):
115            x = float(x[0])
116    if isinstance(x, float) and x.is_integer():
117        x = int(x)
118    return x
119
120
121class BoundTransform(Transform):  # pylint: disable=abstract-method
122    def __init__(self, a_min: BoundType = None, a_max: BoundType = None) -> None:
123        super().__init__()
124        self.a_min: tp.Optional[np.ndarray] = None
125        self.a_max: tp.Optional[np.ndarray] = None
126        for name, value in [("a_min", a_min), ("a_max", a_max)]:
127            if value is not None:
128                isarray = isinstance(value, (tuple, list, np.ndarray))
129                setattr(self, name, np.array(value, copy=False) if isarray else np.array([value]))
130        if not (self.a_min is None or self.a_max is None):
131            if (self.a_min >= self.a_max).any():
132                raise ValueError(f"Lower bounds {a_min} should be strictly smaller than upper bounds {a_max}")
133        if self.a_min is None and self.a_max is None:
134            raise ValueError("At least one bound must be specified")
135        self.shape: tp.Tuple[int, ...] = self.a_min.shape if self.a_min is not None else self.a_max.shape
136
137    def _check_shape(self, x: np.ndarray) -> None:
138        for dims in itertools.zip_longest(x.shape, self.shape, fillvalue=1):
139            if dims[0] != dims[1] and not any(x == 1 for x in dims):  # same or broadcastable
140                raise ValueError(f"Shapes do not match: {self.shape} and {x.shape}")
141
142
143class TanhBound(BoundTransform):
144    """Bounds all real values into [a_min, a_max] using a tanh transform.
145    Beware, tanh goes very fast to its limits.
146
147    Parameters
148    ----------
149    a_min: float
150    a_max: float
151    """
152
153    def __init__(self, a_min: tp.Union[tp.ArrayLike, float], a_max: tp.Union[tp.ArrayLike, float]) -> None:
154        super().__init__(a_min=a_min, a_max=a_max)
155        if self.a_min is None or self.a_max is None:
156            raise ValueError("Both bounds must be specified")
157        self._b = 0.5 * (self.a_max + self.a_min)
158        self._a = 0.5 * (self.a_max - self.a_min)
159        self.name = f"Th({_f(a_min)},{_f(a_max)})"
160
161    def forward(self, x: np.ndarray) -> np.ndarray:
162        self._check_shape(x)
163        return self._b + self._a * np.tanh(x)  # type: ignore
164
165    def backward(self, y: np.ndarray) -> np.ndarray:
166        self._check_shape(y)
167        if (y > self.a_max).any() or (y < self.a_min).any():
168            raise ValueError(
169                f"Only data between {self.a_min} and {self.a_max} "
170                "can be transformed back (bounds lead to infinity)."
171            )
172        return np.arctanh((y - self._b) / self._a)  # type: ignore
173
174
175class Clipping(BoundTransform):
176    """Bounds all real values into [a_min, a_max] using clipping (not bijective).
177
178    Parameters
179    ----------
180    a_min: float or None
181        lower bound
182    a_max: float or None
183        upper bound
184    bounce: bool
185        bounce (once) on borders instead of just clipping
186    """
187
188    def __init__(
189        self,
190        a_min: BoundType = None,
191        a_max: BoundType = None,
192        bounce: bool = False,
193    ) -> None:
194        super().__init__(a_min=a_min, a_max=a_max)
195        self._bounce = bounce
196        b = ",b" if bounce else ""
197        self.name = f"Cl({_f(a_min)},{_f(a_max)}{b})"
198        self.checker = utils.BoundChecker(self.a_min, self.a_max)
199
200    def forward(self, x: np.ndarray) -> np.ndarray:
201        self._check_shape(x)
202        if self.checker(x):
203            return x
204        out = np.clip(x, self.a_min, self.a_max)
205        if self._bounce:
206            out = np.clip(2 * out - x, self.a_min, self.a_max)
207        return out  # type: ignore
208
209    def backward(self, y: np.ndarray) -> np.ndarray:
210        self._check_shape(y)
211        if not self.checker(y):
212            raise ValueError(
213                f"Only data between {self.a_min} and {self.a_max} can be transformed back.\n" f"Got: {y}"
214            )
215        return y
216
217
218class ArctanBound(BoundTransform):
219    """Bounds all real values into [a_min, a_max] using an arctan transform.
220    This is a much softer approach compared to tanh.
221
222    Parameters
223    ----------
224    a_min: float
225    a_max: float
226    """
227
228    def __init__(self, a_min: tp.Union[tp.ArrayLike, float], a_max: tp.Union[tp.ArrayLike, float]) -> None:
229        super().__init__(a_min=a_min, a_max=a_max)
230        if self.a_min is None or self.a_max is None:
231            raise ValueError("Both bounds must be specified")
232        self._b = 0.5 * (self.a_max + self.a_min)
233        self._a = (self.a_max - self.a_min) / np.pi
234        self.name = f"At({_f(a_min)},{_f(a_max)})"
235
236    def forward(self, x: np.ndarray) -> np.ndarray:
237        self._check_shape(x)
238        return self._b + self._a * np.arctan(x)  # type: ignore
239
240    def backward(self, y: np.ndarray) -> np.ndarray:
241        self._check_shape(y)
242        if (y > self.a_max).any() or (y < self.a_min).any():
243            raise ValueError(f"Only data between {self.a_min} and {self.a_max} can be transformed back.")
244        return np.tan((y - self._b) / self._a)  # type: ignore
245
246
247class CumulativeDensity(BoundTransform):
248    """Bounds all real values into [0, 1] using a gaussian cumulative density function (cdf)
249    Beware, cdf goes very fast to its limits.
250    """
251
252    def __init__(self, lower: float = 0.0, upper: float = 1.0, eps: float = 1e-9) -> None:
253        super().__init__(a_min=lower, a_max=upper)
254        self._b = lower
255        self._a = upper - lower
256        self._eps = eps
257        self.name = f"Cd({_f(lower)},{_f(upper)})"
258
259    def forward(self, x: np.ndarray) -> np.ndarray:
260        return self._a * stats.norm.cdf(x) + self._b  # type: ignore
261
262    def backward(self, y: np.ndarray) -> np.ndarray:
263        if (y > self.a_max).any() or (y < self.a_min).any():
264            raise ValueError(
265                f"Only data between {self.a_min} and {self.a_max} can be transformed back.\nGot: {y}"
266            )
267        y = np.clip((y - self._b) / self._a, self._eps, 1 - self._eps)
268        return stats.norm.ppf(y)
269
270
271class Fourrier(Transform):
272    def __init__(self, axes: tp.Union[int, tp.Sequence[int]] = 0) -> None:
273        super().__init__()
274        self.axes: tp.Tuple[int, ...] = (axes,) if isinstance(axes, int) else tuple(axes)  # type: ignore
275        self.name = f"F({axes})"
276
277    def forward(self, x: np.ndarray) -> np.ndarray:
278        if any(x.shape[a] % 2 for a in self.axes):
279            raise ValueError(f"Only even shapes are allowed for Fourrier transform, got {x.shape}")
280        return np.fft.rfftn(x, axes=self.axes, norm="ortho")  # type: ignore
281
282    def backward(self, y: np.ndarray) -> np.ndarray:
283        return np.fft.irfftn(y, axes=self.axes, norm="ortho")  # type: ignore
284