1from sympy import Basic, Sum, Dummy, Lambda, Integral
2from sympy.stats.rv import (NamedArgsMixin, random_symbols, _symbol_converter,
3                        PSpace, RandomSymbol, is_random, Distribution)
4from sympy.stats.crv import ContinuousDistribution, SingleContinuousPSpace
5from sympy.stats.drv import DiscreteDistribution, SingleDiscretePSpace
6from sympy.stats.frv import SingleFiniteDistribution, SingleFinitePSpace
7from sympy.stats.crv_types import ContinuousDistributionHandmade
8from sympy.stats.drv_types import DiscreteDistributionHandmade
9from sympy.stats.frv_types import FiniteDistributionHandmade
10
11
12class CompoundPSpace(PSpace):
13    """
14    A temporary Probability Space for the Compound Distribution. After
15    Marginalization, this returns the corresponding Probability Space of the
16    parent distribution.
17    """
18
19    def __new__(cls, s, distribution):
20        s = _symbol_converter(s)
21        if isinstance(distribution, ContinuousDistribution):
22            return SingleContinuousPSpace(s, distribution)
23        if isinstance(distribution, DiscreteDistribution):
24            return SingleDiscretePSpace(s, distribution)
25        if isinstance(distribution, SingleFiniteDistribution):
26            return SingleFinitePSpace(s, distribution)
27        if not isinstance(distribution, CompoundDistribution):
28            raise ValueError("%s should be an isinstance of "
29                        "CompoundDistribution"%(distribution))
30        return Basic.__new__(cls, s, distribution)
31
32    @property
33    def value(self):
34        return RandomSymbol(self.symbol, self)
35
36    @property
37    def symbol(self):
38        return self.args[0]
39
40    @property
41    def is_Continuous(self):
42        return self.distribution.is_Continuous
43
44    @property
45    def is_Finite(self):
46        return self.distribution.is_Finite
47
48    @property
49    def is_Discrete(self):
50        return self.distribution.is_Discrete
51
52    @property
53    def distribution(self):
54        return self.args[1]
55
56    @property
57    def pdf(self):
58        return self.distribution.pdf(self.symbol)
59
60    @property
61    def set(self):
62        return self.distribution.set
63
64    @property
65    def domain(self):
66        return self._get_newpspace().domain
67
68    def _get_newpspace(self, evaluate=False):
69        x = Dummy('x')
70        parent_dist = self.distribution.args[0]
71        func = Lambda(x, self.distribution.pdf(x, evaluate))
72        new_pspace = self._transform_pspace(self.symbol, parent_dist, func)
73        if new_pspace is not None:
74            return new_pspace
75        message = ("Compound Distribution for %s is not implemeted yet" % str(parent_dist))
76        raise NotImplementedError(message)
77
78    def _transform_pspace(self, sym, dist, pdf):
79        """
80        This function returns the new pspace of the distribution using handmade
81        Distributions and their corresponding pspace.
82        """
83        pdf = Lambda(sym, pdf(sym))
84        _set = dist.set
85        if isinstance(dist, ContinuousDistribution):
86            return SingleContinuousPSpace(sym, ContinuousDistributionHandmade(pdf, _set))
87        elif isinstance(dist, DiscreteDistribution):
88            return SingleDiscretePSpace(sym, DiscreteDistributionHandmade(pdf, _set))
89        elif isinstance(dist, SingleFiniteDistribution):
90            dens = {k: pdf(k) for k in _set}
91            return SingleFinitePSpace(sym, FiniteDistributionHandmade(dens))
92
93    def compute_density(self, expr, *, compound_evaluate=True, **kwargs):
94        new_pspace = self._get_newpspace(compound_evaluate)
95        expr = expr.subs({self.value: new_pspace.value})
96        return new_pspace.compute_density(expr, **kwargs)
97
98    def compute_cdf(self, expr, *, compound_evaluate=True, **kwargs):
99        new_pspace = self._get_newpspace(compound_evaluate)
100        expr = expr.subs({self.value: new_pspace.value})
101        return new_pspace.compute_cdf(expr, **kwargs)
102
103    def compute_expectation(self, expr, rvs=None, evaluate=False, **kwargs):
104        new_pspace = self._get_newpspace(evaluate)
105        expr = expr.subs({self.value: new_pspace.value})
106        if rvs:
107            rvs = rvs.subs({self.value: new_pspace.value})
108        if isinstance(new_pspace, SingleFinitePSpace):
109            return new_pspace.compute_expectation(expr, rvs, **kwargs)
110        return new_pspace.compute_expectation(expr, rvs, evaluate, **kwargs)
111
112    def probability(self, condition, *, compound_evaluate=True, **kwargs):
113        new_pspace = self._get_newpspace(compound_evaluate)
114        condition = condition.subs({self.value: new_pspace.value})
115        return new_pspace.probability(condition)
116
117    def conditional_space(self, condition, *, compound_evaluate=True, **kwargs):
118        new_pspace = self._get_newpspace(compound_evaluate)
119        condition = condition.subs({self.value: new_pspace.value})
120        return new_pspace.conditional_space(condition)
121
122
123class CompoundDistribution(Distribution, NamedArgsMixin):
124    """
125    Class for Compound Distributions.
126
127    Parameters
128    ==========
129
130    dist : Distribution
131        Distribution must contain a random parameter
132
133    Examples
134    ========
135
136    >>> from sympy.stats.compound_rv import CompoundDistribution
137    >>> from sympy.stats.crv_types import NormalDistribution
138    >>> from sympy.stats import Normal
139    >>> from sympy.abc import x
140    >>> X = Normal('X', 2, 4)
141    >>> N = NormalDistribution(X, 4)
142    >>> C = CompoundDistribution(N)
143    >>> C.set
144    Interval(-oo, oo)
145    >>> C.pdf(x, evaluate=True).simplify()
146    exp(-x**2/64 + x/16 - 1/16)/(8*sqrt(pi))
147
148    References
149    ==========
150
151    .. [1] https://en.wikipedia.org/wiki/Compound_probability_distribution
152
153    """
154
155    def __new__(cls, dist):
156        if not isinstance(dist, (ContinuousDistribution,
157                SingleFiniteDistribution, DiscreteDistribution)):
158            message = "Compound Distribution for %s is not implemeted yet" % str(dist)
159            raise NotImplementedError(message)
160        if not cls._compound_check(dist):
161            return dist
162        return Basic.__new__(cls, dist)
163
164    @property
165    def set(self):
166        return self.args[0].set
167
168    @property
169    def is_Continuous(self):
170        return isinstance(self.args[0], ContinuousDistribution)
171
172    @property
173    def is_Finite(self):
174        return isinstance(self.args[0], SingleFiniteDistribution)
175
176    @property
177    def is_Discrete(self):
178        return isinstance(self.args[0], DiscreteDistribution)
179
180    def pdf(self, x, evaluate=False):
181        dist = self.args[0]
182        randoms = [rv for rv in dist.args if is_random(rv)]
183        if isinstance(dist, SingleFiniteDistribution):
184            y = Dummy('y', integer=True, negative=False)
185            expr = dist.pmf(y)
186        else:
187            y = Dummy('y')
188            expr = dist.pdf(y)
189        for rv in randoms:
190            expr = self._marginalise(expr, rv, evaluate)
191        return Lambda(y, expr)(x)
192
193    def _marginalise(self, expr, rv, evaluate):
194        if isinstance(rv.pspace.distribution, SingleFiniteDistribution):
195            rv_dens = rv.pspace.distribution.pmf(rv)
196        else:
197            rv_dens = rv.pspace.distribution.pdf(rv)
198        rv_dom = rv.pspace.domain.set
199        if rv.pspace.is_Discrete or rv.pspace.is_Finite:
200            expr = Sum(expr*rv_dens, (rv, rv_dom._inf,
201                    rv_dom._sup))
202        else:
203            expr = Integral(expr*rv_dens, (rv, rv_dom._inf,
204                    rv_dom._sup))
205        if evaluate:
206            return expr.doit()
207        return expr
208
209    @classmethod
210    def _compound_check(self, dist):
211        """
212        Checks if the given distribution contains random parameters.
213        """
214        randoms = []
215        for arg in dist.args:
216            randoms.extend(random_symbols(arg))
217        if len(randoms) == 0:
218            return False
219        return True
220