1from __future__ import print_function, division
2import random
3
4import itertools
5from typing import Sequence as tSequence, Union as tUnion, List as tList, Tuple as tTuple
6
7from sympy import (Matrix, MatrixSymbol, S, Indexed, Basic, Tuple, Range,
8                   Set, And, Eq, FiniteSet, ImmutableMatrix, Integer, igcd,
9                   Lambda, Mul, Dummy, IndexedBase, Add, Interval, oo,
10                   linsolve, eye, Or, Not, Intersection, factorial, Contains,
11                   Union, Expr, Function, exp, cacheit, sqrt, pi, gamma,
12                   Ge, Piecewise, Symbol, NonSquareMatrixError, EmptySet,
13                   ceiling, MatrixBase, ConditionSet, ones, zeros, Identity,
14                   Rational, Lt, Gt, Le, Ne, BlockMatrix, Sum)
15from sympy.core.relational import Relational
16from sympy.logic.boolalg import Boolean
17from sympy.utilities.exceptions import SymPyDeprecationWarning
18from sympy.utilities.iterables import strongly_connected_components
19from sympy.stats.joint_rv import JointDistribution
20from sympy.stats.joint_rv_types import JointDistributionHandmade
21from sympy.stats.rv import (RandomIndexedSymbol, random_symbols, RandomSymbol,
22                            _symbol_converter, _value_check, pspace, given,
23                           dependent, is_random, sample_iter, Distribution,
24                           Density)
25from sympy.stats.stochastic_process import StochasticPSpace
26from sympy.stats.symbolic_probability import Probability, Expectation
27from sympy.stats.frv_types import Bernoulli, BernoulliDistribution, FiniteRV
28from sympy.stats.drv_types import Poisson, PoissonDistribution
29from sympy.stats.crv_types import Normal, NormalDistribution, Gamma, GammaDistribution
30from sympy.core.sympify import _sympify, sympify
31
32__all__ = [
33    'StochasticProcess',
34    'DiscreteTimeStochasticProcess',
35    'DiscreteMarkovChain',
36    'TransitionMatrixOf',
37    'StochasticStateSpaceOf',
38    'GeneratorMatrixOf',
39    'ContinuousMarkovChain',
40    'BernoulliProcess',
41    'PoissonProcess',
42    'WienerProcess',
43    'GammaProcess'
44]
45
46
47@is_random.register(Indexed)
48def _(x):
49    return is_random(x.base)
50
51@is_random.register(RandomIndexedSymbol)  # type: ignore
52def _(x):
53    return True
54
55def _set_converter(itr):
56    """
57    Helper function for converting list/tuple/set to Set.
58    If parameter is not an instance of list/tuple/set then
59    no operation is performed.
60
61    Returns
62    =======
63
64    Set
65        The argument converted to Set.
66
67
68    Raises
69    ======
70
71    TypeError
72        If the argument is not an instance of list/tuple/set.
73    """
74    if isinstance(itr, (list, tuple, set)):
75        itr = FiniteSet(*itr)
76    if not isinstance(itr, Set):
77        raise TypeError("%s is not an instance of list/tuple/set."%(itr))
78    return itr
79
80def _state_converter(itr: tSequence) -> tUnion[Tuple, Range]:
81    """
82    Helper function for converting list/tuple/set/Range/Tuple/FiniteSet
83    to tuple/Range.
84    """
85    if isinstance(itr, (Tuple, set, FiniteSet)):
86        itr = Tuple(*(sympify(i) if isinstance(i, str) else i for i in itr))
87
88    elif isinstance(itr, (list, tuple)):
89        # check if states are unique
90        if len(set(itr)) != len(itr):
91            raise ValueError('The state space must have unique elements.')
92        itr = Tuple(*(sympify(i) if isinstance(i, str) else i for i in itr))
93
94    elif isinstance(itr, Range):
95        # the only ordered set in sympy I know of
96        # try to convert to tuple
97        try:
98            itr = Tuple(*(sympify(i) if isinstance(i, str) else i for i in itr))
99        except (TypeError, ValueError):
100            pass
101
102    else:
103        raise TypeError("%s is not an instance of list/tuple/set/Range/Tuple/FiniteSet." % (itr))
104    return itr
105
106def _sym_sympify(arg):
107    """
108    Converts an arbitrary expression to a type that can be used inside SymPy.
109    As generally strings are unwise to use in the expressions,
110    it returns the Symbol of argument if the string type argument is passed.
111
112    Parameters
113    =========
114
115    arg: The parameter to be converted to be used in Sympy.
116
117    Returns
118    =======
119
120    The converted parameter.
121
122    """
123    if isinstance(arg, str):
124        return Symbol(arg)
125    else:
126        return _sympify(arg)
127
128def _matrix_checks(matrix):
129    if not isinstance(matrix, (Matrix, MatrixSymbol, ImmutableMatrix)):
130        raise TypeError("Transition probabilities either should "
131                            "be a Matrix or a MatrixSymbol.")
132    if matrix.shape[0] != matrix.shape[1]:
133        raise NonSquareMatrixError("%s is not a square matrix"%(matrix))
134    if isinstance(matrix, Matrix):
135        matrix = ImmutableMatrix(matrix.tolist())
136    return matrix
137
138class StochasticProcess(Basic):
139    """
140    Base class for all the stochastic processes whether
141    discrete or continuous.
142
143    Parameters
144    ==========
145
146    sym: Symbol or str
147    state_space: Set
148        The state space of the stochastic process, by default S.Reals.
149        For discrete sets it is zero indexed.
150
151    See Also
152    ========
153
154    DiscreteTimeStochasticProcess
155    """
156
157    index_set = S.Reals
158
159    def __new__(cls, sym, state_space=S.Reals, **kwargs):
160        sym = _symbol_converter(sym)
161        state_space = _set_converter(state_space)
162        return Basic.__new__(cls, sym, state_space)
163
164    @property
165    def symbol(self):
166        return self.args[0]
167
168    @property
169    def state_space(self) -> tUnion[FiniteSet, Range]:
170        if not isinstance(self.args[1], (FiniteSet, Range)):
171            return FiniteSet(*self.args[1])
172        return self.args[1]
173
174    def _deprecation_warn_distribution(self):
175        SymPyDeprecationWarning(
176            feature="Calling distribution with RandomIndexedSymbol",
177            useinstead="distribution with just timestamp as argument",
178            issue=20078,
179            deprecated_since_version="1.7.1"
180        ).warn()
181
182    def distribution(self, key=None):
183        if key is None:
184            self._deprecation_warn_distribution()
185        return Distribution()
186
187    def density(self, x):
188        return Density()
189
190    def __call__(self, time):
191        """
192        Overridden in ContinuousTimeStochasticProcess.
193        """
194        raise NotImplementedError("Use [] for indexing discrete time stochastic process.")
195
196    def __getitem__(self, time):
197        """
198        Overridden in DiscreteTimeStochasticProcess.
199        """
200        raise NotImplementedError("Use () for indexing continuous time stochastic process.")
201
202    def probability(self, condition):
203        raise NotImplementedError()
204
205    def joint_distribution(self, *args):
206        """
207        Computes the joint distribution of the random indexed variables.
208
209        Parameters
210        ==========
211
212        args: iterable
213            The finite list of random indexed variables/the key of a stochastic
214            process whose joint distribution has to be computed.
215
216        Returns
217        =======
218
219        JointDistribution
220            The joint distribution of the list of random indexed variables.
221            An unevaluated object is returned if it is not possible to
222            compute the joint distribution.
223
224        Raises
225        ======
226
227        ValueError: When the arguments passed are not of type RandomIndexSymbol
228        or Number.
229        """
230        args = list(args)
231        for i, arg in enumerate(args):
232            if S(arg).is_Number:
233                if self.index_set.is_subset(S.Integers):
234                    args[i] = self.__getitem__(arg)
235                else:
236                    args[i] = self.__call__(arg)
237            elif not isinstance(arg, RandomIndexedSymbol):
238                raise ValueError("Expected a RandomIndexedSymbol or "
239                                "key not  %s"%(type(arg)))
240
241        if args[0].pspace.distribution == Distribution():
242            return JointDistribution(*args)
243        density = Lambda(tuple(args),
244                         expr=Mul.fromiter(arg.pspace.process.density(arg) for arg in args))
245        return JointDistributionHandmade(density)
246
247    def expectation(self, condition, given_condition):
248        raise NotImplementedError("Abstract method for expectation queries.")
249
250    def sample(self):
251        raise NotImplementedError("Abstract method for sampling queries.")
252
253class DiscreteTimeStochasticProcess(StochasticProcess):
254    """
255    Base class for all discrete stochastic processes.
256    """
257    def __getitem__(self, time):
258        """
259        For indexing discrete time stochastic processes.
260
261        Returns
262        =======
263
264        RandomIndexedSymbol
265        """
266        time = sympify(time)
267        if not time.is_symbol and time not in self.index_set:
268            raise IndexError("%s is not in the index set of %s"%(time, self.symbol))
269        idx_obj = Indexed(self.symbol, time)
270        pspace_obj = StochasticPSpace(self.symbol, self, self.distribution(time))
271        return RandomIndexedSymbol(idx_obj, pspace_obj)
272
273class ContinuousTimeStochasticProcess(StochasticProcess):
274    """
275    Base class for all continuous time stochastic process.
276    """
277    def __call__(self, time):
278        """
279        For indexing continuous time stochastic processes.
280
281        Returns
282        =======
283
284        RandomIndexedSymbol
285        """
286        time = sympify(time)
287        if not time.is_symbol and time not in self.index_set:
288            raise IndexError("%s is not in the index set of %s"%(time, self.symbol))
289        func_obj = Function(self.symbol)(time)
290        pspace_obj = StochasticPSpace(self.symbol, self, self.distribution(time))
291        return RandomIndexedSymbol(func_obj, pspace_obj)
292
293class TransitionMatrixOf(Boolean):
294    """
295    Assumes that the matrix is the transition matrix
296    of the process.
297    """
298
299    def __new__(cls, process, matrix):
300        if not isinstance(process, DiscreteMarkovChain):
301            raise ValueError("Currently only DiscreteMarkovChain "
302                                "support TransitionMatrixOf.")
303        matrix = _matrix_checks(matrix)
304        return Basic.__new__(cls, process, matrix)
305
306    process = property(lambda self: self.args[0])
307    matrix = property(lambda self: self.args[1])
308
309class GeneratorMatrixOf(TransitionMatrixOf):
310    """
311    Assumes that the matrix is the generator matrix
312    of the process.
313    """
314
315    def __new__(cls, process, matrix):
316        if not isinstance(process, ContinuousMarkovChain):
317            raise ValueError("Currently only ContinuousMarkovChain "
318                                "support GeneratorMatrixOf.")
319        matrix = _matrix_checks(matrix)
320        return Basic.__new__(cls, process, matrix)
321
322class StochasticStateSpaceOf(Boolean):
323
324    def __new__(cls, process, state_space):
325        if not isinstance(process, (DiscreteMarkovChain, ContinuousMarkovChain)):
326            raise ValueError("Currently only DiscreteMarkovChain and ContinuousMarkovChain "
327                                "support StochasticStateSpaceOf.")
328        state_space = _state_converter(state_space)
329        if isinstance(state_space, Range):
330            ss_size = ceiling((state_space.stop - state_space.start) / state_space.step)
331        else:
332            ss_size = len(state_space)
333        state_index = Range(ss_size)
334        return Basic.__new__(cls, process, state_index)
335
336    process = property(lambda self: self.args[0])
337    state_index = property(lambda self: self.args[1])
338
339class MarkovProcess(StochasticProcess):
340    """
341    Contains methods that handle queries
342    common to Markov processes.
343    """
344
345    @property
346    def number_of_states(self) -> tUnion[Integer, Symbol]:
347        """
348        The number of states in the Markov Chain.
349        """
350        return _sympify(self.args[2].shape[0])
351
352    @property
353    def _state_index(self) -> Range:
354        """
355        Returns state index as Range.
356        """
357        return self.args[1]
358
359    @classmethod
360    def _sanity_checks(cls, state_space, trans_probs):
361        # Try to never have None as state_space or trans_probs.
362        # This helps a lot if we get it done at the start.
363        if (state_space is None) and (trans_probs is None):
364            _n = Dummy('n', integer=True, nonnegative=True)
365            state_space = _state_converter(Range(_n))
366            trans_probs = _matrix_checks(MatrixSymbol('_T', _n, _n))
367
368        elif state_space is None:
369            trans_probs = _matrix_checks(trans_probs)
370            state_space = _state_converter(Range(trans_probs.shape[0]))
371
372        elif trans_probs is None:
373            state_space = _state_converter(state_space)
374            if isinstance(state_space, Range):
375                _n = ceiling((state_space.stop - state_space.start) / state_space.step)
376            else:
377                _n = len(state_space)
378            trans_probs = MatrixSymbol('_T', _n, _n)
379
380        else:
381            state_space = _state_converter(state_space)
382            trans_probs = _matrix_checks(trans_probs)
383            # Range object doesn't want to give a symbolic size
384            # so we do it ourselves.
385            if isinstance(state_space, Range):
386                ss_size = ceiling((state_space.stop - state_space.start) / state_space.step)
387            else:
388                ss_size = len(state_space)
389            if ss_size != trans_probs.shape[0]:
390                raise ValueError('The size of the state space and the number of '
391                                 'rows of the transition matrix must be the same.')
392
393        return state_space, trans_probs
394
395    def _extract_information(self, given_condition):
396        """
397        Helper function to extract information, like,
398        transition matrix/generator matrix, state space, etc.
399        """
400        if isinstance(self, DiscreteMarkovChain):
401            trans_probs = self.transition_probabilities
402            state_index = self._state_index
403        elif isinstance(self, ContinuousMarkovChain):
404            trans_probs = self.generator_matrix
405            state_index = self._state_index
406        if isinstance(given_condition, And):
407            gcs = given_condition.args
408            given_condition = S.true
409            for gc in gcs:
410                if isinstance(gc, TransitionMatrixOf):
411                    trans_probs = gc.matrix
412                if isinstance(gc, StochasticStateSpaceOf):
413                    state_index = gc.state_index
414                if isinstance(gc, Relational):
415                    given_condition = given_condition & gc
416        if isinstance(given_condition, TransitionMatrixOf):
417            trans_probs = given_condition.matrix
418            given_condition = S.true
419        if isinstance(given_condition, StochasticStateSpaceOf):
420            state_index = given_condition.state_index
421            given_condition = S.true
422        return trans_probs, state_index, given_condition
423
424    def _check_trans_probs(self, trans_probs, row_sum=1):
425        """
426        Helper function for checking the validity of transition
427        probabilities.
428        """
429        if not isinstance(trans_probs, MatrixSymbol):
430            rows = trans_probs.tolist()
431            for row in rows:
432                if (sum(row) - row_sum) != 0:
433                    raise ValueError("Values in a row must sum to %s. "
434                    "If you are using Float or floats then please use Rational."%(row_sum))
435
436    def _work_out_state_index(self, state_index, given_condition, trans_probs):
437        """
438        Helper function to extract state space if there
439        is a random symbol in the given condition.
440        """
441        # if given condition is None, then there is no need to work out
442        # state_space from random variables
443        if given_condition != None:
444            rand_var = list(given_condition.atoms(RandomSymbol) -
445                        given_condition.atoms(RandomIndexedSymbol))
446            if len(rand_var) == 1:
447                state_index = rand_var[0].pspace.set
448
449        # `not None` is `True`. So the old test fails for symbolic sizes.
450        # Need to build the statement differently.
451        sym_cond = not isinstance(self.number_of_states, (int, Integer))
452        cond1 = not sym_cond and len(state_index) != trans_probs.shape[0]
453        if cond1:
454            raise ValueError("state space is not compatible with the transition probabilities.")
455        if not isinstance(trans_probs.shape[0], Symbol):
456            state_index = FiniteSet(*[i for i in range(trans_probs.shape[0])])
457        return state_index
458
459    @cacheit
460    def _preprocess(self, given_condition, evaluate):
461        """
462        Helper function for pre-processing the information.
463        """
464        is_insufficient = False
465
466        if not evaluate: # avoid pre-processing if the result is not to be evaluated
467            return (True, None, None, None)
468
469        # extracting transition matrix and state space
470        trans_probs, state_index, given_condition = self._extract_information(given_condition)
471
472        # given_condition does not have sufficient information
473        # for computations
474        if trans_probs is None or \
475            given_condition is None:
476            is_insufficient = True
477        else:
478            # checking transition probabilities
479            if isinstance(self, DiscreteMarkovChain):
480                self._check_trans_probs(trans_probs, row_sum=1)
481            elif isinstance(self, ContinuousMarkovChain):
482                self._check_trans_probs(trans_probs, row_sum=0)
483
484            # working out state space
485            state_index = self._work_out_state_index(state_index, given_condition, trans_probs)
486
487        return is_insufficient, trans_probs, state_index, given_condition
488
489    def replace_with_index(self, condition):
490        if isinstance(condition, Relational):
491            lhs, rhs = condition.lhs, condition.rhs
492            if not isinstance(lhs, RandomIndexedSymbol):
493                lhs, rhs = rhs, lhs
494            condition = type(condition)(self.index_of.get(lhs, lhs),
495                                        self.index_of.get(rhs, rhs))
496        return condition
497
498    def probability(self, condition, given_condition=None, evaluate=True, **kwargs):
499        """
500        Handles probability queries for Markov process.
501
502        Parameters
503        ==========
504
505        condition: Relational
506        given_condition: Relational/And
507
508        Returns
509        =======
510        Probability
511            If the information is not sufficient.
512        Expr
513            In all other cases.
514
515        Note
516        ====
517        Any information passed at the time of query overrides
518        any information passed at the time of object creation like
519        transition probabilities, state space.
520        Pass the transition matrix using TransitionMatrixOf,
521        generator matrix using GeneratorMatrixOf and state space
522        using StochasticStateSpaceOf in given_condition using & or And.
523        """
524        check, mat, state_index, new_given_condition = \
525            self._preprocess(given_condition, evaluate)
526
527        rv = list(condition.atoms(RandomIndexedSymbol))
528        symbolic = False
529        for sym in rv:
530            if sym.key.is_symbol:
531                symbolic = True
532                break
533
534        if check:
535            return Probability(condition, new_given_condition)
536
537        if isinstance(self, ContinuousMarkovChain):
538            trans_probs = self.transition_probabilities(mat)
539        elif isinstance(self, DiscreteMarkovChain):
540            trans_probs = mat
541        condition = self.replace_with_index(condition)
542        given_condition = self.replace_with_index(given_condition)
543        new_given_condition = self.replace_with_index(new_given_condition)
544
545        if isinstance(condition, Relational):
546            if isinstance(new_given_condition, And):
547                gcs = new_given_condition.args
548            else:
549                gcs = (new_given_condition, )
550            min_key_rv = list(new_given_condition.atoms(RandomIndexedSymbol))
551
552            if len(min_key_rv):
553                min_key_rv = min_key_rv[0]
554                for r in rv:
555                    if min_key_rv.key.is_symbol or r.key.is_symbol:
556                        continue
557                    if min_key_rv.key > r.key:
558                        return Probability(condition)
559            else:
560                min_key_rv = None
561                return Probability(condition)
562
563            if symbolic:
564                return self._symbolic_probability(condition, new_given_condition, rv, min_key_rv)
565
566            if len(rv) > 1:
567                rv[0] = condition.lhs
568                rv[1] = condition.rhs
569                if rv[0].key < rv[1].key:
570                        rv[0], rv[1] = rv[1], rv[0]
571                        if isinstance(condition, Gt):
572                            condition = Lt(condition.lhs, condition.rhs)
573                        elif isinstance(condition, Lt):
574                            condition = Gt(condition.lhs, condition.rhs)
575                        elif isinstance(condition, Ge):
576                            condition = Le(condition.lhs, condition.rhs)
577                        elif isinstance(condition, Le):
578                            condition = Ge(condition.lhs, condition.rhs)
579                s = Rational(0, 1)
580                n = len(self.state_space)
581
582                if isinstance(condition, Eq) or isinstance(condition, Ne):
583                    for i in range(0, n):
584                        s += self.probability(Eq(rv[0], i), Eq(rv[1], i)) * self.probability(Eq(rv[1], i), new_given_condition)
585                    return s if isinstance(condition, Eq) else 1 - s
586                else:
587                    upper = 0
588                    greater = False
589                    if isinstance(condition, Ge) or isinstance(condition, Lt):
590                        upper = 1
591                    if isinstance(condition, Gt) or isinstance(condition, Ge):
592                        greater = True
593
594                    for i in range(0, n):
595                        if i <= n//2:
596                            for j in range(0, i + upper):
597                                s += self.probability(Eq(rv[0], i), Eq(rv[1], j)) * self.probability(Eq(rv[1], j), new_given_condition)
598                        else:
599                            s += self.probability(Eq(rv[0], i), new_given_condition)
600                            for j in range(i + upper, n):
601                                s -= self.probability(Eq(rv[0], i), Eq(rv[1], j)) * self.probability(Eq(rv[1], j), new_given_condition)
602                    return s if greater else 1 - s
603
604            rv = rv[0]
605            states = condition.as_set()
606            prob, gstate = dict(), None
607            for gc in gcs:
608                if gc.has(min_key_rv):
609                    if gc.has(Probability):
610                        p, gp = (gc.rhs, gc.lhs) if isinstance(gc.lhs, Probability) \
611                                    else (gc.lhs, gc.rhs)
612                        gr = gp.args[0]
613                        gset = Intersection(gr.as_set(), state_index)
614                        gstate = list(gset)[0]
615                        prob[gset] = p
616                    else:
617                        _, gstate = (gc.lhs.key, gc.rhs) if isinstance(gc.lhs, RandomIndexedSymbol) \
618                                    else (gc.rhs.key, gc.lhs)
619
620            if any((k not in self.index_set) for k in (rv.key, min_key_rv.key)):
621                raise IndexError("The timestamps of the process are not in it's index set.")
622            states = Intersection(states, state_index) if not isinstance(self.number_of_states, Symbol) else states
623            for state in Union(states, FiniteSet(gstate)):
624                if not isinstance(state, (int, Integer)) or Ge(state, mat.shape[0]) is True:
625                    raise IndexError("No information is available for (%s, %s) in "
626                        "transition probabilities of shape, (%s, %s). "
627                        "State space is zero indexed."
628                        %(gstate, state, mat.shape[0], mat.shape[1]))
629            if prob:
630                gstates = Union(*prob.keys())
631                if len(gstates) == 1:
632                    gstate = list(gstates)[0]
633                    gprob = list(prob.values())[0]
634                    prob[gstates] = gprob
635                elif len(gstates) == len(state_index) - 1:
636                    gstate = list(state_index - gstates)[0]
637                    gprob = S.One - sum(prob.values())
638                    prob[state_index - gstates] = gprob
639                else:
640                    raise ValueError("Conflicting information.")
641            else:
642                gprob = S.One
643
644            if min_key_rv == rv:
645                return sum([prob[FiniteSet(state)] for state in states])
646            if isinstance(self, ContinuousMarkovChain):
647                return gprob * sum([trans_probs(rv.key - min_key_rv.key).__getitem__((gstate, state))
648                                    for state in states])
649            if isinstance(self, DiscreteMarkovChain):
650                return gprob * sum([(trans_probs**(rv.key - min_key_rv.key)).__getitem__((gstate, state))
651                                    for state in states])
652
653        if isinstance(condition, Not):
654            expr = condition.args[0]
655            return S.One - self.probability(expr, given_condition, evaluate, **kwargs)
656
657        if isinstance(condition, And):
658            compute_later, state2cond, conds = [], dict(), condition.args
659            for expr in conds:
660                if isinstance(expr, Relational):
661                    ris = list(expr.atoms(RandomIndexedSymbol))[0]
662                    if state2cond.get(ris, None) is None:
663                        state2cond[ris] = S.true
664                    state2cond[ris] &= expr
665                else:
666                    compute_later.append(expr)
667            ris = []
668            for ri in state2cond:
669                ris.append(ri)
670                cset = Intersection(state2cond[ri].as_set(), state_index)
671                if len(cset) == 0:
672                    return S.Zero
673                state2cond[ri] = cset.as_relational(ri)
674            sorted_ris = sorted(ris, key=lambda ri: ri.key)
675            prod = self.probability(state2cond[sorted_ris[0]], given_condition, evaluate, **kwargs)
676            for i in range(1, len(sorted_ris)):
677                ri, prev_ri = sorted_ris[i], sorted_ris[i-1]
678                if not isinstance(state2cond[ri], Eq):
679                    raise ValueError("The process is in multiple states at %s, unable to determine the probability."%(ri))
680                mat_of = TransitionMatrixOf(self, mat) if isinstance(self, DiscreteMarkovChain) else GeneratorMatrixOf(self, mat)
681                prod *= self.probability(state2cond[ri], state2cond[prev_ri]
682                                 & mat_of
683                                 & StochasticStateSpaceOf(self, state_index),
684                                 evaluate, **kwargs)
685            for expr in compute_later:
686                prod *= self.probability(expr, given_condition, evaluate, **kwargs)
687            return prod
688
689        if isinstance(condition, Or):
690            return sum([self.probability(expr, given_condition, evaluate, **kwargs)
691                        for expr in condition.args])
692
693        raise NotImplementedError("Mechanism for handling (%s, %s) queries hasn't been "
694                                "implemented yet."%(condition, given_condition))
695
696    def _symbolic_probability(self, condition, new_given_condition, rv, min_key_rv):
697        #Function to calculate probability for queries with symbols
698        if isinstance(condition, Relational):
699            curr_state = new_given_condition.rhs if isinstance(new_given_condition.lhs, RandomIndexedSymbol) \
700                    else new_given_condition.lhs
701            next_state = condition.rhs if isinstance(condition.lhs, RandomIndexedSymbol) \
702                else condition.lhs
703
704            if isinstance(condition, Eq) or isinstance(condition, Ne):
705                if isinstance(self, DiscreteMarkovChain):
706                    P = self.transition_probabilities**(rv[0].key - min_key_rv.key)
707                else:
708                    P = exp(self.generator_matrix*(rv[0].key - min_key_rv.key))
709                prob = P[curr_state, next_state] if isinstance(condition, Eq) else 1 - P[curr_state, next_state]
710                return Piecewise((prob, rv[0].key > min_key_rv.key), (Probability(condition), True))
711            else:
712                upper = 1
713                greater = False
714                if isinstance(condition, Ge) or isinstance(condition, Lt):
715                    upper = 0
716                if isinstance(condition, Gt) or isinstance(condition, Ge):
717                    greater = True
718                k = Dummy('k')
719                condition = Eq(condition.lhs, k) if isinstance(condition.lhs, RandomIndexedSymbol)\
720                    else Eq(condition.rhs, k)
721                total = Sum(self.probability(condition, new_given_condition), (k, next_state + upper, self.state_space._sup))
722                return Piecewise((total, rv[0].key > min_key_rv.key), (Probability(condition), True)) if greater\
723                    else Piecewise((1 - total, rv[0].key > min_key_rv.key), (Probability(condition), True))
724        else:
725            return Probability(condition, new_given_condition)
726
727    def expectation(self, expr, condition=None, evaluate=True, **kwargs):
728        """
729        Handles expectation queries for markov process.
730
731        Parameters
732        ==========
733
734        expr: RandomIndexedSymbol, Relational, Logic
735            Condition for which expectation has to be computed. Must
736            contain a RandomIndexedSymbol of the process.
737        condition: Relational, Logic
738            The given conditions under which computations should be done.
739
740        Returns
741        =======
742
743        Expectation
744            Unevaluated object if computations cannot be done due to
745            insufficient information.
746        Expr
747            In all other cases when the computations are successful.
748
749        Note
750        ====
751
752        Any information passed at the time of query overrides
753        any information passed at the time of object creation like
754        transition probabilities, state space.
755
756        Pass the transition matrix using TransitionMatrixOf,
757        generator matrix using GeneratorMatrixOf and state space
758        using StochasticStateSpaceOf in given_condition using & or And.
759        """
760
761        check, mat, state_index, condition = \
762            self._preprocess(condition, evaluate)
763
764        if check:
765            return Expectation(expr, condition)
766
767        rvs = random_symbols(expr)
768        if isinstance(expr, Expr) and isinstance(condition, Eq) \
769            and len(rvs) == 1:
770            # handle queries similar to E(f(X[i]), Eq(X[i-m], <some-state>))
771            condition=self.replace_with_index(condition)
772            state_index=self.replace_with_index(state_index)
773            rv = list(rvs)[0]
774            lhsg, rhsg = condition.lhs, condition.rhs
775            if not isinstance(lhsg, RandomIndexedSymbol):
776                lhsg, rhsg = (rhsg, lhsg)
777            if rhsg not in state_index:
778                raise ValueError("%s state is not in the state space."%(rhsg))
779            if rv.key < lhsg.key:
780                raise ValueError("Incorrect given condition is given, expectation "
781                    "time %s < time %s"%(rv.key, rv.key))
782            mat_of = TransitionMatrixOf(self, mat) if isinstance(self, DiscreteMarkovChain) else GeneratorMatrixOf(self, mat)
783            cond = condition & mat_of & \
784                    StochasticStateSpaceOf(self, state_index)
785            func = lambda s: self.probability(Eq(rv, s), cond) * expr.subs(rv, self._state_index[s])
786            return sum([func(s) for s in state_index])
787
788        raise NotImplementedError("Mechanism for handling (%s, %s) queries hasn't been "
789                                "implemented yet."%(expr, condition))
790
791class DiscreteMarkovChain(DiscreteTimeStochasticProcess, MarkovProcess):
792    """
793    Represents a finite discrete time-homogeneous Markov chain.
794
795    This type of Markov Chain can be uniquely characterised by
796    its (ordered) state space and its one-step transition probability
797    matrix.
798
799    Parameters
800    ==========
801
802    sym:
803        The name given to the Markov Chain
804    state_space:
805        Optional, by default, Range(n)
806    trans_probs:
807        Optional, by default, MatrixSymbol('_T', n, n)
808
809    Examples
810    ========
811
812    >>> from sympy.stats import DiscreteMarkovChain, TransitionMatrixOf, P, E
813    >>> from sympy import Matrix, MatrixSymbol, Eq, symbols
814    >>> T = Matrix([[0.5, 0.2, 0.3],[0.2, 0.5, 0.3],[0.2, 0.3, 0.5]])
815    >>> Y = DiscreteMarkovChain("Y", [0, 1, 2], T)
816    >>> YS = DiscreteMarkovChain("Y")
817
818    >>> Y.state_space
819    {0, 1, 2}
820    >>> Y.transition_probabilities
821    Matrix([
822    [0.5, 0.2, 0.3],
823    [0.2, 0.5, 0.3],
824    [0.2, 0.3, 0.5]])
825    >>> TS = MatrixSymbol('T', 3, 3)
826    >>> P(Eq(YS[3], 2), Eq(YS[1], 1) & TransitionMatrixOf(YS, TS))
827    T[0, 2]*T[1, 0] + T[1, 1]*T[1, 2] + T[1, 2]*T[2, 2]
828    >>> P(Eq(Y[3], 2), Eq(Y[1], 1)).round(2)
829    0.36
830
831    Probabilities will be calculated based on indexes rather
832    than state names. For example, with the Sunny-Cloudy-Rainy
833    model with string state names:
834
835    >>> from sympy.core.symbol import Str
836    >>> Y = DiscreteMarkovChain("Y", [Str('Sunny'), Str('Cloudy'), Str('Rainy')], T)
837    >>> P(Eq(Y[3], 2), Eq(Y[1], 1)).round(2)
838    0.36
839
840    This gives the same answer as the ``[0, 1, 2]`` state space.
841    Currently, there is no support for state names within probability
842    and expectation statements. Here is a work-around using ``Str``:
843
844    >>> P(Eq(Str('Rainy'), Y[3]), Eq(Y[1], Str('Cloudy'))).round(2)
845    0.36
846
847    Symbol state names can also be used:
848
849    >>> sunny, cloudy, rainy = symbols('Sunny, Cloudy, Rainy')
850    >>> Y = DiscreteMarkovChain("Y", [sunny, cloudy, rainy], T)
851    >>> P(Eq(Y[3], rainy), Eq(Y[1], cloudy)).round(2)
852    0.36
853
854    Expectations will be calculated as follows:
855
856    >>> E(Y[3], Eq(Y[1], cloudy))
857    0.38*Cloudy + 0.36*Rainy + 0.26*Sunny
858
859    Probability of expressions with multiple RandomIndexedSymbols
860    can also be calculated provided there is only 1 RandomIndexedSymbol
861    in the given condition. It is always better to use Rational instead
862    of floating point numbers for the probabilities in the
863    transition matrix to avoid errors.
864
865    >>> from sympy import Gt, Le, Rational
866    >>> T = Matrix([[Rational(5, 10), Rational(3, 10), Rational(2, 10)], [Rational(2, 10), Rational(7, 10), Rational(1, 10)], [Rational(3, 10), Rational(3, 10), Rational(4, 10)]])
867    >>> Y = DiscreteMarkovChain("Y", [0, 1, 2], T)
868    >>> P(Eq(Y[3], Y[1]), Eq(Y[0], 0)).round(3)
869    0.409
870    >>> P(Gt(Y[3], Y[1]), Eq(Y[0], 0)).round(2)
871    0.36
872    >>> P(Le(Y[15], Y[10]), Eq(Y[8], 2)).round(7)
873    0.6963328
874
875    Symbolic probability queries are also supported
876
877    >>> from sympy import symbols, Matrix, Rational, Eq, Gt
878    >>> from sympy.stats import P, DiscreteMarkovChain
879    >>> a, b, c, d = symbols('a b c d')
880    >>> T = Matrix([[Rational(1, 10), Rational(4, 10), Rational(5, 10)], [Rational(3, 10), Rational(4, 10), Rational(3, 10)], [Rational(7, 10), Rational(2, 10), Rational(1, 10)]])
881    >>> Y = DiscreteMarkovChain("Y", [0, 1, 2], T)
882    >>> query = P(Eq(Y[a], b), Eq(Y[c], d))
883    >>> query.subs({a:10 ,b:2, c:5, d:1}).round(4)
884    0.3096
885    >>> P(Eq(Y[10], 2), Eq(Y[5], 1)).evalf().round(4)
886    0.3096
887    >>> query_gt = P(Gt(Y[a], b), Eq(Y[c], d))
888    >>> query_gt.subs({a:21, b:0, c:5, d:0}).evalf().round(5)
889    0.64705
890    >>> P(Gt(Y[21], 0), Eq(Y[5], 0)).round(5)
891    0.64705
892
893    There is limited support for arbitrarily sized states:
894
895    >>> n = symbols('n', nonnegative=True, integer=True)
896    >>> T = MatrixSymbol('T', n, n)
897    >>> Y = DiscreteMarkovChain("Y", trans_probs=T)
898    >>> Y.state_space
899    Range(0, n, 1)
900    >>> query = P(Eq(Y[a], b), Eq(Y[c], d))
901    >>> query.subs({a:10, b:2, c:5, d:1})
902    (T**5)[1, 2]
903
904    References
905    ==========
906
907    .. [1] https://en.wikipedia.org/wiki/Markov_chain#Discrete-time_Markov_chain
908    .. [2] https://www.dartmouth.edu/~chance/teaching_aids/books_articles/probability_book/Chapter11.pdf
909    """
910    index_set = S.Naturals0
911
912    def __new__(cls, sym, state_space=None, trans_probs=None):
913        # type: (Basic, tUnion[str, Symbol], tSequence, tUnion[MatrixBase, MatrixSymbol]) -> DiscreteMarkovChain
914        sym = _symbol_converter(sym)
915
916        state_space, trans_probs = MarkovProcess._sanity_checks(state_space, trans_probs)
917
918        obj = Basic.__new__(cls, sym, state_space, trans_probs)
919        indices = dict()
920        if isinstance(obj.number_of_states, Integer):
921            for index, state in enumerate(obj._state_index):
922                indices[state] = index
923        obj.index_of = indices
924        return obj
925
926    @property
927    def transition_probabilities(self) -> tUnion[MatrixBase, MatrixSymbol]:
928        """
929        Transition probabilities of discrete Markov chain,
930        either an instance of Matrix or MatrixSymbol.
931        """
932        return self.args[2]
933
934    def communication_classes(self) -> tList[tTuple[tList[Basic], Boolean, Integer]]:
935        """
936        Returns the list of communication classes that partition
937        the states of the markov chain.
938
939        A communication class is defined to be a set of states
940        such that every state in that set is reachable from
941        every other state in that set. Due to its properties
942        this forms a class in the mathematical sense.
943        Communication classes are also known as recurrence
944        classes.
945
946        Returns
947        =======
948
949        classes
950            The ``classes`` are a list of tuples. Each
951            tuple represents a single communication class
952            with its properties. The first element in the
953            tuple is the list of states in the class, the
954            second element is whether the class is recurrent
955            and the third element is the period of the
956            communication class.
957
958        Examples
959        ========
960
961        >>> from sympy.stats import DiscreteMarkovChain
962        >>> from sympy import Matrix
963        >>> T = Matrix([[0, 1, 0],
964        ...             [1, 0, 0],
965        ...             [1, 0, 0]])
966        >>> X = DiscreteMarkovChain('X', [1, 2, 3], T)
967        >>> classes = X.communication_classes()
968        >>> for states, is_recurrent, period in classes:
969        ...     states, is_recurrent, period
970        ([1, 2], True, 2)
971        ([3], False, 1)
972
973        From this we can see that states ``1`` and ``2``
974        communicate, are recurrent and have a period
975        of 2. We can also see state ``3`` is transient
976        with a period of 1.
977
978        Notes
979        =====
980
981        The algorithm used is of order ``O(n**2)`` where
982        ``n`` is the number of states in the markov chain.
983        It uses Tarjan's algorithm to find the classes
984        themselves and then it uses a breadth-first search
985        algorithm to find each class's periodicity.
986        Most of the algorithm's components approach ``O(n)``
987        as the matrix becomes more and more sparse.
988
989        References
990        ==========
991
992        .. [1] http://www.columbia.edu/~ww2040/4701Sum07/4701-06-Notes-MCII.pdf
993        .. [2] http://cecas.clemson.edu/~shierd/Shier/markov.pdf
994        .. [3] https://ujcontent.uj.ac.za/vital/access/services/Download/uj:7506/CONTENT1
995        .. [4] https://www.mathworks.com/help/econ/dtmc.classify.html
996        """
997        n = self.number_of_states
998        T = self.transition_probabilities
999
1000        if isinstance(T, MatrixSymbol):
1001            raise NotImplementedError("Cannot perform the operation with a symbolic matrix.")
1002
1003        # begin Tarjan's algorithm
1004        V = Range(n)
1005        # don't use state names. Rather use state
1006        # indexes since we use them for matrix
1007        # indexing here and later onward
1008        E = [(i, j) for i in V for j in V if T[i, j] != 0]
1009        classes = strongly_connected_components((V, E))
1010        # end Tarjan's algorithm
1011
1012        recurrence = []
1013        periods = []
1014        for class_ in classes:
1015            # begin recurrent check (similar to self._check_trans_probs())
1016            submatrix = T[class_, class_]  # get the submatrix with those states
1017            is_recurrent = S.true
1018            rows = submatrix.tolist()
1019            for row in rows:
1020                if (sum(row) - 1) != 0:
1021                    is_recurrent = S.false
1022                    break
1023            recurrence.append(is_recurrent)
1024            # end recurrent check
1025
1026            # begin breadth-first search
1027            non_tree_edge_values = set()
1028            visited = {class_[0]}
1029            newly_visited = {class_[0]}
1030            level = {class_[0]: 0}
1031            current_level = 0
1032            done = False  # imitate a do-while loop
1033            while not done:  # runs at most len(class_) times
1034                done = len(visited) == len(class_)
1035                current_level += 1
1036
1037                # this loop and the while loop above run a combined len(class_) number of times.
1038                # so this triple nested loop runs through each of the n states once.
1039                for i in newly_visited:
1040
1041                    # the loop below runs len(class_) number of times
1042                    # complexity is around about O(n * avg(len(class_)))
1043                    newly_visited = {j for j in class_ if T[i, j] != 0}
1044
1045                    new_tree_edges = newly_visited.difference(visited)
1046                    for j in new_tree_edges:
1047                        level[j] = current_level
1048
1049                    new_non_tree_edges = newly_visited.intersection(visited)
1050                    new_non_tree_edge_values = {level[i]-level[j]+1 for j in new_non_tree_edges}
1051
1052                    non_tree_edge_values = non_tree_edge_values.union(new_non_tree_edge_values)
1053                    visited = visited.union(new_tree_edges)
1054
1055            # igcd needs at least 2 arguments
1056            positive_ntev = {val_e for val_e in non_tree_edge_values if val_e > 0}
1057            if len(positive_ntev) == 0:
1058                periods.append(len(class_))
1059            elif len(positive_ntev) == 1:
1060                periods.append(positive_ntev.pop())
1061            else:
1062                periods.append(igcd(*positive_ntev))
1063            # end breadth-first search
1064
1065        # convert back to the user's state names
1066        classes = [[self._state_index[i] for i in class_] for class_ in classes]
1067
1068        return sympify(list(zip(classes, recurrence, periods)))
1069
1070    def fundamental_matrix(self):
1071        """
1072        Each entry fundamental matrix can be interpreted as
1073        the expected number of times the chains is in state j
1074        if it started in state i.
1075
1076        References
1077        ==========
1078
1079        .. [1] https://lips.cs.princeton.edu/the-fundamental-matrix-of-a-finite-markov-chain/
1080
1081        """
1082        _, _, _, Q = self.decompose()
1083
1084        if Q.shape[0] > 0:  # if non-ergodic
1085            I = eye(Q.shape[0])
1086            if (I - Q).det() == 0:
1087                raise ValueError("The fundamental matrix doesn't exist.")
1088            return (I - Q).inv().as_immutable()
1089        else:  # if ergodic
1090            P = self.transition_probabilities
1091            I = eye(P.shape[0])
1092            w = self.fixed_row_vector()
1093            W = Matrix([list(w) for i in range(0, P.shape[0])])
1094            if (I - P + W).det() == 0:
1095                raise ValueError("The fundamental matrix doesn't exist.")
1096            return (I - P + W).inv().as_immutable()
1097
1098    def absorbing_probabilities(self):
1099        """
1100        Computes the absorbing probabilities, i.e.,
1101        the ij-th entry of the matrix denotes the
1102        probability of Markov chain being absorbed
1103        in state j starting from state i.
1104        """
1105        _, _, R, _ = self.decompose()
1106        N = self.fundamental_matrix()
1107        if R is None or N is None:
1108            return None
1109        return N*R
1110
1111    def absorbing_probabilites(self):
1112        SymPyDeprecationWarning(
1113            feature="absorbing_probabilites",
1114            useinstead="absorbing_probabilities",
1115            issue=20042,
1116            deprecated_since_version="1.7"
1117        ).warn()
1118        return self.absorbing_probabilities()
1119
1120    def is_regular(self):
1121        tuples = self.communication_classes()
1122        if len(tuples) == 0:
1123            return S.false  # not defined for a 0x0 matrix
1124        classes, _, periods = list(zip(*tuples))
1125        return And(len(classes) == 1, periods[0] == 1)
1126
1127    def is_ergodic(self):
1128        tuples = self.communication_classes()
1129        if len(tuples) == 0:
1130            return S.false  # not defined for a 0x0 matrix
1131        classes, _, _ = list(zip(*tuples))
1132        return S(len(classes) == 1)
1133
1134    def is_absorbing_state(self, state):
1135        trans_probs = self.transition_probabilities
1136        if isinstance(trans_probs, ImmutableMatrix) and \
1137            state < trans_probs.shape[0]:
1138            return S(trans_probs[state, state]) is S.One
1139
1140    def is_absorbing_chain(self):
1141        states, A, B, C = self.decompose()
1142        r = A.shape[0]
1143        return And(r > 0, A == Identity(r).as_explicit())
1144
1145    def stationary_distribution(self, condition_set=False) -> tUnion[ImmutableMatrix, ConditionSet, Lambda]:
1146        """
1147        The stationary distribution is any row vector, p, that solves p = pP,
1148        is row stochastic and each element in p must be nonnegative.
1149        That means in matrix form: :math:`(P-I)^T p^T = 0` and
1150        :math:`(1, ..., 1) p = 1`
1151        where ``P`` is the one-step transition matrix.
1152
1153        All time-homogeneous Markov Chains with a finite state space
1154        have at least one stationary distribution. In addition, if
1155        a finite time-homogeneous Markov Chain is irreducible, the
1156        stationary distribution is unique.
1157
1158        Parameters
1159        ==========
1160
1161        condition_set : bool
1162            If the chain has a symbolic size or transition matrix,
1163            it will return a ``Lambda`` if ``False`` and return a
1164            ``ConditionSet`` if ``True``.
1165
1166        Examples
1167        ========
1168
1169        >>> from sympy.stats import DiscreteMarkovChain
1170        >>> from sympy import Matrix, S
1171
1172        An irreducible Markov Chain
1173
1174        >>> T = Matrix([[S(1)/2, S(1)/2, 0],
1175        ...             [S(4)/5, S(1)/5, 0],
1176        ...             [1, 0, 0]])
1177        >>> X = DiscreteMarkovChain('X', trans_probs=T)
1178        >>> X.stationary_distribution()
1179        Matrix([[8/13, 5/13, 0]])
1180
1181        A reducible Markov Chain
1182
1183        >>> T = Matrix([[S(1)/2, S(1)/2, 0],
1184        ...             [S(4)/5, S(1)/5, 0],
1185        ...             [0, 0, 1]])
1186        >>> X = DiscreteMarkovChain('X', trans_probs=T)
1187        >>> X.stationary_distribution()
1188        Matrix([[8/13 - 8*tau0/13, 5/13 - 5*tau0/13, tau0]])
1189
1190        >>> Y = DiscreteMarkovChain('Y')
1191        >>> Y.stationary_distribution()
1192        Lambda((wm, _T), Eq(wm*_T, wm))
1193
1194        >>> Y.stationary_distribution(condition_set=True)
1195        ConditionSet(wm, Eq(wm*_T, wm))
1196
1197        References
1198        ==========
1199
1200        .. [1] https://www.probabilitycourse.com/chapter11/11_2_6_stationary_and_limiting_distributions.php
1201        .. [2] https://galton.uchicago.edu/~yibi/teaching/stat317/2014/Lectures/Lecture4_6up.pdf
1202
1203        See Also
1204        ========
1205
1206        sympy.stats.DiscreteMarkovChain.limiting_distribution
1207        """
1208        trans_probs = self.transition_probabilities
1209        n = self.number_of_states
1210
1211        if n == 0:
1212            return ImmutableMatrix(Matrix([[]]))
1213
1214        # symbolic matrix version
1215        if isinstance(trans_probs, MatrixSymbol):
1216            wm = MatrixSymbol('wm', 1, n)
1217            if condition_set:
1218                return ConditionSet(wm, Eq(wm * trans_probs, wm))
1219            else:
1220                return Lambda((wm, trans_probs), Eq(wm * trans_probs, wm))
1221
1222        # numeric matrix version
1223        a = Matrix(trans_probs - Identity(n)).T
1224        a[0, 0:n] = ones(1, n)
1225        b = zeros(n, 1)
1226        b[0, 0] = 1
1227
1228        soln = list(linsolve((a, b)))[0]
1229        return ImmutableMatrix([[sol for sol in soln]])
1230
1231    def fixed_row_vector(self):
1232        """
1233        A wrapper for ``stationary_distribution()``.
1234        """
1235        return self.stationary_distribution()
1236
1237    @property
1238    def limiting_distribution(self):
1239        """
1240        The fixed row vector is the limiting
1241        distribution of a discrete Markov chain.
1242        """
1243        return self.fixed_row_vector()
1244
1245    def decompose(self) -> tTuple[tList[Basic], ImmutableMatrix, ImmutableMatrix, ImmutableMatrix]:
1246        """
1247        Decomposes the transition matrix into submatrices with
1248        special properties.
1249
1250        The transition matrix can be decomposed into 4 submatrices:
1251        - A - the submatrix from recurrent states to recurrent states.
1252        - B - the submatrix from transient to recurrent states.
1253        - C - the submatrix from transient to transient states.
1254        - O - the submatrix of zeros for recurrent to transient states.
1255
1256        Returns
1257        =======
1258
1259        states, A, B, C
1260            ``states`` - a list of state names with the first being
1261            the recurrent states and the last being
1262            the transient states in the order
1263            of the row names of A and then the row names of C.
1264            ``A`` - the submatrix from recurrent states to recurrent states.
1265            ``B`` - the submatrix from transient to recurrent states.
1266            ``C`` - the submatrix from transient to transient states.
1267
1268        Examples
1269        ========
1270
1271        >>> from sympy.stats import DiscreteMarkovChain
1272        >>> from sympy import Matrix, S
1273
1274        One can decompose this chain for example:
1275
1276        >>> T = Matrix([[S(1)/2, S(1)/2, 0,      0,      0],
1277        ...             [S(2)/5, S(1)/5, S(2)/5, 0,      0],
1278        ...             [0,      0,      1,      0,      0],
1279        ...             [0,      0,      S(1)/2, S(1)/2, 0],
1280        ...             [S(1)/2, 0,      0,      0, S(1)/2]])
1281        >>> X = DiscreteMarkovChain('X', trans_probs=T)
1282        >>> states, A, B, C = X.decompose()
1283        >>> states
1284        [2, 0, 1, 3, 4]
1285
1286        >>> A   # recurrent to recurrent
1287        Matrix([[1]])
1288
1289        >>> B  # transient to recurrent
1290        Matrix([
1291        [  0],
1292        [2/5],
1293        [1/2],
1294        [  0]])
1295
1296        >>> C  # transient to transient
1297        Matrix([
1298        [1/2, 1/2,   0,   0],
1299        [2/5, 1/5,   0,   0],
1300        [  0,   0, 1/2,   0],
1301        [1/2,   0,   0, 1/2]])
1302
1303        This means that state 2 is the only absorbing state
1304        (since A is a 1x1 matrix). B is a 4x1 matrix since
1305        the 4 remaining transient states all merge into reccurent
1306        state 2. And C is the 4x4 matrix that shows how the
1307        transient states 0, 1, 3, 4 all interact.
1308
1309        See Also
1310        ========
1311
1312        sympy.stats.DiscreteMarkovChain.communication_classes
1313        sympy.stats.DiscreteMarkovChain.canonical_form
1314
1315        References
1316        ==========
1317
1318        .. [1] https://en.wikipedia.org/wiki/Absorbing_Markov_chain
1319        .. [2] http://people.brandeis.edu/~igusa/Math56aS08/Math56a_S08_notes015.pdf
1320        """
1321        trans_probs = self.transition_probabilities
1322
1323        classes = self.communication_classes()
1324        r_states = []
1325        t_states = []
1326
1327        for states, recurrent, period in classes:
1328            if recurrent:
1329                r_states += states
1330            else:
1331                t_states += states
1332
1333        states = r_states + t_states
1334        indexes = [self.index_of[state] for state in states]
1335
1336        A = Matrix(len(r_states), len(r_states),
1337                   lambda i, j: trans_probs[indexes[i], indexes[j]])
1338
1339        B = Matrix(len(t_states), len(r_states),
1340                   lambda i, j: trans_probs[indexes[len(r_states) + i], indexes[j]])
1341
1342        C = Matrix(len(t_states), len(t_states),
1343                   lambda i, j: trans_probs[indexes[len(r_states) + i], indexes[len(r_states) + j]])
1344
1345        return states, A.as_immutable(), B.as_immutable(), C.as_immutable()
1346
1347    def canonical_form(self) -> tTuple[tList[Basic], ImmutableMatrix]:
1348        """
1349        Reorders the one-step transition matrix
1350        so that recurrent states appear first and transient
1351        states appear last. Other representations include inserting
1352        transient states first and recurrent states last.
1353
1354        Returns
1355        =======
1356
1357        states, P_new
1358            ``states`` is the list that describes the order of the
1359            new states in the matrix
1360            so that the ith element in ``states`` is the state of the
1361            ith row of A.
1362            ``P_new`` is the new transition matrix in canonical form.
1363
1364        Examples
1365        ========
1366
1367        >>> from sympy.stats import DiscreteMarkovChain
1368        >>> from sympy import Matrix, S
1369
1370        You can convert your chain into canonical form:
1371
1372        >>> T = Matrix([[S(1)/2, S(1)/2, 0,      0,      0],
1373        ...             [S(2)/5, S(1)/5, S(2)/5, 0,      0],
1374        ...             [0,      0,      1,      0,      0],
1375        ...             [0,      0,      S(1)/2, S(1)/2, 0],
1376        ...             [S(1)/2, 0,      0,      0, S(1)/2]])
1377        >>> X = DiscreteMarkovChain('X', list(range(1, 6)), trans_probs=T)
1378        >>> states, new_matrix = X.canonical_form()
1379        >>> states
1380        [3, 1, 2, 4, 5]
1381
1382        >>> new_matrix
1383        Matrix([
1384        [  1,   0,   0,   0,   0],
1385        [  0, 1/2, 1/2,   0,   0],
1386        [2/5, 2/5, 1/5,   0,   0],
1387        [1/2,   0,   0, 1/2,   0],
1388        [  0, 1/2,   0,   0, 1/2]])
1389
1390        The new states are [3, 1, 2, 4, 5] and you can
1391        create a new chain with this and its canonical
1392        form will remain the same (since it is already
1393        in canonical form).
1394
1395        >>> X = DiscreteMarkovChain('X', states, new_matrix)
1396        >>> states, new_matrix = X.canonical_form()
1397        >>> states
1398        [3, 1, 2, 4, 5]
1399
1400        >>> new_matrix
1401        Matrix([
1402        [  1,   0,   0,   0,   0],
1403        [  0, 1/2, 1/2,   0,   0],
1404        [2/5, 2/5, 1/5,   0,   0],
1405        [1/2,   0,   0, 1/2,   0],
1406        [  0, 1/2,   0,   0, 1/2]])
1407
1408        This is not limited to absorbing chains:
1409
1410        >>> T = Matrix([[0, 5,  5, 0,  0],
1411        ...             [0, 0,  0, 10, 0],
1412        ...             [5, 0,  5, 0,  0],
1413        ...             [0, 10, 0, 0,  0],
1414        ...             [0, 3,  0, 3,  4]])/10
1415        >>> X = DiscreteMarkovChain('X', trans_probs=T)
1416        >>> states, new_matrix = X.canonical_form()
1417        >>> states
1418        [1, 3, 0, 2, 4]
1419
1420        >>> new_matrix
1421        Matrix([
1422        [   0,    1,   0,   0,   0],
1423        [   1,    0,   0,   0,   0],
1424        [ 1/2,    0,   0, 1/2,   0],
1425        [   0,    0, 1/2, 1/2,   0],
1426        [3/10, 3/10,   0,   0, 2/5]])
1427
1428        See Also
1429        ========
1430
1431        sympy.stats.DiscreteMarkovChain.communication_classes
1432        sympy.stats.DiscreteMarkovChain.decompose
1433
1434        References
1435        ==========
1436
1437        .. [1] https://onlinelibrary.wiley.com/doi/pdf/10.1002/9780470316887.app1
1438        .. [2] http://www.columbia.edu/~ww2040/6711F12/lect1023big.pdf
1439        """
1440        states, A, B, C = self.decompose()
1441        O = zeros(A.shape[0], C.shape[1])
1442        return states, BlockMatrix([[A, O], [B, C]]).as_explicit()
1443
1444    def sample(self):
1445        """
1446        Returns
1447        =======
1448
1449        sample: iterator object
1450            iterator object containing the sample
1451
1452        """
1453        if not isinstance(self.transition_probabilities, (Matrix, ImmutableMatrix)):
1454            raise ValueError("Transition Matrix must be provided for sampling")
1455        Tlist = self.transition_probabilities.tolist()
1456        samps = [random.choice(list(self.state_space))]
1457        yield samps[0]
1458        time = 1
1459        densities = {}
1460        for state in self.state_space:
1461            states = list(self.state_space)
1462            densities[state] = {states[i]: Tlist[state][i]
1463                        for i in range(len(states))}
1464        while time < S.Infinity:
1465            samps.append((next(sample_iter(FiniteRV("_", densities[samps[time - 1]])))))
1466            yield samps[time]
1467            time += 1
1468
1469class ContinuousMarkovChain(ContinuousTimeStochasticProcess, MarkovProcess):
1470    """
1471    Represents continuous time Markov chain.
1472
1473    Parameters
1474    ==========
1475
1476    sym: Symbol/str
1477    state_space: Set
1478        Optional, by default, S.Reals
1479    gen_mat: Matrix/ImmutableMatrix/MatrixSymbol
1480        Optional, by default, None
1481
1482    Examples
1483    ========
1484
1485    >>> from sympy.stats import ContinuousMarkovChain, P
1486    >>> from sympy import Matrix, S, Eq, Gt
1487    >>> G = Matrix([[-S(1), S(1)], [S(1), -S(1)]])
1488    >>> C = ContinuousMarkovChain('C', state_space=[0, 1], gen_mat=G)
1489    >>> C.limiting_distribution()
1490    Matrix([[1/2, 1/2]])
1491    >>> C.state_space
1492    {0, 1}
1493    >>> C.generator_matrix
1494    Matrix([
1495    [-1,  1],
1496    [ 1, -1]])
1497
1498    Probability queries are supported
1499
1500    >>> P(Eq(C(1.96), 0), Eq(C(0.78), 1)).round(5)
1501    0.45279
1502    >>> P(Gt(C(1.7), 0), Eq(C(0.82), 1)).round(5)
1503    0.58602
1504
1505    Probability of expressions with multiple RandomIndexedSymbols
1506    can also be calculated provided there is only 1 RandomIndexedSymbol
1507    in the given condition. It is always better to use Rational instead
1508    of floating point numbers for the probabilities in the
1509    generator matrix to avoid errors.
1510
1511    >>> from sympy import Gt, Le, Rational
1512    >>> G = Matrix([[-S(1), Rational(1, 10), Rational(9, 10)], [Rational(2, 5), -S(1), Rational(3, 5)], [Rational(1, 2), Rational(1, 2), -S(1)]])
1513    >>> C = ContinuousMarkovChain('C', state_space=[0, 1, 2], gen_mat=G)
1514    >>> P(Eq(C(3.92), C(1.75)), Eq(C(0.46), 0)).round(5)
1515    0.37933
1516    >>> P(Gt(C(3.92), C(1.75)), Eq(C(0.46), 0)).round(5)
1517    0.34211
1518    >>> P(Le(C(1.57), C(3.14)), Eq(C(1.22), 1)).round(4)
1519    0.7143
1520
1521    Symbolic probability queries are also supported
1522
1523    >>> from sympy import S, symbols, Matrix, Rational, Eq, Gt
1524    >>> from sympy.stats import P, ContinuousMarkovChain
1525    >>> a,b,c,d = symbols('a b c d')
1526    >>> G = Matrix([[-S(1), Rational(1, 10), Rational(9, 10)], [Rational(2, 5), -S(1), Rational(3, 5)], [Rational(1, 2), Rational(1, 2), -S(1)]])
1527    >>> C = ContinuousMarkovChain('C', state_space=[0, 1, 2], gen_mat=G)
1528    >>> query = P(Eq(C(a), b), Eq(C(c), d))
1529    >>> query.subs({a:3.65 ,b:2, c:1.78, d:1}).evalf().round(10)
1530    0.4002723175
1531    >>> P(Eq(C(3.65), 2), Eq(C(1.78), 1)).round(10)
1532    0.4002723175
1533    >>> query_gt = P(Gt(C(a), b), Eq(C(c), d))
1534    >>> query_gt.subs({a:43.2 ,b:0, c:3.29, d:2}).evalf().round(10)
1535    0.6832579186
1536    >>> P(Gt(C(43.2), 0), Eq(C(3.29), 2)).round(10)
1537    0.6832579186
1538
1539    References
1540    ==========
1541
1542    .. [1] https://en.wikipedia.org/wiki/Markov_chain#Continuous-time_Markov_chain
1543    .. [2] http://u.math.biu.ac.il/~amirgi/CTMCnotes.pdf
1544    """
1545    index_set = S.Reals
1546
1547    def __new__(cls, sym, state_space=None, gen_mat=None):
1548        sym = _symbol_converter(sym)
1549        state_space, gen_mat = MarkovProcess._sanity_checks(state_space, gen_mat)
1550        obj = Basic.__new__(cls, sym, state_space, gen_mat)
1551        indices = dict()
1552        if isinstance(obj.number_of_states, Integer):
1553            for index, state in enumerate(obj.state_space):
1554                indices[state] = index
1555        obj.index_of = indices
1556        return obj
1557
1558    @property
1559    def generator_matrix(self):
1560        return self.args[2]
1561
1562    @cacheit
1563    def transition_probabilities(self, gen_mat=None):
1564        t = Dummy('t')
1565        if isinstance(gen_mat, (Matrix, ImmutableMatrix)) and \
1566                gen_mat.is_diagonalizable():
1567            # for faster computation use diagonalized generator matrix
1568            Q, D = gen_mat.diagonalize()
1569            return Lambda(t, Q*exp(t*D)*Q.inv())
1570        if gen_mat != None:
1571            return Lambda(t, exp(t*gen_mat))
1572
1573    def limiting_distribution(self):
1574        gen_mat = self.generator_matrix
1575        if gen_mat is None:
1576            return None
1577        if isinstance(gen_mat, MatrixSymbol):
1578            wm = MatrixSymbol('wm', 1, gen_mat.shape[0])
1579            return Lambda((wm, gen_mat), Eq(wm*gen_mat, wm))
1580        w = IndexedBase('w')
1581        wi = [w[i] for i in range(gen_mat.shape[0])]
1582        wm = Matrix([wi])
1583        eqs = (wm*gen_mat).tolist()[0]
1584        eqs.append(sum(wi) - 1)
1585        soln = list(linsolve(eqs, wi))[0]
1586        return ImmutableMatrix([[sol for sol in soln]])
1587
1588
1589class BernoulliProcess(DiscreteTimeStochasticProcess):
1590    """
1591    The Bernoulli process consists of repeated
1592    independent Bernoulli process trials with the same parameter `p`.
1593    It's assumed that the probability `p` applies to every
1594    trial and that the outcomes of each trial
1595    are independent of all the rest. Therefore Bernoulli Processs
1596    is Discrete State and Discrete Time Stochastic Process.
1597
1598    Parameters
1599    ==========
1600
1601    sym: Symbol/str
1602    success: Integer/str
1603            The event which is considered to be success, by default is 1.
1604    failure: Integer/str
1605            The event which is considered to be failure, by default is 0.
1606    p: Real Number between 0 and 1
1607            Represents the probability of getting success.
1608
1609    Examples
1610    ========
1611
1612    >>> from sympy.stats import BernoulliProcess, P, E
1613    >>> from sympy import Eq, Gt
1614    >>> B = BernoulliProcess("B", p=0.7, success=1, failure=0)
1615    >>> B.state_space
1616    {0, 1}
1617    >>> (B.p).round(2)
1618    0.70
1619    >>> B.success
1620    1
1621    >>> B.failure
1622    0
1623    >>> X = B[1] + B[2] + B[3]
1624    >>> P(Eq(X, 0)).round(2)
1625    0.03
1626    >>> P(Eq(X, 2)).round(2)
1627    0.44
1628    >>> P(Eq(X, 4)).round(2)
1629    0
1630    >>> P(Gt(X, 1)).round(2)
1631    0.78
1632    >>> P(Eq(B[1], 0) & Eq(B[2], 1) & Eq(B[3], 0) & Eq(B[4], 1)).round(2)
1633    0.04
1634    >>> B.joint_distribution(B[1], B[2])
1635    JointDistributionHandmade(Lambda((B[1], B[2]), Piecewise((0.7, Eq(B[1], 1)),
1636    (0.3, Eq(B[1], 0)), (0, True))*Piecewise((0.7, Eq(B[2], 1)), (0.3, Eq(B[2], 0)),
1637    (0, True))))
1638    >>> E(2*B[1] + B[2]).round(2)
1639    2.10
1640    >>> P(B[1] < 1).round(2)
1641    0.30
1642
1643    References
1644    ==========
1645
1646    .. [1] https://en.wikipedia.org/wiki/Bernoulli_process
1647    .. [2] https://mathcs.clarku.edu/~djoyce/ma217/bernoulli.pdf
1648
1649    """
1650
1651    index_set = S.Naturals0
1652
1653    def __new__(cls, sym, p, success=1, failure=0):
1654        _value_check(p >= 0 and p <= 1, 'Value of p must be between 0 and 1.')
1655        sym = _symbol_converter(sym)
1656        p = _sympify(p)
1657        success = _sym_sympify(success)
1658        failure = _sym_sympify(failure)
1659        return Basic.__new__(cls, sym, p, success, failure)
1660
1661    @property
1662    def symbol(self):
1663        return self.args[0]
1664
1665    @property
1666    def p(self):
1667        return self.args[1]
1668
1669    @property
1670    def success(self):
1671        return self.args[2]
1672
1673    @property
1674    def failure(self):
1675        return self.args[3]
1676
1677    @property
1678    def state_space(self):
1679        return _set_converter([self.success, self.failure])
1680
1681    def distribution(self, key=None):
1682        if key is None:
1683            self._deprecation_warn_distribution()
1684            return BernoulliDistribution(self.p)
1685        return BernoulliDistribution(self.p, self.success, self.failure)
1686
1687    def simple_rv(self, rv):
1688        return Bernoulli(rv.name, p=self.p,
1689                succ=self.success, fail=self.failure)
1690
1691    def expectation(self, expr, condition=None, evaluate=True, **kwargs):
1692        """
1693        Computes expectation.
1694
1695        Parameters
1696        ==========
1697
1698        expr: RandomIndexedSymbol, Relational, Logic
1699            Condition for which expectation has to be computed. Must
1700            contain a RandomIndexedSymbol of the process.
1701        condition: Relational, Logic
1702            The given conditions under which computations should be done.
1703
1704        Returns
1705        =======
1706
1707        Expectation of the RandomIndexedSymbol.
1708
1709        """
1710
1711        return _SubstituteRV._expectation(expr, condition, evaluate, **kwargs)
1712
1713    def probability(self, condition, given_condition=None, evaluate=True, **kwargs):
1714        """
1715        Computes probability.
1716
1717        Parameters
1718        ==========
1719
1720        condition: Relational
1721                Condition for which probability has to be computed. Must
1722                contain a RandomIndexedSymbol of the process.
1723        given_condition: Relational/And
1724                The given conditions under which computations should be done.
1725
1726        Returns
1727        =======
1728
1729        Probability of the condition.
1730
1731        """
1732
1733        return _SubstituteRV._probability(condition, given_condition, evaluate, **kwargs)
1734
1735    def density(self, x):
1736        return Piecewise((self.p, Eq(x, self.success)),
1737                         (1 - self.p, Eq(x, self.failure)),
1738                         (S.Zero, True))
1739
1740class _SubstituteRV:
1741    """
1742    Internal class to handle the queries of expectation and probability
1743    by substitution.
1744    """
1745
1746    @staticmethod
1747    def _rvindexed_subs(expr, condition=None):
1748        """
1749        Substitutes the RandomIndexedSymbol with the RandomSymbol with
1750        same name, distribution and probability as RandomIndexedSymbol.
1751
1752        Parameters
1753        ==========
1754
1755        expr: RandomIndexedSymbol, Relational, Logic
1756            Condition for which expectation has to be computed. Must
1757            contain a RandomIndexedSymbol of the process.
1758        condition: Relational, Logic
1759            The given conditions under which computations should be done.
1760
1761        """
1762
1763        rvs_expr = random_symbols(expr)
1764        if len(rvs_expr) != 0:
1765            swapdict_expr = {}
1766            for rv in rvs_expr:
1767                if isinstance(rv, RandomIndexedSymbol):
1768                    newrv = rv.pspace.process.simple_rv(rv) # substitute with equivalent simple rv
1769                    swapdict_expr[rv] = newrv
1770            expr = expr.subs(swapdict_expr)
1771        rvs_cond = random_symbols(condition)
1772        if len(rvs_cond)!=0:
1773            swapdict_cond = {}
1774            for rv in rvs_cond:
1775                if isinstance(rv, RandomIndexedSymbol):
1776                    newrv = rv.pspace.process.simple_rv(rv)
1777                    swapdict_cond[rv] = newrv
1778            condition = condition.subs(swapdict_cond)
1779        return expr, condition
1780
1781    @classmethod
1782    def _expectation(self, expr, condition=None, evaluate=True, **kwargs):
1783        """
1784        Internal method for computing expectation of indexed RV.
1785
1786        Parameters
1787        ==========
1788
1789        expr: RandomIndexedSymbol, Relational, Logic
1790            Condition for which expectation has to be computed. Must
1791            contain a RandomIndexedSymbol of the process.
1792        condition: Relational, Logic
1793            The given conditions under which computations should be done.
1794
1795        Returns
1796        =======
1797
1798        Expectation of the RandomIndexedSymbol.
1799
1800        """
1801        new_expr, new_condition = self._rvindexed_subs(expr, condition)
1802
1803        if not is_random(new_expr):
1804            return new_expr
1805        new_pspace = pspace(new_expr)
1806        if new_condition is not None:
1807            new_expr = given(new_expr, new_condition)
1808        if new_expr.is_Add:  # As E is Linear
1809            return Add(*[new_pspace.compute_expectation(
1810                        expr=arg, evaluate=evaluate, **kwargs)
1811                        for arg in new_expr.args])
1812        return new_pspace.compute_expectation(
1813                new_expr, evaluate=evaluate, **kwargs)
1814
1815    @classmethod
1816    def _probability(self, condition, given_condition=None, evaluate=True, **kwargs):
1817        """
1818        Internal method for computing probability of indexed RV
1819
1820        Parameters
1821        ==========
1822
1823        condition: Relational
1824                Condition for which probability has to be computed. Must
1825                contain a RandomIndexedSymbol of the process.
1826        given_condition: Relational/And
1827                The given conditions under which computations should be done.
1828
1829        Returns
1830        =======
1831
1832        Probability of the condition.
1833
1834        """
1835        new_condition, new_givencondition = self._rvindexed_subs(condition, given_condition)
1836
1837        if isinstance(new_givencondition, RandomSymbol):
1838            condrv = random_symbols(new_condition)
1839            if len(condrv) == 1 and condrv[0] == new_givencondition:
1840                return BernoulliDistribution(self._probability(new_condition), 0, 1)
1841
1842            if any([dependent(rv, new_givencondition) for rv in condrv]):
1843                return Probability(new_condition, new_givencondition)
1844            else:
1845                return self._probability(new_condition)
1846
1847        if new_givencondition is not None and \
1848                not isinstance(new_givencondition, (Relational, Boolean)):
1849            raise ValueError("%s is not a relational or combination of relationals"
1850                    % (new_givencondition))
1851        if new_givencondition == False or new_condition == False:
1852            return S.Zero
1853        if new_condition == True:
1854            return S.One
1855        if not isinstance(new_condition, (Relational, Boolean)):
1856            raise ValueError("%s is not a relational or combination of relationals"
1857                    % (new_condition))
1858
1859        if new_givencondition is not None:  # If there is a condition
1860        # Recompute on new conditional expr
1861            return self._probability(given(new_condition, new_givencondition, **kwargs), **kwargs)
1862        result = pspace(new_condition).probability(new_condition, **kwargs)
1863        if evaluate and hasattr(result, 'doit'):
1864            return result.doit()
1865        else:
1866            return result
1867
1868def get_timerv_swaps(expr, condition):
1869    """
1870    Finds the appropriate interval for each time stamp in expr by parsing
1871    the given condition and returns intervals for each timestamp and
1872    dictionary that maps variable time-stamped Random Indexed Symbol to its
1873    corresponding Random Indexed variable with fixed time stamp.
1874
1875    Parameters
1876    ==========
1877
1878    expr: Sympy Expression
1879        Expression containing Random Indexed Symbols with variable time stamps
1880    condition: Relational/Boolean Expression
1881        Expression containing time bounds of variable time stamps in expr
1882
1883    Examples
1884    ========
1885
1886    >>> from sympy.stats.stochastic_process_types import get_timerv_swaps, PoissonProcess
1887    >>> from sympy import symbols, Contains, Interval
1888    >>> x, t, d = symbols('x t d', positive=True)
1889    >>> X = PoissonProcess("X", 3)
1890    >>> get_timerv_swaps(x*X(t), Contains(t, Interval.Lopen(0, 1)))
1891    ([Interval.Lopen(0, 1)], {X(t): X(1)})
1892    >>> get_timerv_swaps((X(t)**2 + X(d)**2), Contains(t, Interval.Lopen(0, 1))
1893    ... & Contains(d, Interval.Ropen(1, 4))) # doctest: +SKIP
1894    ([Interval.Ropen(1, 4), Interval.Lopen(0, 1)], {X(d): X(3), X(t): X(1)})
1895
1896    Returns
1897    =======
1898
1899    intervals: list
1900        List of Intervals/FiniteSet on which each time stamp is defined
1901    rv_swap: dict
1902        Dictionary mapping variable time Random Indexed Symbol to constant time
1903        Random Indexed Variable
1904
1905    """
1906
1907    if not isinstance(condition, (Relational, Boolean)):
1908        raise ValueError("%s is not a relational or combination of relationals"
1909            % (condition))
1910    expr_syms = list(expr.atoms(RandomIndexedSymbol))
1911    if isinstance(condition, (And, Or)):
1912        given_cond_args = condition.args
1913    else: # single condition
1914        given_cond_args = (condition, )
1915    rv_swap = {}
1916    intervals = []
1917    for expr_sym in expr_syms:
1918        for arg in given_cond_args:
1919            if arg.has(expr_sym.key) and isinstance(expr_sym.key, Symbol):
1920                intv = _set_converter(arg.args[1])
1921                diff_key = intv._sup - intv._inf
1922                if diff_key == oo:
1923                    raise ValueError("%s should have finite bounds" % str(expr_sym.name))
1924                elif diff_key == S.Zero: # has singleton set
1925                    diff_key = intv._sup
1926                rv_swap[expr_sym] = expr_sym.subs({expr_sym.key: diff_key})
1927                intervals.append(intv)
1928    return intervals, rv_swap
1929
1930
1931class CountingProcess(ContinuousTimeStochasticProcess):
1932    """
1933    This class handles the common methods of the Counting Processes
1934    such as Poisson, Wiener and Gamma Processes
1935    """
1936    index_set = _set_converter(Interval(0, oo))
1937
1938    @property
1939    def symbol(self):
1940        return self.args[0]
1941
1942    def expectation(self, expr, condition=None, evaluate=True, **kwargs):
1943        """
1944        Computes expectation
1945
1946        Parameters
1947        ==========
1948
1949        expr: RandomIndexedSymbol, Relational, Logic
1950            Condition for which expectation has to be computed. Must
1951            contain a RandomIndexedSymbol of the process.
1952        condition: Relational, Boolean
1953            The given conditions under which computations should be done, i.e,
1954            the intervals on which each variable time stamp in expr is defined
1955
1956        Returns
1957        =======
1958
1959        Expectation of the given expr
1960
1961        """
1962        if condition is not None:
1963            intervals, rv_swap = get_timerv_swaps(expr, condition)
1964             # they are independent when they have non-overlapping intervals
1965            if len(intervals) == 1 or all(Intersection(*intv_comb) == EmptySet
1966                for intv_comb in itertools.combinations(intervals, 2)):
1967                if expr.is_Add:
1968                    return Add.fromiter(self.expectation(arg, condition)
1969                            for arg in expr.args)
1970                expr = expr.subs(rv_swap)
1971            else:
1972                return Expectation(expr, condition)
1973
1974        return _SubstituteRV._expectation(expr, evaluate=evaluate, **kwargs)
1975
1976    def _solve_argwith_tworvs(self, arg):
1977        if arg.args[0].key >= arg.args[1].key or isinstance(arg, Eq):
1978            diff_key = abs(arg.args[0].key - arg.args[1].key)
1979            rv = arg.args[0]
1980            arg = arg.__class__(rv.pspace.process(diff_key), 0)
1981        else:
1982            diff_key = arg.args[1].key - arg.args[0].key
1983            rv = arg.args[1]
1984            arg = arg.__class__(rv.pspace.process(diff_key), 0)
1985        return arg
1986
1987    def _solve_numerical(self, condition, given_condition=None):
1988        if isinstance(condition, And):
1989            args_list = list(condition.args)
1990        else:
1991            args_list = [condition]
1992        if given_condition is not None:
1993            if isinstance(given_condition, And):
1994                args_list.extend(list(given_condition.args))
1995            else:
1996                args_list.extend([given_condition])
1997        # sort the args based on timestamp to get the independent increments in
1998        # each segment using all the condition args as well as given_condition args
1999        args_list = sorted(args_list, key=lambda x: x.args[0].key)
2000        result = []
2001        cond_args = list(condition.args) if isinstance(condition, And) else [condition]
2002        if args_list[0] in cond_args and not (is_random(args_list[0].args[0])
2003                        and is_random(args_list[0].args[1])):
2004            result.append(_SubstituteRV._probability(args_list[0]))
2005
2006        if is_random(args_list[0].args[0]) and is_random(args_list[0].args[1]):
2007            arg = self._solve_argwith_tworvs(args_list[0])
2008            result.append(_SubstituteRV._probability(arg))
2009
2010        for i in range(len(args_list) - 1):
2011            curr, nex = args_list[i], args_list[i + 1]
2012            diff_key = nex.args[0].key - curr.args[0].key
2013            working_set = curr.args[0].pspace.process.state_space
2014            if curr.args[1] > nex.args[1]: #impossible condition so return 0
2015                result.append(0)
2016                break
2017            if isinstance(curr, Eq):
2018                working_set = Intersection(working_set, Interval.Lopen(curr.args[1], oo))
2019            else:
2020                working_set = Intersection(working_set, curr.as_set())
2021            if isinstance(nex, Eq):
2022                working_set = Intersection(working_set, Interval(-oo, nex.args[1]))
2023            else:
2024                working_set = Intersection(working_set, nex.as_set())
2025            if working_set == EmptySet:
2026                rv = Eq(curr.args[0].pspace.process(diff_key), 0)
2027                result.append(_SubstituteRV._probability(rv))
2028            else:
2029                if working_set.is_finite_set:
2030                    if isinstance(curr, Eq) and isinstance(nex, Eq):
2031                        rv = Eq(curr.args[0].pspace.process(diff_key), len(working_set))
2032                        result.append(_SubstituteRV._probability(rv))
2033                    elif isinstance(curr, Eq) ^ isinstance(nex, Eq):
2034                        result.append(Add.fromiter(_SubstituteRV._probability(Eq(
2035                        curr.args[0].pspace.process(diff_key), x))
2036                                for x in range(len(working_set))))
2037                    else:
2038                        n = len(working_set)
2039                        result.append(Add.fromiter((n - x)*_SubstituteRV._probability(Eq(
2040                        curr.args[0].pspace.process(diff_key), x)) for x in range(n)))
2041                else:
2042                    result.append(_SubstituteRV._probability(
2043                    curr.args[0].pspace.process(diff_key) <= working_set._sup - working_set._inf))
2044        return Mul.fromiter(result)
2045
2046
2047    def probability(self, condition, given_condition=None, evaluate=True, **kwargs):
2048        """
2049        Computes probability.
2050
2051        Parameters
2052        ==========
2053
2054        condition: Relational
2055            Condition for which probability has to be computed. Must
2056            contain a RandomIndexedSymbol of the process.
2057        given_condition: Relational, Boolean
2058            The given conditions under which computations should be done, i.e,
2059            the intervals on which each variable time stamp in expr is defined
2060
2061        Returns
2062        =======
2063
2064        Probability of the condition
2065
2066        """
2067        check_numeric = True
2068        if isinstance(condition, (And, Or)):
2069            cond_args = condition.args
2070        else:
2071            cond_args = (condition, )
2072        # check that condition args are numeric or not
2073        if not all(arg.args[0].key.is_number for arg in cond_args):
2074            check_numeric = False
2075        if given_condition is not None:
2076            check_given_numeric = True
2077            if isinstance(given_condition, (And, Or)):
2078                given_cond_args = given_condition.args
2079            else:
2080                given_cond_args = (given_condition, )
2081            # check that given condition args are numeric or not
2082            if given_condition.has(Contains):
2083                check_given_numeric = False
2084            # Handle numerical queries
2085            if check_numeric and check_given_numeric:
2086                res = []
2087                if isinstance(condition, Or):
2088                    res.append(Add.fromiter(self._solve_numerical(arg, given_condition)
2089                            for arg in condition.args))
2090                if isinstance(given_condition, Or):
2091                    res.append(Add.fromiter(self._solve_numerical(condition, arg)
2092                            for arg in given_condition.args))
2093                if res:
2094                    return Add.fromiter(res)
2095                return self._solve_numerical(condition, given_condition)
2096
2097            # No numeric queries, go by Contains?... then check that all the
2098            # given condition are in form of `Contains`
2099            if not all(arg.has(Contains) for arg in given_cond_args):
2100                raise ValueError("If given condition is passed with `Contains`, then "
2101                "please pass the evaluated condition with its corresponding information "
2102                "in terms of intervals of each time stamp to be passed in given condition.")
2103
2104            intervals, rv_swap = get_timerv_swaps(condition, given_condition)
2105            # they are independent when they have non-overlapping intervals
2106            if len(intervals) == 1 or all(Intersection(*intv_comb) == EmptySet
2107                for intv_comb in itertools.combinations(intervals, 2)):
2108                if isinstance(condition, And):
2109                    return Mul.fromiter(self.probability(arg, given_condition)
2110                            for arg in condition.args)
2111                elif isinstance(condition, Or):
2112                    return Add.fromiter(self.probability(arg, given_condition)
2113                            for arg in condition.args)
2114                condition = condition.subs(rv_swap)
2115            else:
2116                return Probability(condition, given_condition)
2117        if check_numeric:
2118            return self._solve_numerical(condition)
2119        return _SubstituteRV._probability(condition, evaluate=evaluate, **kwargs)
2120
2121class PoissonProcess(CountingProcess):
2122    """
2123    The Poisson process is a counting process. It is usually used in scenarios
2124    where we are counting the occurrences of certain events that appear
2125    to happen at a certain rate, but completely at random.
2126
2127    Parameters
2128    ==========
2129
2130    sym: Symbol/str
2131    lamda: Positive number
2132        Rate of the process, ``lamda > 0``
2133
2134    Examples
2135    ========
2136
2137    >>> from sympy.stats import PoissonProcess, P, E
2138    >>> from sympy import symbols, Eq, Ne, Contains, Interval
2139    >>> X = PoissonProcess("X", lamda=3)
2140    >>> X.state_space
2141    Naturals0
2142    >>> X.lamda
2143    3
2144    >>> t1, t2 = symbols('t1 t2', positive=True)
2145    >>> P(X(t1) < 4)
2146    (9*t1**3/2 + 9*t1**2/2 + 3*t1 + 1)*exp(-3*t1)
2147    >>> P(Eq(X(t1), 2) | Ne(X(t1), 4), Contains(t1, Interval.Ropen(2, 4)))
2148    1 - 36*exp(-6)
2149    >>> P(Eq(X(t1), 2) & Eq(X(t2), 3), Contains(t1, Interval.Lopen(0, 2))
2150    ... & Contains(t2, Interval.Lopen(2, 4)))
2151    648*exp(-12)
2152    >>> E(X(t1))
2153    3*t1
2154    >>> E(X(t1)**2 + 2*X(t2),  Contains(t1, Interval.Lopen(0, 1))
2155    ... & Contains(t2, Interval.Lopen(1, 2)))
2156    18
2157    >>> P(X(3) < 1, Eq(X(1), 0))
2158    exp(-6)
2159    >>> P(Eq(X(4), 3), Eq(X(2), 3))
2160    exp(-6)
2161    >>> P(X(2) <= 3, X(1) > 1)
2162    5*exp(-3)
2163
2164    Merging two Poisson Processes
2165
2166    >>> Y = PoissonProcess("Y", lamda=4)
2167    >>> Z = X + Y
2168    >>> Z.lamda
2169    7
2170
2171    Splitting a Poisson Process into two independent Poisson Processes
2172
2173    >>> N, M = Z.split(l1=2, l2=5)
2174    >>> N.lamda, M.lamda
2175    (2, 5)
2176
2177    References
2178    ==========
2179
2180    .. [1] https://www.probabilitycourse.com/chapter11/11_0_0_intro.php
2181    .. [2] https://en.wikipedia.org/wiki/Poisson_point_process
2182
2183    """
2184
2185    def __new__(cls, sym, lamda):
2186        _value_check(lamda > 0, 'lamda should be a positive number.')
2187        sym = _symbol_converter(sym)
2188        lamda = _sympify(lamda)
2189        return Basic.__new__(cls, sym, lamda)
2190
2191    @property
2192    def lamda(self):
2193        return self.args[1]
2194
2195    @property
2196    def state_space(self):
2197        return S.Naturals0
2198
2199    def distribution(self, key):
2200        if isinstance(key, RandomIndexedSymbol):
2201            self._deprecation_warn_distribution()
2202            return PoissonDistribution(self.lamda*key.key)
2203        return PoissonDistribution(self.lamda*key)
2204
2205    def density(self, x):
2206        return (self.lamda*x.key)**x / factorial(x) * exp(-(self.lamda*x.key))
2207
2208    def simple_rv(self, rv):
2209        return Poisson(rv.name, lamda=self.lamda*rv.key)
2210
2211    def __add__(self, other):
2212        if not isinstance(other, PoissonProcess):
2213            raise ValueError("Only instances of Poisson Process can be merged")
2214        return PoissonProcess(Dummy(self.symbol.name + other.symbol.name),
2215                self.lamda + other.lamda)
2216
2217    def split(self, l1, l2):
2218        if _sympify(l1 + l2) != self.lamda:
2219            raise ValueError("Sum of l1 and l2 should be %s" % str(self.lamda))
2220        return PoissonProcess(Dummy("l1"), l1), PoissonProcess(Dummy("l2"), l2)
2221
2222class WienerProcess(CountingProcess):
2223    """
2224    The Wiener process is a real valued continuous-time stochastic process.
2225    In physics it is used to study Brownian motion and therefore also known as
2226    Brownian Motion.
2227
2228    Parameters
2229    ==========
2230
2231    sym: Symbol/str
2232
2233    Examples
2234    ========
2235
2236    >>> from sympy.stats import WienerProcess, P, E
2237    >>> from sympy import symbols, Contains, Interval
2238    >>> X = WienerProcess("X")
2239    >>> X.state_space
2240    Reals
2241    >>> t1, t2 = symbols('t1 t2', positive=True)
2242    >>> P(X(t1) < 7).simplify()
2243    erf(7*sqrt(2)/(2*sqrt(t1)))/2 + 1/2
2244    >>> P((X(t1) > 2) | (X(t1) < 4), Contains(t1, Interval.Ropen(2, 4))).simplify()
2245    -erf(1)/2 + erf(2)/2 + 1
2246    >>> E(X(t1))
2247    0
2248    >>> E(X(t1) + 2*X(t2),  Contains(t1, Interval.Lopen(0, 1))
2249    ... & Contains(t2, Interval.Lopen(1, 2)))
2250    0
2251
2252    References
2253    ==========
2254
2255    .. [1] https://www.probabilitycourse.com/chapter11/11_4_0_brownian_motion_wiener_process.php
2256    .. [2] https://en.wikipedia.org/wiki/Wiener_process
2257
2258    """
2259    def __new__(cls, sym):
2260        sym = _symbol_converter(sym)
2261        return Basic.__new__(cls, sym)
2262
2263    @property
2264    def state_space(self):
2265        return S.Reals
2266
2267    def distribution(self, key):
2268        if isinstance(key, RandomIndexedSymbol):
2269            self._deprecation_warn_distribution()
2270            return NormalDistribution(0, sqrt(key.key))
2271        return NormalDistribution(0, sqrt(key))
2272
2273    def density(self, x):
2274        return exp(-x**2/(2*x.key)) / (sqrt(2*pi)*sqrt(x.key))
2275
2276    def simple_rv(self, rv):
2277        return Normal(rv.name, 0, sqrt(rv.key))
2278
2279
2280class GammaProcess(CountingProcess):
2281    """
2282    A Gamma process is a random process with independent gamma distributed
2283    increments.  It is a pure-jump increasing Levy process.
2284
2285    Parameters
2286    ==========
2287
2288    sym: Symbol/str
2289    lamda: Positive number
2290        Jump size of the process, ``lamda > 0``
2291    gamma: Positive number
2292        Rate of jump arrivals, ``gamma > 0``
2293
2294    Examples
2295    ========
2296
2297    >>> from sympy.stats import GammaProcess, E, P, variance
2298    >>> from sympy import symbols, Contains, Interval, Not
2299    >>> t, d, x, l, g = symbols('t d x l g', positive=True)
2300    >>> X = GammaProcess("X", l, g)
2301    >>> E(X(t))
2302    g*t/l
2303    >>> variance(X(t)).simplify()
2304    g*t/l**2
2305    >>> X = GammaProcess('X', 1, 2)
2306    >>> P(X(t) < 1).simplify()
2307    lowergamma(2*t, 1)/gamma(2*t)
2308    >>> P(Not((X(t) < 5) & (X(d) > 3)), Contains(t, Interval.Ropen(2, 4)) &
2309    ... Contains(d, Interval.Lopen(7, 8))).simplify()
2310    -4*exp(-3) + 472*exp(-8)/3 + 1
2311    >>> E(X(2) + x*E(X(5)))
2312    10*x + 4
2313
2314    References
2315    ==========
2316
2317    .. [1] https://en.wikipedia.org/wiki/Gamma_process
2318
2319    """
2320    def __new__(cls, sym, lamda, gamma):
2321        _value_check(lamda > 0, 'lamda should be a positive number')
2322        _value_check(gamma > 0, 'gamma should be a positive number')
2323        sym = _symbol_converter(sym)
2324        gamma = _sympify(gamma)
2325        lamda = _sympify(lamda)
2326        return Basic.__new__(cls, sym, lamda, gamma)
2327
2328    @property
2329    def lamda(self):
2330        return self.args[1]
2331
2332    @property
2333    def gamma(self):
2334        return self.args[2]
2335
2336    @property
2337    def state_space(self):
2338        return _set_converter(Interval(0, oo))
2339
2340    def distribution(self, key):
2341        if isinstance(key, RandomIndexedSymbol):
2342            self._deprecation_warn_distribution()
2343            return GammaDistribution(self.gamma*key.key, 1/self.lamda)
2344        return GammaDistribution(self.gamma*key, 1/self.lamda)
2345
2346    def density(self, x):
2347        k = self.gamma*x.key
2348        theta = 1/self.lamda
2349        return x**(k - 1) * exp(-x/theta) / (gamma(k)*theta**k)
2350
2351    def simple_rv(self, rv):
2352        return Gamma(rv.name, self.gamma*rv.key, 1/self.lamda)
2353