1from __future__ import absolute_import, print_function, division
2# Definitions of theano.scalar ops that have their python implementation taken
3# from SciPy. As SciPy is not always available, we treat them separately.
4
5import numpy as np
6import os
7
8import theano
9from theano.gradient import grad_not_implemented
10from theano.scalar.basic import (UnaryScalarOp, BinaryScalarOp,
11                                 exp, upgrade_to_float,
12                                 upgrade_to_float64,
13                                 float_types)
14from theano.scalar.basic import (upgrade_to_float_no_complex,
15                                 complex_types, discrete_types,
16                                 upcast)
17
18imported_scipy_special = False
19try:
20    import scipy.special
21    import scipy.stats
22    imported_scipy_special = True
23# Importing scipy.special may raise ValueError.
24# See http://projects.scipy.org/scipy/ticket/1739
25except (ImportError, ValueError):
26    pass
27
28
29class Erf(UnaryScalarOp):
30    nfunc_spec = ('scipy.special.erf', 1, 1)
31
32    def impl(self, x):
33        if imported_scipy_special:
34            return scipy.special.erf(x)
35        else:
36            super(Erf, self).impl(x)
37
38    def L_op(self, inputs, outputs, grads):
39        x, = inputs
40        gz, = grads
41        if x.type in complex_types:
42            raise NotImplementedError()
43        if outputs[0].type in discrete_types:
44            if x.type in discrete_types:
45                return [x.zeros_like(dtype=theano.config.floatX)]
46            else:
47                return [x.zeros_like()]
48
49        cst = np.asarray(2. / np.sqrt(np.pi),
50                         dtype=upcast(x.type.dtype, gz.type.dtype))
51        return gz * cst * exp(-x * x),
52
53    def c_code(self, node, name, inp, out, sub):
54        x, = inp
55        z, = out
56        if node.inputs[0].type in complex_types:
57            raise NotImplementedError('type not supported', type)
58        cast = node.outputs[0].type.dtype_specs()[1]
59        return "%(z)s = erf((%(cast)s)%(x)s);" % locals()
60erf = Erf(upgrade_to_float, name='erf')
61
62
63class Erfc(UnaryScalarOp):
64    nfunc_spec = ('scipy.special.erfc', 1, 1)
65
66    def impl(self, x):
67        if imported_scipy_special:
68            return scipy.special.erfc(x)
69        else:
70            super(Erfc, self).impl(x)
71
72    def L_op(self, inputs, outputs, grads):
73        x, = inputs
74        gz, = grads
75        if x.type in complex_types:
76            raise NotImplementedError()
77        if outputs[0].type in discrete_types:
78            if x.type in discrete_types:
79                return [x.zeros_like(dtype=theano.config.floatX)]
80            else:
81                return [x.zeros_like()]
82
83        cst = np.asarray(2. / np.sqrt(np.pi),
84                         dtype=upcast(x.type.dtype, gz.type.dtype))
85        return - gz * cst * exp(-x * x),
86
87    def c_code(self, node, name, inp, out, sub):
88        x, = inp
89        z, = out
90        if node.inputs[0].type in complex_types:
91            raise NotImplementedError('type not supported', type)
92        cast = node.outputs[0].type.dtype_specs()[1]
93        return "%(z)s = erfc((%(cast)s)%(x)s);" % locals()
94
95# scipy.special.erfc don't support complex. Why?
96erfc = Erfc(upgrade_to_float_no_complex, name='erfc')
97
98
99class Erfcx(UnaryScalarOp):
100    """
101    Implements the scaled complementary error function exp(x**2)*erfc(x) in a
102    numerically stable way for large x. This is useful for calculating things
103    like log(erfc(x)) = log(erfcx(x)) - x ** 2 without causing underflow.
104    Should only be used if x is known to be large and positive, as using
105    erfcx(x) for large negative x may instead introduce overflow problems.
106
107    Notes
108    -----
109    This op can still be executed on GPU, despite not having c_code. When
110    running on GPU an optimization will replace it with a gpu version.
111
112    """
113    nfunc_spec = ('scipy.special.erfcx', 1, 1)
114
115    def impl(self, x):
116        if imported_scipy_special:
117            return scipy.special.erfcx(x)
118        else:
119            super(Erfcx, self).impl(x)
120
121    def L_op(self, inputs, outputs, grads):
122        x, = inputs
123        gz, = grads
124        if x.type in complex_types:
125            raise NotImplementedError()
126        if outputs[0].type in discrete_types:
127            if x.type in discrete_types:
128                return [x.zeros_like(dtype=theano.config.floatX)]
129            else:
130                return [x.zeros_like()]
131
132        cst = np.asarray(2. / np.sqrt(np.pi),
133                         dtype=upcast(x.type.dtype, gz.type.dtype))
134        return gz * (-cst + (2. * x) * erfcx(x)),
135
136erfcx = Erfcx(upgrade_to_float_no_complex, name='erfcx')
137
138
139class Erfinv(UnaryScalarOp):
140    """
141    Implements the inverse error function.
142
143    Notes
144    -----
145    This op can still be executed on GPU, despite not having c_code. When
146    running on GPU, an optimization will replace it with a GPU version.
147
148    (TODO) Find a C implementation of erfinv for CPU.
149    """
150    nfunc_spec = ('scipy.special.erfinv', 1, 1)
151
152    def impl(self, x):
153        if imported_scipy_special:
154            return scipy.special.erfinv(x)
155        else:
156            super(Erfinv, self).impl(x)
157
158    def L_op(self, inputs, outputs, grads):
159        x, = inputs
160        gz, = grads
161        if x.type in complex_types:
162            raise NotImplementedError()
163        if outputs[0].type in discrete_types:
164            if x.type in discrete_types:
165                return [x.zeros_like(dtype=theano.config.floatX)]
166            else:
167                return [x.zeros_like()]
168
169        cst = np.asarray(np.sqrt(np.pi) / 2.,
170                         dtype=upcast(x.type.dtype, gz.type.dtype))
171        return gz * cst * exp(erfinv(x) ** 2),
172
173    # TODO: erfinv() is not provided by the C standard library
174    # def c_code(self, node, name, inp, out, sub):
175    #    x, = inp
176    #    z, = out
177    #    if node.inputs[0].type in complex_types:
178    #        raise NotImplementedError('type not supported', type)
179    #    return "%(z)s = erfinv(%(x)s);" % locals()
180
181erfinv = Erfinv(upgrade_to_float_no_complex, name='erfinv')
182
183
184class Erfcinv(UnaryScalarOp):
185    nfunc_spec = ('scipy.special.erfcinv', 1, 1)
186
187    def impl(self, x):
188        if imported_scipy_special:
189            return scipy.special.erfcinv(x)
190        else:
191            super(Erfcinv, self).impl(x)
192
193    def L_op(self, inputs, outputs, grads):
194        x, = inputs
195        gz, = grads
196        if x.type in complex_types:
197            raise NotImplementedError()
198        if outputs[0].type in discrete_types:
199            if x.type in discrete_types:
200                return [x.zeros_like(dtype=theano.config.floatX)]
201            else:
202                return [x.zeros_like()]
203
204        cst = np.asarray(np.sqrt(np.pi) / 2.,
205                         dtype=upcast(x.type.dtype, gz.type.dtype))
206        return - gz * cst * exp(erfcinv(x) ** 2),
207
208    # TODO: erfcinv() is not provided by the C standard library
209    # def c_code(self, node, name, inp, out, sub):
210    #    x, = inp
211    #    z, = out
212    #    if node.inputs[0].type in complex_types:
213    #        raise NotImplementedError('type not supported', type)
214    #    return "%(z)s = erfcinv(%(x)s);" % locals()
215
216erfcinv = Erfcinv(upgrade_to_float_no_complex, name='erfcinv')
217
218
219class Gamma(UnaryScalarOp):
220    nfunc_spec = ('scipy.special.gamma', 1, 1)
221
222    @staticmethod
223    def st_impl(x):
224        return scipy.special.gamma(x)
225
226    def impl(self, x):
227        if imported_scipy_special:
228            return Gamma.st_impl(x)
229        else:
230            super(Gamma, self).impl(x)
231
232    def L_op(self, inputs, outputs, gout):
233        (x,) = inputs
234        (gz,) = gout
235        if x.type in complex_types:
236            raise NotImplementedError()
237        if outputs[0].type in discrete_types:
238            if x.type in discrete_types:
239                return [x.zeros_like(dtype=theano.config.floatX)]
240            else:
241                return [x.zeros_like()]
242
243        return gz * gamma(x) * psi(x),
244
245    def c_code(self, node, name, inputs, outputs, sub):
246        (x,) = inputs
247        (z,) = outputs
248        if node.inputs[0].type in float_types:
249            return """%(z)s = tgamma(%(x)s);""" % locals()
250        raise NotImplementedError('only floating point is implemented')
251gamma = Gamma(upgrade_to_float, name='gamma')
252
253
254class GammaLn(UnaryScalarOp):
255    """
256    Log gamma function.
257
258    """
259    nfunc_spec = ('scipy.special.gammaln', 1, 1)
260
261    @staticmethod
262    def st_impl(x):
263        return scipy.special.gammaln(x)
264
265    def impl(self, x):
266        if imported_scipy_special:
267            return GammaLn.st_impl(x)
268        else:
269            super(GammaLn, self).impl(x)
270
271    def L_op(self, inputs, outputs, grads):
272        x, = inputs
273        gz, = grads
274        if x.type in complex_types:
275            raise NotImplementedError()
276        if outputs[0].type in discrete_types:
277            if x.type in discrete_types:
278                return [x.zeros_like(dtype=theano.config.floatX)]
279            else:
280                return [x.zeros_like()]
281
282        return [gz * psi(x)]
283
284    def c_code(self, node, name, inp, out, sub):
285        x, = inp
286        z, = out
287        # no c code for complex
288        # [u]int* will be casted to float64 before computation
289        if node.inputs[0].type in complex_types:
290            raise NotImplementedError(
291                'gammaln complex c code is not implemented')
292        # For some reason, on the GPU, uint64 inputs don't get casted
293        # automatically to float64. This make the compilation crash
294        dtype = ""
295        cast = node.outputs[0].type.dtype_specs()[1]
296        return """%(z)s = lgamma((%(cast)s)%(x)s);""" % locals()
297gammaln = GammaLn(upgrade_to_float, name='gammaln')
298
299
300class Psi(UnaryScalarOp):
301    """
302    Derivative of log gamma function.
303
304    """
305    nfunc_spec = ('scipy.special.psi', 1, 1)
306
307    @staticmethod
308    def st_impl(x):
309        return scipy.special.psi(x)
310
311    def impl(self, x):
312        if imported_scipy_special:
313            return Psi.st_impl(x)
314        else:
315            super(Psi, self).impl(x)
316
317    def L_op(self, inputs, outputs, grads):
318        x, = inputs
319        gz, = grads
320        if x.type in complex_types:
321            raise NotImplementedError()
322        if outputs[0].type in discrete_types:
323            if x.type in discrete_types:
324                return [x.zeros_like(dtype=theano.config.floatX)]
325            else:
326                return [x.zeros_like()]
327
328        return [gz * tri_gamma(x)]
329
330    def c_support_code(self):
331        return (
332            """
333            // For GPU support
334            #ifdef WITHIN_KERNEL
335            #define DEVICE WITHIN_KERNEL
336            #else
337            #define DEVICE
338            #endif
339
340            #ifndef ga_double
341            #define ga_double double
342            #endif
343
344            #ifndef _PSIFUNCDEFINED
345            #define _PSIFUNCDEFINED
346            DEVICE double _psi(ga_double x) {
347
348            /*taken from
349            Bernardo, J. M. (1976). Algorithm AS 103:
350            Psi (Digamma) Function. Applied Statistics. 25 (3), 315-317.
351            http://www.uv.es/~bernardo/1976AppStatist.pdf */
352
353            ga_double y, R, psi_ = 0;
354            ga_double S  = 1.0e-5;
355            ga_double C = 8.5;
356            ga_double S3 = 8.333333333e-2;
357            ga_double S4 = 8.333333333e-3;
358            ga_double S5 = 3.968253968e-3;
359            ga_double D1 = -0.5772156649;
360
361            y = x;
362
363            if (y <= 0.0)
364               return psi_;
365
366            if (y <= S)
367                return D1 - 1.0/y;
368
369            while (y < C) {
370                psi_ = psi_ - 1.0 / y;
371                y = y + 1;
372            }
373
374            R = 1.0 / y;
375            psi_ = psi_ + log(y) - .5 * R ;
376            R= R*R;
377            psi_ = psi_ - R * (S3 - R * (S4 - R * S5));
378
379            return psi_;
380            }
381            #endif
382            """)
383
384    def c_code(self, node, name, inp, out, sub):
385        x, = inp
386        z, = out
387        if node.inputs[0].type in float_types:
388            return """%(z)s =
389                _psi(%(x)s);""" % locals()
390        raise NotImplementedError('only floating point is implemented')
391psi = Psi(upgrade_to_float, name='psi')
392
393
394class TriGamma(UnaryScalarOp):
395    """
396    Second derivative of log gamma function.
397
398    """
399
400    @staticmethod
401    def st_impl(x):
402        return scipy.special.polygamma(1, x)
403
404    def impl(self, x):
405        if imported_scipy_special:
406            return TriGamma.st_impl(x)
407        else:
408            super(TriGamma, self).impl(x)
409
410    def grad(self, inputs, outputs_gradients):
411        raise NotImplementedError()
412
413    def c_support_code(self):
414        # The implementation has been copied from
415        # http://people.sc.fsu.edu/~jburkardt/cpp_src/asa121/asa121.html
416        return (
417            """
418            // For GPU support
419            #ifdef WITHIN_KERNEL
420            #define DEVICE WITHIN_KERNEL
421            #else
422            #define DEVICE
423            #endif
424
425            #ifndef ga_double
426            #define ga_double double
427            #endif
428
429            #ifndef _TRIGAMMAFUNCDEFINED
430            #define _TRIGAMMAFUNCDEFINED
431
432            DEVICE double _tri_gamma(ga_double x) {
433
434                double a = 0.0001;
435                double b = 5.0;
436                double b2 =  0.1666666667;
437                double b4 = -0.03333333333;
438                double b6 =  0.02380952381;
439                double b8 = -0.03333333333;
440                double value;
441                double y;
442                double z;
443
444                if (x <= 0) {
445                    return 0.0;
446                }
447
448                if ( x <= a ) {
449                    value = 1.0 / x / x;
450                    return value;
451                }
452
453                value = 0.0;
454                z = x;
455
456                while ( z < b ) {
457                    value += 1.0 / z / z;
458                    z += 1.0;
459                }
460
461                y = 1.0 / z / z;
462
463                value +=  0.5 * y + (1.0 + y * (b2 + y * (b4 + y * (b6 + y * b8 )))) / z;
464
465                return value;
466            }
467            #endif
468            """)
469
470    def c_code(self, node, name, inp, out, sub):
471        x, = inp
472        z, = out
473        if node.inputs[0].type in float_types:
474            return """%(z)s =
475                _tri_gamma(%(x)s);""" % locals()
476        raise NotImplementedError('only floating point is implemented')
477
478
479tri_gamma = TriGamma(upgrade_to_float, name='tri_gamma')
480
481
482class Chi2SF(BinaryScalarOp):
483    """
484    Compute (1 - chi2_cdf(x))
485        ie. chi2 pvalue (chi2 'survival function')
486    """
487    nfunc_spec = ('scipy.stats.chi2.sf', 2, 1)
488
489    @staticmethod
490    def st_impl(x, k):
491        return scipy.stats.chi2.sf(x, k)
492
493    def impl(self, x, k):
494        if imported_scipy_special:
495            return Chi2SF.st_impl(x, k)
496        else:
497            super(Chi2SF, self).impl(x, k)
498
499    def c_support_code(self):
500        with open(os.path.join(
501                os.path.dirname(__file__),
502                'c_code',
503                'gamma.c')) as f:
504            raw = f.read()
505            return raw
506
507    def c_code(self, node, name, inp, out, sub):
508        x, k = inp
509        z, = out
510        if node.inputs[0].type in float_types:
511            dtype = 'npy_' + node.outputs[0].dtype
512            return """%(z)s =
513                (%(dtype)s) 1 - GammaP(%(k)s/2., %(x)s/2.);""" % locals()
514        raise NotImplementedError('only floatingpoint is implemented')
515
516    def __eq__(self, other):
517        return type(self) == type(other)
518
519    def __hash__(self):
520        return hash(type(self))
521
522
523chi2sf = Chi2SF(upgrade_to_float64, name='chi2sf')
524
525
526class GammaInc(BinaryScalarOp):
527    """
528    Compute the regularized lower gamma function (P).
529    """
530    nfunc_spec = ('scipy.special.gammainc', 2, 1)
531
532    @staticmethod
533    def st_impl(k, x):
534        return scipy.special.gammainc(k, x)
535
536    def impl(self, k, x):
537        if imported_scipy_special:
538            return GammaInc.st_impl(k, x)
539        else:
540            super(GammaInc, self).impl(k, x)
541
542    def c_support_code(self):
543        with open(os.path.join(
544                os.path.dirname(__file__),
545                'c_code',
546                'gamma.c')) as f:
547            raw = f.read()
548            return raw
549
550    def c_code(self, node, name, inp, out, sub):
551        k, x = inp
552        z, = out
553        if node.inputs[0].type in float_types:
554            dtype = 'npy_' + node.outputs[0].dtype
555            return """%(z)s =
556                (%(dtype)s) GammaP(%(k)s, %(x)s);""" % locals()
557        raise NotImplementedError('only floatingpoint is implemented')
558
559    def __eq__(self, other):
560        return type(self) == type(other)
561
562    def __hash__(self):
563        return hash(type(self))
564
565
566gammainc = GammaInc(upgrade_to_float, name='gammainc')
567
568
569class GammaIncC(BinaryScalarOp):
570    """
571    Compute the regularized upper gamma function (Q).
572    """
573    nfunc_spec = ('scipy.special.gammaincc', 2, 1)
574
575    @staticmethod
576    def st_impl(k, x):
577        return scipy.special.gammaincc(x, k)
578
579    def impl(self, k, x):
580        if imported_scipy_special:
581            return GammaIncC.st_impl(k, x)
582        else:
583            super(GammaIncC, self).impl(k, x)
584
585    def c_support_code(self):
586        with open(os.path.join(
587                os.path.dirname(__file__),
588                'c_code',
589                'gamma.c')) as f:
590            raw = f.read()
591            return raw
592
593    def c_code(self, node, name, inp, out, sub):
594        k, x = inp
595        z, = out
596        if node.inputs[0].type in float_types:
597            dtype = 'npy_' + node.outputs[0].dtype
598            return """%(z)s =
599                (%(dtype)s) GammaQ(%(k)s, %(x)s);""" % locals()
600        raise NotImplementedError('only floatingpoint is implemented')
601
602    def __eq__(self, other):
603        return type(self) == type(other)
604
605    def __hash__(self):
606        return hash(type(self))
607
608
609gammaincc = GammaIncC(upgrade_to_float, name='gammaincc')
610
611
612class GammaU(BinaryScalarOp):
613    """
614    compute the upper incomplete gamma function.
615    """
616    # Note there is no basic SciPy version so no nfunc_spec.
617
618    @staticmethod
619    def st_impl(k, x):
620        return scipy.special.gammaincc(k, x) * scipy.special.gamma(k)
621
622    def impl(self, k, x):
623        if imported_scipy_special:
624            return GammaU.st_impl(k, x)
625        else:
626            super(GammaU, self).impl(k, x)
627
628    def c_support_code(self):
629        with open(os.path.join(
630                os.path.dirname(__file__),
631                'c_code',
632                'gamma.c')) as f:
633            raw = f.read()
634            return raw
635
636    def c_code(self, node, name, inp, out, sub):
637        k, x = inp
638        z, = out
639        if node.inputs[0].type in float_types:
640            dtype = 'npy_' + node.outputs[0].dtype
641            return """%(z)s =
642                (%(dtype)s) upperGamma(%(k)s, %(x)s);""" % locals()
643        raise NotImplementedError('only floatingpoint is implemented')
644
645    def __eq__(self, other):
646        return type(self) == type(other)
647
648    def __hash__(self):
649        return hash(type(self))
650
651
652gammau = GammaU(upgrade_to_float, name='gammau')
653
654
655class GammaL(BinaryScalarOp):
656    """
657    Compute the lower incomplete gamma function.
658    """
659    # Note there is no basic SciPy version so no nfunc_spec.
660
661    @staticmethod
662    def st_impl(k, x):
663        return scipy.special.gammainc(k, x) * scipy.special.gamma(k)
664
665    def impl(self, k, x):
666        if imported_scipy_special:
667            return GammaL.st_impl(k, x)
668        else:
669            super(GammaL, self).impl(k, x)
670
671    def c_support_code(self):
672        with open(os.path.join(
673                os.path.dirname(__file__),
674                'c_code',
675                'gamma.c')) as f:
676            raw = f.read()
677            return raw
678
679    def c_code(self, node, name, inp, out, sub):
680        k, x = inp
681        z, = out
682        if node.inputs[0].type in float_types:
683            dtype = 'npy_' + node.outputs[0].dtype
684            return """%(z)s =
685                (%(dtype)s) lowerGamma(%(k)s, %(x)s);""" % locals()
686        raise NotImplementedError('only floatingpoint is implemented')
687
688    def __eq__(self, other):
689        return type(self) == type(other)
690
691    def __hash__(self):
692        return hash(type(self))
693
694
695gammal = GammaL(upgrade_to_float, name='gammal')
696
697
698class Jv(BinaryScalarOp):
699    """
700    Bessel function of the first kind of order v (real).
701    """
702    nfunc_spec = ('scipy.special.jv', 2, 1)
703
704    @staticmethod
705    def st_impl(v, x):
706        return scipy.special.jv(v, x)
707
708    def impl(self, v, x):
709        if imported_scipy_special:
710            return self.st_impl(v, x)
711        else:
712            super(Jv, self).impl(v, x)
713
714    def grad(self, inputs, grads):
715        v, x = inputs
716        gz, = grads
717        return [grad_not_implemented(self, 0, v),
718                gz * (jv(v - 1, x) - jv(v + 1, x)) / 2.]
719
720jv = Jv(upgrade_to_float, name='jv')
721
722
723class J1(UnaryScalarOp):
724    """
725    Bessel function of the first kind of order 1.
726    """
727    nfunc_spec = ('scipy.special.j1', 1, 1)
728
729    @staticmethod
730    def st_impl(x):
731        return scipy.special.j1(x)
732
733    def impl(self, x):
734        if imported_scipy_special:
735            return self.st_impl(x)
736        else:
737            super(J1, self).impl(x)
738
739    def grad(self, inputs, grads):
740        x, = inputs
741        gz, = grads
742        return [gz * (j0(x) - jv(2, x)) / 2.]
743
744    def c_code(self, node, name, inp, out, sub):
745        x, = inp
746        z, = out
747        if node.inputs[0].type in float_types:
748            return """%(z)s =
749                j1(%(x)s);""" % locals()
750        raise NotImplementedError('only floating point is implemented')
751
752j1 = J1(upgrade_to_float, name='j1')
753
754
755class J0(UnaryScalarOp):
756    """
757    Bessel function of the first kind of order 0.
758    """
759    nfunc_spec = ('scipy.special.j0', 1, 1)
760
761    @staticmethod
762    def st_impl(x):
763        return scipy.special.j0(x)
764
765    def impl(self, x):
766        if imported_scipy_special:
767            return self.st_impl(x)
768        else:
769            super(J0, self).impl(x)
770
771    def grad(self, inp, grads):
772        x, = inp
773        gz, = grads
774        return [gz * -1 * j1(x)]
775
776    def c_code(self, node, name, inp, out, sub):
777        x, = inp
778        z, = out
779        if node.inputs[0].type in float_types:
780            return """%(z)s =
781                j0(%(x)s);""" % locals()
782        raise NotImplementedError('only floating point is implemented')
783
784j0 = J0(upgrade_to_float, name='j0')
785
786
787class Iv(BinaryScalarOp):
788    """
789    Modified Bessel function of the first kind of order v (real).
790    """
791    nfunc_spec = ('scipy.special.iv', 2, 1)
792
793    @staticmethod
794    def st_impl(v, x):
795        return scipy.special.iv(v, x)
796
797    def impl(self, v, x):
798        if imported_scipy_special:
799            return self.st_impl(v, x)
800        else:
801            super(Iv, self).impl(v, x)
802
803    def grad(self, inputs, grads):
804        v, x = inputs
805        gz, = grads
806        return [grad_not_implemented(self, 0, v),
807                gz * (iv(v - 1, x) + iv(v + 1, x)) / 2.]
808
809iv = Iv(upgrade_to_float, name='iv')
810
811
812class I1(UnaryScalarOp):
813    """
814    Modified Bessel function of the first kind of order 1.
815    """
816    nfunc_spec = ('scipy.special.i1', 1, 1)
817
818    @staticmethod
819    def st_impl(x):
820        return scipy.special.i1(x)
821
822    def impl(self, x):
823        if imported_scipy_special:
824            return self.st_impl(x)
825        else:
826            super(I1, self).impl(x)
827
828    def grad(self, inputs, grads):
829        x, = inputs
830        gz, = grads
831        return [gz * (i0(x) + iv(2, x)) / 2.]
832
833i1 = I1(upgrade_to_float, name='i1')
834
835
836class I0(UnaryScalarOp):
837    """
838    Modified Bessel function of the first kind of order 0.
839    """
840    nfunc_spec = ('scipy.special.i0', 1, 1)
841
842    @staticmethod
843    def st_impl(x):
844        return scipy.special.i0(x)
845
846    def impl(self, x):
847        if imported_scipy_special:
848            return self.st_impl(x)
849        else:
850            super(I0, self).impl(x)
851
852    def grad(self, inp, grads):
853        x, = inp
854        gz, = grads
855        return [gz * i1(x)]
856
857i0 = I0(upgrade_to_float, name='i0')
858