1import warnings
2from datetime import datetime as dt
3import numpy as np
4import copy
5import multiprocessing as mp
6import pandas as pd
7import os
8
9from .sqlite import head_to_sql, start_sql
10from .plotting import plot_trace
11from collections import OrderedDict
12
13import six
14if not six.PY3:
15    range = xrange
16
17
18__all__ = ['Sampler_Mixin', 'Hashmap', 'Trace']
19
20######################
21# SAMPLER MECHANISMS #
22######################
23
24
25class Sampler_Mixin(object):
26    """
27    A Mixin class designed to facilitate code reuse. This should be the parent class of anything that uses the sampling framework in this package.
28    """
29    def __init__(self):
30        super(Sampler_Mixin, self).__init__()
31
32    def sample(self, n_samples, n_jobs=1):
33        """
34        Sample from the joint posterior distribution defined by all of the
35        parameters in the gibbs sampler.
36
37        Parameters
38        ----------
39        n_samples   :   int
40                        number of samples from the joint posterior density to take
41        n_jobs      :   int
42                        number of parallel chains to run.
43
44        Returns
45        -------
46        Implicitly updates all values in place, returns None
47        """
48        try:
49            from tqdm import tqdm
50        except ImportError:
51            from .utils import thru_op
52            tqdm = thru_op
53            msg = '`tqdm` is not available. '
54            msg += 'Using `spvcm.utils.thru_op` in place of `tqdm`.'
55            warnings.warn(msg, stacklevel=2)
56
57        if n_jobs > 1:
58           self._parallel_sample(n_samples, n_jobs)
59           return
60        elif isinstance(self.state, list):
61            self._parallel_sample(n_samples, n_jobs=len(self.state))
62            return
63        _start = dt.now()
64        try:
65            for _ in tqdm(range(n_samples)):
66                if (self._verbose > 1) and (n_samples % 100 == 0):
67                    print('{} Draws to go'.format(n_samples))
68                self.draw()
69        except KeyboardInterrupt:
70            warnings.warn('Sampling interrupted, drew {} samples'.format(self.cycles))
71        finally:
72            _stop = dt.now()
73            if not hasattr(self, 'total_sample_time'):
74                self.total_sample_time = _stop - _start
75            else:
76                self.total_sample_time += _stop - _start
77
78    def draw(self):
79        """
80        Take exactly one sample from the joint posterior distribution.
81        """
82        if self.cycles == 0:
83            self._finalize()
84        self._iteration()
85        self.cycles += 1
86        for param in self.traced_params:
87            self.trace.chains[0][param].append(self.state[param])
88        if self.database is not None:
89            head_to_sql(self, self._cur, self._cxn)
90            for param in self.traced_params:
91                self.trace.chains[0][param] = [self.trace[param,-1]]
92
93    def _parallel_sample(self, n_samples, n_jobs):
94        """
95        Run n_jobs parallel samples of a given model.
96        Not intended to be called directly, and should be called by model.sample.
97        """
98        models = [copy.deepcopy(self) for _ in range(n_jobs)]
99        for i, model in enumerate(models):
100            if isinstance(model.state, list):
101                models[i].state = copy.deepcopy(self.state[i])
102            if hasattr(model, 'configs'):
103                if isinstance(model.configs, list):
104                    models[i].configs = copy.deepcopy(self.configs[i])
105            if self.database is not None:
106                models[i].database = self.database + str(i)
107            models[i].trace = Trace(**{k:[] for k in model.trace.varnames})
108            if self.cycles == 0:
109                models[i]._fuzz_starting_values()
110        n_samples = [n_samples] * n_jobs
111        _start = dt.now()
112        seed = np.random.randint(0,10000, size=n_jobs).tolist()
113        P = mp.Pool(n_jobs)
114        results = P.map(_reflexive_sample, zip(models, n_samples, seed))
115        P.close()
116        _stop = dt.now()
117        if self.cycles > 0:
118            new_traces = []
119            for i, model in enumerate(results):
120                # model.trace.chains is always single-chain, since we've broken everything into single chains
121                new_traces.append(Hashmap(**{k:param + model.trace.chains[0][k]
122                                             for k, param in self.trace.chains[i].items()}))
123            new_trace = Trace(*new_traces)
124        else:
125            new_trace = Trace(*[model.trace.chains[0] for model in results])
126        self.trace = new_trace
127        self.state = [model.state for model in results]
128        self.cycles += n_samples[0]
129        self.configs = [model.configs for model in results]
130        if hasattr(self, 'total_sample_time'):
131            self.total_sample_time += _stop - _start
132        else:
133            self.total_sample_time = _stop - _start
134
135    def _fuzz_starting_values(self, state=None):
136        """
137        Function to overdisperse starting values used in the package.
138        """
139        st = self.state
140        if hasattr(st, 'Betas'):
141            st.Betas += np.random.normal(0,5, size=st.Betas.shape)
142        if hasattr(st, 'Alphas'):
143            st.Alphas += np.random.normal(0,5,size=st.Alphas.shape)
144        if hasattr(st, 'Sigma2'):
145            st.Sigma2 += np.random.uniform(0,5)
146        if hasattr(st, 'Tau2'):
147            st.Tau2 += np.random.uniform(0,5)
148        if hasattr(st, 'Lambda'):
149            st.Lambda += np.random.uniform(-.25,.25)
150        if hasattr(st, 'Rho'):
151            st.Rho += np.random.uniform(-.25,.25)
152
153    def _finalize(self, **args):
154        """
155        Abstract function to ensure inheritors define a finalze method. This method should compute all derived quantities used in the _iteration() function that would change if the user changed priors, starting values, or other information. This is to ensure that if the user initializes the sampler with n_samples=0 and then changes the state, the derived quantites used in sampling are correct.
156        """
157        raise NotImplementedError
158
159    def _setup_priors(self, **args):
160        """
161        Abstract function to ensure inheritors define a _setup_priors method. This method should assign into the state all of the correct priors for all parameters in the model.
162        """
163        raise NotImplementedError
164
165    def _setup_truncation(self, **args):
166        """
167        Abstract function to ensure inheritors define a _setup_truncation method. This method should truncate parameter space to a given arbitrary bounds.
168        """
169        raise NotImplementedError
170
171    def _setup_starting_values(self, **args):
172        """
173        Abstract function to ensure that inheritors define a _setup_starting_values method. This method should assign the correct values for each of the parameters into model.state.
174        """
175        raise NotImplementedError
176
177    @property
178    def database(self):
179        """
180        the database used for the model.
181        """
182        return getattr(self, '_db', None)
183
184    @database.setter
185    def database(self, filename):
186        self._cxn, self._cur = start_sql(self, tracename=filename)
187        self._db = filename
188        from .sqlite import trace_from_sql
189        def load_sqlite():
190            return trace_from_sql(filename)
191        self.trace.load_sqlite = load_from_sqlite
192
193def _reflexive_sample(tup):
194    """
195    a helper function sample a bunch of models in parallel.
196
197    Tuple must be:
198
199    model : model object
200    n_samples : int number of samples
201    seed : seed to use for the sampler
202    """
203    model, n_samples, seed = tup
204    np.random.seed(seed)
205    model.sample(n_samples=n_samples)
206    return model
207
208def _noop(*args, **kwargs):
209    pass
210
211#######################
212# MAPS AND CONTAINERS #
213#######################
214
215class Hashmap(dict):
216    """
217    A dictionary with dot access on attributes
218    """
219    def __init__(self, **kw):
220        super(Hashmap, self).__init__(**kw)
221        if kw != dict():
222            for k in kw:
223                self[k] = kw[k]
224
225    def __getattr__(self, attr):
226        try:
227            r = self[attr]
228        except KeyError:
229            try:
230                r = getattr(super(Hashmap, self), attr)
231            except AttributeError:
232                raise AttributeError("'{}' object has no attribute '{}'"
233                                     .format(self.__class__, attr))
234        return r
235
236    def __setattr__(self, key, value):
237        self.__setitem__(key, value)
238
239    def __setitem__(self, key, value):
240        super(Hashmap, self).__setitem__(key,value)
241        self.__dict__.update({key:value})
242
243    def __delattr__(self, item):
244        self.__delitem__(item)
245
246    def __delitem__(self, key):
247        super(Hashmap, self).__delitem__(key)
248        del self.__dict__[key]
249
250class Trace(object):
251    """
252    Object to contain results from sampling.
253
254    Arguments
255    ---------
256    chains  :   a chain or comma-separated sequence of chains
257                a chain is a dict-like collection, where keys are the parameter name and the values are the values of the chain.
258    kwargs  :   a dictionary splatted into keyword arguments
259                the name of the argument is taken to the be the parameter name, and the value is taken to be a chain of that parameter.
260
261    Examples
262    ---------
263    >>> Trace(a=[1,2,3], b=[4,2,5], c=[1,9,23]) #Trace with one chain
264    >>> Trace([{'a':[1,2,3], 'b':[4,2,5], 'c':[1,9,23]},
265               {'a':[2,5,1], 'b':[2,9,1], 'c':[9,21,1]}]) #Trace with two chains
266    """
267    def __init__(self, *chains, **kwargs):
268        if chains is () and kwargs != dict():
269            self.chains = _maybe_hashmap(kwargs)
270        if chains is not ():
271            self.chains = _maybe_hashmap(*chains)
272            if kwargs != dict():
273                self.chains.extend(_maybe_hashmap(kwargs))
274        self._validate_schema()
275
276    @property
277    def varnames(self, chain=None):
278        """
279        Names of variables contained in the trace.
280        """
281        try:
282            return self._varnames
283        except AttributeError:
284            try:
285                self._validate_schema()
286            except KeyError:
287                if chain is None:
288                    raise Exception('Variable names are heterogeneous in chains and no default index provided.')
289                else:
290                    warnings.warn('Variable names are heterogeneous in chains!', stacklevel=2)
291                    return list(self.chains[chain].keys())
292            self._varnames = list(self.chains[0].keys())
293            return self._varnames
294
295    def drop(self, varnames, inplace=True):
296        """
297        Drop a variable from the trace.
298
299        Arguments
300        ---------
301        varnames    :   list of strings
302                        names of parameters to drop from the trace.
303        inplace     :   bool
304                        whether to return a copy of the trace with parameters removed, or remove them inplace.
305        """
306        if isinstance(varnames, str):
307            varnames = (varnames,)
308        if not inplace:
309            new = copy.deepcopy(self)
310            new.drop(varnames, inplace=True)
311            new._varnames = list(new.chains[0].keys())
312            return new
313        for i, chain in enumerate(self.chains):
314            for varname in varnames:
315                del self.chains[i][varname]
316        self._varnames = list(self.chains[0].keys())
317
318    def _validate_schema(self, chains=None):
319        """
320        Validates the trace to ensure that the chain is self-consistent.
321        """
322        if chains is None:
323            chains = self.chains
324        tracked_in_each = [set(chain.keys()) for chain in chains]
325        same_schema = [names == tracked_in_each[0] for names in tracked_in_each]
326        try:
327            assert all(same_schema)
328        except AssertionError:
329            bad_chains = [i for i in range(len(chains)) if same_schema[i]]
330            KeyError('The parameters tracked in each chain are not the same!'
331                     '\nChains {} do not have the same parameters as chain 1!'.format(bad_chains))
332
333    def add_chain(self, chains, validate=True):
334        """
335        Add chains to a trace object
336
337        Parameters
338        ----------
339        chains  :   Hashmap or list of hashmaps
340                    chains to merge into the trace
341        validate:   bool
342                    whether or not to validate the schema and reject the chain if it does not match the current trace.
343        """
344        if not isinstance(chains, (list, tuple)):
345            chains = (chains,)
346        new_chains = [self.chains]
347        for chain in chains:
348            if isinstance(chain, Hashmap):
349                new_chains.append(chain)
350            elif isinstance(chain, Trace):
351                new_chains.extend(chain.chains)
352            else:
353                new_chains.extend(_maybe_hashmap(chain))
354        if validate:
355            self._validate_schema(chains=new_chains)
356        self.chains = new_chains
357
358    def map(self, func, **func_args):
359        """
360        Map a function over all parameters in a chain.
361        Multivariate parameters are reduced to sequences of univariate parameters.
362
363        Usage
364        -------
365        Intended when full-trace statistics are required. Most often,
366        the trace should be sliced directly. For example, to get the mean value of a
367        parameter over the last -1000 iterations with a thinning of 2:
368
369        trace[0, 'Betas', -1000::2].mean(axis=0)
370
371        but, to average of the parameter over all recorded chains:
372
373        trace['Betas', -1000::2].mean(axis=0).mean(axis=0)
374
375        since the first reduction provides an array where rows
376        are iterations and columns are parameters.
377
378        trace.map(np.mean) yields the mean of each parameter within each chain, and is
379        provided to make within-chain reductions easier.
380
381        Arguments
382        ---------
383        func        :   callable
384                        a function that returns a result when provided a flat vector.
385        varnames    :   string or list of strings
386                        a keyword only argument governing which parameters to map over.
387        func_args   :   dictionary/keyword arguments
388                        arguments needed to be passed to the reduction
389        """
390        varnames = func_args.pop('varnames', self.varnames)
391        if isinstance(varnames, str):
392            varnames = (varnames, )
393        all_stats = []
394        for i, chain in enumerate(self.chains):
395            these_stats=dict()
396            for var in varnames:
397                data = np.squeeze(self[i,var])
398                if data.ndim > 1:
399                    n,p = data.shape[0:2]
400                    rest = data.shape[2:0]
401                    if len(rest) == 0:
402                        data = data.T
403                    elif len(rest) == 1:
404                        data = data.reshape(n,p*rest[0]).T
405                    else:
406                        raise Exception('Parameter "{}" shape not understood.'                  ' Please extract, shape it, and pass '
407                                        ' as its own chain. '.format(var))
408                else:
409                    data = data.reshape(1,-1)
410                stats = [func(datum, **func_args) for datum in data]
411                if len(stats) == 1:
412                    stats = stats[0]
413                these_stats.update({var:stats})
414            all_stats.append(these_stats)
415        return all_stats
416
417    @property
418    def n_chains(self):
419        return len(self.chains)
420
421    @property
422    def n_iters(self):
423        """
424        Number of raw iterations stored in the trace.
425        """
426        lengths = [len(chain[self.varnames[0]]) for chain in self.chains]
427        if len(lengths) == 1:
428            return lengths[0]
429        else:
430            return lengths
431
432    def plot(self, burn=0, thin=None, varnames=None,
433             kde_kwargs={}, trace_kwargs={}, figure_kwargs={}):
434        """
435        Make a trace plot paired with a distributional plot.
436
437        Arguments
438        -----------
439        trace   :   namespace
440                    a namespace whose variables are contained in varnames
441        burn    :   int
442                    the number of iterations to discard from the front of the trace
443        thin    :   int
444                    the number of iterations to discard between iterations
445        varnames :  str or list
446                    name or list of names to plot.
447        kde_kwargs : dictionary
448                     dictionary of aesthetic arguments for the kde plot
449        trace_kwargs : dictionary
450                       dictinoary of aesthetic arguments for the traceplot
451
452        Returns
453        -------
454        figure, axis tuple, where axis is (len(varnames), 2)
455        """
456        f, ax = plot_trace(model=None, trace=self, burn=burn,
457                           thin=thin, varnames=varnames,
458                      kde_kwargs=kde_kwargs, trace_kwargs=trace_kwargs,
459                      figure_kwargs=figure_kwargs)
460        return f,ax
461
462    def summarize(self, level=0):
463        """
464        Compute a summary of the trace. See Also: diagnostics.summary
465
466        Arguments
467        ------------
468        level   :   int
469                    0 for a summary by chain or 1 if the summary should be computed by pooling over chains.
470        """
471        from .diagnostics import summarize
472        return summarize(trace=self, level=level)
473
474    def __getitem__(self, key):
475        """
476        Getting an item from a trace can be done using at most three indices, where:
477
478        1 index
479        --------
480            str/list of str: names of variates in all chains to grab. Returns list of Hashmaps
481            slice/int: iterations to grab from all chains. Returns list of Hashmaps, sliced to the specification
482
483        2 index
484        -------
485            (str/list of str, slice/int): first term is name(s) of variates in all chains to grab,
486                                          second term specifies the slice each chain.
487                                          returns: list of hashmaps with keys of first term and entries sliced by the second term.
488            (slice/int, str/list of str): first term specifies which chains to retrieve,
489                                          second term is name(s) of variates in those chains
490                                          returns: list of hashmaps containing all iterations
491            (slice/int, slice/int): first term specifies which chains to retrieve,
492                                    second term specifies the slice of each chain.
493                                    returns: list of hashmaps with entries sliced by the second term
494        3 index
495        --------
496            (slice/int, str/list of str, slice/int) : first term specifies which chains to retrieve,
497                                                      second term is the name(s) of variates in those chains,
498                                                      third term is the iteration slicing.
499                                                      returns: list of hashmaps keyed on second term, with entries sliced by the third term
500        """
501        if isinstance(key, str): #user wants only one name from the trace
502            if self.n_chains  > 1:
503                result = ([chain[key] for chain in self.chains])
504            else:
505                result = (self.chains[0][key])
506        elif isinstance(key, (slice, int)): #user wants all draws past a certain index
507            if self.n_chains > 1:
508                return [Hashmap(**{k:v[key] for k,v in chain.items()}) for chain in self.chains]
509            else:
510                return Hashmap(**{k:v[key] for k,v in self.chains[0].items()})
511        elif isinstance(key, list) and all([isinstance(val, str) for val in key]): #list of atts over all iters and all chains
512                if self.n_chains > 1:
513                    return [Hashmap(**{k:chain[k] for k in key}) for chain in self.chains]
514                else:
515                    return Hashmap(**{k:self.chains[0][k] for k in key})
516        elif isinstance(key, tuple): #complex slicing
517            if len(key) == 1:
518                return self[key[0]] #ignore empty blocks
519            if len(key) == 2:
520                head, tail = key
521                if isinstance(head, str): #all chains, one var, some iters
522                    if self.n_chains > 1:
523                        result = ([_ifilter(tail, chain[head]) for chain in self.chains])
524                    else:
525                        result = (_ifilter(tail, self.chains[0][head]))
526                elif isinstance(head, list) and all([isinstance(v, str) for v in head]): #all chains, some vars, some iters
527                    if self.n_chains > 1:
528                        return [Hashmap(**{name:_ifilter(tail, chain[name]) for name in head})
529                                   for chain in self.chains]
530                    else:
531                        chain = self.chains[0]
532                        return Hashmap(**{name:_ifilter(tail, chain[name]) for name in head})
533                elif isinstance(tail, str):
534                    target_chains = _ifilter(head, self.chains)
535                    if isinstance(target_chains, Hashmap):
536                        target_chains = [target_chains]
537                    if len(target_chains) > 1:
538                        result = ([chain[tail] for chain in target_chains])
539                    elif len(target_chains) == 1:
540                        result = (target_chains[0][tail])
541                    else:
542                        raise IndexError('The supplied chain index {} does not'
543                                        ' match any chains in trace.chains'.format(head))
544                elif isinstance(tail, list) and all([isinstance(v, str) for v in tail]):
545                    target_chains = _ifilter(head, self.chains)
546                    if isinstance(target_chains, Hashmap):
547                        target_chains = [target_chains]
548                    if len(target_chains) > 1:
549                        return [Hashmap(**{k:chain[k] for k in tail}) for chain in target_chains]
550                    elif len(target_chains) == 1:
551                        return Hashmap(**{k:target_chains[0][k] for k in tail})
552                    else:
553                        raise IndexError('The supplied chain index {} does not'
554                                         ' match any chains in trace.chains'.format(head))
555                else:
556                    target_chains = _ifilter(head, self.chains)
557                    if isinstance(target_chains, Hashmap):
558                        target_chains = [target_chains]
559                    out = [Hashmap(**{k:_ifilter(tail, val) for k,val in chain.items()})
560                            for chain in target_chains]
561                    if len(out) == 1:
562                        return out[0]
563                    else:
564                        return out
565            elif len(key) == 3:
566                chidx, varnames, iters = key
567                if isinstance(chidx, int):
568                    if np.abs(chidx) > self.n_chains:
569                        raise IndexError('The supplied chain index {} does not'
570                                         ' match any chains in trace.chains'.format(chidx))
571                if varnames == slice(None, None, None):
572                    varnames = self.varnames
573                chains = _ifilter(chidx, self.chains)
574                if isinstance(chains, Hashmap):
575                    chains = [chains]
576                nchains = len(chains)
577                if isinstance(varnames, str):
578                    varnames = [varnames]
579                if varnames is slice(None, None, None):
580                    varnames = self.varnames
581                if len(varnames) == 1:
582                    if nchains > 1:
583                        result = ([_ifilter(iters, chain[varnames[0]]) for chain in chains])
584                    else:
585                        result = (_ifilter(iters, chains[0][varnames[0]]))
586                else:
587                    if nchains > 1:
588                        return [Hashmap(**{varname:_ifilter(iters, chain[varname])
589                                        for varname in varnames})
590                                for chain in chains]
591                    else:
592                        return Hashmap(**{varname:_ifilter(iters, chains[0][varname]) for varname in varnames})
593        else:
594            raise IndexError('index not understood')
595
596        result = np.asarray(result)
597        if result.shape == ():
598            result = result.tolist()
599        elif result.shape in [(1,1), (1,)]:
600            result = result[0]
601        return result
602
603    ##############
604    # Comparison #
605    ##############
606
607    def __eq__(self, other):
608        if not isinstance(other, type(self)):
609            return False
610        else:
611            a = [ch1==ch2 for ch1,ch2 in zip(other.chains, self.chains)]
612            return all(a)
613
614    def _allclose(self, other, **allclose_kw):
615        try:
616            self._assert_allclose(other, **allclose_kw)
617        except AssertionError:
618            return False
619        return True
620
621    def _assert_allclose(self, other, **allclose_kw):
622        ignore_shape = allclose_kw.pop('ignore_shape', False)
623        squeeze = allclose_kw.pop('squeeze', True)
624        try:
625            assert set(self.varnames) == set(other.varnames)
626        except AssertionError:
627            raise AssertionError('Variable names are different!\n'
628                                 'self: {}\nother:{}'.format(
629                                     self.varnames, other.varnames))
630        assert isinstance(other, type(self))
631        for ch1, ch2 in zip(self.chains, other.chains):
632            for k,v in ch1.items():
633                allclose_kw['err_msg'] = 'Failed on {}'.format(k)
634                if ignore_shape:
635                    A = [np.asarray(item).flatten() for item in v]
636                    B = [np.asarray(item).flatten() for item in ch2[k]]
637                elif squeeze:
638                    A = [np.squeeze(item) for item in v]
639                    B = [np.squeeze(item) for item in ch2[k]]
640                else:
641                    A = v
642                    B = ch2[k]
643                np.testing.assert_allclose(A,B,**allclose_kw)
644
645
646    ###################
647    # IO and Exchange #
648    ###################
649
650    def to_df(self):
651        """
652        Convert the trace object to a Pandas Dataframe.
653
654        Returns
655        -------
656        a dataframe where each column is a parameter. Multivariate parameters are vectorized and stuffed into a column.
657        """
658        dfs = []
659        outnames = self.varnames
660        to_split = [name for name in outnames if np.asarray(self[0,name,0]).size > 1]
661        for chain in self.chains:
662            out = OrderedDict(list(chain.items()))
663            for split in to_split:
664                records = np.asarray(copy.deepcopy(chain[split]))
665                if len(records.shape) == 1:
666                    records = records.reshape(-1,1)
667                n,k = records.shape[0:2]
668                rest = records.shape[2:]
669                if len(rest) == 0:
670                    pass
671                elif len(rest) == 1:
672                    records = records.reshape(n,int(k*rest[0]))
673                else:
674                    raise Exception("Parameter '{}' has too many dimensions"
675                                    " to flatten able to be flattend?"               .format(split))
676                records = OrderedDict([(split+'_'+str(i),record.T.tolist())
677                                        for i,record in enumerate(records.T)])
678                out.update(records)
679                del out[split]
680            df = pd.DataFrame().from_dict(out)
681            dfs.append(df)
682        if len(dfs) == 1:
683            return dfs[0]
684        else:
685            return dfs
686
687    def to_csv(self, filename, **pandas_kwargs):
688        """
689        Write trace out to file, going through Trace.to_df()
690
691        If there are multiple chains in this trace, this will write
692        them each out to 'filename_number.csv', where `number` is the
693            number of the trace.
694
695        Arguments
696        ---------
697        filename    :   string
698                        name of file to write the trace to.
699        pandas_kwargs:  keyword arguments
700                        arguments to pass to the pandas to_csv function.
701        """
702        if 'index' not in pandas_kwargs:
703            pandas_kwargs['index'] = False
704        dfs = self.to_df()
705        if isinstance(dfs, list):
706            name, ext = os.path.splitext(filename)
707            for i, df in enumerate(dfs):
708                df.to_csv(name + '_' + str(i) + ext, **pandas_kwargs)
709        else:
710            dfs.to_csv(filename, **pandas_kwargs)
711
712    @classmethod
713    def from_df(cls, dfs, varnames=None, combine_suffix='_'):
714        """
715        Convert a dataframe into a trace object.
716
717        Arguments
718        ----------
719        dfs     :   dataframe or list of dataframes
720                    pandas dataframes to convert into a trace. Each dataframe is assumed to be a single chain.
721        varnames:   string or list of strings
722                    names to use instead of the names in the dataframe. If none, the column
723                    names are split using `combine_suffix`, and the unique things before the suffix are used as parameter names.
724        """
725        if not isinstance(dfs, (tuple, list)):
726            dfs = (dfs,)
727        if len(dfs) > 1:
728            traces = ([cls.from_df(df, varnames=varnames,
729                        combine_suffix=combine_suffix) for df in dfs])
730            return cls(*[trace.chains[0] for trace in traces])
731        else:
732            df = dfs[0]
733        if varnames is None:
734            varnames = df.columns
735        unique_stems = set()
736        for col in varnames:
737            suffix_split = col.split(combine_suffix)
738            if suffix_split[0] == col:
739                unique_stems.update([col])
740            else:
741                unique_stems.update(['_'.join(suffix_split[:-1])])
742        out = dict()
743        for stem in unique_stems:
744            cols = []
745            for var in df.columns:
746                if var == stem:
747                    cols.append(var)
748                elif '_'.join(var.split('_')[:-1]) == stem:
749                    cols.append(var)
750            if len(cols) == 1:
751                targets = df[cols].values.flatten().tolist()
752            else:
753                # ensure the tail ordinate sorts the columns, not string order
754                # '1','11','2' will corrupt the trace
755                order = [int(st.split(combine_suffix)[-1]) for st in cols]
756                cols = np.asarray(cols)[np.argsort(order)]
757                targets = [vec for vec in df[cols].values]
758            out.update({stem:targets})
759        return cls(**out)
760
761    @classmethod
762    def from_pymc3(cls, pymc3trace):
763        """
764        Convert a PyMC3 trace to a spvcm trace
765        """
766        try:
767            from pymc3 import trace_to_dataframe
768        except ImportError:
769            raise ImportError("The 'trace_to_dataframe' function in "
770                              "pymc3 is used for this feature. Pymc3 "
771                              "failed to import.")
772        return cls.from_df(mc.trace_to_dataframe(pymc3trace))
773
774    @classmethod
775    def from_csv(cls, filename=None, multi=False,
776                      varnames=None, combine_suffix='_', **pandas_kwargs):
777        """
778        Read a CSV into a trace object, by way of `Trace.from_df()`
779
780        Arguments
781        ----------
782        filename    :   string
783                        string containing the name of the file to read.
784        multi       :   bool
785                        flag denoting whether the trace being read is a multitrace or not. If so, the filename is understood to be the prefix of many files that end in `filename_#.csv`
786        varnames    :   string or list of strings
787                        custom names to use for the trace. If not provided, combine suffix is used to identify the unique prefixes in the csvs.
788        pandas_kawrgs:  keyword arguments
789                        keyword arguments to pass to the pandas functions.
790        """
791        if multi:
792            filepath = os.path.dirname(os.path.abspath(filename))
793            filestem = os.path.basename(filename)
794            targets = [f for f in os.listdir(filepath)
795                         if f.startswith(filestem)]
796            ordinates = [int(os.path.splitext(fname)[0].split(combine_suffix)[-1])
797                         for fname in targets]
798            # preserve the order of the trailing ordinates
799            targets = np.asarray(targets)[np.argsort(ordinates)].tolist()
800            traces = ([cls.from_csv(filename=os.path.join(filepath, f)
801                                    ,multi=False) for f in targets])
802            if traces == []:
803                raise IOError("No such file or directory: " +
804                                        filepath + filestem)
805
806            return cls(*[trace.chains[0] for trace in traces])
807        else:
808            df = pd.read_csv(filename, **pandas_kwargs)
809            return cls.from_df(df, varnames=varnames,
810                               combine_suffix=combine_suffix)
811
812
813####################
814# HELPER FUNCTIONS #
815####################
816
817def _ifilter(filt,iterable):
818    """
819    Filter an iterable by whether or not each item is in the filt
820    """
821    try:
822        return iterable[filt]
823    except:
824        if isinstance(filt, (int, float)):
825            filt = [filt]
826        return [val for i,val in enumerate(iterable) if i in filt]
827
828def _maybe_hashmap(*collections):
829    """
830    Attempt to coerce a collection into a Hashmap. Otherwise, leave it alone.
831    """
832    out = []
833    for collection in collections:
834        if isinstance(collection, Hashmap):
835            out.append(collection)
836        else:
837            out.append(Hashmap(**collection))
838    return out
839
840def _copy_hashmaps(*hashmaps):
841    """
842    Create deep copies of the hashmaps passed to the function.
843    """
844    return [Hashmap(**{k:copy.deepcopy(v) for k,v in hashmap.items()})
845            for hashmap in hashmaps]
846