1from __future__ import absolute_import, print_function, division 2# Definitions of theano.scalar ops that have their python implementation taken 3# from SciPy. As SciPy is not always available, we treat them separately. 4 5import numpy as np 6import os 7 8import theano 9from theano.gradient import grad_not_implemented 10from theano.scalar.basic import (UnaryScalarOp, BinaryScalarOp, 11 exp, upgrade_to_float, 12 upgrade_to_float64, 13 float_types) 14from theano.scalar.basic import (upgrade_to_float_no_complex, 15 complex_types, discrete_types, 16 upcast) 17 18imported_scipy_special = False 19try: 20 import scipy.special 21 import scipy.stats 22 imported_scipy_special = True 23# Importing scipy.special may raise ValueError. 24# See http://projects.scipy.org/scipy/ticket/1739 25except (ImportError, ValueError): 26 pass 27 28 29class Erf(UnaryScalarOp): 30 nfunc_spec = ('scipy.special.erf', 1, 1) 31 32 def impl(self, x): 33 if imported_scipy_special: 34 return scipy.special.erf(x) 35 else: 36 super(Erf, self).impl(x) 37 38 def L_op(self, inputs, outputs, grads): 39 x, = inputs 40 gz, = grads 41 if x.type in complex_types: 42 raise NotImplementedError() 43 if outputs[0].type in discrete_types: 44 if x.type in discrete_types: 45 return [x.zeros_like(dtype=theano.config.floatX)] 46 else: 47 return [x.zeros_like()] 48 49 cst = np.asarray(2. / np.sqrt(np.pi), 50 dtype=upcast(x.type.dtype, gz.type.dtype)) 51 return gz * cst * exp(-x * x), 52 53 def c_code(self, node, name, inp, out, sub): 54 x, = inp 55 z, = out 56 if node.inputs[0].type in complex_types: 57 raise NotImplementedError('type not supported', type) 58 cast = node.outputs[0].type.dtype_specs()[1] 59 return "%(z)s = erf((%(cast)s)%(x)s);" % locals() 60erf = Erf(upgrade_to_float, name='erf') 61 62 63class Erfc(UnaryScalarOp): 64 nfunc_spec = ('scipy.special.erfc', 1, 1) 65 66 def impl(self, x): 67 if imported_scipy_special: 68 return scipy.special.erfc(x) 69 else: 70 super(Erfc, self).impl(x) 71 72 def L_op(self, inputs, outputs, grads): 73 x, = inputs 74 gz, = grads 75 if x.type in complex_types: 76 raise NotImplementedError() 77 if outputs[0].type in discrete_types: 78 if x.type in discrete_types: 79 return [x.zeros_like(dtype=theano.config.floatX)] 80 else: 81 return [x.zeros_like()] 82 83 cst = np.asarray(2. / np.sqrt(np.pi), 84 dtype=upcast(x.type.dtype, gz.type.dtype)) 85 return - gz * cst * exp(-x * x), 86 87 def c_code(self, node, name, inp, out, sub): 88 x, = inp 89 z, = out 90 if node.inputs[0].type in complex_types: 91 raise NotImplementedError('type not supported', type) 92 cast = node.outputs[0].type.dtype_specs()[1] 93 return "%(z)s = erfc((%(cast)s)%(x)s);" % locals() 94 95# scipy.special.erfc don't support complex. Why? 96erfc = Erfc(upgrade_to_float_no_complex, name='erfc') 97 98 99class Erfcx(UnaryScalarOp): 100 """ 101 Implements the scaled complementary error function exp(x**2)*erfc(x) in a 102 numerically stable way for large x. This is useful for calculating things 103 like log(erfc(x)) = log(erfcx(x)) - x ** 2 without causing underflow. 104 Should only be used if x is known to be large and positive, as using 105 erfcx(x) for large negative x may instead introduce overflow problems. 106 107 Notes 108 ----- 109 This op can still be executed on GPU, despite not having c_code. When 110 running on GPU an optimization will replace it with a gpu version. 111 112 """ 113 nfunc_spec = ('scipy.special.erfcx', 1, 1) 114 115 def impl(self, x): 116 if imported_scipy_special: 117 return scipy.special.erfcx(x) 118 else: 119 super(Erfcx, self).impl(x) 120 121 def L_op(self, inputs, outputs, grads): 122 x, = inputs 123 gz, = grads 124 if x.type in complex_types: 125 raise NotImplementedError() 126 if outputs[0].type in discrete_types: 127 if x.type in discrete_types: 128 return [x.zeros_like(dtype=theano.config.floatX)] 129 else: 130 return [x.zeros_like()] 131 132 cst = np.asarray(2. / np.sqrt(np.pi), 133 dtype=upcast(x.type.dtype, gz.type.dtype)) 134 return gz * (-cst + (2. * x) * erfcx(x)), 135 136erfcx = Erfcx(upgrade_to_float_no_complex, name='erfcx') 137 138 139class Erfinv(UnaryScalarOp): 140 """ 141 Implements the inverse error function. 142 143 Notes 144 ----- 145 This op can still be executed on GPU, despite not having c_code. When 146 running on GPU, an optimization will replace it with a GPU version. 147 148 (TODO) Find a C implementation of erfinv for CPU. 149 """ 150 nfunc_spec = ('scipy.special.erfinv', 1, 1) 151 152 def impl(self, x): 153 if imported_scipy_special: 154 return scipy.special.erfinv(x) 155 else: 156 super(Erfinv, self).impl(x) 157 158 def L_op(self, inputs, outputs, grads): 159 x, = inputs 160 gz, = grads 161 if x.type in complex_types: 162 raise NotImplementedError() 163 if outputs[0].type in discrete_types: 164 if x.type in discrete_types: 165 return [x.zeros_like(dtype=theano.config.floatX)] 166 else: 167 return [x.zeros_like()] 168 169 cst = np.asarray(np.sqrt(np.pi) / 2., 170 dtype=upcast(x.type.dtype, gz.type.dtype)) 171 return gz * cst * exp(erfinv(x) ** 2), 172 173 # TODO: erfinv() is not provided by the C standard library 174 # def c_code(self, node, name, inp, out, sub): 175 # x, = inp 176 # z, = out 177 # if node.inputs[0].type in complex_types: 178 # raise NotImplementedError('type not supported', type) 179 # return "%(z)s = erfinv(%(x)s);" % locals() 180 181erfinv = Erfinv(upgrade_to_float_no_complex, name='erfinv') 182 183 184class Erfcinv(UnaryScalarOp): 185 nfunc_spec = ('scipy.special.erfcinv', 1, 1) 186 187 def impl(self, x): 188 if imported_scipy_special: 189 return scipy.special.erfcinv(x) 190 else: 191 super(Erfcinv, self).impl(x) 192 193 def L_op(self, inputs, outputs, grads): 194 x, = inputs 195 gz, = grads 196 if x.type in complex_types: 197 raise NotImplementedError() 198 if outputs[0].type in discrete_types: 199 if x.type in discrete_types: 200 return [x.zeros_like(dtype=theano.config.floatX)] 201 else: 202 return [x.zeros_like()] 203 204 cst = np.asarray(np.sqrt(np.pi) / 2., 205 dtype=upcast(x.type.dtype, gz.type.dtype)) 206 return - gz * cst * exp(erfcinv(x) ** 2), 207 208 # TODO: erfcinv() is not provided by the C standard library 209 # def c_code(self, node, name, inp, out, sub): 210 # x, = inp 211 # z, = out 212 # if node.inputs[0].type in complex_types: 213 # raise NotImplementedError('type not supported', type) 214 # return "%(z)s = erfcinv(%(x)s);" % locals() 215 216erfcinv = Erfcinv(upgrade_to_float_no_complex, name='erfcinv') 217 218 219class Gamma(UnaryScalarOp): 220 nfunc_spec = ('scipy.special.gamma', 1, 1) 221 222 @staticmethod 223 def st_impl(x): 224 return scipy.special.gamma(x) 225 226 def impl(self, x): 227 if imported_scipy_special: 228 return Gamma.st_impl(x) 229 else: 230 super(Gamma, self).impl(x) 231 232 def L_op(self, inputs, outputs, gout): 233 (x,) = inputs 234 (gz,) = gout 235 if x.type in complex_types: 236 raise NotImplementedError() 237 if outputs[0].type in discrete_types: 238 if x.type in discrete_types: 239 return [x.zeros_like(dtype=theano.config.floatX)] 240 else: 241 return [x.zeros_like()] 242 243 return gz * gamma(x) * psi(x), 244 245 def c_code(self, node, name, inputs, outputs, sub): 246 (x,) = inputs 247 (z,) = outputs 248 if node.inputs[0].type in float_types: 249 return """%(z)s = tgamma(%(x)s);""" % locals() 250 raise NotImplementedError('only floating point is implemented') 251gamma = Gamma(upgrade_to_float, name='gamma') 252 253 254class GammaLn(UnaryScalarOp): 255 """ 256 Log gamma function. 257 258 """ 259 nfunc_spec = ('scipy.special.gammaln', 1, 1) 260 261 @staticmethod 262 def st_impl(x): 263 return scipy.special.gammaln(x) 264 265 def impl(self, x): 266 if imported_scipy_special: 267 return GammaLn.st_impl(x) 268 else: 269 super(GammaLn, self).impl(x) 270 271 def L_op(self, inputs, outputs, grads): 272 x, = inputs 273 gz, = grads 274 if x.type in complex_types: 275 raise NotImplementedError() 276 if outputs[0].type in discrete_types: 277 if x.type in discrete_types: 278 return [x.zeros_like(dtype=theano.config.floatX)] 279 else: 280 return [x.zeros_like()] 281 282 return [gz * psi(x)] 283 284 def c_code(self, node, name, inp, out, sub): 285 x, = inp 286 z, = out 287 # no c code for complex 288 # [u]int* will be casted to float64 before computation 289 if node.inputs[0].type in complex_types: 290 raise NotImplementedError( 291 'gammaln complex c code is not implemented') 292 # For some reason, on the GPU, uint64 inputs don't get casted 293 # automatically to float64. This make the compilation crash 294 dtype = "" 295 cast = node.outputs[0].type.dtype_specs()[1] 296 return """%(z)s = lgamma((%(cast)s)%(x)s);""" % locals() 297gammaln = GammaLn(upgrade_to_float, name='gammaln') 298 299 300class Psi(UnaryScalarOp): 301 """ 302 Derivative of log gamma function. 303 304 """ 305 nfunc_spec = ('scipy.special.psi', 1, 1) 306 307 @staticmethod 308 def st_impl(x): 309 return scipy.special.psi(x) 310 311 def impl(self, x): 312 if imported_scipy_special: 313 return Psi.st_impl(x) 314 else: 315 super(Psi, self).impl(x) 316 317 def L_op(self, inputs, outputs, grads): 318 x, = inputs 319 gz, = grads 320 if x.type in complex_types: 321 raise NotImplementedError() 322 if outputs[0].type in discrete_types: 323 if x.type in discrete_types: 324 return [x.zeros_like(dtype=theano.config.floatX)] 325 else: 326 return [x.zeros_like()] 327 328 return [gz * tri_gamma(x)] 329 330 def c_support_code(self): 331 return ( 332 """ 333 // For GPU support 334 #ifdef WITHIN_KERNEL 335 #define DEVICE WITHIN_KERNEL 336 #else 337 #define DEVICE 338 #endif 339 340 #ifndef ga_double 341 #define ga_double double 342 #endif 343 344 #ifndef _PSIFUNCDEFINED 345 #define _PSIFUNCDEFINED 346 DEVICE double _psi(ga_double x) { 347 348 /*taken from 349 Bernardo, J. M. (1976). Algorithm AS 103: 350 Psi (Digamma) Function. Applied Statistics. 25 (3), 315-317. 351 http://www.uv.es/~bernardo/1976AppStatist.pdf */ 352 353 ga_double y, R, psi_ = 0; 354 ga_double S = 1.0e-5; 355 ga_double C = 8.5; 356 ga_double S3 = 8.333333333e-2; 357 ga_double S4 = 8.333333333e-3; 358 ga_double S5 = 3.968253968e-3; 359 ga_double D1 = -0.5772156649; 360 361 y = x; 362 363 if (y <= 0.0) 364 return psi_; 365 366 if (y <= S) 367 return D1 - 1.0/y; 368 369 while (y < C) { 370 psi_ = psi_ - 1.0 / y; 371 y = y + 1; 372 } 373 374 R = 1.0 / y; 375 psi_ = psi_ + log(y) - .5 * R ; 376 R= R*R; 377 psi_ = psi_ - R * (S3 - R * (S4 - R * S5)); 378 379 return psi_; 380 } 381 #endif 382 """) 383 384 def c_code(self, node, name, inp, out, sub): 385 x, = inp 386 z, = out 387 if node.inputs[0].type in float_types: 388 return """%(z)s = 389 _psi(%(x)s);""" % locals() 390 raise NotImplementedError('only floating point is implemented') 391psi = Psi(upgrade_to_float, name='psi') 392 393 394class TriGamma(UnaryScalarOp): 395 """ 396 Second derivative of log gamma function. 397 398 """ 399 400 @staticmethod 401 def st_impl(x): 402 return scipy.special.polygamma(1, x) 403 404 def impl(self, x): 405 if imported_scipy_special: 406 return TriGamma.st_impl(x) 407 else: 408 super(TriGamma, self).impl(x) 409 410 def grad(self, inputs, outputs_gradients): 411 raise NotImplementedError() 412 413 def c_support_code(self): 414 # The implementation has been copied from 415 # http://people.sc.fsu.edu/~jburkardt/cpp_src/asa121/asa121.html 416 return ( 417 """ 418 // For GPU support 419 #ifdef WITHIN_KERNEL 420 #define DEVICE WITHIN_KERNEL 421 #else 422 #define DEVICE 423 #endif 424 425 #ifndef ga_double 426 #define ga_double double 427 #endif 428 429 #ifndef _TRIGAMMAFUNCDEFINED 430 #define _TRIGAMMAFUNCDEFINED 431 432 DEVICE double _tri_gamma(ga_double x) { 433 434 double a = 0.0001; 435 double b = 5.0; 436 double b2 = 0.1666666667; 437 double b4 = -0.03333333333; 438 double b6 = 0.02380952381; 439 double b8 = -0.03333333333; 440 double value; 441 double y; 442 double z; 443 444 if (x <= 0) { 445 return 0.0; 446 } 447 448 if ( x <= a ) { 449 value = 1.0 / x / x; 450 return value; 451 } 452 453 value = 0.0; 454 z = x; 455 456 while ( z < b ) { 457 value += 1.0 / z / z; 458 z += 1.0; 459 } 460 461 y = 1.0 / z / z; 462 463 value += 0.5 * y + (1.0 + y * (b2 + y * (b4 + y * (b6 + y * b8 )))) / z; 464 465 return value; 466 } 467 #endif 468 """) 469 470 def c_code(self, node, name, inp, out, sub): 471 x, = inp 472 z, = out 473 if node.inputs[0].type in float_types: 474 return """%(z)s = 475 _tri_gamma(%(x)s);""" % locals() 476 raise NotImplementedError('only floating point is implemented') 477 478 479tri_gamma = TriGamma(upgrade_to_float, name='tri_gamma') 480 481 482class Chi2SF(BinaryScalarOp): 483 """ 484 Compute (1 - chi2_cdf(x)) 485 ie. chi2 pvalue (chi2 'survival function') 486 """ 487 nfunc_spec = ('scipy.stats.chi2.sf', 2, 1) 488 489 @staticmethod 490 def st_impl(x, k): 491 return scipy.stats.chi2.sf(x, k) 492 493 def impl(self, x, k): 494 if imported_scipy_special: 495 return Chi2SF.st_impl(x, k) 496 else: 497 super(Chi2SF, self).impl(x, k) 498 499 def c_support_code(self): 500 with open(os.path.join( 501 os.path.dirname(__file__), 502 'c_code', 503 'gamma.c')) as f: 504 raw = f.read() 505 return raw 506 507 def c_code(self, node, name, inp, out, sub): 508 x, k = inp 509 z, = out 510 if node.inputs[0].type in float_types: 511 dtype = 'npy_' + node.outputs[0].dtype 512 return """%(z)s = 513 (%(dtype)s) 1 - GammaP(%(k)s/2., %(x)s/2.);""" % locals() 514 raise NotImplementedError('only floatingpoint is implemented') 515 516 def __eq__(self, other): 517 return type(self) == type(other) 518 519 def __hash__(self): 520 return hash(type(self)) 521 522 523chi2sf = Chi2SF(upgrade_to_float64, name='chi2sf') 524 525 526class GammaInc(BinaryScalarOp): 527 """ 528 Compute the regularized lower gamma function (P). 529 """ 530 nfunc_spec = ('scipy.special.gammainc', 2, 1) 531 532 @staticmethod 533 def st_impl(k, x): 534 return scipy.special.gammainc(k, x) 535 536 def impl(self, k, x): 537 if imported_scipy_special: 538 return GammaInc.st_impl(k, x) 539 else: 540 super(GammaInc, self).impl(k, x) 541 542 def c_support_code(self): 543 with open(os.path.join( 544 os.path.dirname(__file__), 545 'c_code', 546 'gamma.c')) as f: 547 raw = f.read() 548 return raw 549 550 def c_code(self, node, name, inp, out, sub): 551 k, x = inp 552 z, = out 553 if node.inputs[0].type in float_types: 554 dtype = 'npy_' + node.outputs[0].dtype 555 return """%(z)s = 556 (%(dtype)s) GammaP(%(k)s, %(x)s);""" % locals() 557 raise NotImplementedError('only floatingpoint is implemented') 558 559 def __eq__(self, other): 560 return type(self) == type(other) 561 562 def __hash__(self): 563 return hash(type(self)) 564 565 566gammainc = GammaInc(upgrade_to_float, name='gammainc') 567 568 569class GammaIncC(BinaryScalarOp): 570 """ 571 Compute the regularized upper gamma function (Q). 572 """ 573 nfunc_spec = ('scipy.special.gammaincc', 2, 1) 574 575 @staticmethod 576 def st_impl(k, x): 577 return scipy.special.gammaincc(x, k) 578 579 def impl(self, k, x): 580 if imported_scipy_special: 581 return GammaIncC.st_impl(k, x) 582 else: 583 super(GammaIncC, self).impl(k, x) 584 585 def c_support_code(self): 586 with open(os.path.join( 587 os.path.dirname(__file__), 588 'c_code', 589 'gamma.c')) as f: 590 raw = f.read() 591 return raw 592 593 def c_code(self, node, name, inp, out, sub): 594 k, x = inp 595 z, = out 596 if node.inputs[0].type in float_types: 597 dtype = 'npy_' + node.outputs[0].dtype 598 return """%(z)s = 599 (%(dtype)s) GammaQ(%(k)s, %(x)s);""" % locals() 600 raise NotImplementedError('only floatingpoint is implemented') 601 602 def __eq__(self, other): 603 return type(self) == type(other) 604 605 def __hash__(self): 606 return hash(type(self)) 607 608 609gammaincc = GammaIncC(upgrade_to_float, name='gammaincc') 610 611 612class GammaU(BinaryScalarOp): 613 """ 614 compute the upper incomplete gamma function. 615 """ 616 # Note there is no basic SciPy version so no nfunc_spec. 617 618 @staticmethod 619 def st_impl(k, x): 620 return scipy.special.gammaincc(k, x) * scipy.special.gamma(k) 621 622 def impl(self, k, x): 623 if imported_scipy_special: 624 return GammaU.st_impl(k, x) 625 else: 626 super(GammaU, self).impl(k, x) 627 628 def c_support_code(self): 629 with open(os.path.join( 630 os.path.dirname(__file__), 631 'c_code', 632 'gamma.c')) as f: 633 raw = f.read() 634 return raw 635 636 def c_code(self, node, name, inp, out, sub): 637 k, x = inp 638 z, = out 639 if node.inputs[0].type in float_types: 640 dtype = 'npy_' + node.outputs[0].dtype 641 return """%(z)s = 642 (%(dtype)s) upperGamma(%(k)s, %(x)s);""" % locals() 643 raise NotImplementedError('only floatingpoint is implemented') 644 645 def __eq__(self, other): 646 return type(self) == type(other) 647 648 def __hash__(self): 649 return hash(type(self)) 650 651 652gammau = GammaU(upgrade_to_float, name='gammau') 653 654 655class GammaL(BinaryScalarOp): 656 """ 657 Compute the lower incomplete gamma function. 658 """ 659 # Note there is no basic SciPy version so no nfunc_spec. 660 661 @staticmethod 662 def st_impl(k, x): 663 return scipy.special.gammainc(k, x) * scipy.special.gamma(k) 664 665 def impl(self, k, x): 666 if imported_scipy_special: 667 return GammaL.st_impl(k, x) 668 else: 669 super(GammaL, self).impl(k, x) 670 671 def c_support_code(self): 672 with open(os.path.join( 673 os.path.dirname(__file__), 674 'c_code', 675 'gamma.c')) as f: 676 raw = f.read() 677 return raw 678 679 def c_code(self, node, name, inp, out, sub): 680 k, x = inp 681 z, = out 682 if node.inputs[0].type in float_types: 683 dtype = 'npy_' + node.outputs[0].dtype 684 return """%(z)s = 685 (%(dtype)s) lowerGamma(%(k)s, %(x)s);""" % locals() 686 raise NotImplementedError('only floatingpoint is implemented') 687 688 def __eq__(self, other): 689 return type(self) == type(other) 690 691 def __hash__(self): 692 return hash(type(self)) 693 694 695gammal = GammaL(upgrade_to_float, name='gammal') 696 697 698class Jv(BinaryScalarOp): 699 """ 700 Bessel function of the first kind of order v (real). 701 """ 702 nfunc_spec = ('scipy.special.jv', 2, 1) 703 704 @staticmethod 705 def st_impl(v, x): 706 return scipy.special.jv(v, x) 707 708 def impl(self, v, x): 709 if imported_scipy_special: 710 return self.st_impl(v, x) 711 else: 712 super(Jv, self).impl(v, x) 713 714 def grad(self, inputs, grads): 715 v, x = inputs 716 gz, = grads 717 return [grad_not_implemented(self, 0, v), 718 gz * (jv(v - 1, x) - jv(v + 1, x)) / 2.] 719 720jv = Jv(upgrade_to_float, name='jv') 721 722 723class J1(UnaryScalarOp): 724 """ 725 Bessel function of the first kind of order 1. 726 """ 727 nfunc_spec = ('scipy.special.j1', 1, 1) 728 729 @staticmethod 730 def st_impl(x): 731 return scipy.special.j1(x) 732 733 def impl(self, x): 734 if imported_scipy_special: 735 return self.st_impl(x) 736 else: 737 super(J1, self).impl(x) 738 739 def grad(self, inputs, grads): 740 x, = inputs 741 gz, = grads 742 return [gz * (j0(x) - jv(2, x)) / 2.] 743 744 def c_code(self, node, name, inp, out, sub): 745 x, = inp 746 z, = out 747 if node.inputs[0].type in float_types: 748 return """%(z)s = 749 j1(%(x)s);""" % locals() 750 raise NotImplementedError('only floating point is implemented') 751 752j1 = J1(upgrade_to_float, name='j1') 753 754 755class J0(UnaryScalarOp): 756 """ 757 Bessel function of the first kind of order 0. 758 """ 759 nfunc_spec = ('scipy.special.j0', 1, 1) 760 761 @staticmethod 762 def st_impl(x): 763 return scipy.special.j0(x) 764 765 def impl(self, x): 766 if imported_scipy_special: 767 return self.st_impl(x) 768 else: 769 super(J0, self).impl(x) 770 771 def grad(self, inp, grads): 772 x, = inp 773 gz, = grads 774 return [gz * -1 * j1(x)] 775 776 def c_code(self, node, name, inp, out, sub): 777 x, = inp 778 z, = out 779 if node.inputs[0].type in float_types: 780 return """%(z)s = 781 j0(%(x)s);""" % locals() 782 raise NotImplementedError('only floating point is implemented') 783 784j0 = J0(upgrade_to_float, name='j0') 785 786 787class Iv(BinaryScalarOp): 788 """ 789 Modified Bessel function of the first kind of order v (real). 790 """ 791 nfunc_spec = ('scipy.special.iv', 2, 1) 792 793 @staticmethod 794 def st_impl(v, x): 795 return scipy.special.iv(v, x) 796 797 def impl(self, v, x): 798 if imported_scipy_special: 799 return self.st_impl(v, x) 800 else: 801 super(Iv, self).impl(v, x) 802 803 def grad(self, inputs, grads): 804 v, x = inputs 805 gz, = grads 806 return [grad_not_implemented(self, 0, v), 807 gz * (iv(v - 1, x) + iv(v + 1, x)) / 2.] 808 809iv = Iv(upgrade_to_float, name='iv') 810 811 812class I1(UnaryScalarOp): 813 """ 814 Modified Bessel function of the first kind of order 1. 815 """ 816 nfunc_spec = ('scipy.special.i1', 1, 1) 817 818 @staticmethod 819 def st_impl(x): 820 return scipy.special.i1(x) 821 822 def impl(self, x): 823 if imported_scipy_special: 824 return self.st_impl(x) 825 else: 826 super(I1, self).impl(x) 827 828 def grad(self, inputs, grads): 829 x, = inputs 830 gz, = grads 831 return [gz * (i0(x) + iv(2, x)) / 2.] 832 833i1 = I1(upgrade_to_float, name='i1') 834 835 836class I0(UnaryScalarOp): 837 """ 838 Modified Bessel function of the first kind of order 0. 839 """ 840 nfunc_spec = ('scipy.special.i0', 1, 1) 841 842 @staticmethod 843 def st_impl(x): 844 return scipy.special.i0(x) 845 846 def impl(self, x): 847 if imported_scipy_special: 848 return self.st_impl(x) 849 else: 850 super(I0, self).impl(x) 851 852 def grad(self, inp, grads): 853 x, = inp 854 gz, = grads 855 return [gz * i1(x)] 856 857i0 = I0(upgrade_to_float, name='i0') 858