1#   Copyright 2020 The PyMC Developers
2#
3#   Licensed under the Apache License, Version 2.0 (the "License");
4#   you may not use this file except in compliance with the License.
5#   You may obtain a copy of the License at
6#
7#       http://www.apache.org/licenses/LICENSE-2.0
8#
9#   Unless required by applicable law or agreed to in writing, software
10#   distributed under the License is distributed on an "AS IS" BASIS,
11#   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12#   See the License for the specific language governing permissions and
13#   limitations under the License.
14
15import numpy as np
16import pytest
17import theano
18import theano.tensor as tt
19
20import pymc3 as pm
21import pymc3.distributions.transforms as tr
22
23from pymc3.tests.checks import close_to, close_to_logical
24from pymc3.tests.helpers import SeededTest
25from pymc3.tests.test_distributions import (
26    Circ,
27    MultiSimplex,
28    R,
29    Rminusbig,
30    Rplusbig,
31    Simplex,
32    SortedVector,
33    Unit,
34    UnitSortedVector,
35    Vector,
36)
37from pymc3.theanof import jacobian
38
39# some transforms (stick breaking) require additon of small slack in order to be numerically
40# stable. The minimal addable slack for float32 is higher thus we need to be less strict
41tol = 1e-7 if theano.config.floatX == "float64" else 1e-6
42
43
44def check_transform(transform, domain, constructor=tt.dscalar, test=0):
45    x = constructor("x")
46    x.tag.test_value = test
47    # test forward and forward_val
48    forward_f = theano.function([x], transform.forward(x))
49    # test transform identity
50    identity_f = theano.function([x], transform.backward(transform.forward(x)))
51    for val in domain.vals:
52        close_to(val, identity_f(val), tol)
53        close_to(transform.forward_val(val), forward_f(val), tol)
54
55
56def check_vector_transform(transform, domain):
57    return check_transform(transform, domain, tt.dvector, test=np.array([0, 0]))
58
59
60def get_values(transform, domain=R, constructor=tt.dscalar, test=0):
61    x = constructor("x")
62    x.tag.test_value = test
63    f = theano.function([x], transform.backward(x))
64    return np.array([f(val) for val in domain.vals])
65
66
67def check_jacobian_det(
68    transform, domain, constructor=tt.dscalar, test=0, make_comparable=None, elemwise=False
69):
70    y = constructor("y")
71    y.tag.test_value = test
72
73    x = transform.backward(y)
74    if make_comparable:
75        x = make_comparable(x)
76
77    if not elemwise:
78        jac = tt.log(tt.nlinalg.det(jacobian(x, [y])))
79    else:
80        jac = tt.log(tt.abs_(tt.diag(jacobian(x, [y]))))
81
82    # ljd = log jacobian det
83    actual_ljd = theano.function([y], jac)
84
85    computed_ljd = theano.function(
86        [y], tt.as_tensor_variable(transform.jacobian_det(y)), on_unused_input="ignore"
87    )
88
89    for yval in domain.vals:
90        close_to(actual_ljd(yval), computed_ljd(yval), tol)
91
92
93def test_stickbreaking():
94    with pytest.warns(
95        DeprecationWarning, match="The argument `eps` is deprecated and will not be used."
96    ):
97        tr.StickBreaking(eps=1e-9)
98    check_vector_transform(tr.stick_breaking, Simplex(2))
99    check_vector_transform(tr.stick_breaking, Simplex(4))
100
101    check_transform(
102        tr.stick_breaking, MultiSimplex(3, 2), constructor=tt.dmatrix, test=np.zeros((2, 2))
103    )
104
105
106def test_stickbreaking_bounds():
107    vals = get_values(tr.stick_breaking, Vector(R, 2), tt.dvector, np.array([0, 0]))
108
109    close_to(vals.sum(axis=1), 1, tol)
110    close_to_logical(vals > 0, True, tol)
111    close_to_logical(vals < 1, True, tol)
112
113    check_jacobian_det(
114        tr.stick_breaking, Vector(R, 2), tt.dvector, np.array([0, 0]), lambda x: x[:-1]
115    )
116
117
118def test_stickbreaking_accuracy():
119    val = np.array([-30])
120    x = tt.dvector("x")
121    x.tag.test_value = val
122    identity_f = theano.function([x], tr.stick_breaking.forward(tr.stick_breaking.backward(x)))
123    close_to(val, identity_f(val), tol)
124
125
126def test_sum_to_1():
127    check_vector_transform(tr.sum_to_1, Simplex(2))
128    check_vector_transform(tr.sum_to_1, Simplex(4))
129
130    check_jacobian_det(tr.sum_to_1, Vector(Unit, 2), tt.dvector, np.array([0, 0]), lambda x: x[:-1])
131
132
133def test_log():
134    check_transform(tr.log, Rplusbig)
135
136    check_jacobian_det(tr.log, Rplusbig, elemwise=True)
137    check_jacobian_det(tr.log, Vector(Rplusbig, 2), tt.dvector, [0, 0], elemwise=True)
138
139    vals = get_values(tr.log)
140    close_to_logical(vals > 0, True, tol)
141
142
143def test_log_exp_m1():
144    check_transform(tr.log_exp_m1, Rplusbig)
145
146    check_jacobian_det(tr.log_exp_m1, Rplusbig, elemwise=True)
147    check_jacobian_det(tr.log_exp_m1, Vector(Rplusbig, 2), tt.dvector, [0, 0], elemwise=True)
148
149    vals = get_values(tr.log_exp_m1)
150    close_to_logical(vals > 0, True, tol)
151
152
153def test_logodds():
154    check_transform(tr.logodds, Unit)
155
156    check_jacobian_det(tr.logodds, Unit, elemwise=True)
157    check_jacobian_det(tr.logodds, Vector(Unit, 2), tt.dvector, [0.5, 0.5], elemwise=True)
158
159    vals = get_values(tr.logodds)
160    close_to_logical(vals > 0, True, tol)
161    close_to_logical(vals < 1, True, tol)
162
163
164def test_lowerbound():
165    trans = tr.lowerbound(0.0)
166    check_transform(trans, Rplusbig)
167
168    check_jacobian_det(trans, Rplusbig, elemwise=True)
169    check_jacobian_det(trans, Vector(Rplusbig, 2), tt.dvector, [0, 0], elemwise=True)
170
171    vals = get_values(trans)
172    close_to_logical(vals > 0, True, tol)
173
174
175def test_upperbound():
176    trans = tr.upperbound(0.0)
177    check_transform(trans, Rminusbig)
178
179    check_jacobian_det(trans, Rminusbig, elemwise=True)
180    check_jacobian_det(trans, Vector(Rminusbig, 2), tt.dvector, [-1, -1], elemwise=True)
181
182    vals = get_values(trans)
183    close_to_logical(vals < 0, True, tol)
184
185
186def test_interval():
187    for a, b in [(-4, 5.5), (0.1, 0.7), (-10, 4.3)]:
188        domain = Unit * np.float64(b - a) + np.float64(a)
189        trans = tr.interval(a, b)
190        check_transform(trans, domain)
191
192        check_jacobian_det(trans, domain, elemwise=True)
193
194        vals = get_values(trans)
195        close_to_logical(vals > a, True, tol)
196        close_to_logical(vals < b, True, tol)
197
198
199@pytest.mark.skipif(theano.config.floatX == "float32", reason="Test fails on 32 bit")
200def test_interval_near_boundary():
201    lb = -1.0
202    ub = 1e-7
203    x0 = np.nextafter(ub, lb)
204
205    with pm.Model() as model:
206        pm.Uniform("x", testval=x0, lower=lb, upper=ub)
207
208    log_prob = model.check_test_point()
209    np.testing.assert_allclose(log_prob.values, np.array([-52.68]))
210
211
212def test_circular():
213    trans = tr.circular
214    check_transform(trans, Circ)
215
216    check_jacobian_det(trans, Circ)
217
218    vals = get_values(trans)
219    close_to_logical(vals > -np.pi, True, tol)
220    close_to_logical(vals < np.pi, True, tol)
221
222    assert isinstance(trans.forward(1), tt.TensorConstant)
223
224
225def test_ordered():
226    check_vector_transform(tr.ordered, SortedVector(6))
227
228    check_jacobian_det(tr.ordered, Vector(R, 2), tt.dvector, np.array([0, 0]), elemwise=False)
229
230    vals = get_values(tr.ordered, Vector(R, 3), tt.dvector, np.zeros(3))
231    close_to_logical(np.diff(vals) >= 0, True, tol)
232
233
234@pytest.mark.xfail(condition=(theano.config.floatX == "float32"), reason="Fails on float32")
235def test_chain():
236    chain_tranf = tr.Chain([tr.logodds, tr.ordered])
237    check_vector_transform(chain_tranf, UnitSortedVector(3))
238
239    check_jacobian_det(chain_tranf, Vector(R, 4), tt.dvector, np.zeros(4), elemwise=False)
240
241    vals = get_values(chain_tranf, Vector(R, 5), tt.dvector, np.zeros(5))
242    close_to_logical(np.diff(vals) >= 0, True, tol)
243
244
245class TestElementWiseLogp(SeededTest):
246    def build_model(self, distfam, params, shape, transform, testval=None):
247        if testval is not None:
248            testval = pm.floatX(testval)
249        with pm.Model() as m:
250            distfam("x", shape=shape, transform=transform, testval=testval, **params)
251        return m
252
253    def check_transform_elementwise_logp(self, model):
254        x0 = model.deterministics[0]
255        x = model.free_RVs[0]
256        assert x.ndim == x.logp_elemwiset.ndim
257
258        pt = model.test_point
259        array = np.random.randn(*pt[x.name].shape)
260        pt[x.name] = array
261        dist = x.distribution
262        logp_nojac = x0.distribution.logp(dist.transform_used.backward(array))
263        jacob_det = dist.transform_used.jacobian_det(theano.shared(array))
264        assert x.logp_elemwiset.ndim == jacob_det.ndim
265
266        elementwiselogp = logp_nojac + jacob_det
267
268        close_to(x.logp_elemwise(pt), elementwiselogp.eval(), tol)
269
270    def check_vectortransform_elementwise_logp(self, model, vect_opt=0):
271        x0 = model.deterministics[0]
272        x = model.free_RVs[0]
273        assert (x.ndim - 1) == x.logp_elemwiset.ndim
274
275        pt = model.test_point
276        array = np.random.randn(*pt[x.name].shape)
277        pt[x.name] = array
278        dist = x.distribution
279        logp_nojac = x0.distribution.logp(dist.transform_used.backward(array))
280        jacob_det = dist.transform_used.jacobian_det(theano.shared(array))
281        assert x.logp_elemwiset.ndim == jacob_det.ndim
282
283        if vect_opt == 0:
284            # the original distribution is univariate
285            elementwiselogp = logp_nojac.sum(axis=-1) + jacob_det
286        else:
287            elementwiselogp = logp_nojac + jacob_det
288        # Hack to get relative tolerance
289        a = x.logp_elemwise(pt)
290        b = elementwiselogp.eval()
291        close_to(a, b, np.abs(0.5 * (a + b) * tol))
292
293    @pytest.mark.parametrize(
294        "sd,shape",
295        [
296            (2.5, 2),
297            (5.0, (2, 3)),
298            (np.ones(3) * 10.0, (4, 3)),
299        ],
300    )
301    def test_half_normal(self, sd, shape):
302        model = self.build_model(pm.HalfNormal, {"sd": sd}, shape=shape, transform=tr.log)
303        self.check_transform_elementwise_logp(model)
304
305    @pytest.mark.parametrize("lam,shape", [(2.5, 2), (5.0, (2, 3)), (np.ones(3), (4, 3))])
306    def test_exponential(self, lam, shape):
307        model = self.build_model(pm.Exponential, {"lam": lam}, shape=shape, transform=tr.log)
308        self.check_transform_elementwise_logp(model)
309
310    @pytest.mark.parametrize(
311        "a,b,shape",
312        [
313            (1.0, 1.0, 2),
314            (0.5, 0.5, (2, 3)),
315            (np.ones(3), np.ones(3), (4, 3)),
316        ],
317    )
318    def test_beta(self, a, b, shape):
319        model = self.build_model(
320            pm.Beta, {"alpha": a, "beta": b}, shape=shape, transform=tr.logodds
321        )
322        self.check_transform_elementwise_logp(model)
323
324    @pytest.mark.parametrize(
325        "lower,upper,shape",
326        [
327            (0.0, 1.0, 2),
328            (0.5, 5.5, (2, 3)),
329            (pm.floatX(np.zeros(3)), pm.floatX(np.ones(3)), (4, 3)),
330        ],
331    )
332    def test_uniform(self, lower, upper, shape):
333        interval = tr.Interval(lower, upper)
334        model = self.build_model(
335            pm.Uniform, {"lower": lower, "upper": upper}, shape=shape, transform=interval
336        )
337        self.check_transform_elementwise_logp(model)
338
339    @pytest.mark.parametrize(
340        "mu,kappa,shape", [(0.0, 1.0, 2), (-0.5, 5.5, (2, 3)), (np.zeros(3), np.ones(3), (4, 3))]
341    )
342    def test_vonmises(self, mu, kappa, shape):
343        model = self.build_model(
344            pm.VonMises, {"mu": mu, "kappa": kappa}, shape=shape, transform=tr.circular
345        )
346        self.check_transform_elementwise_logp(model)
347
348    @pytest.mark.parametrize(
349        "a,shape", [(np.ones(2), 2), (np.ones((2, 3)) * 0.5, (2, 3)), (np.ones(3), (4, 3))]
350    )
351    def test_dirichlet(self, a, shape):
352        model = self.build_model(pm.Dirichlet, {"a": a}, shape=shape, transform=tr.stick_breaking)
353        self.check_vectortransform_elementwise_logp(model, vect_opt=1)
354
355    def test_normal_ordered(self):
356        model = self.build_model(
357            pm.Normal,
358            {"mu": 0.0, "sd": 1.0},
359            shape=3,
360            testval=np.asarray([-1.0, 1.0, 4.0]),
361            transform=tr.ordered,
362        )
363        self.check_vectortransform_elementwise_logp(model, vect_opt=0)
364
365    @pytest.mark.parametrize(
366        "sd,shape",
367        [
368            (2.5, (2,)),
369            (np.ones(3), (4, 3)),
370        ],
371    )
372    @pytest.mark.xfail(condition=(theano.config.floatX == "float32"), reason="Fails on float32")
373    def test_half_normal_ordered(self, sd, shape):
374        testval = np.sort(np.abs(np.random.randn(*shape)))
375        model = self.build_model(
376            pm.HalfNormal,
377            {"sd": sd},
378            shape=shape,
379            testval=testval,
380            transform=tr.Chain([tr.log, tr.ordered]),
381        )
382        self.check_vectortransform_elementwise_logp(model, vect_opt=0)
383
384    @pytest.mark.parametrize("lam,shape", [(2.5, (2,)), (np.ones(3), (4, 3))])
385    def test_exponential_ordered(self, lam, shape):
386        testval = np.sort(np.abs(np.random.randn(*shape)))
387        model = self.build_model(
388            pm.Exponential,
389            {"lam": lam},
390            shape=shape,
391            testval=testval,
392            transform=tr.Chain([tr.log, tr.ordered]),
393        )
394        self.check_vectortransform_elementwise_logp(model, vect_opt=0)
395
396    @pytest.mark.parametrize(
397        "a,b,shape",
398        [
399            (1.0, 1.0, (2,)),
400            (np.ones(3), np.ones(3), (4, 3)),
401        ],
402    )
403    def test_beta_ordered(self, a, b, shape):
404        testval = np.sort(np.abs(np.random.rand(*shape)))
405        model = self.build_model(
406            pm.Beta,
407            {"alpha": a, "beta": b},
408            shape=shape,
409            testval=testval,
410            transform=tr.Chain([tr.logodds, tr.ordered]),
411        )
412        self.check_vectortransform_elementwise_logp(model, vect_opt=0)
413
414    @pytest.mark.parametrize(
415        "lower,upper,shape",
416        [(0.0, 1.0, (2,)), (pm.floatX(np.zeros(3)), pm.floatX(np.ones(3)), (4, 3))],
417    )
418    def test_uniform_ordered(self, lower, upper, shape):
419        interval = tr.Interval(lower, upper)
420        testval = np.sort(np.abs(np.random.rand(*shape)))
421        model = self.build_model(
422            pm.Uniform,
423            {"lower": lower, "upper": upper},
424            shape=shape,
425            testval=testval,
426            transform=tr.Chain([interval, tr.ordered]),
427        )
428        self.check_vectortransform_elementwise_logp(model, vect_opt=0)
429
430    @pytest.mark.parametrize(
431        "mu,kappa,shape", [(0.0, 1.0, (2,)), (np.zeros(3), np.ones(3), (4, 3))]
432    )
433    def test_vonmises_ordered(self, mu, kappa, shape):
434        testval = np.sort(np.abs(np.random.rand(*shape)))
435        model = self.build_model(
436            pm.VonMises,
437            {"mu": mu, "kappa": kappa},
438            shape=shape,
439            testval=testval,
440            transform=tr.Chain([tr.circular, tr.ordered]),
441        )
442        self.check_vectortransform_elementwise_logp(model, vect_opt=0)
443
444    @pytest.mark.parametrize(
445        "lower,upper,shape,transform",
446        [
447            (0.0, 1.0, (2,), tr.stick_breaking),
448            (0.5, 5.5, (2, 3), tr.stick_breaking),
449            (np.zeros(3), np.ones(3), (4, 3), tr.Chain([tr.sum_to_1, tr.logodds])),
450        ],
451    )
452    def test_uniform_other(self, lower, upper, shape, transform):
453        testval = np.ones(shape) / shape[-1]
454        model = self.build_model(
455            pm.Uniform,
456            {"lower": lower, "upper": upper},
457            shape=shape,
458            testval=testval,
459            transform=transform,
460        )
461        self.check_vectortransform_elementwise_logp(model, vect_opt=0)
462
463    @pytest.mark.parametrize(
464        "mu,cov,shape",
465        [
466            (np.zeros(2), np.diag(np.ones(2)), (2,)),
467            (np.zeros(3), np.diag(np.ones(3)), (4, 3)),
468        ],
469    )
470    def test_mvnormal_ordered(self, mu, cov, shape):
471        testval = np.sort(np.random.randn(*shape))
472        model = self.build_model(
473            pm.MvNormal, {"mu": mu, "cov": cov}, shape=shape, testval=testval, transform=tr.ordered
474        )
475        self.check_vectortransform_elementwise_logp(model, vect_opt=1)
476