1#   Copyright 2020 The PyMC Developers
2#
3#   Licensed under the Apache License, Version 2.0 (the "License");
4#   you may not use this file except in compliance with the License.
5#   You may obtain a copy of the License at
6#
7#       http://www.apache.org/licenses/LICENSE-2.0
8#
9#   Unless required by applicable law or agreed to in writing, software
10#   distributed under the License is distributed on an "AS IS" BASIS,
11#   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12#   See the License for the specific language governing permissions and
13#   limitations under the License.
14
15import collections
16import os
17import shutil
18
19import numpy as np
20import numpy.testing as npt
21import pytest
22import theano
23
24from pymc3.backends import base
25from pymc3.tests import models
26
27
28class ModelBackendSetupTestCase:
29    """Set up a backend trace.
30
31    Provides the attributes
32    - test_point
33    - model
34    - strace
35    - draws
36
37    Children must define
38    - backend
39    - name
40    - shape
41
42    Children may define
43    - sampler_vars
44    """
45
46    def setup_method(self):
47        self.test_point, self.model, _ = models.beta_bernoulli(self.shape)
48        with self.model:
49            self.strace = self.backend(self.name)
50        self.draws, self.chain = 3, 0
51        if not hasattr(self, "sampler_vars"):
52            self.sampler_vars = None
53        if self.sampler_vars is not None:
54            assert self.strace.supports_sampler_stats
55            self.strace.setup(self.draws, self.chain, self.sampler_vars)
56        else:
57            self.strace.setup(self.draws, self.chain)
58
59    def test_append_invalid(self):
60        if self.sampler_vars is not None:
61            with pytest.raises(ValueError):
62                self.strace.setup(self.draws, self.chain)
63            with pytest.raises(ValueError):
64                vars = self.sampler_vars + [{"a": bool}]
65                self.strace.setup(self.draws, self.chain, vars)
66        else:
67            with pytest.raises((ValueError, TypeError)):
68                self.strace.setup(self.draws, self.chain, [{"a": bool}])
69
70    def test_append(self):
71        if self.sampler_vars is None:
72            self.strace.setup(self.draws, self.chain)
73            assert len(self.strace) == 0
74        else:
75            self.strace.setup(self.draws, self.chain, self.sampler_vars)
76            assert len(self.strace) == 0
77
78    def test_double_close(self):
79        self.strace.close()
80        self.strace.close()
81
82    def teardown_method(self):
83        if self.name is not None:
84            remove_file_or_directory(self.name)
85
86
87class StatsTestCase:
88    """Test for init and setup of backups.
89
90    Provides the attributes
91    - test_point
92    - model
93    - draws
94
95    Children must define
96    - backend
97    - name
98    - shape
99    """
100
101    def setup_method(self):
102        self.test_point, self.model, _ = models.beta_bernoulli(self.shape)
103        self.draws, self.chain = 3, 0
104
105    def test_bad_dtype(self):
106        bad_vars = [{"a": np.float64}, {"a": bool}]
107        good_vars = [{"a": np.float64}, {"a": np.float64}]
108        with self.model:
109            strace = self.backend(self.name)
110        with pytest.raises((ValueError, TypeError)):
111            strace.setup(self.draws, self.chain, bad_vars)
112        strace.setup(self.draws, self.chain, good_vars)
113        if strace.supports_sampler_stats:
114            assert strace.stat_names == {"a"}
115        else:
116            with pytest.raises((ValueError, TypeError)):
117                strace.setup(self.draws, self.chain, good_vars)
118
119    def teardown_method(self):
120        if self.name is not None:
121            remove_file_or_directory(self.name)
122
123
124class ModelBackendSampledTestCase:
125    """Setup and sample a backend trace.
126
127    Provides the attributes
128    - test_point
129    - model
130    - mtrace (MultiTrace object)
131    - draws
132    - expected
133        Expected values mapped to chain number and variable name.
134    - stat_dtypes
135
136    Children must define
137    - backend
138    - name
139    - shape
140
141    Children may define
142    - sampler_vars
143    - write_partial_chain
144    """
145
146    @classmethod
147    def setup_class(cls):
148        cls.test_point, cls.model, _ = models.beta_bernoulli(cls.shape)
149
150        if hasattr(cls, "write_partial_chain") and cls.write_partial_chain is True:
151            cls.chain_vars = cls.model.unobserved_RVs[1:]
152        else:
153            cls.chain_vars = cls.model.unobserved_RVs
154
155        with cls.model:
156            strace0 = cls.backend(cls.name, vars=cls.chain_vars)
157            strace1 = cls.backend(cls.name, vars=cls.chain_vars)
158
159        if not hasattr(cls, "sampler_vars"):
160            cls.sampler_vars = None
161
162        cls.draws = 5
163        if cls.sampler_vars is not None:
164            strace0.setup(cls.draws, chain=0, sampler_vars=cls.sampler_vars)
165            strace1.setup(cls.draws, chain=1, sampler_vars=cls.sampler_vars)
166        else:
167            strace0.setup(cls.draws, chain=0)
168            strace1.setup(cls.draws, chain=1)
169
170        varnames = list(cls.test_point.keys())
171        shapes = {varname: value.shape for varname, value in cls.test_point.items()}
172        dtypes = {varname: value.dtype for varname, value in cls.test_point.items()}
173
174        cls.expected = {0: {}, 1: {}}
175        for varname in varnames:
176            mcmc_shape = (cls.draws,) + shapes[varname]
177            values = np.arange(cls.draws * np.prod(shapes[varname]), dtype=dtypes[varname])
178            cls.expected[0][varname] = values.reshape(mcmc_shape)
179            cls.expected[1][varname] = values.reshape(mcmc_shape) * 100
180
181        if cls.sampler_vars is not None:
182            cls.expected_stats = {0: [], 1: []}
183            for vars in cls.sampler_vars:
184                stats = {}
185                cls.expected_stats[0].append(stats)
186                cls.expected_stats[1].append(stats)
187                for key, dtype in vars.items():
188                    if dtype == bool:
189                        stats[key] = np.zeros(cls.draws, dtype=dtype)
190                    else:
191                        stats[key] = np.arange(cls.draws, dtype=dtype)
192
193        for idx in range(cls.draws):
194            point0 = {varname: cls.expected[0][varname][idx, ...] for varname in varnames}
195            point1 = {varname: cls.expected[1][varname][idx, ...] for varname in varnames}
196            if cls.sampler_vars is not None:
197                stats1 = [
198                    {key: val[idx] for key, val in stats.items()} for stats in cls.expected_stats[0]
199                ]
200                stats2 = [
201                    {key: val[idx] for key, val in stats.items()} for stats in cls.expected_stats[1]
202                ]
203                strace0.record(point=point0, sampler_stats=stats1)
204                strace1.record(point=point1, sampler_stats=stats2)
205            else:
206                strace0.record(point=point0)
207                strace1.record(point=point1)
208        strace0.close()
209        strace1.close()
210        cls.mtrace = base.MultiTrace([strace0, strace1])
211
212        cls.stat_dtypes = {}
213        cls.stats_counts = collections.Counter()
214        for stats in cls.sampler_vars or []:
215            cls.stat_dtypes.update(stats)
216            cls.stats_counts.update(stats.keys())
217
218    @classmethod
219    def teardown_class(cls):
220        if cls.name is not None:
221            remove_file_or_directory(cls.name)
222
223    def test_varnames_nonempty(self):
224        # Make sure the test_point has variables names because many
225        # tests rely on looping through these and would pass silently
226        # if the loop is never entered.
227        assert list(self.test_point.keys())
228
229    def test_stat_names(self):
230        names = set()
231        for vars in self.sampler_vars or []:
232            names.update(vars.keys())
233        assert self.mtrace.stat_names == names
234
235
236class SamplingTestCase(ModelBackendSetupTestCase):
237    """Test backend sampling.
238
239    Children must define
240    - backend
241    - name
242    - shape
243    """
244
245    def record_point(self, val):
246        point = {varname: np.tile(val, value.shape) for varname, value in self.test_point.items()}
247        if self.sampler_vars is not None:
248            stats = [{key: dtype(val) for key, dtype in vars.items()} for vars in self.sampler_vars]
249            self.strace.record(point=point, sampler_stats=stats)
250        else:
251            self.strace.record(point=point)
252
253    @pytest.mark.xfail(condition=(theano.config.floatX == "float32"), reason="Fails on float32")
254    def test_standard_close(self):
255        for idx in range(self.draws):
256            self.record_point(idx)
257        self.strace.close()
258
259        for varname in self.test_point.keys():
260            npt.assert_equal(
261                self.strace.get_values(varname)[0, ...], np.zeros(self.strace.var_shapes[varname])
262            )
263            last_idx = self.draws - 1
264            npt.assert_equal(
265                self.strace.get_values(varname)[last_idx, ...],
266                np.tile(last_idx, self.strace.var_shapes[varname]),
267            )
268        if self.sampler_vars:
269            for varname in self.strace.stat_names:
270                vals = self.strace.get_sampler_stats(varname)
271                assert vals.shape[0] == self.draws
272
273    def test_missing_stats(self):
274        if self.sampler_vars is not None:
275            with pytest.raises(ValueError):
276                self.strace.record(point=self.test_point)
277
278    def test_clean_interrupt(self):
279        self.record_point(0)
280        self.strace.close()
281        for varname in self.test_point.keys():
282            assert self.strace.get_values(varname).shape[0] == 1
283        for statname in self.strace.stat_names:
284            assert self.strace.get_sampler_stats(statname).shape[0] == 1
285
286
287class SelectionTestCase(ModelBackendSampledTestCase):
288    """Test backend selection.
289
290    Children must define
291    - backend
292    - name
293    - shape
294    """
295
296    @pytest.mark.xfail(condition=(theano.config.floatX == "float32"), reason="Fails on float32")
297    def test_get_values_default(self):
298        for varname in self.test_point.keys():
299            expected = np.concatenate([self.expected[chain][varname] for chain in [0, 1]])
300            result = self.mtrace.get_values(varname)
301            npt.assert_equal(result, expected)
302
303    @pytest.mark.xfail(condition=(theano.config.floatX == "float32"), reason="Fails on float32")
304    def test_get_values_nocombine_burn_keyword(self):
305        burn = 2
306        for varname in self.test_point.keys():
307            expected = [self.expected[0][varname][burn:], self.expected[1][varname][burn:]]
308            result = self.mtrace.get_values(varname, burn=burn, combine=False)
309            npt.assert_equal(result, expected)
310
311    def test_len(self):
312        assert len(self.mtrace) == self.draws
313
314    @pytest.mark.xfail(condition=(theano.config.floatX == "float32"), reason="Fails on float32")
315    def test_dtypes(self):
316        for varname in self.test_point.keys():
317            assert (
318                self.expected[0][varname].dtype == self.mtrace.get_values(varname, chains=0).dtype
319            )
320
321        for statname in self.mtrace.stat_names:
322            assert (
323                self.stat_dtypes[statname]
324                == self.mtrace.get_sampler_stats(statname, chains=0).dtype
325            )
326
327    def test_get_values_nocombine_thin_keyword(self):
328        thin = 2
329        for varname in self.test_point.keys():
330            expected = [self.expected[0][varname][::thin], self.expected[1][varname][::thin]]
331            result = self.mtrace.get_values(varname, thin=thin, combine=False)
332            npt.assert_equal(result, expected)
333
334    def test_get_point(self):
335        idx = 2
336        result = self.mtrace.point(idx)
337        for varname in self.test_point.keys():
338            expected = self.expected[1][varname][idx]
339            npt.assert_equal(result[varname], expected)
340
341    def test_get_slice(self):
342        expected = []
343        for chain in [0, 1]:
344            expected.append(
345                {varname: self.expected[chain][varname][2:] for varname in self.mtrace.varnames}
346            )
347        result = self.mtrace[2:]
348        for chain in [0, 1]:
349            for varname in self.test_point.keys():
350                npt.assert_equal(
351                    result.get_values(varname, chains=[chain]), expected[chain][varname]
352                )
353
354    def test_get_slice_step(self):
355        result = self.mtrace[:]
356        assert len(result) == self.draws
357
358        result = self.mtrace[::2]
359        assert len(result) == self.draws // 2
360
361    def test_get_slice_neg_step(self):
362        if hasattr(self, "skip_test_get_slice_neg_step"):
363            return
364
365        result = self.mtrace[::-1]
366        assert len(result) == self.draws
367
368        result = self.mtrace[::-2]
369        assert len(result) == self.draws // 2
370
371    def test_get_neg_slice(self):
372        expected = []
373        for chain in [0, 1]:
374            expected.append(
375                {varname: self.expected[chain][varname][-2:] for varname in self.mtrace.varnames}
376            )
377        result = self.mtrace[-2:]
378        for chain in [0, 1]:
379            for varname in self.test_point.keys():
380                npt.assert_equal(
381                    result.get_values(varname, chains=[chain]), expected[chain][varname]
382                )
383
384    def test_get_values_one_chain(self):
385        for varname in self.test_point.keys():
386            expected = self.expected[0][varname]
387            result = self.mtrace.get_values(varname, chains=[0])
388            npt.assert_equal(result, expected)
389
390    def test_get_values_nocombine_chains_reversed(self):
391        for varname in self.test_point.keys():
392            expected = [self.expected[1][varname], self.expected[0][varname]]
393            result = self.mtrace.get_values(varname, chains=[1, 0], combine=False)
394            npt.assert_equal(result, expected)
395
396    def test_nchains(self):
397        self.mtrace.nchains == 2
398
399    def test_get_values_one_chain_int_arg(self):
400        for varname in self.test_point.keys():
401            npt.assert_equal(
402                self.mtrace.get_values(varname, chains=[0]),
403                self.mtrace.get_values(varname, chains=0),
404            )
405
406    def test_get_values_combine(self):
407        for varname in self.test_point.keys():
408            expected = np.concatenate([self.expected[chain][varname] for chain in [0, 1]])
409            result = self.mtrace.get_values(varname, combine=True)
410            npt.assert_equal(result, expected)
411
412    def test_get_values_combine_burn_arg(self):
413        burn = 2
414        for varname in self.test_point.keys():
415            expected = np.concatenate([self.expected[chain][varname][burn:] for chain in [0, 1]])
416            result = self.mtrace.get_values(varname, combine=True, burn=burn)
417            npt.assert_equal(result, expected)
418
419    def test_get_values_combine_thin_arg(self):
420        thin = 2
421        for varname in self.test_point.keys():
422            expected = np.concatenate([self.expected[chain][varname][::thin] for chain in [0, 1]])
423            result = self.mtrace.get_values(varname, combine=True, thin=thin)
424            npt.assert_equal(result, expected)
425
426    def test_getitem_equivalence(self):
427        mtrace = self.mtrace
428        for varname in self.test_point.keys():
429            npt.assert_equal(mtrace[varname], mtrace.get_values(varname, combine=True))
430            npt.assert_equal(mtrace[varname, 2:], mtrace.get_values(varname, burn=2, combine=True))
431            npt.assert_equal(
432                mtrace[varname, 2::2], mtrace.get_values(varname, burn=2, thin=2, combine=True)
433            )
434
435    def test_selection_method_equivalence(self):
436        varname = self.mtrace.varnames[0]
437        mtrace = self.mtrace
438        npt.assert_equal(mtrace.get_values(varname), mtrace[varname])
439        npt.assert_equal(mtrace[varname], mtrace.__getattr__(varname))
440
441
442class DumpLoadTestCase(ModelBackendSampledTestCase):
443    """Test equality of a dumped and loaded trace with original.
444
445    Children must define
446    - backend
447    - load_func
448        Function to load dumped backend
449    - name
450    - shape
451    """
452
453    @classmethod
454    def setup_class(cls):
455        super().setup_class()
456        try:
457            with cls.model:
458                cls.dumped = cls.load_func(cls.name)
459        except:
460            remove_file_or_directory(cls.name)
461            raise
462
463    @classmethod
464    def teardown_class(cls):
465        remove_file_or_directory(cls.name)
466
467    def test_nchains(self):
468        assert self.mtrace.nchains == self.dumped.nchains
469
470    def test_varnames(self):
471        trace_names = list(sorted(self.mtrace.varnames))
472        dumped_names = list(sorted(self.dumped.varnames))
473        assert trace_names == dumped_names
474
475    def test_values(self):
476        trace = self.mtrace
477        dumped = self.dumped
478        for chain in trace.chains:
479            for varname in self.chain_vars:
480                data = trace.get_values(varname, chains=[chain])
481                dumped_data = dumped.get_values(varname, chains=[chain])
482                npt.assert_equal(data, dumped_data)
483
484
485class BackendEqualityTestCase(ModelBackendSampledTestCase):
486    """Test equality of attirbutes from two backends.
487
488    Children must define
489    - backend0
490    - backend1
491    - name0
492    - name1
493    - shape
494    """
495
496    @classmethod
497    def setup_class(cls):
498        cls.backend = cls.backend0
499        cls.name = cls.name0
500        super().setup_class()
501        cls.mtrace0 = cls.mtrace
502
503        cls.backend = cls.backend1
504        cls.name = cls.name1
505        super().setup_class()
506        cls.mtrace1 = cls.mtrace
507
508    @classmethod
509    def teardown_class(cls):
510        for name in [cls.name0, cls.name1]:
511            if name is not None:
512                remove_file_or_directory(name)
513
514    def test_chain_length(self):
515        assert self.mtrace0.nchains == self.mtrace1.nchains
516        assert len(self.mtrace0) == len(self.mtrace1)
517
518    @pytest.mark.xfail(condition=(theano.config.floatX == "float32"), reason="Fails on float32")
519    def test_dtype(self):
520        for varname in self.test_point.keys():
521            assert (
522                self.mtrace0.get_values(varname, chains=0).dtype
523                == self.mtrace1.get_values(varname, chains=0).dtype
524            )
525
526    def test_number_of_draws(self):
527        for varname in self.test_point.keys():
528            values0 = self.mtrace0.get_values(varname, combine=False, squeeze=False)
529            values1 = self.mtrace1.get_values(varname, combine=False, squeeze=False)
530            assert values0[0].shape[0] == self.draws
531            assert values1[0].shape[0] == self.draws
532
533    def test_get_item(self):
534        for varname in self.test_point.keys():
535            npt.assert_equal(self.mtrace0[varname], self.mtrace1[varname])
536
537    def test_get_values(self):
538        for varname in self.test_point.keys():
539            for cf in [False, True]:
540                npt.assert_equal(
541                    self.mtrace0.get_values(varname, combine=cf),
542                    self.mtrace1.get_values(varname, combine=cf),
543                )
544
545    def test_get_values_no_squeeze(self):
546        for varname in self.test_point.keys():
547            npt.assert_equal(
548                self.mtrace0.get_values(varname, combine=False, squeeze=False),
549                self.mtrace1.get_values(varname, combine=False, squeeze=False),
550            )
551
552    def test_get_values_combine_and_no_squeeze(self):
553        for varname in self.test_point.keys():
554            npt.assert_equal(
555                self.mtrace0.get_values(varname, combine=True, squeeze=False),
556                self.mtrace1.get_values(varname, combine=True, squeeze=False),
557            )
558
559    def test_get_values_with_burn(self):
560        for varname in self.test_point.keys():
561            for cf in [False, True]:
562                npt.assert_equal(
563                    self.mtrace0.get_values(varname, combine=cf, burn=3),
564                    self.mtrace1.get_values(varname, combine=cf, burn=3),
565                )
566                # Burn to one value.
567                npt.assert_equal(
568                    self.mtrace0.get_values(varname, combine=cf, burn=self.draws - 1),
569                    self.mtrace1.get_values(varname, combine=cf, burn=self.draws - 1),
570                )
571
572    def test_get_values_with_thin(self):
573        for varname in self.test_point.keys():
574            for cf in [False, True]:
575                npt.assert_equal(
576                    self.mtrace0.get_values(varname, combine=cf, thin=2),
577                    self.mtrace1.get_values(varname, combine=cf, thin=2),
578                )
579
580    def test_get_values_with_burn_and_thin(self):
581        for varname in self.test_point.keys():
582            for cf in [False, True]:
583                npt.assert_equal(
584                    self.mtrace0.get_values(varname, combine=cf, burn=2, thin=2),
585                    self.mtrace1.get_values(varname, combine=cf, burn=2, thin=2),
586                )
587
588    def test_get_values_with_chains_arg(self):
589        for varname in self.test_point.keys():
590            for cf in [False, True]:
591                npt.assert_equal(
592                    self.mtrace0.get_values(varname, chains=[0], combine=cf),
593                    self.mtrace1.get_values(varname, chains=[0], combine=cf),
594                )
595
596    def test_get_point(self):
597        npoint, spoint = self.mtrace0[4], self.mtrace1[4]
598        for varname in self.test_point.keys():
599            npt.assert_equal(npoint[varname], spoint[varname])
600
601    def test_point_with_chain_arg(self):
602        npoint = self.mtrace0.point(4, chain=0)
603        spoint = self.mtrace1.point(4, chain=0)
604        for varname in self.test_point.keys():
605            npt.assert_equal(npoint[varname], spoint[varname])
606
607
608def remove_file_or_directory(name):
609    try:
610        os.remove(name)
611    except OSError:
612        shutil.rmtree(name, ignore_errors=True)
613