1"""Smolyak sparse grid constructor."""
2from collections import defaultdict
3from itertools import product
4
5import numpy
6from scipy.special import comb
7
8import numpoly
9import chaospy
10
11
12def sparse_grid(
13        order,
14        dist,
15        growth=None,
16        recurrence_algorithm="stieltjes",
17        rule="gaussian",
18        tolerance=1e-10,
19        scaling=3,
20        n_max=5000,
21):
22    """
23    Smolyak sparse grid constructor.
24
25    Args:
26        order (int, numpy.ndarray):
27            The order of the grid. If ``numpy.ndarray``, it overrides both
28            ``dim`` and ``skew``.
29        dist (chaospy.distributions.baseclass.Distribution):
30            The distribution which density will be used as weight function.
31        growth (bool, None):
32            If True sets the growth rule for the quadrature rule to only
33            include orders that enhances nested samples. Defaults to the same
34            value as ``sparse`` if omitted.
35        recurrence_algorithm (str):
36            Name of the algorithm used to generate abscissas and weights in
37            case of Gaussian quadrature scheme. If omitted, ``analytical`` will
38            be tried first, and ``stieltjes`` used if that fails.
39        rule (str):
40            Rule for generating abscissas and weights. Either done with
41            quadrature rules, or with random samples with constant weights.
42        tolerance (float):
43            The allowed relative error in norm between two quadrature orders
44            before method assumes convergence.
45        scaling (float):
46            A multiplier the adaptive order increases with for each step
47            quadrature order is not converged. Use 0 to indicate unit
48            increments.
49        n_max (int):
50            The allowed number of quadrature points to use in approximation.
51
52    Returns:
53        (numpy.ndarray, numpy.ndarray):
54            Abscissas and weights created from sparse grid rule. Flatten such
55            that ``abscissas.shape == (len(dist), len(weights))``.
56
57    Example:
58        >>> distribution = chaospy.J(chaospy.Normal(0, 1), chaospy.Uniform(-1, 1))
59        >>> abscissas, weights = chaospy.quadrature.sparse_grid(1, distribution)
60        >>> abscissas.round(4)
61        array([[-1.    ,  0.    ,  0.    ,  0.    ,  1.    ],
62               [ 0.    , -0.5774,  0.    ,  0.5774,  0.    ]])
63        >>> weights.round(4)
64        array([ 0.5,  0.5, -1. ,  0.5,  0.5])
65        >>> abscissas, weights = chaospy.quadrature.sparse_grid([2, 1], distribution)
66        >>> abscissas.round(2)
67        array([[-1.73, -1.  , -1.  , -1.  ,  0.  ,  1.  ,  1.  ,  1.  ,  1.73],
68               [ 0.  , -0.58,  0.  ,  0.58,  0.  , -0.58,  0.  ,  0.58,  0.  ]])
69        >>> weights.round(2)
70        array([ 0.17,  0.25, -0.5 ,  0.25,  0.67,  0.25, -0.5 ,  0.25,  0.17])
71    """
72    orders = order*numpy.ones(len(dist), dtype=int)
73    growth = True if growth is None else growth
74
75    assert isinstance(dist, chaospy.Distribution), "dist must be chaospy.Distribution"
76    dist = dist if isinstance(dist, (chaospy.J, chaospy.Iid)) else chaospy.J(dist)
77
78    if isinstance(rule, str):
79        rule = (rule,)*len(dist)
80
81    x_lookup, w_lookup = _construct_lookup(
82        orders=orders,
83        dists=dist,
84        growth=growth,
85        recurrence_algorithm=recurrence_algorithm,
86        rules=rule,
87        tolerance=tolerance,
88        scaling=scaling,
89        n_max=n_max,
90    )
91    collection = _construct_collection(
92        order, dist, x_lookup, w_lookup)
93
94    abscissas = sorted(collection)
95    weights = numpy.array([collection[key] for key in abscissas])
96    abscissas = numpy.array(abscissas).T
97    return abscissas, weights
98
99
100def _construct_collection(
101        orders,
102        dist,
103        x_lookup,
104        w_lookup,
105):
106    """Create a collection of {abscissa: weight} key-value pairs."""
107    order = numpy.min(orders)
108    skew = orders-order
109
110    # Indices and coefficients used in the calculations
111    indices = numpoly.glexindex(
112        order-len(dist)+1, order+1, dimensions=len(dist))
113    coeffs = numpy.sum(indices, -1)
114    coeffs = (2*((order-coeffs+1) % 2)-1)*comb(len(dist)-1, order-coeffs)
115
116    collection = defaultdict(float)
117    for bidx, coeff in zip(indices+skew, coeffs.tolist()):
118        abscissas = [value[idx] for idx, value in zip(bidx, x_lookup)]
119        weights = [value[idx] for idx, value in zip(bidx, w_lookup)]
120        for abscissa, weight in zip(product(*abscissas), product(*weights)):
121            collection[abscissa] += numpy.prod(weight)*coeff
122
123    return collection
124
125
126def _construct_lookup(
127        orders,
128        dists,
129        growth,
130        recurrence_algorithm,
131        rules,
132        tolerance,
133        scaling,
134        n_max,
135):
136    """
137    Create abscissas and weights look-up table so values do not need to be
138    re-calculatated on the fly.
139    """
140    x_lookup = []
141    w_lookup = []
142    for max_order, dist, rule in zip(orders, dists, rules):
143        x_lookup.append([])
144        w_lookup.append([])
145        for order in range(max_order+1):
146            (abscissas,), weights = chaospy.generate_quadrature(
147                order=order,
148                dist=dist,
149                growth=growth,
150                recurrence_algorithm=recurrence_algorithm,
151                rule=rule,
152                tolerance=tolerance,
153                scaling=scaling,
154                n_max=n_max,
155            )
156            x_lookup[-1].append(abscissas)
157            w_lookup[-1].append(weights)
158    return x_lookup, w_lookup
159