1# Copyright 2020 The PyMC Developers 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14 15import warnings 16 17import numpy as np 18import theano.tensor as tt 19 20from scipy.special import logit as nplogit 21 22from pymc3.distributions import distribution 23from pymc3.distributions.distribution import draw_values 24from pymc3.math import invlogit, logit, logsumexp 25from pymc3.model import FreeRV 26from pymc3.theanof import floatX, gradient 27 28__all__ = [ 29 "Transform", 30 "transform", 31 "stick_breaking", 32 "logodds", 33 "interval", 34 "log_exp_m1", 35 "lowerbound", 36 "upperbound", 37 "ordered", 38 "log", 39 "sum_to_1", 40 "circular", 41 "CholeskyCovPacked", 42 "Chain", 43] 44 45 46class Transform: 47 """A transformation of a random variable from one space into another. 48 49 Attributes 50 ---------- 51 name: str 52 """ 53 54 name = "" 55 56 def forward(self, x): 57 """Applies transformation forward to input variable `x`. 58 When transform is used on some distribution `p`, it will transform the random variable `x` after sampling 59 from `p`. 60 61 Parameters 62 ---------- 63 x: tensor 64 Input tensor to be transformed. 65 66 Returns 67 -------- 68 tensor 69 Transformed tensor. 70 """ 71 raise NotImplementedError 72 73 def forward_val(self, x, point): 74 """Applies transformation forward to input array `x`. 75 Similar to `forward` but for constant data. 76 77 Parameters 78 ---------- 79 x: array_like 80 Input array to be transformed. 81 point: array_like, optional 82 Test value used to draw (fix) bounds-like transformations 83 84 Returns 85 -------- 86 array_like 87 Transformed array. 88 """ 89 raise NotImplementedError 90 91 def backward(self, z): 92 """Applies inverse of transformation to input variable `z`. 93 When transform is used on some distribution `p`, which has observed values `z`, it is used to 94 transform the values of `z` correctly to the support of `p`. 95 96 Parameters 97 ---------- 98 z: tensor 99 Input tensor to be inverse transformed. 100 101 Returns 102 ------- 103 tensor 104 Inverse transformed tensor. 105 """ 106 raise NotImplementedError 107 108 def jacobian_det(self, x): 109 """Calculates logarithm of the absolute value of the Jacobian determinant 110 of the backward transformation for input `x`. 111 112 Parameters 113 ---------- 114 x: tensor 115 Input to calculate Jacobian determinant of. 116 117 Returns 118 ------- 119 tensor 120 The log abs Jacobian determinant of `x` w.r.t. this transform. 121 """ 122 raise NotImplementedError 123 124 def apply(self, dist): 125 # avoid circular import 126 return TransformedDistribution.dist(dist, self) 127 128 def __str__(self): 129 return self.name + " transform" 130 131 132class ElemwiseTransform(Transform): 133 def jacobian_det(self, x): 134 grad = tt.reshape(gradient(tt.sum(self.backward(x)), [x]), x.shape) 135 return tt.log(tt.abs_(grad)) 136 137 138class TransformedDistribution(distribution.Distribution): 139 """A distribution that has been transformed from one space into another.""" 140 141 def __init__(self, dist, transform, *args, **kwargs): 142 """ 143 Parameters 144 ---------- 145 dist: Distribution 146 transform: Transform 147 args, kwargs 148 arguments to Distribution""" 149 forward = transform.forward 150 testval = forward(dist.default()) 151 152 self.dist = dist 153 self.transform_used = transform 154 v = forward(FreeRV(name="v", distribution=dist)) 155 self.type = v.type 156 157 super().__init__(v.shape.tag.test_value, v.dtype, testval, dist.defaults, *args, **kwargs) 158 159 if transform.name == "stickbreaking": 160 b = np.hstack(((np.atleast_1d(self.shape) == 1)[:-1], False)) 161 # force the last dim not broadcastable 162 self.type = tt.TensorType(v.dtype, b) 163 164 def logp(self, x): 165 """ 166 Calculate log-probability of Transformed distribution at specified value. 167 168 Parameters 169 ---------- 170 x: numeric 171 Value for which log-probability is calculated. 172 173 Returns 174 ------- 175 TensorVariable 176 """ 177 logp_nojac = self.logp_nojac(x) 178 jacobian_det = self.transform_used.jacobian_det(x) 179 if logp_nojac.ndim > jacobian_det.ndim: 180 logp_nojac = logp_nojac.sum(axis=-1) 181 return logp_nojac + jacobian_det 182 183 def logp_nojac(self, x): 184 """ 185 Calculate log-probability of Transformed distribution at specified value 186 without jacobian term for transforms. 187 188 Parameters 189 ---------- 190 x: numeric 191 Value for which log-probability is calculated. 192 193 Returns 194 ------- 195 TensorVariable 196 """ 197 return self.dist.logp(self.transform_used.backward(x)) 198 199 def _repr_latex_(self, **kwargs): 200 # prevent TransformedDistributions from ending up in LaTeX representations 201 # of models 202 return None 203 204 def _distr_parameters_for_repr(self): 205 return [] 206 207 208transform = Transform 209 210 211class Log(ElemwiseTransform): 212 name = "log" 213 214 def backward(self, x): 215 return tt.exp(x) 216 217 def forward(self, x): 218 return tt.log(x) 219 220 def forward_val(self, x, point=None): 221 return np.log(x) 222 223 def jacobian_det(self, x): 224 return x 225 226 227log = Log() 228 229 230class LogExpM1(ElemwiseTransform): 231 name = "log_exp_m1" 232 233 def backward(self, x): 234 return tt.nnet.softplus(x) 235 236 def forward(self, x): 237 """Inverse operation of softplus. 238 239 y = Log(Exp(x) - 1) 240 = Log(1 - Exp(-x)) + x 241 """ 242 return tt.log(1.0 - tt.exp(-x)) + x 243 244 def forward_val(self, x, point=None): 245 return np.log(1.0 - np.exp(-x)) + x 246 247 def jacobian_det(self, x): 248 return -tt.nnet.softplus(-x) 249 250 251log_exp_m1 = LogExpM1() 252 253 254class LogOdds(ElemwiseTransform): 255 name = "logodds" 256 257 def backward(self, x): 258 return invlogit(x, 0.0) 259 260 def forward(self, x): 261 return logit(x) 262 263 def forward_val(self, x, point=None): 264 return nplogit(x) 265 266 267logodds = LogOdds() 268 269 270class Interval(ElemwiseTransform): 271 """Transform from real line interval [a,b] to whole real line.""" 272 273 name = "interval" 274 275 def __init__(self, a, b): 276 self.a = tt.as_tensor_variable(a) 277 self.b = tt.as_tensor_variable(b) 278 279 def backward(self, x): 280 a, b = self.a, self.b 281 sigmoid_x = tt.nnet.sigmoid(x) 282 r = sigmoid_x * b + (1 - sigmoid_x) * a 283 return r 284 285 def forward(self, x): 286 a, b = self.a, self.b 287 return tt.log(x - a) - tt.log(b - x) 288 289 def forward_val(self, x, point=None): 290 # 2017-06-19 291 # the `self.a-0.` below is important for the testval to propagates 292 # For an explanation see pull/2328#issuecomment-309303811 293 a, b = draw_values([self.a - 0.0, self.b - 0.0], point=point) 294 return floatX(np.log(x - a) - np.log(b - x)) 295 296 def jacobian_det(self, x): 297 s = tt.nnet.softplus(-x) 298 return tt.log(self.b - self.a) - 2 * s - x 299 300 301interval = Interval 302 303 304class LowerBound(ElemwiseTransform): 305 """Transform from real line interval [a,inf] to whole real line.""" 306 307 name = "lowerbound" 308 309 def __init__(self, a): 310 self.a = tt.as_tensor_variable(a) 311 312 def backward(self, x): 313 a = self.a 314 r = tt.exp(x) + a 315 return r 316 317 def forward(self, x): 318 a = self.a 319 return tt.log(x - a) 320 321 def forward_val(self, x, point=None): 322 # 2017-06-19 323 # the `self.a-0.` below is important for the testval to propagates 324 # For an explanation see pull/2328#issuecomment-309303811 325 a = draw_values([self.a - 0.0], point=point)[0] 326 return floatX(np.log(x - a)) 327 328 def jacobian_det(self, x): 329 return x 330 331 332lowerbound = LowerBound 333""" 334Alias for ``LowerBound`` (:class: LowerBound) Transform (:class: Transform) class 335for use in the ``transform`` argument of a random variable. 336""" 337 338 339class UpperBound(ElemwiseTransform): 340 """Transform from real line interval [-inf,b] to whole real line.""" 341 342 name = "upperbound" 343 344 def __init__(self, b): 345 self.b = tt.as_tensor_variable(b) 346 347 def backward(self, x): 348 b = self.b 349 r = b - tt.exp(x) 350 return r 351 352 def forward(self, x): 353 b = self.b 354 return tt.log(b - x) 355 356 def forward_val(self, x, point=None): 357 # 2017-06-19 358 # the `self.b-0.` below is important for the testval to propagates 359 # For an explanation see pull/2328#issuecomment-309303811 360 b = draw_values([self.b - 0.0], point=point)[0] 361 return floatX(np.log(b - x)) 362 363 def jacobian_det(self, x): 364 return x 365 366 367upperbound = UpperBound 368""" 369Alias for ``UpperBound`` (:class: UpperBound) Transform (:class: Transform) class 370for use in the ``transform`` argument of a random variable. 371""" 372 373 374class Ordered(Transform): 375 name = "ordered" 376 377 def backward(self, y): 378 x = tt.zeros(y.shape) 379 x = tt.inc_subtensor(x[..., 0], y[..., 0]) 380 x = tt.inc_subtensor(x[..., 1:], tt.exp(y[..., 1:])) 381 return tt.cumsum(x, axis=-1) 382 383 def forward(self, x): 384 y = tt.zeros(x.shape) 385 y = tt.inc_subtensor(y[..., 0], x[..., 0]) 386 y = tt.inc_subtensor(y[..., 1:], tt.log(x[..., 1:] - x[..., :-1])) 387 return y 388 389 def forward_val(self, x, point=None): 390 y = np.zeros_like(x) 391 y[..., 0] = x[..., 0] 392 y[..., 1:] = np.log(x[..., 1:] - x[..., :-1]) 393 return y 394 395 def jacobian_det(self, y): 396 return tt.sum(y[..., 1:], axis=-1) 397 398 399ordered = Ordered() 400""" 401Instantiation of ``Ordered`` (:class: Ordered) Transform (:class: Transform) class 402for use in the ``transform`` argument of a random variable. 403""" 404 405 406class SumTo1(Transform): 407 """ 408 Transforms K - 1 dimensional simplex space (k values in [0,1] and that sum to 1) to a K - 1 vector of values in [0,1] 409 This Transformation operates on the last dimension of the input tensor. 410 """ 411 412 name = "sumto1" 413 414 def backward(self, y): 415 remaining = 1 - tt.sum(y[..., :], axis=-1, keepdims=True) 416 return tt.concatenate([y[..., :], remaining], axis=-1) 417 418 def forward(self, x): 419 return x[..., :-1] 420 421 def forward_val(self, x, point=None): 422 return x[..., :-1] 423 424 def jacobian_det(self, x): 425 y = tt.zeros(x.shape) 426 return tt.sum(y, axis=-1) 427 428 429sum_to_1 = SumTo1() 430 431 432class StickBreaking(Transform): 433 """ 434 Transforms K - 1 dimensional simplex space (k values in [0,1] and that sum to 1) to a K - 1 vector of real values. 435 This is a variant of the isometric logration transformation :: 436 437 Egozcue, J.J., Pawlowsky-Glahn, V., Mateu-Figueras, G. et al. 438 Isometric Logratio Transformations for Compositional Data Analysis. 439 Mathematical Geology 35, 279–300 (2003). https://doi.org/10.1023/A:1023818214614 440 """ 441 442 name = "stickbreaking" 443 444 def __init__(self, eps=None): 445 if eps is not None: 446 warnings.warn( 447 "The argument `eps` is deprecated and will not be used.", DeprecationWarning 448 ) 449 450 def forward(self, x_): 451 x = x_.T 452 n = x.shape[0] 453 lx = tt.log(x) 454 shift = tt.sum(lx, 0, keepdims=True) / n 455 y = lx[:-1] - shift 456 return floatX(y.T) 457 458 def forward_val(self, x_, point=None): 459 x = x_.T 460 n = x.shape[0] 461 lx = np.log(x) 462 shift = np.sum(lx, 0, keepdims=True) / n 463 y = lx[:-1] - shift 464 return floatX(y.T) 465 466 def backward(self, y_): 467 y = y_.T 468 y = tt.concatenate([y, -tt.sum(y, 0, keepdims=True)]) 469 # "softmax" with vector support and no deprication warning: 470 e_y = tt.exp(y - tt.max(y, 0, keepdims=True)) 471 x = e_y / tt.sum(e_y, 0, keepdims=True) 472 return floatX(x.T) 473 474 def jacobian_det(self, y_): 475 y = y_.T 476 Km1 = y.shape[0] + 1 477 sy = tt.sum(y, 0, keepdims=True) 478 r = tt.concatenate([y + sy, tt.zeros(sy.shape)]) 479 sr = logsumexp(r, 0, keepdims=True) 480 d = tt.log(Km1) + (Km1 * sy) - (Km1 * sr) 481 return tt.sum(d, 0).T 482 483 484stick_breaking = StickBreaking() 485 486 487class Circular(ElemwiseTransform): 488 """Transforms a linear space into a circular one.""" 489 490 name = "circular" 491 492 def backward(self, y): 493 return tt.arctan2(tt.sin(y), tt.cos(y)) 494 495 def forward(self, x): 496 return tt.as_tensor_variable(x) 497 498 def forward_val(self, x, point=None): 499 return x 500 501 def jacobian_det(self, x): 502 return tt.zeros(x.shape) 503 504 505circular = Circular() 506 507 508class CholeskyCovPacked(Transform): 509 name = "cholesky-cov-packed" 510 511 def __init__(self, n): 512 self.diag_idxs = np.arange(1, n + 1).cumsum() - 1 513 514 def backward(self, x): 515 return tt.advanced_set_subtensor1(x, tt.exp(x[self.diag_idxs]), self.diag_idxs) 516 517 def forward(self, y): 518 return tt.advanced_set_subtensor1(y, tt.log(y[self.diag_idxs]), self.diag_idxs) 519 520 def forward_val(self, y, point=None): 521 y[..., self.diag_idxs] = np.log(y[..., self.diag_idxs]) 522 return y 523 524 def jacobian_det(self, y): 525 return tt.sum(y[self.diag_idxs]) 526 527 528class Chain(Transform): 529 def __init__(self, transform_list): 530 self.transform_list = transform_list 531 self.name = "+".join([transf.name for transf in self.transform_list]) 532 533 def forward(self, x): 534 y = x 535 for transf in self.transform_list: 536 y = transf.forward(y) 537 return y 538 539 def forward_val(self, x, point=None): 540 y = x 541 for transf in self.transform_list: 542 y = transf.forward_val(y) 543 return y 544 545 def backward(self, y): 546 x = y 547 for transf in reversed(self.transform_list): 548 x = transf.backward(x) 549 return x 550 551 def jacobian_det(self, y): 552 y = tt.as_tensor_variable(y) 553 det_list = [] 554 ndim0 = y.ndim 555 for transf in reversed(self.transform_list): 556 det_ = transf.jacobian_det(y) 557 det_list.append(det_) 558 y = transf.backward(y) 559 ndim0 = min(ndim0, det_.ndim) 560 # match the shape of the smallest jacobian_det 561 det = 0.0 562 for det_ in det_list: 563 if det_.ndim > ndim0: 564 det += det_.sum(axis=-1) 565 else: 566 det += det_ 567 return det 568