1# Copyright 2020 The PyMC Developers 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14 15import collections 16import os 17import shutil 18 19import numpy as np 20import numpy.testing as npt 21import pytest 22import theano 23 24from pymc3.backends import base 25from pymc3.tests import models 26 27 28class ModelBackendSetupTestCase: 29 """Set up a backend trace. 30 31 Provides the attributes 32 - test_point 33 - model 34 - strace 35 - draws 36 37 Children must define 38 - backend 39 - name 40 - shape 41 42 Children may define 43 - sampler_vars 44 """ 45 46 def setup_method(self): 47 self.test_point, self.model, _ = models.beta_bernoulli(self.shape) 48 with self.model: 49 self.strace = self.backend(self.name) 50 self.draws, self.chain = 3, 0 51 if not hasattr(self, "sampler_vars"): 52 self.sampler_vars = None 53 if self.sampler_vars is not None: 54 assert self.strace.supports_sampler_stats 55 self.strace.setup(self.draws, self.chain, self.sampler_vars) 56 else: 57 self.strace.setup(self.draws, self.chain) 58 59 def test_append_invalid(self): 60 if self.sampler_vars is not None: 61 with pytest.raises(ValueError): 62 self.strace.setup(self.draws, self.chain) 63 with pytest.raises(ValueError): 64 vars = self.sampler_vars + [{"a": bool}] 65 self.strace.setup(self.draws, self.chain, vars) 66 else: 67 with pytest.raises((ValueError, TypeError)): 68 self.strace.setup(self.draws, self.chain, [{"a": bool}]) 69 70 def test_append(self): 71 if self.sampler_vars is None: 72 self.strace.setup(self.draws, self.chain) 73 assert len(self.strace) == 0 74 else: 75 self.strace.setup(self.draws, self.chain, self.sampler_vars) 76 assert len(self.strace) == 0 77 78 def test_double_close(self): 79 self.strace.close() 80 self.strace.close() 81 82 def teardown_method(self): 83 if self.name is not None: 84 remove_file_or_directory(self.name) 85 86 87class StatsTestCase: 88 """Test for init and setup of backups. 89 90 Provides the attributes 91 - test_point 92 - model 93 - draws 94 95 Children must define 96 - backend 97 - name 98 - shape 99 """ 100 101 def setup_method(self): 102 self.test_point, self.model, _ = models.beta_bernoulli(self.shape) 103 self.draws, self.chain = 3, 0 104 105 def test_bad_dtype(self): 106 bad_vars = [{"a": np.float64}, {"a": bool}] 107 good_vars = [{"a": np.float64}, {"a": np.float64}] 108 with self.model: 109 strace = self.backend(self.name) 110 with pytest.raises((ValueError, TypeError)): 111 strace.setup(self.draws, self.chain, bad_vars) 112 strace.setup(self.draws, self.chain, good_vars) 113 if strace.supports_sampler_stats: 114 assert strace.stat_names == {"a"} 115 else: 116 with pytest.raises((ValueError, TypeError)): 117 strace.setup(self.draws, self.chain, good_vars) 118 119 def teardown_method(self): 120 if self.name is not None: 121 remove_file_or_directory(self.name) 122 123 124class ModelBackendSampledTestCase: 125 """Setup and sample a backend trace. 126 127 Provides the attributes 128 - test_point 129 - model 130 - mtrace (MultiTrace object) 131 - draws 132 - expected 133 Expected values mapped to chain number and variable name. 134 - stat_dtypes 135 136 Children must define 137 - backend 138 - name 139 - shape 140 141 Children may define 142 - sampler_vars 143 - write_partial_chain 144 """ 145 146 @classmethod 147 def setup_class(cls): 148 cls.test_point, cls.model, _ = models.beta_bernoulli(cls.shape) 149 150 if hasattr(cls, "write_partial_chain") and cls.write_partial_chain is True: 151 cls.chain_vars = cls.model.unobserved_RVs[1:] 152 else: 153 cls.chain_vars = cls.model.unobserved_RVs 154 155 with cls.model: 156 strace0 = cls.backend(cls.name, vars=cls.chain_vars) 157 strace1 = cls.backend(cls.name, vars=cls.chain_vars) 158 159 if not hasattr(cls, "sampler_vars"): 160 cls.sampler_vars = None 161 162 cls.draws = 5 163 if cls.sampler_vars is not None: 164 strace0.setup(cls.draws, chain=0, sampler_vars=cls.sampler_vars) 165 strace1.setup(cls.draws, chain=1, sampler_vars=cls.sampler_vars) 166 else: 167 strace0.setup(cls.draws, chain=0) 168 strace1.setup(cls.draws, chain=1) 169 170 varnames = list(cls.test_point.keys()) 171 shapes = {varname: value.shape for varname, value in cls.test_point.items()} 172 dtypes = {varname: value.dtype for varname, value in cls.test_point.items()} 173 174 cls.expected = {0: {}, 1: {}} 175 for varname in varnames: 176 mcmc_shape = (cls.draws,) + shapes[varname] 177 values = np.arange(cls.draws * np.prod(shapes[varname]), dtype=dtypes[varname]) 178 cls.expected[0][varname] = values.reshape(mcmc_shape) 179 cls.expected[1][varname] = values.reshape(mcmc_shape) * 100 180 181 if cls.sampler_vars is not None: 182 cls.expected_stats = {0: [], 1: []} 183 for vars in cls.sampler_vars: 184 stats = {} 185 cls.expected_stats[0].append(stats) 186 cls.expected_stats[1].append(stats) 187 for key, dtype in vars.items(): 188 if dtype == bool: 189 stats[key] = np.zeros(cls.draws, dtype=dtype) 190 else: 191 stats[key] = np.arange(cls.draws, dtype=dtype) 192 193 for idx in range(cls.draws): 194 point0 = {varname: cls.expected[0][varname][idx, ...] for varname in varnames} 195 point1 = {varname: cls.expected[1][varname][idx, ...] for varname in varnames} 196 if cls.sampler_vars is not None: 197 stats1 = [ 198 {key: val[idx] for key, val in stats.items()} for stats in cls.expected_stats[0] 199 ] 200 stats2 = [ 201 {key: val[idx] for key, val in stats.items()} for stats in cls.expected_stats[1] 202 ] 203 strace0.record(point=point0, sampler_stats=stats1) 204 strace1.record(point=point1, sampler_stats=stats2) 205 else: 206 strace0.record(point=point0) 207 strace1.record(point=point1) 208 strace0.close() 209 strace1.close() 210 cls.mtrace = base.MultiTrace([strace0, strace1]) 211 212 cls.stat_dtypes = {} 213 cls.stats_counts = collections.Counter() 214 for stats in cls.sampler_vars or []: 215 cls.stat_dtypes.update(stats) 216 cls.stats_counts.update(stats.keys()) 217 218 @classmethod 219 def teardown_class(cls): 220 if cls.name is not None: 221 remove_file_or_directory(cls.name) 222 223 def test_varnames_nonempty(self): 224 # Make sure the test_point has variables names because many 225 # tests rely on looping through these and would pass silently 226 # if the loop is never entered. 227 assert list(self.test_point.keys()) 228 229 def test_stat_names(self): 230 names = set() 231 for vars in self.sampler_vars or []: 232 names.update(vars.keys()) 233 assert self.mtrace.stat_names == names 234 235 236class SamplingTestCase(ModelBackendSetupTestCase): 237 """Test backend sampling. 238 239 Children must define 240 - backend 241 - name 242 - shape 243 """ 244 245 def record_point(self, val): 246 point = {varname: np.tile(val, value.shape) for varname, value in self.test_point.items()} 247 if self.sampler_vars is not None: 248 stats = [{key: dtype(val) for key, dtype in vars.items()} for vars in self.sampler_vars] 249 self.strace.record(point=point, sampler_stats=stats) 250 else: 251 self.strace.record(point=point) 252 253 @pytest.mark.xfail(condition=(theano.config.floatX == "float32"), reason="Fails on float32") 254 def test_standard_close(self): 255 for idx in range(self.draws): 256 self.record_point(idx) 257 self.strace.close() 258 259 for varname in self.test_point.keys(): 260 npt.assert_equal( 261 self.strace.get_values(varname)[0, ...], np.zeros(self.strace.var_shapes[varname]) 262 ) 263 last_idx = self.draws - 1 264 npt.assert_equal( 265 self.strace.get_values(varname)[last_idx, ...], 266 np.tile(last_idx, self.strace.var_shapes[varname]), 267 ) 268 if self.sampler_vars: 269 for varname in self.strace.stat_names: 270 vals = self.strace.get_sampler_stats(varname) 271 assert vals.shape[0] == self.draws 272 273 def test_missing_stats(self): 274 if self.sampler_vars is not None: 275 with pytest.raises(ValueError): 276 self.strace.record(point=self.test_point) 277 278 def test_clean_interrupt(self): 279 self.record_point(0) 280 self.strace.close() 281 for varname in self.test_point.keys(): 282 assert self.strace.get_values(varname).shape[0] == 1 283 for statname in self.strace.stat_names: 284 assert self.strace.get_sampler_stats(statname).shape[0] == 1 285 286 287class SelectionTestCase(ModelBackendSampledTestCase): 288 """Test backend selection. 289 290 Children must define 291 - backend 292 - name 293 - shape 294 """ 295 296 @pytest.mark.xfail(condition=(theano.config.floatX == "float32"), reason="Fails on float32") 297 def test_get_values_default(self): 298 for varname in self.test_point.keys(): 299 expected = np.concatenate([self.expected[chain][varname] for chain in [0, 1]]) 300 result = self.mtrace.get_values(varname) 301 npt.assert_equal(result, expected) 302 303 @pytest.mark.xfail(condition=(theano.config.floatX == "float32"), reason="Fails on float32") 304 def test_get_values_nocombine_burn_keyword(self): 305 burn = 2 306 for varname in self.test_point.keys(): 307 expected = [self.expected[0][varname][burn:], self.expected[1][varname][burn:]] 308 result = self.mtrace.get_values(varname, burn=burn, combine=False) 309 npt.assert_equal(result, expected) 310 311 def test_len(self): 312 assert len(self.mtrace) == self.draws 313 314 @pytest.mark.xfail(condition=(theano.config.floatX == "float32"), reason="Fails on float32") 315 def test_dtypes(self): 316 for varname in self.test_point.keys(): 317 assert ( 318 self.expected[0][varname].dtype == self.mtrace.get_values(varname, chains=0).dtype 319 ) 320 321 for statname in self.mtrace.stat_names: 322 assert ( 323 self.stat_dtypes[statname] 324 == self.mtrace.get_sampler_stats(statname, chains=0).dtype 325 ) 326 327 def test_get_values_nocombine_thin_keyword(self): 328 thin = 2 329 for varname in self.test_point.keys(): 330 expected = [self.expected[0][varname][::thin], self.expected[1][varname][::thin]] 331 result = self.mtrace.get_values(varname, thin=thin, combine=False) 332 npt.assert_equal(result, expected) 333 334 def test_get_point(self): 335 idx = 2 336 result = self.mtrace.point(idx) 337 for varname in self.test_point.keys(): 338 expected = self.expected[1][varname][idx] 339 npt.assert_equal(result[varname], expected) 340 341 def test_get_slice(self): 342 expected = [] 343 for chain in [0, 1]: 344 expected.append( 345 {varname: self.expected[chain][varname][2:] for varname in self.mtrace.varnames} 346 ) 347 result = self.mtrace[2:] 348 for chain in [0, 1]: 349 for varname in self.test_point.keys(): 350 npt.assert_equal( 351 result.get_values(varname, chains=[chain]), expected[chain][varname] 352 ) 353 354 def test_get_slice_step(self): 355 result = self.mtrace[:] 356 assert len(result) == self.draws 357 358 result = self.mtrace[::2] 359 assert len(result) == self.draws // 2 360 361 def test_get_slice_neg_step(self): 362 if hasattr(self, "skip_test_get_slice_neg_step"): 363 return 364 365 result = self.mtrace[::-1] 366 assert len(result) == self.draws 367 368 result = self.mtrace[::-2] 369 assert len(result) == self.draws // 2 370 371 def test_get_neg_slice(self): 372 expected = [] 373 for chain in [0, 1]: 374 expected.append( 375 {varname: self.expected[chain][varname][-2:] for varname in self.mtrace.varnames} 376 ) 377 result = self.mtrace[-2:] 378 for chain in [0, 1]: 379 for varname in self.test_point.keys(): 380 npt.assert_equal( 381 result.get_values(varname, chains=[chain]), expected[chain][varname] 382 ) 383 384 def test_get_values_one_chain(self): 385 for varname in self.test_point.keys(): 386 expected = self.expected[0][varname] 387 result = self.mtrace.get_values(varname, chains=[0]) 388 npt.assert_equal(result, expected) 389 390 def test_get_values_nocombine_chains_reversed(self): 391 for varname in self.test_point.keys(): 392 expected = [self.expected[1][varname], self.expected[0][varname]] 393 result = self.mtrace.get_values(varname, chains=[1, 0], combine=False) 394 npt.assert_equal(result, expected) 395 396 def test_nchains(self): 397 self.mtrace.nchains == 2 398 399 def test_get_values_one_chain_int_arg(self): 400 for varname in self.test_point.keys(): 401 npt.assert_equal( 402 self.mtrace.get_values(varname, chains=[0]), 403 self.mtrace.get_values(varname, chains=0), 404 ) 405 406 def test_get_values_combine(self): 407 for varname in self.test_point.keys(): 408 expected = np.concatenate([self.expected[chain][varname] for chain in [0, 1]]) 409 result = self.mtrace.get_values(varname, combine=True) 410 npt.assert_equal(result, expected) 411 412 def test_get_values_combine_burn_arg(self): 413 burn = 2 414 for varname in self.test_point.keys(): 415 expected = np.concatenate([self.expected[chain][varname][burn:] for chain in [0, 1]]) 416 result = self.mtrace.get_values(varname, combine=True, burn=burn) 417 npt.assert_equal(result, expected) 418 419 def test_get_values_combine_thin_arg(self): 420 thin = 2 421 for varname in self.test_point.keys(): 422 expected = np.concatenate([self.expected[chain][varname][::thin] for chain in [0, 1]]) 423 result = self.mtrace.get_values(varname, combine=True, thin=thin) 424 npt.assert_equal(result, expected) 425 426 def test_getitem_equivalence(self): 427 mtrace = self.mtrace 428 for varname in self.test_point.keys(): 429 npt.assert_equal(mtrace[varname], mtrace.get_values(varname, combine=True)) 430 npt.assert_equal(mtrace[varname, 2:], mtrace.get_values(varname, burn=2, combine=True)) 431 npt.assert_equal( 432 mtrace[varname, 2::2], mtrace.get_values(varname, burn=2, thin=2, combine=True) 433 ) 434 435 def test_selection_method_equivalence(self): 436 varname = self.mtrace.varnames[0] 437 mtrace = self.mtrace 438 npt.assert_equal(mtrace.get_values(varname), mtrace[varname]) 439 npt.assert_equal(mtrace[varname], mtrace.__getattr__(varname)) 440 441 442class DumpLoadTestCase(ModelBackendSampledTestCase): 443 """Test equality of a dumped and loaded trace with original. 444 445 Children must define 446 - backend 447 - load_func 448 Function to load dumped backend 449 - name 450 - shape 451 """ 452 453 @classmethod 454 def setup_class(cls): 455 super().setup_class() 456 try: 457 with cls.model: 458 cls.dumped = cls.load_func(cls.name) 459 except: 460 remove_file_or_directory(cls.name) 461 raise 462 463 @classmethod 464 def teardown_class(cls): 465 remove_file_or_directory(cls.name) 466 467 def test_nchains(self): 468 assert self.mtrace.nchains == self.dumped.nchains 469 470 def test_varnames(self): 471 trace_names = list(sorted(self.mtrace.varnames)) 472 dumped_names = list(sorted(self.dumped.varnames)) 473 assert trace_names == dumped_names 474 475 def test_values(self): 476 trace = self.mtrace 477 dumped = self.dumped 478 for chain in trace.chains: 479 for varname in self.chain_vars: 480 data = trace.get_values(varname, chains=[chain]) 481 dumped_data = dumped.get_values(varname, chains=[chain]) 482 npt.assert_equal(data, dumped_data) 483 484 485class BackendEqualityTestCase(ModelBackendSampledTestCase): 486 """Test equality of attirbutes from two backends. 487 488 Children must define 489 - backend0 490 - backend1 491 - name0 492 - name1 493 - shape 494 """ 495 496 @classmethod 497 def setup_class(cls): 498 cls.backend = cls.backend0 499 cls.name = cls.name0 500 super().setup_class() 501 cls.mtrace0 = cls.mtrace 502 503 cls.backend = cls.backend1 504 cls.name = cls.name1 505 super().setup_class() 506 cls.mtrace1 = cls.mtrace 507 508 @classmethod 509 def teardown_class(cls): 510 for name in [cls.name0, cls.name1]: 511 if name is not None: 512 remove_file_or_directory(name) 513 514 def test_chain_length(self): 515 assert self.mtrace0.nchains == self.mtrace1.nchains 516 assert len(self.mtrace0) == len(self.mtrace1) 517 518 @pytest.mark.xfail(condition=(theano.config.floatX == "float32"), reason="Fails on float32") 519 def test_dtype(self): 520 for varname in self.test_point.keys(): 521 assert ( 522 self.mtrace0.get_values(varname, chains=0).dtype 523 == self.mtrace1.get_values(varname, chains=0).dtype 524 ) 525 526 def test_number_of_draws(self): 527 for varname in self.test_point.keys(): 528 values0 = self.mtrace0.get_values(varname, combine=False, squeeze=False) 529 values1 = self.mtrace1.get_values(varname, combine=False, squeeze=False) 530 assert values0[0].shape[0] == self.draws 531 assert values1[0].shape[0] == self.draws 532 533 def test_get_item(self): 534 for varname in self.test_point.keys(): 535 npt.assert_equal(self.mtrace0[varname], self.mtrace1[varname]) 536 537 def test_get_values(self): 538 for varname in self.test_point.keys(): 539 for cf in [False, True]: 540 npt.assert_equal( 541 self.mtrace0.get_values(varname, combine=cf), 542 self.mtrace1.get_values(varname, combine=cf), 543 ) 544 545 def test_get_values_no_squeeze(self): 546 for varname in self.test_point.keys(): 547 npt.assert_equal( 548 self.mtrace0.get_values(varname, combine=False, squeeze=False), 549 self.mtrace1.get_values(varname, combine=False, squeeze=False), 550 ) 551 552 def test_get_values_combine_and_no_squeeze(self): 553 for varname in self.test_point.keys(): 554 npt.assert_equal( 555 self.mtrace0.get_values(varname, combine=True, squeeze=False), 556 self.mtrace1.get_values(varname, combine=True, squeeze=False), 557 ) 558 559 def test_get_values_with_burn(self): 560 for varname in self.test_point.keys(): 561 for cf in [False, True]: 562 npt.assert_equal( 563 self.mtrace0.get_values(varname, combine=cf, burn=3), 564 self.mtrace1.get_values(varname, combine=cf, burn=3), 565 ) 566 # Burn to one value. 567 npt.assert_equal( 568 self.mtrace0.get_values(varname, combine=cf, burn=self.draws - 1), 569 self.mtrace1.get_values(varname, combine=cf, burn=self.draws - 1), 570 ) 571 572 def test_get_values_with_thin(self): 573 for varname in self.test_point.keys(): 574 for cf in [False, True]: 575 npt.assert_equal( 576 self.mtrace0.get_values(varname, combine=cf, thin=2), 577 self.mtrace1.get_values(varname, combine=cf, thin=2), 578 ) 579 580 def test_get_values_with_burn_and_thin(self): 581 for varname in self.test_point.keys(): 582 for cf in [False, True]: 583 npt.assert_equal( 584 self.mtrace0.get_values(varname, combine=cf, burn=2, thin=2), 585 self.mtrace1.get_values(varname, combine=cf, burn=2, thin=2), 586 ) 587 588 def test_get_values_with_chains_arg(self): 589 for varname in self.test_point.keys(): 590 for cf in [False, True]: 591 npt.assert_equal( 592 self.mtrace0.get_values(varname, chains=[0], combine=cf), 593 self.mtrace1.get_values(varname, chains=[0], combine=cf), 594 ) 595 596 def test_get_point(self): 597 npoint, spoint = self.mtrace0[4], self.mtrace1[4] 598 for varname in self.test_point.keys(): 599 npt.assert_equal(npoint[varname], spoint[varname]) 600 601 def test_point_with_chain_arg(self): 602 npoint = self.mtrace0.point(4, chain=0) 603 spoint = self.mtrace1.point(4, chain=0) 604 for varname in self.test_point.keys(): 605 npt.assert_equal(npoint[varname], spoint[varname]) 606 607 608def remove_file_or_directory(name): 609 try: 610 os.remove(name) 611 except OSError: 612 shutil.rmtree(name, ignore_errors=True) 613