1from __future__ import print_function, division 2import random 3 4import itertools 5from typing import Sequence as tSequence, Union as tUnion, List as tList, Tuple as tTuple 6 7from sympy import (Matrix, MatrixSymbol, S, Indexed, Basic, Tuple, Range, 8 Set, And, Eq, FiniteSet, ImmutableMatrix, Integer, igcd, 9 Lambda, Mul, Dummy, IndexedBase, Add, Interval, oo, 10 linsolve, eye, Or, Not, Intersection, factorial, Contains, 11 Union, Expr, Function, exp, cacheit, sqrt, pi, gamma, 12 Ge, Piecewise, Symbol, NonSquareMatrixError, EmptySet, 13 ceiling, MatrixBase, ConditionSet, ones, zeros, Identity, 14 Rational, Lt, Gt, Le, Ne, BlockMatrix, Sum) 15from sympy.core.relational import Relational 16from sympy.logic.boolalg import Boolean 17from sympy.utilities.exceptions import SymPyDeprecationWarning 18from sympy.utilities.iterables import strongly_connected_components 19from sympy.stats.joint_rv import JointDistribution 20from sympy.stats.joint_rv_types import JointDistributionHandmade 21from sympy.stats.rv import (RandomIndexedSymbol, random_symbols, RandomSymbol, 22 _symbol_converter, _value_check, pspace, given, 23 dependent, is_random, sample_iter, Distribution, 24 Density) 25from sympy.stats.stochastic_process import StochasticPSpace 26from sympy.stats.symbolic_probability import Probability, Expectation 27from sympy.stats.frv_types import Bernoulli, BernoulliDistribution, FiniteRV 28from sympy.stats.drv_types import Poisson, PoissonDistribution 29from sympy.stats.crv_types import Normal, NormalDistribution, Gamma, GammaDistribution 30from sympy.core.sympify import _sympify, sympify 31 32__all__ = [ 33 'StochasticProcess', 34 'DiscreteTimeStochasticProcess', 35 'DiscreteMarkovChain', 36 'TransitionMatrixOf', 37 'StochasticStateSpaceOf', 38 'GeneratorMatrixOf', 39 'ContinuousMarkovChain', 40 'BernoulliProcess', 41 'PoissonProcess', 42 'WienerProcess', 43 'GammaProcess' 44] 45 46 47@is_random.register(Indexed) 48def _(x): 49 return is_random(x.base) 50 51@is_random.register(RandomIndexedSymbol) # type: ignore 52def _(x): 53 return True 54 55def _set_converter(itr): 56 """ 57 Helper function for converting list/tuple/set to Set. 58 If parameter is not an instance of list/tuple/set then 59 no operation is performed. 60 61 Returns 62 ======= 63 64 Set 65 The argument converted to Set. 66 67 68 Raises 69 ====== 70 71 TypeError 72 If the argument is not an instance of list/tuple/set. 73 """ 74 if isinstance(itr, (list, tuple, set)): 75 itr = FiniteSet(*itr) 76 if not isinstance(itr, Set): 77 raise TypeError("%s is not an instance of list/tuple/set."%(itr)) 78 return itr 79 80def _state_converter(itr: tSequence) -> tUnion[Tuple, Range]: 81 """ 82 Helper function for converting list/tuple/set/Range/Tuple/FiniteSet 83 to tuple/Range. 84 """ 85 if isinstance(itr, (Tuple, set, FiniteSet)): 86 itr = Tuple(*(sympify(i) if isinstance(i, str) else i for i in itr)) 87 88 elif isinstance(itr, (list, tuple)): 89 # check if states are unique 90 if len(set(itr)) != len(itr): 91 raise ValueError('The state space must have unique elements.') 92 itr = Tuple(*(sympify(i) if isinstance(i, str) else i for i in itr)) 93 94 elif isinstance(itr, Range): 95 # the only ordered set in sympy I know of 96 # try to convert to tuple 97 try: 98 itr = Tuple(*(sympify(i) if isinstance(i, str) else i for i in itr)) 99 except (TypeError, ValueError): 100 pass 101 102 else: 103 raise TypeError("%s is not an instance of list/tuple/set/Range/Tuple/FiniteSet." % (itr)) 104 return itr 105 106def _sym_sympify(arg): 107 """ 108 Converts an arbitrary expression to a type that can be used inside SymPy. 109 As generally strings are unwise to use in the expressions, 110 it returns the Symbol of argument if the string type argument is passed. 111 112 Parameters 113 ========= 114 115 arg: The parameter to be converted to be used in Sympy. 116 117 Returns 118 ======= 119 120 The converted parameter. 121 122 """ 123 if isinstance(arg, str): 124 return Symbol(arg) 125 else: 126 return _sympify(arg) 127 128def _matrix_checks(matrix): 129 if not isinstance(matrix, (Matrix, MatrixSymbol, ImmutableMatrix)): 130 raise TypeError("Transition probabilities either should " 131 "be a Matrix or a MatrixSymbol.") 132 if matrix.shape[0] != matrix.shape[1]: 133 raise NonSquareMatrixError("%s is not a square matrix"%(matrix)) 134 if isinstance(matrix, Matrix): 135 matrix = ImmutableMatrix(matrix.tolist()) 136 return matrix 137 138class StochasticProcess(Basic): 139 """ 140 Base class for all the stochastic processes whether 141 discrete or continuous. 142 143 Parameters 144 ========== 145 146 sym: Symbol or str 147 state_space: Set 148 The state space of the stochastic process, by default S.Reals. 149 For discrete sets it is zero indexed. 150 151 See Also 152 ======== 153 154 DiscreteTimeStochasticProcess 155 """ 156 157 index_set = S.Reals 158 159 def __new__(cls, sym, state_space=S.Reals, **kwargs): 160 sym = _symbol_converter(sym) 161 state_space = _set_converter(state_space) 162 return Basic.__new__(cls, sym, state_space) 163 164 @property 165 def symbol(self): 166 return self.args[0] 167 168 @property 169 def state_space(self) -> tUnion[FiniteSet, Range]: 170 if not isinstance(self.args[1], (FiniteSet, Range)): 171 return FiniteSet(*self.args[1]) 172 return self.args[1] 173 174 def _deprecation_warn_distribution(self): 175 SymPyDeprecationWarning( 176 feature="Calling distribution with RandomIndexedSymbol", 177 useinstead="distribution with just timestamp as argument", 178 issue=20078, 179 deprecated_since_version="1.7.1" 180 ).warn() 181 182 def distribution(self, key=None): 183 if key is None: 184 self._deprecation_warn_distribution() 185 return Distribution() 186 187 def density(self, x): 188 return Density() 189 190 def __call__(self, time): 191 """ 192 Overridden in ContinuousTimeStochasticProcess. 193 """ 194 raise NotImplementedError("Use [] for indexing discrete time stochastic process.") 195 196 def __getitem__(self, time): 197 """ 198 Overridden in DiscreteTimeStochasticProcess. 199 """ 200 raise NotImplementedError("Use () for indexing continuous time stochastic process.") 201 202 def probability(self, condition): 203 raise NotImplementedError() 204 205 def joint_distribution(self, *args): 206 """ 207 Computes the joint distribution of the random indexed variables. 208 209 Parameters 210 ========== 211 212 args: iterable 213 The finite list of random indexed variables/the key of a stochastic 214 process whose joint distribution has to be computed. 215 216 Returns 217 ======= 218 219 JointDistribution 220 The joint distribution of the list of random indexed variables. 221 An unevaluated object is returned if it is not possible to 222 compute the joint distribution. 223 224 Raises 225 ====== 226 227 ValueError: When the arguments passed are not of type RandomIndexSymbol 228 or Number. 229 """ 230 args = list(args) 231 for i, arg in enumerate(args): 232 if S(arg).is_Number: 233 if self.index_set.is_subset(S.Integers): 234 args[i] = self.__getitem__(arg) 235 else: 236 args[i] = self.__call__(arg) 237 elif not isinstance(arg, RandomIndexedSymbol): 238 raise ValueError("Expected a RandomIndexedSymbol or " 239 "key not %s"%(type(arg))) 240 241 if args[0].pspace.distribution == Distribution(): 242 return JointDistribution(*args) 243 density = Lambda(tuple(args), 244 expr=Mul.fromiter(arg.pspace.process.density(arg) for arg in args)) 245 return JointDistributionHandmade(density) 246 247 def expectation(self, condition, given_condition): 248 raise NotImplementedError("Abstract method for expectation queries.") 249 250 def sample(self): 251 raise NotImplementedError("Abstract method for sampling queries.") 252 253class DiscreteTimeStochasticProcess(StochasticProcess): 254 """ 255 Base class for all discrete stochastic processes. 256 """ 257 def __getitem__(self, time): 258 """ 259 For indexing discrete time stochastic processes. 260 261 Returns 262 ======= 263 264 RandomIndexedSymbol 265 """ 266 time = sympify(time) 267 if not time.is_symbol and time not in self.index_set: 268 raise IndexError("%s is not in the index set of %s"%(time, self.symbol)) 269 idx_obj = Indexed(self.symbol, time) 270 pspace_obj = StochasticPSpace(self.symbol, self, self.distribution(time)) 271 return RandomIndexedSymbol(idx_obj, pspace_obj) 272 273class ContinuousTimeStochasticProcess(StochasticProcess): 274 """ 275 Base class for all continuous time stochastic process. 276 """ 277 def __call__(self, time): 278 """ 279 For indexing continuous time stochastic processes. 280 281 Returns 282 ======= 283 284 RandomIndexedSymbol 285 """ 286 time = sympify(time) 287 if not time.is_symbol and time not in self.index_set: 288 raise IndexError("%s is not in the index set of %s"%(time, self.symbol)) 289 func_obj = Function(self.symbol)(time) 290 pspace_obj = StochasticPSpace(self.symbol, self, self.distribution(time)) 291 return RandomIndexedSymbol(func_obj, pspace_obj) 292 293class TransitionMatrixOf(Boolean): 294 """ 295 Assumes that the matrix is the transition matrix 296 of the process. 297 """ 298 299 def __new__(cls, process, matrix): 300 if not isinstance(process, DiscreteMarkovChain): 301 raise ValueError("Currently only DiscreteMarkovChain " 302 "support TransitionMatrixOf.") 303 matrix = _matrix_checks(matrix) 304 return Basic.__new__(cls, process, matrix) 305 306 process = property(lambda self: self.args[0]) 307 matrix = property(lambda self: self.args[1]) 308 309class GeneratorMatrixOf(TransitionMatrixOf): 310 """ 311 Assumes that the matrix is the generator matrix 312 of the process. 313 """ 314 315 def __new__(cls, process, matrix): 316 if not isinstance(process, ContinuousMarkovChain): 317 raise ValueError("Currently only ContinuousMarkovChain " 318 "support GeneratorMatrixOf.") 319 matrix = _matrix_checks(matrix) 320 return Basic.__new__(cls, process, matrix) 321 322class StochasticStateSpaceOf(Boolean): 323 324 def __new__(cls, process, state_space): 325 if not isinstance(process, (DiscreteMarkovChain, ContinuousMarkovChain)): 326 raise ValueError("Currently only DiscreteMarkovChain and ContinuousMarkovChain " 327 "support StochasticStateSpaceOf.") 328 state_space = _state_converter(state_space) 329 if isinstance(state_space, Range): 330 ss_size = ceiling((state_space.stop - state_space.start) / state_space.step) 331 else: 332 ss_size = len(state_space) 333 state_index = Range(ss_size) 334 return Basic.__new__(cls, process, state_index) 335 336 process = property(lambda self: self.args[0]) 337 state_index = property(lambda self: self.args[1]) 338 339class MarkovProcess(StochasticProcess): 340 """ 341 Contains methods that handle queries 342 common to Markov processes. 343 """ 344 345 @property 346 def number_of_states(self) -> tUnion[Integer, Symbol]: 347 """ 348 The number of states in the Markov Chain. 349 """ 350 return _sympify(self.args[2].shape[0]) 351 352 @property 353 def _state_index(self) -> Range: 354 """ 355 Returns state index as Range. 356 """ 357 return self.args[1] 358 359 @classmethod 360 def _sanity_checks(cls, state_space, trans_probs): 361 # Try to never have None as state_space or trans_probs. 362 # This helps a lot if we get it done at the start. 363 if (state_space is None) and (trans_probs is None): 364 _n = Dummy('n', integer=True, nonnegative=True) 365 state_space = _state_converter(Range(_n)) 366 trans_probs = _matrix_checks(MatrixSymbol('_T', _n, _n)) 367 368 elif state_space is None: 369 trans_probs = _matrix_checks(trans_probs) 370 state_space = _state_converter(Range(trans_probs.shape[0])) 371 372 elif trans_probs is None: 373 state_space = _state_converter(state_space) 374 if isinstance(state_space, Range): 375 _n = ceiling((state_space.stop - state_space.start) / state_space.step) 376 else: 377 _n = len(state_space) 378 trans_probs = MatrixSymbol('_T', _n, _n) 379 380 else: 381 state_space = _state_converter(state_space) 382 trans_probs = _matrix_checks(trans_probs) 383 # Range object doesn't want to give a symbolic size 384 # so we do it ourselves. 385 if isinstance(state_space, Range): 386 ss_size = ceiling((state_space.stop - state_space.start) / state_space.step) 387 else: 388 ss_size = len(state_space) 389 if ss_size != trans_probs.shape[0]: 390 raise ValueError('The size of the state space and the number of ' 391 'rows of the transition matrix must be the same.') 392 393 return state_space, trans_probs 394 395 def _extract_information(self, given_condition): 396 """ 397 Helper function to extract information, like, 398 transition matrix/generator matrix, state space, etc. 399 """ 400 if isinstance(self, DiscreteMarkovChain): 401 trans_probs = self.transition_probabilities 402 state_index = self._state_index 403 elif isinstance(self, ContinuousMarkovChain): 404 trans_probs = self.generator_matrix 405 state_index = self._state_index 406 if isinstance(given_condition, And): 407 gcs = given_condition.args 408 given_condition = S.true 409 for gc in gcs: 410 if isinstance(gc, TransitionMatrixOf): 411 trans_probs = gc.matrix 412 if isinstance(gc, StochasticStateSpaceOf): 413 state_index = gc.state_index 414 if isinstance(gc, Relational): 415 given_condition = given_condition & gc 416 if isinstance(given_condition, TransitionMatrixOf): 417 trans_probs = given_condition.matrix 418 given_condition = S.true 419 if isinstance(given_condition, StochasticStateSpaceOf): 420 state_index = given_condition.state_index 421 given_condition = S.true 422 return trans_probs, state_index, given_condition 423 424 def _check_trans_probs(self, trans_probs, row_sum=1): 425 """ 426 Helper function for checking the validity of transition 427 probabilities. 428 """ 429 if not isinstance(trans_probs, MatrixSymbol): 430 rows = trans_probs.tolist() 431 for row in rows: 432 if (sum(row) - row_sum) != 0: 433 raise ValueError("Values in a row must sum to %s. " 434 "If you are using Float or floats then please use Rational."%(row_sum)) 435 436 def _work_out_state_index(self, state_index, given_condition, trans_probs): 437 """ 438 Helper function to extract state space if there 439 is a random symbol in the given condition. 440 """ 441 # if given condition is None, then there is no need to work out 442 # state_space from random variables 443 if given_condition != None: 444 rand_var = list(given_condition.atoms(RandomSymbol) - 445 given_condition.atoms(RandomIndexedSymbol)) 446 if len(rand_var) == 1: 447 state_index = rand_var[0].pspace.set 448 449 # `not None` is `True`. So the old test fails for symbolic sizes. 450 # Need to build the statement differently. 451 sym_cond = not isinstance(self.number_of_states, (int, Integer)) 452 cond1 = not sym_cond and len(state_index) != trans_probs.shape[0] 453 if cond1: 454 raise ValueError("state space is not compatible with the transition probabilities.") 455 if not isinstance(trans_probs.shape[0], Symbol): 456 state_index = FiniteSet(*[i for i in range(trans_probs.shape[0])]) 457 return state_index 458 459 @cacheit 460 def _preprocess(self, given_condition, evaluate): 461 """ 462 Helper function for pre-processing the information. 463 """ 464 is_insufficient = False 465 466 if not evaluate: # avoid pre-processing if the result is not to be evaluated 467 return (True, None, None, None) 468 469 # extracting transition matrix and state space 470 trans_probs, state_index, given_condition = self._extract_information(given_condition) 471 472 # given_condition does not have sufficient information 473 # for computations 474 if trans_probs is None or \ 475 given_condition is None: 476 is_insufficient = True 477 else: 478 # checking transition probabilities 479 if isinstance(self, DiscreteMarkovChain): 480 self._check_trans_probs(trans_probs, row_sum=1) 481 elif isinstance(self, ContinuousMarkovChain): 482 self._check_trans_probs(trans_probs, row_sum=0) 483 484 # working out state space 485 state_index = self._work_out_state_index(state_index, given_condition, trans_probs) 486 487 return is_insufficient, trans_probs, state_index, given_condition 488 489 def replace_with_index(self, condition): 490 if isinstance(condition, Relational): 491 lhs, rhs = condition.lhs, condition.rhs 492 if not isinstance(lhs, RandomIndexedSymbol): 493 lhs, rhs = rhs, lhs 494 condition = type(condition)(self.index_of.get(lhs, lhs), 495 self.index_of.get(rhs, rhs)) 496 return condition 497 498 def probability(self, condition, given_condition=None, evaluate=True, **kwargs): 499 """ 500 Handles probability queries for Markov process. 501 502 Parameters 503 ========== 504 505 condition: Relational 506 given_condition: Relational/And 507 508 Returns 509 ======= 510 Probability 511 If the information is not sufficient. 512 Expr 513 In all other cases. 514 515 Note 516 ==== 517 Any information passed at the time of query overrides 518 any information passed at the time of object creation like 519 transition probabilities, state space. 520 Pass the transition matrix using TransitionMatrixOf, 521 generator matrix using GeneratorMatrixOf and state space 522 using StochasticStateSpaceOf in given_condition using & or And. 523 """ 524 check, mat, state_index, new_given_condition = \ 525 self._preprocess(given_condition, evaluate) 526 527 rv = list(condition.atoms(RandomIndexedSymbol)) 528 symbolic = False 529 for sym in rv: 530 if sym.key.is_symbol: 531 symbolic = True 532 break 533 534 if check: 535 return Probability(condition, new_given_condition) 536 537 if isinstance(self, ContinuousMarkovChain): 538 trans_probs = self.transition_probabilities(mat) 539 elif isinstance(self, DiscreteMarkovChain): 540 trans_probs = mat 541 condition = self.replace_with_index(condition) 542 given_condition = self.replace_with_index(given_condition) 543 new_given_condition = self.replace_with_index(new_given_condition) 544 545 if isinstance(condition, Relational): 546 if isinstance(new_given_condition, And): 547 gcs = new_given_condition.args 548 else: 549 gcs = (new_given_condition, ) 550 min_key_rv = list(new_given_condition.atoms(RandomIndexedSymbol)) 551 552 if len(min_key_rv): 553 min_key_rv = min_key_rv[0] 554 for r in rv: 555 if min_key_rv.key.is_symbol or r.key.is_symbol: 556 continue 557 if min_key_rv.key > r.key: 558 return Probability(condition) 559 else: 560 min_key_rv = None 561 return Probability(condition) 562 563 if symbolic: 564 return self._symbolic_probability(condition, new_given_condition, rv, min_key_rv) 565 566 if len(rv) > 1: 567 rv[0] = condition.lhs 568 rv[1] = condition.rhs 569 if rv[0].key < rv[1].key: 570 rv[0], rv[1] = rv[1], rv[0] 571 if isinstance(condition, Gt): 572 condition = Lt(condition.lhs, condition.rhs) 573 elif isinstance(condition, Lt): 574 condition = Gt(condition.lhs, condition.rhs) 575 elif isinstance(condition, Ge): 576 condition = Le(condition.lhs, condition.rhs) 577 elif isinstance(condition, Le): 578 condition = Ge(condition.lhs, condition.rhs) 579 s = Rational(0, 1) 580 n = len(self.state_space) 581 582 if isinstance(condition, Eq) or isinstance(condition, Ne): 583 for i in range(0, n): 584 s += self.probability(Eq(rv[0], i), Eq(rv[1], i)) * self.probability(Eq(rv[1], i), new_given_condition) 585 return s if isinstance(condition, Eq) else 1 - s 586 else: 587 upper = 0 588 greater = False 589 if isinstance(condition, Ge) or isinstance(condition, Lt): 590 upper = 1 591 if isinstance(condition, Gt) or isinstance(condition, Ge): 592 greater = True 593 594 for i in range(0, n): 595 if i <= n//2: 596 for j in range(0, i + upper): 597 s += self.probability(Eq(rv[0], i), Eq(rv[1], j)) * self.probability(Eq(rv[1], j), new_given_condition) 598 else: 599 s += self.probability(Eq(rv[0], i), new_given_condition) 600 for j in range(i + upper, n): 601 s -= self.probability(Eq(rv[0], i), Eq(rv[1], j)) * self.probability(Eq(rv[1], j), new_given_condition) 602 return s if greater else 1 - s 603 604 rv = rv[0] 605 states = condition.as_set() 606 prob, gstate = dict(), None 607 for gc in gcs: 608 if gc.has(min_key_rv): 609 if gc.has(Probability): 610 p, gp = (gc.rhs, gc.lhs) if isinstance(gc.lhs, Probability) \ 611 else (gc.lhs, gc.rhs) 612 gr = gp.args[0] 613 gset = Intersection(gr.as_set(), state_index) 614 gstate = list(gset)[0] 615 prob[gset] = p 616 else: 617 _, gstate = (gc.lhs.key, gc.rhs) if isinstance(gc.lhs, RandomIndexedSymbol) \ 618 else (gc.rhs.key, gc.lhs) 619 620 if any((k not in self.index_set) for k in (rv.key, min_key_rv.key)): 621 raise IndexError("The timestamps of the process are not in it's index set.") 622 states = Intersection(states, state_index) if not isinstance(self.number_of_states, Symbol) else states 623 for state in Union(states, FiniteSet(gstate)): 624 if not isinstance(state, (int, Integer)) or Ge(state, mat.shape[0]) is True: 625 raise IndexError("No information is available for (%s, %s) in " 626 "transition probabilities of shape, (%s, %s). " 627 "State space is zero indexed." 628 %(gstate, state, mat.shape[0], mat.shape[1])) 629 if prob: 630 gstates = Union(*prob.keys()) 631 if len(gstates) == 1: 632 gstate = list(gstates)[0] 633 gprob = list(prob.values())[0] 634 prob[gstates] = gprob 635 elif len(gstates) == len(state_index) - 1: 636 gstate = list(state_index - gstates)[0] 637 gprob = S.One - sum(prob.values()) 638 prob[state_index - gstates] = gprob 639 else: 640 raise ValueError("Conflicting information.") 641 else: 642 gprob = S.One 643 644 if min_key_rv == rv: 645 return sum([prob[FiniteSet(state)] for state in states]) 646 if isinstance(self, ContinuousMarkovChain): 647 return gprob * sum([trans_probs(rv.key - min_key_rv.key).__getitem__((gstate, state)) 648 for state in states]) 649 if isinstance(self, DiscreteMarkovChain): 650 return gprob * sum([(trans_probs**(rv.key - min_key_rv.key)).__getitem__((gstate, state)) 651 for state in states]) 652 653 if isinstance(condition, Not): 654 expr = condition.args[0] 655 return S.One - self.probability(expr, given_condition, evaluate, **kwargs) 656 657 if isinstance(condition, And): 658 compute_later, state2cond, conds = [], dict(), condition.args 659 for expr in conds: 660 if isinstance(expr, Relational): 661 ris = list(expr.atoms(RandomIndexedSymbol))[0] 662 if state2cond.get(ris, None) is None: 663 state2cond[ris] = S.true 664 state2cond[ris] &= expr 665 else: 666 compute_later.append(expr) 667 ris = [] 668 for ri in state2cond: 669 ris.append(ri) 670 cset = Intersection(state2cond[ri].as_set(), state_index) 671 if len(cset) == 0: 672 return S.Zero 673 state2cond[ri] = cset.as_relational(ri) 674 sorted_ris = sorted(ris, key=lambda ri: ri.key) 675 prod = self.probability(state2cond[sorted_ris[0]], given_condition, evaluate, **kwargs) 676 for i in range(1, len(sorted_ris)): 677 ri, prev_ri = sorted_ris[i], sorted_ris[i-1] 678 if not isinstance(state2cond[ri], Eq): 679 raise ValueError("The process is in multiple states at %s, unable to determine the probability."%(ri)) 680 mat_of = TransitionMatrixOf(self, mat) if isinstance(self, DiscreteMarkovChain) else GeneratorMatrixOf(self, mat) 681 prod *= self.probability(state2cond[ri], state2cond[prev_ri] 682 & mat_of 683 & StochasticStateSpaceOf(self, state_index), 684 evaluate, **kwargs) 685 for expr in compute_later: 686 prod *= self.probability(expr, given_condition, evaluate, **kwargs) 687 return prod 688 689 if isinstance(condition, Or): 690 return sum([self.probability(expr, given_condition, evaluate, **kwargs) 691 for expr in condition.args]) 692 693 raise NotImplementedError("Mechanism for handling (%s, %s) queries hasn't been " 694 "implemented yet."%(condition, given_condition)) 695 696 def _symbolic_probability(self, condition, new_given_condition, rv, min_key_rv): 697 #Function to calculate probability for queries with symbols 698 if isinstance(condition, Relational): 699 curr_state = new_given_condition.rhs if isinstance(new_given_condition.lhs, RandomIndexedSymbol) \ 700 else new_given_condition.lhs 701 next_state = condition.rhs if isinstance(condition.lhs, RandomIndexedSymbol) \ 702 else condition.lhs 703 704 if isinstance(condition, Eq) or isinstance(condition, Ne): 705 if isinstance(self, DiscreteMarkovChain): 706 P = self.transition_probabilities**(rv[0].key - min_key_rv.key) 707 else: 708 P = exp(self.generator_matrix*(rv[0].key - min_key_rv.key)) 709 prob = P[curr_state, next_state] if isinstance(condition, Eq) else 1 - P[curr_state, next_state] 710 return Piecewise((prob, rv[0].key > min_key_rv.key), (Probability(condition), True)) 711 else: 712 upper = 1 713 greater = False 714 if isinstance(condition, Ge) or isinstance(condition, Lt): 715 upper = 0 716 if isinstance(condition, Gt) or isinstance(condition, Ge): 717 greater = True 718 k = Dummy('k') 719 condition = Eq(condition.lhs, k) if isinstance(condition.lhs, RandomIndexedSymbol)\ 720 else Eq(condition.rhs, k) 721 total = Sum(self.probability(condition, new_given_condition), (k, next_state + upper, self.state_space._sup)) 722 return Piecewise((total, rv[0].key > min_key_rv.key), (Probability(condition), True)) if greater\ 723 else Piecewise((1 - total, rv[0].key > min_key_rv.key), (Probability(condition), True)) 724 else: 725 return Probability(condition, new_given_condition) 726 727 def expectation(self, expr, condition=None, evaluate=True, **kwargs): 728 """ 729 Handles expectation queries for markov process. 730 731 Parameters 732 ========== 733 734 expr: RandomIndexedSymbol, Relational, Logic 735 Condition for which expectation has to be computed. Must 736 contain a RandomIndexedSymbol of the process. 737 condition: Relational, Logic 738 The given conditions under which computations should be done. 739 740 Returns 741 ======= 742 743 Expectation 744 Unevaluated object if computations cannot be done due to 745 insufficient information. 746 Expr 747 In all other cases when the computations are successful. 748 749 Note 750 ==== 751 752 Any information passed at the time of query overrides 753 any information passed at the time of object creation like 754 transition probabilities, state space. 755 756 Pass the transition matrix using TransitionMatrixOf, 757 generator matrix using GeneratorMatrixOf and state space 758 using StochasticStateSpaceOf in given_condition using & or And. 759 """ 760 761 check, mat, state_index, condition = \ 762 self._preprocess(condition, evaluate) 763 764 if check: 765 return Expectation(expr, condition) 766 767 rvs = random_symbols(expr) 768 if isinstance(expr, Expr) and isinstance(condition, Eq) \ 769 and len(rvs) == 1: 770 # handle queries similar to E(f(X[i]), Eq(X[i-m], <some-state>)) 771 condition=self.replace_with_index(condition) 772 state_index=self.replace_with_index(state_index) 773 rv = list(rvs)[0] 774 lhsg, rhsg = condition.lhs, condition.rhs 775 if not isinstance(lhsg, RandomIndexedSymbol): 776 lhsg, rhsg = (rhsg, lhsg) 777 if rhsg not in state_index: 778 raise ValueError("%s state is not in the state space."%(rhsg)) 779 if rv.key < lhsg.key: 780 raise ValueError("Incorrect given condition is given, expectation " 781 "time %s < time %s"%(rv.key, rv.key)) 782 mat_of = TransitionMatrixOf(self, mat) if isinstance(self, DiscreteMarkovChain) else GeneratorMatrixOf(self, mat) 783 cond = condition & mat_of & \ 784 StochasticStateSpaceOf(self, state_index) 785 func = lambda s: self.probability(Eq(rv, s), cond) * expr.subs(rv, self._state_index[s]) 786 return sum([func(s) for s in state_index]) 787 788 raise NotImplementedError("Mechanism for handling (%s, %s) queries hasn't been " 789 "implemented yet."%(expr, condition)) 790 791class DiscreteMarkovChain(DiscreteTimeStochasticProcess, MarkovProcess): 792 """ 793 Represents a finite discrete time-homogeneous Markov chain. 794 795 This type of Markov Chain can be uniquely characterised by 796 its (ordered) state space and its one-step transition probability 797 matrix. 798 799 Parameters 800 ========== 801 802 sym: 803 The name given to the Markov Chain 804 state_space: 805 Optional, by default, Range(n) 806 trans_probs: 807 Optional, by default, MatrixSymbol('_T', n, n) 808 809 Examples 810 ======== 811 812 >>> from sympy.stats import DiscreteMarkovChain, TransitionMatrixOf, P, E 813 >>> from sympy import Matrix, MatrixSymbol, Eq, symbols 814 >>> T = Matrix([[0.5, 0.2, 0.3],[0.2, 0.5, 0.3],[0.2, 0.3, 0.5]]) 815 >>> Y = DiscreteMarkovChain("Y", [0, 1, 2], T) 816 >>> YS = DiscreteMarkovChain("Y") 817 818 >>> Y.state_space 819 {0, 1, 2} 820 >>> Y.transition_probabilities 821 Matrix([ 822 [0.5, 0.2, 0.3], 823 [0.2, 0.5, 0.3], 824 [0.2, 0.3, 0.5]]) 825 >>> TS = MatrixSymbol('T', 3, 3) 826 >>> P(Eq(YS[3], 2), Eq(YS[1], 1) & TransitionMatrixOf(YS, TS)) 827 T[0, 2]*T[1, 0] + T[1, 1]*T[1, 2] + T[1, 2]*T[2, 2] 828 >>> P(Eq(Y[3], 2), Eq(Y[1], 1)).round(2) 829 0.36 830 831 Probabilities will be calculated based on indexes rather 832 than state names. For example, with the Sunny-Cloudy-Rainy 833 model with string state names: 834 835 >>> from sympy.core.symbol import Str 836 >>> Y = DiscreteMarkovChain("Y", [Str('Sunny'), Str('Cloudy'), Str('Rainy')], T) 837 >>> P(Eq(Y[3], 2), Eq(Y[1], 1)).round(2) 838 0.36 839 840 This gives the same answer as the ``[0, 1, 2]`` state space. 841 Currently, there is no support for state names within probability 842 and expectation statements. Here is a work-around using ``Str``: 843 844 >>> P(Eq(Str('Rainy'), Y[3]), Eq(Y[1], Str('Cloudy'))).round(2) 845 0.36 846 847 Symbol state names can also be used: 848 849 >>> sunny, cloudy, rainy = symbols('Sunny, Cloudy, Rainy') 850 >>> Y = DiscreteMarkovChain("Y", [sunny, cloudy, rainy], T) 851 >>> P(Eq(Y[3], rainy), Eq(Y[1], cloudy)).round(2) 852 0.36 853 854 Expectations will be calculated as follows: 855 856 >>> E(Y[3], Eq(Y[1], cloudy)) 857 0.38*Cloudy + 0.36*Rainy + 0.26*Sunny 858 859 Probability of expressions with multiple RandomIndexedSymbols 860 can also be calculated provided there is only 1 RandomIndexedSymbol 861 in the given condition. It is always better to use Rational instead 862 of floating point numbers for the probabilities in the 863 transition matrix to avoid errors. 864 865 >>> from sympy import Gt, Le, Rational 866 >>> T = Matrix([[Rational(5, 10), Rational(3, 10), Rational(2, 10)], [Rational(2, 10), Rational(7, 10), Rational(1, 10)], [Rational(3, 10), Rational(3, 10), Rational(4, 10)]]) 867 >>> Y = DiscreteMarkovChain("Y", [0, 1, 2], T) 868 >>> P(Eq(Y[3], Y[1]), Eq(Y[0], 0)).round(3) 869 0.409 870 >>> P(Gt(Y[3], Y[1]), Eq(Y[0], 0)).round(2) 871 0.36 872 >>> P(Le(Y[15], Y[10]), Eq(Y[8], 2)).round(7) 873 0.6963328 874 875 Symbolic probability queries are also supported 876 877 >>> from sympy import symbols, Matrix, Rational, Eq, Gt 878 >>> from sympy.stats import P, DiscreteMarkovChain 879 >>> a, b, c, d = symbols('a b c d') 880 >>> T = Matrix([[Rational(1, 10), Rational(4, 10), Rational(5, 10)], [Rational(3, 10), Rational(4, 10), Rational(3, 10)], [Rational(7, 10), Rational(2, 10), Rational(1, 10)]]) 881 >>> Y = DiscreteMarkovChain("Y", [0, 1, 2], T) 882 >>> query = P(Eq(Y[a], b), Eq(Y[c], d)) 883 >>> query.subs({a:10 ,b:2, c:5, d:1}).round(4) 884 0.3096 885 >>> P(Eq(Y[10], 2), Eq(Y[5], 1)).evalf().round(4) 886 0.3096 887 >>> query_gt = P(Gt(Y[a], b), Eq(Y[c], d)) 888 >>> query_gt.subs({a:21, b:0, c:5, d:0}).evalf().round(5) 889 0.64705 890 >>> P(Gt(Y[21], 0), Eq(Y[5], 0)).round(5) 891 0.64705 892 893 There is limited support for arbitrarily sized states: 894 895 >>> n = symbols('n', nonnegative=True, integer=True) 896 >>> T = MatrixSymbol('T', n, n) 897 >>> Y = DiscreteMarkovChain("Y", trans_probs=T) 898 >>> Y.state_space 899 Range(0, n, 1) 900 >>> query = P(Eq(Y[a], b), Eq(Y[c], d)) 901 >>> query.subs({a:10, b:2, c:5, d:1}) 902 (T**5)[1, 2] 903 904 References 905 ========== 906 907 .. [1] https://en.wikipedia.org/wiki/Markov_chain#Discrete-time_Markov_chain 908 .. [2] https://www.dartmouth.edu/~chance/teaching_aids/books_articles/probability_book/Chapter11.pdf 909 """ 910 index_set = S.Naturals0 911 912 def __new__(cls, sym, state_space=None, trans_probs=None): 913 # type: (Basic, tUnion[str, Symbol], tSequence, tUnion[MatrixBase, MatrixSymbol]) -> DiscreteMarkovChain 914 sym = _symbol_converter(sym) 915 916 state_space, trans_probs = MarkovProcess._sanity_checks(state_space, trans_probs) 917 918 obj = Basic.__new__(cls, sym, state_space, trans_probs) 919 indices = dict() 920 if isinstance(obj.number_of_states, Integer): 921 for index, state in enumerate(obj._state_index): 922 indices[state] = index 923 obj.index_of = indices 924 return obj 925 926 @property 927 def transition_probabilities(self) -> tUnion[MatrixBase, MatrixSymbol]: 928 """ 929 Transition probabilities of discrete Markov chain, 930 either an instance of Matrix or MatrixSymbol. 931 """ 932 return self.args[2] 933 934 def communication_classes(self) -> tList[tTuple[tList[Basic], Boolean, Integer]]: 935 """ 936 Returns the list of communication classes that partition 937 the states of the markov chain. 938 939 A communication class is defined to be a set of states 940 such that every state in that set is reachable from 941 every other state in that set. Due to its properties 942 this forms a class in the mathematical sense. 943 Communication classes are also known as recurrence 944 classes. 945 946 Returns 947 ======= 948 949 classes 950 The ``classes`` are a list of tuples. Each 951 tuple represents a single communication class 952 with its properties. The first element in the 953 tuple is the list of states in the class, the 954 second element is whether the class is recurrent 955 and the third element is the period of the 956 communication class. 957 958 Examples 959 ======== 960 961 >>> from sympy.stats import DiscreteMarkovChain 962 >>> from sympy import Matrix 963 >>> T = Matrix([[0, 1, 0], 964 ... [1, 0, 0], 965 ... [1, 0, 0]]) 966 >>> X = DiscreteMarkovChain('X', [1, 2, 3], T) 967 >>> classes = X.communication_classes() 968 >>> for states, is_recurrent, period in classes: 969 ... states, is_recurrent, period 970 ([1, 2], True, 2) 971 ([3], False, 1) 972 973 From this we can see that states ``1`` and ``2`` 974 communicate, are recurrent and have a period 975 of 2. We can also see state ``3`` is transient 976 with a period of 1. 977 978 Notes 979 ===== 980 981 The algorithm used is of order ``O(n**2)`` where 982 ``n`` is the number of states in the markov chain. 983 It uses Tarjan's algorithm to find the classes 984 themselves and then it uses a breadth-first search 985 algorithm to find each class's periodicity. 986 Most of the algorithm's components approach ``O(n)`` 987 as the matrix becomes more and more sparse. 988 989 References 990 ========== 991 992 .. [1] http://www.columbia.edu/~ww2040/4701Sum07/4701-06-Notes-MCII.pdf 993 .. [2] http://cecas.clemson.edu/~shierd/Shier/markov.pdf 994 .. [3] https://ujcontent.uj.ac.za/vital/access/services/Download/uj:7506/CONTENT1 995 .. [4] https://www.mathworks.com/help/econ/dtmc.classify.html 996 """ 997 n = self.number_of_states 998 T = self.transition_probabilities 999 1000 if isinstance(T, MatrixSymbol): 1001 raise NotImplementedError("Cannot perform the operation with a symbolic matrix.") 1002 1003 # begin Tarjan's algorithm 1004 V = Range(n) 1005 # don't use state names. Rather use state 1006 # indexes since we use them for matrix 1007 # indexing here and later onward 1008 E = [(i, j) for i in V for j in V if T[i, j] != 0] 1009 classes = strongly_connected_components((V, E)) 1010 # end Tarjan's algorithm 1011 1012 recurrence = [] 1013 periods = [] 1014 for class_ in classes: 1015 # begin recurrent check (similar to self._check_trans_probs()) 1016 submatrix = T[class_, class_] # get the submatrix with those states 1017 is_recurrent = S.true 1018 rows = submatrix.tolist() 1019 for row in rows: 1020 if (sum(row) - 1) != 0: 1021 is_recurrent = S.false 1022 break 1023 recurrence.append(is_recurrent) 1024 # end recurrent check 1025 1026 # begin breadth-first search 1027 non_tree_edge_values = set() 1028 visited = {class_[0]} 1029 newly_visited = {class_[0]} 1030 level = {class_[0]: 0} 1031 current_level = 0 1032 done = False # imitate a do-while loop 1033 while not done: # runs at most len(class_) times 1034 done = len(visited) == len(class_) 1035 current_level += 1 1036 1037 # this loop and the while loop above run a combined len(class_) number of times. 1038 # so this triple nested loop runs through each of the n states once. 1039 for i in newly_visited: 1040 1041 # the loop below runs len(class_) number of times 1042 # complexity is around about O(n * avg(len(class_))) 1043 newly_visited = {j for j in class_ if T[i, j] != 0} 1044 1045 new_tree_edges = newly_visited.difference(visited) 1046 for j in new_tree_edges: 1047 level[j] = current_level 1048 1049 new_non_tree_edges = newly_visited.intersection(visited) 1050 new_non_tree_edge_values = {level[i]-level[j]+1 for j in new_non_tree_edges} 1051 1052 non_tree_edge_values = non_tree_edge_values.union(new_non_tree_edge_values) 1053 visited = visited.union(new_tree_edges) 1054 1055 # igcd needs at least 2 arguments 1056 positive_ntev = {val_e for val_e in non_tree_edge_values if val_e > 0} 1057 if len(positive_ntev) == 0: 1058 periods.append(len(class_)) 1059 elif len(positive_ntev) == 1: 1060 periods.append(positive_ntev.pop()) 1061 else: 1062 periods.append(igcd(*positive_ntev)) 1063 # end breadth-first search 1064 1065 # convert back to the user's state names 1066 classes = [[self._state_index[i] for i in class_] for class_ in classes] 1067 1068 return sympify(list(zip(classes, recurrence, periods))) 1069 1070 def fundamental_matrix(self): 1071 """ 1072 Each entry fundamental matrix can be interpreted as 1073 the expected number of times the chains is in state j 1074 if it started in state i. 1075 1076 References 1077 ========== 1078 1079 .. [1] https://lips.cs.princeton.edu/the-fundamental-matrix-of-a-finite-markov-chain/ 1080 1081 """ 1082 _, _, _, Q = self.decompose() 1083 1084 if Q.shape[0] > 0: # if non-ergodic 1085 I = eye(Q.shape[0]) 1086 if (I - Q).det() == 0: 1087 raise ValueError("The fundamental matrix doesn't exist.") 1088 return (I - Q).inv().as_immutable() 1089 else: # if ergodic 1090 P = self.transition_probabilities 1091 I = eye(P.shape[0]) 1092 w = self.fixed_row_vector() 1093 W = Matrix([list(w) for i in range(0, P.shape[0])]) 1094 if (I - P + W).det() == 0: 1095 raise ValueError("The fundamental matrix doesn't exist.") 1096 return (I - P + W).inv().as_immutable() 1097 1098 def absorbing_probabilities(self): 1099 """ 1100 Computes the absorbing probabilities, i.e., 1101 the ij-th entry of the matrix denotes the 1102 probability of Markov chain being absorbed 1103 in state j starting from state i. 1104 """ 1105 _, _, R, _ = self.decompose() 1106 N = self.fundamental_matrix() 1107 if R is None or N is None: 1108 return None 1109 return N*R 1110 1111 def absorbing_probabilites(self): 1112 SymPyDeprecationWarning( 1113 feature="absorbing_probabilites", 1114 useinstead="absorbing_probabilities", 1115 issue=20042, 1116 deprecated_since_version="1.7" 1117 ).warn() 1118 return self.absorbing_probabilities() 1119 1120 def is_regular(self): 1121 tuples = self.communication_classes() 1122 if len(tuples) == 0: 1123 return S.false # not defined for a 0x0 matrix 1124 classes, _, periods = list(zip(*tuples)) 1125 return And(len(classes) == 1, periods[0] == 1) 1126 1127 def is_ergodic(self): 1128 tuples = self.communication_classes() 1129 if len(tuples) == 0: 1130 return S.false # not defined for a 0x0 matrix 1131 classes, _, _ = list(zip(*tuples)) 1132 return S(len(classes) == 1) 1133 1134 def is_absorbing_state(self, state): 1135 trans_probs = self.transition_probabilities 1136 if isinstance(trans_probs, ImmutableMatrix) and \ 1137 state < trans_probs.shape[0]: 1138 return S(trans_probs[state, state]) is S.One 1139 1140 def is_absorbing_chain(self): 1141 states, A, B, C = self.decompose() 1142 r = A.shape[0] 1143 return And(r > 0, A == Identity(r).as_explicit()) 1144 1145 def stationary_distribution(self, condition_set=False) -> tUnion[ImmutableMatrix, ConditionSet, Lambda]: 1146 """ 1147 The stationary distribution is any row vector, p, that solves p = pP, 1148 is row stochastic and each element in p must be nonnegative. 1149 That means in matrix form: :math:`(P-I)^T p^T = 0` and 1150 :math:`(1, ..., 1) p = 1` 1151 where ``P`` is the one-step transition matrix. 1152 1153 All time-homogeneous Markov Chains with a finite state space 1154 have at least one stationary distribution. In addition, if 1155 a finite time-homogeneous Markov Chain is irreducible, the 1156 stationary distribution is unique. 1157 1158 Parameters 1159 ========== 1160 1161 condition_set : bool 1162 If the chain has a symbolic size or transition matrix, 1163 it will return a ``Lambda`` if ``False`` and return a 1164 ``ConditionSet`` if ``True``. 1165 1166 Examples 1167 ======== 1168 1169 >>> from sympy.stats import DiscreteMarkovChain 1170 >>> from sympy import Matrix, S 1171 1172 An irreducible Markov Chain 1173 1174 >>> T = Matrix([[S(1)/2, S(1)/2, 0], 1175 ... [S(4)/5, S(1)/5, 0], 1176 ... [1, 0, 0]]) 1177 >>> X = DiscreteMarkovChain('X', trans_probs=T) 1178 >>> X.stationary_distribution() 1179 Matrix([[8/13, 5/13, 0]]) 1180 1181 A reducible Markov Chain 1182 1183 >>> T = Matrix([[S(1)/2, S(1)/2, 0], 1184 ... [S(4)/5, S(1)/5, 0], 1185 ... [0, 0, 1]]) 1186 >>> X = DiscreteMarkovChain('X', trans_probs=T) 1187 >>> X.stationary_distribution() 1188 Matrix([[8/13 - 8*tau0/13, 5/13 - 5*tau0/13, tau0]]) 1189 1190 >>> Y = DiscreteMarkovChain('Y') 1191 >>> Y.stationary_distribution() 1192 Lambda((wm, _T), Eq(wm*_T, wm)) 1193 1194 >>> Y.stationary_distribution(condition_set=True) 1195 ConditionSet(wm, Eq(wm*_T, wm)) 1196 1197 References 1198 ========== 1199 1200 .. [1] https://www.probabilitycourse.com/chapter11/11_2_6_stationary_and_limiting_distributions.php 1201 .. [2] https://galton.uchicago.edu/~yibi/teaching/stat317/2014/Lectures/Lecture4_6up.pdf 1202 1203 See Also 1204 ======== 1205 1206 sympy.stats.DiscreteMarkovChain.limiting_distribution 1207 """ 1208 trans_probs = self.transition_probabilities 1209 n = self.number_of_states 1210 1211 if n == 0: 1212 return ImmutableMatrix(Matrix([[]])) 1213 1214 # symbolic matrix version 1215 if isinstance(trans_probs, MatrixSymbol): 1216 wm = MatrixSymbol('wm', 1, n) 1217 if condition_set: 1218 return ConditionSet(wm, Eq(wm * trans_probs, wm)) 1219 else: 1220 return Lambda((wm, trans_probs), Eq(wm * trans_probs, wm)) 1221 1222 # numeric matrix version 1223 a = Matrix(trans_probs - Identity(n)).T 1224 a[0, 0:n] = ones(1, n) 1225 b = zeros(n, 1) 1226 b[0, 0] = 1 1227 1228 soln = list(linsolve((a, b)))[0] 1229 return ImmutableMatrix([[sol for sol in soln]]) 1230 1231 def fixed_row_vector(self): 1232 """ 1233 A wrapper for ``stationary_distribution()``. 1234 """ 1235 return self.stationary_distribution() 1236 1237 @property 1238 def limiting_distribution(self): 1239 """ 1240 The fixed row vector is the limiting 1241 distribution of a discrete Markov chain. 1242 """ 1243 return self.fixed_row_vector() 1244 1245 def decompose(self) -> tTuple[tList[Basic], ImmutableMatrix, ImmutableMatrix, ImmutableMatrix]: 1246 """ 1247 Decomposes the transition matrix into submatrices with 1248 special properties. 1249 1250 The transition matrix can be decomposed into 4 submatrices: 1251 - A - the submatrix from recurrent states to recurrent states. 1252 - B - the submatrix from transient to recurrent states. 1253 - C - the submatrix from transient to transient states. 1254 - O - the submatrix of zeros for recurrent to transient states. 1255 1256 Returns 1257 ======= 1258 1259 states, A, B, C 1260 ``states`` - a list of state names with the first being 1261 the recurrent states and the last being 1262 the transient states in the order 1263 of the row names of A and then the row names of C. 1264 ``A`` - the submatrix from recurrent states to recurrent states. 1265 ``B`` - the submatrix from transient to recurrent states. 1266 ``C`` - the submatrix from transient to transient states. 1267 1268 Examples 1269 ======== 1270 1271 >>> from sympy.stats import DiscreteMarkovChain 1272 >>> from sympy import Matrix, S 1273 1274 One can decompose this chain for example: 1275 1276 >>> T = Matrix([[S(1)/2, S(1)/2, 0, 0, 0], 1277 ... [S(2)/5, S(1)/5, S(2)/5, 0, 0], 1278 ... [0, 0, 1, 0, 0], 1279 ... [0, 0, S(1)/2, S(1)/2, 0], 1280 ... [S(1)/2, 0, 0, 0, S(1)/2]]) 1281 >>> X = DiscreteMarkovChain('X', trans_probs=T) 1282 >>> states, A, B, C = X.decompose() 1283 >>> states 1284 [2, 0, 1, 3, 4] 1285 1286 >>> A # recurrent to recurrent 1287 Matrix([[1]]) 1288 1289 >>> B # transient to recurrent 1290 Matrix([ 1291 [ 0], 1292 [2/5], 1293 [1/2], 1294 [ 0]]) 1295 1296 >>> C # transient to transient 1297 Matrix([ 1298 [1/2, 1/2, 0, 0], 1299 [2/5, 1/5, 0, 0], 1300 [ 0, 0, 1/2, 0], 1301 [1/2, 0, 0, 1/2]]) 1302 1303 This means that state 2 is the only absorbing state 1304 (since A is a 1x1 matrix). B is a 4x1 matrix since 1305 the 4 remaining transient states all merge into reccurent 1306 state 2. And C is the 4x4 matrix that shows how the 1307 transient states 0, 1, 3, 4 all interact. 1308 1309 See Also 1310 ======== 1311 1312 sympy.stats.DiscreteMarkovChain.communication_classes 1313 sympy.stats.DiscreteMarkovChain.canonical_form 1314 1315 References 1316 ========== 1317 1318 .. [1] https://en.wikipedia.org/wiki/Absorbing_Markov_chain 1319 .. [2] http://people.brandeis.edu/~igusa/Math56aS08/Math56a_S08_notes015.pdf 1320 """ 1321 trans_probs = self.transition_probabilities 1322 1323 classes = self.communication_classes() 1324 r_states = [] 1325 t_states = [] 1326 1327 for states, recurrent, period in classes: 1328 if recurrent: 1329 r_states += states 1330 else: 1331 t_states += states 1332 1333 states = r_states + t_states 1334 indexes = [self.index_of[state] for state in states] 1335 1336 A = Matrix(len(r_states), len(r_states), 1337 lambda i, j: trans_probs[indexes[i], indexes[j]]) 1338 1339 B = Matrix(len(t_states), len(r_states), 1340 lambda i, j: trans_probs[indexes[len(r_states) + i], indexes[j]]) 1341 1342 C = Matrix(len(t_states), len(t_states), 1343 lambda i, j: trans_probs[indexes[len(r_states) + i], indexes[len(r_states) + j]]) 1344 1345 return states, A.as_immutable(), B.as_immutable(), C.as_immutable() 1346 1347 def canonical_form(self) -> tTuple[tList[Basic], ImmutableMatrix]: 1348 """ 1349 Reorders the one-step transition matrix 1350 so that recurrent states appear first and transient 1351 states appear last. Other representations include inserting 1352 transient states first and recurrent states last. 1353 1354 Returns 1355 ======= 1356 1357 states, P_new 1358 ``states`` is the list that describes the order of the 1359 new states in the matrix 1360 so that the ith element in ``states`` is the state of the 1361 ith row of A. 1362 ``P_new`` is the new transition matrix in canonical form. 1363 1364 Examples 1365 ======== 1366 1367 >>> from sympy.stats import DiscreteMarkovChain 1368 >>> from sympy import Matrix, S 1369 1370 You can convert your chain into canonical form: 1371 1372 >>> T = Matrix([[S(1)/2, S(1)/2, 0, 0, 0], 1373 ... [S(2)/5, S(1)/5, S(2)/5, 0, 0], 1374 ... [0, 0, 1, 0, 0], 1375 ... [0, 0, S(1)/2, S(1)/2, 0], 1376 ... [S(1)/2, 0, 0, 0, S(1)/2]]) 1377 >>> X = DiscreteMarkovChain('X', list(range(1, 6)), trans_probs=T) 1378 >>> states, new_matrix = X.canonical_form() 1379 >>> states 1380 [3, 1, 2, 4, 5] 1381 1382 >>> new_matrix 1383 Matrix([ 1384 [ 1, 0, 0, 0, 0], 1385 [ 0, 1/2, 1/2, 0, 0], 1386 [2/5, 2/5, 1/5, 0, 0], 1387 [1/2, 0, 0, 1/2, 0], 1388 [ 0, 1/2, 0, 0, 1/2]]) 1389 1390 The new states are [3, 1, 2, 4, 5] and you can 1391 create a new chain with this and its canonical 1392 form will remain the same (since it is already 1393 in canonical form). 1394 1395 >>> X = DiscreteMarkovChain('X', states, new_matrix) 1396 >>> states, new_matrix = X.canonical_form() 1397 >>> states 1398 [3, 1, 2, 4, 5] 1399 1400 >>> new_matrix 1401 Matrix([ 1402 [ 1, 0, 0, 0, 0], 1403 [ 0, 1/2, 1/2, 0, 0], 1404 [2/5, 2/5, 1/5, 0, 0], 1405 [1/2, 0, 0, 1/2, 0], 1406 [ 0, 1/2, 0, 0, 1/2]]) 1407 1408 This is not limited to absorbing chains: 1409 1410 >>> T = Matrix([[0, 5, 5, 0, 0], 1411 ... [0, 0, 0, 10, 0], 1412 ... [5, 0, 5, 0, 0], 1413 ... [0, 10, 0, 0, 0], 1414 ... [0, 3, 0, 3, 4]])/10 1415 >>> X = DiscreteMarkovChain('X', trans_probs=T) 1416 >>> states, new_matrix = X.canonical_form() 1417 >>> states 1418 [1, 3, 0, 2, 4] 1419 1420 >>> new_matrix 1421 Matrix([ 1422 [ 0, 1, 0, 0, 0], 1423 [ 1, 0, 0, 0, 0], 1424 [ 1/2, 0, 0, 1/2, 0], 1425 [ 0, 0, 1/2, 1/2, 0], 1426 [3/10, 3/10, 0, 0, 2/5]]) 1427 1428 See Also 1429 ======== 1430 1431 sympy.stats.DiscreteMarkovChain.communication_classes 1432 sympy.stats.DiscreteMarkovChain.decompose 1433 1434 References 1435 ========== 1436 1437 .. [1] https://onlinelibrary.wiley.com/doi/pdf/10.1002/9780470316887.app1 1438 .. [2] http://www.columbia.edu/~ww2040/6711F12/lect1023big.pdf 1439 """ 1440 states, A, B, C = self.decompose() 1441 O = zeros(A.shape[0], C.shape[1]) 1442 return states, BlockMatrix([[A, O], [B, C]]).as_explicit() 1443 1444 def sample(self): 1445 """ 1446 Returns 1447 ======= 1448 1449 sample: iterator object 1450 iterator object containing the sample 1451 1452 """ 1453 if not isinstance(self.transition_probabilities, (Matrix, ImmutableMatrix)): 1454 raise ValueError("Transition Matrix must be provided for sampling") 1455 Tlist = self.transition_probabilities.tolist() 1456 samps = [random.choice(list(self.state_space))] 1457 yield samps[0] 1458 time = 1 1459 densities = {} 1460 for state in self.state_space: 1461 states = list(self.state_space) 1462 densities[state] = {states[i]: Tlist[state][i] 1463 for i in range(len(states))} 1464 while time < S.Infinity: 1465 samps.append((next(sample_iter(FiniteRV("_", densities[samps[time - 1]]))))) 1466 yield samps[time] 1467 time += 1 1468 1469class ContinuousMarkovChain(ContinuousTimeStochasticProcess, MarkovProcess): 1470 """ 1471 Represents continuous time Markov chain. 1472 1473 Parameters 1474 ========== 1475 1476 sym: Symbol/str 1477 state_space: Set 1478 Optional, by default, S.Reals 1479 gen_mat: Matrix/ImmutableMatrix/MatrixSymbol 1480 Optional, by default, None 1481 1482 Examples 1483 ======== 1484 1485 >>> from sympy.stats import ContinuousMarkovChain, P 1486 >>> from sympy import Matrix, S, Eq, Gt 1487 >>> G = Matrix([[-S(1), S(1)], [S(1), -S(1)]]) 1488 >>> C = ContinuousMarkovChain('C', state_space=[0, 1], gen_mat=G) 1489 >>> C.limiting_distribution() 1490 Matrix([[1/2, 1/2]]) 1491 >>> C.state_space 1492 {0, 1} 1493 >>> C.generator_matrix 1494 Matrix([ 1495 [-1, 1], 1496 [ 1, -1]]) 1497 1498 Probability queries are supported 1499 1500 >>> P(Eq(C(1.96), 0), Eq(C(0.78), 1)).round(5) 1501 0.45279 1502 >>> P(Gt(C(1.7), 0), Eq(C(0.82), 1)).round(5) 1503 0.58602 1504 1505 Probability of expressions with multiple RandomIndexedSymbols 1506 can also be calculated provided there is only 1 RandomIndexedSymbol 1507 in the given condition. It is always better to use Rational instead 1508 of floating point numbers for the probabilities in the 1509 generator matrix to avoid errors. 1510 1511 >>> from sympy import Gt, Le, Rational 1512 >>> G = Matrix([[-S(1), Rational(1, 10), Rational(9, 10)], [Rational(2, 5), -S(1), Rational(3, 5)], [Rational(1, 2), Rational(1, 2), -S(1)]]) 1513 >>> C = ContinuousMarkovChain('C', state_space=[0, 1, 2], gen_mat=G) 1514 >>> P(Eq(C(3.92), C(1.75)), Eq(C(0.46), 0)).round(5) 1515 0.37933 1516 >>> P(Gt(C(3.92), C(1.75)), Eq(C(0.46), 0)).round(5) 1517 0.34211 1518 >>> P(Le(C(1.57), C(3.14)), Eq(C(1.22), 1)).round(4) 1519 0.7143 1520 1521 Symbolic probability queries are also supported 1522 1523 >>> from sympy import S, symbols, Matrix, Rational, Eq, Gt 1524 >>> from sympy.stats import P, ContinuousMarkovChain 1525 >>> a,b,c,d = symbols('a b c d') 1526 >>> G = Matrix([[-S(1), Rational(1, 10), Rational(9, 10)], [Rational(2, 5), -S(1), Rational(3, 5)], [Rational(1, 2), Rational(1, 2), -S(1)]]) 1527 >>> C = ContinuousMarkovChain('C', state_space=[0, 1, 2], gen_mat=G) 1528 >>> query = P(Eq(C(a), b), Eq(C(c), d)) 1529 >>> query.subs({a:3.65 ,b:2, c:1.78, d:1}).evalf().round(10) 1530 0.4002723175 1531 >>> P(Eq(C(3.65), 2), Eq(C(1.78), 1)).round(10) 1532 0.4002723175 1533 >>> query_gt = P(Gt(C(a), b), Eq(C(c), d)) 1534 >>> query_gt.subs({a:43.2 ,b:0, c:3.29, d:2}).evalf().round(10) 1535 0.6832579186 1536 >>> P(Gt(C(43.2), 0), Eq(C(3.29), 2)).round(10) 1537 0.6832579186 1538 1539 References 1540 ========== 1541 1542 .. [1] https://en.wikipedia.org/wiki/Markov_chain#Continuous-time_Markov_chain 1543 .. [2] http://u.math.biu.ac.il/~amirgi/CTMCnotes.pdf 1544 """ 1545 index_set = S.Reals 1546 1547 def __new__(cls, sym, state_space=None, gen_mat=None): 1548 sym = _symbol_converter(sym) 1549 state_space, gen_mat = MarkovProcess._sanity_checks(state_space, gen_mat) 1550 obj = Basic.__new__(cls, sym, state_space, gen_mat) 1551 indices = dict() 1552 if isinstance(obj.number_of_states, Integer): 1553 for index, state in enumerate(obj.state_space): 1554 indices[state] = index 1555 obj.index_of = indices 1556 return obj 1557 1558 @property 1559 def generator_matrix(self): 1560 return self.args[2] 1561 1562 @cacheit 1563 def transition_probabilities(self, gen_mat=None): 1564 t = Dummy('t') 1565 if isinstance(gen_mat, (Matrix, ImmutableMatrix)) and \ 1566 gen_mat.is_diagonalizable(): 1567 # for faster computation use diagonalized generator matrix 1568 Q, D = gen_mat.diagonalize() 1569 return Lambda(t, Q*exp(t*D)*Q.inv()) 1570 if gen_mat != None: 1571 return Lambda(t, exp(t*gen_mat)) 1572 1573 def limiting_distribution(self): 1574 gen_mat = self.generator_matrix 1575 if gen_mat is None: 1576 return None 1577 if isinstance(gen_mat, MatrixSymbol): 1578 wm = MatrixSymbol('wm', 1, gen_mat.shape[0]) 1579 return Lambda((wm, gen_mat), Eq(wm*gen_mat, wm)) 1580 w = IndexedBase('w') 1581 wi = [w[i] for i in range(gen_mat.shape[0])] 1582 wm = Matrix([wi]) 1583 eqs = (wm*gen_mat).tolist()[0] 1584 eqs.append(sum(wi) - 1) 1585 soln = list(linsolve(eqs, wi))[0] 1586 return ImmutableMatrix([[sol for sol in soln]]) 1587 1588 1589class BernoulliProcess(DiscreteTimeStochasticProcess): 1590 """ 1591 The Bernoulli process consists of repeated 1592 independent Bernoulli process trials with the same parameter `p`. 1593 It's assumed that the probability `p` applies to every 1594 trial and that the outcomes of each trial 1595 are independent of all the rest. Therefore Bernoulli Processs 1596 is Discrete State and Discrete Time Stochastic Process. 1597 1598 Parameters 1599 ========== 1600 1601 sym: Symbol/str 1602 success: Integer/str 1603 The event which is considered to be success, by default is 1. 1604 failure: Integer/str 1605 The event which is considered to be failure, by default is 0. 1606 p: Real Number between 0 and 1 1607 Represents the probability of getting success. 1608 1609 Examples 1610 ======== 1611 1612 >>> from sympy.stats import BernoulliProcess, P, E 1613 >>> from sympy import Eq, Gt 1614 >>> B = BernoulliProcess("B", p=0.7, success=1, failure=0) 1615 >>> B.state_space 1616 {0, 1} 1617 >>> (B.p).round(2) 1618 0.70 1619 >>> B.success 1620 1 1621 >>> B.failure 1622 0 1623 >>> X = B[1] + B[2] + B[3] 1624 >>> P(Eq(X, 0)).round(2) 1625 0.03 1626 >>> P(Eq(X, 2)).round(2) 1627 0.44 1628 >>> P(Eq(X, 4)).round(2) 1629 0 1630 >>> P(Gt(X, 1)).round(2) 1631 0.78 1632 >>> P(Eq(B[1], 0) & Eq(B[2], 1) & Eq(B[3], 0) & Eq(B[4], 1)).round(2) 1633 0.04 1634 >>> B.joint_distribution(B[1], B[2]) 1635 JointDistributionHandmade(Lambda((B[1], B[2]), Piecewise((0.7, Eq(B[1], 1)), 1636 (0.3, Eq(B[1], 0)), (0, True))*Piecewise((0.7, Eq(B[2], 1)), (0.3, Eq(B[2], 0)), 1637 (0, True)))) 1638 >>> E(2*B[1] + B[2]).round(2) 1639 2.10 1640 >>> P(B[1] < 1).round(2) 1641 0.30 1642 1643 References 1644 ========== 1645 1646 .. [1] https://en.wikipedia.org/wiki/Bernoulli_process 1647 .. [2] https://mathcs.clarku.edu/~djoyce/ma217/bernoulli.pdf 1648 1649 """ 1650 1651 index_set = S.Naturals0 1652 1653 def __new__(cls, sym, p, success=1, failure=0): 1654 _value_check(p >= 0 and p <= 1, 'Value of p must be between 0 and 1.') 1655 sym = _symbol_converter(sym) 1656 p = _sympify(p) 1657 success = _sym_sympify(success) 1658 failure = _sym_sympify(failure) 1659 return Basic.__new__(cls, sym, p, success, failure) 1660 1661 @property 1662 def symbol(self): 1663 return self.args[0] 1664 1665 @property 1666 def p(self): 1667 return self.args[1] 1668 1669 @property 1670 def success(self): 1671 return self.args[2] 1672 1673 @property 1674 def failure(self): 1675 return self.args[3] 1676 1677 @property 1678 def state_space(self): 1679 return _set_converter([self.success, self.failure]) 1680 1681 def distribution(self, key=None): 1682 if key is None: 1683 self._deprecation_warn_distribution() 1684 return BernoulliDistribution(self.p) 1685 return BernoulliDistribution(self.p, self.success, self.failure) 1686 1687 def simple_rv(self, rv): 1688 return Bernoulli(rv.name, p=self.p, 1689 succ=self.success, fail=self.failure) 1690 1691 def expectation(self, expr, condition=None, evaluate=True, **kwargs): 1692 """ 1693 Computes expectation. 1694 1695 Parameters 1696 ========== 1697 1698 expr: RandomIndexedSymbol, Relational, Logic 1699 Condition for which expectation has to be computed. Must 1700 contain a RandomIndexedSymbol of the process. 1701 condition: Relational, Logic 1702 The given conditions under which computations should be done. 1703 1704 Returns 1705 ======= 1706 1707 Expectation of the RandomIndexedSymbol. 1708 1709 """ 1710 1711 return _SubstituteRV._expectation(expr, condition, evaluate, **kwargs) 1712 1713 def probability(self, condition, given_condition=None, evaluate=True, **kwargs): 1714 """ 1715 Computes probability. 1716 1717 Parameters 1718 ========== 1719 1720 condition: Relational 1721 Condition for which probability has to be computed. Must 1722 contain a RandomIndexedSymbol of the process. 1723 given_condition: Relational/And 1724 The given conditions under which computations should be done. 1725 1726 Returns 1727 ======= 1728 1729 Probability of the condition. 1730 1731 """ 1732 1733 return _SubstituteRV._probability(condition, given_condition, evaluate, **kwargs) 1734 1735 def density(self, x): 1736 return Piecewise((self.p, Eq(x, self.success)), 1737 (1 - self.p, Eq(x, self.failure)), 1738 (S.Zero, True)) 1739 1740class _SubstituteRV: 1741 """ 1742 Internal class to handle the queries of expectation and probability 1743 by substitution. 1744 """ 1745 1746 @staticmethod 1747 def _rvindexed_subs(expr, condition=None): 1748 """ 1749 Substitutes the RandomIndexedSymbol with the RandomSymbol with 1750 same name, distribution and probability as RandomIndexedSymbol. 1751 1752 Parameters 1753 ========== 1754 1755 expr: RandomIndexedSymbol, Relational, Logic 1756 Condition for which expectation has to be computed. Must 1757 contain a RandomIndexedSymbol of the process. 1758 condition: Relational, Logic 1759 The given conditions under which computations should be done. 1760 1761 """ 1762 1763 rvs_expr = random_symbols(expr) 1764 if len(rvs_expr) != 0: 1765 swapdict_expr = {} 1766 for rv in rvs_expr: 1767 if isinstance(rv, RandomIndexedSymbol): 1768 newrv = rv.pspace.process.simple_rv(rv) # substitute with equivalent simple rv 1769 swapdict_expr[rv] = newrv 1770 expr = expr.subs(swapdict_expr) 1771 rvs_cond = random_symbols(condition) 1772 if len(rvs_cond)!=0: 1773 swapdict_cond = {} 1774 for rv in rvs_cond: 1775 if isinstance(rv, RandomIndexedSymbol): 1776 newrv = rv.pspace.process.simple_rv(rv) 1777 swapdict_cond[rv] = newrv 1778 condition = condition.subs(swapdict_cond) 1779 return expr, condition 1780 1781 @classmethod 1782 def _expectation(self, expr, condition=None, evaluate=True, **kwargs): 1783 """ 1784 Internal method for computing expectation of indexed RV. 1785 1786 Parameters 1787 ========== 1788 1789 expr: RandomIndexedSymbol, Relational, Logic 1790 Condition for which expectation has to be computed. Must 1791 contain a RandomIndexedSymbol of the process. 1792 condition: Relational, Logic 1793 The given conditions under which computations should be done. 1794 1795 Returns 1796 ======= 1797 1798 Expectation of the RandomIndexedSymbol. 1799 1800 """ 1801 new_expr, new_condition = self._rvindexed_subs(expr, condition) 1802 1803 if not is_random(new_expr): 1804 return new_expr 1805 new_pspace = pspace(new_expr) 1806 if new_condition is not None: 1807 new_expr = given(new_expr, new_condition) 1808 if new_expr.is_Add: # As E is Linear 1809 return Add(*[new_pspace.compute_expectation( 1810 expr=arg, evaluate=evaluate, **kwargs) 1811 for arg in new_expr.args]) 1812 return new_pspace.compute_expectation( 1813 new_expr, evaluate=evaluate, **kwargs) 1814 1815 @classmethod 1816 def _probability(self, condition, given_condition=None, evaluate=True, **kwargs): 1817 """ 1818 Internal method for computing probability of indexed RV 1819 1820 Parameters 1821 ========== 1822 1823 condition: Relational 1824 Condition for which probability has to be computed. Must 1825 contain a RandomIndexedSymbol of the process. 1826 given_condition: Relational/And 1827 The given conditions under which computations should be done. 1828 1829 Returns 1830 ======= 1831 1832 Probability of the condition. 1833 1834 """ 1835 new_condition, new_givencondition = self._rvindexed_subs(condition, given_condition) 1836 1837 if isinstance(new_givencondition, RandomSymbol): 1838 condrv = random_symbols(new_condition) 1839 if len(condrv) == 1 and condrv[0] == new_givencondition: 1840 return BernoulliDistribution(self._probability(new_condition), 0, 1) 1841 1842 if any([dependent(rv, new_givencondition) for rv in condrv]): 1843 return Probability(new_condition, new_givencondition) 1844 else: 1845 return self._probability(new_condition) 1846 1847 if new_givencondition is not None and \ 1848 not isinstance(new_givencondition, (Relational, Boolean)): 1849 raise ValueError("%s is not a relational or combination of relationals" 1850 % (new_givencondition)) 1851 if new_givencondition == False or new_condition == False: 1852 return S.Zero 1853 if new_condition == True: 1854 return S.One 1855 if not isinstance(new_condition, (Relational, Boolean)): 1856 raise ValueError("%s is not a relational or combination of relationals" 1857 % (new_condition)) 1858 1859 if new_givencondition is not None: # If there is a condition 1860 # Recompute on new conditional expr 1861 return self._probability(given(new_condition, new_givencondition, **kwargs), **kwargs) 1862 result = pspace(new_condition).probability(new_condition, **kwargs) 1863 if evaluate and hasattr(result, 'doit'): 1864 return result.doit() 1865 else: 1866 return result 1867 1868def get_timerv_swaps(expr, condition): 1869 """ 1870 Finds the appropriate interval for each time stamp in expr by parsing 1871 the given condition and returns intervals for each timestamp and 1872 dictionary that maps variable time-stamped Random Indexed Symbol to its 1873 corresponding Random Indexed variable with fixed time stamp. 1874 1875 Parameters 1876 ========== 1877 1878 expr: Sympy Expression 1879 Expression containing Random Indexed Symbols with variable time stamps 1880 condition: Relational/Boolean Expression 1881 Expression containing time bounds of variable time stamps in expr 1882 1883 Examples 1884 ======== 1885 1886 >>> from sympy.stats.stochastic_process_types import get_timerv_swaps, PoissonProcess 1887 >>> from sympy import symbols, Contains, Interval 1888 >>> x, t, d = symbols('x t d', positive=True) 1889 >>> X = PoissonProcess("X", 3) 1890 >>> get_timerv_swaps(x*X(t), Contains(t, Interval.Lopen(0, 1))) 1891 ([Interval.Lopen(0, 1)], {X(t): X(1)}) 1892 >>> get_timerv_swaps((X(t)**2 + X(d)**2), Contains(t, Interval.Lopen(0, 1)) 1893 ... & Contains(d, Interval.Ropen(1, 4))) # doctest: +SKIP 1894 ([Interval.Ropen(1, 4), Interval.Lopen(0, 1)], {X(d): X(3), X(t): X(1)}) 1895 1896 Returns 1897 ======= 1898 1899 intervals: list 1900 List of Intervals/FiniteSet on which each time stamp is defined 1901 rv_swap: dict 1902 Dictionary mapping variable time Random Indexed Symbol to constant time 1903 Random Indexed Variable 1904 1905 """ 1906 1907 if not isinstance(condition, (Relational, Boolean)): 1908 raise ValueError("%s is not a relational or combination of relationals" 1909 % (condition)) 1910 expr_syms = list(expr.atoms(RandomIndexedSymbol)) 1911 if isinstance(condition, (And, Or)): 1912 given_cond_args = condition.args 1913 else: # single condition 1914 given_cond_args = (condition, ) 1915 rv_swap = {} 1916 intervals = [] 1917 for expr_sym in expr_syms: 1918 for arg in given_cond_args: 1919 if arg.has(expr_sym.key) and isinstance(expr_sym.key, Symbol): 1920 intv = _set_converter(arg.args[1]) 1921 diff_key = intv._sup - intv._inf 1922 if diff_key == oo: 1923 raise ValueError("%s should have finite bounds" % str(expr_sym.name)) 1924 elif diff_key == S.Zero: # has singleton set 1925 diff_key = intv._sup 1926 rv_swap[expr_sym] = expr_sym.subs({expr_sym.key: diff_key}) 1927 intervals.append(intv) 1928 return intervals, rv_swap 1929 1930 1931class CountingProcess(ContinuousTimeStochasticProcess): 1932 """ 1933 This class handles the common methods of the Counting Processes 1934 such as Poisson, Wiener and Gamma Processes 1935 """ 1936 index_set = _set_converter(Interval(0, oo)) 1937 1938 @property 1939 def symbol(self): 1940 return self.args[0] 1941 1942 def expectation(self, expr, condition=None, evaluate=True, **kwargs): 1943 """ 1944 Computes expectation 1945 1946 Parameters 1947 ========== 1948 1949 expr: RandomIndexedSymbol, Relational, Logic 1950 Condition for which expectation has to be computed. Must 1951 contain a RandomIndexedSymbol of the process. 1952 condition: Relational, Boolean 1953 The given conditions under which computations should be done, i.e, 1954 the intervals on which each variable time stamp in expr is defined 1955 1956 Returns 1957 ======= 1958 1959 Expectation of the given expr 1960 1961 """ 1962 if condition is not None: 1963 intervals, rv_swap = get_timerv_swaps(expr, condition) 1964 # they are independent when they have non-overlapping intervals 1965 if len(intervals) == 1 or all(Intersection(*intv_comb) == EmptySet 1966 for intv_comb in itertools.combinations(intervals, 2)): 1967 if expr.is_Add: 1968 return Add.fromiter(self.expectation(arg, condition) 1969 for arg in expr.args) 1970 expr = expr.subs(rv_swap) 1971 else: 1972 return Expectation(expr, condition) 1973 1974 return _SubstituteRV._expectation(expr, evaluate=evaluate, **kwargs) 1975 1976 def _solve_argwith_tworvs(self, arg): 1977 if arg.args[0].key >= arg.args[1].key or isinstance(arg, Eq): 1978 diff_key = abs(arg.args[0].key - arg.args[1].key) 1979 rv = arg.args[0] 1980 arg = arg.__class__(rv.pspace.process(diff_key), 0) 1981 else: 1982 diff_key = arg.args[1].key - arg.args[0].key 1983 rv = arg.args[1] 1984 arg = arg.__class__(rv.pspace.process(diff_key), 0) 1985 return arg 1986 1987 def _solve_numerical(self, condition, given_condition=None): 1988 if isinstance(condition, And): 1989 args_list = list(condition.args) 1990 else: 1991 args_list = [condition] 1992 if given_condition is not None: 1993 if isinstance(given_condition, And): 1994 args_list.extend(list(given_condition.args)) 1995 else: 1996 args_list.extend([given_condition]) 1997 # sort the args based on timestamp to get the independent increments in 1998 # each segment using all the condition args as well as given_condition args 1999 args_list = sorted(args_list, key=lambda x: x.args[0].key) 2000 result = [] 2001 cond_args = list(condition.args) if isinstance(condition, And) else [condition] 2002 if args_list[0] in cond_args and not (is_random(args_list[0].args[0]) 2003 and is_random(args_list[0].args[1])): 2004 result.append(_SubstituteRV._probability(args_list[0])) 2005 2006 if is_random(args_list[0].args[0]) and is_random(args_list[0].args[1]): 2007 arg = self._solve_argwith_tworvs(args_list[0]) 2008 result.append(_SubstituteRV._probability(arg)) 2009 2010 for i in range(len(args_list) - 1): 2011 curr, nex = args_list[i], args_list[i + 1] 2012 diff_key = nex.args[0].key - curr.args[0].key 2013 working_set = curr.args[0].pspace.process.state_space 2014 if curr.args[1] > nex.args[1]: #impossible condition so return 0 2015 result.append(0) 2016 break 2017 if isinstance(curr, Eq): 2018 working_set = Intersection(working_set, Interval.Lopen(curr.args[1], oo)) 2019 else: 2020 working_set = Intersection(working_set, curr.as_set()) 2021 if isinstance(nex, Eq): 2022 working_set = Intersection(working_set, Interval(-oo, nex.args[1])) 2023 else: 2024 working_set = Intersection(working_set, nex.as_set()) 2025 if working_set == EmptySet: 2026 rv = Eq(curr.args[0].pspace.process(diff_key), 0) 2027 result.append(_SubstituteRV._probability(rv)) 2028 else: 2029 if working_set.is_finite_set: 2030 if isinstance(curr, Eq) and isinstance(nex, Eq): 2031 rv = Eq(curr.args[0].pspace.process(diff_key), len(working_set)) 2032 result.append(_SubstituteRV._probability(rv)) 2033 elif isinstance(curr, Eq) ^ isinstance(nex, Eq): 2034 result.append(Add.fromiter(_SubstituteRV._probability(Eq( 2035 curr.args[0].pspace.process(diff_key), x)) 2036 for x in range(len(working_set)))) 2037 else: 2038 n = len(working_set) 2039 result.append(Add.fromiter((n - x)*_SubstituteRV._probability(Eq( 2040 curr.args[0].pspace.process(diff_key), x)) for x in range(n))) 2041 else: 2042 result.append(_SubstituteRV._probability( 2043 curr.args[0].pspace.process(diff_key) <= working_set._sup - working_set._inf)) 2044 return Mul.fromiter(result) 2045 2046 2047 def probability(self, condition, given_condition=None, evaluate=True, **kwargs): 2048 """ 2049 Computes probability. 2050 2051 Parameters 2052 ========== 2053 2054 condition: Relational 2055 Condition for which probability has to be computed. Must 2056 contain a RandomIndexedSymbol of the process. 2057 given_condition: Relational, Boolean 2058 The given conditions under which computations should be done, i.e, 2059 the intervals on which each variable time stamp in expr is defined 2060 2061 Returns 2062 ======= 2063 2064 Probability of the condition 2065 2066 """ 2067 check_numeric = True 2068 if isinstance(condition, (And, Or)): 2069 cond_args = condition.args 2070 else: 2071 cond_args = (condition, ) 2072 # check that condition args are numeric or not 2073 if not all(arg.args[0].key.is_number for arg in cond_args): 2074 check_numeric = False 2075 if given_condition is not None: 2076 check_given_numeric = True 2077 if isinstance(given_condition, (And, Or)): 2078 given_cond_args = given_condition.args 2079 else: 2080 given_cond_args = (given_condition, ) 2081 # check that given condition args are numeric or not 2082 if given_condition.has(Contains): 2083 check_given_numeric = False 2084 # Handle numerical queries 2085 if check_numeric and check_given_numeric: 2086 res = [] 2087 if isinstance(condition, Or): 2088 res.append(Add.fromiter(self._solve_numerical(arg, given_condition) 2089 for arg in condition.args)) 2090 if isinstance(given_condition, Or): 2091 res.append(Add.fromiter(self._solve_numerical(condition, arg) 2092 for arg in given_condition.args)) 2093 if res: 2094 return Add.fromiter(res) 2095 return self._solve_numerical(condition, given_condition) 2096 2097 # No numeric queries, go by Contains?... then check that all the 2098 # given condition are in form of `Contains` 2099 if not all(arg.has(Contains) for arg in given_cond_args): 2100 raise ValueError("If given condition is passed with `Contains`, then " 2101 "please pass the evaluated condition with its corresponding information " 2102 "in terms of intervals of each time stamp to be passed in given condition.") 2103 2104 intervals, rv_swap = get_timerv_swaps(condition, given_condition) 2105 # they are independent when they have non-overlapping intervals 2106 if len(intervals) == 1 or all(Intersection(*intv_comb) == EmptySet 2107 for intv_comb in itertools.combinations(intervals, 2)): 2108 if isinstance(condition, And): 2109 return Mul.fromiter(self.probability(arg, given_condition) 2110 for arg in condition.args) 2111 elif isinstance(condition, Or): 2112 return Add.fromiter(self.probability(arg, given_condition) 2113 for arg in condition.args) 2114 condition = condition.subs(rv_swap) 2115 else: 2116 return Probability(condition, given_condition) 2117 if check_numeric: 2118 return self._solve_numerical(condition) 2119 return _SubstituteRV._probability(condition, evaluate=evaluate, **kwargs) 2120 2121class PoissonProcess(CountingProcess): 2122 """ 2123 The Poisson process is a counting process. It is usually used in scenarios 2124 where we are counting the occurrences of certain events that appear 2125 to happen at a certain rate, but completely at random. 2126 2127 Parameters 2128 ========== 2129 2130 sym: Symbol/str 2131 lamda: Positive number 2132 Rate of the process, ``lamda > 0`` 2133 2134 Examples 2135 ======== 2136 2137 >>> from sympy.stats import PoissonProcess, P, E 2138 >>> from sympy import symbols, Eq, Ne, Contains, Interval 2139 >>> X = PoissonProcess("X", lamda=3) 2140 >>> X.state_space 2141 Naturals0 2142 >>> X.lamda 2143 3 2144 >>> t1, t2 = symbols('t1 t2', positive=True) 2145 >>> P(X(t1) < 4) 2146 (9*t1**3/2 + 9*t1**2/2 + 3*t1 + 1)*exp(-3*t1) 2147 >>> P(Eq(X(t1), 2) | Ne(X(t1), 4), Contains(t1, Interval.Ropen(2, 4))) 2148 1 - 36*exp(-6) 2149 >>> P(Eq(X(t1), 2) & Eq(X(t2), 3), Contains(t1, Interval.Lopen(0, 2)) 2150 ... & Contains(t2, Interval.Lopen(2, 4))) 2151 648*exp(-12) 2152 >>> E(X(t1)) 2153 3*t1 2154 >>> E(X(t1)**2 + 2*X(t2), Contains(t1, Interval.Lopen(0, 1)) 2155 ... & Contains(t2, Interval.Lopen(1, 2))) 2156 18 2157 >>> P(X(3) < 1, Eq(X(1), 0)) 2158 exp(-6) 2159 >>> P(Eq(X(4), 3), Eq(X(2), 3)) 2160 exp(-6) 2161 >>> P(X(2) <= 3, X(1) > 1) 2162 5*exp(-3) 2163 2164 Merging two Poisson Processes 2165 2166 >>> Y = PoissonProcess("Y", lamda=4) 2167 >>> Z = X + Y 2168 >>> Z.lamda 2169 7 2170 2171 Splitting a Poisson Process into two independent Poisson Processes 2172 2173 >>> N, M = Z.split(l1=2, l2=5) 2174 >>> N.lamda, M.lamda 2175 (2, 5) 2176 2177 References 2178 ========== 2179 2180 .. [1] https://www.probabilitycourse.com/chapter11/11_0_0_intro.php 2181 .. [2] https://en.wikipedia.org/wiki/Poisson_point_process 2182 2183 """ 2184 2185 def __new__(cls, sym, lamda): 2186 _value_check(lamda > 0, 'lamda should be a positive number.') 2187 sym = _symbol_converter(sym) 2188 lamda = _sympify(lamda) 2189 return Basic.__new__(cls, sym, lamda) 2190 2191 @property 2192 def lamda(self): 2193 return self.args[1] 2194 2195 @property 2196 def state_space(self): 2197 return S.Naturals0 2198 2199 def distribution(self, key): 2200 if isinstance(key, RandomIndexedSymbol): 2201 self._deprecation_warn_distribution() 2202 return PoissonDistribution(self.lamda*key.key) 2203 return PoissonDistribution(self.lamda*key) 2204 2205 def density(self, x): 2206 return (self.lamda*x.key)**x / factorial(x) * exp(-(self.lamda*x.key)) 2207 2208 def simple_rv(self, rv): 2209 return Poisson(rv.name, lamda=self.lamda*rv.key) 2210 2211 def __add__(self, other): 2212 if not isinstance(other, PoissonProcess): 2213 raise ValueError("Only instances of Poisson Process can be merged") 2214 return PoissonProcess(Dummy(self.symbol.name + other.symbol.name), 2215 self.lamda + other.lamda) 2216 2217 def split(self, l1, l2): 2218 if _sympify(l1 + l2) != self.lamda: 2219 raise ValueError("Sum of l1 and l2 should be %s" % str(self.lamda)) 2220 return PoissonProcess(Dummy("l1"), l1), PoissonProcess(Dummy("l2"), l2) 2221 2222class WienerProcess(CountingProcess): 2223 """ 2224 The Wiener process is a real valued continuous-time stochastic process. 2225 In physics it is used to study Brownian motion and therefore also known as 2226 Brownian Motion. 2227 2228 Parameters 2229 ========== 2230 2231 sym: Symbol/str 2232 2233 Examples 2234 ======== 2235 2236 >>> from sympy.stats import WienerProcess, P, E 2237 >>> from sympy import symbols, Contains, Interval 2238 >>> X = WienerProcess("X") 2239 >>> X.state_space 2240 Reals 2241 >>> t1, t2 = symbols('t1 t2', positive=True) 2242 >>> P(X(t1) < 7).simplify() 2243 erf(7*sqrt(2)/(2*sqrt(t1)))/2 + 1/2 2244 >>> P((X(t1) > 2) | (X(t1) < 4), Contains(t1, Interval.Ropen(2, 4))).simplify() 2245 -erf(1)/2 + erf(2)/2 + 1 2246 >>> E(X(t1)) 2247 0 2248 >>> E(X(t1) + 2*X(t2), Contains(t1, Interval.Lopen(0, 1)) 2249 ... & Contains(t2, Interval.Lopen(1, 2))) 2250 0 2251 2252 References 2253 ========== 2254 2255 .. [1] https://www.probabilitycourse.com/chapter11/11_4_0_brownian_motion_wiener_process.php 2256 .. [2] https://en.wikipedia.org/wiki/Wiener_process 2257 2258 """ 2259 def __new__(cls, sym): 2260 sym = _symbol_converter(sym) 2261 return Basic.__new__(cls, sym) 2262 2263 @property 2264 def state_space(self): 2265 return S.Reals 2266 2267 def distribution(self, key): 2268 if isinstance(key, RandomIndexedSymbol): 2269 self._deprecation_warn_distribution() 2270 return NormalDistribution(0, sqrt(key.key)) 2271 return NormalDistribution(0, sqrt(key)) 2272 2273 def density(self, x): 2274 return exp(-x**2/(2*x.key)) / (sqrt(2*pi)*sqrt(x.key)) 2275 2276 def simple_rv(self, rv): 2277 return Normal(rv.name, 0, sqrt(rv.key)) 2278 2279 2280class GammaProcess(CountingProcess): 2281 """ 2282 A Gamma process is a random process with independent gamma distributed 2283 increments. It is a pure-jump increasing Levy process. 2284 2285 Parameters 2286 ========== 2287 2288 sym: Symbol/str 2289 lamda: Positive number 2290 Jump size of the process, ``lamda > 0`` 2291 gamma: Positive number 2292 Rate of jump arrivals, ``gamma > 0`` 2293 2294 Examples 2295 ======== 2296 2297 >>> from sympy.stats import GammaProcess, E, P, variance 2298 >>> from sympy import symbols, Contains, Interval, Not 2299 >>> t, d, x, l, g = symbols('t d x l g', positive=True) 2300 >>> X = GammaProcess("X", l, g) 2301 >>> E(X(t)) 2302 g*t/l 2303 >>> variance(X(t)).simplify() 2304 g*t/l**2 2305 >>> X = GammaProcess('X', 1, 2) 2306 >>> P(X(t) < 1).simplify() 2307 lowergamma(2*t, 1)/gamma(2*t) 2308 >>> P(Not((X(t) < 5) & (X(d) > 3)), Contains(t, Interval.Ropen(2, 4)) & 2309 ... Contains(d, Interval.Lopen(7, 8))).simplify() 2310 -4*exp(-3) + 472*exp(-8)/3 + 1 2311 >>> E(X(2) + x*E(X(5))) 2312 10*x + 4 2313 2314 References 2315 ========== 2316 2317 .. [1] https://en.wikipedia.org/wiki/Gamma_process 2318 2319 """ 2320 def __new__(cls, sym, lamda, gamma): 2321 _value_check(lamda > 0, 'lamda should be a positive number') 2322 _value_check(gamma > 0, 'gamma should be a positive number') 2323 sym = _symbol_converter(sym) 2324 gamma = _sympify(gamma) 2325 lamda = _sympify(lamda) 2326 return Basic.__new__(cls, sym, lamda, gamma) 2327 2328 @property 2329 def lamda(self): 2330 return self.args[1] 2331 2332 @property 2333 def gamma(self): 2334 return self.args[2] 2335 2336 @property 2337 def state_space(self): 2338 return _set_converter(Interval(0, oo)) 2339 2340 def distribution(self, key): 2341 if isinstance(key, RandomIndexedSymbol): 2342 self._deprecation_warn_distribution() 2343 return GammaDistribution(self.gamma*key.key, 1/self.lamda) 2344 return GammaDistribution(self.gamma*key, 1/self.lamda) 2345 2346 def density(self, x): 2347 k = self.gamma*x.key 2348 theta = 1/self.lamda 2349 return x**(k - 1) * exp(-x/theta) / (gamma(k)*theta**k) 2350 2351 def simple_rv(self, rv): 2352 return Gamma(rv.name, self.gamma*rv.key, 1/self.lamda) 2353