1"""Shift-Scale transformation""" 2import numpy 3from scipy.special import comb 4import numpoly 5import chaospy 6 7from .distribution import Distribution 8 9 10class ShiftScaleDistribution(Distribution): 11 """ 12 Shift-Scale transformation. 13 14 Linear transforms any distribution of the form `A*X+b` where A is a 15 scaling matrix and `b` is a shift vector. 16 17 Args: 18 dist (Distribution): 19 The underlying distribution to be scaled. 20 shift (float, Sequence[float], Distribution): 21 Mean vector. 22 scale (float, Sequence[float], Distribution): 23 Covariance matrix or variance vector if scale is a 1-d vector. 24 If omitted, assumed to be 1. 25 rotation (Sequence[int], Sequence[Sequence[bool]]): 26 The order of which to resolve conditionals. Either as a sequence of 27 column rotations, or as a permutation matrix. 28 Defaults to `range(len(distribution))` which is the same as 29 `p(x0), p(x1|x0), p(x2|x0,x1), ...`. 30 31 """ 32 33 def __init__( 34 self, 35 dist, 36 shift=0, 37 scale=1, 38 rotation=None, 39 repr_args=None, 40 ): 41 assert isinstance(dist, Distribution), "'dist' should be a distribution" 42 if repr_args is None: 43 repr_args = dist._repr_args[:] 44 repr_args += chaospy.format_repr_kwargs(scale=(scale, 1)) 45 repr_args += chaospy.format_repr_kwargs(shift=(shift, 0)) 46 length = len(dist) if len(dist) > 1 else None 47 dependencies, parameters, rotation = chaospy.declare_dependencies( 48 distribution=self, 49 parameters=dict(shift=shift, scale=scale), 50 rotation=rotation, 51 is_operator=True, 52 length=length, 53 extra_parameters=dict(dist=dist), 54 ) 55 super(ShiftScaleDistribution, self).__init__( 56 parameters=parameters, 57 rotation=rotation, 58 dependencies=dependencies, 59 repr_args=repr_args, 60 ) 61 self._dist = dist 62 permute = numpy.zeros((len(self._rotation), len(self._rotation)), dtype=int) 63 permute[numpy.arange(len(self._rotation), dtype=int), self._rotation] = 1 64 self._permute = permute 65 66 def get_parameters(self, idx, cache, assert_numerical=True): 67 68 shift = self._parameters["shift"] 69 if isinstance(shift, Distribution): 70 shift = shift._get_cache(idx, cache, get=0) 71 elif idx is not None and len(shift) > 1: 72 shift = shift[idx] 73 assert not isinstance(shift, Distribution), shift 74 75 scale = self._parameters["scale"] 76 if isinstance(scale, Distribution): 77 scale = scale._get_cache(idx, cache, get=0) 78 elif idx is not None and len(scale) > 1: 79 scale = scale[idx] 80 assert not isinstance(scale, Distribution), scale 81 assert numpy.all([scale]) > 0, "condition not satisfied: `scale > 0`" 82 83 assert not assert_numerical or not (isinstance(shift, Distribution) or 84 isinstance(scale, Distribution)) 85 86 return dict(idx=idx, dist=self._dist, shift=shift, scale=scale, cache=cache) 87 88 def _ppf(self, qloc, idx, dist, shift, scale, cache): 89 return dist._get_inv(qloc, idx, cache=cache)*scale+shift 90 91 def _cdf(self, xloc, idx, dist, shift, scale, cache): 92 return dist._get_fwd((xloc-shift)/scale, idx, cache=cache) 93 94 def _pdf(self, xloc, idx, dist, shift, scale, cache): 95 return dist._get_pdf((xloc-shift)/scale, idx, cache=cache)/scale 96 97 def get_mom_parameters(self): 98 parameters = self.get_parameters( 99 idx=None, cache={}, assert_numerical=False) 100 del parameters["idx"] 101 del parameters["cache"] 102 return parameters 103 104 def _mom(self, kloc, dist, shift, scale): 105 poly = numpoly.variable(len(self)) 106 poly = numpoly.sum(scale*poly, axis=-1)+shift 107 poly = numpoly.set_dimensions(numpoly.prod(poly**kloc), len(self)) 108 out = sum([dist._get_mom(key)*coeff 109 for key, coeff in zip(poly.exponents, poly.coefficients)]) 110 return out 111 112 def get_ttr_parameters(self, idx): 113 parameters = self.get_parameters( 114 idx=idx, cache={}, assert_numerical=False) 115 del parameters["cache"] 116 return parameters 117 118 def _ttr(self, kloc, idx, dist, shift, scale): 119 coeff0, coeff1 = dist._get_ttr(kloc, idx) 120 coeff0 = coeff0*scale+shift 121 coeff1 = coeff1*scale*scale 122 return coeff0, coeff1 123 124 def _lower(self, idx, dist, shift, scale, cache): 125 return dist._get_lower(idx, cache=cache)*scale+shift 126 127 def _upper(self, idx, dist, shift, scale, cache): 128 return dist._get_upper(idx, cache=cache)*scale+shift 129