1"""
2.. warning::
3
4This directory is for the internal of Theano.
5
6You are strongly advised not to use it, except if you know
7what you are doing!
8
9If you want to use a scalar variable in a Theano graph,
10you probably want to use theano.tensor.[c,z,f,d,b,w,i,l,]scalar!
11"""
12from __future__ import absolute_import, print_function, division
13
14from itertools import chain
15import math
16import warnings
17from copy import copy
18from textwrap import dedent
19
20import numpy as np
21import six
22from six.moves import xrange
23
24import theano
25from theano.compat import imap, izip, Callable
26from theano import gof, printing
27from theano.gof import (Op, utils, Variable, Constant, Type, Apply,
28                        FunctionGraph)
29from functools import partial
30from theano import config
31
32from theano.gradient import DisconnectedType
33from theano.gradient import grad_undefined
34
35from theano.printing import pprint
36
37builtin_bool = bool
38builtin_complex = complex
39builtin_int = int
40builtin_float = float
41
42
43class ComplexError(NotImplementedError):
44    """
45    Raised if complex numbers are used in an unsupported operation.
46
47    """
48    pass
49
50
51class IntegerDivisionError(Exception):
52    """
53    Raised if someone tries to divide integers with '/' instead of '//'.
54
55    """
56    pass
57
58
59def upcast(dtype, *dtypes):
60    # This tries to keep data in floatX or lower precision, unless we
61    # explicitely request a higher precision datatype.
62    keep_float32 = [(config.cast_policy == 'numpy+floatX' and
63                     config.floatX == 'float32')]
64    keep_float16 = [(config.cast_policy == 'numpy+floatX' and
65                     config.floatX == 'float16')]
66
67    def make_array(dt):
68        if dt == 'float64':
69            # There is an explicit float64 dtype: we cannot keep float32.
70            keep_float32[0] = False
71            keep_float16[0] = False
72        if dt == 'float32':
73            keep_float16[0] = False
74        return np.zeros((), dtype=dt)
75    z = make_array(dtype)
76    for dt in dtypes:
77        z = z + make_array(dt=dt)
78    rval = str(z.dtype)
79    if rval == 'float64':
80        if keep_float16[0]:
81            return 'float16'
82        if keep_float32[0]:
83            return 'float32'
84    elif rval == 'float32':
85        if keep_float16[0]:
86            return 'float16'
87    return rval
88
89
90def as_common_dtype(*vars):
91    """
92    For for theano.scalar.Scalar and TensorVariable.
93    """
94    dtype = upcast(*[v.dtype for v in vars])
95    return (v.astype(dtype) for v in vars)
96
97
98def get_scalar_type(dtype):
99    """
100    Return a Scalar(dtype) object.
101
102    This caches objects to save allocation and run time.
103
104    """
105    if dtype not in get_scalar_type.cache:
106        get_scalar_type.cache[dtype] = Scalar(dtype=dtype)
107    return get_scalar_type.cache[dtype]
108get_scalar_type.cache = {}
109
110
111def as_scalar(x, name=None):
112    from ..tensor import TensorType, scalar_from_tensor
113    if isinstance(x, gof.Apply):
114        if len(x.outputs) != 1:
115            raise ValueError("It is ambiguous which output of a multi-output"
116                             " Op has to be fetched.", x)
117        else:
118            x = x.outputs[0]
119    if isinstance(x, Variable):
120        if isinstance(x.type, Scalar):
121            return x
122        elif isinstance(x.type, TensorType) and x.ndim == 0:
123            return scalar_from_tensor(x)
124        else:
125            raise TypeError("Variable type field must be a Scalar.", x, x.type)
126    try:
127        return constant(x)
128    except TypeError:
129        raise TypeError("Cannot convert %s to Scalar" % x, type(x))
130
131
132class NumpyAutocaster(object):
133    """
134    This class is used to cast python ints and floats to numpy arrays.
135
136    The behavior when called on scalar `x` depends on `config.cast_policy`:
137        - 'numpy' will simply use the same type as found by `numpy.asarray(x)`.
138        - 'numpy+floatX' will do the same, except it will use float32 instead
139          of float64 if `x` is a Python float and `config.floatX` is set to
140          'float32' (note that if `x` is a numpy scalar whose data type is
141          float64, it is not modified since we assume the user is purposely
142          using float64).
143        - 'custom' lets one define a tuple of data types such that:
144            - if `x` is already a numpy scalar and its data type is in this
145              tuple, then it is returned unchanged;
146            - otherwise, the first data type in this tuple that can represent
147              `x` without loss of precision will be used, unless `x` is a float
148              and 'float32' is in the tuple (in which case `x` is cast as a
149              float32);
150            - if no data type can represent `x` without loss of precision, then
151              the last data type in the tuple will be used.
152
153
154    Parameters
155    ----------
156    dtypes: tuple of strings
157        The ordered list of preferred data types (only used when
158        `config.cast_policy` is set to 'custom', see the `NumpyAutocaster`
159        help for details).
160
161    """
162
163    def __init__(self, dtypes):
164        self.dtypes = tuple(dtypes)
165
166    def __call__(self, x):
167        # Make sure we only deal with scalars.
168        assert (isinstance(x, six.integer_types) or
169                isinstance(x, builtin_float) or
170                (isinstance(x, np.ndarray) and x.ndim == 0))
171
172        if config.cast_policy == 'numpy':
173            return np.asarray(x)
174        elif config.cast_policy == 'numpy+floatX':
175            rval = np.asarray(x)
176            if ((not hasattr(x, 'dtype') and
177                 rval.dtype in ('float64', 'float32') and
178                 rval.dtype != config.floatX)):
179                rval = theano._asarray(rval, dtype=config.floatX)
180            return rval
181
182        # The following is the original code, corresponding to the 'custom'
183        # option for `config.cast_policy`.
184        assert config.cast_policy == 'custom'
185
186        try:
187            # Pass through numpy scalars, since they are already typed on
188            # purpose typically.
189            if str(x.dtype) in self.dtypes:
190                # No need to cast `x` into a new dtype. Note that we still
191                # need to convert it into an array, because it may not be
192                # one already (e.g. if x == numpy.float64(1.1)).
193                return np.asarray(x)
194        except AttributeError:
195            # Means `x` has no 'dtype' attribute.
196            pass
197
198        # unsafe downcast of float64 variables when config.floatX == 'float32'
199        # recall: float is numpy.float
200        if ((isinstance(x, float) and
201             config.floatX in self.dtypes and
202             config.floatX != 'float64')):
203            return theano._asarray(x, dtype=config.floatX)
204
205        # Don't autocast to float16 unless config.floatX is float16
206        try_dtypes = [d for d in self.dtypes
207                      if config.floatX == 'float16' or d != 'float16']
208
209        for dtype in try_dtypes:
210            x_ = theano._asarray(x, dtype=dtype)
211            if np.all(x == x_):
212                break
213        # returns either an exact x_==x, or the last cast x_
214        return x_
215
216autocast_int = NumpyAutocaster(('int8', 'int16', 'int32', 'int64'))
217# autocast_float dtypes might be manipulated in tensor.*
218autocast_float = NumpyAutocaster(('float16', 'float32', 'float64'))
219
220
221class autocast_float_as(object):
222    """
223    Temporarily adjust autocasting behavior.
224
225    This class makes it possible to temporarily and locally adjust autocasting
226    behavior when `config.cast_policy` is set to 'custom'.
227    If `config.cast_policy` is not 'custom', an exception is raised.
228    This class might be convenient in some code, but it definitely
229    helps to test the autocasting mechanism.
230
231    Examples
232    --------
233    >>> with autocast_float_as('float32'):
234    ...    assert (fvector() + 1.1).dtype == 'float32'  # temporary downcasting
235    >>> assert (fvector() + 1.1).dtype == 'float64' # back to default behaviour
236
237    """
238    def __init__(self, *dtypes):
239        self.dtypes = dtypes
240        assert config.cast_policy == 'custom'
241
242    def __enter__(self):
243        assert config.cast_policy == 'custom'
244        self.old_dtypes = autocast_float.dtypes
245        autocast_float.dtypes = self.dtypes
246
247    def __exit__(self, *args):
248        assert config.cast_policy == 'custom'
249        autocast_float.dtypes = self.old_dtypes
250
251
252def convert(x, dtype=None):
253    """
254    Convert the input to a properly typed numpy value according to the
255    current casting policy.  Work with scalars and tensors.
256
257    """
258    if dtype is not None:
259        # in this case, the semantics are that the caller is forcing the dtype
260        x_ = theano._asarray(x, dtype=dtype)
261    else:
262        # In this case, this function should infer the dtype according to the
263        # autocasting rules. See autocasting above.
264        x_ = None
265        if isinstance(x, six.integer_types):
266            try:
267                x_ = autocast_int(x)
268            except OverflowError:
269                # This is to imitate numpy behavior which tries to fit
270                # bigger numbers into a uint64.
271                x_ = theano._asarray(x, dtype='uint64')
272        elif isinstance(x, builtin_float):
273            x_ = autocast_float(x)
274        elif isinstance(x, np.ndarray):
275            x_ = x
276        else:
277            # Here x is probably a list or a tuple. If it contains a
278            # long, we will behave like the current NumPy version: it
279            # will work if the long fits in int64 or uint64.
280            x_ = np.asarray(x)
281            if x_.size == 0 and not hasattr(x, 'dtype'):
282                x_ = np.asarray(x, dtype=config.floatX)
283    assert type(x_) in [np.ndarray, np.memmap]
284    return x_
285
286
287def constant(x, name=None, dtype=None):
288    x = convert(x, dtype=dtype)
289    assert x.ndim == 0
290    return ScalarConstant(get_scalar_type(str(x.dtype)), x, name=name)
291
292
293class Scalar(Type):
294
295    """
296    Internal class, should not be used by clients.
297
298    Primarily used by tensor.elemwise and tensor.reduce.
299    Analogous to TensorType, but for zero-dimensional objects.
300    Maps directly to C primitives.
301
302    TODO: refactor to be named ScalarType for consistency with TensorType.
303
304    """
305    __props__ = ('dtype',)
306    ndim = 0
307
308    def __init__(self, dtype):
309        if dtype == 'floatX':
310            dtype = config.floatX
311        self.dtype = dtype
312        self.dtype_specs()  # error checking
313
314    @staticmethod
315    def may_share_memory(a, b):
316        # This class represent basic c type, represented in python
317        # with numpy.scalar. They are read only. So from python, they
318        # can never share memory.
319        return False
320
321    def filter(self, data, strict=False, allow_downcast=None):
322        py_type = self.dtype_specs()[0]
323        if strict and not isinstance(data, py_type):
324            raise TypeError("%s expected a %s, got %s of type %s" % (
325                self, py_type, data, type(data)), data)
326        try:
327            converted_data = py_type(data)
328            if (allow_downcast or
329                    (allow_downcast is None and
330                        type(data) is float and
331                        self.dtype == theano.config.floatX) or
332                    data == converted_data):
333                return py_type(data)
334            else:
335                raise TypeError('Value cannot accurately be converted to dtype'
336                                ' (%s) and allow_downcast is not True' %
337                                self.dtype)
338        except Exception as e:
339            raise TypeError("Could not convert %s (value=%s) to %s" % (
340                type(data), data, self.dtype), e)
341
342    def values_eq_approx(self, a, b, tolerance=1e-4):
343        # The addition have risk of overflow especially with [u]int8
344        if self.dtype == 'bool':
345            return a == b
346        diff = a - b
347        if diff == 0:
348            return True
349        return abs(diff) <= (abs(a) * tolerance) + (abs(b) * tolerance)
350
351    def c_element_type(self):
352        return self.dtype_specs()[1]
353
354    def c_headers(self, c_compiler):
355        l = ['<math.h>']
356        # These includes are needed by Scalar and TensorType,
357        # we declare them here and they will be re-used by TensorType
358        l.append('<numpy/arrayobject.h>')
359        l.append('<numpy/arrayscalars.h>')
360        if config.lib.amdlibm and c_compiler.supports_amdlibm:
361            l += ['<amdlibm.h>']
362        return l
363
364    def c_libraries(self, c_compiler):
365        l = []
366        if config.lib.amdlibm and c_compiler.supports_amdlibm:
367            l += ['amdlibm']
368        return l
369
370    def c_compile_args(self, c_compiler):
371        if config.lib.amdlibm and c_compiler.supports_amdlibm:
372            return ['-DREPLACE_WITH_AMDLIBM']
373        else:
374            return []
375
376    def dtype_specs(self):
377        try:
378            # To help debug dtype/typenum problem, here is code to get
379            # the list of numpy typenum.  This list change between 32
380            # and 64 bit platform and probably also also between
381            # Windows and Linux.
382            # NOTE: equivalent type on a platform can have different typenum.
383            #     This is the source of all dtype/typenum problem found up to
384            #     now, as Theano always expect the exact typenum that
385            #     correspond to our supported dtype.
386            """
387            for dtype in ['bool', 'int8', 'uint8', 'short', 'ushort', 'intc',
388                          'uintc',
389                          'longlong', 'ulonglong', 'single', 'double',
390                          'longdouble', 'csingle', 'cdouble', 'clongdouble',
391                          'float32', 'float64', 'int8', 'int16', 'int32',
392                          'int64', 'uint8', 'uint16', 'uint32', 'uint64',
393                          'complex64', 'complex128', 'float', 'double',
394                          'int', 'uint']:
395                print(dtype, np.zeros(1, dtype=dtype).dtype.num)
396            """
397            return {  # dtype: (py_type, c_type, cls_name)
398                'float16': (np.float16, 'npy_float16', 'Float16'),
399                'float32': (np.float32, 'npy_float32', 'Float32'),
400                'float64': (np.float64, 'npy_float64', 'Float64'),
401                'complex128': (np.complex128, 'theano_complex128',
402                               'Complex128'),
403                'complex64': (np.complex64, 'theano_complex64', 'Complex64'),
404                'bool': (np.bool_, 'npy_bool', 'Bool'),
405                'uint8': (np.uint8, 'npy_uint8', 'UInt8'),
406                'int8': (np.int8, 'npy_int8', 'Int8'),
407                'uint16': (np.uint16, 'npy_uint16', 'UInt16'),
408                'int16': (np.int16, 'npy_int16', 'Int16'),
409                'uint32': (np.uint32, 'npy_uint32', 'UInt32'),
410                'int32': (np.int32, 'npy_int32', 'Int32'),
411                'uint64': (np.uint64, 'npy_uint64', 'UInt64'),
412                'int64': (np.int64, 'npy_int64', 'Int64')
413            }[self.dtype]
414        except KeyError:
415            raise TypeError("Unsupported dtype for %s: %s" % (
416                self.__class__.__name__, self.dtype))
417
418    def upcast(self, *others):
419        return upcast(*[x.dtype for x in [self] + list(others)])
420
421    def make_variable(self, name=None):
422        return ScalarVariable(self, name=name)
423
424    def __str__(self):
425        return str(self.dtype)
426
427    def __repr__(self):
428        return "Scalar(%s)" % self.dtype
429
430    def c_literal(self, data):
431        if 'complex' in self.dtype:
432            raise NotImplementedError("No literal for complex values.")
433        if self.dtype == 'bool':
434            return '1' if data else '0'
435        return str(data)
436
437    def c_declare(self, name, sub, check_input=True):
438        if(check_input):
439            pre = """
440                typedef %(dtype)s dtype_%(name)s;
441            """ % dict(name=name, dtype=self.dtype_specs()[1])
442        else:
443            pre = ""
444        return pre + """
445        %(dtype)s %(name)s;
446        """ % dict(name=name, dtype=self.dtype_specs()[1])
447
448    def c_init(self, name, sub):
449        return """
450        %(name)s = 0;
451        """ % locals()
452
453    def c_extract(self, name, sub, check_input=True):
454        if self.dtype == 'float16':
455            # This doesn't work at the numpy level
456            raise NotImplementedError('float16')
457        specs = self.dtype_specs()
458        if(check_input):
459            pre = """
460            if (!PyObject_TypeCheck(py_%(name)s, &%(pyarr_type)s))
461            {
462                PyErr_Format(PyExc_ValueError,
463                    "Scalar check failed (%(dtype)s)");
464                %(fail)s
465            }
466            """ % dict(sub,
467                       name=name,
468                       dtype=specs[1],
469                       pyarr_type='Py%sArrType_Type' % specs[2])
470        else:
471            pre = ""
472        return pre + """
473        PyArray_ScalarAsCtype(py_%(name)s, &%(name)s);
474        """ % dict(sub, name=name)
475
476    def c_sync(self, name, sub):
477        specs = self.dtype_specs()
478        return """
479        Py_XDECREF(py_%(name)s);
480        py_%(name)s = PyArrayScalar_New(%(cls)s);
481        if (!py_%(name)s)
482        {
483            Py_XINCREF(Py_None);
484            py_%(name)s = Py_None;
485            PyErr_Format(PyExc_MemoryError,
486                "Instantiation of new Python scalar failed (%(dtype)s)");
487            %(fail)s
488        }
489        PyArrayScalar_ASSIGN(py_%(name)s, %(cls)s, %(name)s);
490        """ % dict(sub,
491                   name=name,
492                   dtype=specs[1],
493                   cls=specs[2])
494
495    def c_cleanup(self, name, sub):
496        return ""
497
498    def c_support_code(self):
499
500        if self.dtype.startswith('complex'):
501            cplx_types = ['theano_complex64', 'theano_complex128']
502            real_types = ['npy_int8', 'npy_int16', 'npy_int32', 'npy_int64',
503                          'npy_float32', 'npy_float64']
504            # If the 'int' C type is not exactly the same as an existing
505            # 'npy_intX', some C code may not compile, e.g. when assigning
506            # the value 0 (cast to 'int' in C) to a theano_complex64.
507            if (np.dtype('intc').num not in
508                    [np.dtype(d[4:]).num for d in real_types]):
509                # In that case we add the 'int' type to the real types.
510                real_types.append('int')
511
512            template = """
513            struct theano_complex%(nbits)s : public npy_complex%(nbits)s
514            {
515                typedef theano_complex%(nbits)s complex_type;
516                typedef npy_float%(half_nbits)s scalar_type;
517
518                complex_type operator +(const complex_type &y) const {
519                    complex_type ret;
520                    ret.real = this->real + y.real;
521                    ret.imag = this->imag + y.imag;
522                    return ret;
523                }
524
525                complex_type operator -() const {
526                    complex_type ret;
527                    ret.real = -this->real;
528                    ret.imag = -this->imag;
529                    return ret;
530                }
531                bool operator ==(const complex_type &y) const {
532                    return (this->real == y.real) && (this->imag == y.imag);
533                }
534                bool operator ==(const scalar_type &y) const {
535                    return (this->real == y) && (this->imag == 0);
536                }
537                complex_type operator -(const complex_type &y) const {
538                    complex_type ret;
539                    ret.real = this->real - y.real;
540                    ret.imag = this->imag - y.imag;
541                    return ret;
542                }
543                complex_type operator *(const complex_type &y) const {
544                    complex_type ret;
545                    ret.real = this->real * y.real - this->imag * y.imag;
546                    ret.imag = this->real * y.imag + this->imag * y.real;
547                    return ret;
548                }
549                complex_type operator /(const complex_type &y) const {
550                    complex_type ret;
551                    scalar_type y_norm_square = y.real * y.real + y.imag * y.imag;
552                    ret.real = (this->real * y.real + this->imag * y.imag) / y_norm_square;
553                    ret.imag = (this->imag * y.real - this->real * y.imag) / y_norm_square;
554                    return ret;
555                }
556                template <typename T>
557                complex_type& operator =(const T& y);
558
559                theano_complex%(nbits)s() {}
560
561                template <typename T>
562                theano_complex%(nbits)s(const T& y) { *this = y; }
563
564                template <typename TR, typename TI>
565                theano_complex%(nbits)s(const TR& r, const TI& i) { this->real=r; this->imag=i; }
566            };
567            """
568
569            def operator_eq_real(mytype, othertype):
570                return '''
571                template <> %(mytype)s & %(mytype)s::operator=<%(othertype)s>(const %(othertype)s & y)
572                { this->real=y; this->imag=0; return *this; }
573                ''' % dict(mytype=mytype, othertype=othertype)
574
575            def operator_eq_cplx(mytype, othertype):
576                return '''
577                template <> %(mytype)s & %(mytype)s::operator=<%(othertype)s>(const %(othertype)s & y)
578                { this->real=y.real; this->imag=y.imag; return *this; }
579                ''' % dict(mytype=mytype, othertype=othertype)
580
581            operator_eq = (''.join(operator_eq_real(ctype, rtype)
582                                   for ctype in cplx_types
583                                   for rtype in real_types) +
584                           ''.join(operator_eq_cplx(ctype1, ctype2)
585                                   for ctype1 in cplx_types
586                                   for ctype2 in cplx_types))
587
588            # We are not using C++ generic templating here, because this would
589            # generate two different functions for adding a complex64 and a
590            # complex128, one returning a complex64, the other a complex128,
591            # and the compiler complains it is ambiguous.
592            # Instead, we generate code for known and safe types only.
593
594            def operator_plus_real(mytype, othertype):
595                return '''
596                const %(mytype)s operator+(const %(mytype)s &x, const %(othertype)s &y)
597                { return %(mytype)s(x.real+y, x.imag); }
598
599                const %(mytype)s operator+(const %(othertype)s &y, const %(mytype)s &x)
600                { return %(mytype)s(x.real+y, x.imag); }
601                ''' % dict(mytype=mytype, othertype=othertype)
602
603            operator_plus = ''.join(operator_plus_real(ctype, rtype)
604                                    for ctype in cplx_types
605                                    for rtype in real_types)
606
607            def operator_minus_real(mytype, othertype):
608                return '''
609                const %(mytype)s operator-(const %(mytype)s &x, const %(othertype)s &y)
610                { return %(mytype)s(x.real-y, x.imag); }
611
612                const %(mytype)s operator-(const %(othertype)s &y, const %(mytype)s &x)
613                { return %(mytype)s(y-x.real, -x.imag); }
614                ''' % dict(mytype=mytype, othertype=othertype)
615
616            operator_minus = ''.join(operator_minus_real(ctype, rtype)
617                                     for ctype in cplx_types
618                                     for rtype in real_types)
619
620            def operator_mul_real(mytype, othertype):
621                return '''
622                const %(mytype)s operator*(const %(mytype)s &x, const %(othertype)s &y)
623                { return %(mytype)s(x.real*y, x.imag*y); }
624
625                const %(mytype)s operator*(const %(othertype)s &y, const %(mytype)s &x)
626                { return %(mytype)s(x.real*y, x.imag*y); }
627                ''' % dict(mytype=mytype, othertype=othertype)
628
629            operator_mul = ''.join(operator_mul_real(ctype, rtype)
630                                   for ctype in cplx_types
631                                   for rtype in real_types)
632
633            return (template % dict(nbits=64, half_nbits=32) +
634                    template % dict(nbits=128, half_nbits=64) +
635                    operator_eq +
636                    operator_plus +
637                    operator_minus +
638                    operator_mul)
639
640        else:
641            return ""
642
643    def c_init_code(self):
644        return ["import_array();"]
645
646    def c_code_cache_version(self):
647        return (13, np.__version__)
648
649    def get_shape_info(self, obj):
650        return obj.itemsize
651
652    def get_size(self, shape_info):
653        return shape_info
654
655# Register C code for ViewOp on Scalars.
656theano.compile.register_view_op_c_code(
657    Scalar,
658    """
659    %(oname)s = %(iname)s;
660    """,
661    1)
662
663
664bool = get_scalar_type('bool')
665int8 = get_scalar_type('int8')
666int16 = get_scalar_type('int16')
667int32 = get_scalar_type('int32')
668int64 = get_scalar_type('int64')
669uint8 = get_scalar_type('uint8')
670uint16 = get_scalar_type('uint16')
671uint32 = get_scalar_type('uint32')
672uint64 = get_scalar_type('uint64')
673float16 = get_scalar_type('float16')
674float32 = get_scalar_type('float32')
675float64 = get_scalar_type('float64')
676complex64 = get_scalar_type('complex64')
677complex128 = get_scalar_type('complex128')
678
679int_types = int8, int16, int32, int64
680uint_types = uint8, uint16, uint32, uint64
681float_types = float16, float32, float64
682complex_types = complex64, complex128
683
684integer_types = int_types + uint_types
685discrete_types = (bool,) + integer_types
686continuous_types = float_types + complex_types
687all_types = discrete_types + continuous_types
688
689
690class _scalar_py_operators:
691    # So that we can simplify checking code when we have a mixture of Scalar
692    # variables and Tensor variables
693    ndim = 0
694
695    dtype = property(lambda self: self.type.dtype)
696    """The dtype of this scalar."""
697
698    # UNARY
699    def __abs__(self):
700        return abs_(self)
701
702    def __neg__(self):
703        return neg(self)
704
705    # CASTS
706    # def __int__(self): return AsInt(self).out
707    # def __float__(self): return AsDouble(self).out
708    # def __complex__(self): return AsComplex(self).out
709
710    # BITWISE
711    def __invert__(self):
712        return invert(self)
713
714    def __and__(self, other):
715        return and_(self, other)
716
717    def __or__(self, other):
718        return or_(self, other)
719
720    def __xor__(self, other):
721        return xor(self, other)
722
723    def __rand__(self, other):
724        return and_(other, self)
725
726    def __ror__(self, other):
727        return or_(other, self)
728
729    def __rxor__(self, other):
730        return xor(other, self)
731
732    # COMPARISONS
733    def __lt__(self, other):
734        return lt(self, other)
735
736    def __le__(self, other):
737        return le(self, other)
738
739    def __gt__(self, other):
740        return gt(self, other)
741
742    def __ge__(self, other):
743        return ge(self, other)
744
745    # ARITHMETIC - NORMAL
746    def __add__(self, other):
747        return add(self, other)
748
749    def __sub__(self, other):
750        return sub(self, other)
751
752    def __mul__(self, other):
753        return mul(self, other)
754
755    def __truediv__(self, other):
756        return div_proxy(self, other)
757
758    def __div__(self, other):
759        return div_proxy(self, other)
760
761    def __floordiv__(self, other):
762        return int_div(self, other)
763
764    def __mod__(self, other):
765        return mod_check(self, other)
766
767    def __pow__(self, other):
768        return pow(self, other)
769
770    # ARITHMETIC - RIGHT-OPERAND
771    def __radd__(self, other):
772        return add(other, self)
773
774    def __rsub__(self, other):
775        return sub(other, self)
776
777    def __rmul__(self, other):
778        return mul(other, self)
779
780    def __rdiv__(self, other):
781        return div_proxy(other, self)
782
783    def __rmod__(self, other):
784        return mod(other, self)
785
786    def __rpow__(self, other):
787        return pow(other, self)
788
789    def zeros_like(self, dtype=None):
790        # The second is needed for Elemwise ops to work right
791        if dtype is None:
792            dtype = str(self.type.dtype)
793        return second(self, ScalarConstant(get_scalar_type(dtype), 0))
794
795    def ones_like(self, dtype=None):
796        # The second is needed for Elemwise ops to work right
797        if dtype is None:
798            dtype = str(self.type.dtype)
799        return second(self, ScalarConstant(get_scalar_type(dtype), 1))
800
801    def astype(self, dtype):
802        return cast(self, dtype)
803
804
805class ScalarVariable(_scalar_py_operators, Variable):
806    pass
807
808
809class ScalarConstant(_scalar_py_operators, Constant):
810    pass
811
812# Register ScalarConstant as the type of Constant corresponding to Scalar
813Scalar.Constant = ScalarConstant
814
815
816# Easy constructors
817
818def _multi(*fns):
819    def f2(f, names):
820        if len(names) == 1:
821            return f(names)
822        else:
823            return [f(name) for name in names]
824    if len(fns) == 1:
825        return partial(f2, fns[0])
826    else:
827        return [partial(f2, f) for f in fns]
828
829ints = _multi(int64)
830floats = _multi(float64)
831complexs = _multi(complex128)
832complexs64 = _multi(complex64)
833complexs128 = _multi(complex128)
834
835
836def upcast_out(*types):
837    dtype = Scalar.upcast(*types)
838    return get_scalar_type(dtype),
839
840
841def upcast_out_nobool(*types):
842    type = upcast_out(*types)
843    if type[0] == bool:
844        raise TypeError("bool output not supported")
845    return type
846
847
848def upcast_out_min8(*types):
849    type = upcast_out(*types)
850    if type[0] == bool:
851        return int8,
852    return type
853
854
855def upgrade_to_float(*types):
856    """
857    Upgrade any int types to float32 or float64 to avoid losing precision.
858
859    """
860    conv = {bool: float32,
861            int8: float32,
862            int16: float32,
863            int32: float64,
864            int64: float64,
865            uint8: float32,
866            uint16: float32,
867            uint32: float64,
868            uint64: float64}
869    return get_scalar_type(Scalar.upcast(*[conv.get(type, type)
870                                           for type in types])),
871
872
873def upgrade_to_float64(*types):
874    """
875    Upgrade any int and float32 to float64 to do as SciPy.
876
877    """
878    return get_scalar_type('float64'),
879
880
881def same_out(type):
882    return type,
883
884
885def same_out_nobool(type):
886    if type == bool:
887        raise TypeError("bool input not supported")
888    return type,
889
890
891def same_out_min8(type):
892    if type == bool:
893        return int8,
894    return type,
895
896
897def upcast_out_no_complex(*types):
898    if any([type in complex_types for type in types]):
899        raise TypeError('complex type are not supported')
900    return get_scalar_type(dtype=Scalar.upcast(*types)),
901
902
903def same_out_float_only(type):
904    if type not in float_types:
905        raise TypeError('only float type are supported')
906    return type,
907
908
909class transfer_type(gof.utils.object2):
910    __props__ = ('transfer',)
911
912    def __init__(self, *transfer):
913        assert all(type(x) in [int, str] or x is None for x in transfer)
914        self.transfer = transfer
915
916    def __str__(self):
917        return 'transfer_type{%s}' % self.transfer
918
919    def __call__(self, *types):
920        upcast = upcast_out(*types)
921        retval = []
922        for i in self.transfer:
923            if i is None:
924                retval += [upcast]
925            elif isinstance(i, str):
926                retval += [i]
927            else:
928                retval += [types[i]]
929        return retval
930        # return [upcast if i is None else types[i] for i in self.transfer]
931
932
933class specific_out(gof.utils.object2):
934    __props__ = ('spec',)
935
936    def __init__(self, *spec):
937        self.spec = spec
938
939    def __call__(self, *types):
940        return self.spec
941
942
943def int_out(*types):
944    return int64,
945
946
947def float_out(*types):
948    return float64,
949
950
951def upgrade_to_float_no_complex(*types):
952    """
953    Don't accept complex, otherwise call upgrade_to_float().
954
955    """
956    for type in types:
957        if type in complex_types:
958            raise TypeError('complex argument not supported')
959    return upgrade_to_float(*types)
960
961
962def same_out_nocomplex(type):
963    if type in complex_types:
964        raise TypeError('complex argument not supported')
965    return type,
966
967
968def int_out_nocomplex(*types):
969    for type in types:
970        if type in complex_types:
971            raise TypeError('complex argument not supported')
972    return int64,
973
974
975def float_out_nocomplex(*types):
976    for type in types:
977        if type in complex_types:
978            raise TypeError('complex argument not supported')
979    return float64,
980
981
982class unary_out_lookup(gof.utils.object2):
983    """
984    Get a output_types_preference object by passing a dictionary:
985
986    unary_out_lookup({int8:int32, float32:complex128})
987
988    The result is an op that maps in8 to int32 and float32 to
989    complex128 and other input types lead to a TypeError.
990
991    """
992    def __init__(self, type_table):
993        self.tbl = type_table
994
995    def __call__(self, *types):
996        if len(types) == 1:
997            types = types[0]
998        try:
999            rval = self.tbl[types]
1000        except Exception:
1001            raise TypeError(types)
1002        if isinstance(types, (list, tuple)):
1003            return rval
1004        else:
1005            return [rval]
1006
1007    def __eq__(self, other):
1008        return type(self) == type(other) and self.tbl == other.tbl
1009
1010    def __hash__(self):
1011        return hash(type(self))  # ignore hash of table
1012
1013
1014def real_out(type):
1015    if type == complex64:
1016        return float32,
1017    if type == complex128:
1018        return float64,
1019    return type,
1020
1021
1022class ScalarOp(Op):
1023
1024    nin = -1
1025    nout = 1
1026
1027    def __init__(self, output_types_preference=None, name=None):
1028        self.name = name
1029        if output_types_preference is not None:
1030            if not isinstance(output_types_preference, Callable):
1031                raise TypeError(
1032                    "Expected a callable for the 'output_types_preference' argument to %s. (got: %s)" %
1033                    (self.__class__, output_types_preference))
1034            self.output_types_preference = output_types_preference
1035
1036    def make_node(self, *inputs):
1037        if self.nin >= 0:
1038            if len(inputs) != self.nin:
1039                raise TypeError("Wrong number of inputs for %s.make_node (got %i(%s), expected %i)" %
1040                                (self, len(inputs), str(inputs), self.nin))
1041        inputs = [as_scalar(input) for input in inputs]
1042        outputs = [t() for t in self.output_types([input.type
1043                                                   for input in inputs])]
1044        if len(outputs) != self.nout:
1045            raise TypeError("Not the right number of outputs produced for %s(%s). Expected %s, got %s."
1046                            % (self, ", ".join(str(input) for input in inputs), self.nout, len(outputs)))
1047        return Apply(self, inputs, outputs)
1048
1049    def output_types(self, types):
1050        if hasattr(self, 'output_types_preference'):
1051            variables = self.output_types_preference(*types)
1052            if not isinstance(variables, (list, tuple)) or any(not isinstance(x, Type) for x in variables):
1053                raise TypeError(
1054                    "output_types_preference should return a list or a tuple of types", self.output_types_preference, variables)
1055            if len(variables) != self.nout:
1056                raise TypeError("Not the right number of outputs types produced for %s(%s) by %s. Expected %s, got %s."
1057                                % (self, ", ".join(str(type) for type in variables),
1058                                   self.output_types_preference, self.nout, len(variables)))
1059            return variables
1060        else:
1061            raise NotImplementedError(
1062                "Cannot calculate the output types for %s" % self)
1063
1064    def perform(self, node, inputs, output_storage):
1065        if self.nout == 1:
1066            output_storage[0][0] = self.impl(*inputs)
1067        else:
1068            variables = utils.from_return_values(self.impl(*inputs))
1069            assert len(variables) == len(output_storage)
1070            for storage, variable in zip(output_storage, variables):
1071                storage[0] = variable
1072
1073    def impl(self, *inputs):
1074        raise utils.MethodNotDefined("impl", type(self),
1075                                     self.__class__.__name__)
1076
1077    def grad(self, inputs, output_gradients):
1078        raise utils.MethodNotDefined("grad", type(self),
1079                                     self.__class__.__name__)
1080
1081    def L_op(self, inputs, outputs, output_gradients):
1082        return self.grad(inputs, output_gradients)
1083
1084    def __eq__(self, other):
1085        test = (type(self) == type(other) and
1086                getattr(self, 'output_types_preference', None) ==
1087                getattr(other, 'output_types_preference', None))
1088        return test
1089
1090    def __hash__(self):
1091        return hash((type(self),
1092                     getattr(self, 'output_types_preference', 0)))
1093
1094    def __str__(self):
1095        if hasattr(self, 'name') and self.name:
1096            return self.name
1097        else:
1098            param = [(k, v) for k, v in self.__dict__.items()
1099                     if k not in ["name", "_op_use_c_code", "bool",
1100                                  "output_types_preference"]]
1101            if param:
1102                return "%s{%s}" % (self.__class__.__name__,
1103                                   ", ".join("%s=%s" % (k, v)
1104                                             for k, v in param))
1105            else:
1106                return self.__class__.__name__
1107
1108    def c_code_cache_version(self):
1109        return (4,)
1110
1111    def c_code_contiguous(self, node, name, inp, out, sub):
1112        """
1113        This function is called by Elemwise when all inputs and outputs are
1114        c_contiguous. This allows to use the SIMD version of this op.
1115
1116        The inputs are the same as c_code except that:
1117
1118            - inp and out must be the names of the variables associated to the
1119              ndarrays in the C code
1120            - node must be the elemwise node (this is needed to know
1121              the inputs/outputs types)
1122
1123        """
1124        raise theano.gof.utils.MethodNotDefined()
1125
1126    def supports_c_code(self, inputs, outputs):
1127        """Returns True if the current op has functioning C code for
1128        the given Elemwise inputs, outputs.
1129
1130        """
1131        try:
1132            tmp_s_input = []
1133            # To keep the same aliasing between inputs
1134            mapping = dict()
1135            for ii in inputs:
1136                if ii in mapping:
1137                    tmp_s_input.append(mapping[ii])
1138                else:
1139                    tmp = get_scalar_type(ii.dtype).make_variable()
1140                    tmp_s_input.append(tmp)
1141                    mapping[ii] = tmp_s_input[-1]
1142
1143            with theano.change_flags(compute_test_value='ignore'):
1144                s_op = self(*tmp_s_input, return_list=True)
1145
1146            # if the scalar_op don't have a c implementation,
1147            # we skip its fusion to allow the fusion of the
1148            # other ops.
1149            self.c_code(s_op[0].owner,
1150                        "test_presence_of_c_code",
1151                        ["x" for x in inputs],
1152                        ["z" for z in outputs],
1153                        {"fail": "%(fail)s"})
1154        except (theano.gof.utils.MethodNotDefined, NotImplementedError):
1155            return False
1156        return True
1157
1158
1159class UnaryScalarOp(ScalarOp):
1160    nin = 1
1161    amd_float32 = None
1162    amd_float64 = None
1163
1164    def c_code_contiguous(self, node, name, inputs, outputs, sub):
1165        (x,) = inputs
1166        (z,) = outputs
1167        if (not theano.config.lib.amdlibm or
1168                # We compare the dtype AND the broadcast flag
1169                # as this function do not broadcast
1170                node.inputs[0].type != node.outputs[0].type):
1171            raise theano.gof.utils.MethodNotDefined()
1172
1173        dtype = node.inputs[0].type.dtype_specs()[1]
1174        fct_call = self.c_code_contiguous_raw(dtype, 'n', 'x', 'z')
1175        return """
1176{
1177        npy_intp n = PyArray_SIZE(%(z)s);
1178        %(dtype)s * x = (%(dtype)s*) PyArray_DATA(%(x)s);
1179        %(dtype)s * z = (%(dtype)s*) PyArray_DATA(%(z)s);
1180        %(fct_call)s;
1181}
1182        """ % locals()
1183
1184    def c_code_contiguous_raw(self, dtype, n, i, o):
1185        if not config.lib.amdlibm:
1186            raise theano.gof.utils.MethodNotDefined()
1187        if dtype.startswith('npy_'):
1188            dtype = dtype[4:]
1189        if dtype == 'float32' and self.amd_float32 is not None:
1190            dtype = 'float'
1191            fct = self.amd_float32
1192        elif dtype == 'float64' and self.amd_float64 is not None:
1193            dtype = 'double'
1194            fct = self.amd_float64
1195        else:
1196            raise theano.gof.utils.MethodNotDefined()
1197        return "%(fct)s(%(n)s, %(i)s, %(o)s)" % locals()
1198
1199
1200class BinaryScalarOp(ScalarOp):
1201    # One may define in subclasses the following fields:
1202    #   - `identity`: for an associative operation, identity corresponds to
1203    #     the neutral element. For instance, it will be 0 for addition, 1 for
1204    #     multiplication, True for "and", False for "or".
1205    #   - `commutative`: whether op(a, b) == op(b, a)
1206    #   - `associative`: whether op(op(a, b), c) == op(a, op(b, c))
1207    nin = 2
1208
1209
1210###############
1211# Comparisons
1212###############
1213
1214class LogicalComparison(BinaryScalarOp):
1215    def __init__(self, *args, **kwargs):
1216        BinaryScalarOp.__init__(self, *args, **kwargs)
1217        # This is for compat with old pickles.
1218        self.bool = True
1219
1220    def __eq__(self, other):
1221        return (BinaryScalarOp.__eq__(self, other) and
1222                getattr(self, 'bool', False) == getattr(other, 'bool', False))
1223
1224    def __hash__(self):
1225        # bool should always be True
1226        return BinaryScalarOp.__hash__(self)
1227
1228    def output_types(self, *input_dtypes):
1229        return [bool] if getattr(self, 'bool', False) else [int8]
1230
1231    def L_op(self, inputs, outputs, output_gradients):
1232        x, y = inputs
1233        assert outputs[0].type == bool
1234        return [x.zeros_like().astype(theano.config.floatX),
1235                y.zeros_like().astype(theano.config.floatX)]
1236
1237    def c_code_cache_version(self):
1238        super_version = super(LogicalComparison, self).c_code_cache_version()
1239        return super_version + (0,)
1240
1241
1242class FixedLogicalComparison(UnaryScalarOp):
1243    """
1244    Comparison to a fixed value.
1245
1246    """
1247    def __init__(self, *args, **kwargs):
1248        UnaryScalarOp.__init__(self, *args, **kwargs)
1249        # This is for compat with old pickles
1250        self.bool = True
1251
1252    def __eq__(self, other):
1253        return (UnaryScalarOp.__eq__(self, other) and
1254                getattr(self, 'bool', False) == getattr(other, 'bool', False))
1255
1256    def __hash__(self):
1257        # bool should always be True
1258        return UnaryScalarOp.__hash__(self)
1259
1260    def output_types(self, *input_dtypes):
1261        return [bool] if getattr(self, 'bool', False) else [int8]
1262
1263    def L_op(self, inputs, outputs, output_gradients):
1264        x, = inputs
1265        assert outputs[0].type == bool
1266        return [x.zeros_like().astype(theano.config.floatX)]
1267
1268    def c_code_cache_version(self):
1269        super_version = super(FixedLogicalComparison, self).c_code_cache_version()
1270        return super_version + (0,)
1271
1272
1273class LT(LogicalComparison):
1274    identity = False
1275    commutative = False
1276    associative = False
1277    nfunc_spec = ('less', 2, 1)
1278
1279    def impl(self, x, y):
1280        # built-in < don't support complex
1281        return np.less(x, y)
1282
1283    def c_code(self, node, name, inputs, outputs, sub):
1284        (x, y) = inputs
1285        (z,) = outputs
1286        if node.inputs[0].type in complex_types:
1287            raise NotImplementedError()
1288        return "%(z)s = (%(x)s < %(y)s);" % locals()
1289lt = LT()
1290
1291
1292class GT(LogicalComparison):
1293    identity = False
1294    commutative = False
1295    associative = False
1296    nfunc_spec = ('greater', 2, 1)
1297
1298    def impl(self, x, y):
1299        # built-in > don't support complex
1300        return np.greater(x, y)
1301
1302    def c_code(self, node, name, inputs, outputs, sub):
1303        (x, y) = inputs
1304        (z,) = outputs
1305        if node.inputs[0].type in complex_types:
1306            raise NotImplementedError()
1307        return "%(z)s = (%(x)s > %(y)s);" % locals()
1308gt = GT()
1309
1310
1311class LE(LogicalComparison):
1312    identity = False
1313    commutative = False
1314    associative = False
1315    nfunc_spec = ('less_equal', 2, 1)
1316
1317    def impl(self, x, y):
1318        # built-in <= don't support complex
1319        return np.less_equal(x, y)
1320
1321    def c_code(self, node, name, inputs, outputs, sub):
1322        (x, y) = inputs
1323        (z,) = outputs
1324        if node.inputs[0].type in complex_types:
1325            raise NotImplementedError()
1326        return "%(z)s = (%(x)s <= %(y)s);" % locals()
1327le = LE()
1328
1329
1330class GE(LogicalComparison):
1331    identity = False
1332    commutative = False
1333    associative = False
1334    nfunc_spec = ('greater_equal', 2, 1)
1335
1336    def impl(self, x, y):
1337        # built-in >= don't support complex
1338        return np.greater_equal(x, y)
1339
1340    def c_code(self, node, name, inputs, outputs, sub):
1341        (x, y) = inputs
1342        (z,) = outputs
1343        if node.inputs[0].type in complex_types:
1344            raise NotImplementedError()
1345        return "%(z)s = (%(x)s >= %(y)s);" % locals()
1346ge = GE()
1347
1348
1349class EQ(LogicalComparison):
1350    identity = False
1351    commutative = True
1352    associative = False
1353    nfunc_spec = ('equal', 2, 1)
1354
1355    def impl(self, x, y):
1356        return x == y
1357
1358    def c_code(self, node, name, inputs, outputs, sub):
1359        (x, y) = inputs
1360        (z,) = outputs
1361        return "%(z)s = (%(x)s == %(y)s);" % locals()
1362eq = EQ()
1363
1364
1365class NEQ(LogicalComparison):
1366    identity = False
1367    commutative = True
1368    associative = False
1369    nfunc_spec = ('not_equal', 2, 1)
1370
1371    def impl(self, x, y):
1372        return x != y
1373
1374    def c_code(self, node, name, inputs, outputs, sub):
1375        (x, y) = inputs
1376        (z,) = outputs
1377        if node.inputs[0].type in complex_types:
1378            raise NotImplementedError()
1379        return "%(z)s = (%(x)s != %(y)s);" % locals()
1380neq = NEQ()
1381
1382
1383class IsNan(FixedLogicalComparison):
1384    nfunc_spec = ('isnan', 1, 1)
1385
1386    def impl(self, x):
1387        return np.isnan(x)
1388
1389    def c_code(self, node, name, inputs, outputs, sub):
1390        (x,) = inputs
1391        (z,) = outputs
1392        if node.inputs[0].type in complex_types:
1393            raise NotImplementedError()
1394        # Discrete type can never be nan
1395        if node.inputs[0].type in discrete_types:
1396            return "%(z)s = false;" % locals()
1397
1398        # Windows tries to be different and sometimes return -1, but we want
1399        # to be consistent with numpy (which returns True), hence the "abs".
1400        return "%(z)s = abs(isnan(%(x)s));" % locals()
1401
1402    def c_code_cache_version(self):
1403        scalarop_version = super(IsNan, self).c_code_cache_version()
1404        return tuple(scalarop_version) + (3,)
1405isnan = IsNan()
1406
1407
1408class IsInf(FixedLogicalComparison):
1409    nfunc_spec = ('isinf', 1, 1)
1410
1411    def impl(self, x):
1412        return np.isinf(x)
1413
1414    def c_code(self, node, name, inputs, outputs, sub):
1415        (x,) = inputs
1416        (z,) = outputs
1417        if node.inputs[0].type in complex_types:
1418            raise NotImplementedError()
1419        # Discrete type can never be inf
1420        if node.inputs[0].type in discrete_types:
1421            return "%(z)s = false;" % locals()
1422
1423        # Note that the C isinf returns -1 for -Inf and +1 for +Inf, while
1424        # numpy simply returns True: we mimic numpy's behavior here, thus
1425        # the absolute value.
1426        return "%(z)s = abs(isinf(%(x)s));" % locals()
1427
1428    def c_code_cache_version(self):
1429        scalarop_version = super(IsInf, self).c_code_cache_version()
1430        return tuple(scalarop_version) + (3,)
1431isinf = IsInf()
1432
1433
1434class InRange(LogicalComparison):
1435    nin = 3
1436
1437    def __init__(self, openlow, openhi):
1438        self.openlow = openlow
1439        self.openhi = openhi
1440
1441    def impl(self, x, low, hi):
1442        if self.openlow and x <= low:
1443            return False
1444        elif not self.openlow and x < low:
1445            return False
1446        if self.openhi and x >= hi:
1447            return False
1448        elif not self.openhi and x > hi:
1449            return False
1450        return True
1451
1452    def c_code(self, node, name, inputs, outputs, sub):
1453        (x, low, hi) = inputs
1454        (z,) = outputs
1455
1456        cmp1 = '>' if self.openlow else '>='
1457        cmp2 = '<' if self.openhi else '<='
1458
1459        return ("%(z)s = %(x)s %(cmp1)s %(low)s &&"
1460                " %(x)s %(cmp2)s %(hi)s;" % locals())
1461
1462    def get_grad(self, elem):
1463        if elem.type in complex_types:
1464            msg = ("No gradient implemented for complex numbers in "
1465                   "class scalar.basic.InRange")
1466            raise NotImplementedError(msg)
1467        elif elem.type in discrete_types:
1468            return elem.zeros_like().astype(theano.config.floatX)
1469        else:
1470            return elem.zeros_like()
1471
1472    def L_op(self, inputs, outputs, gout):
1473        (x, low, hi) = inputs
1474        (gz,) = gout
1475        grads = []
1476        for elem in [x, low, hi]:
1477            grads.append(self.get_grad(elem))
1478        return grads
1479
1480inopenrange = InRange(True, True)
1481inclosedrange = InRange(False, False)
1482
1483
1484class Switch(ScalarOp):
1485    nin = 3
1486    nfunc_spec = ('where', 3, 1)
1487
1488    def impl(self, cond, ift, iff):
1489        return ift if cond else iff
1490
1491    def c_code(self, node, name, inputs, outputs, sub):
1492        (cond, ift, iff) = inputs
1493        (z,) = outputs
1494        return "%(z)s = %(cond)s ? %(ift)s : %(iff)s;" % locals()
1495
1496    def L_op(self, inputs, outputs, gout):
1497        (cond, ift, iff) = inputs
1498        (gz,) = gout
1499        first_part = switch(cond, gz, 0.)
1500        second_part = switch(cond, 0., gz)
1501
1502        if (outputs[0].type.dtype in discrete_types):
1503            first_part = 0.
1504            second_part = 0.
1505
1506        # cond does affect the elements of the output so it is connected.
1507        # For the sake of making the gradient convenient we assume that
1508        # condition + epsilon always triggers the same branch as condition
1509        condition_grad = cond.zeros_like().astype(theano.config.floatX)
1510
1511        return (condition_grad, first_part, second_part)
1512
1513    def output_types(self, types):
1514        (cond_t, ift_t, iff_t) = types
1515        return upcast_out(ift_t, iff_t)
1516switch = Switch()
1517
1518####################
1519# BIT-WISE OPERATORS
1520####################
1521
1522
1523class UnaryBitOp(UnaryScalarOp):
1524    def output_types(self, *input_types):
1525        for i in input_types[0]:
1526            if i not in discrete_types:
1527                raise TypeError('input to a BitOp must have type (u)int8, '
1528                                '(u)int16, (u)int32 or (u)int64 or bool not %s' % i)
1529        return upcast_out(*input_types[0])
1530
1531    def grad(self, inputs, output_gradients):
1532        return [inputs[0].zeros_like().astype(theano.config.floatX)]
1533
1534
1535class BinaryBitOp(BinaryScalarOp):
1536    def output_types(self, *input_types):
1537        t0, t1 = input_types[0]
1538        if t0 == bool and t1 == bool:
1539            return [bool]
1540        for i in input_types[0]:
1541            if i not in integer_types:
1542                raise TypeError('input to a BitOp must have type (u)int8, '
1543                                '(u)int16, (u)int32 or (u)int64 or '
1544                                'be all bools not %s' % i)
1545        return upcast_out(*input_types[0])
1546
1547    def grad(self, inputs, output_gradients):
1548        a, b = inputs
1549        return [a.zeros_like().astype(theano.config.floatX),
1550                b.zeros_like().astype(theano.config.floatX)]
1551
1552
1553class OR(BinaryBitOp):
1554    identity = 0
1555    commutative = True
1556    associative = True
1557    nfunc_spec = ('bitwise_or', 2, 1)
1558
1559    def impl(self, x, y):
1560        return x | y
1561
1562    def c_code(self, node, name, inputs, outputs, sub):
1563        (x, y) = inputs
1564        (z,) = outputs
1565        return "%(z)s = (%(x)s | %(y)s);" % locals()
1566or_ = OR()
1567
1568
1569class XOR(BinaryBitOp):
1570    identity = 0
1571    commutative = True
1572    associative = True
1573    nfunc_spec = ('bitwise_xor', 2, 1)
1574
1575    def impl(self, x, y):
1576        return x ^ y
1577
1578    def c_code(self, node, name, inputs, outputs, sub):
1579        (x, y) = inputs
1580        (z,) = outputs
1581        return "%(z)s = (%(x)s ^ %(y)s);" % locals()
1582xor = XOR()
1583
1584
1585class AND(BinaryBitOp):
1586    identity = -1
1587    commutative = True
1588    associative = True
1589    nfunc_spec = ('bitwise_and', 2, 1)
1590
1591    def impl(self, x, y):
1592        return x & y
1593
1594    def c_code(self, node, name, inputs, outputs, sub):
1595        (x, y) = inputs
1596        (z,) = outputs
1597        return "%(z)s = (%(x)s & %(y)s);" % locals()
1598
1599    def c_code_cache_version(self):
1600        super_version = super(AND, self).c_code_cache_version()
1601        return super_version + (3,)
1602and_ = AND()
1603
1604
1605class Invert(UnaryBitOp):
1606    nfunc_spec = ('invert', 1, 1)
1607
1608    def impl(self, x):
1609        return ~x
1610
1611    def c_code(self, node, name, inputs, outputs, sub):
1612        (x,) = inputs
1613        (z,) = outputs
1614        if node.outputs[0].type == bool:
1615            return "%(z)s = (!%(x)s);" % locals()
1616        return "%(z)s = (~%(x)s);" % locals()
1617invert = Invert()
1618
1619
1620##############
1621# Arithmetic
1622##############
1623class Maximum(BinaryScalarOp):
1624    commutative = True
1625    associative = True
1626    nfunc_spec = ('maximum', 2, 1)
1627
1628    def impl(self, *inputs):
1629        # The built-in max function don't support complex type
1630        return np.maximum(*inputs)
1631
1632    def c_code(self, node, name, inputs, outputs, sub):
1633        (x, y) = inputs
1634        (z,) = outputs
1635        if any([i.type in complex_types for i in node.inputs]):
1636            raise NotImplementedError()
1637        # Test for both y>x and x>=y to detect NaN
1638        return ('%(z)s = ((%(y)s)>(%(x)s)? (%(y)s): '
1639                '((%(x)s)>=(%(y)s)? (%(x)s): nan("")));' % locals())
1640
1641    def L_op(self, inputs, outputs, gout):
1642        (x, y) = inputs
1643        (gz,) = gout
1644        if gz.type in complex_types:
1645            # max is currently defined for complex_types,
1646            # but the gradient for complex is not.
1647            raise NotImplementedError()
1648
1649        if outputs[0].type in discrete_types:
1650            return [x.zeros_like().astype(theano.config.floatX),
1651                    y.zeros_like().astype(theano.config.floatX)]
1652        # This form handle the case when both value are the same.
1653        # In that case, gx will be gz, gy will be 0.
1654        e = eq(outputs[0], x)
1655        gx = e * gz
1656        gy = (constant(1, dtype=gz.dtype) - e) * gz
1657        return (gx, gy)
1658
1659maximum = Maximum(upcast_out, name='maximum')
1660
1661
1662class Minimum(BinaryScalarOp):
1663    commutative = True
1664    associative = True
1665    nfunc_spec = ('minimum', 2, 1)
1666
1667    def impl(self, *inputs):
1668        # The built-in min function don't support complex type
1669        return np.minimum(*inputs)
1670
1671    def c_code(self, node, name, inputs, outputs, sub):
1672        (x, y) = inputs
1673        (z,) = outputs
1674        if any([i.type in complex_types for i in node.inputs]):
1675            raise NotImplementedError()
1676        return ('%(z)s = ((%(y)s)<(%(x)s)? (%(y)s): '
1677                '((%(x)s)<=(%(y)s)? (%(x)s): nan("")));' % locals())
1678
1679    def L_op(self, inputs, outputs, gout):
1680        (x, y) = inputs
1681        (gz,) = gout
1682        if gz.type in complex_types:
1683            # min is currently defined for complex_types,
1684            # but the gradient for complex is not.
1685            raise NotImplementedError()
1686
1687        if outputs[0].type in discrete_types:
1688            return [x.zeros_like().astype(theano.config.floatX),
1689                    y.zeros_like().astype(theano.config.floatX)]
1690        # This form handle the case when both value are the same.
1691        # In that case, gx will be gz, gy will be 0.
1692        e = eq(outputs[0], x)
1693        gx = e * gz
1694        gy = (constant(1, dtype=gz.dtype) - e) * gz
1695        return (gx, gy)
1696minimum = Minimum(upcast_out, name='minimum')
1697
1698
1699class Add(ScalarOp):
1700    identity = 0
1701    commutative = True
1702    associative = True
1703    nfunc_spec = ('add', 2, 1)
1704
1705    def impl(self, *inputs):
1706        return sum(inputs)
1707
1708    def c_code(self, node, name, inputs, outputs, sub):
1709        (z,) = outputs
1710        op = " + "
1711        if node.outputs[0].type == bool:
1712            op = " || "
1713        if not inputs:
1714            return z + " = 0;"
1715        else:
1716            return z + " = " + op.join(inputs) + ";"
1717
1718    def L_op(self, inputs, outputs, gout):
1719        (gz,) = gout
1720        if gz.type in complex_types:
1721            raise NotImplementedError()
1722        if (outputs[0].type in discrete_types):
1723            assert gz is not None
1724            retval = []
1725            for ii, inp in enumerate(inputs):
1726                if hasattr(inp, 'zeros_like'):
1727                    retval.append(
1728                        inp.zeros_like().astype(theano.config.floatX))
1729                else:
1730                    retval.append(grad_undefined(self, ii, inp))
1731        else:
1732            retval = []
1733            for i in inputs:
1734                    retval += [gz]
1735        return retval
1736
1737
1738add = Add(upcast_out, name='add')
1739
1740
1741class Mul(ScalarOp):
1742    identity = 1
1743    commutative = True
1744    associative = True
1745    nfunc_spec = ('multiply', 2, 1)
1746
1747    def impl(self, *inputs):
1748        return np.product(inputs)
1749
1750    def c_code(self, node, name, inputs, outputs, sub):
1751        (z,) = outputs
1752        op = " * "
1753        if node.outputs[0].type == bool:
1754            op = " && "
1755        if not inputs:
1756            return z + " = 1;"
1757        else:
1758            return z + " = " + op.join(inputs) + ";"
1759
1760    def grad(self, inputs, gout):
1761        (gz,) = gout
1762        retval = []
1763
1764        # The following 3 lines verify that gz is complex when the
1765        # output is complex. The rest of this function make this supposition.
1766        output_type = self.output_types([i.type for i in inputs])[0]
1767        if output_type in complex_types:
1768            if gz.type not in complex_types:
1769                raise TypeError(
1770                    'Mul with output_type ' + str(output_type) +
1771                    ' expected gz type to be complex, got gz with type ' +
1772                    str(gz.type))
1773
1774        if output_type in discrete_types:
1775            return [ipt.zeros_like().astype(theano.config.floatX)
1776                    for ipt in inputs]
1777
1778        for input in inputs:
1779            if gz.type in complex_types:
1780                # zr+zi = (xr + xi)(yr + yi)
1781                # zr+zi = (xr*yr - xi*yi) + (xr yi + xi yr )
1782                otherprod = mul(*(utils.difference(inputs, [input])))
1783                yr = real(otherprod)
1784                yi = imag(otherprod)
1785                if input.type in complex_types:
1786                    retval += [complex(yr * real(gz) + yi * imag(gz),
1787                                       yr * imag(gz) - yi * real(gz))]
1788                else:
1789                    retval += [yr * real(gz) + yi * imag(gz)]
1790            else:
1791                retval += [mul(*([gz] + utils.difference(inputs,
1792                                                         [input])))]
1793        return retval
1794
1795
1796mul = Mul(upcast_out, name='mul')
1797
1798
1799class Sub(BinaryScalarOp):
1800    nfunc_spec = ('subtract', 2, 1)
1801
1802    def impl(self, x, y):
1803        return x - y
1804
1805    def c_code(self, node, name, inputs, outputs, sub):
1806        (x, y) = inputs
1807        (z,) = outputs
1808        return "%(z)s = %(x)s - %(y)s;" % locals()
1809
1810    def L_op(self, inputs, outputs, gout):
1811        (x, y) = inputs
1812        (gz,) = gout
1813        if gz.type in complex_types:
1814            raise NotImplementedError()
1815        if outputs[0].type in discrete_types:
1816            return [x.zeros_like().astype(theano.config.floatX),
1817                    y.zeros_like().astype(theano.config.floatX)]
1818
1819        first_part = gz
1820        second_part = -gz
1821
1822        return first_part, second_part
1823sub = Sub(upcast_out_nobool, name='sub')
1824
1825
1826def int_or_true_div(x_discrete, y_discrete):
1827    """
1828    Return 'int' or 'true' depending on the type of division used for x / y.
1829
1830    Parameters
1831    ----------
1832    x_discrete : bool
1833        True if `x` is discrete ([unsigned] integer).
1834    y_discrete : bool
1835        True if `y` is discrete ([unsigned] integer).
1836
1837    Returns
1838    -------
1839    str
1840        'int' if `x / y` should be an integer division, or `true` if it
1841        should be a true division.
1842
1843    Raises
1844    ------
1845    IntegerDivisionError
1846        If both `x_discrete` and `y_discrete` are True and `config.int_division`
1847        is set to 'raise'.
1848
1849    Notes
1850    -----
1851    This function is used by both scalar/basic.py and tensor/basic.py.
1852
1853    """
1854    if (x_discrete and y_discrete):
1855        if config.int_division == 'raise':
1856            raise IntegerDivisionError(
1857                "With `config.int_division` set to 'raise', dividing two "
1858                "integer types with '/' is forbidden to avoid confusion "
1859                "between integer and floating point divisions. Please "
1860                "use // for integer division, or if you want a float result "
1861                "either cast one of the arguments to a float or directly call "
1862                "`x.__truediv__(y)`.")
1863        elif config.int_division == 'int':
1864            warnings.warn(
1865                "Division of two integer types with x / y is deprecated, "
1866                "please use x // y for an integer division.",
1867                DeprecationWarning,
1868                stacklevel=4)
1869            return int_div
1870        elif config.int_division == 'floatX':
1871            return true_div
1872        else:
1873            raise NotImplementedError(config.int_division)
1874    else:
1875        return true_div
1876
1877
1878def div_proxy(x, y):
1879    """
1880    Proxy for either true_div or int_div, depending on types of x, y.
1881
1882    """
1883    f = int_or_true_div(as_scalar(x).type in discrete_types,
1884                        as_scalar(y).type in discrete_types)
1885    return f(x, y)
1886
1887
1888class TrueDiv(BinaryScalarOp):
1889    nfunc_spec = ('true_divide', 2, 1)
1890
1891    def output_types(self, types):
1892        if all(t in discrete_types for t in types):
1893            return [get_scalar_type(config.floatX)]
1894        else:
1895            return super(TrueDiv, self).output_types(types)
1896
1897    def impl(self, x, y):
1898        x = np.asarray(x)
1899        y = np.asarray(y)
1900        if all(a.dtype in discrete_types for a in (x, y)):
1901            return np.sctypeDict[config.floatX](float(x) / y)
1902        else:
1903            return x / y
1904
1905    def c_code(self, node, name, inputs, outputs, sub):
1906        # we generate good c code only when both are complex!
1907        (x, y) = inputs
1908        (z,) = outputs
1909        if sum([node.inputs[0].type in complex_types,
1910                node.inputs[1].type in complex_types]) == 1:
1911            raise NotImplementedError('type not supported', type)
1912        if (node.inputs[0].type in discrete_types and
1913                node.inputs[1].type in discrete_types):
1914            return "%(z)s = ((double)%(x)s) / %(y)s;" % locals()
1915        return "%(z)s = %(x)s / %(y)s;" % locals()
1916
1917    def grad(self, inputs, gout):
1918
1919        (x, y) = inputs
1920        (gz,) = gout
1921        if x.type in complex_types:
1922            raise NotImplementedError()
1923
1924        # If the output of this op is discrete, then it
1925        # it is locally flat everywhere, so the gradient
1926        # through it is 0.
1927        # This is different from it not being connected
1928        # to the output; x/y is still a function of x
1929        # and y; it's just a step function.
1930        if all(a.dtype in discrete_types for a in (x, y)):
1931            return [x.zeros_like(), y.zeros_like()]
1932
1933        first_part = gz / y
1934
1935        if y.type in complex_types:
1936            raise NotImplementedError()
1937
1938        second_part = -(gz * x) / (y * y)
1939
1940        return first_part, second_part
1941
1942true_div = TrueDiv(upcast_out, name='true_div')
1943
1944
1945class IntDiv(BinaryScalarOp):
1946    nfunc_spec = ('floor_divide', 2, 1)
1947    complex_error = ComplexError(
1948        "Theano does not support integer division (//) on "
1949        "complex numbers, since numpy deprecated it.")
1950
1951    def impl(self, x, y):
1952        return x // y
1953
1954    def c_support_code(self):
1955        # We use a macro as python use % as a special string character,
1956        # and the output of c_code may be run through another level
1957        # of string formatting.
1958        return "#define THEANO_MACRO_MOD(x,y) (x % y)"
1959
1960    def c_code(self, node, name, inputs, outputs, sub):
1961        (x, y) = inputs
1962        (z,) = outputs
1963        fail = sub['fail']
1964
1965        t = node.inputs[0].type.upcast(*[i.type for i in node.inputs[1:]])
1966        if t in imap(str, discrete_types):
1967            x_div_y_pp = '(%(x)s / %(y)s)' % locals()
1968            x_div_y_mp = '((-%(x)s) / %(y)s)' % locals()
1969            x_mod_y_mp = 'THEANO_MACRO_MOD((-%(x)s), %(y)s)' % locals()
1970            x_div_y_pm = '(%(x)s / (-%(y)s))' % locals()
1971            x_mod_y_pm = 'THEANO_MACRO_MOD(%(x)s, (-%(y)s))' % locals()
1972            x_div_y_mm = '((-%(x)s) / (-%(y)s))' % locals()
1973            # If we are in a gpuarray kernel, %(fail)s exits the kernel,
1974            # and we do not have any error report, and we cannot set
1975            # Python error messages either, so for now we just call the
1976            # cuda function, which return a binary pattern of all 1s.
1977            div_zero = dedent('''
1978                #ifdef KERNEL
1979                    %(z)s = %(x_div_y_pp)s;
1980                #else
1981                    PyErr_SetString(PyExc_ZeroDivisionError, "integer division by zero");
1982                    %(fail)s
1983                #endif
1984                ''') % locals()
1985        elif t in imap(str, float_types):
1986            # We need to call different functions of math.h
1987            # depending on the type
1988            if t == 'float32':
1989                floor = 'floorf'
1990                fmod = 'fmodf'
1991            elif t == 'float64':
1992                floor = 'floor'
1993                fmod = 'fmod'
1994            else:
1995                raise NotImplementedError('type not supported', t)
1996
1997            x_div_y_pp = '%(floor)s(%(x)s / %(y)s)' % locals()
1998            x_div_y_mp = '%(floor)s((-%(x)s) / %(y)s)' % locals()
1999            x_mod_y_mp = '%(fmod)s((-%(x)s), %(y)s)' % locals()
2000            x_div_y_pm = '%(floor)s(%(x)s / (-%(y)s))' % locals()
2001            x_mod_y_pm = '%(fmod)s(%(x)s, (-%(y)s))' % locals()
2002            x_div_y_mm = '%(floor)s((-%(x)s) / (-%(y)s))' % locals()
2003            div_zero = '%(z)s = %(x_div_y_pp)s;' % locals()
2004        elif t in complex_types:
2005            raise self.complex_error
2006        else:
2007            raise NotImplementedError('type not supported', t)
2008
2009        return dedent("""
2010            if (%(y)s == 0) {
2011                %(div_zero)s;
2012            } else if (%(y)s < 0) {
2013                if (%(x)s < 0) {
2014                    %(z)s = %(x_div_y_mm)s;
2015                } else {
2016                    %(z)s = - %(x_div_y_pm)s - ((%(x_mod_y_pm)s == 0) ? 0 : 1);
2017                }
2018            } else {
2019                if (%(x)s < 0) {
2020                    %(z)s = - %(x_div_y_mp)s - ((%(x_mod_y_mp)s == 0) ? 0 : 1);
2021                } else {
2022                    %(z)s = %(x_div_y_pp)s;
2023                }
2024            }
2025            """) % locals()
2026
2027    def c_code_cache_version(self):
2028        return (6,)
2029
2030    def grad(self, inputs, g_output):
2031        return [inp.zeros_like(dtype=theano.config.floatX)
2032                for inp in inputs]
2033int_div = IntDiv(upcast_out, name='int_div')
2034
2035
2036floor_div = int_div
2037
2038
2039def mod_check(x, y):
2040    if (as_scalar(x).type in complex_types or
2041            as_scalar(y).type in complex_types):
2042        # Currently forbidden.
2043        raise Mod.complex_error
2044    else:
2045        return mod(x, y)
2046
2047
2048class Mod(BinaryScalarOp):
2049    nfunc_spec = ('mod', 2, 1)
2050    complex_error = ComplexError(
2051        "Theano does not support the mod operator (%) on "
2052        "complex numbers, since numpy deprecated it.")
2053
2054    def impl(self, x, y):
2055        if isinstance(x, np.complex) or isinstance(y, np.complex):
2056            raise self.complex_error
2057        return x % y
2058
2059    def c_code_cache_version(self):
2060        return (9,)
2061
2062    def c_support_code(self):
2063        # We use a macro as python use % as a special string character,
2064        # and the output of c_code may be run through another level
2065        # of string formatting.
2066        return "#define THEANO_MACRO_MOD(x, y) (x % y)"
2067
2068    def c_code(self, node, name, inputs, outputs, sub):
2069        """
2070        We want the result to have the same sign as Python, not the other
2071        implementation of mod.
2072
2073        """
2074        (x, y) = inputs
2075        (z,) = outputs
2076        fail = sub['fail']
2077        t = node.inputs[0].type.upcast(*[i.type for i in node.inputs[1:]])
2078        if (str(t) in imap(str, discrete_types) or
2079                t in ['uint8', 'int8', 'uint16', 'int16'] or
2080                t in ['uint32', 'int32', 'uint64', 'int64'] or
2081                t in discrete_types):
2082            # The above or's should not be needed anymore. However, for now we
2083            # keep them out of safety, and verify they are useless with an
2084            # assert.
2085            assert str(t) in imap(str, discrete_types)
2086            x_mod_y = "THEANO_MACRO_MOD(%(x)s, %(y)s)" % locals()
2087            x_mod_ymm = "THEANO_MACRO_MOD(-%(x)s, -%(y)s)" % locals()
2088            x_mod_ypm = "THEANO_MACRO_MOD(%(x)s, -%(y)s)" % locals()
2089            x_mod_ymp = "THEANO_MACRO_MOD(-%(x)s, %(y)s)" % locals()
2090            # If we are in a gpuarray kernel, %(fail)s exits the kernel,
2091            # and we do not have any error report, and we cannot set
2092            # Python error messages either, so for now we just call the
2093            # cuda function, returning a binary pattern depending on dtype
2094            mod_zero = dedent('''
2095                #ifdef KERNEL
2096                    %(z)s = %(x_mod_y)s;
2097                #else
2098                    PyErr_SetString(PyExc_ZeroDivisionError, "integer modulo by zero");
2099                    %(fail)s
2100                #endif
2101                ''') % locals()
2102        elif (str(t) in imap(str, float_types) or
2103              t in ['float32', 'float64'] or
2104              t in float_types):
2105            # The above or's should not be needed anymore. However, for now we
2106            # keep them out of safety, and verify they are useless with an
2107            # assert.
2108            assert str(t) in imap(str, float_types)
2109            x_mod_y = "fmod(%(x)s, %(y)s)" % locals()
2110            x_mod_ymm = "fmod(-%(x)s, -%(y)s)" % locals()
2111            x_mod_ypm = "fmod(%(x)s, -%(y)s)" % locals()
2112            x_mod_ymp = "fmod(-%(x)s, %(y)s)" % locals()
2113            mod_zero = "%(z)s = %(x_mod_y)s;" % locals()
2114        elif str(t) in imap(str, complex_types):
2115            raise self.complex_error
2116        else:
2117            raise NotImplementedError('type not supported', t)
2118
2119        return dedent("""
2120            if (%(y)s == 0) {
2121                %(mod_zero)s;
2122            } else if (%(y)s < 0){
2123                if (%(x)s < 0){
2124                    %(z)s = -(%(x_mod_ymm)s);
2125                } else {
2126                    %(z)s = (%(x_mod_ypm)s) + (%(x_mod_ypm)s != 0 ? %(y)s : 0);
2127                }
2128            } else {
2129                if (%(x)s < 0){
2130                    %(z)s = - %(x_mod_ymp)s + (%(x_mod_ymp)s != 0 ? %(y)s : 0);
2131                } else {
2132                    %(z)s = %(x_mod_y)s;
2133                }
2134            }
2135            """) % locals()
2136
2137    def L_op(self, inputs, outputs, gout):
2138        (x, y) = inputs
2139        (gz,) = gout
2140        if outputs[0].type.dtype in discrete_types:
2141            # The gradient does not flow in if the output is discrete
2142            return [x.zeros_like(dtype=theano.config.floatX),
2143                    y.zeros_like(dtype=theano.config.floatX)]
2144        return [gz,
2145                -(x // y) * gz]
2146
2147mod = Mod(upcast_out, name='mod')
2148
2149
2150class Pow(BinaryScalarOp):
2151    nfunc_spec = ('power', 2, 1)
2152
2153    def impl(self, x, y):
2154        return x ** y
2155
2156    def c_code(self, node, name, inputs, outputs, sub):
2157        (x, y) = inputs
2158        (z,) = outputs
2159        if (node.inputs[0].type in complex_types or
2160                node.inputs[1].type in complex_types):
2161            raise NotImplementedError('type not supported', type)
2162        return "%(z)s = pow(%(x)s, %(y)s);" % locals()
2163
2164    def L_op(self, inputs, outputs, gout):
2165        (x, y) = inputs
2166        (gz,) = gout
2167        if gz.type in complex_types:
2168            raise NotImplementedError()
2169
2170        if outputs[0].type in discrete_types:
2171            return [x.zeros_like().astype(theano.config.floatX),
2172                    y.zeros_like().astype(theano.config.floatX)]
2173
2174        first_part = gz * y * x ** (y - 1)
2175
2176        second_part = gz * log(x) * x ** y
2177        second_part = switch(eq(x, 0), 0, second_part)
2178
2179        return (first_part, second_part)
2180
2181    def c_code_contiguous(self, node, name, inputs, outputs, sub):
2182        (x, y) = inputs
2183        (z,) = outputs
2184        if not theano.config.lib.amdlibm:
2185            raise theano.gof.utils.MethodNotDefined()
2186
2187        # We compare the dtype AND the broadcast flag
2188        # as this function do not broadcast
2189        if (node.inputs[0].type == node.outputs[0].type and
2190                node.inputs[1].type == node.outputs[0].type and
2191                # amdlibm 3.0 do not have a float64 version of this SIMD function
2192                node.inputs[0].dtype == 'float32' and
2193                node.inputs[1].dtype == 'float32'):
2194            dtype = 'float'
2195            fct = "amd_vrsa_powf"
2196            return """
2197        npy_intp n = PyArray_SIZE(%(z)s);
2198        %(dtype)s * x = (%(dtype)s*) PyArray_DATA(%(x)s);
2199        %(dtype)s * y = (%(dtype)s*) PyArray_DATA(%(y)s);
2200        %(dtype)s * z = (%(dtype)s*) PyArray_DATA(%(z)s);
2201        %(fct)s(n, x, y, z);
2202        """ % locals()
2203        # We compare the dtype and check we broadcast a scalar
2204        elif (node.inputs[0].type == node.outputs[0].type and
2205              node.inputs[1].dtype == node.outputs[0].dtype and
2206              all(node.inputs[1].broadcastable) and
2207              # amdlibm 3.0 do not have a float64 version of this SIMD function
2208              node.inputs[0].dtype == 'float32' and
2209              node.inputs[1].dtype == 'float32'):
2210            dtype = 'float'
2211            fct = "amd_vrsa_powxf"
2212            return """
2213        npy_intp n = PyArray_SIZE(%(z)s);
2214        %(dtype)s * x = (%(dtype)s*) PyArray_DATA(%(x)s);
2215        %(dtype)s * y = (%(dtype)s*) PyArray_DATA(%(y)s);
2216        %(dtype)s * z = (%(dtype)s*) PyArray_DATA(%(z)s);
2217        %(fct)s(n, x, *y, z);
2218        """ % locals()
2219
2220        raise theano.gof.utils.MethodNotDefined()
2221
2222
2223pow = Pow(upcast_out_min8, name='pow')
2224
2225
2226class Clip(ScalarOp):
2227    nin = 3
2228    # The numpy.clip don't work correctly when the min is bigger then the max,
2229    # So we do not use nfunc_spec = ('clip', 3, 1)
2230
2231    def impl(self, x, min, max):
2232        if x < min:
2233            return min
2234        elif x > max:
2235            return max
2236        else:
2237            return x
2238
2239    def c_code(self, node, name, inputs, outputs, sub):
2240        (x, min, max) = inputs
2241        (z,) = outputs
2242        return "%(z)s = %(x)s < %(min)s ? %(min)s : %(x)s > %(max)s ? %(max)s : %(x)s;" % locals()
2243
2244    def L_op(self, inputs, outputs, gout):
2245        (x, mn, mx) = inputs
2246        (gz,) = gout
2247        assert gz.type not in complex_types
2248        gx = ((x >= mn) & (x <= mx)) * gz
2249        gmn = (x < mn) * gz
2250        gmx = (x > mx) * gz
2251
2252        def handle_int(v):
2253            if outputs[0].type in int_types:
2254                return v.zeros_like().astype(config.floatX)
2255            return v
2256
2257        return list(map(handle_int, [gx, gmn, gmx]))
2258
2259# Don't allow complex even if numpy do
2260# As there is no mathematical reason for this function on complex
2261clip = Clip(upcast_out_no_complex, name='clip')
2262
2263
2264class Second(BinaryScalarOp):
2265    def impl(self, x, y):
2266        return y
2267
2268    def c_code(self, node, name, inputs, outputs, sub):
2269        (x, y) = inputs
2270        (z,) = outputs
2271        return "%(z)s = %(y)s;" % locals()
2272
2273    def connection_pattern(self, node):
2274
2275        # x is never connected because its elements are never used
2276        # y is connected because its elements are copied over
2277
2278        return [[False], [True]]
2279
2280    def grad(self, inputs, gout):
2281
2282        (x, y) = inputs
2283        (gz,) = gout
2284        if y.type in continuous_types:
2285            # x is disconnected because the elements of x are not used
2286            return DisconnectedType()(), gz
2287        else:
2288            # when y is discrete, we assume the function can be extended
2289            # to deal with real-valued inputs by rounding them to the
2290            # nearest integer. f(x+eps) thus equals f(x) so the gradient
2291            # is zero, not disconnected or undefined
2292            return DisconnectedType()(), y.zeros_like()
2293
2294second = Second(transfer_type(1), name='second')
2295
2296
2297class Identity(UnaryScalarOp):
2298    def impl(self, input):
2299        return input
2300
2301    def c_code(self, node, name, inputs, outputs, sub):
2302        (x,) = inputs
2303        (z,) = outputs
2304        return "%(z)s = %(x)s;" % locals()
2305
2306    def grad(self, inputs, gout):
2307        (x,) = inputs
2308        (gz,) = gout
2309        if x.type in continuous_types:
2310            return gz,
2311        else:
2312            return x.zeros_like(dtype=theano.config.floatX),
2313identity = Identity(same_out, name='identity')
2314
2315
2316# CASTING OPERATIONS
2317class Cast(UnaryScalarOp):
2318    def __init__(self, o_type, name=None):
2319        if not isinstance(o_type, Scalar):
2320            raise TypeError(o_type)
2321        super(Cast, self).__init__(specific_out(o_type), name=name)
2322        self.o_type = o_type
2323        self.ctor = getattr(np, o_type.dtype)
2324
2325    def __str__(self):
2326        return '%s{%s}' % (self.__class__.__name__, self.o_type.dtype)
2327
2328    def clone_float32(self):
2329        if self.o_type == float16:
2330            return convert_to_float32
2331        return self
2332
2333    def make_new_inplace(self, output_types_preference=None, name=None):
2334        """
2335        This op.__init__ fct don't have the same parameter as other scalar op.
2336        This breaks the insert_inplace_optimizer optimization.
2337        This function is a fix to patch this, by ignoring the
2338        output_types_preference passed by the optimization, and replacing it
2339        by the current output type. This should only be triggered when
2340        both input and output have the same dtype anyway.
2341
2342        """
2343        return self.__class__(self.o_type, name)
2344
2345    def impl(self, input):
2346        return self.ctor(input)
2347
2348    def c_code(self, node, name, inputs, outputs, sub):
2349        (x,) = inputs
2350        (z,) = outputs
2351        if node.outputs[0].type == bool:
2352            return "%s = (%s) ? 1 : 0;" % (z, x)
2353        return "%s = (%s)%s;" % (z, node.outputs[0].type.dtype_specs()[1], x)
2354
2355    def grad(self, inputs, gout):
2356        (x,) = inputs
2357        (gz,) = gout
2358        if self.o_type in continuous_types:
2359            return [gz]
2360        else:
2361            return [x.zeros_like().astype(theano.config.floatX)]
2362
2363    def c_code_cache_version(self):
2364        s = super(Cast, self).c_code_cache_version()
2365        if s:
2366            return (4,) + s
2367        else:
2368            return s
2369
2370convert_to_bool = Cast(bool, name='convert_to_bool')
2371convert_to_int8 = Cast(int8, name='convert_to_int8')
2372convert_to_int16 = Cast(int16, name='convert_to_int16')
2373convert_to_int32 = Cast(int32, name='convert_to_int32')
2374convert_to_int64 = Cast(int64, name='convert_to_int64')
2375convert_to_uint8 = Cast(uint8, name='convert_to_uint8')
2376convert_to_uint16 = Cast(uint16, name='convert_to_uint16')
2377convert_to_uint32 = Cast(uint32, name='convert_to_uint32')
2378convert_to_uint64 = Cast(uint64, name='convert_to_uint64')
2379convert_to_float16 = Cast(float16, name='convert_to_float16')
2380convert_to_float32 = Cast(float32, name='convert_to_float32')
2381convert_to_float64 = Cast(float64, name='convert_to_float64')
2382convert_to_complex64 = Cast(complex64, name='convert_to_complex64')
2383convert_to_complex128 = Cast(complex128, name='convert_to_complex128')
2384
2385_cast_mapping = {
2386    'bool': convert_to_bool,
2387    'int8': convert_to_int8,
2388    'int16': convert_to_int16,
2389    'int32': convert_to_int32,
2390    'int64': convert_to_int64,
2391    'uint8': convert_to_uint8,
2392    'uint16': convert_to_uint16,
2393    'uint32': convert_to_uint32,
2394    'uint64': convert_to_uint64,
2395    'float16': convert_to_float16,
2396    'float32': convert_to_float32,
2397    'float64': convert_to_float64,
2398    'complex64': convert_to_complex64,
2399    'complex128': convert_to_complex128}
2400
2401
2402def cast(x, dtype):
2403    """
2404    Symbolically cast `x` to a Scalar of given `dtype`.
2405
2406    """
2407    if dtype == 'floatX':
2408        dtype = config.floatX
2409
2410    _x = as_scalar(x)
2411    if _x.type.dtype == dtype:
2412        return _x
2413    if _x.type.dtype.startswith('complex') and not dtype.startswith('complex'):
2414        raise TypeError('Casting from complex to real is ambiguous: consider'
2415                        ' real(), imag(), angle() or abs()')
2416    return _cast_mapping[dtype](_x)
2417
2418
2419class Abs(UnaryScalarOp):
2420    nfunc_spec = ('abs', 1, 1)
2421
2422    def make_node(self, x):
2423        inputs = [as_scalar(input) for input in [x]]
2424        if inputs[0].type == complex64:
2425            outputs = [float32()]
2426        elif inputs[0].type == complex128:
2427            outputs = [float64()]
2428        else:
2429            outputs = [t() for t in self.output_types(
2430                [input.type for input in inputs])]
2431        return Apply(self, inputs, outputs)
2432
2433    def impl(self, x):
2434        return np.abs(x)
2435
2436    def L_op(self, inputs, outputs, gout):
2437        (x,) = inputs
2438        (gz,) = gout
2439        if (outputs[0].type in discrete_types):
2440            if x.type in discrete_types:
2441                return [x.zeros_like(dtype=theano.config.floatX)]
2442            else:
2443                return [x.zeros_like()]
2444
2445        if x.type in float_types:
2446            return gz * sgn(x),
2447        return gz * x / abs(x),  # formula works for complex and real
2448
2449    def c_code(self, node, name, inputs, outputs, sub):
2450        (x,) = inputs
2451        (z,) = outputs
2452        type = node.inputs[0].type
2453        if type in int_types:
2454            return "%(z)s = abs(%(x)s);" % locals()
2455        if type in float_types:
2456            return "%(z)s = fabs(%(x)s);" % locals()
2457        if type in complex_types:
2458            return "%(z)s = sqrt(%(x)s.real*%(x)s.real + %(x)s.imag*%(x)s.imag);" % locals()
2459        if node.outputs[0].type == bool:
2460            return "%(z)s = (%(x)s) ? 1 : 0;" % locals()
2461        if type in uint_types:
2462            # uint are always already absolute value.
2463            return "%(z)s = %(x)s;" % locals()
2464        raise NotImplementedError('type not supported', type)
2465abs_ = Abs(same_out)
2466
2467
2468class Sgn(UnaryScalarOp):
2469    nfunc_spec = ('sign', 1, 1)
2470
2471    @staticmethod
2472    def output_types_preference(x):
2473        if x == bool:
2474            raise TypeError(x)
2475        return same_out_nocomplex(x)
2476
2477    def impl(self, x):
2478        # casting to output type is handled by filter
2479        return np.sign(x)
2480
2481    def grad(self, inputs, gout):
2482        (x,) = inputs
2483        (gz,) = gout
2484        rval = x.zeros_like()
2485
2486        if rval.type.dtype in discrete_types:
2487            rval = rval.astype(theano.config.floatX)
2488
2489        return [rval]
2490
2491    def c_code(self, node, name, inputs, outputs, sub):
2492        # casting is done by compiler
2493        # TODO: use copysign
2494        (x,) = inputs
2495        (z,) = outputs
2496        type = node.inputs[0].type
2497        if type in float_types:
2498            return '%(z)s = (%(x)s > 0) ? 1. : ((%(x)s < 0) ? -1. : (isnan(%(x)s) ? NAN : 0.));' % locals()
2499        if type in int_types:
2500            return "%(z)s = (%(x)s >= 0) ? (%(x)s == 0) ? 0 : 1 : -1;" % locals()
2501        raise ComplexError('complex has no sgn')
2502
2503    def c_code_cache_version(self):
2504        s = super(Sgn, self).c_code_cache_version()
2505        if s:
2506            return (4,) + s
2507        else:  # if parent is unversioned, we are too
2508            return s
2509sgn = Sgn(name='sgn')
2510
2511
2512class Ceil(UnaryScalarOp):
2513    nfunc_spec = ('ceil', 1, 1)
2514
2515    def impl(self, x):
2516        return np.ceil(x)
2517
2518    def grad(self, inputs, gout):
2519        (x,) = inputs
2520        (gz,) = gout
2521        rval = x.zeros_like()
2522
2523        if rval.type.dtype in discrete_types:
2524            rval = rval.astype(theano.config.floatX)
2525
2526        return [rval]
2527
2528    def c_code(self, node, name, inputs, outputs, sub):
2529        (x,) = inputs
2530        (z,) = outputs
2531        cast = node.outputs[0].type.dtype_specs()[1]
2532        return "%(z)s = ceil((%(cast)s)%(x)s);" % locals()
2533ceil = Ceil(upgrade_to_float_no_complex, name='ceil')
2534
2535
2536class Floor(UnaryScalarOp):
2537    nfunc_spec = ('floor', 1, 1)
2538
2539    def impl(self, x):
2540        return np.floor(x)
2541
2542    def grad(self, inputs, gout):
2543        (x,) = inputs
2544        (gz,) = gout
2545        rval = x.zeros_like()
2546
2547        if rval.type.dtype in discrete_types:
2548            rval = rval.astype(theano.config.floatX)
2549
2550        return [rval]
2551
2552    def c_code(self, node, name, inputs, outputs, sub):
2553        (x,) = inputs
2554        (z,) = outputs
2555        cast = node.outputs[0].type.dtype_specs()[1]
2556        return "%(z)s = floor((%(cast)s)%(x)s);" % locals()
2557floor = Floor(upgrade_to_float_no_complex, name='floor')
2558
2559
2560class Trunc(UnaryScalarOp):
2561    nfunc_spec = ('trunc', 1, 1)
2562
2563    def impl(self, x):
2564        return np.trunc(x)
2565
2566    def grad(self, inputs, gout):
2567        (x,) = inputs
2568        (gz,) = gout
2569        return [x.zeros_like().astype(theano.config.floatX)]
2570
2571    def c_code(self, node, name, inputs, outputs, sub):
2572        (x,) = inputs
2573        (z,) = outputs
2574        return "%(z)s = %(x)s >= 0? floor(%(x)s): -floor(-%(x)s);" % locals()
2575trunc = Trunc(upgrade_to_float_no_complex, name='trunc')
2576
2577
2578class RoundHalfToEven(UnaryScalarOp):
2579    """
2580    This function implement the same rounding than numpy: Round half to even.
2581
2582    c/c++ round fct IS DIFFERENT!
2583    See http://en.wikipedia.org/wiki/Rounding for more details.
2584
2585    """
2586    nfunc_spec = ('around', 1, 1)
2587
2588    def impl(self, x):
2589        return np.round(x)
2590
2591    def grad(self, inputs, gout):
2592        (x,) = inputs
2593        (gz,) = gout
2594        rval = x.zeros_like()
2595
2596        if rval.type.dtype in discrete_types:
2597            rval = rval.astype(theano.config.floatX)
2598
2599        return [rval]
2600
2601    def c_code_cache_version(self):
2602        return (1,)
2603
2604    def c_code(self, node, name, inputs, outputs, sub):
2605        (x,) = inputs
2606        (z,) = outputs
2607        typ = node.outputs[0].type.dtype
2608        if typ not in ['float32', 'float64']:
2609            raise NotImplementedError("The output should be float32 or float64")
2610        if typ == 'float32':
2611            ctype = 'float'
2612            floor_function = 'floorf'
2613        else:
2614            ctype = 'double'
2615            floor_function = 'floor'
2616        return """
2617        /* Code inspired from NumPy npy_rint implementation. */
2618        {
2619            %(ctype)s y, r;
2620            y = %(floor_function)s(%(x)s);
2621            r = %(x)s - y;
2622            if(r > 0.5) {
2623                y += 1;
2624            } else if(r == 0.5) {
2625                r = y - 2.0*%(floor_function)s(0.5*y);
2626                /*
2627                If y is even, then r == 0
2628                If y is odd,  then r == 1
2629                So we can just add r to y, so that
2630                y will be incremented only if he's odd.
2631                */
2632                y += (int)r;
2633            }
2634            %(z)s = y;
2635        }
2636        """ % locals()
2637round_half_to_even = RoundHalfToEven(same_out_float_only)
2638
2639
2640def round_half_away_from_zero_(a):
2641    if a > 0:
2642        return np.floor(a + 0.5)
2643    else:
2644        return np.ceil(a - 0.5)
2645
2646round_half_away_from_zero_vec64 = np.vectorize(
2647    round_half_away_from_zero_,
2648    doc='round_half_away_from_zero_vec64')
2649round_half_away_from_zero_vec32 = np.vectorize(
2650    round_half_away_from_zero_,
2651    doc='round_half_away_from_zero_vec32',
2652    otypes=['float32'])
2653
2654
2655def round_half_away_from_zero_vec(a):
2656    if getattr(a, 'dtype', None) == np.float32:
2657        return round_half_away_from_zero_vec32(a)
2658    return round_half_away_from_zero_vec64(a)
2659
2660
2661class RoundHalfAwayFromZero(UnaryScalarOp):
2662    """
2663    Implement the same rounding algo as c round() fct.
2664
2665    numpy.round fct IS DIFFERENT!
2666    See http://en.wikipedia.org/wiki/Rounding for more details.
2667
2668    """
2669    def impl(self, x):
2670        return round_half_away_from_zero_vec(x)
2671
2672    def grad(self, inputs, gout):
2673        (x,) = inputs
2674        (gz,) = gout
2675        rval = x.zeros_like()
2676
2677        if rval.type.dtype in discrete_types:
2678            rval = rval.astype(theano.config.floatX)
2679
2680        return [rval]
2681
2682    def c_code(self, node, name, inputs, outputs, sub):
2683        (x,) = inputs
2684        (z,) = outputs
2685        if node.outputs[0].type.dtype in ['float32', 'float64']:
2686            return "%(z)s = round(%(x)s);" % locals()
2687        else:
2688            raise NotImplementedError("The output should be float32 or float64")
2689round_half_away_from_zero = RoundHalfAwayFromZero(same_out_float_only)
2690
2691
2692class Neg(UnaryScalarOp):
2693    # We can use numpy.negative here, because even if it gives unexpected
2694    # results on Boolean arrays, it will be passed other dtypes as Theano
2695    # does not have a Boolean type for tensors.
2696    nfunc_spec = ('negative', 1, 1)
2697
2698    def impl(self, x):
2699        return -x
2700
2701    def L_op(self, inputs, outputs, gout):
2702        (x,) = inputs
2703        (gz,) = gout
2704        if outputs[0].type in discrete_types:
2705            if x.type in discrete_types:
2706                return [x.zeros_like(dtype=theano.config.floatX)]
2707            else:
2708                return [x.zeros_like()]
2709
2710        return -gz,
2711
2712    def c_code(self, node, name, inputs, outputs, sub):
2713        (x,) = inputs
2714        (z,) = outputs
2715        return "%(z)s = -%(x)s;" % locals()
2716neg = Neg(same_out_nobool, name='neg')
2717
2718pprint.assign(add, printing.OperatorPrinter('+', -2, 'either'))
2719pprint.assign(mul, printing.OperatorPrinter('*', -1, 'either'))
2720pprint.assign(sub, printing.OperatorPrinter('-', -2, 'left'))
2721pprint.assign(neg, printing.OperatorPrinter('-', 0, 'either'))
2722pprint.assign(true_div, printing.OperatorPrinter('/', -1, 'left'))
2723pprint.assign(int_div, printing.OperatorPrinter('//', -1, 'left'))
2724pprint.assign(pow, printing.OperatorPrinter('**', 1, 'right'))
2725pprint.assign(mod, printing.OperatorPrinter('%', -1, 'left'))
2726
2727
2728class Inv(UnaryScalarOp):
2729    """
2730    Multiplicative inverse. Also called reciprocal.
2731
2732    """
2733    def impl(self, x):
2734        return np.float32(1.0) / x
2735
2736    def L_op(self, inputs, outputs, gout):
2737        (x,) = inputs
2738        (gz,) = gout
2739        if x.type in complex_types:
2740            raise NotImplementedError()
2741        if outputs[0].type in discrete_types:
2742            if x.type in discrete_types:
2743                return [x.zeros_like(dtype=theano.config.floatX)]
2744            else:
2745                return [x.zeros_like()]
2746
2747        return -gz / (x * x),
2748
2749    def c_code(self, node, name, inputs, outputs, sub):
2750        (x,) = inputs
2751        (z,) = outputs
2752        if node.inputs[0].type in complex_types:
2753            raise NotImplementedError()
2754        return "%(z)s = 1.0 / %(x)s;" % locals()
2755inv = Inv(upgrade_to_float, name='inv')
2756
2757
2758class Log(UnaryScalarOp):
2759    """
2760    log base e.
2761
2762    """
2763    nfunc_spec = ('log', 1, 1)
2764    amd_float32 = "amd_vrsa_logf"
2765    amd_float64 = "amd_vrda_log"
2766
2767    def impl(self, x):
2768        # If x is an int8 or uint8, numpy.log will compute the result in
2769        # half-precision (float16), where we want float32.
2770        x_dtype = str(getattr(x, 'dtype', ''))
2771        if x_dtype in ('int8', 'uint8'):
2772            return np.log(x, sig='f')
2773        return np.log(x)
2774
2775    def L_op(self, inputs, outputs, gout):
2776        (x,) = inputs
2777        (gz,) = gout
2778        if x.type in complex_types:
2779            raise NotImplementedError()
2780        if outputs[0].type in discrete_types:
2781            if x.type in discrete_types:
2782                return [x.zeros_like(dtype=theano.config.floatX)]
2783            else:
2784                return [x.zeros_like()]
2785
2786        return gz / x,
2787
2788    def c_code(self, node, name, inputs, outputs, sub):
2789        # todo: the version using log2 seems to be very slightly faster
2790        # on some machines for some reason, check if it's worth switching
2791        # return "%(z)s = log2(%(x)s) * 0.69314718055994529;" % locals()
2792        (x,) = inputs
2793        (z,) = outputs
2794        if node.inputs[0].type in complex_types:
2795            raise NotImplementedError('type not supported', type)
2796        cast = node.outputs[0].type.dtype_specs()[1]
2797        return "%(z)s = log((%(cast)s)%(x)s);" % locals()
2798log = Log(upgrade_to_float, name='log')
2799
2800
2801class Log2(UnaryScalarOp):
2802    """
2803    log base 2.
2804
2805    """
2806    nfunc_spec = ('log2', 1, 1)
2807    amd_float32 = "amd_vrsa_log2f"
2808    amd_float64 = "amd_vrda_log2"
2809
2810    def impl(self, x):
2811        # If x is an int8 or uint8, numpy.log2 will compute the result in
2812        # half-precision (float16), where we want float32.
2813        x_dtype = str(getattr(x, 'dtype', ''))
2814        if x_dtype in ('int8', 'uint8'):
2815            return np.log2(x, sig='f')
2816        return np.log2(x)
2817
2818    def L_op(self, inputs, outputs, gout):
2819        (x,) = inputs
2820        (gz,) = gout
2821        if x.type in complex_types:
2822            raise NotImplementedError()
2823        if outputs[0].type in discrete_types:
2824            if x.type in discrete_types:
2825                return [x.zeros_like(dtype=theano.config.floatX)]
2826            else:
2827                return [x.zeros_like()]
2828
2829        return gz / (x * np.asarray(math.log(2.0)).astype(x.dtype)),
2830
2831    def c_code(self, node, name, inputs, outputs, sub):
2832        (x,) = inputs
2833        (z,) = outputs
2834        if node.inputs[0].type in complex_types:
2835            raise NotImplementedError('type not supported', type)
2836        cast = node.outputs[0].type.dtype_specs()[1]
2837        return "%(z)s = log2((%(cast)s)%(x)s);" % locals()
2838log2 = Log2(upgrade_to_float, name='log2')
2839
2840
2841class Log10(UnaryScalarOp):
2842    """
2843    log base 10.
2844
2845    """
2846    nfunc_spec = ('log10', 1, 1)
2847    amd_float32 = "amd_vrsa_log10f"
2848    amd_float64 = "amd_vrda_log10"
2849
2850    def impl(self, x):
2851        # If x is an int8 or uint8, numpy.log10 will compute the result in
2852        # half-precision (float16), where we want float32.
2853        x_dtype = str(getattr(x, 'dtype', ''))
2854        if x_dtype in ('int8', 'uint8'):
2855            return np.log10(x, sig='f')
2856        return np.log10(x)
2857
2858    def L_op(self, inputs, outputs, gout):
2859        (x,) = inputs
2860        (gz,) = gout
2861        if x.type in complex_types:
2862            raise NotImplementedError()
2863        if outputs[0].type in discrete_types:
2864            if x.type in discrete_types:
2865                return [x.zeros_like(dtype=theano.config.floatX)]
2866            else:
2867                return [x.zeros_like()]
2868
2869        return gz / (x * np.log(10.0)),
2870
2871    def c_code(self, node, name, inputs, outputs, sub):
2872        (x,) = inputs
2873        (z,) = outputs
2874        if node.inputs[0].type in complex_types:
2875            raise NotImplementedError('type not supported', type)
2876        cast = node.outputs[0].type.dtype_specs()[1]
2877        return "%(z)s = log10((%(cast)s)%(x)s);" % locals()
2878log10 = Log10(upgrade_to_float, name='log10')
2879
2880
2881class Log1p(UnaryScalarOp):
2882    """
2883    log(1+x).
2884
2885    """
2886    nfunc_spec = ('log1p', 1, 1)
2887
2888    def impl(self, x):
2889        # If x is an int8 or uint8, numpy.log1p will compute the result in
2890        # half-precision (float16), where we want float32.
2891        x_dtype = str(getattr(x, 'dtype', ''))
2892        if x_dtype in ('int8', 'uint8'):
2893            return np.log1p(x, sig='f')
2894        return np.log1p(x)
2895
2896    def L_op(self, inputs, outputs, gout):
2897        (x,) = inputs
2898        (gz,) = gout
2899        if gz.type in complex_types:
2900            raise NotImplementedError()
2901        if outputs[0].type in discrete_types:
2902            if x.type in discrete_types:
2903                return [x.zeros_like(dtype=theano.config.floatX)]
2904            else:
2905                return [x.zeros_like()]
2906
2907        return [gz / (1 + x)]
2908
2909    def c_code(self, node, name, inputs, outputs, sub):
2910        (x,) = inputs
2911        (z,) = outputs
2912        if node.inputs[0].type in complex_types:
2913            raise NotImplementedError('type not supported', type)
2914        cast = node.outputs[0].type.dtype_specs()[1]
2915        return "%(z)s = log1p((%(cast)s)%(x)s);" % locals()
2916log1p = Log1p(upgrade_to_float, name='log1p')
2917
2918
2919class Exp(UnaryScalarOp):
2920    nfunc_spec = ('exp', 1, 1)
2921    amd_float32 = "amd_vrsa_expf"
2922    amd_float64 = "amd_vrda_exp"
2923
2924    def impl(self, x):
2925        # If x is an int8 or uint8, numpy.exp will compute the result in
2926        # half-precision (float16), where we want float32.
2927        x_dtype = str(getattr(x, 'dtype', ''))
2928        if x_dtype in ('int8', 'uint8'):
2929            return np.exp(x, sig='f')
2930        return np.exp(x)
2931
2932    def L_op(self, inputs, outputs, gout):
2933        (x,) = inputs
2934        (gz,) = gout
2935        if x.type in complex_types:
2936            raise NotImplementedError()
2937        if outputs[0].type in discrete_types:
2938            if x.type in discrete_types:
2939                return [x.zeros_like(dtype=theano.config.floatX)]
2940            else:
2941                return [x.zeros_like()]
2942
2943        return gz * exp(x),
2944
2945    def c_code(self, node, name, inputs, outputs, sub):
2946        (x,) = inputs
2947        (z,) = outputs
2948        if node.inputs[0].type in complex_types:
2949            raise NotImplementedError('type not supported', type)
2950        cast = node.outputs[0].type.dtype_specs()[1]
2951        return "%(z)s = exp((%(cast)s)%(x)s);" % locals()
2952exp = Exp(upgrade_to_float, name='exp')
2953
2954
2955class Exp2(UnaryScalarOp):
2956    nfunc_spec = ('exp2', 1, 1)
2957
2958    def impl(self, x):
2959        # If x is an int8 or uint8, numpy.exp2 will compute the result in
2960        # half-precision (float16), where we want float32.
2961        x_dtype = str(getattr(x, 'dtype', ''))
2962        if x_dtype in ('int8', 'uint8'):
2963            return np.exp2(x, sig='f')
2964        return np.exp2(x)
2965
2966    def L_op(self, inputs, outputs, gout):
2967        (x,) = inputs
2968        (gz,) = gout
2969        if x.type in complex_types:
2970            raise NotImplementedError()
2971        if outputs[0].type in discrete_types:
2972            if x.type in discrete_types:
2973                return [x.zeros_like(dtype=theano.config.floatX)]
2974            else:
2975                return [x.zeros_like()]
2976
2977        return gz * exp2(x) * log(np.cast[x.type](2)),
2978
2979    def c_code(self, node, name, inputs, outputs, sub):
2980        (x,) = inputs
2981        (z,) = outputs
2982        if node.inputs[0].type in complex_types:
2983            raise NotImplementedError('type not supported', type)
2984        cast = node.outputs[0].type.dtype_specs()[1]
2985        return "%(z)s = exp2((%(cast)s)%(x)s);" % locals()
2986exp2 = Exp2(upgrade_to_float, name='exp2')
2987
2988
2989class Expm1(UnaryScalarOp):
2990    nfunc_spec = ('expm1', 1, 1)
2991
2992    def impl(self, x):
2993        # If x is an int8 or uint8, numpy.expm1 will compute the result in
2994        # half-precision (float16), where we want float32.
2995        x_dtype = str(getattr(x, 'dtype', ''))
2996        if x_dtype in ('int8', 'uint8'):
2997            return np.expm1(x, sig='f')
2998        return np.expm1(x)
2999
3000    def L_op(self, inputs, outputs, gout):
3001        (x,) = inputs
3002        (gz,) = gout
3003        if x.type in complex_types:
3004            raise NotImplementedError()
3005        if outputs[0].type in discrete_types:
3006            if x.type in discrete_types:
3007                return [x.zeros_like(dtype=theano.config.floatX)]
3008            else:
3009                return [x.zeros_like()]
3010
3011        return gz * exp(x),
3012
3013    def c_code(self, node, name, inputs, outputs, sub):
3014        (x,) = inputs
3015        (z,) = outputs
3016        if node.inputs[0].type in complex_types:
3017            raise NotImplementedError('type not supported', type)
3018        cast = node.outputs[0].type.dtype_specs()[1]
3019        return "%(z)s = expm1((%(cast)s)%(x)s);" % locals()
3020
3021    def c_code_cache_version(self):
3022        return (5,)
3023expm1 = Expm1(upgrade_to_float, name='expm1')
3024
3025
3026class Sqr(UnaryScalarOp):
3027    nfunc_spec = ('square', 1, 1)
3028
3029    def impl(self, x):
3030        return x * x
3031
3032    def L_op(self, inputs, outputs, gout):
3033        (x,) = inputs
3034        (gz,) = gout
3035        if gz.type in complex_types:
3036            raise NotImplementedError()
3037        if outputs[0].type in discrete_types:
3038            if x.type in discrete_types:
3039                return [x.zeros_like(dtype=theano.config.floatX)]
3040            else:
3041                return [x.zeros_like()]
3042
3043        return gz * x * 2,
3044
3045    def c_code(self, node, name, inputs, outputs, sub):
3046        (x,) = inputs
3047        (z,) = outputs
3048        return "%(z)s = %(x)s * %(x)s;" % locals()
3049sqr = Sqr(same_out, name='sqr')
3050
3051
3052class Sqrt(UnaryScalarOp):
3053    nfunc_spec = ('sqrt', 1, 1)
3054
3055    def impl(self, x):
3056        # If x is an int8 or uint8, numpy.sqrt will compute the result in
3057        # half-precision (float16), where we want float32.
3058        x_dtype = str(getattr(x, 'dtype', ''))
3059        if x_dtype in ('int8', 'uint8'):
3060            return np.sqrt(x, sig='f')
3061        return np.sqrt(x)
3062
3063    def L_op(self, inputs, outputs, gout):
3064        (x,) = inputs
3065        (gz,) = gout
3066        if gz.type in complex_types:
3067            raise NotImplementedError()
3068        if outputs[0].type in discrete_types:
3069            if x.type in discrete_types:
3070                return [x.zeros_like(dtype=theano.config.floatX)]
3071            else:
3072                return [x.zeros_like()]
3073
3074        return (gz * 0.5) / sqrt(x),
3075
3076    def c_code(self, node, name, inputs, outputs, sub):
3077        (x,) = inputs
3078        (z,) = outputs
3079        if node.inputs[0].type in complex_types:
3080            raise NotImplementedError('type not supported', type)
3081        cast = node.outputs[0].type.dtype_specs()[1]
3082        return "%(z)s = sqrt((%(cast)s)%(x)s);" % locals()
3083sqrt = Sqrt(upgrade_to_float, name='sqrt')
3084
3085
3086class Deg2Rad(UnaryScalarOp):
3087    nfunc_spec = ('deg2rad', 1, 1)
3088
3089    def impl(self, x):
3090        # If x is an int8 or uint8, numpy.deg2rad will compute the result in
3091        # half-precision (float16), where we want float32.
3092        x_dtype = str(getattr(x, 'dtype', ''))
3093        if x_dtype in ('int8', 'uint8'):
3094            return np.deg2rad(x, sig='f')
3095        return np.deg2rad(x)
3096
3097    def L_op(self, inputs, outputs, gout):
3098        (x,) = inputs
3099        (gz,) = gout
3100        if gz.type in complex_types:
3101            raise NotImplementedError()
3102        if outputs[0].type in discrete_types:
3103            if x.type in discrete_types:
3104                return [x.zeros_like(dtype=theano.config.floatX)]
3105            else:
3106                return [x.zeros_like()]
3107
3108        return gz * np.asarray(np.pi / 180, gz.type),
3109
3110    def c_code(self, node, name, inputs, outputs, sub):
3111        (x,) = inputs
3112        (z,) = outputs
3113        if node.inputs[0].type in complex_types:
3114            raise NotImplementedError('type not supported', type)
3115        return "%(z)s = %(x)s * (M_PI / 180.0);" % locals()
3116deg2rad = Deg2Rad(upgrade_to_float, name='deg2rad')
3117
3118
3119class Rad2Deg(UnaryScalarOp):
3120    nfunc_spec = ('rad2deg', 1, 1)
3121
3122    def impl(self, x):
3123        # If x is an int8 or uint8, numpy.rad2deg will compute the result in
3124        # half-precision (float16), where we want float32.
3125        x_dtype = str(getattr(x, 'dtype', ''))
3126        if x_dtype in ('int8', 'uint8'):
3127            return np.rad2deg(x, sig='f')
3128        return np.rad2deg(x)
3129
3130    def L_op(self, inputs, outputs, gout):
3131        (x,) = inputs
3132        (gz,) = gout
3133        if gz.type in complex_types:
3134            raise NotImplementedError()
3135        if outputs[0].type in discrete_types:
3136            if x.type in discrete_types:
3137                return [x.zeros_like(dtype=theano.config.floatX)]
3138            else:
3139                return [x.zeros_like()]
3140
3141        return gz * np.asarray(180. / np.pi, gz.type),
3142
3143    def c_code(self, node, name, inputs, outputs, sub):
3144        (x,) = inputs
3145        (z,) = outputs
3146        if node.inputs[0].type in complex_types:
3147            raise NotImplementedError('type not supported', type)
3148        return "%(z)s = %(x)s * (180.0 / M_PI);" % locals()
3149rad2deg = Rad2Deg(upgrade_to_float, name='rad2deg')
3150
3151
3152class Cos(UnaryScalarOp):
3153    nfunc_spec = ('cos', 1, 1)
3154    amd_float32 = "amd_vrsa_cosf"
3155    amd_float64 = "amd_vrda_cos"
3156
3157    def impl(self, x):
3158        # If x is an int8 or uint8, numpy.cos will compute the result in
3159        # half-precision (float16), where we want float32.
3160        x_dtype = str(getattr(x, 'dtype', ''))
3161        if x_dtype in ('int8', 'uint8'):
3162            return np.cos(x, sig='f')
3163        return np.cos(x)
3164
3165    def L_op(self, inputs, outputs, gout):
3166        (x,) = inputs
3167        (gz,) = gout
3168        if gz.type in complex_types:
3169            raise NotImplementedError()
3170        if outputs[0].type in discrete_types:
3171            if x.type in discrete_types:
3172                return [x.zeros_like(dtype=theano.config.floatX)]
3173            else:
3174                return [x.zeros_like()]
3175
3176        return -gz * sin(x),
3177
3178    def c_code(self, node, name, inputs, outputs, sub):
3179        (x,) = inputs
3180        (z,) = outputs
3181        if node.inputs[0].type in complex_types:
3182            raise NotImplementedError('type not supported', type)
3183        cast = node.outputs[0].type.dtype_specs()[1]
3184        return "%(z)s = cos((%(cast)s)%(x)s);" % locals()
3185cos = Cos(upgrade_to_float, name='cos')
3186
3187
3188class ArcCos(UnaryScalarOp):
3189    nfunc_spec = ('arccos', 1, 1)
3190
3191    def impl(self, x):
3192        # If x is an int8 or uint8, numpy.arccos will compute the result in
3193        # half-precision (float16), where we want float32.
3194        x_dtype = str(getattr(x, 'dtype', ''))
3195        if x_dtype in ('int8', 'uint8'):
3196            return np.arccos(x, sig='f')
3197        return np.arccos(x)
3198
3199    def L_op(self, inputs, outputs, gout):
3200        (x,) = inputs
3201        (gz,) = gout
3202        if gz.type in complex_types:
3203            raise NotImplementedError()
3204        if outputs[0].type in discrete_types:
3205            if x.type in discrete_types:
3206                return [x.zeros_like(dtype=theano.config.floatX)]
3207            else:
3208                return [x.zeros_like()]
3209
3210        return - gz / sqrt(np.cast[x.type](1) - sqr(x)),
3211
3212    def c_code(self, node, name, inputs, outputs, sub):
3213        (x,) = inputs
3214        (z,) = outputs
3215        if node.inputs[0].type in complex_types:
3216            raise NotImplementedError('type not supported', type)
3217        cast = node.outputs[0].type.dtype_specs()[1]
3218        return "%(z)s = acos((%(cast)s)%(x)s);" % locals()
3219arccos = ArcCos(upgrade_to_float, name='arccos')
3220
3221
3222class Sin(UnaryScalarOp):
3223    nfunc_spec = ('sin', 1, 1)
3224    amd_float32 = "amd_vrsa_sinf"
3225    amd_float64 = "amd_vrda_sin"
3226
3227    def impl(self, x):
3228        # If x is an int8 or uint8, numpy.sin will compute the result in
3229        # half-precision (float16), where we want float32.
3230        x_dtype = str(getattr(x, 'dtype', ''))
3231        if x_dtype in ('int8', 'uint8'):
3232            return np.sin(x, sig='f')
3233        return np.sin(x)
3234
3235    def L_op(self, inputs, outputs, gout):
3236        (x,) = inputs
3237        (gz,) = gout
3238        if x.type in complex_types:
3239            raise NotImplementedError()
3240        if outputs[0].type in discrete_types:
3241            if x.type in discrete_types:
3242                return [x.zeros_like(dtype=theano.config.floatX)]
3243            else:
3244                return [x.zeros_like()]
3245
3246        return gz * cos(x),
3247
3248    def c_code(self, node, name, inputs, outputs, sub):
3249        (x,) = inputs
3250        (z,) = outputs
3251        if node.inputs[0].type in complex_types:
3252            raise NotImplementedError('type not supported', type)
3253        cast = node.outputs[0].type.dtype_specs()[1]
3254        return "%(z)s = sin((%(cast)s)%(x)s);" % locals()
3255sin = Sin(upgrade_to_float, name='sin')
3256
3257
3258class ArcSin(UnaryScalarOp):
3259    nfunc_spec = ('arcsin', 1, 1)
3260
3261    def impl(self, x):
3262        # If x is an int8 or uint8, numpy.arcsin will compute the result in
3263        # half-precision (float16), where we want float32.
3264        x_dtype = str(getattr(x, 'dtype', ''))
3265        if x_dtype in ('int8', 'uint8'):
3266            return np.arcsin(x, sig='f')
3267        return np.arcsin(x)
3268
3269    def L_op(self, inputs, outputs, gout):
3270        (x,) = inputs
3271        (gz,) = gout
3272        if gz.type in complex_types:
3273            raise NotImplementedError()
3274        if outputs[0].type in discrete_types:
3275            if x.type in discrete_types:
3276                return [x.zeros_like(dtype=theano.config.floatX)]
3277            else:
3278                return [x.zeros_like()]
3279
3280        return gz / sqrt(np.cast[x.type](1) - sqr(x)),
3281
3282    def c_code(self, node, name, inputs, outputs, sub):
3283        (x,) = inputs
3284        (z,) = outputs
3285        if node.inputs[0].type in complex_types:
3286            raise NotImplementedError('type not supported', type)
3287        cast = node.outputs[0].type.dtype_specs()[1]
3288        return "%(z)s = asin((%(cast)s)%(x)s);" % locals()
3289arcsin = ArcSin(upgrade_to_float, name='arcsin')
3290
3291
3292class Tan(UnaryScalarOp):
3293    nfunc_spec = ('tan', 1, 1)
3294
3295    def impl(self, x):
3296        # If x is an int8 or uint8, numpy.tan will compute the result in
3297        # half-precision (float16), where we want float32.
3298        x_dtype = str(getattr(x, 'dtype', ''))
3299        if x_dtype in ('int8', 'uint8'):
3300            return np.tan(x, sig='f')
3301        return np.tan(x)
3302
3303    def L_op(self, inputs, outputs, gout):
3304        (x,) = inputs
3305        (gz,) = gout
3306        if x.type in complex_types:
3307            raise NotImplementedError()
3308        if outputs[0].type in discrete_types:
3309            if x.type in discrete_types:
3310                return [x.zeros_like(dtype=theano.config.floatX)]
3311            else:
3312                return [x.zeros_like()]
3313
3314        return gz / sqr(cos(x)),
3315
3316    def c_code(self, node, name, inputs, outputs, sub):
3317        (x,) = inputs
3318        (z,) = outputs
3319        if node.inputs[0].type in complex_types:
3320            raise NotImplementedError('type not supported', type)
3321        cast = node.outputs[0].type.dtype_specs()[1]
3322        return "%(z)s = tan((%(cast)s)%(x)s);" % locals()
3323tan = Tan(upgrade_to_float, name='tan')
3324
3325
3326class ArcTan(UnaryScalarOp):
3327    nfunc_spec = ('arctan', 1, 1)
3328
3329    def impl(self, x):
3330        # If x is an int8 or uint8, numpy.arctan will compute the result in
3331        # half-precision (float16), where we want float32.
3332        x_dtype = str(getattr(x, 'dtype', ''))
3333        if x_dtype in ('int8', 'uint8'):
3334            return np.arctan(x, sig='f')
3335        return np.arctan(x)
3336
3337    def L_op(self, inputs, outputs, gout):
3338        (x,) = inputs
3339        (gz,) = gout
3340        if gz.type in complex_types:
3341            raise NotImplementedError()
3342        if outputs[0].type in discrete_types:
3343            if x.type in discrete_types:
3344                return [x.zeros_like(dtype=theano.config.floatX)]
3345            else:
3346                return [x.zeros_like()]
3347
3348        return gz / (np.cast[x.type](1) + sqr(x)),
3349
3350    def c_code(self, node, name, inputs, outputs, sub):
3351        (x,) = inputs
3352        (z,) = outputs
3353        if node.inputs[0].type in complex_types:
3354            raise NotImplementedError('type not supported', type)
3355        cast = node.outputs[0].type.dtype_specs()[1]
3356        return "%(z)s = atan((%(cast)s)%(x)s);" % locals()
3357arctan = ArcTan(upgrade_to_float, name='arctan')
3358
3359
3360class ArcTan2(BinaryScalarOp):
3361    nfunc_spec = ('arctan2', 2, 1)
3362
3363    def impl(self, y, x):
3364        # If x and y are int8 or uint8, numpy.arctan2 will compute the result
3365        # in half-precision (float16), where we want float32.
3366        x_dtype = str(getattr(x, 'dtype', ''))
3367        if x_dtype in ('int8', 'uint8'):
3368            y_dtype = str(getattr(x, 'dtype', ''))
3369            if y_dtype in ('int8', 'uint8'):
3370                return np.arctan2(y, x, sig='f')
3371        return np.arctan2(y, x)
3372
3373    def L_op(self, inputs, outputs, gout):
3374        (y, x) = inputs
3375        (gz,) = gout
3376        if gz.type in complex_types:
3377            raise NotImplementedError()
3378        else:
3379            if outputs[0].type in discrete_types:
3380                if x.type in discrete_types:
3381                    gx = x.zeros_like(dtype=theano.config.floatX)
3382                else:
3383                    gx = x.zeros_like()
3384                if y.type in discrete_types:
3385                    gy = y.zeros_like(dtype=theano.config.floatX)
3386                else:
3387                    gy = y.zeros_like()
3388                return [gx, gy]
3389
3390            # If the output is float, the gradient should flow,
3391            # even if the inputs are ints
3392            return [gz * x / (sqr(x) + sqr(y)),
3393                    gz * neg(y) / (sqr(x) + sqr(y))]
3394
3395    def c_code(self, node, name, inputs, outputs, sub):
3396        (y, x) = inputs
3397        (z,) = outputs
3398        if (node.inputs[0].type in complex_types or
3399                node.inputs[1].type in complex_types):
3400            raise NotImplementedError('type not supported', type)
3401        cast = node.outputs[0].type.dtype_specs()[1]
3402        return "%(z)s = atan2((%(cast)s)%(y)s, (%(cast)s)%(x)s);" % locals()
3403arctan2 = ArcTan2(upgrade_to_float, name='arctan2')
3404
3405
3406class Cosh(UnaryScalarOp):
3407    """
3408    cosh(x) = (exp(x) + exp(-x)) / 2.
3409
3410    """
3411    nfunc_spec = ('cosh', 1, 1)
3412
3413    def impl(self, x):
3414        # If x is an int8 or uint8, numpy.cosh will compute the result in
3415        # half-precision (float16), where we want float32.
3416        x_dtype = str(getattr(x, 'dtype', ''))
3417        if x_dtype in ('int8', 'uint8'):
3418            return np.cosh(x, sig='f')
3419        return np.cosh(x)
3420
3421    def L_op(self, inputs, outputs, gout):
3422        (x,) = inputs
3423        (gz,) = gout
3424        if x.type in complex_types:
3425            raise NotImplementedError()
3426        if outputs[0].type in discrete_types:
3427            if x.type in discrete_types:
3428                return [x.zeros_like(dtype=theano.config.floatX)]
3429            else:
3430                return [x.zeros_like()]
3431
3432        return gz * sinh(x),
3433
3434    def c_code(self, node, name, inputs, outputs, sub):
3435        (x,) = inputs
3436        (z,) = outputs
3437        if node.inputs[0].type in complex_types:
3438            raise NotImplementedError('type not supported', type)
3439        cast = node.outputs[0].type.dtype_specs()[1]
3440        return "%(z)s = cosh((%(cast)s)%(x)s);" % locals()
3441cosh = Cosh(upgrade_to_float, name='cosh')
3442
3443
3444class ArcCosh(UnaryScalarOp):
3445    nfunc_spec = ('arccosh', 1, 1)
3446
3447    def impl(self, x):
3448        # If x is an int8 or uint8, numpy.arccosh will compute the result in
3449        # half-precision (float16), where we want float32.
3450        x_dtype = str(getattr(x, 'dtype', ''))
3451        if x_dtype in ('int8', 'uint8'):
3452            return np.arccosh(x, sig='f')
3453        return np.arccosh(x)
3454
3455    def L_op(self, inputs, outputs, gout):
3456        (x,) = inputs
3457        (gz,) = gout
3458        if x.type in complex_types:
3459            raise NotImplementedError()
3460        if outputs[0].type in discrete_types:
3461            if x.type in discrete_types:
3462                return [x.zeros_like(dtype=theano.config.floatX)]
3463            else:
3464                return [x.zeros_like()]
3465
3466        return gz / sqrt(sqr(x) - np.cast[x.type](1)),
3467
3468    def c_code(self, node, name, inputs, outputs, sub):
3469        (x,) = inputs
3470        (z,) = outputs
3471        if node.inputs[0].type in complex_types:
3472            raise NotImplementedError('type not supported', type)
3473        cast = node.outputs[0].type.dtype_specs()[1]
3474        return "%(z)s = acosh((%(cast)s)%(x)s);" % locals()
3475arccosh = ArcCosh(upgrade_to_float, name='arccosh')
3476
3477
3478class Sinh(UnaryScalarOp):
3479    """
3480    sinh(x) = (exp(x) - exp(-x)) / 2.
3481
3482    """
3483    nfunc_spec = ('sinh', 1, 1)
3484
3485    def impl(self, x):
3486        # If x is an int8 or uint8, numpy.sinh will compute the result in
3487        # half-precision (float16), where we want float32.
3488        x_dtype = str(getattr(x, 'dtype', ''))
3489        if x_dtype in ('int8', 'uint8'):
3490            return np.sinh(x, sig='f')
3491        return np.sinh(x)
3492
3493    def L_op(self, inputs, outputs, gout):
3494        (x,) = inputs
3495        (gz,) = gout
3496        if x.type in complex_types:
3497            raise NotImplementedError()
3498        if outputs[0].type in discrete_types:
3499            if x.type in discrete_types:
3500                return [x.zeros_like(dtype=theano.config.floatX)]
3501            else:
3502                return [x.zeros_like()]
3503
3504        return gz * cosh(x),
3505
3506    def c_code(self, node, name, inputs, outputs, sub):
3507        (x,) = inputs
3508        (z,) = outputs
3509        if node.inputs[0].type in complex_types:
3510            raise NotImplementedError('type not supported', type)
3511        cast = node.outputs[0].type.dtype_specs()[1]
3512        return "%(z)s = sinh((%(cast)s)%(x)s);" % locals()
3513sinh = Sinh(upgrade_to_float, name='sinh')
3514
3515
3516class ArcSinh(UnaryScalarOp):
3517    nfunc_spec = ('arcsinh', 1, 1)
3518
3519    def impl(self, x):
3520        # If x is an int8 or uint8, numpy.arcsinh will compute the result in
3521        # half-precision (float16), where we want float32.
3522        x_dtype = str(getattr(x, 'dtype', ''))
3523        if x_dtype in ('int8', 'uint8'):
3524            return np.arcsinh(x, sig='f')
3525        return np.arcsinh(x)
3526
3527    def L_op(self, inputs, outputs, gout):
3528        (x,) = inputs
3529        (gz,) = gout
3530        if x.type in complex_types:
3531            raise NotImplementedError()
3532        if outputs[0].type in discrete_types:
3533            if x.type in discrete_types:
3534                return [x.zeros_like(dtype=theano.config.floatX)]
3535            else:
3536                return [x.zeros_like()]
3537
3538        return gz / sqrt(sqr(x) + np.cast[x.type](1)),
3539
3540    def c_code(self, node, name, inputs, outputs, sub):
3541        (x,) = inputs
3542        (z,) = outputs
3543        if node.inputs[0].type in complex_types:
3544            raise NotImplementedError('type not supported', type)
3545        cast = node.outputs[0].type.dtype_specs()[1]
3546        return "%(z)s = asinh((%(cast)s)%(x)s);" % locals()
3547arcsinh = ArcSinh(upgrade_to_float, name='arcsinh')
3548
3549
3550class Tanh(UnaryScalarOp):
3551    """
3552    tanh(x) = sinh(x) / cosh(x)
3553            = (exp(2*x) - 1) / (exp(2*x) + 1).
3554
3555    """
3556    nfunc_spec = ('tanh', 1, 1)
3557
3558    def impl(self, x):
3559        # If x is an int8 or uint8, numpy.tanh will compute the result in
3560        # half-precision (float16), where we want float32.
3561        x_dtype = str(getattr(x, 'dtype', ''))
3562        if x_dtype in ('int8', 'uint8'):
3563            return np.tanh(x, sig='f')
3564        return np.tanh(x)
3565
3566    def L_op(self, inputs, outputs, gout):
3567        (x,) = inputs
3568        (gz,) = gout
3569        if x.type in complex_types:
3570            raise NotImplementedError()
3571        if outputs[0].type in discrete_types:
3572            if x.type in discrete_types:
3573                return [x.zeros_like(dtype=theano.config.floatX)]
3574            else:
3575                return [x.zeros_like()]
3576
3577        return gz * (1 - sqr(tanh(x))),
3578
3579    def c_code(self, node, name, inputs, outputs, sub):
3580        (x,) = inputs
3581        (z,) = outputs
3582        if node.inputs[0].type in complex_types:
3583            raise NotImplementedError('type not supported', type)
3584        cast = node.outputs[0].type.dtype_specs()[1]
3585        return "%(z)s = tanh((%(cast)s)%(x)s);" % locals()
3586tanh = Tanh(upgrade_to_float, name='tanh')
3587
3588
3589class ArcTanh(UnaryScalarOp):
3590    nfunc_spec = ('arctanh', 1, 1)
3591
3592    def impl(self, x):
3593        # If x is an int8 or uint8, numpy.arctanh will compute the result in
3594        # half-precision (float16), where we want float32.
3595        x_dtype = str(getattr(x, 'dtype', ''))
3596        if x_dtype in ('int8', 'uint8'):
3597            return np.arctanh(x, sig='f')
3598        return np.arctanh(x)
3599
3600    def L_op(self, inputs, outputs, gout):
3601        (x,) = inputs
3602        (gz,) = gout
3603        if x.type in complex_types:
3604            raise NotImplementedError()
3605        if outputs[0].type in discrete_types:
3606            if x.type in discrete_types:
3607                return [x.zeros_like(dtype=theano.config.floatX)]
3608            else:
3609                return [x.zeros_like()]
3610
3611        return gz / (np.cast[x.type](1) - sqr(x)),
3612
3613    def c_code(self, node, name, inputs, outputs, sub):
3614        (x,) = inputs
3615        (z,) = outputs
3616        if node.inputs[0].type in complex_types:
3617            raise NotImplementedError('type not supported', type)
3618        cast = node.outputs[0].type.dtype_specs()[1]
3619        return "%(z)s = atanh((%(cast)s)%(x)s);" % locals()
3620arctanh = ArcTanh(upgrade_to_float, name='arctanh')
3621
3622
3623class Real(UnaryScalarOp):
3624    """
3625    Extract the real coordinate of a complex number.
3626
3627    """
3628    # numpy.real(float32) return a view on the inputs.
3629    # nfunc_spec = ('real', 1, 1)
3630
3631    def impl(self, x):
3632        return np.real(x)
3633
3634    def grad(self, inputs, gout):
3635        (x,) = inputs
3636        (gz,) = gout
3637        return [complex(gz, 0)]
3638
3639real = Real(real_out, name='real')
3640
3641
3642class Imag(UnaryScalarOp):
3643    nfunc_spec = ('imag', 1, 1)
3644
3645    def impl(self, x):
3646        return np.imag(x)
3647
3648    def grad(self, inputs, gout):
3649        (x,) = inputs
3650        (gz,) = gout
3651        if x.type in complex_types:
3652            return [complex(0, gz)]
3653        elif x.type in float_types:
3654            return [second(x, 0)]
3655        else:
3656            return [x.zeros_like(dtype=theano.config.floatX)]
3657
3658imag = Imag(real_out, name='imag')
3659
3660
3661class Angle(UnaryScalarOp):
3662    nfunc_spec = ('angle', 1, 1)
3663
3664    def impl(self, x):
3665        return np.angle(x)
3666
3667    def grad(self, inputs, gout):
3668        # y = x.imag
3669        # r = sqrt(y**2 + x.real**2)
3670        # g = y/r
3671        # if x == 0 and y == 0:
3672        #     theta = 0
3673        # elif x >= 0:
3674        #     theta = numpy.arcsin(g)
3675        # else:
3676        #     theta = -numpy.arcsin(g)+numpy.pi
3677
3678        (c,) = inputs
3679        (gtheta,) = gout
3680        x = real(c)
3681        y = imag(c)
3682        r = abs(c)
3683
3684        gr = -gtheta * y / (r ** 2 * sqrt(1 - (y / r) ** 2))
3685        gx = gr * x / r
3686        gy = gr * y / r
3687        if c in complex_types:
3688            return [cast(complex(gx, gy), x.type.dtype)]
3689        elif c in float_types:
3690            return [cast(second(x, 0), x.type.dtype)]
3691        else:
3692            return [c.zeros_like(dtype=theano.config.floatX)]
3693
3694angle = Angle(specific_out(float64), name='angle')
3695
3696
3697class Complex(BinaryScalarOp):
3698    @staticmethod
3699    def output_types_preference(x, y):
3700        if x in complex_types:
3701            raise TypeError(x)
3702        if y in complex_types:
3703            raise TypeError(y)
3704
3705        up = Scalar.upcast(x, y)
3706        if up in ('float64', 'int64', 'uint64', 'int32', 'uint32'):
3707            return [complex128]
3708        else:
3709            return [complex64]
3710
3711    def impl(self, x, y):
3712        return np.complex(x, y)
3713
3714    def grad(self, inputs, gout):
3715        (x, y) = inputs
3716        (gz,) = gout
3717        return [cast(real(gz), x.type.dtype),
3718                cast(imag(gz), y.type.dtype)]
3719complex = Complex(name='complex')
3720
3721
3722class Conj(UnaryScalarOp):
3723    nfunc_spec = ('conj', 1, 1)
3724
3725    def impl(self, x):
3726        return np.conj(x)
3727
3728    def c_code(self, node, name, inputs, outputs, sub):
3729        (x,) = inputs
3730        (z,) = outputs
3731        if node.inputs[0].type in complex_types:
3732            # For non complex, th
3733            raise NotImplementedError('type have no c code',
3734                                      node.inputs[0].type)
3735        return "%(z)s = %(x)s;" % locals()
3736
3737conj = Conj(same_out_min8, name='conj')
3738
3739
3740class ComplexFromPolar(BinaryScalarOp):
3741    @staticmethod
3742    def output_types_preference(x, y):
3743        return Complex.output_types_preference(x, y)
3744
3745    def impl(self, r, theta):
3746        if r < 0:
3747            raise ValueError('polar radius must be non-negative', r)
3748        x = r * np.cos(theta)
3749        y = r * np.sin(theta)
3750        if x.dtype == 'float32':
3751            return np.complex64(np.complex(x, y))
3752        else:
3753            return np.complex128(np.complex(x, y))
3754
3755    def grad(self, inputs, gout):
3756        (r, theta) = inputs
3757        (gz,) = gout
3758        gr = gz * complex_from_polar(1, theta)
3759        gtheta = gz * complex_from_polar(r, -theta)
3760        return [gr, gtheta]
3761complex_from_polar = ComplexFromPolar(name='complex_from_polar')
3762
3763
3764class Composite(ScalarOp):
3765    """
3766    Composite is an Op that takes a graph of scalar operations and
3767    produces c code for the whole graph. Its purpose is to implement loop
3768    fusion.
3769
3770    Composite depends on all the Ops in its graph having C code.
3771
3772    """
3773    init_param = ('inputs', 'outputs')
3774
3775    def __str__(self):
3776        if self.name is None:
3777            self.init_name()
3778        return self.name
3779
3780    def make_new_inplace(self, output_types_preference=None, name=None):
3781        """
3782        This op.__init__ fct don't have the same parameter as other scalar op.
3783        This break the insert_inplace_optimizer optimization.
3784        This fct allow fix patch this.
3785
3786        """
3787        d = dict([(k, getattr(self, k)) for k in self.init_param])
3788        out = self.__class__(**d)
3789        if name:
3790            out.name = name
3791        else:
3792            name = out.name
3793        super(Composite, out).__init__(output_types_preference, name)
3794        return out
3795
3796    def init_c_code(self):
3797        """
3798        Assemble the C code for this Composite Op.
3799
3800        The result is assigned to `self._c_code`.
3801        """
3802        # It was already called
3803        if hasattr(self, '_c_code'):
3804            return
3805        subd = dict(chain(
3806            ((e, "%%(i%i)s" % i) for i, e in enumerate(self.fgraph.inputs)),
3807            ((e, "%%(o%i)s" % i) for i, e in enumerate(self.fgraph.outputs))))
3808
3809        for var in self.fgraph.variables:
3810            if var.owner is None:
3811                if var not in self.fgraph.inputs:
3812                    # This is an orphan
3813                    if isinstance(var, Constant):
3814                        subd[var] = var.type.c_literal(var.data)
3815                    else:
3816                        raise ValueError(
3817                            "All orphans in the fgraph to Composite must"
3818                            " be Constant instances.")
3819            elif (any(i.dtype == 'float16' for i in var.owner.inputs) or
3820                  any(o.dtype == 'float16' for o in var.owner.outputs)):
3821                # flag for elemwise ops to check.
3822                self.inner_float16 = True
3823
3824        _c_code = "{\n"
3825        self.nodenames = ["%(nodename)s_" + ('subnode%i' % j)
3826                          for j, n in enumerate(self.fgraph.toposort())]
3827
3828        i = 0
3829        for j, node in enumerate(self.fgraph.toposort()):
3830            for output in node.outputs:
3831                if output not in subd:
3832                    i += 1
3833                    name = "V%%(id)s_tmp%i" % i
3834                    subd[output] = name
3835                    _c_code += "%s %s;\n" % (
3836                        output.type.dtype_specs()[1], name)
3837            s = node.op.c_code(
3838                node,
3839                self.nodenames[j],
3840                [subd[input] for input in node.inputs],
3841                [subd[output] for output in node.outputs],
3842                dict(fail="%(fail)s", id="%%(id)s_%i" % j))
3843            _c_code += s
3844            _c_code += "\n"
3845        _c_code += "}\n"
3846        self._c_code = _c_code
3847
3848    def init_py_impls(self):
3849        """
3850        Return a list of functions that compute each output of self.
3851
3852        """
3853        # In the case where the graph is a dag, but not a tree like:
3854        # add(*1 -> mul(x, y), *1)
3855
3856        # We have an efficient way to build the executable (we build
3857        # and traverse each node only once).
3858
3859        # But we don't have an efficient execution. We will execute
3860        # like a tree, so nodes that have more then 1 client will be
3861        # executed as many times as there number of clients. In the
3862        # example aboce, it will calculate *1 twice. Doing otherwise
3863        # imply making a complicated execution engine.
3864
3865        # We need the fast creation of the executor as we always do it
3866        # even if we will use the c code. The Python implementation is
3867        # already slow, so it is not as much important to have a fast
3868        # execution there.
3869
3870        memo = {}
3871
3872        def compose_impl(r):
3873            if r in memo:
3874                return memo[r]
3875            if r in self.fgraph.inputs:
3876                idx = self.fgraph.inputs.index(r)
3877
3878                def f(inputs):
3879                    return inputs[idx]
3880                memo[r] = f
3881                return f
3882            elif r.owner is None:  # in fgraph.orphans:
3883                def f(inputs):
3884                    return r.data
3885                memo[r] = f
3886                return f
3887            node = r.owner
3888            producers = [compose_impl(input) for input in node.inputs]
3889
3890            def f(inputs):
3891                return node.op.impl(*[p(inputs) for p in producers])
3892            memo[r] = f
3893            return f
3894        self._impls = [compose_impl(r) for r in self.fgraph.outputs]
3895
3896    def init_name(self):
3897        """
3898        Return a readable string representation of self.fgraph.
3899
3900        """
3901        rval = self.name
3902        if rval is None:
3903            for i, r in enumerate(self.fgraph.inputs):
3904                r.name = 'i%i' % i
3905            for i, r in enumerate(self.fgraph.outputs):
3906                r.name = 'o%i' % i
3907            io = set(self.fgraph.inputs + self.fgraph.outputs)
3908            for i, r in enumerate(self.fgraph.variables):
3909                if r not in io and len(r.clients) > 1:
3910                    r.name = 't%i' % i
3911            rval = "Composite{%s}" % ', '.join([pprint(output) for output
3912                                                in self.fgraph.outputs])
3913            self.name = rval
3914
3915    def init_fgraph(self):
3916        # The clone done by FunctionGraph is needed as we don't want
3917        # the fgraph to be set to the variable as we need to pickle
3918        # them for the cache of c module to work.
3919        fgraph = FunctionGraph(self.inputs, self.outputs)
3920        gof.MergeOptimizer().optimize(fgraph)
3921        for node in fgraph.apply_nodes:
3922            if not isinstance(node.op, ScalarOp):
3923                raise ValueError("The fgraph to Composite must be exclusively"
3924                                 " composed of ScalarOp instances.")
3925        self.fgraph = fgraph
3926
3927    def __init__(self, inputs, outputs):
3928        # We need to clone the graph as sometimes its nodes already
3929        # contain a reference to an fgraph. As we want the Composite
3930        # to be pickable, we can't have reference to fgraph.
3931
3932        # Also, if there is Composite in the inner graph, we want to
3933        # remove them. In that case, we do a more complicated clone
3934        # that will flatten Composite. We don't need to do this
3935        # recusively, as the way the fusion optimizer work, we have
3936        # only 1 new Composite each time at the output.
3937        for i in inputs:
3938            assert i not in outputs  # This isn't supported, use identity
3939        if len(outputs) > 1 or not any([isinstance(var.owner.op, Composite)
3940                                        for var in outputs]):
3941            # No inner Composite
3942            inputs, outputs = gof.graph.clone(inputs, outputs)
3943        else:
3944            # Inner Composite that we need to flatten
3945            assert len(outputs) == 1
3946            # 1. Create a new graph from inputs up to the
3947            # Composite
3948            res = theano.compile.rebuild_collect_shared(
3949                inputs=inputs,
3950                outputs=outputs[0].owner.inputs,
3951                copy_inputs_over=False)  # Clone also the inputs
3952            # 2. We continue this partial clone with the graph in
3953            # the inner Composite
3954            res2 = theano.compile.rebuild_collect_shared(
3955                inputs=outputs[0].owner.op.inputs,
3956                outputs=outputs[0].owner.op.outputs,
3957                replace=dict(izip(outputs[0].owner.op.inputs, res[1]))
3958            )
3959            assert len(res2[1]) == len(outputs)
3960            assert len(res[0]) == len(inputs)
3961            assert res[0] != inputs
3962            inputs, outputs = res[0], res2[1]
3963            # Next assert comment just for speed
3964            # assert not any([isinstance(node.op, Composite) for node in
3965            #                theano.gof.graph.ops(inputs, outputs)])
3966
3967        self.inputs = copy(inputs)
3968        self.outputs = copy(outputs)
3969        self.inputs_type = tuple([input.type for input in inputs])
3970        self.outputs_type = tuple([output.type for output in outputs])
3971        self.nin = len(inputs)
3972        self.nout = len(outputs)
3973        self.init_fgraph()       # self.fgraph
3974        # Postpone the creation in case it isn't needed.
3975        #  self.init_name()      # self.name
3976        self.name = None
3977        self.prepare_node_called = set()
3978
3979    def prepare_node(self, node, storage_map, compute_map, impl):
3980        if impl == 'py':
3981            self.init_py_impls()  # self._impls
3982        if impl not in self.prepare_node_called:
3983            for n in theano.gof.graph.list_of_nodes(self.inputs, self.outputs):
3984                n.op.prepare_node(n, None, None, impl)
3985            self.prepare_node_called.add(impl)
3986
3987    def clone_float32(self):
3988        # This will not modify the fgraph or the nodes
3989        new_ins, new_outs = composite_f32.apply(self.fgraph)
3990        return Composite(new_ins, new_outs)
3991
3992    def output_types(self, input_types):
3993        if tuple(input_types) != self.inputs_type:
3994            raise TypeError("Wrong types for Composite. Expected %s, got %s."
3995                            % (self.inputs_type, tuple(input_types)))
3996        return self.outputs_type
3997
3998    def make_node(self, *inputs):
3999        if (tuple([i.type for i in self.inputs]) ==
4000                tuple([i.type for i in inputs])):
4001            return super(Composite, self).make_node(*inputs)
4002        else:
4003            # Make a new op with the right input type.
4004            assert len(inputs) == self.nin
4005            res = theano.compile.rebuild_collect_shared(
4006                self.outputs,
4007                replace=dict(izip(self.inputs, inputs)),
4008                rebuild_strict=False)
4009            # After rebuild_collect_shared, the Variable in inputs
4010            # are not necessarily in the graph represented by res.
4011            # res[2][0] is a dict that map from the original variable to the
4012            # cloned variable.
4013            cloned_inputs = [res[2][0][i] for i in inputs]
4014            node = Composite(cloned_inputs, res[1]).make_node(*inputs)
4015            return node
4016
4017    def perform(self, node, inputs, output_storage):
4018        for storage, impl in zip(output_storage, self._impls):
4019            storage[0] = impl(inputs)
4020
4021    def impl(self, *inputs):
4022        output_storage = [[None] for i in xrange(self.nout)]
4023        self.perform(None, inputs, output_storage)
4024        ret = utils.to_return_values([storage[0] for storage in
4025                                      output_storage])
4026        if self.nout > 1:
4027            ret = tuple(ret)
4028        return ret
4029
4030    def grad(self, inputs, output_grads):
4031        raise NotImplementedError("grad is not implemented for Composite")
4032
4033    def c_code(self, node, nodename, inames, onames, sub):
4034        self.init_c_code()
4035
4036        d = dict(chain(izip(("i%i" % i for i in xrange(len(inames))), inames),
4037                       izip(("o%i" % i for i in xrange(len(onames))),
4038                            onames)), **sub)
4039        d['nodename'] = nodename
4040        if 'id' not in sub:
4041            # The use of a dummy id is safe as the code is in a separate block.
4042            # It won't generate conflicting variable name.
4043            d['id'] = '_DUMMY_ID_'
4044
4045        return self._c_code % d
4046
4047    def c_code_cache_version(self):
4048        rval = [3]
4049        for x in self.fgraph.toposort():
4050            xv = x.op.c_code_cache_version()
4051            if xv:
4052                rval.append(xv)
4053            else:
4054                return ()
4055        return tuple(rval)
4056
4057    def c_support_code(self):
4058        rval = []
4059        for subnode in self.fgraph.toposort():
4060            try:
4061                rval.append(subnode.op.c_support_code().strip())
4062            except gof.utils.MethodNotDefined:
4063                pass
4064        # remove duplicate code blocks
4065        return "\n".join(sorted(set(rval)))
4066
4067    def c_support_code_apply(self, node, name):
4068        self.init_c_code()
4069        rval = []
4070        for subnode, subnodename in zip(self.fgraph.toposort(), self.nodenames):
4071            try:
4072                subnode_support_code = subnode.op.c_support_code_apply(
4073                    subnode,
4074                    subnodename % dict(nodename=name))
4075                if subnode_support_code:
4076                    rval.append(subnode_support_code)
4077            except gof.utils.MethodNotDefined:
4078                pass
4079        # there should be no need to remove duplicate code blocks because
4080        # each block should have been specialized for the given nodename.
4081        # Any block that isn't specialized should be returned via
4082        # c_support_code instead of c_support_code_apply.
4083        return "\n".join(rval)
4084
4085    def __eq__(self, other):
4086        if self is other:
4087            return True
4088        if (type(self) != type(other) or
4089                self.nin != other.nin or
4090                self.nout != other.nout):
4091            return False
4092        # see __hash__ for comment on why there is no mention of fgraph
4093        # or module cache key here.
4094        self.init_c_code()    # self._c_code and self.nodenames
4095        other.init_c_code()
4096        return (self._c_code == other._c_code)
4097
4098    def __hash__(self):
4099        self.init_c_code()    # self._c_code and self.nodenames
4100        rval = hash((type(self),
4101                    self.nin,
4102                    self.nout,
4103                    self._c_code))
4104        # Note that in general, the configparser settings at the time
4105        # of code generation (__init__) affect the semantics of this Op.
4106        # This function assumes that all relevant info about the configparser
4107        # is embodied in _c_code.  So the _c_code, rather than self.fgraph,
4108        # is the signature of the semantics of this Op.
4109        # _c_code is preserved through unpickling, so the Op will not change
4110        # semantics when it is reloaded with different configparser
4111        # settings.
4112        return rval
4113
4114    def __getstate__(self):
4115        rval = dict(self.__dict__)
4116        rval.pop('_impls', None)
4117        rval.pop('prepare_node_called', None)
4118        del rval['fgraph']
4119        return rval
4120
4121    def __setstate__(self, d):
4122        self.__dict__.update(d)
4123        # We must call init to set fgraph and _impls again, as otherwise
4124        # self.perform will not work.
4125        self.prepare_node_called = set()
4126        self.init_fgraph()
4127        self.init_py_impls()
4128
4129
4130class Compositef32(object):
4131    # This is a dict of scalar op classes that need special handling
4132    special = {}
4133
4134    def apply(self, fgraph):
4135        mapping = {}
4136        topo = fgraph.toposort()
4137        for i in fgraph.inputs:
4138            if i.dtype == 'float16':
4139                mapping[i] = get_scalar_type('float32')()
4140                if hasattr(i.tag, 'test_value'):
4141                    mapping[i].tag.test_value = i.tag.test_value
4142            else:
4143                mapping[i] = i
4144        for node in topo:
4145            # Patch up for constants
4146            for i in node.inputs:
4147                if i not in mapping:
4148                    assert type(i) is ScalarConstant
4149                    if i.type == float16:
4150                        ni = ScalarConstant(float32, i.data)
4151                    else:
4152                        ni = i
4153                    mapping[i] = ni
4154            if type(node.op) in self.special:
4155                self.special[type(node.op)](node, mapping)
4156                continue
4157            new_node = node.clone_with_new_inputs(
4158                [mapping[inp] for inp in node.inputs],
4159                strict=False)
4160            # make sure we don't produce any float16.
4161            assert not any(o.dtype == 'float16' for o in new_node.outputs)
4162            for o, no in zip(node.outputs, new_node.outputs):
4163                mapping[o] = no
4164
4165        new_ins = [mapping[inp] for inp in fgraph.inputs]
4166        new_outs = [mapping[out] for out in fgraph.outputs]
4167        return new_ins, new_outs
4168
4169composite_f32 = Compositef32()
4170
4171
4172def handle_cast(node, mapping):
4173    inp = mapping[node.inputs[0]]
4174    out = node.outputs[0]
4175    node_ok = False
4176    if node.op.o_type == float16:
4177        if node.inputs[0].type == float32:
4178            # cast f32 -> f16, remove
4179            mapping[out] = inp
4180            return
4181        else:
4182            # cast to f16, convert to f32
4183            new_out = cast(inp, 'float32')
4184            # change the node for the following if
4185            node = new_out.owner
4186            mapping[out] = new_out
4187            node_ok = True
4188    if node.inputs[0].type == float16:
4189        if node.op.o_type == inp.type:
4190            # cast f16 to new input type, remove
4191            mapping[out] = inp
4192            return
4193    if not node_ok:
4194        new_node = node.clone_with_new_inputs([inp],
4195                                              strict=False)
4196        mapping[out] = new_node.outputs[0]
4197
4198Compositef32.special[Cast] = handle_cast
4199
4200
4201def handle_composite(node, mapping):
4202    new_op = node.op.clone_float32()
4203    new_outs = new_op(*[mapping[i] for i in node.inputs], return_list=True)
4204    assert len(new_outs) == len(node.outputs)
4205    for o, no in zip(node.outputs, new_outs):
4206        mapping[o] = no
4207
4208Compositef32.special[Composite] = handle_composite
4209