1"""Truncated exponential distribution."""
2import numpy
3from scipy import special
4
5from ..baseclass import SimpleDistribution, ShiftScaleDistribution
6
7
8class truncexpon(SimpleDistribution):
9    """Truncated exponential distribution."""
10
11    def __init__(self, b):
12        super(truncexpon, self).__init__(dict(b=b))
13
14    def _pdf(self, x, b):
15        return numpy.exp(-x)/(1-numpy.exp(-b))
16
17    def _cdf(self, x, b):
18        return (1.0-numpy.exp(-x))/(1-numpy.exp(-b))
19
20    def _ppf(self, q, b):
21        return -numpy.log(1-q+q*numpy.exp(-b))
22
23    def _lower(self, b):
24        return 0.
25
26    def _upper(self, b):
27        return b
28
29
30class TruncExponential(ShiftScaleDistribution):
31    """
32    Truncated exponential distribution.
33
34    Args:
35        upper (float, Distribution):
36            Location of upper threshold
37        scale (float, Distribution):
38            Scaling parameter in the exponential distribution
39        shift (float, Distribution):
40            Location parameter
41
42    Examples:
43        >>> distribution = chaospy.TruncExponential(1.5)
44        >>> distribution
45        TruncExponential(1.5)
46        >>> uloc = numpy.linspace(0, 1, 6)
47        >>> uloc
48        array([0. , 0.2, 0.4, 0.6, 0.8, 1. ])
49        >>> xloc = distribution.inv(uloc)
50        >>> xloc.round(3)
51        array([0.   , 0.169, 0.372, 0.628, 0.972, 1.5  ])
52        >>> numpy.allclose(distribution.fwd(xloc), uloc)
53        True
54        >>> distribution.pdf(xloc).round(3)
55        array([1.287, 1.087, 0.887, 0.687, 0.487, 0.287])
56        >>> distribution.sample(4).round(3)
57        array([0.709, 0.094, 1.34 , 0.469])
58
59    """
60
61    def __init__(self, upper=1, scale=1, shift=0):
62        super(TruncExponential, self).__init__(
63            dist=truncexpon((upper-shift)*1./scale),
64            scale=scale,
65            shift=shift,
66            repr_args=[upper],
67        )
68