1import warnings 2from datetime import datetime as dt 3import numpy as np 4import copy 5import multiprocessing as mp 6import pandas as pd 7import os 8 9from .sqlite import head_to_sql, start_sql 10from .plotting import plot_trace 11from collections import OrderedDict 12 13import six 14if not six.PY3: 15 range = xrange 16 17 18__all__ = ['Sampler_Mixin', 'Hashmap', 'Trace'] 19 20###################### 21# SAMPLER MECHANISMS # 22###################### 23 24 25class Sampler_Mixin(object): 26 """ 27 A Mixin class designed to facilitate code reuse. This should be the parent class of anything that uses the sampling framework in this package. 28 """ 29 def __init__(self): 30 super(Sampler_Mixin, self).__init__() 31 32 def sample(self, n_samples, n_jobs=1): 33 """ 34 Sample from the joint posterior distribution defined by all of the 35 parameters in the gibbs sampler. 36 37 Parameters 38 ---------- 39 n_samples : int 40 number of samples from the joint posterior density to take 41 n_jobs : int 42 number of parallel chains to run. 43 44 Returns 45 ------- 46 Implicitly updates all values in place, returns None 47 """ 48 try: 49 from tqdm import tqdm 50 except ImportError: 51 from .utils import thru_op 52 tqdm = thru_op 53 msg = '`tqdm` is not available. ' 54 msg += 'Using `spvcm.utils.thru_op` in place of `tqdm`.' 55 warnings.warn(msg, stacklevel=2) 56 57 if n_jobs > 1: 58 self._parallel_sample(n_samples, n_jobs) 59 return 60 elif isinstance(self.state, list): 61 self._parallel_sample(n_samples, n_jobs=len(self.state)) 62 return 63 _start = dt.now() 64 try: 65 for _ in tqdm(range(n_samples)): 66 if (self._verbose > 1) and (n_samples % 100 == 0): 67 print('{} Draws to go'.format(n_samples)) 68 self.draw() 69 except KeyboardInterrupt: 70 warnings.warn('Sampling interrupted, drew {} samples'.format(self.cycles)) 71 finally: 72 _stop = dt.now() 73 if not hasattr(self, 'total_sample_time'): 74 self.total_sample_time = _stop - _start 75 else: 76 self.total_sample_time += _stop - _start 77 78 def draw(self): 79 """ 80 Take exactly one sample from the joint posterior distribution. 81 """ 82 if self.cycles == 0: 83 self._finalize() 84 self._iteration() 85 self.cycles += 1 86 for param in self.traced_params: 87 self.trace.chains[0][param].append(self.state[param]) 88 if self.database is not None: 89 head_to_sql(self, self._cur, self._cxn) 90 for param in self.traced_params: 91 self.trace.chains[0][param] = [self.trace[param,-1]] 92 93 def _parallel_sample(self, n_samples, n_jobs): 94 """ 95 Run n_jobs parallel samples of a given model. 96 Not intended to be called directly, and should be called by model.sample. 97 """ 98 models = [copy.deepcopy(self) for _ in range(n_jobs)] 99 for i, model in enumerate(models): 100 if isinstance(model.state, list): 101 models[i].state = copy.deepcopy(self.state[i]) 102 if hasattr(model, 'configs'): 103 if isinstance(model.configs, list): 104 models[i].configs = copy.deepcopy(self.configs[i]) 105 if self.database is not None: 106 models[i].database = self.database + str(i) 107 models[i].trace = Trace(**{k:[] for k in model.trace.varnames}) 108 if self.cycles == 0: 109 models[i]._fuzz_starting_values() 110 n_samples = [n_samples] * n_jobs 111 _start = dt.now() 112 seed = np.random.randint(0,10000, size=n_jobs).tolist() 113 P = mp.Pool(n_jobs) 114 results = P.map(_reflexive_sample, zip(models, n_samples, seed)) 115 P.close() 116 _stop = dt.now() 117 if self.cycles > 0: 118 new_traces = [] 119 for i, model in enumerate(results): 120 # model.trace.chains is always single-chain, since we've broken everything into single chains 121 new_traces.append(Hashmap(**{k:param + model.trace.chains[0][k] 122 for k, param in self.trace.chains[i].items()})) 123 new_trace = Trace(*new_traces) 124 else: 125 new_trace = Trace(*[model.trace.chains[0] for model in results]) 126 self.trace = new_trace 127 self.state = [model.state for model in results] 128 self.cycles += n_samples[0] 129 self.configs = [model.configs for model in results] 130 if hasattr(self, 'total_sample_time'): 131 self.total_sample_time += _stop - _start 132 else: 133 self.total_sample_time = _stop - _start 134 135 def _fuzz_starting_values(self, state=None): 136 """ 137 Function to overdisperse starting values used in the package. 138 """ 139 st = self.state 140 if hasattr(st, 'Betas'): 141 st.Betas += np.random.normal(0,5, size=st.Betas.shape) 142 if hasattr(st, 'Alphas'): 143 st.Alphas += np.random.normal(0,5,size=st.Alphas.shape) 144 if hasattr(st, 'Sigma2'): 145 st.Sigma2 += np.random.uniform(0,5) 146 if hasattr(st, 'Tau2'): 147 st.Tau2 += np.random.uniform(0,5) 148 if hasattr(st, 'Lambda'): 149 st.Lambda += np.random.uniform(-.25,.25) 150 if hasattr(st, 'Rho'): 151 st.Rho += np.random.uniform(-.25,.25) 152 153 def _finalize(self, **args): 154 """ 155 Abstract function to ensure inheritors define a finalze method. This method should compute all derived quantities used in the _iteration() function that would change if the user changed priors, starting values, or other information. This is to ensure that if the user initializes the sampler with n_samples=0 and then changes the state, the derived quantites used in sampling are correct. 156 """ 157 raise NotImplementedError 158 159 def _setup_priors(self, **args): 160 """ 161 Abstract function to ensure inheritors define a _setup_priors method. This method should assign into the state all of the correct priors for all parameters in the model. 162 """ 163 raise NotImplementedError 164 165 def _setup_truncation(self, **args): 166 """ 167 Abstract function to ensure inheritors define a _setup_truncation method. This method should truncate parameter space to a given arbitrary bounds. 168 """ 169 raise NotImplementedError 170 171 def _setup_starting_values(self, **args): 172 """ 173 Abstract function to ensure that inheritors define a _setup_starting_values method. This method should assign the correct values for each of the parameters into model.state. 174 """ 175 raise NotImplementedError 176 177 @property 178 def database(self): 179 """ 180 the database used for the model. 181 """ 182 return getattr(self, '_db', None) 183 184 @database.setter 185 def database(self, filename): 186 self._cxn, self._cur = start_sql(self, tracename=filename) 187 self._db = filename 188 from .sqlite import trace_from_sql 189 def load_sqlite(): 190 return trace_from_sql(filename) 191 self.trace.load_sqlite = load_from_sqlite 192 193def _reflexive_sample(tup): 194 """ 195 a helper function sample a bunch of models in parallel. 196 197 Tuple must be: 198 199 model : model object 200 n_samples : int number of samples 201 seed : seed to use for the sampler 202 """ 203 model, n_samples, seed = tup 204 np.random.seed(seed) 205 model.sample(n_samples=n_samples) 206 return model 207 208def _noop(*args, **kwargs): 209 pass 210 211####################### 212# MAPS AND CONTAINERS # 213####################### 214 215class Hashmap(dict): 216 """ 217 A dictionary with dot access on attributes 218 """ 219 def __init__(self, **kw): 220 super(Hashmap, self).__init__(**kw) 221 if kw != dict(): 222 for k in kw: 223 self[k] = kw[k] 224 225 def __getattr__(self, attr): 226 try: 227 r = self[attr] 228 except KeyError: 229 try: 230 r = getattr(super(Hashmap, self), attr) 231 except AttributeError: 232 raise AttributeError("'{}' object has no attribute '{}'" 233 .format(self.__class__, attr)) 234 return r 235 236 def __setattr__(self, key, value): 237 self.__setitem__(key, value) 238 239 def __setitem__(self, key, value): 240 super(Hashmap, self).__setitem__(key,value) 241 self.__dict__.update({key:value}) 242 243 def __delattr__(self, item): 244 self.__delitem__(item) 245 246 def __delitem__(self, key): 247 super(Hashmap, self).__delitem__(key) 248 del self.__dict__[key] 249 250class Trace(object): 251 """ 252 Object to contain results from sampling. 253 254 Arguments 255 --------- 256 chains : a chain or comma-separated sequence of chains 257 a chain is a dict-like collection, where keys are the parameter name and the values are the values of the chain. 258 kwargs : a dictionary splatted into keyword arguments 259 the name of the argument is taken to the be the parameter name, and the value is taken to be a chain of that parameter. 260 261 Examples 262 --------- 263 >>> Trace(a=[1,2,3], b=[4,2,5], c=[1,9,23]) #Trace with one chain 264 >>> Trace([{'a':[1,2,3], 'b':[4,2,5], 'c':[1,9,23]}, 265 {'a':[2,5,1], 'b':[2,9,1], 'c':[9,21,1]}]) #Trace with two chains 266 """ 267 def __init__(self, *chains, **kwargs): 268 if chains is () and kwargs != dict(): 269 self.chains = _maybe_hashmap(kwargs) 270 if chains is not (): 271 self.chains = _maybe_hashmap(*chains) 272 if kwargs != dict(): 273 self.chains.extend(_maybe_hashmap(kwargs)) 274 self._validate_schema() 275 276 @property 277 def varnames(self, chain=None): 278 """ 279 Names of variables contained in the trace. 280 """ 281 try: 282 return self._varnames 283 except AttributeError: 284 try: 285 self._validate_schema() 286 except KeyError: 287 if chain is None: 288 raise Exception('Variable names are heterogeneous in chains and no default index provided.') 289 else: 290 warnings.warn('Variable names are heterogeneous in chains!', stacklevel=2) 291 return list(self.chains[chain].keys()) 292 self._varnames = list(self.chains[0].keys()) 293 return self._varnames 294 295 def drop(self, varnames, inplace=True): 296 """ 297 Drop a variable from the trace. 298 299 Arguments 300 --------- 301 varnames : list of strings 302 names of parameters to drop from the trace. 303 inplace : bool 304 whether to return a copy of the trace with parameters removed, or remove them inplace. 305 """ 306 if isinstance(varnames, str): 307 varnames = (varnames,) 308 if not inplace: 309 new = copy.deepcopy(self) 310 new.drop(varnames, inplace=True) 311 new._varnames = list(new.chains[0].keys()) 312 return new 313 for i, chain in enumerate(self.chains): 314 for varname in varnames: 315 del self.chains[i][varname] 316 self._varnames = list(self.chains[0].keys()) 317 318 def _validate_schema(self, chains=None): 319 """ 320 Validates the trace to ensure that the chain is self-consistent. 321 """ 322 if chains is None: 323 chains = self.chains 324 tracked_in_each = [set(chain.keys()) for chain in chains] 325 same_schema = [names == tracked_in_each[0] for names in tracked_in_each] 326 try: 327 assert all(same_schema) 328 except AssertionError: 329 bad_chains = [i for i in range(len(chains)) if same_schema[i]] 330 KeyError('The parameters tracked in each chain are not the same!' 331 '\nChains {} do not have the same parameters as chain 1!'.format(bad_chains)) 332 333 def add_chain(self, chains, validate=True): 334 """ 335 Add chains to a trace object 336 337 Parameters 338 ---------- 339 chains : Hashmap or list of hashmaps 340 chains to merge into the trace 341 validate: bool 342 whether or not to validate the schema and reject the chain if it does not match the current trace. 343 """ 344 if not isinstance(chains, (list, tuple)): 345 chains = (chains,) 346 new_chains = [self.chains] 347 for chain in chains: 348 if isinstance(chain, Hashmap): 349 new_chains.append(chain) 350 elif isinstance(chain, Trace): 351 new_chains.extend(chain.chains) 352 else: 353 new_chains.extend(_maybe_hashmap(chain)) 354 if validate: 355 self._validate_schema(chains=new_chains) 356 self.chains = new_chains 357 358 def map(self, func, **func_args): 359 """ 360 Map a function over all parameters in a chain. 361 Multivariate parameters are reduced to sequences of univariate parameters. 362 363 Usage 364 ------- 365 Intended when full-trace statistics are required. Most often, 366 the trace should be sliced directly. For example, to get the mean value of a 367 parameter over the last -1000 iterations with a thinning of 2: 368 369 trace[0, 'Betas', -1000::2].mean(axis=0) 370 371 but, to average of the parameter over all recorded chains: 372 373 trace['Betas', -1000::2].mean(axis=0).mean(axis=0) 374 375 since the first reduction provides an array where rows 376 are iterations and columns are parameters. 377 378 trace.map(np.mean) yields the mean of each parameter within each chain, and is 379 provided to make within-chain reductions easier. 380 381 Arguments 382 --------- 383 func : callable 384 a function that returns a result when provided a flat vector. 385 varnames : string or list of strings 386 a keyword only argument governing which parameters to map over. 387 func_args : dictionary/keyword arguments 388 arguments needed to be passed to the reduction 389 """ 390 varnames = func_args.pop('varnames', self.varnames) 391 if isinstance(varnames, str): 392 varnames = (varnames, ) 393 all_stats = [] 394 for i, chain in enumerate(self.chains): 395 these_stats=dict() 396 for var in varnames: 397 data = np.squeeze(self[i,var]) 398 if data.ndim > 1: 399 n,p = data.shape[0:2] 400 rest = data.shape[2:0] 401 if len(rest) == 0: 402 data = data.T 403 elif len(rest) == 1: 404 data = data.reshape(n,p*rest[0]).T 405 else: 406 raise Exception('Parameter "{}" shape not understood.' ' Please extract, shape it, and pass ' 407 ' as its own chain. '.format(var)) 408 else: 409 data = data.reshape(1,-1) 410 stats = [func(datum, **func_args) for datum in data] 411 if len(stats) == 1: 412 stats = stats[0] 413 these_stats.update({var:stats}) 414 all_stats.append(these_stats) 415 return all_stats 416 417 @property 418 def n_chains(self): 419 return len(self.chains) 420 421 @property 422 def n_iters(self): 423 """ 424 Number of raw iterations stored in the trace. 425 """ 426 lengths = [len(chain[self.varnames[0]]) for chain in self.chains] 427 if len(lengths) == 1: 428 return lengths[0] 429 else: 430 return lengths 431 432 def plot(self, burn=0, thin=None, varnames=None, 433 kde_kwargs={}, trace_kwargs={}, figure_kwargs={}): 434 """ 435 Make a trace plot paired with a distributional plot. 436 437 Arguments 438 ----------- 439 trace : namespace 440 a namespace whose variables are contained in varnames 441 burn : int 442 the number of iterations to discard from the front of the trace 443 thin : int 444 the number of iterations to discard between iterations 445 varnames : str or list 446 name or list of names to plot. 447 kde_kwargs : dictionary 448 dictionary of aesthetic arguments for the kde plot 449 trace_kwargs : dictionary 450 dictinoary of aesthetic arguments for the traceplot 451 452 Returns 453 ------- 454 figure, axis tuple, where axis is (len(varnames), 2) 455 """ 456 f, ax = plot_trace(model=None, trace=self, burn=burn, 457 thin=thin, varnames=varnames, 458 kde_kwargs=kde_kwargs, trace_kwargs=trace_kwargs, 459 figure_kwargs=figure_kwargs) 460 return f,ax 461 462 def summarize(self, level=0): 463 """ 464 Compute a summary of the trace. See Also: diagnostics.summary 465 466 Arguments 467 ------------ 468 level : int 469 0 for a summary by chain or 1 if the summary should be computed by pooling over chains. 470 """ 471 from .diagnostics import summarize 472 return summarize(trace=self, level=level) 473 474 def __getitem__(self, key): 475 """ 476 Getting an item from a trace can be done using at most three indices, where: 477 478 1 index 479 -------- 480 str/list of str: names of variates in all chains to grab. Returns list of Hashmaps 481 slice/int: iterations to grab from all chains. Returns list of Hashmaps, sliced to the specification 482 483 2 index 484 ------- 485 (str/list of str, slice/int): first term is name(s) of variates in all chains to grab, 486 second term specifies the slice each chain. 487 returns: list of hashmaps with keys of first term and entries sliced by the second term. 488 (slice/int, str/list of str): first term specifies which chains to retrieve, 489 second term is name(s) of variates in those chains 490 returns: list of hashmaps containing all iterations 491 (slice/int, slice/int): first term specifies which chains to retrieve, 492 second term specifies the slice of each chain. 493 returns: list of hashmaps with entries sliced by the second term 494 3 index 495 -------- 496 (slice/int, str/list of str, slice/int) : first term specifies which chains to retrieve, 497 second term is the name(s) of variates in those chains, 498 third term is the iteration slicing. 499 returns: list of hashmaps keyed on second term, with entries sliced by the third term 500 """ 501 if isinstance(key, str): #user wants only one name from the trace 502 if self.n_chains > 1: 503 result = ([chain[key] for chain in self.chains]) 504 else: 505 result = (self.chains[0][key]) 506 elif isinstance(key, (slice, int)): #user wants all draws past a certain index 507 if self.n_chains > 1: 508 return [Hashmap(**{k:v[key] for k,v in chain.items()}) for chain in self.chains] 509 else: 510 return Hashmap(**{k:v[key] for k,v in self.chains[0].items()}) 511 elif isinstance(key, list) and all([isinstance(val, str) for val in key]): #list of atts over all iters and all chains 512 if self.n_chains > 1: 513 return [Hashmap(**{k:chain[k] for k in key}) for chain in self.chains] 514 else: 515 return Hashmap(**{k:self.chains[0][k] for k in key}) 516 elif isinstance(key, tuple): #complex slicing 517 if len(key) == 1: 518 return self[key[0]] #ignore empty blocks 519 if len(key) == 2: 520 head, tail = key 521 if isinstance(head, str): #all chains, one var, some iters 522 if self.n_chains > 1: 523 result = ([_ifilter(tail, chain[head]) for chain in self.chains]) 524 else: 525 result = (_ifilter(tail, self.chains[0][head])) 526 elif isinstance(head, list) and all([isinstance(v, str) for v in head]): #all chains, some vars, some iters 527 if self.n_chains > 1: 528 return [Hashmap(**{name:_ifilter(tail, chain[name]) for name in head}) 529 for chain in self.chains] 530 else: 531 chain = self.chains[0] 532 return Hashmap(**{name:_ifilter(tail, chain[name]) for name in head}) 533 elif isinstance(tail, str): 534 target_chains = _ifilter(head, self.chains) 535 if isinstance(target_chains, Hashmap): 536 target_chains = [target_chains] 537 if len(target_chains) > 1: 538 result = ([chain[tail] for chain in target_chains]) 539 elif len(target_chains) == 1: 540 result = (target_chains[0][tail]) 541 else: 542 raise IndexError('The supplied chain index {} does not' 543 ' match any chains in trace.chains'.format(head)) 544 elif isinstance(tail, list) and all([isinstance(v, str) for v in tail]): 545 target_chains = _ifilter(head, self.chains) 546 if isinstance(target_chains, Hashmap): 547 target_chains = [target_chains] 548 if len(target_chains) > 1: 549 return [Hashmap(**{k:chain[k] for k in tail}) for chain in target_chains] 550 elif len(target_chains) == 1: 551 return Hashmap(**{k:target_chains[0][k] for k in tail}) 552 else: 553 raise IndexError('The supplied chain index {} does not' 554 ' match any chains in trace.chains'.format(head)) 555 else: 556 target_chains = _ifilter(head, self.chains) 557 if isinstance(target_chains, Hashmap): 558 target_chains = [target_chains] 559 out = [Hashmap(**{k:_ifilter(tail, val) for k,val in chain.items()}) 560 for chain in target_chains] 561 if len(out) == 1: 562 return out[0] 563 else: 564 return out 565 elif len(key) == 3: 566 chidx, varnames, iters = key 567 if isinstance(chidx, int): 568 if np.abs(chidx) > self.n_chains: 569 raise IndexError('The supplied chain index {} does not' 570 ' match any chains in trace.chains'.format(chidx)) 571 if varnames == slice(None, None, None): 572 varnames = self.varnames 573 chains = _ifilter(chidx, self.chains) 574 if isinstance(chains, Hashmap): 575 chains = [chains] 576 nchains = len(chains) 577 if isinstance(varnames, str): 578 varnames = [varnames] 579 if varnames is slice(None, None, None): 580 varnames = self.varnames 581 if len(varnames) == 1: 582 if nchains > 1: 583 result = ([_ifilter(iters, chain[varnames[0]]) for chain in chains]) 584 else: 585 result = (_ifilter(iters, chains[0][varnames[0]])) 586 else: 587 if nchains > 1: 588 return [Hashmap(**{varname:_ifilter(iters, chain[varname]) 589 for varname in varnames}) 590 for chain in chains] 591 else: 592 return Hashmap(**{varname:_ifilter(iters, chains[0][varname]) for varname in varnames}) 593 else: 594 raise IndexError('index not understood') 595 596 result = np.asarray(result) 597 if result.shape == (): 598 result = result.tolist() 599 elif result.shape in [(1,1), (1,)]: 600 result = result[0] 601 return result 602 603 ############## 604 # Comparison # 605 ############## 606 607 def __eq__(self, other): 608 if not isinstance(other, type(self)): 609 return False 610 else: 611 a = [ch1==ch2 for ch1,ch2 in zip(other.chains, self.chains)] 612 return all(a) 613 614 def _allclose(self, other, **allclose_kw): 615 try: 616 self._assert_allclose(other, **allclose_kw) 617 except AssertionError: 618 return False 619 return True 620 621 def _assert_allclose(self, other, **allclose_kw): 622 ignore_shape = allclose_kw.pop('ignore_shape', False) 623 squeeze = allclose_kw.pop('squeeze', True) 624 try: 625 assert set(self.varnames) == set(other.varnames) 626 except AssertionError: 627 raise AssertionError('Variable names are different!\n' 628 'self: {}\nother:{}'.format( 629 self.varnames, other.varnames)) 630 assert isinstance(other, type(self)) 631 for ch1, ch2 in zip(self.chains, other.chains): 632 for k,v in ch1.items(): 633 allclose_kw['err_msg'] = 'Failed on {}'.format(k) 634 if ignore_shape: 635 A = [np.asarray(item).flatten() for item in v] 636 B = [np.asarray(item).flatten() for item in ch2[k]] 637 elif squeeze: 638 A = [np.squeeze(item) for item in v] 639 B = [np.squeeze(item) for item in ch2[k]] 640 else: 641 A = v 642 B = ch2[k] 643 np.testing.assert_allclose(A,B,**allclose_kw) 644 645 646 ################### 647 # IO and Exchange # 648 ################### 649 650 def to_df(self): 651 """ 652 Convert the trace object to a Pandas Dataframe. 653 654 Returns 655 ------- 656 a dataframe where each column is a parameter. Multivariate parameters are vectorized and stuffed into a column. 657 """ 658 dfs = [] 659 outnames = self.varnames 660 to_split = [name for name in outnames if np.asarray(self[0,name,0]).size > 1] 661 for chain in self.chains: 662 out = OrderedDict(list(chain.items())) 663 for split in to_split: 664 records = np.asarray(copy.deepcopy(chain[split])) 665 if len(records.shape) == 1: 666 records = records.reshape(-1,1) 667 n,k = records.shape[0:2] 668 rest = records.shape[2:] 669 if len(rest) == 0: 670 pass 671 elif len(rest) == 1: 672 records = records.reshape(n,int(k*rest[0])) 673 else: 674 raise Exception("Parameter '{}' has too many dimensions" 675 " to flatten able to be flattend?" .format(split)) 676 records = OrderedDict([(split+'_'+str(i),record.T.tolist()) 677 for i,record in enumerate(records.T)]) 678 out.update(records) 679 del out[split] 680 df = pd.DataFrame().from_dict(out) 681 dfs.append(df) 682 if len(dfs) == 1: 683 return dfs[0] 684 else: 685 return dfs 686 687 def to_csv(self, filename, **pandas_kwargs): 688 """ 689 Write trace out to file, going through Trace.to_df() 690 691 If there are multiple chains in this trace, this will write 692 them each out to 'filename_number.csv', where `number` is the 693 number of the trace. 694 695 Arguments 696 --------- 697 filename : string 698 name of file to write the trace to. 699 pandas_kwargs: keyword arguments 700 arguments to pass to the pandas to_csv function. 701 """ 702 if 'index' not in pandas_kwargs: 703 pandas_kwargs['index'] = False 704 dfs = self.to_df() 705 if isinstance(dfs, list): 706 name, ext = os.path.splitext(filename) 707 for i, df in enumerate(dfs): 708 df.to_csv(name + '_' + str(i) + ext, **pandas_kwargs) 709 else: 710 dfs.to_csv(filename, **pandas_kwargs) 711 712 @classmethod 713 def from_df(cls, dfs, varnames=None, combine_suffix='_'): 714 """ 715 Convert a dataframe into a trace object. 716 717 Arguments 718 ---------- 719 dfs : dataframe or list of dataframes 720 pandas dataframes to convert into a trace. Each dataframe is assumed to be a single chain. 721 varnames: string or list of strings 722 names to use instead of the names in the dataframe. If none, the column 723 names are split using `combine_suffix`, and the unique things before the suffix are used as parameter names. 724 """ 725 if not isinstance(dfs, (tuple, list)): 726 dfs = (dfs,) 727 if len(dfs) > 1: 728 traces = ([cls.from_df(df, varnames=varnames, 729 combine_suffix=combine_suffix) for df in dfs]) 730 return cls(*[trace.chains[0] for trace in traces]) 731 else: 732 df = dfs[0] 733 if varnames is None: 734 varnames = df.columns 735 unique_stems = set() 736 for col in varnames: 737 suffix_split = col.split(combine_suffix) 738 if suffix_split[0] == col: 739 unique_stems.update([col]) 740 else: 741 unique_stems.update(['_'.join(suffix_split[:-1])]) 742 out = dict() 743 for stem in unique_stems: 744 cols = [] 745 for var in df.columns: 746 if var == stem: 747 cols.append(var) 748 elif '_'.join(var.split('_')[:-1]) == stem: 749 cols.append(var) 750 if len(cols) == 1: 751 targets = df[cols].values.flatten().tolist() 752 else: 753 # ensure the tail ordinate sorts the columns, not string order 754 # '1','11','2' will corrupt the trace 755 order = [int(st.split(combine_suffix)[-1]) for st in cols] 756 cols = np.asarray(cols)[np.argsort(order)] 757 targets = [vec for vec in df[cols].values] 758 out.update({stem:targets}) 759 return cls(**out) 760 761 @classmethod 762 def from_pymc3(cls, pymc3trace): 763 """ 764 Convert a PyMC3 trace to a spvcm trace 765 """ 766 try: 767 from pymc3 import trace_to_dataframe 768 except ImportError: 769 raise ImportError("The 'trace_to_dataframe' function in " 770 "pymc3 is used for this feature. Pymc3 " 771 "failed to import.") 772 return cls.from_df(mc.trace_to_dataframe(pymc3trace)) 773 774 @classmethod 775 def from_csv(cls, filename=None, multi=False, 776 varnames=None, combine_suffix='_', **pandas_kwargs): 777 """ 778 Read a CSV into a trace object, by way of `Trace.from_df()` 779 780 Arguments 781 ---------- 782 filename : string 783 string containing the name of the file to read. 784 multi : bool 785 flag denoting whether the trace being read is a multitrace or not. If so, the filename is understood to be the prefix of many files that end in `filename_#.csv` 786 varnames : string or list of strings 787 custom names to use for the trace. If not provided, combine suffix is used to identify the unique prefixes in the csvs. 788 pandas_kawrgs: keyword arguments 789 keyword arguments to pass to the pandas functions. 790 """ 791 if multi: 792 filepath = os.path.dirname(os.path.abspath(filename)) 793 filestem = os.path.basename(filename) 794 targets = [f for f in os.listdir(filepath) 795 if f.startswith(filestem)] 796 ordinates = [int(os.path.splitext(fname)[0].split(combine_suffix)[-1]) 797 for fname in targets] 798 # preserve the order of the trailing ordinates 799 targets = np.asarray(targets)[np.argsort(ordinates)].tolist() 800 traces = ([cls.from_csv(filename=os.path.join(filepath, f) 801 ,multi=False) for f in targets]) 802 if traces == []: 803 raise IOError("No such file or directory: " + 804 filepath + filestem) 805 806 return cls(*[trace.chains[0] for trace in traces]) 807 else: 808 df = pd.read_csv(filename, **pandas_kwargs) 809 return cls.from_df(df, varnames=varnames, 810 combine_suffix=combine_suffix) 811 812 813#################### 814# HELPER FUNCTIONS # 815#################### 816 817def _ifilter(filt,iterable): 818 """ 819 Filter an iterable by whether or not each item is in the filt 820 """ 821 try: 822 return iterable[filt] 823 except: 824 if isinstance(filt, (int, float)): 825 filt = [filt] 826 return [val for i,val in enumerate(iterable) if i in filt] 827 828def _maybe_hashmap(*collections): 829 """ 830 Attempt to coerce a collection into a Hashmap. Otherwise, leave it alone. 831 """ 832 out = [] 833 for collection in collections: 834 if isinstance(collection, Hashmap): 835 out.append(collection) 836 else: 837 out.append(Hashmap(**collection)) 838 return out 839 840def _copy_hashmaps(*hashmaps): 841 """ 842 Create deep copies of the hashmaps passed to the function. 843 """ 844 return [Hashmap(**{k:copy.deepcopy(v) for k,v in hashmap.items()}) 845 for hashmap in hashmaps] 846