1"""Gram-Schmidt process for generating orthogonal polynomials."""
2import logging
3
4import numpy
5import numpoly
6import chaospy
7
8
9def gram_schmidt(order, dist, normed=False, graded=True, reverse=True,
10            retall=False, cross_truncation=1., **kws):
11    """
12    Gram-Schmidt process for generating orthogonal polynomials.
13
14    Args:
15        order (int, numpoly.ndpoly):
16            The upper polynomial order. Alternative a custom polynomial basis
17            can be used.
18        dist (Distribution):
19            Weighting distribution(s) defining orthogonality.
20        normed (bool):
21            If True orthonormal polynomials will be used instead of monic.
22        graded (bool):
23            Graded sorting, meaning the indices are always sorted by the index
24            sum. E.g. ``q0**2*q1**2*q2**2`` has an exponent sum of 6, and will
25            therefore be consider larger than both ``q0**2*q1*q2``,
26            ``q0*q1**2*q2`` and ``q0*q1*q2**2``, which all have exponent sum of
27            5.
28        reverse (bool):
29            Reverse lexicographical sorting meaning that ``q0*q1**3`` is
30            considered bigger than ``q0**3*q1``, instead of the opposite.
31        retall (bool):
32            If true return numerical stabilized norms as well. Roughly the same
33            as ``cp.E(orth**2, dist)``.
34        cross_truncation (float):
35            Use hyperbolic cross truncation scheme to reduce the number of
36            terms in expansion.
37
38    Returns:
39        (chapspy.poly.ndpoly):
40            The orthogonal polynomial expansion.
41
42    Examples:
43        >>> distribution = chaospy.J(chaospy.Normal(), chaospy.Normal())
44        >>> polynomials, norms = chaospy.expansion.gram_schmidt(2, distribution, retall=True)
45        >>> polynomials.round(10)
46        polynomial([1.0, q1, q0, q1**2-1.0, q0*q1, q0**2-1.0])
47        >>> norms.round(10)
48        array([1., 1., 1., 2., 1., 2.])
49        >>> polynomials = chaospy.expansion.gram_schmidt(2, distribution, normed=True)
50        >>> polynomials.round(3)
51        polynomial([1.0, q1, q0, 0.707*q1**2-0.707, q0*q1, 0.707*q0**2-0.707])
52
53    """
54    logger = logging.getLogger(__name__)
55    dim = len(dist)
56
57    if isinstance(order, int):
58        order = numpoly.monomial(
59            0,
60            order+1,
61            dimensions=numpoly.variable(2).names,
62            graded=graded,
63            reverse=reverse,
64            cross_truncation=cross_truncation,
65        )
66    basis = list(order)
67    polynomials = [basis[0]]
68
69    norms = [1.]
70    for idx in range(1, len(basis)):
71
72        # orthogonalize polynomial:
73        for idy in range(idx):
74            orth = chaospy.E(basis[idx]*polynomials[idy], dist, **kws)
75            basis[idx] = basis[idx]-polynomials[idy]*orth/norms[idy]
76
77        norms_ = chaospy.E(basis[idx]**2, dist, **kws)
78        if norms_ <= 0:  # pragma: no cover
79            logger.warning("Warning: Polynomial cutoff at term %d", idx)
80            break
81
82        norms.append(1. if normed else norms_)
83        basis[idx] = basis[idx]/numpy.sqrt(norms_) if normed else basis[idx]
84        polynomials.append(basis[idx])
85
86    polynomials = chaospy.polynomial(polynomials).flatten()
87    if retall:
88        norms = numpy.array(norms)
89        return polynomials, norms
90    return polynomials
91