1r"""
2This file contains some useful objects for handling a finite-state
3discrete-time Markov chain.
4
5Definitions and Some Basic Facts about Markov Chains
6----------------------------------------------------
7
8Let :math:`\{X_t\}` be a Markov chain represented by an :math:`n \times
9n` stochastic matrix :math:`P`. State :math:`i` *has access* to state
10:math:`j`, denoted :math:`i \to j`, if :math:`i = j` or :math:`P^k[i, j]
11> 0` for some :math:`k = 1, 2, \ldots`; :math:`i` and `j` *communicate*,
12denoted :math:`i \leftrightarrow j`, if :math:`i \to j` and :math:`j \to
13i`. The binary relation :math:`\leftrightarrow` is an equivalent
14relation. A *communication class* of the Markov chain :math:`\{X_t\}`,
15or of the stochastic matrix :math:`P`, is an equivalent class of
16:math:`\leftrightarrow`. Equivalently, a communication class is a
17*strongly connected component* (SCC) in the associated *directed graph*
18:math:`\Gamma(P)`, a directed graph with :math:`n` nodes where there is
19an edge from :math:`i` to :math:`j` if and only if :math:`P[i, j] > 0`.
20The Markov chain, or the stochastic matrix, is *irreducible* if it
21admits only one communication class, or equivalently, if
22:math:`\Gamma(P)` is *strongly connected*.
23
24A state :math:`i` is *recurrent* if :math:`i \to j` implies :math:`j \to
25i`; it is *transient* if it is not recurrent. For any :math:`i, j`
26contained in a communication class, :math:`i` is recurrent if and only
27if :math:`j` is recurrent. Therefore, recurrence is a property of a
28communication class. Thus, a communication class is a *recurrent class*
29if it contains a recurrent state. Equivalently, a recurrent class is a
30SCC that corresponds to a sink node in the *condensation* of the
31directed graph :math:`\Gamma(P)`, where the condensation of
32:math:`\Gamma(P)` is a directed graph in which each SCC is replaced with
33a single node and there is an edge from one SCC :math:`C` to another SCC
34:math:`C'` if :math:`C \neq C'` and there is an edge from some node in
35:math:`C` to some node in :math:`C'`. A recurrent class is also called a
36*closed communication class*. The condensation is acyclic, so that there
37exists at least one recurrent class.
38
39For example, if the entries of :math:`P` are all strictly positive, then
40the whole state space is a communication class as well as a recurrent
41class. (More generally, if there is only one communication class, then
42it is a recurrent class.) As another example, consider the stochastic
43matrix :math:`P = [[1, 0], [0,5, 0.5]]`. This has two communication
44classes, :math:`\{0\}` and :math:`\{1\}`, and :math:`\{0\}` is the only
45recurrent class.
46
47A *stationary distribution* of the Markov chain :math:`\{X_t\}`, or of
48the stochastic matrix :math:`P`, is a nonnegative vector :math:`x` such
49that :math:`x' P = x'` and :math:`x' \mathbf{1} = 1`, where
50:math:`\mathbf{1}` is the vector of ones. The Markov chain has a unique
51stationary distribution if and only if it has a unique recurrent class.
52More generally, each recurrent class has a unique stationary
53distribution whose support equals that recurrent class. The set of all
54stationary distributions is given by the convex hull of these unique
55stationary distributions for the recurrent classes.
56
57A natural number :math:`d` is the *period* of state :math:`i` if it is
58the greatest common divisor of all :math:`k`'s such that :math:`P^k[i,
59i] > 0`; equivalently, it is the GCD of the lengths of the cycles in
60:math:`\Gamma(P)` passing through :math:`i`. For any :math:`i, j`
61contained in a communication class, :math:`i` has period :math:`d` if
62and only if :math:`j` has period :math:`d`. The *period* of an
63irreducible Markov chain (or of an irreducible stochastic matrix) is the
64period of any state. We define the period of a general (not necessarily
65irreducible) Markov chain to be the least common multiple of the periods
66of its recurrent classes, where the period of a recurrent class is the
67period of any state in that class. A Markov chain is *aperiodic* if its
68period is one. A Markov chain is irreducible and aperiodic if and only
69if it is *uniformly ergodic*, i.e., there exists some :math:`m` such
70that :math:`P^m[i, j] > 0` for all :math:`i, j` (in this case, :math:`P`
71is also called *primitive*).
72
73Suppose that an irreducible Markov chain has period :math:`d`. Fix any
74state, say state :math:`0`. For each :math:`m = 0, \ldots, d-1`, let
75:math:`S_m` be the set of states :math:`i` such that :math:`P^{kd+m}[0,
76i] > 0` for some :math:`k`. These sets :math:`S_0, \ldots, S_{d-1}`
77constitute a partition of the state space and are called the *cyclic
78classes*. For each :math:`S_m` and each :math:`i \in S_m`, we have
79:math:`\sum_{j \in S_{m+1}} P[i, j] = 1`, where :math:`S_d = S_0`.
80
81"""
82import numbers
83from math import gcd
84import numpy as np
85from scipy import sparse
86from numba import jit
87
88from .gth_solve import gth_solve
89from ..graph_tools import DiGraph
90from ..util import searchsorted, check_random_state
91
92
93class MarkovChain:
94    """
95    Class for a finite-state discrete-time Markov chain. It stores
96    useful information such as the stationary distributions, and
97    communication, recurrent, and cyclic classes, and allows simulation
98    of state transitions.
99
100    Parameters
101    ----------
102    P : array_like or scipy sparse matrix (float, ndim=2)
103        The transition matrix.  Must be of shape n x n.
104
105    state_values : array_like(default=None)
106        Array_like of length n containing the values associated with the
107        states, which must be homogeneous in type. If None, the values
108        default to integers 0 through n-1.
109
110    Attributes
111    ----------
112    P : ndarray or scipy.sparse.csr_matrix (float, ndim=2)
113        See Parameters
114
115    stationary_distributions : array_like(float, ndim=2)
116        Array containing stationary distributions, one for each
117        recurrent class, as rows.
118
119    is_irreducible : bool
120        Indicate whether the Markov chain is irreducible.
121
122    num_communication_classes : int
123        The number of the communication classes.
124
125    communication_classes_indices : list(ndarray(int))
126        List of numpy arrays containing the indices of the communication
127        classes.
128
129    communication_classes : list(ndarray)
130        List of numpy arrays containing the communication classes, where
131        the states are annotated with their values (if `state_values` is
132        not None).
133
134    num_recurrent_classes : int
135        The number of the recurrent classes.
136
137    recurrent_classes_indices : list(ndarray(int))
138        List of numpy arrays containing the indices of the recurrent
139        classes.
140
141    recurrent_classes : list(ndarray)
142        List of numpy arrays containing the recurrent classes, where the
143        states are annotated with their values (if `state_values` is not
144        None).
145
146    is_aperiodic : bool
147        Indicate whether the Markov chain is aperiodic.
148
149    period : int
150        The period of the Markov chain.
151
152    cyclic_classes_indices : list(ndarray(int))
153        List of numpy arrays containing the indices of the cyclic
154        classes. Defined only when the Markov chain is irreducible.
155
156    cyclic_classes : list(ndarray)
157        List of numpy arrays containing the cyclic classes, where the
158        states are annotated with their values (if `state_values` is not
159        None). Defined only when the Markov chain is irreducible.
160
161    Notes
162    -----
163    In computing stationary distributions, if the input matrix is a
164    sparse matrix, internally it is converted to a dense matrix.
165
166    """
167
168    def __init__(self, P, state_values=None):
169        if sparse.issparse(P):  # Sparse matrix
170            self.P = sparse.csr_matrix(P)
171            self.is_sparse = True
172        else:  # Dense matrix
173            self.P = np.asarray(P)
174            self.is_sparse = False
175
176        # Check Properties
177        # Double check that P is a square matrix
178        if len(self.P.shape) != 2 or self.P.shape[0] != self.P.shape[1]:
179            raise ValueError('P must be a square matrix')
180
181        # The number of states
182        self.n = self.P.shape[0]
183
184        # Double check that P is a nonnegative matrix
185        if not self.is_sparse:
186            data_nonnegative = (self.P >= 0)  # ndarray
187        else:
188            data_nonnegative = (self.P.data >= 0)  # csr_matrx
189        if not np.all(data_nonnegative):
190            raise ValueError('P must be nonnegative')
191
192        # Double check that the rows of P sum to one
193        row_sums = self.P.sum(axis=1)
194        if self.is_sparse:  # row_sums is np.matrix (ndim=2)
195            row_sums = row_sums.getA1()
196        if not np.allclose(row_sums, np.ones(self.n)):
197            raise ValueError('The rows of P must sum to 1')
198
199        # Call the setter method
200        self.state_values = state_values
201
202        # To analyze the structure of P as a directed graph
203        self._digraph = None
204
205        self._stationary_dists = None
206        self._cdfs = None  # For dense matrix
207        self._cdfs1d = None  # For sparse matrix
208
209    def __repr__(self):
210        msg = "Markov chain with transition matrix \nP = \n{0}"
211
212        if self._stationary_dists is None:
213            return msg.format(self.P)
214        else:
215            msg = msg + "\nand stationary distributions \n{1}"
216            return msg.format(self.P, self._stationary_dists)
217
218    def __str__(self):
219        return str(self.__repr__)
220
221    @property
222    def state_values(self):
223        return self._state_values
224
225    @state_values.setter
226    def state_values(self, values):
227        if values is None:
228            self._state_values = None
229        else:
230            values = np.asarray(values)
231            if (values.ndim < 1) or (values.shape[0] != self.n):
232                raise ValueError(
233                    'state_values must be an array_like of length n'
234                )
235            if np.issubdtype(values.dtype, np.object_):
236                raise ValueError(
237                    'data in state_values must be homogeneous in type'
238                )
239            self._state_values = values
240
241    def get_index(self, value):
242        """
243        Return the index (or indices) of the given value (or values) in
244        `state_values`.
245
246        Parameters
247        ----------
248        value
249            Value(s) to get the index (indices) for.
250
251        Returns
252        -------
253        idx : int or ndarray(int)
254            Index of `value` if `value` is a single state value; array
255            of indices if `value` is an array_like of state values.
256
257        """
258        if self.state_values is None:
259            state_values_ndim = 1
260        else:
261            state_values_ndim = self.state_values.ndim
262
263        values = np.asarray(value)
264
265        if values.ndim <= state_values_ndim - 1:
266            return self._get_index(value)
267        elif values.ndim == state_values_ndim:  # array of values
268            k = values.shape[0]
269            idx = np.empty(k, dtype=int)
270            for i in range(k):
271                idx[i] = self._get_index(values[i])
272            return idx
273        else:
274            raise ValueError('invalid value')
275
276    def _get_index(self, value):
277        """
278        Return the index of the given value in `state_values`.
279
280        Parameters
281        ----------
282        value
283            Value to get the index for.
284
285        Returns
286        -------
287        idx : int
288            Index of `value`.
289
290        """
291        error_msg = 'value {0} not found'.format(value)
292
293        if self.state_values is None:
294            if isinstance(value, numbers.Integral) and (0 <= value < self.n):
295                return value
296            else:
297                raise ValueError(error_msg)
298
299        # if self.state_values is not None:
300        if self.state_values.ndim == 1:
301            try:
302                idx = np.where(self.state_values == value)[0][0]
303                return idx
304            except IndexError:
305                raise ValueError(error_msg)
306        else:
307            idx = 0
308            while idx < self.n:
309                if np.array_equal(self.state_values[idx], value):
310                    return idx
311                idx += 1
312            raise ValueError(error_msg)
313
314    @property
315    def digraph(self):
316        if self._digraph is None:
317            self._digraph = DiGraph(self.P, node_labels=self.state_values)
318        return self._digraph
319
320    @property
321    def is_irreducible(self):
322        return self.digraph.is_strongly_connected
323
324    @property
325    def num_communication_classes(self):
326        return self.digraph.num_strongly_connected_components
327
328    @property
329    def communication_classes_indices(self):
330        return self.digraph.strongly_connected_components_indices
331
332    @property
333    def communication_classes(self):
334        return self.digraph.strongly_connected_components
335
336    @property
337    def num_recurrent_classes(self):
338        return self.digraph.num_sink_strongly_connected_components
339
340    @property
341    def recurrent_classes_indices(self):
342        return self.digraph.sink_strongly_connected_components_indices
343
344    @property
345    def recurrent_classes(self):
346        return self.digraph.sink_strongly_connected_components
347
348    @property
349    def is_aperiodic(self):
350        if self.is_irreducible:
351            return self.digraph.is_aperiodic
352        else:
353            return self.period == 1
354
355    @property
356    def period(self):
357        if self.is_irreducible:
358            return self.digraph.period
359        else:
360            rec_classes = self.recurrent_classes
361
362            # Determine the period, the LCM of the periods of rec_classes
363            d = 1
364            for rec_class in rec_classes:
365                period = self.digraph.subgraph(rec_class).period
366                d = (d * period) // gcd(d, period)
367
368            return d
369
370    @property
371    def cyclic_classes(self):
372        if not self.is_irreducible:
373            raise NotImplementedError(
374                'Not defined for a reducible Markov chain'
375            )
376        else:
377            return self.digraph.cyclic_components
378
379    @property
380    def cyclic_classes_indices(self):
381        if not self.is_irreducible:
382            raise NotImplementedError(
383                'Not defined for a reducible Markov chain'
384            )
385        else:
386            return self.digraph.cyclic_components_indices
387
388    def _compute_stationary(self):
389        """
390        Store the stationary distributions in self._stationary_distributions.
391
392        """
393        if self.is_irreducible:
394            if not self.is_sparse:  # Dense
395                stationary_dists = gth_solve(self.P).reshape(1, self.n)
396            else:  # Sparse
397                stationary_dists = \
398                    gth_solve(self.P.toarray(),
399                              overwrite=True).reshape(1, self.n)
400        else:
401            rec_classes = self.recurrent_classes_indices
402            stationary_dists = np.zeros((len(rec_classes), self.n))
403            for i, rec_class in enumerate(rec_classes):
404                P_rec_class = self.P[np.ix_(rec_class, rec_class)]
405                if self.is_sparse:
406                    P_rec_class = P_rec_class.toarray()
407                stationary_dists[i, rec_class] = \
408                    gth_solve(P_rec_class, overwrite=True)
409
410        self._stationary_dists = stationary_dists
411
412    @property
413    def stationary_distributions(self):
414        if self._stationary_dists is None:
415            self._compute_stationary()
416        return self._stationary_dists
417
418    @property
419    def cdfs(self):
420        if (self._cdfs is None) and not self.is_sparse:
421            # See issue #137#issuecomment-96128186
422            cdfs = np.empty((self.n, self.n), order='C', dtype=self.P.dtype)
423            np.cumsum(self.P, axis=-1, out=cdfs)
424            self._cdfs = cdfs
425        return self._cdfs
426
427    @property
428    def cdfs1d(self):
429        if (self._cdfs1d is None) and self.is_sparse:
430            data = self.P.data
431            indptr = self.P.indptr
432
433            cdfs1d = np.empty(self.P.nnz, order='C', dtype=data.dtype)
434            for i in range(self.n):
435                cdfs1d[indptr[i]:indptr[i+1]] = \
436                    data[indptr[i]:indptr[i+1]].cumsum()
437            self._cdfs1d = cdfs1d
438        return self._cdfs1d
439
440    def simulate_indices(self, ts_length, init=None, num_reps=None,
441                         random_state=None):
442        """
443        Simulate time series of state transitions, where state indices
444        are returned.
445
446        Parameters
447        ----------
448        ts_length : scalar(int)
449            Length of each simulation.
450
451        init : int or array_like(int, ndim=1), optional
452            Initial state(s). If None, the initial state is randomly
453            drawn.
454
455        num_reps : scalar(int), optional(default=None)
456            Number of repetitions of simulation.
457
458        random_state : int or np.random.RandomState, optional
459            Random seed (integer) or np.random.RandomState instance to
460            set the initial state of the random number generator for
461            reproducibility. If None, a randomly initialized RandomState
462            is used.
463
464        Returns
465        -------
466        X : ndarray(ndim=1 or 2)
467            Array containing the state values of the sample path(s). See
468            the `simulate` method for more information.
469
470        """
471        random_state = check_random_state(random_state)
472        dim = 1  # Dimension of the returned array: 1 or 2
473
474        msg_out_of_range = 'index {init} is out of the state space'
475
476        try:
477            k = len(init)  # init is an array
478            dim = 2
479            init_states = np.asarray(init, dtype=int)
480            # Check init_states are in the state space
481            if (init_states >= self.n).any() or (init_states < -self.n).any():
482                idx = np.where(
483                    (init_states >= self.n) + (init_states < -self.n)
484                )[0][0]
485                raise ValueError(msg_out_of_range.format(init=idx))
486            if num_reps is not None:
487                k *= num_reps
488                init_states = np.tile(init_states, num_reps)
489        except TypeError:  # init is a scalar(int) or None
490            k = 1
491            if num_reps is not None:
492                dim = 2
493                k = num_reps
494            if init is None:
495                init_states = random_state.randint(self.n, size=k)
496            elif isinstance(init, numbers.Integral):
497                # Check init is in the state space
498                if init >= self.n or init < -self.n:
499                    raise ValueError(msg_out_of_range.format(init=init))
500                init_states = np.ones(k, dtype=int) * init
501            else:
502                raise ValueError(
503                    'init must be int, array_like of ints, or None'
504                )
505
506        # === set up array to store output === #
507        X = np.empty((k, ts_length), dtype=int)
508
509        # Random values, uniformly sampled from [0, 1)
510        random_values = random_state.random_sample(size=(k, ts_length-1))
511
512        # Generate sample paths and store in X
513        if not self.is_sparse:  # Dense
514            _generate_sample_paths(
515                self.cdfs, init_states, random_values, out=X
516            )
517        else:  # Sparse
518            _generate_sample_paths_sparse(
519                self.cdfs1d, self.P.indices, self.P.indptr, init_states,
520                random_values, out=X
521            )
522
523        if dim == 1:
524            return X[0]
525        else:
526            return X
527
528    def simulate(self, ts_length, init=None, num_reps=None, random_state=None):
529        """
530        Simulate time series of state transitions, where the states are
531        annotated with their values (if `state_values` is not None).
532
533        Parameters
534        ----------
535        ts_length : scalar(int)
536            Length of each simulation.
537
538        init : scalar or array_like, optional(default=None)
539            Initial state values(s). If None, the initial state is
540            randomly drawn.
541
542        num_reps : scalar(int), optional(default=None)
543            Number of repetitions of simulation.
544
545        random_state : int or np.random.RandomState, optional
546            Random seed (integer) or np.random.RandomState instance to
547            set the initial state of the random number generator for
548            reproducibility. If None, a randomly initialized RandomState
549            is used.
550
551        Returns
552        -------
553        X : ndarray(ndim=1 or 2)
554            Array containing the sample path(s), of shape (ts_length,)
555            if init is a scalar (integer) or None and num_reps is None;
556            of shape (k, ts_length) otherwise, where k = len(init) if
557            (init, num_reps) = (array, None), k = num_reps if (init,
558            num_reps) = (int or None, int), and k = len(init)*num_reps
559            if (init, num_reps) = (array, int).
560
561        """
562        if init is not None:
563            init_idx = self.get_index(init)
564        else:
565            init_idx = None
566        X = self.simulate_indices(ts_length, init=init_idx, num_reps=num_reps,
567                                  random_state=random_state)
568
569        # Annotate states
570        if self.state_values is not None:
571            X = self.state_values[X]
572
573        return X
574
575
576@jit(nopython=True)
577def _generate_sample_paths(P_cdfs, init_states, random_values, out):
578    """
579    Generate num_reps sample paths of length ts_length, where num_reps =
580    out.shape[0] and ts_length = out.shape[1].
581
582    Parameters
583    ----------
584    P_cdfs : ndarray(float, ndim=2)
585        Array containing as rows the CDFs of the state transition.
586
587    init_states : array_like(int, ndim=1)
588        Array containing the initial states. Its length must be equal to
589        num_reps.
590
591    random_values : ndarray(float, ndim=2)
592        Array containing random values from [0, 1). Its shape must be
593        equal to (num_reps, ts_length-1)
594
595    out : ndarray(int, ndim=2)
596        Array to store the sample paths.
597
598    Notes
599    -----
600    This routine is jit-complied by Numba.
601
602    """
603    num_reps, ts_length = out.shape
604
605    for i in range(num_reps):
606        out[i, 0] = init_states[i]
607        for t in range(ts_length-1):
608            out[i, t+1] = searchsorted(P_cdfs[out[i, t]], random_values[i, t])
609
610
611@jit(nopython=True)
612def _generate_sample_paths_sparse(P_cdfs1d, indices, indptr, init_states,
613                                  random_values, out):
614    """
615    For sparse matrix.
616
617    Generate num_reps sample paths of length ts_length, where num_reps =
618    out.shape[0] and ts_length = out.shape[1].
619
620    Parameters
621    ----------
622    P_cdfs1d : ndarray(float, ndim=1)
623        1D array containing the CDFs of the state transition.
624
625    indices : ndarray(int, ndim=1)
626        CSR format index array.
627
628    indptr : ndarray(int, ndim=1)
629        CSR format index pointer array.
630
631    init_states : array_like(int, ndim=1)
632        Array containing the initial states. Its length must be equal to
633        num_reps.
634
635    random_values : ndarray(float, ndim=2)
636        Array containing random values from [0, 1). Its shape must be
637        equal to (num_reps, ts_length-1)
638
639    out : ndarray(int, ndim=2)
640        Array to store the sample paths.
641
642    Notes
643    -----
644    This routine is jit-complied by Numba.
645
646    """
647    num_reps, ts_length = out.shape
648
649    for i in range(num_reps):
650        out[i, 0] = init_states[i]
651        for t in range(ts_length-1):
652            k = searchsorted(P_cdfs1d[indptr[out[i, t]]:indptr[out[i, t]+1]],
653                             random_values[i, t])
654            out[i, t+1] = indices[indptr[out[i, t]]+k]
655
656
657def mc_compute_stationary(P):
658    """
659    Computes stationary distributions of P, one for each recurrent
660    class. Any stationary distribution is written as a convex
661    combination of these distributions.
662
663    Returns
664    -------
665    stationary_dists : array_like(float, ndim=2)
666        Array containing the stationary distributions as its rows.
667
668    """
669    return MarkovChain(P).stationary_distributions
670
671
672def mc_sample_path(P, init=0, sample_size=1000, random_state=None):
673    """
674    Generates one sample path from the Markov chain represented by
675    (n x n) transition matrix P on state space S = {{0,...,n-1}}.
676
677    Parameters
678    ----------
679    P : array_like(float, ndim=2)
680        A Markov transition matrix.
681
682    init : array_like(float ndim=1) or scalar(int), optional(default=0)
683        If init is an array_like, then it is treated as the initial
684        distribution across states.  If init is a scalar, then it
685        treated as the deterministic initial state.
686
687    sample_size : scalar(int), optional(default=1000)
688        The length of the sample path.
689
690    random_state : int or np.random.RandomState, optional
691        Random seed (integer) or np.random.RandomState instance to set
692        the initial state of the random number generator for
693        reproducibility. If None, a randomly initialized RandomState is
694        used.
695
696    Returns
697    -------
698    X : array_like(int, ndim=1)
699        The simulation of states.
700
701    """
702    random_state = check_random_state(random_state)
703
704    if isinstance(init, numbers.Integral):
705        X_0 = init
706    else:
707        cdf0 = np.cumsum(init)
708        u_0 = random_state.random_sample()
709        X_0 = searchsorted(cdf0, u_0)
710
711    mc = MarkovChain(P)
712    return mc.simulate(ts_length=sample_size, init=X_0,
713                       random_state=random_state)
714