1#   Copyright 2021 The PyMC Developers
2#
3#   Licensed under the Apache License, Version 2.0 (the "License");
4#   you may not use this file except in compliance with the License.
5#   You may obtain a copy of the License at
6#
7#       http://www.apache.org/licenses/LICENSE-2.0
8#
9#   Unless required by applicable law or agreed to in writing, software
10#   distributed under the License is distributed on an "AS IS" BASIS,
11#   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12#   See the License for the specific language governing permissions and
13#   limitations under the License.
14
15"""
16Created on Mar 7, 2011
17
18@author: johnsalvatier
19"""
20import numpy as np
21import scipy.linalg
22import scipy.stats
23import theano
24import theano.tensor as tt
25
26from theano import scan
27from theano.compile.builders import OpFromGraph
28from theano.graph.basic import Apply
29from theano.graph.op import Op
30from theano.scalar import UnaryScalarOp, upgrade_to_float_no_complex
31from theano.scan import until
32from theano.tensor.slinalg import Cholesky
33
34from pymc3.distributions.shape_utils import to_tuple
35from pymc3.distributions.special import gammaln
36from pymc3.model import modelcontext
37from pymc3.theanof import floatX
38
39f = floatX
40c = -0.5 * np.log(2.0 * np.pi)
41_beta_clip_values = {
42    dtype: (np.nextafter(0, 1, dtype=dtype), np.nextafter(1, 0, dtype=dtype))
43    for dtype in ["float16", "float32", "float64"]
44}
45
46
47def bound(logp, *conditions, **kwargs):
48    """
49    Bounds a log probability density with several conditions.
50
51    Parameters
52    ----------
53    logp: float
54    *conditions: booleans
55    broadcast_conditions: bool (optional, default=True)
56        If True, broadcasts logp to match the largest shape of the conditions.
57        This is used e.g. in DiscreteUniform where logp is a scalar constant and the shape
58        is specified via the conditions.
59        If False, will return the same shape as logp.
60        This is used e.g. in Multinomial where broadcasting can lead to differences in the logp.
61
62    Returns
63    -------
64    logp with elements set to -inf where any condition is False
65    """
66
67    # If called inside a model context, see if bounds check is disabled
68    try:
69        model = modelcontext(kwargs.get("model"))
70        if not model.check_bounds:
71            return logp
72    except TypeError:  # No model found
73        pass
74
75    broadcast_conditions = kwargs.get("broadcast_conditions", True)
76
77    if broadcast_conditions:
78        alltrue = alltrue_elemwise
79    else:
80        alltrue = alltrue_scalar
81
82    return tt.switch(alltrue(conditions), logp, -np.inf)
83
84
85def alltrue_elemwise(vals):
86    ret = 1
87    for c in vals:
88        ret = ret * (1 * c)
89    return ret
90
91
92def alltrue_scalar(vals):
93    return tt.all([tt.all(1 * val) for val in vals])
94
95
96def logpow(x, m):
97    """
98    Calculates log(x**m) since m*log(x) will fail when m, x = 0.
99    """
100    # return m * log(x)
101    return tt.switch(tt.eq(x, 0), tt.switch(tt.eq(m, 0), 0.0, -np.inf), m * tt.log(x))
102
103
104def factln(n):
105    return gammaln(n + 1)
106
107
108def binomln(n, k):
109    return factln(n) - factln(k) - factln(n - k)
110
111
112def betaln(x, y):
113    return gammaln(x) + gammaln(y) - gammaln(x + y)
114
115
116def std_cdf(x):
117    """
118    Calculates the standard normal cumulative distribution function.
119    """
120    return 0.5 + 0.5 * tt.erf(x / tt.sqrt(2.0))
121
122
123def normal_lcdf(mu, sigma, x):
124    """Compute the log of the cumulative density function of the normal."""
125    z = (x - mu) / sigma
126    return tt.switch(
127        tt.lt(z, -1.0),
128        tt.log(tt.erfcx(-z / tt.sqrt(2.0)) / 2.0) - tt.sqr(z) / 2.0,
129        tt.log1p(-tt.erfc(z / tt.sqrt(2.0)) / 2.0),
130    )
131
132
133def normal_lccdf(mu, sigma, x):
134    z = (x - mu) / sigma
135    return tt.switch(
136        tt.gt(z, 1.0),
137        tt.log(tt.erfcx(z / tt.sqrt(2.0)) / 2.0) - tt.sqr(z) / 2.0,
138        tt.log1p(-tt.erfc(-z / tt.sqrt(2.0)) / 2.0),
139    )
140
141
142def log_diff_normal_cdf(mu, sigma, x, y):
143    """
144    Compute :math:`\\log(\\Phi(\frac{x - \\mu}{\\sigma}) - \\Phi(\frac{y - \\mu}{\\sigma}))` safely in log space.
145
146    Parameters
147    ----------
148    mu: float
149        mean
150    sigma: float
151        std
152
153    x: float
154
155    y: float
156        must be strictly less than x.
157
158    Returns
159    -------
160    log (\\Phi(x) - \\Phi(y))
161
162    """
163    x = (x - mu) / sigma / tt.sqrt(2.0)
164    y = (y - mu) / sigma / tt.sqrt(2.0)
165
166    # To stabilize the computation, consider these three regions:
167    # 1) x > y > 0 => Use erf(x) = 1 - e^{-x^2} erfcx(x) and erf(y) =1 - e^{-y^2} erfcx(y)
168    # 2) 0 > x > y => Use erf(x) = e^{-x^2} erfcx(-x) and erf(y) = e^{-y^2} erfcx(-y)
169    # 3) x > 0 > y => Naive formula log( (erf(x) - erf(y)) / 2 ) works fine.
170    return tt.log(0.5) + tt.switch(
171        tt.gt(y, 0),
172        -tt.square(y) + tt.log(tt.erfcx(y) - tt.exp(tt.square(y) - tt.square(x)) * tt.erfcx(x)),
173        tt.switch(
174            tt.lt(x, 0),  # 0 > x > y
175            -tt.square(x)
176            + tt.log(tt.erfcx(-x) - tt.exp(tt.square(x) - tt.square(y)) * tt.erfcx(-y)),
177            tt.log(tt.erf(x) - tt.erf(y)),  # x >0 > y
178        ),
179    )
180
181
182def sigma2rho(sigma):
183    """
184    `sigma -> rho` theano converter
185    :math:`mu + sigma*e = mu + log(1+exp(rho))*e`"""
186    return tt.log(tt.exp(tt.abs_(sigma)) - 1.0)
187
188
189def rho2sigma(rho):
190    """
191    `rho -> sigma` theano converter
192    :math:`mu + sigma*e = mu + log(1+exp(rho))*e`"""
193    return tt.nnet.softplus(rho)
194
195
196rho2sd = rho2sigma
197sd2rho = sigma2rho
198
199
200def log_normal(x, mean, **kwargs):
201    """
202    Calculate logarithm of normal distribution at point `x`
203    with given `mean` and `std`
204
205    Parameters
206    ----------
207    x: Tensor
208        point of evaluation
209    mean: Tensor
210        mean of normal distribution
211    kwargs: one of parameters `{sigma, tau, w, rho}`
212
213    Notes
214    -----
215    There are four variants for density parametrization.
216    They are:
217        1) standard deviation - `std`
218        2) `w`, logarithm of `std` :math:`w = log(std)`
219        3) `rho` that follows this equation :math:`rho = log(exp(std) - 1)`
220        4) `tau` that follows this equation :math:`tau = std^{-1}`
221    ----
222    """
223    sigma = kwargs.get("sigma")
224    w = kwargs.get("w")
225    rho = kwargs.get("rho")
226    tau = kwargs.get("tau")
227    eps = kwargs.get("eps", 0.0)
228    check = sum(map(lambda a: a is not None, [sigma, w, rho, tau]))
229    if check > 1:
230        raise ValueError("more than one required kwarg is passed")
231    if check == 0:
232        raise ValueError("none of required kwarg is passed")
233    if sigma is not None:
234        std = sigma
235    elif w is not None:
236        std = tt.exp(w)
237    elif rho is not None:
238        std = rho2sigma(rho)
239    else:
240        std = tau ** (-1)
241    std += f(eps)
242    return f(c) - tt.log(tt.abs_(std)) - (x - mean) ** 2 / (2.0 * std ** 2)
243
244
245def MvNormalLogp():
246    """Compute the log pdf of a multivariate normal distribution.
247
248    This should be used in MvNormal.logp once Theano#5908 is released.
249
250    Parameters
251    ----------
252    cov: tt.matrix
253        The covariance matrix.
254    delta: tt.matrix
255        Array of deviations from the mean.
256    """
257    cov = tt.matrix("cov")
258    cov.tag.test_value = floatX(np.eye(3))
259    delta = tt.matrix("delta")
260    delta.tag.test_value = floatX(np.zeros((2, 3)))
261
262    solve_lower = tt.slinalg.Solve(A_structure="lower_triangular")
263    solve_upper = tt.slinalg.Solve(A_structure="upper_triangular")
264    cholesky = Cholesky(lower=True, on_error="nan")
265
266    n, k = delta.shape
267    n, k = f(n), f(k)
268    chol_cov = cholesky(cov)
269    diag = tt.nlinalg.diag(chol_cov)
270    ok = tt.all(diag > 0)
271
272    chol_cov = tt.switch(ok, chol_cov, tt.fill(chol_cov, 1))
273    delta_trans = solve_lower(chol_cov, delta.T).T
274
275    result = n * k * tt.log(f(2) * np.pi)
276    result += f(2) * n * tt.sum(tt.log(diag))
277    result += (delta_trans ** f(2)).sum()
278    result = f(-0.5) * result
279    logp = tt.switch(ok, result, -np.inf)
280
281    def dlogp(inputs, gradients):
282        (g_logp,) = gradients
283        cov, delta = inputs
284
285        g_logp.tag.test_value = floatX(1.0)
286        n, k = delta.shape
287
288        chol_cov = cholesky(cov)
289        diag = tt.nlinalg.diag(chol_cov)
290        ok = tt.all(diag > 0)
291
292        chol_cov = tt.switch(ok, chol_cov, tt.fill(chol_cov, 1))
293        delta_trans = solve_lower(chol_cov, delta.T).T
294
295        inner = n * tt.eye(k) - tt.dot(delta_trans.T, delta_trans)
296        g_cov = solve_upper(chol_cov.T, inner)
297        g_cov = solve_upper(chol_cov.T, g_cov.T)
298
299        tau_delta = solve_upper(chol_cov.T, delta_trans.T)
300        g_delta = tau_delta.T
301
302        g_cov = tt.switch(ok, g_cov, -np.nan)
303        g_delta = tt.switch(ok, g_delta, -np.nan)
304
305        return [-0.5 * g_cov * g_logp, -g_delta * g_logp]
306
307    return OpFromGraph([cov, delta], [logp], grad_overrides=dlogp, inline=True)
308
309
310class SplineWrapper(Op):
311    """
312    Creates a theano operation from scipy.interpolate.UnivariateSpline
313    """
314
315    __props__ = ("spline",)
316
317    def __init__(self, spline):
318        self.spline = spline
319
320    def make_node(self, x):
321        x = tt.as_tensor_variable(x)
322        return Apply(self, [x], [x.type()])
323
324    @property
325    def grad_op(self):
326        if not hasattr(self, "_grad_op"):
327            try:
328                self._grad_op = SplineWrapper(self.spline.derivative())
329            except ValueError:
330                self._grad_op = None
331
332        if self._grad_op is None:
333            raise NotImplementedError("Spline of order 0 is not differentiable")
334        return self._grad_op
335
336    def perform(self, node, inputs, output_storage):
337        (x,) = inputs
338        output_storage[0][0] = np.asarray(self.spline(x))
339
340    def grad(self, inputs, grads):
341        (x,) = inputs
342        (x_grad,) = grads
343
344        return [x_grad * self.grad_op(x)]
345
346
347class I1e(UnaryScalarOp):
348    """
349    Modified Bessel function of the first kind of order 1, exponentially scaled.
350    """
351
352    nfunc_spec = ("scipy.special.i1e", 1, 1)
353
354    def impl(self, x):
355        return scipy.special.i1e(x)
356
357
358i1e_scalar = I1e(upgrade_to_float_no_complex, name="i1e")
359i1e = tt.Elemwise(i1e_scalar, name="Elemwise{i1e,no_inplace}")
360
361
362class I0e(UnaryScalarOp):
363    """
364    Modified Bessel function of the first kind of order 0, exponentially scaled.
365    """
366
367    nfunc_spec = ("scipy.special.i0e", 1, 1)
368
369    def impl(self, x):
370        return scipy.special.i0e(x)
371
372    def grad(self, inp, grads):
373        (x,) = inp
374        (gz,) = grads
375        return (gz * (i1e_scalar(x) - theano.scalar.sgn(x) * i0e_scalar(x)),)
376
377
378i0e_scalar = I0e(upgrade_to_float_no_complex, name="i0e")
379i0e = tt.Elemwise(i0e_scalar, name="Elemwise{i0e,no_inplace}")
380
381
382def random_choice(*args, **kwargs):
383    """Return draws from a categorial probability functions
384
385    Args:
386        p: array
387           Probability of each class. If p.ndim > 1, the last axis is
388           interpreted as the probability of each class, and numpy.random.choice
389           is iterated for every other axis element.
390        size: int or tuple
391            Shape of the desired output array. If p is multidimensional, size
392            should broadcast with p.shape[:-1].
393
394    Returns:
395        random sample: array
396
397    """
398    p = kwargs.pop("p")
399    size = kwargs.pop("size")
400    k = p.shape[-1]
401
402    if p.ndim > 1:
403        # If p is an nd-array, the last axis is interpreted as the class
404        # probability. We must iterate over the elements of all the other
405        # dimensions.
406        # We first ensure that p is broadcasted to the output's shape
407        size = to_tuple(size) + (1,)
408        p = np.broadcast_arrays(p, np.empty(size))[0]
409        out_shape = p.shape[:-1]
410        # np.random.choice accepts 1D p arrays, so we semiflatten p to
411        # iterate calls using the last axis as the category probabilities
412        p = np.reshape(p, (-1, p.shape[-1]))
413        samples = np.array([np.random.choice(k, p=p_) for p_ in p])
414        # We reshape to the desired output shape
415        samples = np.reshape(samples, out_shape)
416    else:
417        samples = np.random.choice(k, p=p, size=size)
418    return samples
419
420
421def zvalue(value, sigma, mu):
422    """
423    Calculate the z-value for a normal distribution.
424    """
425    return (value - mu) / sigma
426
427
428def incomplete_beta_cfe(a, b, x, small):
429    """Incomplete beta continued fraction expansions
430    based on Cephes library by Steve Moshier (incbet.c).
431    small: Choose element-wise which continued fraction expansion to use.
432    """
433    BIG = tt.constant(4.503599627370496e15, dtype="float64")
434    BIGINV = tt.constant(2.22044604925031308085e-16, dtype="float64")
435    THRESH = tt.constant(3.0 * np.MachAr().eps, dtype="float64")
436
437    zero = tt.constant(0.0, dtype="float64")
438    one = tt.constant(1.0, dtype="float64")
439    two = tt.constant(2.0, dtype="float64")
440
441    r = one
442    k1 = a
443    k3 = a
444    k4 = a + one
445    k5 = one
446    k8 = a + two
447
448    k2 = tt.switch(small, a + b, b - one)
449    k6 = tt.switch(small, b - one, a + b)
450    k7 = tt.switch(small, k4, a + one)
451    k26update = tt.switch(small, one, -one)
452    x = tt.switch(small, x, x / (one - x))
453
454    pkm2 = zero
455    qkm2 = one
456    pkm1 = one
457    qkm1 = one
458    r = one
459
460    def _step(i, pkm1, pkm2, qkm1, qkm2, k1, k2, k3, k4, k5, k6, k7, k8, r):
461        xk = -(x * k1 * k2) / (k3 * k4)
462        pk = pkm1 + pkm2 * xk
463        qk = qkm1 + qkm2 * xk
464        pkm2 = pkm1
465        pkm1 = pk
466        qkm2 = qkm1
467        qkm1 = qk
468
469        xk = (x * k5 * k6) / (k7 * k8)
470        pk = pkm1 + pkm2 * xk
471        qk = qkm1 + qkm2 * xk
472        pkm2 = pkm1
473        pkm1 = pk
474        qkm2 = qkm1
475        qkm1 = qk
476
477        old_r = r
478        r = tt.switch(tt.eq(qk, zero), r, pk / qk)
479
480        k1 += one
481        k2 += k26update
482        k3 += two
483        k4 += two
484        k5 += one
485        k6 -= k26update
486        k7 += two
487        k8 += two
488
489        big_cond = tt.gt(tt.abs_(qk) + tt.abs_(pk), BIG)
490        biginv_cond = tt.or_(tt.lt(tt.abs_(qk), BIGINV), tt.lt(tt.abs_(pk), BIGINV))
491
492        pkm2 = tt.switch(big_cond, pkm2 * BIGINV, pkm2)
493        pkm1 = tt.switch(big_cond, pkm1 * BIGINV, pkm1)
494        qkm2 = tt.switch(big_cond, qkm2 * BIGINV, qkm2)
495        qkm1 = tt.switch(big_cond, qkm1 * BIGINV, qkm1)
496
497        pkm2 = tt.switch(biginv_cond, pkm2 * BIG, pkm2)
498        pkm1 = tt.switch(biginv_cond, pkm1 * BIG, pkm1)
499        qkm2 = tt.switch(biginv_cond, qkm2 * BIG, qkm2)
500        qkm1 = tt.switch(biginv_cond, qkm1 * BIG, qkm1)
501
502        return (
503            (pkm1, pkm2, qkm1, qkm2, k1, k2, k3, k4, k5, k6, k7, k8, r),
504            until(tt.abs_(old_r - r) < (THRESH * tt.abs_(r))),
505        )
506
507    (pkm1, pkm2, qkm1, qkm2, k1, k2, k3, k4, k5, k6, k7, k8, r), _ = scan(
508        _step,
509        sequences=[tt.arange(0, 300)],
510        outputs_info=[
511            e
512            for e in tt.cast((pkm1, pkm2, qkm1, qkm2, k1, k2, k3, k4, k5, k6, k7, k8, r), "float64")
513        ],
514    )
515
516    return r[-1]
517
518
519def incomplete_beta_ps(a, b, value):
520    """Power series for incomplete beta
521    Use when b*x is small and value not too close to 1.
522    Based on Cephes library by Steve Moshier (incbet.c)
523    """
524    one = tt.constant(1, dtype="float64")
525    ai = one / a
526    u = (one - b) * value
527    t1 = u / (a + one)
528    t = u
529    threshold = np.MachAr().eps * ai
530    s = tt.constant(0, dtype="float64")
531
532    def _step(i, t, s):
533        t *= (i - b) * value / i
534        step = t / (a + i)
535        s += step
536        return ((t, s), until(tt.abs_(step) < threshold))
537
538    (t, s), _ = scan(
539        _step, sequences=[tt.arange(2, 302)], outputs_info=[e for e in tt.cast((t, s), "float64")]
540    )
541
542    s = s[-1] + t1 + ai
543
544    t = gammaln(a + b) - gammaln(a) - gammaln(b) + a * tt.log(value) + tt.log(s)
545    return tt.exp(t)
546
547
548def incomplete_beta(a, b, value):
549    """Incomplete beta implementation
550    Power series and continued fraction expansions chosen for best numerical
551    convergence across the board based on inputs.
552    """
553    machep = tt.constant(np.MachAr().eps, dtype="float64")
554    one = tt.constant(1, dtype="float64")
555    w = one - value
556
557    ps = incomplete_beta_ps(a, b, value)
558
559    flip = tt.gt(value, (a / (a + b)))
560    aa, bb = a, b
561    a = tt.switch(flip, bb, aa)
562    b = tt.switch(flip, aa, bb)
563    xc = tt.switch(flip, value, w)
564    x = tt.switch(flip, w, value)
565
566    tps = incomplete_beta_ps(a, b, x)
567    tps = tt.switch(tt.le(tps, machep), one - machep, one - tps)
568
569    # Choose which continued fraction expansion for best convergence.
570    small = tt.lt(x * (a + b - 2.0) - (a - one), 0.0)
571    cfe = incomplete_beta_cfe(a, b, x, small)
572    w = tt.switch(small, cfe, cfe / xc)
573
574    # Direct incomplete beta accounting for flipped a, b.
575    t = tt.exp(
576        a * tt.log(x) + b * tt.log(xc) + gammaln(a + b) - gammaln(a) - gammaln(b) + tt.log(w / a)
577    )
578
579    t = tt.switch(flip, tt.switch(tt.le(t, machep), one - machep, one - t), t)
580    return tt.switch(
581        tt.and_(flip, tt.and_(tt.le((b * x), one), tt.le(x, 0.95))),
582        tps,
583        tt.switch(tt.and_(tt.le(b * value, one), tt.le(value, 0.95)), ps, t),
584    )
585
586
587def clipped_beta_rvs(a, b, size=None, dtype="float64"):
588    """Draw beta distributed random samples in the open :math:`(0, 1)` interval.
589
590    The samples are generated with ``scipy.stats.beta.rvs``, but any value that
591    is equal to 0 or 1 will be shifted towards the next floating point in the
592    interval :math:`[0, 1]`, depending on the floating point precision that is
593    given by ``dtype``.
594
595    Parameters
596    ----------
597    a : float or array_like of floats
598        Alpha, strictly positive (>0).
599    b : float or array_like of floats
600        Beta, strictly positive (>0).
601    size : int or tuple of ints, optional
602        Output shape.  If the given shape is, e.g., ``(m, n, k)``, then
603        ``m * n * k`` samples are drawn.  If size is ``None`` (default),
604        a single value is returned if ``a`` and ``b`` are both scalars.
605        Otherwise, ``np.broadcast(a, b).size`` samples are drawn.
606    dtype : str or dtype instance
607        The floating point precision that the samples should have. This also
608        determines the value that will be used to shift any samples returned
609        by the numpy random number generator that are zero or one.
610
611    Returns
612    -------
613    out : ndarray or scalar
614        Drawn samples from the parameterized beta distribution. The scipy
615        implementation can yield values that are equal to zero or one. We
616        assume the support of the Beta distribution to be in the open interval
617        :math:`(0, 1)`, so we shift any sample that is equal to 0 to
618        ``np.nextafter(0, 1, dtype=dtype)`` and any sample that is equal to 1
619        is shifted to ``np.nextafter(1, 0, dtype=dtype)``.
620
621    """
622    out = scipy.stats.beta.rvs(a, b, size=size).astype(dtype)
623    lower, upper = _beta_clip_values[dtype]
624    return np.maximum(np.minimum(out, upper), lower)
625