1# -*- coding: utf-8 -*-
2# Licensed under a 3-clause BSD style license - see LICENSE.rst
3
4"""
5Distribution class and associated machinery.
6"""
7import builtins
8
9import numpy as np
10
11from astropy import units as u
12from astropy import stats
13
14__all__ = ['Distribution']
15
16
17# we set this by hand because the symbolic expression (below) requires scipy
18# SMAD_SCALE_FACTOR = 1 / scipy.stats.norm.ppf(0.75)
19SMAD_SCALE_FACTOR = 1.48260221850560203193936104071326553821563720703125
20
21
22class Distribution:
23    """
24    A scalar value or array values with associated uncertainty distribution.
25
26    This object will take its exact type from whatever the ``samples`` argument
27    is. In general this is expected to be an `~astropy.units.Quantity` or
28    `numpy.ndarray`, although anything compatible with `numpy.asanyarray` is
29    possible.
30
31    See also: https://docs.astropy.org/en/stable/uncertainty/
32
33    Parameters
34    ----------
35    samples : array-like
36        The distribution, with sampling along the *leading* axis. If 1D, the
37        sole dimension is used as the sampling axis (i.e., it is a scalar
38        distribution).
39    """
40    _generated_subclasses = {}
41
42    def __new__(cls, samples):
43        if isinstance(samples, Distribution):
44            samples = samples.distribution
45        else:
46            samples = np.asanyarray(samples, order='C')
47        if samples.shape == ():
48            raise TypeError('Attempted to initialize a Distribution with a scalar')
49
50        new_dtype = np.dtype({'names': ['samples'],
51                              'formats': [(samples.dtype, (samples.shape[-1],))]})
52        samples_cls = type(samples)
53        new_cls = cls._generated_subclasses.get(samples_cls)
54        if new_cls is None:
55            # Make a new class with the combined name, inserting Distribution
56            # itself below the samples class since that way Quantity methods
57            # like ".to" just work (as .view() gets intercepted).  However,
58            # repr and str are problems, so we put those on top.
59            # TODO: try to deal with this at the lower level.  The problem is
60            # that array2string does not allow one to override how structured
61            # arrays are typeset, leading to all samples to be shown.  It may
62            # be possible to hack oneself out by temporarily becoming a void.
63            new_name = samples_cls.__name__ + cls.__name__
64            new_cls = type(
65                new_name,
66                (_DistributionRepr, samples_cls, ArrayDistribution),
67                {'_samples_cls': samples_cls})
68            cls._generated_subclasses[samples_cls] = new_cls
69
70        self = samples.view(dtype=new_dtype, type=new_cls)
71        # Get rid of trailing dimension of 1.
72        self.shape = samples.shape[:-1]
73        return self
74
75    @property
76    def distribution(self):
77        return self['samples']
78
79    def __array_ufunc__(self, ufunc, method, *inputs, **kwargs):
80        converted = []
81        outputs = kwargs.pop('out', None)
82        if outputs:
83            kwargs['out'] = tuple((output.distribution if
84                                   isinstance(output, Distribution)
85                                   else output) for output in outputs)
86        if method in {'reduce', 'accumulate', 'reduceat'}:
87            axis = kwargs.get('axis', None)
88            if axis is None:
89                assert isinstance(inputs[0], Distribution)
90                kwargs['axis'] = tuple(range(inputs[0].ndim))
91
92        for input_ in inputs:
93            if isinstance(input_, Distribution):
94                converted.append(input_.distribution)
95            else:
96                shape = getattr(input_, 'shape', ())
97                if shape:
98                    converted.append(input_[..., np.newaxis])
99                else:
100                    converted.append(input_)
101
102        results = getattr(ufunc, method)(*converted, **kwargs)
103
104        if not isinstance(results, tuple):
105            results = (results,)
106        if outputs is None:
107            outputs = (None,) * len(results)
108
109        finals = []
110        for result, output in zip(results, outputs):
111            if output is not None:
112                finals.append(output)
113            else:
114                if getattr(result, 'shape', False):
115                    finals.append(Distribution(result))
116                else:
117                    finals.append(result)
118
119        return finals if len(finals) > 1 else finals[0]
120
121    @property
122    def n_samples(self):
123        """
124        The number of samples of this distribution.  A single `int`.
125        """
126        return self.dtype['samples'].shape[0]
127
128    def pdf_mean(self, dtype=None, out=None):
129        """
130        The mean of this distribution.
131
132        Arguments are as for `numpy.mean`.
133        """
134        return self.distribution.mean(axis=-1, dtype=dtype, out=out)
135
136    def pdf_std(self, dtype=None, out=None, ddof=0):
137        """
138        The standard deviation of this distribution.
139
140        Arguments are as for `numpy.std`.
141        """
142        return self.distribution.std(axis=-1, dtype=dtype, out=out, ddof=ddof)
143
144    def pdf_var(self, dtype=None, out=None, ddof=0):
145        """
146        The variance of this distribution.
147
148        Arguments are as for `numpy.var`.
149        """
150        return self.distribution.var(axis=-1, dtype=dtype, out=out, ddof=ddof)
151
152    def pdf_median(self, out=None):
153        """
154        The median of this distribution.
155
156        Parameters
157        ----------
158        out : array, optional
159            Alternative output array in which to place the result. It must
160            have the same shape and buffer length as the expected output,
161            but the type (of the output) will be cast if necessary.
162        """
163        return np.median(self.distribution, axis=-1, out=out)
164
165    def pdf_mad(self, out=None):
166        """
167        The median absolute deviation of this distribution.
168
169        Parameters
170        ----------
171        out : array, optional
172            Alternative output array in which to place the result. It must
173            have the same shape and buffer length as the expected output,
174            but the type (of the output) will be cast if necessary.
175        """
176        median = self.pdf_median(out=out)
177        absdiff = np.abs(self - median)
178        return np.median(absdiff.distribution, axis=-1, out=median,
179                         overwrite_input=True)
180
181    def pdf_smad(self, out=None):
182        """
183        The median absolute deviation of this distribution rescaled to match the
184        standard deviation for a normal distribution.
185
186        Parameters
187        ----------
188        out : array, optional
189            Alternative output array in which to place the result. It must
190            have the same shape and buffer length as the expected output,
191            but the type (of the output) will be cast if necessary.
192        """
193        result = self.pdf_mad(out=out)
194        result *= SMAD_SCALE_FACTOR
195        return result
196
197    def pdf_percentiles(self, percentile, **kwargs):
198        """
199        Compute percentiles of this Distribution.
200
201        Parameters
202        ----------
203        percentile : float or array of float or `~astropy.units.Quantity`
204            The desired percentiles of the distribution (i.e., on [0,100]).
205            `~astropy.units.Quantity` will be converted to percent, meaning
206            that a ``dimensionless_unscaled`` `~astropy.units.Quantity` will
207            be interpreted as a quantile.
208
209        Additional keywords are passed into `numpy.percentile`.
210
211        Returns
212        -------
213        percentiles : `~astropy.units.Quantity` ['dimensionless']
214            The ``fracs`` percentiles of this distribution.
215        """
216        percentile = u.Quantity(percentile, u.percent).value
217        percs = np.percentile(self.distribution, percentile, axis=-1, **kwargs)
218        # numpy.percentile strips units for unclear reasons, so we have to make
219        # a new object with units
220        if hasattr(self.distribution, '_new_view'):
221            return self.distribution._new_view(percs)
222        else:
223            return percs
224
225    def pdf_histogram(self, **kwargs):
226        """
227        Compute histogram over the samples in the distribution.
228
229        Parameters
230        ----------
231        All keyword arguments are passed into `astropy.stats.histogram`. Note
232        That some of these options may not be valid for some multidimensional
233        distributions.
234
235        Returns
236        -------
237        hist : array
238            The values of the histogram. Trailing dimension is the histogram
239            dimension.
240        bin_edges : array of dtype float
241            Return the bin edges ``(length(hist)+1)``. Trailing dimension is the
242            bin histogram dimension.
243        """
244        distr = self.distribution
245        raveled_distr = distr.reshape(distr.size//distr.shape[-1], distr.shape[-1])
246
247        nhists = []
248        bin_edges = []
249        for d in raveled_distr:
250            nhist, bin_edge = stats.histogram(d, **kwargs)
251            nhists.append(nhist)
252            bin_edges.append(bin_edge)
253
254        nhists = np.array(nhists)
255        nh_shape = self.shape + (nhists.size//self.size,)
256        bin_edges = np.array(bin_edges)
257        be_shape = self.shape + (bin_edges.size//self.size,)
258        return nhists.reshape(nh_shape), bin_edges.reshape(be_shape)
259
260
261class ScalarDistribution(Distribution, np.void):
262    """Scalar distribution.
263
264    This class mostly exists to make `~numpy.array2print` possible for
265    all subclasses.  It is a scalar element, still with n_samples samples.
266    """
267    pass
268
269
270class ArrayDistribution(Distribution, np.ndarray):
271    # This includes the important override of view and __getitem__
272    # which are needed for all ndarray subclass Distributions, but not
273    # for the scalar one.
274    _samples_cls = np.ndarray
275
276    # Override view so that we stay a Distribution version of the new type.
277    def view(self, dtype=None, type=None):
278        """New view of array with the same data.
279
280        Like `~numpy.ndarray.view` except that the result will always be a new
281        `~astropy.uncertainty.Distribution` instance.  If the requested
282        ``type`` is a `~astropy.uncertainty.Distribution`, then no change in
283        ``dtype`` is allowed.
284
285        """
286        if type is None and (isinstance(dtype, builtins.type)
287                             and issubclass(dtype, np.ndarray)):
288            type = dtype
289            dtype = None
290
291        view_args = [item for item in (dtype, type) if item is not None]
292
293        if type is None or (isinstance(type, builtins.type)
294                            and issubclass(type, Distribution)):
295            if dtype is not None and dtype != self.dtype:
296                raise ValueError('cannot view as Distribution subclass with a new dtype.')
297            return super().view(*view_args)
298
299        # View as the new non-Distribution class, but turn into a Distribution again.
300        result = self.distribution.view(*view_args)
301        return Distribution(result)
302
303    # Override __getitem__ so that 'samples' is returned as the sample class.
304    def __getitem__(self, item):
305        result = super().__getitem__(item)
306        if item == 'samples':
307            # Here, we need to avoid our own redefinition of view.
308            return super(ArrayDistribution, result).view(self._samples_cls)
309        elif isinstance(result, np.void):
310            return result.view((ScalarDistribution, result.dtype))
311        else:
312            return result
313
314
315class _DistributionRepr:
316    def __repr__(self):
317        reprarr = repr(self.distribution)
318        if reprarr.endswith('>'):
319            firstspace = reprarr.find(' ')
320            reprarr = reprarr[firstspace+1:-1]  # :-1] removes the ending '>'
321            return '<{} {} with n_samples={}>'.format(self.__class__.__name__,
322                                                      reprarr, self.n_samples)
323        else:  # numpy array-like
324            firstparen = reprarr.find('(')
325            reprarr = reprarr[firstparen:]
326            return f'{self.__class__.__name__}{reprarr} with n_samples={self.n_samples}'
327            return reprarr
328
329    def __str__(self):
330        distrstr = str(self.distribution)
331        toadd = f' with n_samples={self.n_samples}'
332        return distrstr + toadd
333
334    def _repr_latex_(self):
335        if hasattr(self.distribution, '_repr_latex_'):
336            superlatex = self.distribution._repr_latex_()
337            toadd = fr', \; n_{{\rm samp}}={self.n_samples}'
338            return superlatex[:-1] + toadd + superlatex[-1]
339        else:
340            return None
341
342
343class NdarrayDistribution(_DistributionRepr, ArrayDistribution):
344    pass
345
346
347# Ensure our base NdarrayDistribution is known.
348Distribution._generated_subclasses[np.ndarray] = NdarrayDistribution
349