1"""Shift-Scale transformation"""
2import numpy
3from scipy.special import comb
4import numpoly
5import chaospy
6
7from .distribution import Distribution
8
9
10class ShiftScaleDistribution(Distribution):
11    """
12    Shift-Scale transformation.
13
14    Linear transforms any distribution of the form `A*X+b` where A is a
15    scaling matrix and `b` is a shift vector.
16
17    Args:
18        dist (Distribution):
19            The underlying distribution to be scaled.
20        shift (float, Sequence[float], Distribution):
21            Mean vector.
22        scale (float, Sequence[float], Distribution):
23            Covariance matrix or variance vector if scale is a 1-d vector.
24            If omitted, assumed to be 1.
25        rotation (Sequence[int], Sequence[Sequence[bool]]):
26            The order of which to resolve conditionals. Either as a sequence of
27            column rotations, or as a permutation matrix.
28            Defaults to `range(len(distribution))` which is the same as
29            `p(x0), p(x1|x0), p(x2|x0,x1), ...`.
30
31    """
32
33    def __init__(
34            self,
35            dist,
36            shift=0,
37            scale=1,
38            rotation=None,
39            repr_args=None,
40    ):
41        assert isinstance(dist, Distribution), "'dist' should be a distribution"
42        if repr_args is None:
43            repr_args = dist._repr_args[:]
44        repr_args += chaospy.format_repr_kwargs(scale=(scale, 1))
45        repr_args += chaospy.format_repr_kwargs(shift=(shift, 0))
46        length = len(dist) if len(dist) > 1 else None
47        dependencies, parameters, rotation = chaospy.declare_dependencies(
48            distribution=self,
49            parameters=dict(shift=shift, scale=scale),
50            rotation=rotation,
51            is_operator=True,
52            length=length,
53            extra_parameters=dict(dist=dist),
54        )
55        super(ShiftScaleDistribution, self).__init__(
56            parameters=parameters,
57            rotation=rotation,
58            dependencies=dependencies,
59            repr_args=repr_args,
60        )
61        self._dist = dist
62        permute = numpy.zeros((len(self._rotation), len(self._rotation)), dtype=int)
63        permute[numpy.arange(len(self._rotation), dtype=int), self._rotation] = 1
64        self._permute = permute
65
66    def get_parameters(self, idx, cache, assert_numerical=True):
67
68        shift = self._parameters["shift"]
69        if isinstance(shift, Distribution):
70            shift = shift._get_cache(idx, cache, get=0)
71        elif idx is not None and len(shift) > 1:
72            shift = shift[idx]
73        assert not isinstance(shift, Distribution), shift
74
75        scale = self._parameters["scale"]
76        if isinstance(scale, Distribution):
77            scale = scale._get_cache(idx, cache, get=0)
78        elif idx is not None and len(scale) > 1:
79            scale = scale[idx]
80        assert not isinstance(scale, Distribution), scale
81        assert numpy.all([scale]) > 0, "condition not satisfied: `scale > 0`"
82
83        assert not assert_numerical or not (isinstance(shift, Distribution) or
84                                            isinstance(scale, Distribution))
85
86        return dict(idx=idx, dist=self._dist, shift=shift, scale=scale, cache=cache)
87
88    def _ppf(self, qloc, idx, dist, shift, scale, cache):
89        return dist._get_inv(qloc, idx, cache=cache)*scale+shift
90
91    def _cdf(self, xloc, idx, dist, shift, scale, cache):
92        return dist._get_fwd((xloc-shift)/scale, idx, cache=cache)
93
94    def _pdf(self, xloc, idx, dist, shift, scale, cache):
95        return dist._get_pdf((xloc-shift)/scale, idx, cache=cache)/scale
96
97    def get_mom_parameters(self):
98        parameters = self.get_parameters(
99            idx=None, cache={}, assert_numerical=False)
100        del parameters["idx"]
101        del parameters["cache"]
102        return parameters
103
104    def _mom(self, kloc, dist, shift, scale):
105        poly = numpoly.variable(len(self))
106        poly = numpoly.sum(scale*poly, axis=-1)+shift
107        poly = numpoly.set_dimensions(numpoly.prod(poly**kloc), len(self))
108        out = sum([dist._get_mom(key)*coeff
109                   for key, coeff in zip(poly.exponents, poly.coefficients)])
110        return out
111
112    def get_ttr_parameters(self, idx):
113        parameters = self.get_parameters(
114            idx=idx, cache={}, assert_numerical=False)
115        del parameters["cache"]
116        return parameters
117
118    def _ttr(self, kloc, idx, dist, shift, scale):
119        coeff0, coeff1 = dist._get_ttr(kloc, idx)
120        coeff0 = coeff0*scale+shift
121        coeff1 = coeff1*scale*scale
122        return coeff0, coeff1
123
124    def _lower(self, idx, dist, shift, scale, cache):
125        return dist._get_lower(idx, cache=cache)*scale+shift
126
127    def _upper(self, idx, dist, shift, scale, cache):
128        return dist._get_upper(idx, cache=cache)*scale+shift
129