1#   Copyright 2020 The PyMC Developers
2#
3#   Licensed under the Apache License, Version 2.0 (the "License");
4#   you may not use this file except in compliance with the License.
5#   You may obtain a copy of the License at
6#
7#       http://www.apache.org/licenses/LICENSE-2.0
8#
9#   Unless required by applicable law or agreed to in writing, software
10#   distributed under the License is distributed on an "AS IS" BASIS,
11#   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12#   See the License for the specific language governing permissions and
13#   limitations under the License.
14
15import warnings
16
17import numpy as np
18import theano.tensor as tt
19
20from scipy.special import logit as nplogit
21
22from pymc3.distributions import distribution
23from pymc3.distributions.distribution import draw_values
24from pymc3.math import invlogit, logit, logsumexp
25from pymc3.model import FreeRV
26from pymc3.theanof import floatX, gradient
27
28__all__ = [
29    "Transform",
30    "transform",
31    "stick_breaking",
32    "logodds",
33    "interval",
34    "log_exp_m1",
35    "lowerbound",
36    "upperbound",
37    "ordered",
38    "log",
39    "sum_to_1",
40    "circular",
41    "CholeskyCovPacked",
42    "Chain",
43]
44
45
46class Transform:
47    """A transformation of a random variable from one space into another.
48
49    Attributes
50    ----------
51    name: str
52    """
53
54    name = ""
55
56    def forward(self, x):
57        """Applies transformation forward to input variable `x`.
58        When transform is used on some distribution `p`, it will transform the random variable `x` after sampling
59        from `p`.
60
61        Parameters
62        ----------
63        x: tensor
64            Input tensor to be transformed.
65
66        Returns
67        --------
68        tensor
69            Transformed tensor.
70        """
71        raise NotImplementedError
72
73    def forward_val(self, x, point):
74        """Applies transformation forward to input array `x`.
75        Similar to `forward` but for constant data.
76
77        Parameters
78        ----------
79        x: array_like
80            Input array to be transformed.
81        point: array_like, optional
82            Test value used to draw (fix) bounds-like transformations
83
84        Returns
85        --------
86        array_like
87            Transformed array.
88        """
89        raise NotImplementedError
90
91    def backward(self, z):
92        """Applies inverse of transformation to input variable `z`.
93        When transform is used on some distribution `p`, which has observed values `z`, it is used to
94        transform the values of `z` correctly to the support of `p`.
95
96        Parameters
97        ----------
98        z: tensor
99            Input tensor to be inverse transformed.
100
101        Returns
102        -------
103        tensor
104            Inverse transformed tensor.
105        """
106        raise NotImplementedError
107
108    def jacobian_det(self, x):
109        """Calculates logarithm of the absolute value of the Jacobian determinant
110        of the backward transformation for input `x`.
111
112        Parameters
113        ----------
114        x: tensor
115            Input to calculate Jacobian determinant of.
116
117        Returns
118        -------
119        tensor
120            The log abs Jacobian determinant of `x` w.r.t. this transform.
121        """
122        raise NotImplementedError
123
124    def apply(self, dist):
125        # avoid circular import
126        return TransformedDistribution.dist(dist, self)
127
128    def __str__(self):
129        return self.name + " transform"
130
131
132class ElemwiseTransform(Transform):
133    def jacobian_det(self, x):
134        grad = tt.reshape(gradient(tt.sum(self.backward(x)), [x]), x.shape)
135        return tt.log(tt.abs_(grad))
136
137
138class TransformedDistribution(distribution.Distribution):
139    """A distribution that has been transformed from one space into another."""
140
141    def __init__(self, dist, transform, *args, **kwargs):
142        """
143        Parameters
144        ----------
145        dist: Distribution
146        transform: Transform
147        args, kwargs
148            arguments to Distribution"""
149        forward = transform.forward
150        testval = forward(dist.default())
151
152        self.dist = dist
153        self.transform_used = transform
154        v = forward(FreeRV(name="v", distribution=dist))
155        self.type = v.type
156
157        super().__init__(v.shape.tag.test_value, v.dtype, testval, dist.defaults, *args, **kwargs)
158
159        if transform.name == "stickbreaking":
160            b = np.hstack(((np.atleast_1d(self.shape) == 1)[:-1], False))
161            # force the last dim not broadcastable
162            self.type = tt.TensorType(v.dtype, b)
163
164    def logp(self, x):
165        """
166        Calculate log-probability of Transformed distribution at specified value.
167
168        Parameters
169        ----------
170        x: numeric
171            Value for which log-probability is calculated.
172
173        Returns
174        -------
175        TensorVariable
176        """
177        logp_nojac = self.logp_nojac(x)
178        jacobian_det = self.transform_used.jacobian_det(x)
179        if logp_nojac.ndim > jacobian_det.ndim:
180            logp_nojac = logp_nojac.sum(axis=-1)
181        return logp_nojac + jacobian_det
182
183    def logp_nojac(self, x):
184        """
185        Calculate log-probability of Transformed distribution at specified value
186        without jacobian term for transforms.
187
188        Parameters
189        ----------
190        x: numeric
191            Value for which log-probability is calculated.
192
193        Returns
194        -------
195        TensorVariable
196        """
197        return self.dist.logp(self.transform_used.backward(x))
198
199    def _repr_latex_(self, **kwargs):
200        # prevent TransformedDistributions from ending up in LaTeX representations
201        # of models
202        return None
203
204    def _distr_parameters_for_repr(self):
205        return []
206
207
208transform = Transform
209
210
211class Log(ElemwiseTransform):
212    name = "log"
213
214    def backward(self, x):
215        return tt.exp(x)
216
217    def forward(self, x):
218        return tt.log(x)
219
220    def forward_val(self, x, point=None):
221        return np.log(x)
222
223    def jacobian_det(self, x):
224        return x
225
226
227log = Log()
228
229
230class LogExpM1(ElemwiseTransform):
231    name = "log_exp_m1"
232
233    def backward(self, x):
234        return tt.nnet.softplus(x)
235
236    def forward(self, x):
237        """Inverse operation of softplus.
238
239        y = Log(Exp(x) - 1)
240          = Log(1 - Exp(-x)) + x
241        """
242        return tt.log(1.0 - tt.exp(-x)) + x
243
244    def forward_val(self, x, point=None):
245        return np.log(1.0 - np.exp(-x)) + x
246
247    def jacobian_det(self, x):
248        return -tt.nnet.softplus(-x)
249
250
251log_exp_m1 = LogExpM1()
252
253
254class LogOdds(ElemwiseTransform):
255    name = "logodds"
256
257    def backward(self, x):
258        return invlogit(x, 0.0)
259
260    def forward(self, x):
261        return logit(x)
262
263    def forward_val(self, x, point=None):
264        return nplogit(x)
265
266
267logodds = LogOdds()
268
269
270class Interval(ElemwiseTransform):
271    """Transform from real line interval [a,b] to whole real line."""
272
273    name = "interval"
274
275    def __init__(self, a, b):
276        self.a = tt.as_tensor_variable(a)
277        self.b = tt.as_tensor_variable(b)
278
279    def backward(self, x):
280        a, b = self.a, self.b
281        sigmoid_x = tt.nnet.sigmoid(x)
282        r = sigmoid_x * b + (1 - sigmoid_x) * a
283        return r
284
285    def forward(self, x):
286        a, b = self.a, self.b
287        return tt.log(x - a) - tt.log(b - x)
288
289    def forward_val(self, x, point=None):
290        # 2017-06-19
291        # the `self.a-0.` below is important for the testval to propagates
292        # For an explanation see pull/2328#issuecomment-309303811
293        a, b = draw_values([self.a - 0.0, self.b - 0.0], point=point)
294        return floatX(np.log(x - a) - np.log(b - x))
295
296    def jacobian_det(self, x):
297        s = tt.nnet.softplus(-x)
298        return tt.log(self.b - self.a) - 2 * s - x
299
300
301interval = Interval
302
303
304class LowerBound(ElemwiseTransform):
305    """Transform from real line interval [a,inf] to whole real line."""
306
307    name = "lowerbound"
308
309    def __init__(self, a):
310        self.a = tt.as_tensor_variable(a)
311
312    def backward(self, x):
313        a = self.a
314        r = tt.exp(x) + a
315        return r
316
317    def forward(self, x):
318        a = self.a
319        return tt.log(x - a)
320
321    def forward_val(self, x, point=None):
322        # 2017-06-19
323        # the `self.a-0.` below is important for the testval to propagates
324        # For an explanation see pull/2328#issuecomment-309303811
325        a = draw_values([self.a - 0.0], point=point)[0]
326        return floatX(np.log(x - a))
327
328    def jacobian_det(self, x):
329        return x
330
331
332lowerbound = LowerBound
333"""
334Alias for ``LowerBound`` (:class: LowerBound) Transform (:class: Transform) class
335for use in the ``transform`` argument of a random variable.
336"""
337
338
339class UpperBound(ElemwiseTransform):
340    """Transform from real line interval [-inf,b] to whole real line."""
341
342    name = "upperbound"
343
344    def __init__(self, b):
345        self.b = tt.as_tensor_variable(b)
346
347    def backward(self, x):
348        b = self.b
349        r = b - tt.exp(x)
350        return r
351
352    def forward(self, x):
353        b = self.b
354        return tt.log(b - x)
355
356    def forward_val(self, x, point=None):
357        # 2017-06-19
358        # the `self.b-0.` below is important for the testval to propagates
359        # For an explanation see pull/2328#issuecomment-309303811
360        b = draw_values([self.b - 0.0], point=point)[0]
361        return floatX(np.log(b - x))
362
363    def jacobian_det(self, x):
364        return x
365
366
367upperbound = UpperBound
368"""
369Alias for ``UpperBound`` (:class: UpperBound) Transform (:class: Transform) class
370for use in the ``transform`` argument of a random variable.
371"""
372
373
374class Ordered(Transform):
375    name = "ordered"
376
377    def backward(self, y):
378        x = tt.zeros(y.shape)
379        x = tt.inc_subtensor(x[..., 0], y[..., 0])
380        x = tt.inc_subtensor(x[..., 1:], tt.exp(y[..., 1:]))
381        return tt.cumsum(x, axis=-1)
382
383    def forward(self, x):
384        y = tt.zeros(x.shape)
385        y = tt.inc_subtensor(y[..., 0], x[..., 0])
386        y = tt.inc_subtensor(y[..., 1:], tt.log(x[..., 1:] - x[..., :-1]))
387        return y
388
389    def forward_val(self, x, point=None):
390        y = np.zeros_like(x)
391        y[..., 0] = x[..., 0]
392        y[..., 1:] = np.log(x[..., 1:] - x[..., :-1])
393        return y
394
395    def jacobian_det(self, y):
396        return tt.sum(y[..., 1:], axis=-1)
397
398
399ordered = Ordered()
400"""
401Instantiation of ``Ordered`` (:class: Ordered) Transform (:class: Transform) class
402for use in the ``transform`` argument of a random variable.
403"""
404
405
406class SumTo1(Transform):
407    """
408    Transforms K - 1 dimensional simplex space (k values in [0,1] and that sum to 1) to a K - 1 vector of values in [0,1]
409    This Transformation operates on the last dimension of the input tensor.
410    """
411
412    name = "sumto1"
413
414    def backward(self, y):
415        remaining = 1 - tt.sum(y[..., :], axis=-1, keepdims=True)
416        return tt.concatenate([y[..., :], remaining], axis=-1)
417
418    def forward(self, x):
419        return x[..., :-1]
420
421    def forward_val(self, x, point=None):
422        return x[..., :-1]
423
424    def jacobian_det(self, x):
425        y = tt.zeros(x.shape)
426        return tt.sum(y, axis=-1)
427
428
429sum_to_1 = SumTo1()
430
431
432class StickBreaking(Transform):
433    """
434    Transforms K - 1 dimensional simplex space (k values in [0,1] and that sum to 1) to a K - 1 vector of real values.
435    This is a variant of the isometric logration transformation ::
436
437        Egozcue, J.J., Pawlowsky-Glahn, V., Mateu-Figueras, G. et al.
438        Isometric Logratio Transformations for Compositional Data Analysis.
439        Mathematical Geology 35, 279–300 (2003). https://doi.org/10.1023/A:1023818214614
440    """
441
442    name = "stickbreaking"
443
444    def __init__(self, eps=None):
445        if eps is not None:
446            warnings.warn(
447                "The argument `eps` is deprecated and will not be used.", DeprecationWarning
448            )
449
450    def forward(self, x_):
451        x = x_.T
452        n = x.shape[0]
453        lx = tt.log(x)
454        shift = tt.sum(lx, 0, keepdims=True) / n
455        y = lx[:-1] - shift
456        return floatX(y.T)
457
458    def forward_val(self, x_, point=None):
459        x = x_.T
460        n = x.shape[0]
461        lx = np.log(x)
462        shift = np.sum(lx, 0, keepdims=True) / n
463        y = lx[:-1] - shift
464        return floatX(y.T)
465
466    def backward(self, y_):
467        y = y_.T
468        y = tt.concatenate([y, -tt.sum(y, 0, keepdims=True)])
469        # "softmax" with vector support and no deprication warning:
470        e_y = tt.exp(y - tt.max(y, 0, keepdims=True))
471        x = e_y / tt.sum(e_y, 0, keepdims=True)
472        return floatX(x.T)
473
474    def jacobian_det(self, y_):
475        y = y_.T
476        Km1 = y.shape[0] + 1
477        sy = tt.sum(y, 0, keepdims=True)
478        r = tt.concatenate([y + sy, tt.zeros(sy.shape)])
479        sr = logsumexp(r, 0, keepdims=True)
480        d = tt.log(Km1) + (Km1 * sy) - (Km1 * sr)
481        return tt.sum(d, 0).T
482
483
484stick_breaking = StickBreaking()
485
486
487class Circular(ElemwiseTransform):
488    """Transforms a linear space into a circular one."""
489
490    name = "circular"
491
492    def backward(self, y):
493        return tt.arctan2(tt.sin(y), tt.cos(y))
494
495    def forward(self, x):
496        return tt.as_tensor_variable(x)
497
498    def forward_val(self, x, point=None):
499        return x
500
501    def jacobian_det(self, x):
502        return tt.zeros(x.shape)
503
504
505circular = Circular()
506
507
508class CholeskyCovPacked(Transform):
509    name = "cholesky-cov-packed"
510
511    def __init__(self, n):
512        self.diag_idxs = np.arange(1, n + 1).cumsum() - 1
513
514    def backward(self, x):
515        return tt.advanced_set_subtensor1(x, tt.exp(x[self.diag_idxs]), self.diag_idxs)
516
517    def forward(self, y):
518        return tt.advanced_set_subtensor1(y, tt.log(y[self.diag_idxs]), self.diag_idxs)
519
520    def forward_val(self, y, point=None):
521        y[..., self.diag_idxs] = np.log(y[..., self.diag_idxs])
522        return y
523
524    def jacobian_det(self, y):
525        return tt.sum(y[self.diag_idxs])
526
527
528class Chain(Transform):
529    def __init__(self, transform_list):
530        self.transform_list = transform_list
531        self.name = "+".join([transf.name for transf in self.transform_list])
532
533    def forward(self, x):
534        y = x
535        for transf in self.transform_list:
536            y = transf.forward(y)
537        return y
538
539    def forward_val(self, x, point=None):
540        y = x
541        for transf in self.transform_list:
542            y = transf.forward_val(y)
543        return y
544
545    def backward(self, y):
546        x = y
547        for transf in reversed(self.transform_list):
548            x = transf.backward(x)
549        return x
550
551    def jacobian_det(self, y):
552        y = tt.as_tensor_variable(y)
553        det_list = []
554        ndim0 = y.ndim
555        for transf in reversed(self.transform_list):
556            det_ = transf.jacobian_det(y)
557            det_list.append(det_)
558            y = transf.backward(y)
559            ndim0 = min(ndim0, det_.ndim)
560        # match the shape of the smallest jacobian_det
561        det = 0.0
562        for det_ in det_list:
563            if det_.ndim > ndim0:
564                det += det_.sum(axis=-1)
565            else:
566                det += det_
567        return det
568