1# Copyright 2020 Google LLC 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# https://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14 15import scipy.stats as osp_stats 16from jax.scipy.special import expit, logit 17 18from jax import lax 19from jax._src.numpy.util import _wraps 20 21 22@_wraps(osp_stats.logistic.logpdf, update_doc=False) 23def logpdf(x): 24 return lax.neg(x) - 2. * lax.log1p(lax.exp(lax.neg(x))) 25 26@_wraps(osp_stats.logistic.pdf, update_doc=False) 27def pdf(x): 28 return lax.exp(logpdf(x)) 29 30@_wraps(osp_stats.logistic.ppf, update_doc=False) 31def ppf(x): 32 return logit(x) 33 34@_wraps(osp_stats.logistic.sf, update_doc=False) 35def sf(x): 36 return expit(lax.neg(x)) 37 38@_wraps(osp_stats.logistic.isf, update_doc=False) 39def isf(x): 40 return -logit(x) 41 42@_wraps(osp_stats.logistic.cdf, update_doc=False) 43def cdf(x): 44 return expit(x) 45