1"""Multivariate Normal Distribution."""
2import logging
3import numpy
4from scipy import special
5import chaospy
6
7from .normal import normal
8from ..baseclass import MeanCovarianceDistribution
9
10
11class MvNormal(MeanCovarianceDistribution):
12    r"""
13    Multivariate Normal Distribution.
14
15    Args:
16        mu (float, numpy.ndarray):
17            Mean vector
18        scale (float, numpy.ndarray):
19            Covariance matrix or variance vector if scale is a 1-d vector.
20        rotation (Sequence[int]):
21            The order of which to resolve conditionals.
22            Defaults to `range(len(distribution))` which is the same as
23            `p(x0), p(x1|x0), p(x2|x0,x1), ...`.
24
25    Examples:
26        >>> distribution = chaospy.MvNormal([10, 20, 30],
27        ...     [[1, 0.2, 0.3], [0.2, 2, 0.4], [0.3, 0.4, 1]], rotation=[1, 2, 0])
28        >>> distribution  # doctest: +NORMALIZE_WHITESPACE
29        MvNormal(mu=[10, 20, 30],
30                 sigma=[[1, 0.2, 0.3], [0.2, 2, 0.4], [0.3, 0.4, 1]])
31        >>> chaospy.E(distribution)
32        array([10., 20., 30.])
33        >>> chaospy.Cov(distribution)
34        array([[1. , 0.2, 0.3],
35               [0.2, 2. , 0.4],
36               [0.3, 0.4, 1. ]])
37        >>> mesh = numpy.mgrid[:2, :2, :2].reshape(3, -1)*.5+.1
38        >>> mesh
39        array([[0.1, 0.1, 0.1, 0.1, 0.6, 0.6, 0.6, 0.6],
40               [0.1, 0.1, 0.6, 0.6, 0.1, 0.1, 0.6, 0.6],
41               [0.1, 0.6, 0.1, 0.6, 0.1, 0.6, 0.1, 0.6]])
42        >>> mapped_samples = distribution.inv(mesh)
43        >>> mapped_samples.round(2)
44        array([[ 8.25,  8.67,  8.47,  8.88,  9.71, 10.13,  9.93, 10.35],
45               [18.19, 18.19, 20.36, 20.36, 18.19, 18.19, 20.36, 20.36],
46               [28.41, 29.88, 28.84, 30.31, 28.41, 29.88, 28.84, 30.31]])
47        >>> numpy.allclose(distribution.fwd(mapped_samples), mesh)
48        True
49        >>> distribution.pdf(mapped_samples).round(4)
50        array([0.0042, 0.0092, 0.0092, 0.0203, 0.0092, 0.0203, 0.0203, 0.0446])
51        >>> distribution.sample(4).round(4)
52        array([[10.3396,  9.0158, 11.1009, 10.0971],
53               [21.6096, 18.871 , 17.5357, 19.6314],
54               [29.6231, 30.7349, 28.7239, 30.5507]])
55
56    """
57
58    def __init__(
59            self,
60            mu,
61            sigma=None,
62            rotation=None,
63    ):
64        super(MvNormal, self).__init__(
65            dist=normal(),
66            mean=mu,
67            covariance=sigma,
68            rotation=rotation,
69            repr_args=chaospy.format_repr_kwargs(mu=(mu, None))+
70                      chaospy.format_repr_kwargs(sigma=(sigma, None)),
71        )
72
73    def _mom(self, k, mean, sigma, cache):
74        if isinstance(mean, chaospy.Distribution):
75            mean = mean._get_cache(None, cache=cache, get=0)
76            if isinstance(mean, chaospy.Distribution):
77                raise chaospy.UnsupportedFeature(
78                    "Analytical moment of a conditional not supported")
79        out = 0.
80        for idx, kdx in enumerate(numpy.ndindex(*[_+1 for _ in k])):
81            coef = numpy.prod(special.comb(k.T, kdx).T, 0)
82            diff = k.T-kdx
83            pos = diff >= 0
84            diff = diff*pos
85            pos = numpy.all(pos)
86            location_ = numpy.prod(mean**diff, -1)
87
88            out = out+pos*coef*location_*isserlis_moment(tuple(kdx), sigma)
89
90        return out
91
92
93def isserlis_moment(k, scale):
94    """
95    Centralized statistical moments using Isserlis' theorem.
96
97    Args:
98        k (Tuple[int, ...]):
99            Moment orders.
100        scale (ndarray):
101            Covariance matrix defining dependencies between variables.
102
103    Returns:
104        Raw statistical moment of order ``k`` given covariance ``scale``.
105
106    Examples:
107        >>> scale = 0.5*numpy.eye(3)+0.5
108        >>> isserlis_moment((2, 2, 2), scale)
109        3.5
110        >>> isserlis_moment((0, 0, 0), scale)
111        1.0
112        >>> isserlis_moment((1, 0, 0), scale)
113        0.0
114        >>> isserlis_moment((0, 1, 1), scale)
115        0.5
116        >>> isserlis_moment((0, 0, 2), scale)
117        1.0
118    """
119    if scale.ndim == 2:
120        scale = scale[numpy.newaxis]
121        return isserlis_moment(k, scale)[0]
122
123    if not isinstance(k, numpy.ndarray):
124        k = numpy.asarray(k)
125    assert len(k.shape) == 1
126
127    # Recursive exit condition
128    if (numpy.sum(k) % 2 == 1) or numpy.any(k < 0):
129        return numpy.zeros(len(scale))
130
131    # Choose a pivot index as first non-zero entry
132    idx = numpy.nonzero(k)[0]
133    if not idx.size:
134        # Skip if no pivot found
135        return numpy.ones(len(scale))
136    idx = idx[0]
137
138    eye = numpy.eye(len(k), dtype=int)
139    out = (k[idx]-1)*scale[:, idx, idx]*isserlis_moment(k-2*eye[idx], scale)
140    for idj in range(idx+1, len(k)):
141        out += k[idj]*scale[:, idx, idj]*isserlis_moment(k-eye[idx]-eye[idj], scale)
142    return out
143