1"""Lower-Upper transformation."""
2import numpy
3from scipy.special import comb
4import numpoly
5import chaospy
6
7from .distribution import Distribution
8
9
10class LowerUpperDistribution(Distribution):
11    """
12    Lower-Upper transformation.
13
14    Linear transforms any distribution on any interval `[lower, upper]`.
15
16    Args:
17        dist (Distribution):
18            The underlying distribution to be scaled.
19        lower (float, Sequence[float], Distribution):
20            Lower bounds.
21        upper (float, Sequence[float], Distribution):
22            Upper bounds.
23
24    """
25
26    def __init__(
27            self,
28            dist,
29            lower=0.,
30            upper=1.,
31            rotation=None,
32            repr_args=None,
33    ):
34        assert isinstance(dist, Distribution), "'dist' should be a distribution"
35        assert len(dist) == 1
36        if repr_args is None:
37            repr_args = dist._repr_args[:]
38        repr_args += chaospy.format_repr_kwargs(lower=(lower, 0), upper=(upper, 1))
39
40        dependencies, parameters, rotation, = chaospy.declare_dependencies(
41            distribution=self,
42            parameters=dict(lower=lower, upper=upper),
43            is_operator=True,
44            rotation=rotation,
45            extra_parameters=dict(dist=dist),
46        )
47        assert len(dependencies) == 1
48        assert len(parameters["lower"]) == 1
49        assert len(parameters["upper"]) == 1
50        super(LowerUpperDistribution, self).__init__(
51            parameters=parameters,
52            dependencies=dependencies,
53            rotation=rotation,
54            repr_args=repr_args,
55        )
56        self._dist = dist
57
58    def get_parameters(self, idx, cache, assert_numerical=True):
59        parameters = super(LowerUpperDistribution, self).get_parameters(
60            idx, cache, assert_numerical=assert_numerical)
61        lower = parameters["lower"]
62        if isinstance(lower, Distribution):
63            lower = lower._get_cache(idx, cache, get=0)
64        upper = parameters["upper"]
65        if isinstance(upper, Distribution):
66            upper = upper._get_cache(idx, cache, get=0)
67        assert not assert_numerical or not (isinstance(lower, Distribution) or
68                                            isinstance(upper, Distribution))
69        assert numpy.all(upper > lower), (
70            "condition not satisfied: `upper > lower`")
71        lower0 = self._dist._get_lower(idx, cache.copy())
72        upper0 = self._dist._get_upper(idx, cache.copy())
73        scale = (upper-lower)/(upper0-lower0)
74        shift = lower-lower0*(upper-lower)/(upper0-lower0)
75        parameters = self._dist.get_parameters(idx, cache, assert_numerical=assert_numerical)
76        return dict(dist=self._dist, scale=scale, shift=shift, parameters=parameters)
77
78    def _lower(self, dist, scale, shift, parameters):
79        return dist._lower(**parameters)*scale+shift
80
81    def _upper(self, dist, scale, shift, parameters):
82        return dist._upper(**parameters)*scale+shift
83
84    def _ppf(self, qloc, dist, scale, shift, parameters):
85        return dist._ppf(qloc, **parameters)*scale+shift
86
87    def _cdf(self, xloc, dist, scale, shift, parameters):
88        return dist._cdf((xloc-shift)/scale, **parameters)
89
90    def _pdf(self, xloc, dist, scale, shift, parameters):
91        return dist._pdf((xloc-shift)/scale, **parameters)/scale
92
93    def _mom(self, kloc, dist, scale, shift, parameters):
94        del parameters
95        poly = numpoly.variable(len(self))
96        poly = numpoly.sum(scale*poly, axis=-1)+shift
97        poly = numpoly.set_dimensions(numpoly.prod(poly**kloc), len(self))
98        out = sum(dist._get_mom(key)*coeff
99                  for key, coeff in zip(poly.exponents, poly.coefficients))
100        return out
101
102    def _ttr(self, kloc, dist, scale, shift, parameters):
103        coeff0, coeff1 = dist._ttr(kloc, **parameters)
104        coeff0 = coeff0*scale+shift
105        coeff1 = coeff1*scale*scale
106        return coeff0, coeff1
107