1# Copyright 2018 Google LLC
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     https://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15import scipy.stats as osp_stats
16
17from jax import lax
18from jax._src.numpy.util import _wraps
19from jax._src.numpy.lax_numpy import (_promote_args_inexact, _constant_like,
20                                      where, inf)
21from jax.scipy.special import gammaln
22
23
24@_wraps(osp_stats.gamma.logpdf, update_doc=False)
25def logpdf(x, a, loc=0, scale=1):
26  x, a, loc, scale = _promote_args_inexact("gamma.logpdf", x, a, loc, scale)
27  one = _constant_like(x, 1)
28  y = lax.div(lax.sub(x, loc), scale)
29  log_linear_term = lax.sub(lax.mul(lax.sub(a, one), lax.log(y)), y)
30  shape_terms = lax.add(gammaln(a), lax.log(scale))
31  log_probs = lax.sub(log_linear_term, shape_terms)
32  return where(lax.lt(x, loc), -inf, log_probs)
33
34@_wraps(osp_stats.gamma.pdf, update_doc=False)
35def pdf(x, a, loc=0, scale=1):
36  return lax.exp(logpdf(x, a, loc, scale))
37