1"""A `Type` and `Op` classes to work with numpy.ndarrays symbolically."""
2from __future__ import absolute_import, print_function, division
3
4from six.moves import builtins
5import sys
6import warnings
7
8import numpy as np
9from six import integer_types
10from six.moves import xrange
11import numbers
12
13import theano
14from theano.compat import izip
15from theano import config
16from theano import gof
17from theano.gof import Apply, Constant, Op, Variable, ParamsType
18from theano.gof.type import Generic
19
20from theano.scalar import int32 as int32_t
21from theano.tensor import elemwise
22from theano.tensor.var import (AsTensorError, TensorVariable,
23                               TensorConstant, TensorConstantSignature,
24                               _tensor_py_operators)
25from theano.tensor.type import TensorType, values_eq_approx_always_true
26from theano.tensor.type_other import NoneConst
27from theano import scalar as scal
28from functools import partial
29from theano import compile, printing
30from theano.printing import pprint, min_informative_str
31# For history
32from theano.compile import Rebroadcast, Shape, shape
33from theano.scalar import int32
34
35
36# We use these exceptions as well.
37import theano.scalar.sharedvar
38from theano.gradient import grad_undefined
39from theano.gradient import grad_not_implemented
40from theano.gradient import DisconnectedType
41
42# set up the external interface
43from theano.tensor.elemwise import Elemwise, DimShuffle, CAReduce, Sum
44
45import logging
46_logger = logging.getLogger("theano.tensor.basic")
47
48__docformat__ = "restructuredtext en"
49
50# This is needed as we will hide it later
51python_complex = complex
52python_any = any
53python_all = all
54
55# Define common subsets of dtypes (as strings).
56complex_dtypes = list(map(str, scal.complex_types))
57continuous_dtypes = list(map(str, scal.continuous_types))
58float_dtypes = list(map(str, scal.float_types))
59integer_dtypes = list(map(str, scal.integer_types))
60discrete_dtypes = list(map(str, scal.discrete_types))
61all_dtypes = list(map(str, scal.all_types))
62int_dtypes = list(map(str, scal.int_types))
63uint_dtypes = list(map(str, scal.uint_types))
64
65
66class ShapeError(Exception):
67    """Raised when the shape cannot be computed."""
68    pass
69
70
71def check_equal_numpy(x, y):
72    """
73    Return True iff x and y are equal.
74
75    Checks the dtype and shape if x and y are numpy.ndarray instances.
76
77    """
78    if isinstance(x, np.ndarray) and isinstance(y, np.ndarray):
79        return (x.dtype == y.dtype and x.shape == y.shape and
80                np.all(abs(x - y) < 1e-10))
81    elif (isinstance(x, np.random.RandomState) and
82          isinstance(y, np.random.RandomState)):
83        return python_all(np.all(a == b) for a, b in
84                          izip(x.__getstate__(), y.__getstate__()))
85    else:
86        return x == y
87
88compile.register_checker(check_equal_numpy)
89
90
91__oplist_constructor_list = []
92"""List of functions to be listed as op constructors in the oplist
93(`gen_oplist`, doc/oplist.txt)."""
94
95
96def constructor(f):
97    """Add `f` to :doc:`oplist`.
98
99    Make `f` appear as a constructor in the oplist (`gen_oplist`,
100    doc/oplist.txt).
101
102    """
103    __oplist_constructor_list.append(f)
104    return f
105
106
107def __oplist_tag(thing, tag):
108    tags = getattr(thing, '__oplist_tags', [])
109    tags.append(tag)
110    thing.__oplist_tags = tags
111
112
113def as_tensor_variable(x, name=None, ndim=None):
114    """Return `x`, transformed into a `TensorType`.
115
116    This function is often used by `make_node` methods of `Op` subclasses
117    to turn ndarrays, numbers, `Scalar` instances, `Apply` instances and
118    `TensorType` instances into valid input list elements.
119
120    Parameters
121    ----------
122    x : Apply instance, Variable instance, numpy.ndarray, or number
123        This thing will be transformed into a `Variable` in a sensible way. An
124        ndarray argument will not be copied, but a list of numbers will be
125        copied to make an ndarray.
126    name : str or None
127        If a new `Variable` instance is created, it will be named with this
128        string.
129    ndim : None or integer
130        Return a Variable with this many dimensions.
131
132    Raises
133    ------
134    ValueError
135        If an `Apply` with more than one output is fetched or
136        if `x` cannot be made into a Variable with `ndim` dimensions.
137    AsTensorError
138        If `x` cannot be converted to a TensorType Variable.
139
140    """
141    if hasattr(x, '_as_TensorVariable'):
142        return x._as_TensorVariable()  # TODO: pass name and ndim arguments
143
144    if isinstance(x, gof.Apply):
145        # use Apply's default output mechanism
146        if (x.op.default_output is None) and (len(x.outputs) != 1):
147            raise ValueError(
148                "It is ambiguous which output of a multi-output Op has"
149                " to be fetched.", x)
150
151        x = x.default_output()
152    if isinstance(x, Variable):
153        if isinstance(x.type, scal.Scalar):
154            x = tensor_from_scalar(x)
155
156        if not isinstance(x.type, TensorType):
157            raise AsTensorError(
158                "Variable type field must be a TensorType.", x, x.type)
159
160        if ndim is None:
161            return x
162        else:
163            if (x.type.ndim > ndim):
164                # strip off leading broadcastable dimensions
165                first_non_broadcastable = [idx for idx in xrange(x.ndim)
166                                           if not x.broadcastable[idx]][0]
167                x = x.dimshuffle(list(range(x.ndim))[first_non_broadcastable:])
168                if x.ndim > ndim:
169                    raise ValueError(
170                        'TensorType could not be cast to have %i dimensions'
171                        % ndim, x.type
172                    )
173                return x
174            elif (x.type.ndim < ndim):
175                return shape_padleft(x, n_ones=(ndim - x.type.ndim))
176            else:
177                return x
178    if isinstance(x, (tuple, list)) and python_any(isinstance(xi, Variable)
179                                                   for xi in x):
180        try:
181            return stack(x)
182        except (TypeError, ValueError):
183            pass
184
185    if isinstance(x, bool):
186        raise AsTensorError(
187            "Cannot cast True or False as a tensor variable. Please use "
188            "np.array(True) or np.array(False) if you need these constants. "
189            "This error might be caused by using the == operator on "
190            "Variables. v == w does not do what you think it does, "
191            "use theano.tensor.eq(v, w) instead.")
192
193    try:
194        return constant(x, name=name, ndim=ndim)
195    except TypeError:
196        try:
197            str_x = str(x)
198        except Exception:
199            str_x = repr(x)
200        raise AsTensorError("Cannot convert %s to TensorType" % str_x, type(x))
201
202# this has a different name, because _as_tensor_variable is the
203# function which ops use to upcast their arguments... this
204# internal-use function is a good place to put debugging stuff, better
205# than the global astensor.
206_as_tensor_variable = as_tensor_variable
207
208as_tensor = as_tensor_variable
209
210
211def constant(x, name=None, ndim=None, dtype=None):
212    """Return a symbolic `Constant` with value `x`.
213
214    Raises
215    ------
216    TypeError
217        `x` could not be converted to a numpy.ndarray.
218    ValueError
219        `x` could not be expanded to have ndim dimensions.
220
221    Note
222    ----
223    We create a small cache of frequently used constant.
224    This speed up the Merge optimization for big graph.
225    We want to cache all scalar to don't merge as frequently constants.
226    But we don't want to cache too much stuff.
227    So we cache integer with dtype [u]int and float where the value is
228    between -10 and 10.
229    We cache all broadcast pattern for scalar.
230
231    """
232    x_ = scal.convert(x, dtype=dtype)
233
234    bcastable = [d == 1 for d in x_.shape]
235    if ndim is not None:
236        if len(bcastable) < ndim:
237            bcastable = [True] * (ndim - len(bcastable)) + bcastable
238        elif len(bcastable) > ndim:
239            # TODO: strip off dimensions of size 1
240            raise ValueError(
241                'ndarray could not be cast to constant with %i dimensions' %
242                ndim)
243        assert len(bcastable) == ndim
244
245    try:
246        ttype = TensorType(dtype=x_.dtype, broadcastable=bcastable)
247        if not constant.enable:
248            return TensorConstant(ttype, x_, name=name)
249
250        sig = TensorConstantSignature((ttype, x_))
251        if sig in constant_cache:
252            return constant_cache[sig]
253
254        ret = TensorConstant(ttype, x_, name=name)
255        if (x_.size == 1 and
256            (-10) <= x_ <= 10 and
257            (x_.dtype in int_dtypes or x_.dtype in uint_dtypes or
258             (x_.dtype in float_dtypes and
259              # Limit the size of the cache.
260              len(constant_cache) < 10000))):
261            constant_cache[sig] = ret
262            # This is needed to raise a good error to the user.
263            ret.cached = True
264        return ret
265    except Exception:
266        raise TypeError("Could not convert %s to TensorType" % x, type(x))
267
268
269constant.enable = True
270constant_cache = {}
271
272
273def _obj_is_wrappable_as_tensor(x):
274    try:
275        constant(x)
276        return True
277    except TypeError:
278        return False
279
280
281if int(config.tensor.cmp_sloppy) > 1:
282    # This config variable is a quick-and-dirty way to get low-precision
283    # comparisons.  For a more precise setting of these tolerances set
284    # them explicitly in your user code by assigning, for example,
285    # "theano.tensor.basic.float32_atol = ..."
286
287    # When config.tensor.cmp_sloppy>1 we are even more sloppy. This is
288    # useful to test the GPU as they don't use extended precision and
289    # this cause some difference bigger then the normal sloppy.
290    float16_atol = 1e-2
291    float16_rtol = 5e-2
292
293    float32_atol = 5e-4
294    float32_rtol = 1e-3
295
296    float64_rtol = 1e-4
297    float64_atol = 1e-3
298elif int(config.tensor.cmp_sloppy):
299    float16_atol = 5e-3
300    float16_rtol = 1e-2
301
302    float32_atol = 1e-4
303    float32_rtol = 1e-3
304
305    float64_rtol = 1e-4
306    float64_atol = 1e-3
307else:
308    # If you change those value in test don't forget to put them back
309    # when the test end.  Don't forget the case when the test fail.
310    float16_atol = 1e-3
311    float16_rtol = 1e-3
312
313    float32_atol = 1e-5
314    float32_rtol = 1e-5
315
316    # defaults in numpy.allclose
317    # Don't be more strict then numpy rtol
318    # It cause useless error.
319    float64_rtol = 1.0000000000000001e-05
320    float64_atol = 1e-8
321
322
323def _get_atol_rtol(a, b):
324    tiny = ('float16',)
325    narrow = ('float32', 'complex64')
326    if (str(a.dtype) in tiny) or (str(b.dtype) in tiny):
327        atol = float16_atol
328        rtol = float16_rtol
329    elif (str(a.dtype) in narrow) or (str(b.dtype) in narrow):
330        atol = float32_atol
331        rtol = float32_rtol
332    else:
333        atol = float64_atol
334        rtol = float64_rtol
335    return atol, rtol
336
337
338def _allclose(a, b, rtol=None, atol=None):
339    a = np.asarray(a)
340    b = np.asarray(b)
341    atol_, rtol_ = _get_atol_rtol(a, b)
342    if rtol is not None:
343        rtol_ = rtol
344    if atol is not None:
345        atol_ = atol
346
347    return np.allclose(a, b, atol=atol_, rtol=rtol_)
348
349
350class NotScalarConstantError(Exception):
351    """
352    Raised by get_scalar_constant_value if called on something that is
353    not a scalar constant.
354    """
355
356
357class EmptyConstantError(NotScalarConstantError):
358    """
359    Raised by get_scalar_const_value if called on something that is a
360    zero dimensional constant.
361    """
362
363
364def numpy_scalar(data):
365    """ Return a scalar stored in a numpy ndarray.
366
367    Raises
368    ------
369     NotScalarConstantError
370        If the numpy ndarray is not a scalar.
371
372    """
373
374    # handle case where data is numpy.array([])
375    if (data.ndim > 0 and
376        (len(data.shape) == 0 or
377         builtins.max(data.shape) == 0)):
378        assert np.all(np.array([]) == data)
379        raise EmptyConstantError()
380    try:
381        np.complex(data)  # works for all numeric scalars
382        return data
383    except Exception:
384        raise NotScalarConstantError(
385            'v.data is non-numeric, non-scalar, or has more than one'
386            ' unique value', data)
387
388
389get_scalar_constant_value_elemwises = (
390    scal.Cast, scal.Switch,
391    scal.NEQ, scal.EQ,
392    scal.LT, scal.GT, scal.LE, scal.GE,
393    scal.Sub, scal.Add, scal.Mod, scal.Mul,
394    scal.IntDiv, scal.TrueDiv, scal.Minimum, scal.Maximum)
395
396
397def get_scalar_constant_value(orig_v, elemwise=True,
398                              only_process_constants=False,
399                              max_recur=10):
400    """Return the constant scalar(0-D) value underlying variable `v`.
401
402    If `v` is the output of dimshuffles, fills, allocs, rebroadcasts,
403    cast, OutputGuard, DeepCopyOp, ScalarFromTensor, ScalarOp, Elemwise
404    and some pattern with Subtensor, this function digs through them.
405
406    If `v` is not some view of constant scalar data, then raise a
407    NotScalarConstantError.
408
409    Parameters
410    ----------
411    elemwise : bool
412        If False, we won't try to go into elemwise. So this call is faster.
413        But we still investigate in Second Elemwise (as this is a substitute
414        for Alloc)
415    only_process_constants : bool
416        If True, we only attempt to obtain the value of `orig_v` if it's
417        directly constant and don't try to dig through dimshuffles, fills,
418        allocs, and other to figure out its value.
419    max_recur : int
420        The maximum number of recursion.
421
422    Notes
423    -----
424        There may be another function similar to this one in the code,
425        but I'm not sure where it is.
426
427    """
428    v = orig_v
429    while True:
430        if v is None:
431            # None is not a scalar (and many uses of this function seem
432            # to depend on passing it None)
433            raise NotScalarConstantError()
434
435        if isinstance(v, (np.integer, integer_types, float)):
436            return np.asarray(v)
437
438        if isinstance(v, np.ndarray):
439            return numpy_scalar(v).copy()
440
441        if isinstance(v, Constant):
442            if getattr(v.tag, 'unique_value', None) is not None:
443                data = v.tag.unique_value
444            else:
445                data = v.data
446            return numpy_scalar(data).copy()
447
448        if (not only_process_constants and
449                getattr(v, 'owner', None) and
450                max_recur > 0):
451            max_recur -= 1
452            if isinstance(v.owner.op, (Alloc, DimShuffle, Rebroadcast,
453                                       # outputguard is only used in debugmode but we
454                                       # keep it here to avoid problems with old pickels.
455                                       compile.ops.OutputGuard,
456                                       compile.DeepCopyOp)):
457                v = v.owner.inputs[0]
458                continue
459            elif isinstance(v.owner.op, theano.compile.ops.Shape_i):
460                i = v.owner.op.i
461                inp = v.owner.inputs[0]
462                if isinstance(inp, Constant):
463                    return np.asarray(inp.data.shape[i])
464                # The shape of a broadcastable dimension is 1
465                if (hasattr(inp.type, 'broadcastable') and
466                        inp.type.broadcastable[i]):
467                    return np.asarray(1)
468
469            # Don't act as the constant_folding optimization here as this
470            # fct is used too early in the optimization phase.  This would
471            # mess with the stabilization optimization and be too slow.
472            # We put all the scalar Ops used by get_canonical_form_slice()
473            # to allow it to determine the broadcast pattern correctly.
474            elif isinstance(v.owner.op, (ScalarFromTensor, TensorFromScalar)):
475                v = v.owner.inputs[0]
476                continue
477            elif isinstance(v.owner.op, theano.tensor.opt.Assert):
478                # check if all conditions are constant and true
479                cond = [get_scalar_constant_value(c, max_recur=max_recur)
480                        for c in v.owner.inputs[1:]]
481                if builtins.all([0 == c.ndim and c != 0 for c in cond]):
482                    v = v.owner.inputs[0]
483                    continue
484            elif isinstance(v.owner.op, scal.ScalarOp):
485                if isinstance(v.owner.op, scal.Second):
486                    # We don't need both input to be constant for second
487                    shp, val = v.owner.inputs
488                    v = val
489                    continue
490                if isinstance(v.owner.op, get_scalar_constant_value_elemwises):
491                    const = [get_scalar_constant_value(i, max_recur=max_recur)
492                             for i in v.owner.inputs]
493                    ret = [[None]]
494                    v.owner.op.perform(v.owner, const, ret)
495                    return ret[0][0].copy()
496            # In fast_compile, we don't enable local_fill_to_alloc, so
497            # we need to investigate Second as Alloc. So elemwise
498            # don't disable the check for Second.
499            elif isinstance(v.owner.op, Elemwise):
500                if isinstance(v.owner.op.scalar_op, scal.Second):
501                    # We don't need both input to be constant for second
502                    shp, val = v.owner.inputs
503                    v = val
504                    continue
505                elif elemwise and isinstance(
506                        v.owner.op.scalar_op,
507                        get_scalar_constant_value_elemwises):
508                    const = [get_scalar_constant_value(i, max_recur=max_recur)
509                             for i in v.owner.inputs]
510                    ret = [[None]]
511                    v.owner.op.perform(v.owner, const, ret)
512                    return ret[0][0].copy()
513            elif (isinstance(v.owner.op, theano.tensor.subtensor.Subtensor) and
514                  v.ndim == 0):
515                if isinstance(v.owner.inputs[0], TensorConstant):
516                    cdata = tuple(v.owner.op.get_constant_idx(v.owner.inputs))
517                    try:
518                        return v.owner.inputs[0].data.__getitem__(cdata).copy()
519                    except IndexError:
520                        raise IndexError(
521                            str(tuple(v.owner.op.idx_list)) +
522                            " is not a valid index into " +
523                            str(v.owner.inputs[0].data))
524
525                # The index list 'idx_list' should have length the same
526                # shape as the input.
527                # TODO: implement the case where we take a scalar in a matrix
528                assert len(v.owner.op.idx_list) == v.owner.inputs[0].ndim
529
530                # Needed to make better graph in this test in
531                # theano/tensor/tests/test_sharedvar.py:
532                # test_shared_options.test_specify_shape_partial
533                if ((v.owner.inputs[0].owner and
534                     isinstance(v.owner.inputs[0].owner.op, Join) and
535                     len(v.owner.op.idx_list) == 1)):
536                    # Ensure the Join is joining only scalar variables (so that
537                    # the constant value can be found at the same index as the
538                    # one used in the sub-tensor).
539                    if python_all(var.ndim == 0 for var in
540                                  v.owner.inputs[0].owner.inputs[1:]):
541                        idx = v.owner.op.idx_list[0]
542                        if isinstance(idx, gof.Type):
543                            idx = get_scalar_constant_value(v.owner.inputs[1],
544                                                            max_recur=max_recur)
545                        # Note the '+ 1' is because the first argument to Join
546                        # is the axis.
547                        ret = v.owner.inputs[0].owner.inputs[idx + 1]
548                        ret = get_scalar_constant_value(ret, max_recur=max_recur)
549                        # join can cast implicitly its input in some case.
550                        return theano._asarray(ret, dtype=v.type.dtype)
551                    if python_all(var.ndim == 1 for var in
552                                  v.owner.inputs[0].owner.inputs[1:]):
553                        idx = v.owner.op.idx_list[0]
554                        if isinstance(idx, gof.Type):
555                            idx = get_scalar_constant_value(v.owner.inputs[1],
556                                                            max_recur=max_recur)
557                        try:
558                            # TODO: assert joined axis is 0.
559                            length = 0
560                            loop = False
561                            for joined in v.owner.inputs[0].owner.inputs[1:]:
562                                ll = get_vector_length(joined)
563                                if idx < length + ll:
564                                    v = joined[idx - length]
565                                    loop = True
566                                    break
567                                length += ll
568                            if loop:
569                                continue
570                        except TypeError:
571                            pass
572                        except ValueError:
573                            pass
574
575                elif (v.owner.inputs[0].owner and
576                      isinstance(v.owner.inputs[0].owner.op,
577                                 theano.tensor.opt.MakeVector) and
578                      # MakeVector normally accept only scalar as input.
579                      # We put this check in case there is change in the future
580                      python_all(var.ndim == 0 for var in
581                                 v.owner.inputs[0].owner.inputs) and
582                      len(v.owner.op.idx_list) == 1):
583
584                    idx = v.owner.op.idx_list[0]
585                    if isinstance(idx, gof.Type):
586                        idx = get_scalar_constant_value(v.owner.inputs[1],
587                                                        max_recur=max_recur)
588                    # Python 2.4 does not support indexing with numpy.integer
589                    # So we cast it.
590                    idx = int(idx)
591                    ret = v.owner.inputs[0].owner.inputs[idx]
592                    ret = get_scalar_constant_value(ret, max_recur=max_recur)
593                    # MakeVector can cast implicitly its input in some case.
594                    return theano._asarray(ret, dtype=v.type.dtype)
595
596                # This is needed when we take the grad as the Shape op
597                # are not already changed into MakeVector
598                owner = v.owner
599                leftmost_parent = owner.inputs[0]
600                if (leftmost_parent.owner and
601                    isinstance(leftmost_parent.owner.op,
602                               theano.tensor.Shape)):
603                    op = owner.op
604                    idx_list = op.idx_list
605                    idx = idx_list[0]
606                    if isinstance(idx, gof.Type):
607                        idx = get_scalar_constant_value(owner.inputs[1],
608                                                        max_recur=max_recur)
609                    grandparent = leftmost_parent.owner.inputs[0]
610                    gp_broadcastable = grandparent.type.broadcastable
611                    ndim = grandparent.type.ndim
612                    if grandparent.owner and isinstance(grandparent.owner.op,
613                                                        Rebroadcast):
614                        ggp_broadcastable = grandparent.owner.inputs[0].broadcastable
615                        l = [b1 or b2 for b1, b2 in zip(ggp_broadcastable,
616                                                        gp_broadcastable)]
617                        gp_broadcastable = tuple(l)
618
619                    assert ndim == len(gp_broadcastable)
620
621                    if not (idx < len(gp_broadcastable)):
622                        msg = ("get_scalar_constant_value detected " +
623                               "deterministic IndexError: x.shape[%d] " +
624                               "when x.ndim=%d.") % (idx, ndim)
625                        if config.exception_verbosity == 'high':
626                            msg += ' x=%s' % min_informative_str(v)
627                        else:
628                            msg += ' x=%s' % str(v)
629                        raise ValueError(msg)
630
631                    if gp_broadcastable[idx]:
632                        return np.asarray(1)
633
634        raise NotScalarConstantError(v)
635
636
637# Easy constructors
638
639def tensor(*args, **kwargs):
640    name = kwargs.pop('name', None)
641    return TensorType(*args, **kwargs)(name=name)
642
643
644def _multi(*fns):
645    def f2(f, *names):
646        if names and isinstance(names[0], integer_types):
647            if names == 1:
648                return f()
649            else:
650                return [f() for i in xrange(names[0])]
651        if isinstance(names, tuple):
652            if len(names) == 1:
653                names = names[0]
654        if len(names) == 1:
655            return f(names)
656        else:
657            return [f(name) for name in names]
658    if len(fns) == 1:
659        return partial(f2, fns)
660    else:
661        return [partial(f2, f) for f in fns]
662
663cscalar = TensorType('complex64', ())
664zscalar = TensorType('complex128', ())
665fscalar = TensorType('float32', ())
666dscalar = TensorType('float64', ())
667bscalar = TensorType('int8', ())
668wscalar = TensorType('int16', ())
669iscalar = TensorType('int32', ())
670lscalar = TensorType('int64', ())
671
672
673def scalar(name=None, dtype=None):
674    """Return a symbolic scalar variable.
675
676    Parameters
677    ----------
678    dtype: numeric
679        None means to use theano.config.floatX.
680    name
681        A name to attach to this variable.
682
683    """
684    if dtype is None:
685        dtype = config.floatX
686    type = TensorType(dtype, ())
687    return type(name)
688
689scalars, fscalars, dscalars, iscalars, lscalars = _multi(
690    scalar, fscalar, dscalar, iscalar, lscalar)
691
692int_types = bscalar, wscalar, iscalar, lscalar
693float_types = fscalar, dscalar
694complex_types = cscalar, zscalar
695int_scalar_types = int_types
696float_scalar_types = float_types
697complex_scalar_types = complex_types
698
699cvector = TensorType('complex64', (False, ))
700zvector = TensorType('complex128', (False, ))
701fvector = TensorType('float32', (False, ))
702dvector = TensorType('float64', (False, ))
703bvector = TensorType('int8', (False,))
704wvector = TensorType('int16', (False,))
705ivector = TensorType('int32', (False, ))
706lvector = TensorType('int64', (False, ))
707
708
709def vector(name=None, dtype=None):
710    """Return a symbolic vector variable.
711
712    Parameters
713    ----------
714    dtype: numeric
715        None means to use theano.config.floatX.
716    name
717        A name to attach to this variable
718
719    """
720    if dtype is None:
721        dtype = config.floatX
722    type = TensorType(dtype, (False, ))
723    return type(name)
724
725vectors, fvectors, dvectors, ivectors, lvectors = _multi(
726    vector, fvector, dvector, ivector, lvector)
727
728int_vector_types = bvector, wvector, ivector, lvector
729float_vector_types = fvector, dvector
730complex_vector_types = cvector, zvector
731
732cmatrix = TensorType('complex64', (False, False))
733zmatrix = TensorType('complex128', (False, False))
734fmatrix = TensorType('float32', (False, False))
735dmatrix = TensorType('float64', (False, False))
736bmatrix = TensorType('int8', (False, False))
737wmatrix = TensorType('int16', (False, False))
738imatrix = TensorType('int32', (False, False))
739lmatrix = TensorType('int64', (False, False))
740
741
742def matrix(name=None, dtype=None):
743    """Return a symbolic matrix variable.
744
745    Parameters
746    ----------
747    dtype: numeric
748        None means to use theano.config.floatX.
749    name
750        A name to attach to this variable.
751
752    """
753    if dtype is None:
754        dtype = config.floatX
755    type = TensorType(dtype, (False, False))
756    return type(name)
757
758matrices, fmatrices, dmatrices, imatrices, lmatrices = _multi(
759    matrix, fmatrix, dmatrix, imatrix, lmatrix)
760
761int_matrix_types = bmatrix, wmatrix, imatrix, lmatrix
762float_matrix_types = fmatrix, dmatrix
763complex_matrix_types = cmatrix, zmatrix
764
765crow = TensorType('complex64', (True, False))
766zrow = TensorType('complex128', (True, False))
767frow = TensorType('float32', (True, False))
768drow = TensorType('float64', (True, False))
769brow = TensorType('int8', (True, False))
770wrow = TensorType('int16', (True, False))
771irow = TensorType('int32', (True, False))
772lrow = TensorType('int64', (True, False))
773
774
775def row(name=None, dtype=None):
776    """Return a symbolic row variable (ndim=2, broadcastable=[True,False]).
777
778    Parameters
779    ----------
780    dtype: numeric type
781        None means to use theano.config.floatX.
782    name
783        A name to attach to this variable.
784
785    """
786    if dtype is None:
787        dtype = config.floatX
788    type = TensorType(dtype, (True, False))
789    return type(name)
790rows, frows, drows, irows, lrows = _multi(row, frow, drow, irow, lrow)
791
792ccol = TensorType('complex64', (False, True))
793zcol = TensorType('complex128', (False, True))
794fcol = TensorType('float32', (False, True))
795dcol = TensorType('float64', (False, True))
796bcol = TensorType('int8', (False, True))
797wcol = TensorType('int16', (False, True))
798icol = TensorType('int32', (False, True))
799lcol = TensorType('int64', (False, True))
800
801
802def col(name=None, dtype=None):
803    """Return a symbolic column variable (ndim=2, broadcastable=[False,True]).
804
805    Parameters
806    ----------
807    dtype : numeric
808        None means to use theano.config.floatX.
809    name
810        A name to attach to this variable.
811
812    """
813    if dtype is None:
814        dtype = config.floatX
815    type = TensorType(dtype, (False, True))
816    return type(name)
817cols, fcols, dcols, icols, lcols = _multi(col, fcol, dcol, icol, lcol)
818
819ctensor3 = TensorType('complex64', ((False,) * 3))
820ztensor3 = TensorType('complex128', ((False,) * 3))
821ftensor3 = TensorType('float32', ((False,) * 3))
822dtensor3 = TensorType('float64', ((False,) * 3))
823btensor3 = TensorType('int8', ((False,) * 3))
824wtensor3 = TensorType('int16', ((False,) * 3))
825itensor3 = TensorType('int32', ((False,) * 3))
826ltensor3 = TensorType('int64', ((False,) * 3))
827
828
829def tensor3(name=None, dtype=None):
830    """Return a symbolic 3-D variable.
831
832    Parameters
833    ----------
834    dtype: numeric type
835        None means to use theano.config.floatX.
836    name
837        A name to attach to this variable.
838
839    """
840    if dtype is None:
841        dtype = config.floatX
842    type = TensorType(dtype, (False, False, False))
843    return type(name)
844
845tensor3s, ftensor3s, dtensor3s, itensor3s, ltensor3s = _multi(
846    tensor3, ftensor3, dtensor3, itensor3, ltensor3)
847
848ctensor4 = TensorType('complex64', ((False,) * 4))
849ztensor4 = TensorType('complex128', ((False,) * 4))
850ftensor4 = TensorType('float32', ((False,) * 4))
851dtensor4 = TensorType('float64', ((False,) * 4))
852btensor4 = TensorType('int8', ((False,) * 4))
853wtensor4 = TensorType('int16', ((False,) * 4))
854itensor4 = TensorType('int32', ((False,) * 4))
855ltensor4 = TensorType('int64', ((False,) * 4))
856
857
858def tensor4(name=None, dtype=None):
859    """Return a symbolic 4-D variable.
860
861    Parameters
862    ----------
863    dtype: numeric type
864        None means to use theano.config.floatX.
865    name
866        A name to attach to this variable.
867
868    """
869    if dtype is None:
870        dtype = config.floatX
871    type = TensorType(dtype, (False, False, False, False))
872    return type(name)
873tensor4s, ftensor4s, dtensor4s, itensor4s, ltensor4s = _multi(
874    tensor4, ftensor4, dtensor4, itensor4, ltensor4)
875
876ctensor5 = TensorType('complex64', ((False,) * 5))
877ztensor5 = TensorType('complex128', ((False,) * 5))
878ftensor5 = TensorType('float32', ((False,) * 5))
879dtensor5 = TensorType('float64', ((False,) * 5))
880btensor5 = TensorType('int8', ((False,) * 5))
881wtensor5 = TensorType('int16', ((False,) * 5))
882itensor5 = TensorType('int32', ((False,) * 5))
883ltensor5 = TensorType('int64', ((False,) * 5))
884
885
886def tensor5(name=None, dtype=None):
887    """Return a symbolic 5-D variable.
888
889    Parameters
890    ----------
891    dtype: numeric type
892        None means to use theano.config.floatX.
893    name
894        A name to attach to this variable.
895
896    """
897    if dtype is None:
898        dtype = config.floatX
899    type = TensorType(dtype, (False, False, False, False, False))
900    return type(name)
901tensor5s, ftensor5s, dtensor5s, itensor5s, ltensor5s = _multi(
902    tensor5, ftensor5, dtensor5, itensor5, ltensor5)
903
904ctensor6 = TensorType('complex64', ((False,) * 6))
905ztensor6 = TensorType('complex128', ((False,) * 6))
906ftensor6 = TensorType('float32', ((False,) * 6))
907dtensor6 = TensorType('float64', ((False,) * 6))
908btensor6 = TensorType('int8', ((False,) * 6))
909wtensor6 = TensorType('int16', ((False,) * 6))
910itensor6 = TensorType('int32', ((False,) * 6))
911ltensor6 = TensorType('int64', ((False,) * 6))
912
913
914def tensor6(name=None, dtype=None):
915    """Return a symbolic 6-D variable.
916
917    Parameters
918    ----------
919    dtype: numeric type
920        None means to use theano.config.floatX.
921    name
922        A name to attach to this variable.
923
924    """
925    if dtype is None:
926        dtype = config.floatX
927    type = TensorType(dtype, (False,) * 6)
928    return type(name)
929tensor6s, ftensor6s, dtensor6s, itensor6s, ltensor6s = _multi(
930    tensor6, ftensor6, dtensor6, itensor6, ltensor6)
931
932ctensor7 = TensorType('complex64', ((False,) * 7))
933ztensor7 = TensorType('complex128', ((False,) * 7))
934ftensor7 = TensorType('float32', ((False,) * 7))
935dtensor7 = TensorType('float64', ((False,) * 7))
936btensor7 = TensorType('int8', ((False,) * 7))
937wtensor7 = TensorType('int16', ((False,) * 7))
938itensor7 = TensorType('int32', ((False,) * 7))
939ltensor7 = TensorType('int64', ((False,) * 7))
940
941
942def tensor7(name=None, dtype=None):
943    """Return a symbolic 7-D variable.
944
945    Parameters
946    ----------
947    dtype: numeric type
948        None means to use theano.config.floatX.
949    name
950        A name to attach to this variable.
951
952    """
953    if dtype is None:
954        dtype = config.floatX
955    type = TensorType(dtype, (False,) * 7)
956    return type(name)
957tensor7s, ftensor7s, dtensor7s, itensor7s, ltensor7s = _multi(
958    tensor7, ftensor7, dtensor7, itensor7, ltensor7)
959
960
961Tensor = TensorType
962
963
964# This bizarre push-import avoids a circular dependency.
965elemwise.as_tensor_variable = as_tensor_variable
966elemwise.TensorType = TensorType
967elemwise.TensorVariable = TensorVariable
968elemwise.TensorConstant = TensorConstant
969
970#########################
971# Utilities
972#########################
973
974
975def _scal_elemwise_with_nfunc(nfunc, nin, nout):
976    """
977    Replace a symbol definition with an elementwise version of the
978    corresponding scalar Op.  If it is not None, the nfunc argument
979    should be a string such that getattr(numpy, nfunc) implements
980    a vectorized version of the elemwise operation. nin is the number
981    of inputs expected by that function, and nout is the number of
982    **destination** inputs it takes. That is, the function should
983    take nin+nout inputs. nout == 0 means that the numpy function
984    does not take a numpy array argument to put its result in.
985
986    """
987    def construct(symbol):
988        symbolname = symbol.__name__
989        inplace = symbolname.endswith('_inplace')
990        if inplace:
991            msg = "inplace"
992        else:
993            msg = "no_inplace"
994
995        n = "Elemwise{%s,%s}" % (symbolname, msg)
996
997        if inplace:
998            scalar_op = getattr(scal, symbolname[:-len('_inplace')])
999            inplace_scalar_op = scalar_op.__class__(scal.transfer_type(0))
1000            rval = elemwise.Elemwise(inplace_scalar_op, {0: 0}, name=n,
1001                                     nfunc_spec=(nfunc and (nfunc, nin, nout)))
1002        else:
1003            scalar_op = getattr(scal, symbolname)
1004            rval = elemwise.Elemwise(scalar_op, name=n,
1005                                     nfunc_spec=(nfunc and (nfunc, nin, nout)))
1006
1007        if getattr(symbol, '__doc__', False):
1008            rval.__doc__ = symbol.__doc__ + '\n' + rval.__doc__
1009
1010        # for the meaning of this see the ./epydoc script
1011        # it makes epydoc display rval as if it were a function, not an object
1012        rval.__epydoc_asRoutine = symbol
1013        rval.__module__ = 'tensor'
1014
1015        pprint.assign(rval, printing.FunctionPrinter(symbolname))
1016
1017        return rval
1018    return construct
1019
1020_scal_elemwise = _scal_elemwise_with_nfunc(None, None, None)
1021
1022
1023def _pack(x):
1024    """
1025    Convert x to a list if it is an iterable, otherwise wrap it in a list.
1026    """
1027    try:
1028        return list(x)
1029    except TypeError:
1030        return [x]
1031
1032
1033def check_and_normalize_axes(x, axis):
1034    """
1035    Check axes, normalize and convert them to a Python list of integers.
1036    Return an empty list if argument is None.
1037
1038    Parameters
1039    ----------
1040    x: Tensor variable
1041    axis = Integer, tuple or list of integers
1042
1043    Returns
1044    -------
1045    axis: list of integers
1046    """
1047    x = as_tensor_variable(x)
1048    if axis is None:
1049        axis = []
1050    elif (isinstance(axis, (integer_types, np.integer)) or
1051            (isinstance(axis, np.ndarray) and axis.ndim == 0)):
1052                axis = [int(axis)]
1053    elif isinstance(axis, (tuple, list, np.ndarray)):
1054        axis = [int(i) for i in axis]
1055    elif isinstance(axis, Variable):
1056        if NoneConst.equals(axis):
1057            axis = []
1058        elif not isinstance(axis, TensorConstant):
1059            raise TypeError("Computation needs a constant axis. Got %s" % axis)
1060        else:
1061            assert axis.dtype in integer_dtypes
1062            if (isinstance(axis.data, (integer_types, np.integer)) or
1063                    (isinstance(axis.data, np.ndarray) and axis.data.ndim == 0)):
1064                        axis = [int(axis.data)]
1065            elif isinstance(axis.data, (list, np.ndarray)):
1066                axis = [int(i) for i in axis.data]
1067    else:
1068        raise TypeError("Axis must be an integer, tuple, list of integers or a TensorVariable. Got %s" % axis)
1069    if len(axis) > 0:
1070        for i in range(len(axis)):
1071            if axis[i] < 0:
1072                axis[i] += x.type.ndim
1073            if axis[i] < 0 or axis[i] >= x.type.ndim:
1074                raise ValueError("Computation needs a valid axis number for %d-D tensor. Got %d" % (x.type.ndim, axis[i]))
1075        axis = list(set(axis))
1076        axis.sort()
1077    return axis
1078
1079
1080#########################
1081# Casting Operations
1082#########################
1083
1084class TensorFromScalar(Op):
1085
1086    __props__ = ()
1087
1088    def make_node(self, s):
1089        assert isinstance(s.type, scal.Scalar)
1090        return Apply(self,
1091                     [s],
1092                     [tensor(dtype=s.type.dtype,
1093                             broadcastable=())])
1094
1095    def perform(self, node, inp, out_):
1096        s, = inp
1097        out, = out_
1098        out[0] = np.asarray(s)
1099
1100    def infer_shape(self, node, in_shapes):
1101        return [()]
1102
1103    def grad(self, inp, grads):
1104        s, = inp
1105        dt, = grads
1106        if s.type.dtype in float_dtypes:
1107            assert dt.type.dtype in float_dtypes
1108            return [scalar_from_tensor(dt)]
1109
1110        # If the input dtype is an integer, then so is the output dtype,
1111        # and the "zero" gradient can be represented in that int dtype.
1112        # Currently, theano.grad insists that the dtype of the returned
1113        # gradient has a float dtype, so we use floatX.
1114        if s.type.dtype in discrete_dtypes:
1115            return [s.zeros_like().astype(theano.config.floatX)]
1116
1117        raise NotImplementedError("grad not implemented for complex dtypes")
1118
1119tensor_from_scalar = TensorFromScalar()
1120
1121
1122class ScalarFromTensor(Op):
1123
1124    __props__ = ()
1125
1126    def make_node(self, t):
1127        assert isinstance(t.type, TensorType)
1128        assert t.type.broadcastable == ()
1129        return Apply(self,
1130                     [t],
1131                     [scal.get_scalar_type(dtype=t.type.dtype).make_variable()]
1132                     )
1133
1134    def perform(self, node, inp, out_):
1135        s, = inp
1136        out, = out_
1137        out[0] = s.flatten()[0]
1138
1139    def infer_shape(self, node, in_shapes):
1140        return [()]
1141
1142    def grad(self, inp, grads):
1143        s, = inp
1144        dt, = grads
1145        return [tensor_from_scalar(dt)]
1146
1147    def R_op(self, inputs, eval_points):
1148        if None in eval_points:
1149            return [None]
1150        return self.make_node(*eval_points).outputs
1151
1152    def c_code(self, node, name, inputs, outputs, sub):
1153        x, = inputs
1154        z, = outputs
1155        fail = sub['fail']
1156        return """
1157        %(z)s = ((dtype_%(x)s*)(PyArray_DATA(%(x)s)))[0];
1158        """ % locals()
1159
1160    def c_code_cache_version(self):
1161        return (1,)
1162
1163scalar_from_tensor = ScalarFromTensor()
1164
1165
1166# to be removed as we get the epydoc routine-documenting thing going
1167# -JB 20080924
1168def _conversion(real_value, name):
1169    __oplist_tag(real_value, 'casting')
1170    real_value.__module__ = 'tensor.basic'
1171    pprint.assign(real_value, printing.FunctionPrinter(name))
1172    return real_value
1173
1174
1175# These _conver_to_<type> functions have leading underscores to indicate that
1176# they should not be called directly.  They do not perform sanity checks about
1177# what types you are casting to what.  That logic is implemented by the
1178# `cast()` function below.
1179
1180_convert_to_bool = _conversion(
1181    elemwise.Elemwise(scal.convert_to_bool), 'bool')
1182"""Cast to boolean"""
1183
1184_convert_to_int8 = _conversion(
1185    elemwise.Elemwise(scal.convert_to_int8), 'int8')
1186"""Cast to 8-bit integer"""
1187
1188_convert_to_int16 = _conversion(
1189    elemwise.Elemwise(scal.convert_to_int16), 'int16')
1190"""Cast to 16-bit integer"""
1191
1192_convert_to_int32 = _conversion(
1193    elemwise.Elemwise(scal.convert_to_int32), 'int32')
1194"""Cast to 32-bit integer"""
1195
1196_convert_to_int64 = _conversion(
1197    elemwise.Elemwise(scal.convert_to_int64), 'int64')
1198"""Cast to 64-bit integer"""
1199
1200_convert_to_uint8 = _conversion(
1201    elemwise.Elemwise(scal.convert_to_uint8), 'uint8')
1202"""Cast to unsigned 8-bit integer"""
1203
1204_convert_to_uint16 = _conversion(
1205    elemwise.Elemwise(scal.convert_to_uint16), 'uint16')
1206"""Cast to unsigned 16-bit integer"""
1207
1208_convert_to_uint32 = _conversion(
1209    elemwise.Elemwise(scal.convert_to_uint32), 'uint32')
1210"""Cast to unsigned 32-bit integer"""
1211
1212_convert_to_uint64 = _conversion(
1213    elemwise.Elemwise(scal.convert_to_uint64), 'uint64')
1214"""Cast to unsigned 64-bit integer"""
1215
1216_convert_to_float16 = _conversion(
1217    elemwise.Elemwise(scal.convert_to_float16), 'float16')
1218"""Cast to half-precision floating point"""
1219
1220_convert_to_float32 = _conversion(
1221    elemwise.Elemwise(scal.convert_to_float32), 'float32')
1222"""Cast to single-precision floating point"""
1223
1224_convert_to_float64 = _conversion(
1225    elemwise.Elemwise(scal.convert_to_float64), 'float64')
1226"""Cast to double-precision floating point"""
1227
1228_convert_to_complex64 = _conversion(
1229    elemwise.Elemwise(scal.convert_to_complex64), 'complex64')
1230"""Cast to single-precision complex"""
1231
1232_convert_to_complex128 = _conversion(
1233    elemwise.Elemwise(scal.convert_to_complex128), 'complex128')
1234"""Cast to double-precision complex"""
1235
1236_cast_mapping = {
1237    'bool': _convert_to_bool,
1238    'int8': _convert_to_int8,
1239    'int16': _convert_to_int16,
1240    'int32': _convert_to_int32,
1241    'int64': _convert_to_int64,
1242    'uint8': _convert_to_uint8,
1243    'uint16': _convert_to_uint16,
1244    'uint32': _convert_to_uint32,
1245    'uint64': _convert_to_uint64,
1246    'float16': _convert_to_float16,
1247    'float32': _convert_to_float32,
1248    'float64': _convert_to_float64,
1249    'complex64': _convert_to_complex64,
1250    'complex128': _convert_to_complex128}
1251
1252
1253@constructor
1254def cast(x, dtype):
1255    """Symbolically cast `x` to a Tensor of type `dtype`."""
1256    if dtype == 'floatX':
1257        dtype = config.floatX
1258
1259    _x = as_tensor_variable(x)
1260    if _x.type.dtype == dtype:
1261        return _x
1262    if _x.type.dtype.startswith('complex') and not dtype.startswith('complex'):
1263        raise TypeError((
1264            'Casting from complex to real is ambiguous: consider real(), '
1265            'imag(), angle() or abs()'))
1266    return _cast_mapping[dtype](x)
1267
1268##########################
1269# Unary Operations
1270##########################
1271
1272
1273class MaxAndArgmax(Op):
1274    """
1275    Calculate the max and argmax over a given axis or over all axes.
1276
1277    """
1278    nin = 2  # tensor, axis
1279    nout = 2  # max val, max idx
1280    E_axis = 'invalid axis'
1281    params_type = Generic()
1282    __props__ = ('axis',)
1283    _f16_ok = True
1284
1285    def __init__(self, axis):
1286        assert isinstance(axis, list)
1287        self.axis = tuple(axis)
1288
1289    def get_params(self, node):
1290        return self.axis
1291
1292    def make_node(self, x):
1293        x = _as_tensor_variable(x)
1294
1295        # We keep the original broadcastable flags for dimensions on which
1296        # we do not perform the max / argmax.
1297        all_axes = set(self.axis)
1298        broadcastable = [b for i, b in enumerate(x.type.broadcastable)
1299                         if i not in all_axes]
1300        inputs = [x]
1301        outputs = [tensor(x.type.dtype, broadcastable, name='max'),
1302                   tensor('int64', broadcastable, name='argmax')]
1303        return Apply(self, inputs, outputs)
1304
1305    def perform(self, node, inp, outs, params):
1306        x = inp[0]
1307        axes = params
1308        max, max_idx = outs
1309        if axes is None:
1310            axes = tuple(range(x.ndim))
1311        else:
1312            axes = tuple(int(ax) for ax in axes)
1313        max[0] = theano._asarray(np.max(x, axes),
1314                                 dtype=node.outputs[0].dtype)
1315        # Numpy does not support multiple axes for argmax
1316        # Work around
1317        keep_axes = np.array([i for i in range(x.ndim) if i not in axes],
1318                             dtype='int64')
1319        # Not-reduced axes in front
1320        transposed_x = np.transpose(x, np.concatenate((keep_axes, axes)))
1321        kept_shape = transposed_x.shape[:len(keep_axes)]
1322        reduced_shape = transposed_x.shape[len(keep_axes):]
1323
1324        # Numpy.prod returns 1.0 when arg is empty, so we cast it to int64
1325        # Otherwise reshape would complain citing float arg
1326        new_shape = kept_shape + (np.prod(reduced_shape, dtype='int64'),)
1327        reshaped_x = transposed_x.reshape(new_shape)
1328
1329        max_idx[0] = theano._asarray(np.argmax(reshaped_x, axis=-1),
1330                                     dtype='int64')
1331
1332    def c_code(self, node, name, inp, out, sub):
1333        if len(self.axis) != 1 and len(self.axis) != node.inputs[0].ndim:
1334            raise NotImplementedError("NumPy C-API can compute max and argmax only for 1 axis or for all axes.")
1335        x = inp[0]
1336        axis = sub['params']
1337        max, argmax = out
1338        fail = sub["fail"]
1339        ret = """
1340        #if PY_MAJOR_VERSION >= 3
1341            #ifndef PyInt_AS_LONG
1342                #define PyInt_AS_LONG PyLong_AS_LONG
1343            #endif
1344        #endif
1345
1346        int axis;
1347
1348        if (PyTuple_GET_SIZE(%(axis)s) == PyArray_NDIM(%(x)s)) {
1349            axis = NPY_MAXDIMS;
1350        } else if(PyTuple_GET_SIZE(%(axis)s) == 1) {
1351            PyObject* axis_object = PyTuple_GET_ITEM(%(axis)s, 0);
1352            axis = (int)PyInt_AS_LONG(axis_object);
1353            if (axis > PyArray_NDIM(%(x)s)-1 || axis < -PyArray_NDIM(%(x)s)) {
1354                PyErr_SetString(PyExc_ValueError,
1355                "MaxAndArgmax: bad axis argument");
1356                %(fail)s
1357            }
1358        } else {
1359            PyErr_SetString(PyExc_NotImplementedError,
1360            "MaxAndArgmax: NumPy C-API can compute max and argmax only for 1 axis or for all axes.");
1361            %(fail)s
1362        }
1363
1364        Py_CLEAR(%(max)s);
1365        Py_CLEAR(%(argmax)s);//todo pass them as out parameter.
1366
1367        %(max)s = (PyArrayObject*)PyArray_Max(%(x)s, axis, NULL);
1368        if (%(max)s == NULL) {
1369            %(fail)s;
1370        }
1371        if (!PyArray_CheckExact(%(max)s)) {
1372            %(max)s = (PyArrayObject*)PyArray_FromAny((PyObject*)%(max)s, NULL, 0, 0, NPY_ARRAY_ENSUREARRAY, NULL);
1373            if(%(max)s == NULL){
1374                %(fail)s;
1375            }
1376        }
1377
1378        %(argmax)s = (PyArrayObject*)PyArray_ArgMax(%(x)s, axis, NULL);
1379        if (%(argmax)s == NULL) {
1380            Py_CLEAR(%(max)s);
1381            %(fail)s;
1382        }
1383        if (!PyArray_CheckExact(%(argmax)s)) {
1384            %(argmax)s = (PyArrayObject*)PyArray_FromAny((PyObject*)%(argmax)s, NULL, 0, 0, NPY_ARRAY_ENSUREARRAY, NULL);
1385            if(%(argmax)s == NULL){
1386                %(fail)s;
1387            }
1388        }
1389        if (PyArray_TYPE(%(argmax)s) != NPY_INT64) {
1390            PyObject * tmp = PyArray_Cast(%(argmax)s, NPY_INT64);
1391            if (NULL == tmp){
1392                %(fail)s;
1393            }
1394            Py_DECREF(%(argmax)s);
1395            %(argmax)s = (PyArrayObject*)tmp;
1396        }
1397        """
1398        return ret % locals()
1399
1400    def c_code_cache_version(self):
1401        return (5,)
1402
1403    def infer_shape(self, node, shapes):
1404        ishape = shapes[0]
1405        rval = tuple(ishape[i] for (i, b) in enumerate(
1406            node.inputs[0].type.broadcastable) if i not in self.axis)
1407        return [rval, rval]
1408
1409    def R_op(self, inputs, eval_points):
1410        if eval_points[0] is None:
1411            return [None, None]
1412        if len(self.axis) != 1:
1413            raise ValueError(('R_op supported for arg_max only for '
1414                              'one axis!'))
1415        if self.axis[0] > 1:
1416            raise ValueError(('R_op supported for arg_max only when '
1417                              ' axis is 0 or 1'))
1418        if inputs[0].ndim != 2:
1419            raise ValueError(('R_op supported for arg_max only when '
1420                              ' input is a matrix'))
1421        max_vals, max_pos = self.make_node(*inputs).outputs
1422        if self.axis[0] == 0:
1423            return [eval_points[0][max_pos,
1424                                   arange(eval_points[0].shape[1])], None]
1425        else:
1426            return [eval_points[0][arange(eval_points[0].shape[0]),
1427                                   max_pos], None]
1428
1429    def grad(self, inp, grads):
1430        # The strict sense mathematical gradient of the maximum function is
1431        # not calculated here for it is not defined at every point where some
1432        # coordinates are identical. However, since the latter set has null
1433        # Lebesgue measure, the result may be interpreted as weak gradient.
1434
1435        # @note: This function should work correctly for L{vector}s.
1436        # (x, y), (gz, gw)
1437        # gz*dz/dx + gw*dw/dx, gz*dz/dy + gw*dw/dy
1438        # gMax * dMax/dx + gArgMax * dArgMax/dx,
1439        # gMax * dMax/daxis + gArgMax * dArgMax/daxis
1440        # g_max has one less dimension than x, so you need to complete
1441        # g_max to x's shape when axis=0 the broadcasting mechanism
1442        # does it automatically
1443        x = inp[0]
1444        axis = _as_tensor_variable(self.axis)
1445        g_max, g_max_idx = grads
1446
1447        g_max_disconnected = isinstance(g_max.type, DisconnectedType)
1448        g_max_idx_disconnected = isinstance(g_max_idx.type, DisconnectedType)
1449
1450        # if the op is totally disconnected, so are its inputs
1451        if g_max_disconnected and g_max_idx_disconnected:
1452            return [DisconnectedType()(), DisconnectedType()()]
1453
1454        # if the max is disconnected but the argmax is not,
1455        # the gradient on its inputs is zero
1456        if g_max_disconnected:
1457            return [x.zeros_like()]
1458        if NoneConst.equals(axis):
1459            axis_ = list(range(x.ndim))
1460        else:
1461            axis_ = axis
1462        xmax = max(x, axis_)
1463
1464        # Raise the g_max and xmax to the same number of dim as the input.
1465        pattern = []
1466        out_dim = 0
1467        if NoneConst.equals(axis):
1468            # We are taking the max/argmax over all dimensions.
1469            axis = None
1470        for i in xrange(x.ndim):
1471            if axis is None or i in axis.data:
1472                pattern.append('x')
1473            else:
1474                pattern.append(out_dim)
1475                out_dim += 1
1476        g_max_pad = DimShuffle(g_max.broadcastable, pattern)(g_max)
1477        xmax_pad = DimShuffle(xmax.broadcastable, pattern)(xmax)
1478
1479        # Set the grad to the correct position.
1480        g_x = eq(xmax_pad, x) * g_max_pad
1481        return g_x,
1482
1483
1484class Argmax(Op):
1485    """
1486    Calculate the argmax over a given axis or over all axes.
1487    """
1488    nin = 2  # tensor, axis
1489    nout = 1
1490    E_axis = 'invalid axis'
1491    __props__ = ('axis',)
1492    _f16_ok = True
1493
1494    params_type = ParamsType(c_axis=scal.int64)
1495
1496    def __init__(self, axis):
1497        if axis is not None:
1498            axis = tuple(axis)
1499        self.axis = tuple(axis)
1500
1501    def get_params(self, node):
1502        if self.axis is not None and len(self.axis) == 1:
1503            c_axis = np.int64(self.axis[0])
1504        else:
1505            # The value here doesn't matter, it won't be used
1506            c_axis = np.int64(-1)
1507        return self.params_type.get_params(c_axis=c_axis)
1508
1509    def make_node(self, x, axis=None):
1510        x = _as_tensor_variable(x)
1511        if self.axis is None:
1512            all_axes = list(range(x.ndim))
1513        else:
1514            all_axes = self.axis
1515        inputs = [x]
1516
1517        # We keep the original broadcastable flags for dimensions on which
1518        # we do not perform the argmax.
1519        broadcastable = [b for i, b in enumerate(x.type.broadcastable)
1520                         if i not in all_axes]
1521        outputs = [tensor('int64', broadcastable, name='argmax')]
1522        return Apply(self, inputs, outputs)
1523
1524    def prepare_node(self, node, storage_map, compute_map, impl):
1525        if len(node.inputs) == 2:
1526            raise ValueError('You are trying to compile a graph with an old Argmax node.  Either reoptimize your graph or rebuild it to get the new node format.')
1527
1528    def perform(self, node, inp, outs, params):
1529        x, = inp
1530        axes = self.axis
1531        max_idx, = outs
1532        if axes is None:
1533            axes = tuple(range(x.ndim))
1534
1535        # Numpy does not support multiple axes for argmax
1536        # Work around
1537        keep_axes = np.array([i for i in range(x.ndim) if i not in axes],
1538                             dtype='int64')
1539        # Not-reduced axes in front
1540        transposed_x = np.transpose(x, np.concatenate((keep_axes,
1541                                                       axes)))
1542        kept_shape = transposed_x.shape[:len(keep_axes)]
1543        reduced_shape = transposed_x.shape[len(keep_axes):]
1544        new_shape = kept_shape + (np.prod(reduced_shape),)
1545        reshaped_x = transposed_x.reshape(new_shape)
1546
1547        max_idx[0] = theano._asarray(np.argmax(reshaped_x, axis=-1),
1548                                     dtype='int64')
1549
1550    def c_code(self, node, name, inp, out, sub):
1551        x, = inp
1552        argmax, = out
1553        fail = sub["fail"]
1554        params = sub["params"]
1555        if self.axis is None:
1556            axis_code = "axis = NPY_MAXDIMS;"
1557        else:
1558            if len(self.axis) > 1:
1559                raise NotImplementedError()
1560            # params is only used here for now
1561            axis_code = """
1562            axis = %(params)s->c_axis;
1563            if(axis > PyArray_NDIM(%(x)s)-1 || axis < -PyArray_NDIM(%(x)s)){
1564                PyErr_SetString(PyExc_ValueError,
1565                "Argmax, bad axis argument");
1566                %(fail)s
1567            }
1568            """ % locals()
1569        ret = """
1570        int axis;
1571
1572        Py_CLEAR(%(argmax)s);//todo pass them as out parameter.
1573        %(axis_code)s
1574
1575        %(argmax)s = (PyArrayObject*)PyArray_ArgMax(%(x)s, axis, NULL);
1576        if(%(argmax)s == NULL){
1577            %(fail)s;
1578        }
1579        if(!PyArray_CheckExact(%(argmax)s)){
1580            %(argmax)s = (PyArrayObject*)PyArray_FromAny((PyObject*)%(argmax)s, NULL, 0, 0, NPY_ARRAY_ENSUREARRAY, NULL);
1581            if(%(argmax)s == NULL){
1582                %(fail)s;
1583            }
1584        }
1585        if(PyArray_TYPE(%(argmax)s) != NPY_INT64){
1586            PyObject * tmp = PyArray_Cast(%(argmax)s, NPY_INT64);
1587            if (NULL == tmp){
1588                %(fail)s;
1589            }
1590            Py_DECREF(%(argmax)s);
1591            %(argmax)s = (PyArrayObject*)tmp;
1592        }
1593        """
1594        return ret % locals()
1595
1596    def c_code_cache_version(self):
1597        return (1,)
1598
1599    def infer_shape(self, node, shapes):
1600        ishape, = shapes
1601        if self.axis is None:
1602            return [()]
1603        rval = tuple([ishape[i] for (i, b) in enumerate(
1604            node.inputs[0].type.broadcastable) if i not in self.axis])
1605        return [rval]
1606
1607    def grad(self, inp, grads):
1608        x, = inp
1609
1610        return [x.zeros_like()]
1611
1612
1613def makeKeepDims(x, y, axis):
1614    """
1615    Reintroduces in y with length one the axes of x which have been left out
1616    in a prior reduction of x. With this option, the resulting tensor will
1617    broadcast correctly against the original tensor x.
1618
1619    """
1620    x = as_tensor_variable(x)
1621    y = as_tensor_variable(y)
1622
1623    if axis is None:
1624        axis = list(range(x.type.ndim))
1625    elif isinstance(axis, (integer_types, np.integer)):
1626        axis = [axis]
1627    elif isinstance(axis, np.ndarray) and axis.ndim == 0:
1628        axis = [int(axis)]
1629    else:
1630        axis = [int(a) for a in axis]
1631    newaxis = []
1632    for a in axis:
1633        if not isinstance(a, integer_types):
1634            raise ValueError(
1635                "keepdims option can be used only with constant axis")
1636        if a < 0:
1637            a += x.type.ndim
1638        newaxis.append(a)
1639    i = 0
1640    new_dims = []
1641    for j, _ in enumerate(x.type.broadcastable):
1642        if j in newaxis:
1643            new_dims.append('x')
1644        else:
1645            new_dims.append(i)
1646            i += 1
1647    return DimShuffle(y.type.broadcastable, new_dims)(y)
1648
1649
1650@constructor
1651def max_and_argmax(a, axis=None, keepdims=False):
1652    """
1653    Returns maximum elements and their indices obtained by iterating over
1654    given axis.
1655
1656    When axis is None (the default value), the max is performed
1657    over the flattened tensor.
1658
1659    Parameters
1660    ----------
1661    keepdims : bool
1662        If this is set to True, the axes which are reduced are left in
1663        the result as dimensions with size one. With this option, the result
1664        will broadcast correctly against the original tensor.
1665
1666    """
1667    # Check axis and convert it to a Python list of integers.
1668    # Axis will be used as an op param of MaxAndArgmax.
1669    a = as_tensor_variable(a)
1670    axis = check_and_normalize_axes(a, axis)
1671    if len(axis) == 0:
1672        axis = list(range(a.type.ndim))
1673    out, argout = MaxAndArgmax(axis)(a)
1674
1675    if keepdims:
1676        out = makeKeepDims(a, out, axis)
1677        argout = makeKeepDims(a, argout, axis)
1678    return [out, argout]
1679
1680
1681@constructor
1682def max(x, axis=None, keepdims=False):
1683    """
1684    Returns maximum elements obtained by iterating over given axis.
1685
1686    When axis is None (the default value), the max is performed
1687    over the flattened tensor.
1688
1689    Parameters
1690    ----------
1691    keepdims: bool
1692        If this is set to True, the axes which are reduced are left in
1693        the result as dimensions with size one. With this option, the result
1694        will broadcast correctly against the original tensor.
1695
1696    Notes
1697    -----
1698    We return an error as numpy when we reduce a dim with a shape of 0.
1699
1700    """
1701
1702    # We have a choice of implementing this call with the
1703    # CAReduce op or the MaxAndArgmax op.
1704
1705    # MaxAndArgmax supports grad and Rop, so we prefer to use that.
1706    # CAReduce is faster, but optimizations will replace MaxAndArgmax[0]
1707    # with CAReduce at compile time, so at this stage the important
1708    # thing is supporting all user interface features, not speed.
1709    # Some cases can be implemented only with CAReduce.
1710
1711    # We thus prefer to use MaxAndArgmax, if possible. It does not
1712    # support all axis arguments, so we may need to fall back to CAReduce.
1713
1714    try:
1715        out = max_and_argmax(x, axis)[0]
1716    except Exception:
1717        out = CAReduce(scal.maximum, axis)(x)
1718
1719    if keepdims:
1720        out = makeKeepDims(x, out, axis)
1721    return out
1722
1723
1724@constructor
1725def argmax(x, axis=None, keepdims=False):
1726    """
1727    Returns indices of maximum elements obtained by iterating over given axis.
1728
1729    When axis is None (the default value), the argmax is performed
1730    over the flattened tensor.
1731
1732    Parameters
1733    ----------
1734    keepdims : bool
1735        If this is set to True, the axes which are reduced are left in
1736        the result as dimensions with size one. With this option, the result
1737        will broadcast correctly against the original tensor.
1738
1739    """
1740    argout = max_and_argmax(x, axis)[1]
1741
1742    if keepdims:
1743        argout = makeKeepDims(x, argout, axis)
1744    return argout
1745
1746
1747@constructor
1748def min(x, axis=None, keepdims=False):
1749    """
1750    Returns minimum elements obtained by iterating over given axis.
1751
1752    When axis is None (the default value), the min is performed
1753    over the flattened tensor.
1754
1755    Parameters
1756    ----------
1757    keepdims: bool
1758        If this is set to True, the axes which are reduced are left in
1759        the result as dimensions with size one. With this option, the result
1760        will broadcast correctly against the original tensor.
1761
1762    """
1763    x = as_tensor_variable(x)
1764    str_x_type = str(x.dtype)
1765    if str_x_type.startswith('float') or str_x_type in int_dtypes:
1766        return -max(-x, axis=axis, keepdims=keepdims)
1767    elif str_x_type in uint_dtypes:
1768        itype = np.iinfo(x.dtype)
1769        max_val = np.array(itype.max, dtype=itype.dtype)
1770        return max_val - max(max_val - x, axis=axis, keepdims=keepdims)
1771    elif str_x_type == 'bool':
1772        return ~max(~x, axis=axis, keepdims=keepdims)
1773    else:
1774        # Be careful about unsigned integers, complex
1775        raise NotImplementedError()
1776
1777
1778@constructor
1779def argmin(x, axis=None, keepdims=False):
1780    """
1781    Returns indices of minimum elements obtained by iterating over given axis.
1782
1783    When axis is None (the default value), the argmin is performed
1784    over the flattened tensor.
1785
1786    Parameters
1787    ----------
1788    keepdims: bool
1789        If this is set to True, the axes which are reduced are left in
1790        the result as dimensions with size one. With this option, the result
1791        will broadcast correctly against the original tensor.
1792
1793    """
1794    x = as_tensor_variable(x)
1795    str_x_type = str(x.dtype)
1796    if str_x_type.startswith('float') or str_x_type in int_dtypes:
1797        return argmax(-x, axis=axis, keepdims=keepdims)
1798    elif str_x_type in uint_dtypes:
1799        itype = np.iinfo(x.dtype)
1800        return argmax(itype.max - x, axis=axis, keepdims=keepdims)
1801    elif str_x_type == 'bool':
1802        return argmax(~x, axis=axis, keepdims=keepdims)
1803    else:
1804        # Be careful about unsigned integers, complex
1805        raise NotImplementedError()
1806
1807
1808@constructor
1809def smallest(*args):
1810    """
1811    Return the [elementwise] smallest of a variable number of arguments.
1812
1813    Like python's min.
1814
1815    """
1816    if len(args) == 2:
1817        a, b = args
1818        return switch(a < b, a, b)
1819    else:
1820        return min(stack(args), axis=0)
1821
1822
1823@constructor
1824def largest(*args):
1825    """
1826    Return the [elementwise] largest of a variable number of arguments.
1827
1828    Like python's max.
1829
1830    """
1831    if len(args) == 2:
1832        a, b = args
1833        return switch(a > b, a, b)
1834    else:
1835        return max(stack(args), axis=0)
1836
1837
1838##########################
1839# Comparison
1840##########################
1841
1842@_scal_elemwise
1843def lt(a, b):
1844    """a < b"""
1845
1846
1847@_scal_elemwise
1848def gt(a, b):
1849    """a > b"""
1850
1851
1852@_scal_elemwise
1853def le(a, b):
1854    """a <= b"""
1855
1856
1857@_scal_elemwise
1858def ge(a, b):
1859    """a >= b"""
1860
1861
1862@_scal_elemwise
1863def eq(a, b):
1864    """a == b"""
1865
1866
1867@_scal_elemwise
1868def neq(a, b):
1869    """a != b"""
1870
1871
1872@_scal_elemwise
1873def isnan(a):
1874    """isnan(a)"""
1875
1876# Rename isnan to isnan_ to allow to bypass it when not needed.
1877# glibc 2.23 don't allow isnan on int, so we remove it from the graph.
1878isnan_ = isnan
1879
1880
1881def isnan(a):
1882    """isnan(a)"""
1883    a = as_tensor_variable(a)
1884    if a.dtype in discrete_dtypes:
1885        return alloc(np.asarray(False, dtype="bool"),
1886                     *[a.shape[i] for i in range(a.ndim)])
1887    return isnan_(a)
1888
1889
1890@_scal_elemwise
1891def isinf(a):
1892    """isinf(a)"""
1893
1894# Rename isnan to isnan_ to allow to bypass it when not needed.
1895# glibc 2.23 don't allow isnan on int, so we remove it from the graph.
1896isinf_ = isinf
1897
1898
1899def isinf(a):
1900    """isinf(a)"""
1901    a = as_tensor_variable(a)
1902    if a.dtype in discrete_dtypes:
1903        return alloc(np.asarray(False, dtype="bool"),
1904                     *[a.shape[i] for i in range(a.ndim)])
1905    return isinf_(a)
1906
1907
1908def allclose(a, b, rtol=1.e-5, atol=1.e-8, equal_nan=False):
1909    """
1910    Implement Numpy's ``allclose`` on tensors.
1911
1912    ``absolute(a - b) <= (atol + rtol * absolute(b))``
1913
1914    Parameters
1915    ----------
1916    a : tensor
1917        Input to compare.
1918    b : tensor
1919        Input to compare.
1920    rtol : float
1921        The relative tolerance parameter.
1922    atol : float
1923        The absolute tolerance parameter.
1924    equal_nan: bool
1925        Whether to consider nan's in the same place to be close.
1926
1927    Returns
1928    -------
1929    bool
1930        A boolean value (of type int8 returned by the tensor elementwise `all`
1931        function) whether all elements in a and b are in the tolerance range
1932        defined above.
1933
1934    Notes
1935    -----
1936    Not a symmetric equation. See Numpy's documentation.
1937
1938    """
1939    return all(isclose(a, b, rtol, atol, equal_nan))
1940
1941
1942def isclose(a, b, rtol=1.e-5, atol=1.e-8, equal_nan=False):
1943    """
1944    Implements Numpy's ``isclose`` on tensors.
1945
1946    The tolerance values are positive, typically very small numbers. The
1947    relative difference (`rtol` * abs(`b`)) and the absolute difference
1948    `atol` are added together to compare against the absolute difference
1949    between `a` and `b`.
1950
1951    ``absolute(a - b) <= (atol + rtol * absolute(b))``
1952
1953    Parameters
1954    ----------
1955    a : tensor
1956        Input to compare.
1957    b : tensor
1958        Input to compare.
1959    rtol : float
1960        The relative tolerance parameter.
1961    atol : float
1962        The absolute tolerance parameter.
1963    equal_nan : bool
1964        Whether to consider nan's in the same place to be close
1965
1966    Returns
1967    -------
1968    int8
1969        A boolean (int8) array where two arrays are element-wise equal
1970        within a tolerance.
1971
1972    Notes
1973    -----
1974    Not a symmetric equation. See Numpy's documentation.
1975
1976    Examples
1977    --------
1978    >>> import theano
1979    >>> import numpy as np
1980    >>> a = theano._asarray([1e10, 1e-7], dtype="float64")
1981    >>> b = theano._asarray([1.00001e10, 1e-8], dtype="float64")
1982    >>> theano.tensor.isclose(a, b).eval()
1983    array([1, 0], dtype=int8)
1984    >>> a = theano._asarray([1e10, 1e-8], dtype="float64")
1985    >>> b = theano._asarray([1.00001e10, 1e-9], dtype="float64")
1986    >>> theano.tensor.isclose(a, b).eval()
1987    array([1, 1], dtype=int8)
1988    >>> a = theano._asarray([1e10, 1e-8], dtype="float64")
1989    >>> b = theano._asarray([1.0001e10, 1e-9], dtype="float64")
1990    >>> theano.tensor.isclose(a, b).eval()
1991    array([0, 1], dtype=int8)
1992    >>> a = theano._asarray([1.0, np.nan], dtype="float64")
1993    >>> b = theano._asarray([1.0, np.nan], dtype="float64")
1994    >>> theano.tensor.isclose(a, b).eval()
1995    array([1, 0], dtype==int8)
1996    >>> a = theano._asarray([1.0, np.nan], dtype="float64")
1997    >>> b = theano._asarray([1.0, np.nan], dtype="float64")
1998    >>> theano.tensor.isclose(a, b, equal_nan=True).eval()
1999    array([1, 1], dtype==int8)
2000    >>> a = theano._asarray([1.0, np.inf], dtype="float64")
2001    >>> b = theano._asarray([1.0, -np.inf], dtype="float64")
2002    >>> theano.tensor.isclose(a, b).eval()
2003    array([1, 0], dtype==int8)
2004    >>> a = theano._asarray([1.0, np.inf], dtype="float64")
2005    >>> b = theano._asarray([1.0, np.inf], dtype="float64")
2006    >>> theano.tensor.isclose(a, b).eval()
2007    array([1, 1], dtype==int8)
2008
2009    """
2010    # close will be an int8 array of 1 where within tolerance
2011    # and 0 where not within tolerance or there was a nan or inf value.
2012    diff = abs(a - b)
2013    tolerance = atol + rtol * abs(b)
2014    close_prelim = le(diff, tolerance)
2015
2016    a_nan = isnan(a)
2017    b_nan = isnan(b)
2018    nans = bitwise_or(a_nan, b_nan)
2019
2020    a_inf = isinf(a)
2021    b_inf = isinf(b)
2022    infs = bitwise_or(a_inf, b_inf)
2023
2024    nans_or_infs = bitwise_or(nans, infs)
2025
2026    # close is now an array of 0's except where elements are not nan or inf
2027    # and are within the tolerance.
2028    close = bitwise_and(close_prelim, bitwise_not(nans_or_infs))
2029
2030    # deal with signed inf values. this will make an array inf_eq of 0's
2031    # except where inf values have the same sign.
2032    both_infs = bitwise_and(a_inf, b_inf)
2033    inf_signs_eq = eq(a_inf * sgn(a), b_inf * sgn(b))
2034    inf_eq = bitwise_and(both_infs, inf_signs_eq)
2035
2036    # now create the potential result combining close and inf_eq
2037    close_with_infs = bitwise_or(close, inf_eq)
2038
2039    # deal with comparing nan's.
2040    if equal_nan:
2041        both_nans = bitwise_and(a_nan, b_nan)
2042        return bitwise_or(close_with_infs, both_nans)
2043    # otherwise nan's aren't considered close.
2044    else:
2045        return close_with_infs
2046
2047
2048##########################
2049# Condition
2050##########################
2051
2052@_scal_elemwise
2053def switch(cond, ift, iff):
2054    """if cond then ift else iff"""
2055
2056where = switch
2057##########################
2058# Bit-wise
2059##########################
2060
2061
2062@_scal_elemwise
2063def and_(a, b):
2064    """bitwise a & b"""
2065bitwise_and = and_  # numpy name for it
2066
2067
2068@_scal_elemwise
2069def or_(a, b):
2070    """bitwise a | b"""
2071bitwise_or = or_  # numpy name for it
2072
2073
2074@_scal_elemwise
2075def xor(a, b):
2076    """bitwise a ^ b"""
2077bitwise_xor = xor  # numpy name for it
2078
2079
2080@_scal_elemwise
2081def invert(a):
2082    """bitwise ~a"""
2083bitwise_not = invert  # numpy alias for it
2084
2085
2086##########################
2087# Math
2088##########################
2089
2090@_scal_elemwise
2091def abs_(a):
2092    """|`a`|
2093
2094    TensorVariable overloads the `TensorVariable.__abs__` operator so that
2095    this function is called when you type abs(a).
2096
2097    """
2098
2099pprint.assign(abs_, printing.PatternPrinter(('|%(0)s|', -1000)))
2100
2101
2102@_scal_elemwise
2103def exp(a):
2104    """e^`a`"""
2105
2106
2107@_scal_elemwise
2108def exp2(a):
2109    """2^`a`"""
2110
2111
2112@_scal_elemwise
2113def expm1(a):
2114    """e^`a` - 1"""
2115
2116
2117@_scal_elemwise
2118def neg(a):
2119    """-a"""
2120
2121
2122# numpy.reciprocal does integer division on integer inputs
2123# (which is not very interesting)
2124@_scal_elemwise
2125def inv(a):
2126    """1.0/a"""
2127
2128
2129@_scal_elemwise
2130def log(a):
2131    """base e logarithm of a"""
2132
2133
2134@_scal_elemwise
2135def log2(a):
2136    """base 2 logarithm of a"""
2137
2138
2139@_scal_elemwise
2140def log10(a):
2141    """base 10 logarithm of a"""
2142
2143
2144@_scal_elemwise
2145def log1p(a):
2146    """log(1+a)"""
2147
2148
2149@_scal_elemwise
2150def sgn(a):
2151    """sign of a"""
2152
2153
2154@_scal_elemwise
2155def ceil(a):
2156    """ceiling of a"""
2157
2158
2159@_scal_elemwise
2160def floor(a):
2161    """floor of a"""
2162
2163
2164@_scal_elemwise
2165def trunc(a):
2166    """trunc of a"""
2167
2168
2169@constructor
2170def iround(a, mode=None):
2171    """cast(round(a,mode),'int64')"""
2172    return cast(round(a, mode), 'int64')
2173
2174
2175@constructor
2176def round(a, mode=None):
2177    """round_mode(a) with mode in [half_away_from_zero, half_to_even].
2178    Default to half_to_even."""
2179    if mode is None:
2180        mode = "half_to_even"
2181        if config.warn.round:
2182            warnings.warn(
2183                "theano.tensor.round() changed its default from"
2184                " `half_away_from_zero` to `half_to_even` to have"
2185                " the same default as NumPy. Use the Theano flag"
2186                " `warn.round=False` to disable this warning.")
2187    if mode == "half_away_from_zero":
2188        return round_half_away_from_zero(a)
2189    elif mode == "half_to_even":
2190        return round_half_to_even(a)
2191    else:
2192        raise Exception("round mode %s is not implemented." % mode)
2193
2194
2195@_scal_elemwise
2196def round_half_to_even(a):
2197    """round_half_to_even(a)"""
2198
2199
2200@_scal_elemwise
2201def round_half_away_from_zero(a):
2202    """round_half_away_from_zero(a)"""
2203
2204
2205@_scal_elemwise
2206def sqr(a):
2207    """square of a"""
2208
2209
2210# alias to sqr, included to maintain similarity with numpy interface
2211square = sqr
2212
2213
2214def cov(m, y=None, rowvar=True, bias=False, ddof=None, fweights=None, aweights=None):
2215    """Calculate the covariance matrix.
2216    Covariance indicates the level to which two variables vary together.
2217    If we examine N-dimensional samples, :math:`m = [x_1, x_2, ... x_N]^T`,
2218    then the covariance matrix element :math:`C_{ij}` is the covariance of
2219    :math:`x_i` and :math:`x_j`. The element :math:`C_{ii}` is the variance
2220    of :math:`x_i`. Code and docstring ported from numpy.
2221    ----------
2222    m : array_like
2223        A 2-D array containing multiple variables and observations.
2224        Each row of `m` represents a variable, and each column is
2225        observations of all those variables.
2226    y : array_like, optional
2227        An additional set of variables and observations. `y` has the same form
2228        as that of `m`.
2229    rowvar : bool, optional
2230        If `rowvar` is True (default), then each row represents a
2231        variable, with observations in the columns. Otherwise, the relationship
2232        is transposed: each column represents a variable, while the rows
2233        contain observations.
2234    bias : bool, optional
2235        Default normalization (False) is by ``(N - 1)``, where ``N`` is the
2236        number of observations given (unbiased estimate). If `bias` is True, then
2237        normalization is by ``N``. These values can be overridden by using the
2238        keyword ``ddof``.
2239    ddof : int, optional
2240        If not ``None`` the default value implied by `bias` is overridden.
2241        The default value is ``None``.
2242    Returns
2243    -------
2244    out : The covariance matrix of the variables.
2245    """
2246
2247    if fweights is not None:
2248        raise NotImplementedError('fweights are not implemented')
2249    if aweights is not None:
2250        raise NotImplementedError('aweights are not implemented')
2251
2252    if not rowvar and m.shape[0] != 1:
2253        m = m.T
2254
2255    if y is not None:
2256        if not rowvar and y.shape[0] != 1:
2257            y = y.T
2258        m = theano.tensor.concatenate((m, y), axis=0)
2259
2260    if ddof is None:
2261        if not bias:
2262            ddof = 1
2263        else:
2264            ddof = 0
2265
2266    # Determine the normalization
2267    fact = m.shape[1] - ddof
2268
2269    m -= m.mean(axis=1, keepdims=1)
2270    c = m.dot(m.T)
2271    c *= theano.tensor.constant(1) / fact
2272    return c.squeeze()
2273
2274
2275@_scal_elemwise
2276def sqrt(a):
2277    """square root of a"""
2278
2279
2280@_scal_elemwise
2281def deg2rad(a):
2282    """convert degree a to radian"""
2283
2284
2285@_scal_elemwise
2286def rad2deg(a):
2287    """convert radian a to degree"""
2288
2289
2290@_scal_elemwise
2291def cos(a):
2292    """cosine of a"""
2293
2294
2295@_scal_elemwise
2296def arccos(a):
2297    """arccosine of a"""
2298
2299
2300@_scal_elemwise
2301def sin(a):
2302    """sine of a"""
2303
2304
2305@_scal_elemwise
2306def arcsin(a):
2307    """arcsine of a"""
2308
2309
2310@_scal_elemwise
2311def tan(a):
2312    """tangent of a"""
2313
2314
2315@_scal_elemwise
2316def arctan(a):
2317    """arctangent of a"""
2318
2319
2320@_scal_elemwise
2321def arctan2(a, b):
2322    """arctangent of a / b"""
2323
2324
2325@_scal_elemwise
2326def cosh(a):
2327    """hyperbolic cosine of a"""
2328
2329
2330@_scal_elemwise
2331def arccosh(a):
2332    """hyperbolic arc cosine of a"""
2333
2334
2335@_scal_elemwise
2336def sinh(a):
2337    """hyperbolic sine of a"""
2338
2339
2340@_scal_elemwise
2341def arcsinh(a):
2342    """hyperbolic arc sine of a"""
2343
2344
2345@_scal_elemwise
2346def tanh(a):
2347    """hyperbolic tangent of a"""
2348
2349
2350@_scal_elemwise
2351def arctanh(a):
2352    """hyperbolic arc tangent of a"""
2353
2354
2355@_scal_elemwise
2356def erf(a):
2357    """error function"""
2358
2359
2360@_scal_elemwise
2361def erfc(a):
2362    """complementary error function"""
2363
2364
2365@_scal_elemwise
2366def erfcx(a):
2367    """scaled complementary error function"""
2368
2369
2370@_scal_elemwise
2371def erfinv(a):
2372    """inverse error function"""
2373
2374
2375@_scal_elemwise
2376def erfcinv(a):
2377    """inverse complementary error function"""
2378
2379
2380@_scal_elemwise
2381def gamma(a):
2382    """gamma function"""
2383
2384
2385@_scal_elemwise
2386def gammaln(a):
2387    """log gamma function"""
2388
2389
2390@_scal_elemwise
2391def psi(a):
2392    """derivative of log gamma function"""
2393
2394
2395@_scal_elemwise
2396def tri_gamma(a):
2397    """second derivative of the log gamma function"""
2398
2399
2400@_scal_elemwise
2401def chi2sf(x, k):
2402    """chi squared survival function"""
2403
2404
2405@_scal_elemwise
2406def gammainc(k, x):
2407    """Regularized lower gamma function"""
2408
2409
2410@_scal_elemwise
2411def gammaincc(k, x):
2412    """Regularized upper gamma function"""
2413
2414
2415@_scal_elemwise
2416def gammau(k, x):
2417    """Upper incomplete gamma function."""
2418
2419
2420@_scal_elemwise
2421def gammal(k, x):
2422    """Lower incomplete gamma function."""
2423
2424
2425@_scal_elemwise
2426def j0(x):
2427    """Bessel function of the first kind of order 0."""
2428
2429
2430@_scal_elemwise
2431def j1(x):
2432    """Bessel function of the first kind of order 1."""
2433
2434
2435@_scal_elemwise
2436def jv(v, x):
2437    """Bessel function of the first kind of order v (real)."""
2438
2439
2440@_scal_elemwise
2441def i0(x):
2442    """Modified Bessel function of the first kind of order 0."""
2443
2444
2445@_scal_elemwise
2446def i1(x):
2447    """Modified Bessel function of the first kind of order 1."""
2448
2449
2450@_scal_elemwise
2451def iv(v, x):
2452    """Modified Bessel function of the first kind of order v (real)."""
2453
2454
2455@_scal_elemwise
2456def real(z):
2457    """Return real component of complex-valued tensor `z`"""
2458_tensor_py_operators.real = property(real)
2459
2460
2461@_scal_elemwise
2462def imag(z):
2463    """Return imaginary component of complex-valued tensor `z`"""
2464_tensor_py_operators.imag = property(imag)
2465
2466
2467@_scal_elemwise
2468def angle(z):
2469    """Return polar-coordinate angle of complex-valued tensor `z`"""
2470
2471
2472@_scal_elemwise  # numpy.complex cannot build tensors
2473def complex(real, imag):
2474    """Return complex-valued tensor with `real` and `imag` components"""
2475
2476
2477@_scal_elemwise
2478def conj(z):
2479    """Return the complex conjugate of `z`."""
2480
2481
2482@_scal_elemwise
2483def complex_from_polar(abs, angle):
2484    """Return complex-valued tensor from polar coordinate specification."""
2485
2486##########################
2487# Misc
2488##########################
2489
2490
2491# fill, _fill_inplace = _elemwise(scal.second, 'fill',
2492# """fill WRITEME (elemwise)""")
2493@_scal_elemwise
2494def second(a, b):
2495    """Create a matrix by filling the shape of a with b"""
2496
2497fill = second
2498pprint.assign(fill, printing.FunctionPrinter('fill'))
2499
2500
2501@constructor
2502def ones_like(model, dtype=None, opt=False):
2503    """equivalent of numpy.ones_like
2504    Parameters
2505    ----------
2506    model : tensor
2507    dtype : data-type, optional
2508    opt : If True, we will return a constant instead of a graph when possible.
2509          Useful for Theano optimization, not for user building a graph as this
2510          have the consequence that model isn't always in the graph.
2511
2512    Returns
2513    -------
2514    tensor
2515        tensor the shape of model containing ones of the type of dtype.
2516    """
2517    if dtype is None:
2518        dtype = model.type.dtype
2519    ret = constant(1.0, dtype=dtype)
2520    if opt and ret.type == model.type:
2521        return ret
2522    return fill(model, ret)
2523
2524
2525@constructor
2526def zeros_like(model, dtype=None, opt=False):
2527    """equivalent of numpy.zeros_like
2528    Parameters
2529    ----------
2530    model : tensor
2531    dtype : data-type, optional
2532    opt : If True, we will return a constant instead of a graph when possible.
2533          Useful for Theano optimization, not for user building a graph as this
2534          have the consequence that model isn't always in the graph.
2535
2536    Returns
2537    -------
2538    tensor
2539        tensor the shape of model containing zeros of the type of dtype.
2540    """
2541
2542    if dtype is None:
2543        dtype = model.type.dtype
2544    ret = constant(0.0, dtype=dtype)
2545    if opt and ret.type == model.type:
2546        return ret
2547    return fill(model, ret)
2548
2549
2550def zeros(shape, dtype=None):
2551    """
2552    Create a Tensor filled with zeros, closer to Numpy's syntax than ``alloc``.
2553    """
2554    if not isinstance(shape, (list, tuple, TensorVariable)):
2555        shape = [shape]
2556    if dtype is None:
2557        dtype = config.floatX
2558    return alloc(np.array(0, dtype=dtype), *shape)
2559
2560
2561def ones(shape, dtype=None):
2562    """
2563    Create a Tensor filled with ones, closer to Numpy's syntax than ``alloc``.
2564    """
2565    if not isinstance(shape, (list, tuple, TensorVariable)):
2566        shape = [shape]
2567    if dtype is None:
2568        dtype = config.floatX
2569    return alloc(np.array(1, dtype=dtype), *shape)
2570
2571
2572class Nonzero(gof.Op):
2573    """
2574    Return the indices of the elements that are non-zero.
2575
2576    Returns a matrix of shape (ndim, number of nonzero elements) such that
2577    element (i,j) is the index in the ith dimension of the jth non-zero
2578    element.
2579
2580    Note this is different than NumPy, which returns a tuple of arrays, one for
2581    each dimension of the input array.
2582
2583    Parameters
2584    ----------
2585    a : array_like
2586        Input array.
2587
2588    Returns
2589    -------
2590    matrix
2591        Matrix containing the indices of the non-zero elements of a.
2592
2593    See Also
2594    --------
2595    nonzero_values : Return the non-zero elements of the input array
2596    flatnonzero : Return the indices of the non-zero elements of the
2597        flattened input array.
2598
2599    """
2600    __props__ = ()
2601
2602    def make_node(self, a):
2603        a = as_tensor_variable(a)
2604        if a.ndim == 0:
2605            raise ValueError('Nonzero only supports non-scalar arrays.')
2606        output = [TensorType(dtype='int64', broadcastable=(False, False))()]
2607        return gof.Apply(self, [a], output)
2608
2609    def perform(self, node, inp, out_):
2610        a = inp[0]
2611        out, = out_
2612
2613        result_tuple = np.nonzero(a)
2614        if len(result_tuple[0]) > 0:
2615            result = np.vstack(result_tuple)
2616        else:
2617            result = np.zeros((len(result_tuple), 0))
2618
2619        out[0] = result.astype('int64')
2620
2621    def grad(self, inp, grads):
2622        return [grad_undefined(self, 0, inp[0])]
2623
2624
2625_nonzero = Nonzero()
2626
2627
2628def nonzero(a, return_matrix=False):
2629    """
2630    Returns one of the following:
2631
2632        If return_matrix is False (default, same as NumPy):
2633            A tuple of vector arrays such that the ith element of the jth array
2634            is the index of the ith non-zero element of the input array in the
2635            jth dimension.
2636
2637        If return_matrix is True (same as Theano Op):
2638            Returns a matrix of shape (ndim, number of nonzero elements) such
2639            that element (i,j) is the index in the ith dimension of the jth
2640            non-zero element.
2641
2642    Parameters
2643    ----------
2644    a : array_like
2645        Input array.
2646    return_matrix : bool
2647        If True, returns a symbolic matrix. If False, returns a tuple of
2648        arrays. Defaults to False.
2649
2650    Returns
2651    -------
2652    tuple of vectors or matrix
2653
2654    See Also
2655    --------
2656    nonzero_values : Return the non-zero elements of the input array
2657    flatnonzero : Return the indices of the non-zero elements of the
2658        flattened input array.
2659
2660    """
2661    matrix_result = _nonzero(a)
2662    if return_matrix:
2663        return matrix_result
2664    else:
2665        if a.ndim > 0:
2666            tuple_result = tuple([matrix_result[i] for i in xrange(a.ndim)])
2667        else:
2668            tuple_result = tuple([matrix_result[0]])
2669        return tuple_result
2670
2671
2672def flatnonzero(a):
2673    """
2674    Return a vector of indices that are non-zero in the flattened version of a.
2675
2676    This is equivalent to nonzero(a.flatten(), return_matrix=True)[0]
2677
2678    Parameters
2679    ----------
2680    a : tensor
2681        Input tensor
2682
2683    Returns
2684    -------
2685    vector
2686        Output vector, containing the indices of the elements of `a.flatten()`
2687        that are non-zero.
2688
2689    See Also
2690    --------
2691    nonzero : Return the indices of the non-zero elements of the input array.
2692    nonzero_values : Return the non-zero elements of the input array
2693
2694    """
2695    if a.ndim == 0:
2696        raise ValueError('Nonzero only supports non-scalar arrays.')
2697    return nonzero(a.flatten(), return_matrix=True)[0]
2698
2699
2700def nonzero_values(a):
2701    """
2702    Return a vector of non-zero elements contained in the input array.
2703
2704    The following behavior works to extract non-zero elements from an array
2705    in NumPy but is *NOT* supported by Theano:
2706
2707        a[numpy.nonzero(a)]
2708
2709    Instead, the nonzero_values function or method should be used:
2710
2711        tensor.nonzero_values(a)
2712        a.nonzero_values()
2713
2714    This is equivalent to the following:
2715
2716        a.flatten()[tensor.flatnonzero(a)]
2717
2718    Parameters
2719    ----------
2720    a : tensor
2721        Input tensor
2722
2723    Returns
2724    -------
2725    vector
2726        Output vector, containing the non-zero elements of a.
2727
2728    See Also
2729    --------
2730    nonzero : Return the indices of the non-zero elements of the input array.
2731    flatnonzero : Return the indices of the non-zero elements of the
2732        flattened input array.
2733
2734    """
2735    return a.flatten()[flatnonzero(a)]
2736
2737
2738class Tri(gof.Op):
2739
2740    __props__ = ("dtype",)
2741
2742    def __init__(self, dtype=None):
2743        if dtype is None:
2744            dtype = config.floatX
2745        self.dtype = dtype
2746
2747    def make_node(self, N, M, k):
2748        N = as_tensor_variable(N)
2749        M = as_tensor_variable(M)
2750        k = as_tensor_variable(k)
2751        return gof.Apply(
2752            self,
2753            [N, M, k],
2754            [TensorType(dtype=self.dtype, broadcastable=(False, False))()])
2755
2756    def perform(self, node, inp, out_):
2757        N, M, k = inp
2758        out, = out_
2759        out[0] = np.tri(N, M, k, dtype=self.dtype)
2760
2761    def infer_shape(self, node, in_shapes):
2762        out_shape = [node.inputs[0], node.inputs[1]]
2763        return [out_shape]
2764
2765    def grad(self, inp, grads):
2766        return [grad_undefined(self, i, inp[i]) for i in xrange(3)]
2767
2768
2769def tri(N, M=None, k=0, dtype=None):
2770    """
2771    An array with ones at and below the given diagonal and zeros elsewhere.
2772
2773    Parameters
2774    ----------
2775    N : int
2776        Number of rows in the array.
2777    M : int, optional
2778        Number of columns in the array.
2779        By default, `M` is taken equal to `N`.
2780    k : int, optional
2781        The sub-diagonal at and below which the array is filled.
2782        `k` = 0 is the main diagonal, while `k` < 0 is below it,
2783        and `k` > 0 is above.  The default is 0.
2784    dtype : dtype, optional
2785        Data type of the returned array.  The default is float.
2786
2787    Returns
2788    -------
2789    Array of shape (N, M)
2790        Array with its lower triangle filled with ones and zero elsewhere;
2791        in other words ``T[i,j] == 1`` for ``i <= j + k``, 0 otherwise.
2792
2793    """
2794    if dtype is None:
2795        dtype = config.floatX
2796    if M is None:
2797        M = N
2798    op = Tri(dtype)
2799    return op(N, M, k)
2800
2801
2802def tril(m, k=0):
2803    """
2804    Lower triangle of an array.
2805
2806    Return a copy of an array with elements above the `k`-th diagonal zeroed.
2807
2808    Parameters
2809    ----------
2810    m : array_like, shape (M, N)
2811        Input array.
2812    k : int, optional
2813        Diagonal above which to zero elements.  `k = 0` (the default) is the
2814        main diagonal, `k < 0` is below it and `k > 0` is above.
2815
2816    Returns
2817    -------
2818    array, shape (M, N)
2819        Lower triangle of `m`, of same shape and data-type as `m`.
2820
2821    See Also
2822    --------
2823    triu : Same thing, only for the upper triangle.
2824
2825    """
2826    return m * tri(m.shape[0], m.shape[1], k=k, dtype=m.dtype)
2827
2828
2829def triu(m, k=0):
2830    """
2831    Upper triangle of an array.
2832
2833    Return a copy of a matrix with the elements below the `k`-th diagonal
2834    zeroed.
2835
2836    Please refer to the documentation for `tril` for further details.
2837
2838    See Also
2839    --------
2840    tril : Lower triangle of an array.
2841
2842    """
2843    return m * (1 - tri(m.shape[0], m.shape[1], k=k - 1, dtype=m.dtype))
2844
2845
2846class Eye(gof.Op):
2847
2848    __props__ = ("dtype", )
2849
2850    def __init__(self, dtype=None):
2851        if dtype is None:
2852            dtype = config.floatX
2853        self.dtype = dtype
2854
2855    def make_node(self, n, m, k):
2856        n = as_tensor_variable(n)
2857        m = as_tensor_variable(m)
2858        k = as_tensor_variable(k)
2859        assert n.ndim == 0
2860        assert m.ndim == 0
2861        assert k.ndim == 0
2862        return gof.Apply(
2863            self,
2864            [n, m, k],
2865            [TensorType(dtype=self.dtype, broadcastable=(False, False))()])
2866
2867    def perform(self, node, inp, out_):
2868        n, m, k = inp
2869        out, = out_
2870        out[0] = np.eye(n, m, k, dtype=self.dtype)
2871
2872    def infer_shape(self, node, in_shapes):
2873        out_shape = [node.inputs[0], node.inputs[1]]
2874        return [out_shape]
2875
2876    def grad(self, inp, grads):
2877        return [grad_undefined(self, i, inp[i]) for i in xrange(3)]
2878
2879
2880def eye(n, m=None, k=0, dtype=None):
2881    """Return a 2-D array with ones on the diagonal and zeros elsewhere.
2882
2883    Parameters
2884    ----------
2885    n : int
2886        Number of rows in the output.
2887    m : int, optional
2888        Number of columns in the output. If None, defaults to `N`.
2889    k : int, optional
2890        Index of the diagonal: 0 (the default) refers to the main diagonal,
2891        a positive value refers to an upper diagonal, and a negative value
2892        to a lower diagonal.
2893    dtype : data-type, optional
2894        Data-type of the returned array.
2895
2896    Returns
2897    -------
2898    ndarray of shape (N,M)
2899        An array where all elements are equal to zero, except for the `k`-th
2900        diagonal, whose values are equal to one.
2901
2902    """
2903    if dtype is None:
2904        dtype = config.floatX
2905    if m is None:
2906        m = n
2907    localop = Eye(dtype)
2908    return localop(n, m, k)
2909
2910
2911def identity_like(x):
2912    return eye(x.shape[0], x.shape[1], k=0, dtype=x.dtype)
2913
2914
2915def alloc_validate_shape(shape):
2916    sh = [as_tensor_variable(s) for s in shape]
2917    bcast = []
2918    for i, s in enumerate(sh):
2919        def err_str():
2920            if config.exception_verbosity == 'high':
2921                return '\n' + min_informative_str(s)
2922            else:
2923                return str(s)
2924        if s.type.dtype not in integer_dtypes:
2925            s_as_str = err_str()
2926            raise TypeError('Shape arguments to Alloc must be integers, '
2927                            'but argument %s is not for apply node: %s' %
2928                            (i, s_as_str))
2929        if s.ndim != 0:
2930            s_as_str = err_str()
2931            raise TypeError(
2932                "Each shape dimension to Alloc must be a scalar, ",
2933                'but dimension %s have %d dimensions for apply node: %s' %
2934                (i, s.ndim, s_as_str))
2935
2936        # if s is constant 1, then we're broadcastable in that dim
2937        try:
2938            const_shp = get_scalar_constant_value(s)
2939        except NotScalarConstantError:
2940            const_shp = None
2941        bcast.append(1 == const_shp)
2942    return sh, bcast
2943
2944
2945class Alloc(gof.Op):
2946    """Create a Tensor from an initial value and a desired shape.
2947
2948    alloc(value, shape0, shape1, ..., shapeN)
2949
2950    Returns an N-dimensional tensor initialized by `value` using something
2951    equivalent to
2952
2953        z = numpy.zeros(shape, value.dtype)
2954        z += value
2955
2956    The result has N dimensions, has the dtype of `value` and is obtained by
2957    broadcasting value over the output ndarray.
2958
2959    This Op is used to replace fill() during optimizations because after shapes
2960    are lifted, the first argument to fill can often be pruned from the graph.
2961
2962    """
2963    _f16_ok = True
2964    __props__ = ()
2965
2966    def validate_shape(self, shape):
2967        return alloc_validate_shape(shape)
2968
2969    def make_node(self, value, *shape):
2970        v = as_tensor_variable(value)
2971        sh, bcast = alloc_validate_shape(shape)
2972        if v.ndim > len(sh):
2973            raise TypeError("The Alloc value to use has more dimensions"
2974                            " than the specified dimensions",
2975                            v.ndim, len(sh))
2976        otype = TensorType(dtype=v.dtype, broadcastable=bcast)
2977        return gof.Apply(self, [v] + sh, [otype()])
2978
2979    def perform(self, node, inputs, out_):
2980        out, = out_
2981        v = inputs[0]
2982        sh = tuple([int(i) for i in inputs[1:]])
2983        if out[0] is None or out[0].shape != sh:
2984            if v.size == 1 and v.item() == 0:
2985                out[0] = np.zeros(sh, dtype=v.dtype)
2986            else:
2987                out[0] = np.empty(sh, dtype=v.dtype)
2988                out[0][...] = v  # broadcast v to fill us up
2989        else:
2990            # reuse the allocated memory.
2991            out[0][...] = v  # broadcast v to fill us up
2992
2993    def c_code(self, node, name, inp, out, sub):
2994        vv = inp[0]
2995        ndim = len(inp[1:])
2996        zz, = out
2997        fail = sub['fail']
2998
2999        code = """
3000            npy_intp shape[%(ndim)s];
3001            """ % dict(ndim=ndim)
3002
3003        # Initialize shape
3004        for i, shp_i in enumerate(inp[1:]):
3005            code += """
3006                shape[%(i)s] = ((dtype_%(shp_i)s*) PyArray_DATA(%(shp_i)s))[0];
3007                """ % dict(i=i, shp_i=shp_i)
3008
3009        code += """
3010            int need_new_out = (NULL == %(zz)s);
3011            for (int i = 0; i < %(ndim)s; i++)
3012                need_new_out = (need_new_out
3013                                || (PyArray_DIMS(%(zz)s)[i] != shape[i]));
3014
3015            if (need_new_out)
3016            {
3017                Py_XDECREF(%(zz)s);
3018                %(zz)s = (PyArrayObject*) PyArray_SimpleNew(%(ndim)s,
3019                    shape, PyArray_TYPE((PyArrayObject*) py_%(vv)s));
3020                if (!%(zz)s)
3021                {
3022                    PyErr_SetString(PyExc_MemoryError, "alloc failed");
3023                    %(fail)s
3024                }
3025            }
3026
3027            // This function takes care of broadcasting
3028            if (PyArray_CopyInto(%(zz)s, %(vv)s) == -1)
3029              %(fail)s
3030            """ % dict(vv=vv, ndim=ndim, zz=zz, fail=fail)
3031
3032        return code
3033
3034    def c_code_cache_version(self):
3035        return (2,)
3036
3037    def infer_shape(self, node, input_shapes):
3038        return [node.inputs[1:]]
3039
3040    def connection_pattern(self, node):
3041
3042        rval = [[True]]
3043
3044        for ipt in node.inputs[1:]:
3045            rval.append([False])
3046
3047        return rval
3048
3049    def grad(self, inputs, grads):
3050        x = inputs[0]
3051        gz = grads[0]
3052        n_axes_to_sum = gz.ndim - x.ndim
3053        # The number of dimensions added
3054        axis = list(range(n_axes_to_sum))
3055        # The broadcasted dimensions
3056        axis_broadcasted = []
3057        axis_kept = []
3058        for i, (ib, gb) in enumerate(
3059            zip(inputs[0].broadcastable,
3060                # We need the dimensions corresponding to x
3061                grads[0].broadcastable[-inputs[0].ndim:])):
3062            if ib and not gb:
3063                axis_broadcasted.append(i + n_axes_to_sum)
3064            else:
3065                axis_kept.append(i)
3066        gx = gz.sum(axis=axis + axis_broadcasted)
3067        if axis_broadcasted:
3068            new_order = ['x'] * x.ndim
3069            for idx, axis in enumerate(axis_kept):
3070                new_order[axis] = idx
3071            gx = gx.dimshuffle(new_order)
3072            # Dimshuffle to add back the broadcasted dims
3073        # The *elements* of the output are not connected to
3074        # the inputs that specify the shape. If you grow the
3075        # shape by epsilon, the existing elements do not
3076        # change.
3077        return [gx] + [DisconnectedType()() for i in inputs[1:]]
3078
3079    def __call__(self, val, *shapes, **kwargs):
3080        """
3081        If the alloc would be useless, this function returns val.
3082
3083        If this function is called outside of a graph optimization context
3084        (for instance, it is manually called by a user building a graph),
3085        then we always return an Alloc node, to allow for DebugMode to check
3086        for size mismatches.
3087
3088        If you always want an Alloc node, call make_node.
3089
3090        """
3091        ret = super(Alloc, self).__call__(val, *shapes, **kwargs)
3092        try:
3093            # It makes optimization difficult when useless allocs are thrown
3094            # into the graph at every stage of optimization.  This little logic
3095            # tries to help at least in some cases.
3096            if hasattr(val, 'fgraph') and (val.type == ret.type):
3097                return val
3098        except AttributeError:
3099            pass
3100        return ret
3101
3102    def R_op(self, inputs, eval_points):
3103        if eval_points[0] is None:
3104            return [None]
3105        return self(eval_points[0], *inputs[1:], **dict(return_list=True))
3106
3107    def do_constant_folding(self, node):
3108        if not getattr(node.outputs[0], 'clients', []):
3109            # If there are no clients then there is no point doing constant
3110            # folding.
3111            return False
3112        for client in node.outputs[0].clients:
3113            if client[0] == 'output':
3114                # If the output is a constant, it will have to be deepcopied
3115                # each time the function is called.  So we do not fold.
3116                return False
3117            elif (
3118                # The following ops work inplace of their input id 0.
3119                client[1] == 0 and
3120                isinstance(client[0].op, (
3121                    # Ops that will work inplace on the Alloc. So if they
3122                    # get constant_folded, they would copy the
3123                    # constant and this is less efficients.
3124
3125                    # Not doing the constant folding could also lower
3126                    # the peak memory usage, as we the "constant" won't
3127                    # always exists.
3128                    theano.tensor.subtensor.IncSubtensor,
3129                    theano.tensor.subtensor.AdvancedIncSubtensor1,
3130                    theano.tensor.subtensor.AdvancedIncSubtensor,
3131                    theano.tensor.blas.Gemv,
3132                    theano.tensor.blas_c.CGemv,
3133                    theano.tensor.blas.Ger,
3134                    theano.tensor.blas_c.CGer,
3135                    theano.tensor.blas_scipy.ScipyGer))):
3136                return False
3137            # If the clients is a transfer to the GPU, we don't want to
3138            # fold. We let the Alloc being moved to the GPU, then we
3139            # let the GPU algo decide if it need to fold it or not.
3140            elif client[0].op.__class__.__name__.lower().startswith("gpu"):
3141                return False
3142        return True
3143
3144alloc = Alloc()
3145pprint.assign(alloc, printing.FunctionPrinter('alloc'))
3146
3147
3148def transfer(var, target):
3149    """
3150    Return a version of `var` transferred to `target`.
3151
3152    `cpu` mean a TensorType (on the CPU).  Other types may define
3153    additional targets.
3154
3155    Parameters
3156    ----------
3157    var : variable
3158        A theano variable
3159    target : str
3160        The target of the transfer
3161    """
3162    if target == 'cpu':
3163        return as_tensor_variable(var)
3164    else:
3165        for trans in transfer._others:
3166            res = trans(var, target)
3167            if res is not None:
3168                return res
3169    raise ValueError("Can't transfer to target %s" % (target,))
3170
3171transfer._others = []
3172
3173
3174def register_transfer(fn):
3175    """
3176    Register a transfer function for alternative targets.
3177
3178    Parameters
3179    ----------
3180    fn : callable
3181    """
3182    transfer._others.append(fn)
3183
3184"""Create a duplicate of `a` (with duplicated storage)"""
3185tensor_copy = elemwise.Elemwise(scal.identity)
3186pprint.assign(tensor_copy, printing.IgnorePrinter())
3187
3188
3189@constructor
3190def sum(input, axis=None, dtype=None, keepdims=False, acc_dtype=None):
3191    """
3192    Computes the sum along the given axis(es) of a tensor `input`.
3193
3194    When axis is None (the default value), the sum is performed
3195    over the flattened tensor.
3196
3197    For full documentation see ``tensor.elemwise.Sum``.
3198    In particular please pay attention to the important warning when using
3199    a custom acc_dtype.
3200
3201    Parameters
3202    ----------
3203    keepdims: bool
3204        If this is set to True, the axes which are reduced are left in
3205        the result as dimensions with size one. With this option, the result
3206        will broadcast correctly against the original tensor.
3207
3208    """
3209
3210    out = elemwise.Sum(axis=axis, dtype=dtype, acc_dtype=acc_dtype)(input)
3211
3212    if keepdims:
3213        out = makeKeepDims(input, out, axis)
3214    return out
3215
3216pprint.assign(Sum(), printing.FunctionPrinter('sum'))
3217
3218
3219@constructor
3220def prod(input, axis=None, dtype=None, keepdims=False, acc_dtype=None,
3221         no_zeros_in_input=False):
3222    """
3223    Computes the product along the given axis(es) of a tensor `input`.
3224
3225    When axis is None (the default value), the product is performed
3226    over the flattened tensor.
3227
3228    For full documentation see ``tensor.elemwise.Prod``.
3229
3230    Parameters
3231    ----------
3232    keepdims: bool
3233        If this is set to True, the axes which are reduced are left in
3234        the result as dimensions with size one. With this option, the result
3235        will broadcast correctly against the original tensor.
3236
3237    """
3238
3239    out = elemwise.Prod(axis, dtype=dtype, acc_dtype=acc_dtype,
3240                        no_zeros_in_input=no_zeros_in_input)(input)
3241
3242    if keepdims:
3243        out = makeKeepDims(input, out, axis)
3244    return out
3245
3246
3247class Mean(elemwise.CAReduce):
3248    def __init__(self, axis=None):
3249        elemwise.CAReduce.__init__(self, scal.add, axis)
3250        assert self.axis is None or len(self.axis) == 1
3251
3252    def __str__(self):
3253        if self.axis is not None:
3254            return "Mean{%s}" % (", ".join(str(x) for x in self.axis))
3255        else:
3256            return "Mean"
3257
3258    def _output_dtype(self, idtype):
3259        # we want to protect against overflow
3260        return 'float64'
3261
3262    def perform(self, node, inp, out):
3263        input, = inp
3264        output, = out
3265        if self.axis is None:
3266            axis = None
3267        else:
3268            axis = self.axis[0]
3269        # numpy.asarray is needed as otherwise we can end up with a
3270        # numpy scalar.
3271        output[0] = np.asarray(np.mean(input, dtype='float64',
3272                                       axis=axis))
3273
3274    def c_code(self, node, name, inames, onames, sub):
3275        if self.axis is not None:
3276            return super(Op, self).c_code(node, name, inames, onames, sub)
3277        ret = elemwise.CAReduce.c_code(self, node, name, inames, onames, sub)
3278        # TODO: c_code perform support only axis is None
3279        return ret + """
3280  *((double *)PyArray_DATA(%s)) /= PyArray_SIZE(%s);
3281  """ % (onames[0], inames[0])
3282
3283# TODO: implement the grad. When done and tested, you can make this the default
3284# version.
3285#    def grad(self, (x,), (gout,)):
3286#      import pdb;pdb.set_trace()
3287#      return grad(mean(x, self.axis, op=False),[x])
3288
3289
3290@constructor
3291def mean(input, axis=None, dtype=None, op=False, keepdims=False,
3292         acc_dtype=None):
3293    """
3294    Computes the mean value along the given axis(es) of a tensor `input`.
3295
3296    Parameters
3297    ----------
3298    axis : None or int or (list of int) (see `Sum`)
3299        Compute the mean along this axis of the tensor.
3300        None means all axes (like numpy).
3301    dtype: None or string
3302        Dtype to cast the result of the inner summation into.
3303        For instance, by default, a sum of a float32 tensor will be
3304        done in float64 (acc_dtype would be float64 by default),
3305        but that result will be casted back in float32.
3306    keepdims: bool
3307        If this is set to True, the axes which are reduced are
3308        left in the result as dimensions with size one. With this option,
3309        the result will broadcast correctly against the original tensor.
3310    acc_dtype: None or string
3311        Dtype to use for the inner summation. This will not
3312        necessarily be the dtype of the output (in particular
3313        if it is a discrete (int/uint) dtype, the output will
3314        be in a float type). If None, then we use the same rules as `sum()`.
3315
3316    Notes
3317    -----
3318    For gpu, if you specify dtype=float32, everything will be done on the gpu.
3319
3320    """
3321    input = as_tensor_variable(input)
3322    if op:
3323        if dtype not in (None, 'float64'):
3324            raise NotImplementedError(
3325                'The Mean op does not support the dtype argument, '
3326                'and will always use float64. If you want to specify '
3327                'the dtype, call tensor.mean(..., op=False).',
3328                dtype)
3329        if acc_dtype not in (None, 'float64'):
3330            raise NotImplementedError(
3331                'The Mean op does not support the acc_dtype argument, '
3332                'and will always use float64. If you want to specify '
3333                'acc_dtype, call tensor.mean(..., op=False).',
3334                dtype)
3335        out = Mean(axis)(input)
3336        if keepdims:
3337            out = makeKeepDims(input, out, axis)
3338        return out
3339
3340    if dtype is not None:
3341        # The summation will be done with the specified dtype.
3342        # sum() will complain if it is not suitable.
3343        sum_dtype = dtype
3344    else:
3345        sum_dtype = None
3346        # float16 overflows on the cast way too often
3347        if input.dtype == 'float16':
3348            sum_dtype = 'float32'
3349
3350    s = sum(input, axis=axis, dtype=sum_dtype, keepdims=keepdims,
3351            acc_dtype=acc_dtype)
3352    shp = shape(input)
3353
3354    # Cast shp into a float type
3355    # TODO Once we have a consistent casting policy, we could simply
3356    # use true_div.
3357    if s.dtype in ('float16', 'float32', 'complex64'):
3358        shp = cast(shp, 'float32')
3359    else:
3360        shp = cast(shp, 'float64')
3361
3362    if axis is None:
3363        axis = list(range(input.ndim))
3364    elif isinstance(axis, (integer_types, np.integer)):
3365        axis = [axis]
3366    elif isinstance(axis, np.ndarray) and axis.ndim == 0:
3367        axis = [int(axis)]
3368    else:
3369        axis = [int(a) for a in axis]
3370
3371    # This sequential division will possibly be optimized by Theano:
3372    for i in axis:
3373        s = true_div(s, shp[i])
3374
3375    # This can happen when axis is an empty list/tuple
3376    if s.dtype != shp.dtype and s.dtype in discrete_dtypes:
3377        s = cast(s, shp.dtype)
3378
3379    if dtype == 'float16' or (dtype is None and input.dtype == 'float16'):
3380        s = cast(s, 'float16')
3381    s.name = 'mean'
3382    return s
3383
3384
3385@constructor
3386def var(input, axis=None, ddof=0, keepdims=False, corrected=False):
3387    """
3388    Computes the variance along the given axis(es) of a tensor `input`.
3389
3390    Parameters
3391    ----------
3392    axis: None or int or (list of int) (see `Sum`)
3393        Compute the variance along this axis of the tensor.
3394        None means all axes (like numpy).
3395    ddof: Degrees of freedom; 0 would compute the ML estimate, 1 would compute
3396        the unbiased estimate.
3397    keepdims : bool
3398        If this is set to True, the axes which are reduced are
3399        left in the result as dimensions with size one. With this option,
3400        the result will broadcast correctly against the original tensor.
3401    corrected : bool
3402        If this is set to True, the 'corrected_two_pass' algorithm is
3403        used to compute the variance.
3404        Refer : http://www.cs.yale.edu/publications/techreports/tr222.pdf
3405
3406    Notes
3407    -----
3408    Default uses the two-pass algorithm (reference below).
3409    https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Two-pass_algorithm
3410    Also supports 'corrected_two_pass' algorithm (using the 'corrected' flag)
3411    which is numerically more stable. There exist other implementations that
3412    offer better stability, but probably slower.
3413
3414    """
3415
3416    if isinstance(ddof, (bool)):
3417        raise ValueError('Parameter keepdims is now at index 3: (input, \
3418                          axis=None, ddof=0, keepdims=False, corrected=False)')
3419
3420    input_ndim = input.type.ndim
3421    if axis is None:
3422        axis = list(range(input_ndim))
3423    elif isinstance(axis, (integer_types, np.integer)):
3424        axis = [axis]
3425    elif isinstance(axis, np.ndarray) and axis.ndim == 0:
3426        axis = [int(axis)]
3427    else:
3428        axis = [int(a) for a in axis]
3429
3430    # compute the axis-wise mean
3431    mean_input = mean(input, axis, keepdims=True)
3432
3433    # center the input
3434    centered_input = input - mean_input
3435
3436    # return the mean sqr
3437    two = constant(2, dtype=centered_input.dtype)
3438    if ddof == 0:
3439        v = mean((centered_input ** two), axis, keepdims=keepdims)
3440    else:
3441        shp = shape(input) - ddof
3442        v = sum((centered_input ** two), axis=axis, keepdims=keepdims)
3443        for i in axis:
3444            v = true_div(v, shp[i])
3445
3446    # use 'corrected_two_pass' algorithm
3447    if corrected:
3448        if ddof == 0:
3449            error = mean(centered_input, axis, keepdims=keepdims) ** 2
3450        else:
3451            shp = shape(input) - ddof
3452            shp_inp = shape(input)
3453            error = sum(centered_input, axis=axis, keepdims=keepdims) ** 2
3454            for i in axis:
3455                error = true_div(error, shp[i] * shp_inp[i])
3456        v = v - error
3457
3458    v.name = 'var'
3459    return v
3460
3461
3462@constructor
3463def std(input, axis=None, ddof=0, keepdims=False, corrected=False):
3464    """
3465    Computes the standard deviation along the given axis(es) of a tensor `input`.
3466
3467    Parameters
3468    ----------
3469    axis: None or int or (list of int) (see `Sum`)
3470        Compute the variance along this axis of the tensor.
3471        None means all axes (like numpy).
3472    ddof: Degrees of freedom; 0 would compute the ML estimate, 1 would compute
3473        the unbiased estimate.
3474    keepdims : bool
3475        If this is set to True, the axes which are reduced are
3476        left in the result as dimensions with size one. With this option,
3477        the result will broadcast correctly against the original tensor.
3478    corrected : bool
3479        If this is set to True, the 'corrected_two_pass' algorithm is
3480        used to compute the variance.
3481        Refer : http://www.cs.yale.edu/publications/techreports/tr222.pdf
3482
3483    Notes
3484    -----
3485    It calls 'var()' and 'var()' uses the two-pass algorithm (reference below).
3486    https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Two-pass_algorithm
3487    Function 'var()' also supports 'corrected_two_pass' algorithm (using the
3488    'corrected' flag) which is numerically more stable. There exist other
3489    implementations that offer better stability, but probably slower.
3490
3491    """
3492
3493    if isinstance(ddof, (bool)):
3494        raise ValueError('Parameter keepdims is now at index 3: (input, \
3495                          axis=None, ddof=0, keepdims=False, corrected=False)')
3496
3497    ret = sqrt(var(input=input, axis=axis, ddof=ddof,
3498                   keepdims=keepdims, corrected=corrected))
3499    ret.name = 'std'
3500    return ret
3501
3502
3503class Default(gof.Op):
3504    """
3505    Takes an input x and a default value.
3506
3507    If the input is not None, a reference to it is returned.
3508    If the input is None, a copy of the default value is returned instead.
3509    The input and the default must have exactly the same type.
3510
3511    """
3512    view_map = {0: [0]}
3513    __props__ = ()
3514
3515    def make_node(self, x, default):
3516        x, default = as_tensor_variable(x), as_tensor_variable(default)
3517        if x.type != default.type:
3518            raise TypeError('Both default() arguments must have same type',
3519                            x, default)
3520        return gof.Apply(self, [x, default], [default.type()])
3521
3522    def perform(self, node, inp, out_):
3523        x, default = inp
3524        out, = out_
3525        if x is None:
3526            # why copy?  Theano can't yet understand out[0] being a view of
3527            # either x or y, so we can be a view of x, but only a copy of y.
3528            out[0] = default.copy()
3529        else:
3530            out[0] = x
3531
3532default = Default()
3533setdefault = default  # legacy
3534
3535
3536##########################
3537# Arithmetics
3538##########################
3539@_scal_elemwise
3540def maximum(x, y):
3541    """elemwise maximum. See max for the maximum in one tensor"""
3542    # see decorator for function body
3543
3544
3545@_scal_elemwise
3546def minimum(x, y):
3547    """elemwise minimum. See min for the minimum in one tensor"""
3548    # see decorator for function body
3549
3550
3551def div_proxy(x, y):
3552    """Proxy for either true_div or int_div, depending on types of x, y."""
3553    f = scal.int_or_true_div(
3554        as_tensor_variable(x).dtype in discrete_dtypes,
3555        as_tensor_variable(y).dtype in discrete_dtypes)
3556    if f is scal.int_div:
3557        return int_div(x, y)
3558    else:
3559        return true_div(x, y)
3560
3561
3562def divmod(x, y):
3563    """elementvise divmod, using floor_div and mod_check"""
3564    return floor_div(x, y), mod_check(x, y)
3565
3566
3567@_scal_elemwise
3568def add(a, *other_terms):
3569    """elementwise addition"""
3570    # see decorator for function body
3571
3572
3573@_scal_elemwise
3574def sub(a, b):
3575    """elementwise subtraction"""
3576    # see decorator for function body
3577
3578
3579@_scal_elemwise
3580def mul(a, *other_terms):
3581    """elementwise multiplication"""
3582    # see decorator for function body
3583
3584
3585@_scal_elemwise
3586def true_div(a, b):
3587    """elementwise [true] division (inverse of multiplication)"""
3588    # see decorator for function body
3589
3590
3591@_scal_elemwise
3592def int_div(a, b):
3593    """elementwise [floor] division (inverse of multiplication)"""
3594    # see decorator for function body
3595
3596
3597# floor_div and int_div are the same thing
3598floor_div = int_div
3599
3600
3601def ceil_intdiv(a, b):
3602    """
3603    Safely compute ceil(float_division(a, b)).
3604
3605    Works for all dtypes, but mostly useful when a and b are int.
3606
3607    """
3608    # If a and b are int with not many significant bits, we could
3609    # cast them to float to avoid doing the modulo. We do not know if this
3610    # is faster or not. But this is not safe for int64 as the cast will
3611    # lose precision.
3612    # e.g.: cast(cast(a, scalar.upcast(a, 'float32')) / b, scal.upcast(a, b))
3613
3614    # We cast for the case when a and b are uint*. Otherwise neq will
3615    # force their upcast to int.
3616    div = int_div(a, b)
3617    ret = cast(neq(a % b, 0), div.dtype) + div
3618    assert ret.dtype == scal.upcast(div.owner.inputs[0], div.owner.inputs[1])
3619    return ret
3620
3621
3622def mod_check(x, y):
3623    """Make sure we do not try to use complex numbers."""
3624    if ((as_tensor_variable(x).dtype in complex_dtypes or
3625         as_tensor_variable(y).dtype in complex_dtypes)):
3626        # Currently forbidden.
3627        raise scal.Mod.complex_error
3628    else:
3629        return mod(x, y)
3630
3631
3632@_scal_elemwise
3633def mod(a, b):
3634    """elementwise modulo"""
3635    # see decorator for function body
3636
3637
3638@_scal_elemwise
3639def pow(a, b):
3640    """elementwise power"""
3641    # see decorator for function body
3642
3643
3644@_scal_elemwise
3645def clip(x, min, max):
3646    """
3647    Clip x to be between min and max.
3648
3649    Notes
3650    -----
3651    When `x` is equal to the boundaries, the output is considered
3652    to be `x`, so at these points, the gradient of the cost wrt the output
3653    will be propagated to `x`, not to `min` nor `max`. In other words,
3654    on these points, the gradient wrt `x` will be equal to the gradient wrt
3655    the output, and the gradient wrt `min` and `max` will be zero.
3656
3657    """
3658    # see decorator for function body
3659    # for grep: clamp, bound
3660
3661pprint.assign(add, printing.OperatorPrinter('+', -2, 'either'))
3662pprint.assign(mul, printing.OperatorPrinter('*', -1, 'either'))
3663pprint.assign(sub, printing.OperatorPrinter('-', -2, 'left'))
3664pprint.assign(neg, printing.OperatorPrinter('-', 0, 'either'))
3665pprint.assign(true_div, printing.OperatorPrinter('/', -1, 'left'))
3666pprint.assign(int_div, printing.OperatorPrinter('//', -1, 'left'))
3667pprint.assign(pow, printing.OperatorPrinter('**', 1, 'right'))
3668
3669
3670##########################
3671# View Operations
3672##########################
3673
3674
3675def extract_constant(x, elemwise=True, only_process_constants=False):
3676    """
3677    This function is basically a call to tensor.get_scalar_constant_value.
3678
3679    The main difference is the behaviour in case of failure. While
3680    get_scalar_constant_value raises an TypeError, this function returns x,
3681    as a tensor if possible. If x is a ScalarVariable from a
3682    scalar_from_tensor, we remove the conversion. If x is just a
3683    ScalarVariable, we convert it to a tensor with tensor_from_scalar.
3684
3685    """
3686    try:
3687        x = get_scalar_constant_value(x,
3688                                      elemwise,
3689                                      only_process_constants)
3690    except NotScalarConstantError:
3691        pass
3692    if ((isinstance(x, scal.ScalarVariable) or
3693         isinstance(x, scal.sharedvar.ScalarSharedVariable))):
3694        if x.owner and isinstance(x.owner.op, ScalarFromTensor):
3695            x = x.owner.inputs[0]
3696        else:
3697            x = tensor_from_scalar(x)
3698    return x
3699
3700
3701def transpose(x, axes=None):
3702    """
3703    Reorder the dimensions of x. (Default: reverse them)
3704
3705    This is a macro around dimshuffle that matches the numpy.transpose function.
3706
3707    """
3708    if axes is None:
3709        axes = list(range((x.ndim - 1), -1, -1))
3710    ret = DimShuffle(x.broadcastable, axes)(x)
3711    if x.name and axes == list(range((x.ndim - 1), -1, -1)):
3712        ret.name = x.name + '.T'
3713    return ret
3714
3715
3716def batched_dot(a, b):
3717    """
3718    Compute the batched dot product of two variables:
3719
3720        batched_dot(a, b)[i] = dot(a[i], b[i])
3721
3722    Note that this batched_dot function does one of three things, in the
3723    following sequence:
3724
3725        1.  If either a or b is a vector, it returns the batched elementwise
3726            product without calling the Theano BatchedDot op.
3727
3728        2.  If both a and b have either 2 or 3 dimensions, it calls Theano's
3729            BatchedDot op on a and b.
3730
3731        3.  If either a or b has more than 3 dimensions, it calls Theano's
3732            batched_tensordot function with appropriate axes. The
3733            batched_tensordot function expresses high-dimensional batched
3734            dot products in terms of batched matrix-matrix dot products, so
3735            it may be possible to futherize optimize for performance.
3736    """
3737    a, b = as_tensor_variable(a), as_tensor_variable(b)
3738
3739    if a.ndim == 0:
3740        raise TypeError("a must have at least one (batch) axis")
3741    elif b.ndim == 0:
3742        raise TypeError("b must have at least one (batch) axis")
3743    elif a.ndim == 1:
3744        return a.dimshuffle(*([0] + ["x"] * (b.ndim - 1))) * b
3745    elif b.ndim == 1:
3746        return a * b.dimshuffle(*([0] + ["x"] * (a.ndim - 1)))
3747    elif a.ndim > 3 or b.ndim > 3:
3748        return batched_tensordot(
3749            a, b, [[a.ndim - 1], [np.maximum(1, b.ndim - 2)]])
3750    else:
3751        # avoid circular import
3752        return theano.tensor.blas.BatchedDot()(a, b)
3753
3754
3755def batched_tensordot(x, y, axes=2):
3756    """
3757    Compute a batched tensordot product.
3758
3759    A hybrid of batched_dot and tensordot, this function computes the
3760    tensordot product between the two tensors, by iterating over the
3761    first dimension to perform a sequence of tensordots.
3762
3763    Parameters
3764    ----------
3765    x : tensor
3766        A Tensor with sizes e.g.: for 3D (dim1, dim3, dim2)
3767    y : tensor
3768        A Tensor with sizes e.g.: for 3D (dim1, dim2, dim4)
3769    axes: int or array-like of length 2
3770        If an integer, the number of axes to sum over.
3771        If an array, it must have two array elements containing the axes to sum
3772        over in each tensor.
3773
3774        If an integer i, it is converted to an array containing
3775        the last i dimensions of the first tensor and the first
3776        i dimensions of the second tensor (excluding the first
3777        (batch) dimension):
3778            axes = [list(range(a.ndim - i, b.ndim)), list(range(1,i+1))]
3779
3780        If an array, its two elements must contain compatible axes
3781        of the two tensors. For example, [[1, 2], [2, 4]] means sum
3782        over the 2nd and 3rd axes of a and the 3rd and 5th axes of b.
3783        (Remember axes are zero-indexed!) The 2nd axis of a and the
3784        3rd axis of b must have the same shape; the same is true for
3785        the 3rd axis of a and the 5th axis of b.
3786
3787    Like tensordot, this function uses a series of dimshuffles and
3788    reshapes to reduce the tensor dot product to a matrix or vector
3789    dot product.  Finally, it calls batched_dot to compute the result.
3790    """
3791    return _tensordot_as_dot(x, y, axes, dot=batched_dot, batched=True)
3792
3793
3794def split(x, splits_size, n_splits, axis=0):
3795    the_split = Split(n_splits)
3796    return the_split(x, axis, splits_size)
3797
3798
3799class Split(Op):
3800    """Partition a `TensorVariable` along some axis.
3801
3802    Examples
3803    --------
3804    >>> x = vector()
3805    >>> splits = lvector()
3806    You have to declare right away how many split_points there will be.
3807    >>> ra, rb, rc = split(x, splits, n_splits = 3, axis = 0)
3808    >>> f = function([x, splits], [ra, rb, rc])
3809    >>> a, b, c = f([0,1,2,3,4,5], [3, 2, 1])
3810    a == [0,1,2]
3811    b == [3, 4]
3812    c == [5]
3813
3814    """
3815
3816    len_splits = None
3817    """A Split instance will have this many outputs, and require that
3818    the splits argument to `perform` have exactly this many elements.
3819    """
3820    __props__ = ("len_splits",)
3821
3822    def __init__(self, len_splits):
3823        self.len_splits = int(len_splits)
3824
3825    def __str__(self):
3826        return self.__class__.__name__ + "{%s}" % self.len_splits
3827
3828    def make_node(self, x, axis, splits):
3829        """WRITEME"""
3830        x = as_tensor_variable(x)
3831        axis = as_tensor_variable(axis)
3832        splits = as_tensor_variable(splits)
3833
3834        if splits.type not in int_vector_types:
3835            raise TypeError('splits must have type tensor.lvector',
3836                            splits.type)
3837        if axis.type not in int_types:
3838            raise TypeError('axis must have type lscalar', axis.type)
3839
3840#         # The following lines are necessary if we allow splits of zero
3841#         if isinstance(axis, gof.Constant):
3842#             x = unbroadcast(x, int(axis.data))
3843#         else:
3844#             x = unbroadcast(x, *range(x.type.ndim))
3845
3846        inputs = [x, axis, splits]
3847        outputs = [x.type() for i in xrange(self.len_splits)]
3848
3849        return Apply(self, inputs, outputs)
3850
3851    def perform(self, node, inputs, outputs):
3852        """WRITEME"""
3853        x, axis, splits = inputs
3854        # in python 2.4, x.shape[numpy.asarray(1)] don't work.
3855        if sys.version_info[0:2] == (2, 4) and axis.size == 1:
3856            axis = int(axis)
3857
3858        try:
3859            len_along_axis = x.shape[axis]
3860        except Exception:
3861            raise ValueError('Split.perform() with axis=(%s) is invalid'
3862                             ' for x.shape==(%s)'
3863                             % (axis, x.shape))
3864        if len(splits) != self.len_splits:
3865            raise ValueError('In Split.perform(), len(splits) != len_splits.',
3866                             (len(splits), self.len_splits))
3867
3868        if np.sum(splits) != len_along_axis:
3869            raise ValueError('The splits sum to %s, expected %s' %
3870                             (np.sum(splits), len_along_axis))
3871        if python_any([nb < 0 for nb in splits]):
3872            raise ValueError('Split: you tried to make an ndarray with a '
3873                             'negative number of elements.')
3874
3875        # Checking is done, let's roll the splitting algorithm!
3876        # Basically we step along the given axis of x, extracting
3877        # subtensors of size splits[i] as we go along.
3878
3879        general_key = [slice(None, None, None) for s in x.shape]
3880        lower_idx = 0
3881        for i in xrange(self.len_splits):
3882            upper_idx = lower_idx + splits[i]
3883            general_key[axis] = slice(lower_idx, upper_idx, None)
3884            outputs[i][0] = x.__getitem__(tuple(general_key)).copy()
3885            lower_idx = upper_idx
3886
3887    def infer_shape(self, node, in_shapes):
3888        axis = node.inputs[1]
3889        splits = node.inputs[2]
3890        shp_x, shp_axis, shp_splits = in_shapes
3891        out_shapes = []
3892        for i in xrange(self.len_splits):
3893            temp = as_tensor_variable(shp_x)
3894            temp = theano.tensor.subtensor.set_subtensor(temp[axis], splits[i])
3895            temp = [temp[i] for i in xrange(len(shp_x))]
3896            out_shapes.append(temp)
3897        return out_shapes
3898
3899    def grad(self, inputs, g_outputs):
3900        """Join the gradients along the axis that was used to split x."""
3901        x, axis, n = inputs
3902        outputs = self(*inputs, **dict(return_list=True))
3903        # If all the output gradients are disconnected, then so are the inputs
3904        if python_all([isinstance(g.type, DisconnectedType)
3905                       for g in g_outputs]):
3906            return [DisconnectedType()(),
3907                    grad_undefined(self, 1, axis),
3908                    grad_undefined(self, 2, n)]
3909        # Else, we have to make them zeros before joining them
3910        new_g_outputs = []
3911        for o, g in zip(outputs, g_outputs):
3912            if isinstance(g.type, DisconnectedType):
3913                new_g_outputs.append(o.zeros_like())
3914            else:
3915                new_g_outputs.append(g)
3916
3917        return [join(axis, *new_g_outputs),
3918                grad_undefined(self, 1, axis),
3919                grad_undefined(self, 2, n)]
3920
3921    def R_op(self, inputs, eval_points):
3922        if eval_points[0] is None:
3923            return [None for i in self.len_splits]
3924        return self.make_node(eval_points[0], *inputs[1:]).outputs
3925
3926    def c_code_cache_version(self):
3927        return (2,)
3928
3929    def c_support_code(self):
3930        return """
3931        /* Return 1 if output has the correct shape. */
3932        int split_output_shape_is_correct (
3933            PyArrayObject* output, PyArrayObject* array_to_split, int axis_to_split, npy_intp split_size
3934        ) {
3935            return
3936                PyArray_NDIM(output) == PyArray_NDIM(array_to_split)
3937                && memcmp(
3938                    PyArray_DIMS(output),
3939                    PyArray_DIMS(array_to_split),
3940                    axis_to_split * sizeof(npy_intp)
3941                ) == 0
3942                && memcmp(
3943                    PyArray_DIMS(output) + axis_to_split + 1,
3944                    PyArray_DIMS(array_to_split) + axis_to_split + 1,
3945                    (PyArray_NDIM(array_to_split) - axis_to_split - 1) * sizeof(npy_intp)
3946                ) == 0
3947                && split_size == PyArray_DIM(output, axis_to_split);
3948        }
3949        """
3950
3951    def c_code(self, node, name, inputs, outputs, sub):
3952        if self.len_splits == 0:
3953            # There are no outputs, then nothing to do.
3954            return ''
3955
3956        # outputs_pointers lists the addresses of the pointers to the outputs.
3957        outputs_pointers = '&' + (', &'.join(outputs))
3958        x, axis, splits = inputs
3959        fail = sub['fail']
3960        x_typenum = np.dtype(node.inputs[0].dtype).num
3961        x_itemsize = np.dtype(node.inputs[0].dtype).itemsize
3962        axis_dtype = node.inputs[1].type.dtype_specs()[1]
3963        splits_dtype = node.inputs[2].type.dtype_specs()[1]
3964        expected_splits_count = self.len_splits
3965
3966        return """
3967        int ndim = PyArray_NDIM(%(x)s);
3968        int axis = (int)(*(%(axis_dtype)s*)PyArray_GETPTR1(%(axis)s, 0));
3969        int splits_count = PyArray_DIM(%(splits)s, 0);
3970        npy_intp len_along_axis, sum_of_splits = 0, current_split_length = 0, current_split_start = 0;
3971        npy_intp* split_dims = NULL;
3972        PyObject* split_view = NULL;
3973        npy_intp data_offset;
3974        int i;
3975        PyArrayObject** outputs[] = {%(outputs_pointers)s};
3976
3977        /* Check inputs. */
3978
3979        if (splits_count != %(expected_splits_count)s) {
3980            PyErr_Format(PyExc_ValueError,
3981                "Split: splits count (%%d) != expected count (%%d).", splits_count, %(expected_splits_count)s);
3982            %(fail)s
3983        }
3984
3985        if (axis < 0) {
3986            axis += ndim;
3987        }
3988        if (axis < 0 || axis >= ndim) {
3989            PyErr_Format(PyExc_IndexError, "Split: invalid axis %%d for a %%d-D array.", axis, ndim);
3990            %(fail)s
3991        }
3992        len_along_axis = PyArray_DIM(%(x)s, axis);
3993
3994        for (i = 0; i < splits_count; ++i) {
3995            current_split_length = (npy_intp)(*(%(splits_dtype)s*)PyArray_GETPTR1(%(splits)s, i));
3996            if (current_split_length < 0) {
3997                PyErr_Format(PyExc_ValueError,
3998                    "Split: you try to take a negative number (%%ld) of elements.", current_split_length);
3999                %(fail)s
4000            }
4001            sum_of_splits += current_split_length;
4002        }
4003        if (sum_of_splits != len_along_axis) {
4004            PyErr_Format(PyExc_ValueError, "Split: the splits sums to %%ld, expected %%ld.", sum_of_splits, len_along_axis);
4005            %(fail)s
4006        }
4007
4008        /* Check outputs. */
4009
4010        split_dims = (npy_intp*) malloc(ndim * sizeof(npy_intp));
4011        if (split_dims == NULL) {
4012            PyErr_NoMemory();
4013            %(fail)s
4014        }
4015
4016        memcpy(split_dims, PyArray_DIMS(%(x)s), ndim * sizeof(npy_intp));
4017
4018        for (i = 0; i < splits_count; ++i) {
4019            PyArrayObject** output = outputs[i];
4020            current_split_length = (npy_intp) (* (%(splits_dtype)s*) PyArray_GETPTR1(%(splits)s, i));
4021            if (*output == NULL || !split_output_shape_is_correct(*output, %(x)s, axis, current_split_length)) {
4022                Py_XDECREF(*output);
4023                split_dims[axis] = current_split_length;
4024                *output = (PyArrayObject*)PyArray_EMPTY(ndim, split_dims, %(x_typenum)s, PyArray_IS_F_CONTIGUOUS(%(x)s));
4025                if (outputs == NULL) {
4026                    PyErr_SetString(PyExc_RuntimeError, "Split: unable to allocate an output.");
4027                    free(split_dims);
4028                    %(fail)s
4029                }
4030            }
4031        }
4032
4033        /* Compute split. */
4034
4035        for (i = 0; i < splits_count; ++i) {
4036            current_split_length = (npy_intp) (* (%(splits_dtype)s*) PyArray_GETPTR1(%(splits)s, i));
4037            data_offset = PyArray_STRIDE(%(x)s, axis) * current_split_start;
4038            split_dims[axis] = current_split_length;
4039            split_view = PyArray_New(&PyArray_Type,
4040                                    ndim, split_dims,
4041                                    %(x_typenum)s,
4042                                    PyArray_STRIDES(%(x)s),
4043                                    PyArray_BYTES(%(x)s) + data_offset,
4044                                    %(x_itemsize)s,
4045                                    PyArray_FLAGS(%(x)s),
4046                                    NULL);
4047            if (split_view == NULL) {
4048                PyErr_SetString(PyExc_RuntimeError, "Split: unable to create a view for a split.");
4049                free(split_dims);
4050                %(fail)s
4051            }
4052            if (PyArray_CopyInto(*outputs[i], (PyArrayObject*)split_view) != 0) {
4053                PyErr_SetString(PyExc_RuntimeError, "Split: unable to copy a split view into the output.");
4054                Py_XDECREF(split_view);
4055                free(split_dims);
4056                %(fail)s
4057            }
4058            Py_XDECREF(split_view);
4059            current_split_start += current_split_length;
4060        }
4061
4062        free(split_dims);
4063        """ % locals()
4064
4065
4066def addbroadcast(x, *axes):
4067    """
4068    Make the input broadcastable in the specified axes.
4069
4070    For example, addbroadcast(x, 0) will make the first dimension of
4071    x broadcastable. When performing the function, if the length of
4072    x along that dimension is not 1, a ValueError will be raised.
4073
4074    We apply the opt here not to pollute the graph especially during
4075    the gpu optimization
4076
4077    Parameters
4078    ----------
4079    x : tensor_like
4080        Input theano tensor.
4081    axis : an int or an iterable object such as list or tuple of int values
4082        The dimension along which the tensor x should be broadcastable.
4083        If the length of x along these dimensions is not 1, a ValueError will
4084        be raised.
4085
4086    Returns
4087    -------
4088    tensor
4089        A theano tensor, which is broadcastable along the specified dimensions.
4090
4091    """
4092    rval = Rebroadcast(*[(axis, True) for axis in axes])(x)
4093    return theano.tensor.opt.apply_rebroadcast_opt(rval)
4094
4095
4096def unbroadcast(x, *axes):
4097    """
4098    Make the input impossible to broadcast in the specified axes.
4099
4100    For example, addbroadcast(x, 0) will make the first dimension
4101    of x broadcastable. When performing the function, if the length
4102    of x along that dimension is not 1, a ValueError will be raised.
4103
4104    We apply the opt here not to pollute the graph especially during
4105    the gpu optimization
4106
4107    Parameters
4108    ----------
4109    x : tensor_like
4110        Input theano tensor.
4111    axis : an int or an iterable object such as list or tuple of int values
4112        The dimension along which the tensor x should be unbroadcastable.
4113        If the length of x along these dimensions is not 1, a ValueError will
4114        be raised.
4115
4116    Returns
4117    -------
4118    tensor
4119        A theano tensor, which is unbroadcastable along the specified dimensions.
4120
4121    """
4122    rval = Rebroadcast(*[(axis, False) for axis in axes])(x)
4123    return theano.tensor.opt.apply_rebroadcast_opt(rval)
4124
4125
4126def patternbroadcast(x, broadcastable):
4127    """
4128    Make the input adopt a specific broadcasting pattern.
4129
4130    Broadcastable must be iterable. For example,
4131    patternbroadcast(x, (True, False)) will make the first
4132    dimension of x broadcastable and the second dimension
4133    not broadcastable, so x will now be a row.
4134
4135    We apply the opt here not to pollute the graph especially during the gpu
4136    optimization.
4137
4138    Parameters
4139    ----------
4140    x : tensor_like
4141        Input theano tensor.
4142    broadcastable : an iterable object such as list or tuple of bool values
4143        A set of boolean values indicating whether a dimension should be
4144        broadcastable or not. If the length of x along these dimensions is
4145        not 1, a ValueError will be raised.
4146
4147    Returns
4148    -------
4149    tensor
4150        A theano tensor, which is unbroadcastable along the specified dimensions.
4151
4152    """
4153    rval = Rebroadcast(*[(i, broadcastable[i])
4154                         for i in xrange(len(broadcastable))])(x)
4155    return theano.tensor.opt.apply_rebroadcast_opt(rval)
4156
4157
4158class Join(Op):
4159    """
4160    Concatenate multiple `TensorVariable`s along some axis.
4161
4162    The axis must be given as first argument. All tensors must have the same
4163    shape along all dimensions other than this axis.
4164    Of course, TensorVariable instances do not have a shape, so this error
4165    cannot be caught until runtime.  See `perform()`.
4166
4167    See Also
4168    --------
4169    stack : For joins involving scalar values
4170
4171    Examples
4172    --------
4173    >>> x, y, z = tensor.matrix(), tensor.matrix(), tensor.matrix()
4174    >>> u = tensor.vector()
4175
4176    >>> r = join(0, x, y, z)
4177    >>> c = join(1, x, y, z)
4178    >>> join(2, x, y, z)    # WRONG: the axis has to be an index into the shape
4179    >>> join(0, x, u)       # WRONG: joined tensors must have the same rank
4180
4181    """
4182    check_input = False
4183    __props__ = ("view",)
4184
4185    def __init__(self, view=-1):
4186        self.view = view
4187        if view != -1:
4188            # since the first input is always the axis, the tensors
4189            # start from index 1.
4190            self.view_map = {0: [1 + view]}
4191
4192    def __str__(self):
4193        if self.view == -1:
4194            return self.__class__.__name__
4195        else:
4196            return "%s{%s}" % (
4197                self.__class__.__name__,
4198                ", ".join("%s=%r" % (p, getattr(self, p))
4199                          for p in self.__props__))
4200
4201    def __setstate__(self, d):
4202        self.__dict__.update(d)
4203        if not hasattr(self, "view"):
4204            self.view = -1
4205
4206    def make_node(self, *axis_and_tensors):
4207        """
4208        Parameters
4209        ----------
4210        axis: an Int or integer-valued Variable
4211        tensors
4212            A variable number (but not zero) of tensors to
4213            concatenate along the specified axis.  These tensors must have
4214            the same shape along all dimensions other than this axis.
4215
4216        Returns
4217        -------
4218        A symbolic Variable
4219            It has the same ndim as the input tensors, and the most inclusive
4220            dtype.
4221
4222        """
4223        axis, tensors = axis_and_tensors[0], axis_and_tensors[1:]
4224        if not tensors:
4225            raise ValueError('Cannot join an empty list of tensors')
4226        as_tensor_variable_args = [as_tensor_variable(x) for x in tensors]
4227
4228        dtypes = [x.type.dtype for x in as_tensor_variable_args]
4229        out_dtype = scal.upcast(*dtypes)
4230
4231        def output_maker(bcastable):
4232            return tensor(dtype=out_dtype, broadcastable=bcastable)
4233
4234        return self._make_node_internal(
4235            axis, tensors, as_tensor_variable_args, output_maker)
4236
4237    def _make_node_internal(self, axis, tensors,
4238                            as_tensor_variable_args, output_maker):
4239        if not python_all(targs.type.ndim for targs
4240                          in as_tensor_variable_args):
4241            raise TypeError('Join cannot handle arguments of dimension 0.'
4242                            ' For joining scalar values, see @stack')
4243        # Handle single-tensor joins immediately.
4244        if len(as_tensor_variable_args) == 1:
4245            bcastable = list(as_tensor_variable_args[0].type.broadcastable)
4246        else:
4247            # When the axis is fixed, a dimension should be
4248            # broadcastable if at least one of the inputs is
4249            # broadcastable on that dimension (see justification below),
4250            # except for the axis dimension.
4251            # Initialize bcastable all false, and then fill in some trues with
4252            # the loops.
4253            bcastable = [False] * len(
4254                as_tensor_variable_args[0].type.broadcastable)
4255            ndim = len(bcastable)
4256            # Axis can also be a constant
4257            if not isinstance(axis, integer_types):
4258                try:
4259                    # Note : `get_scalar_constant_value` returns a ndarray not
4260                    # an int
4261                    axis = int(get_scalar_constant_value(axis))
4262
4263                except NotScalarConstantError:
4264                    pass
4265            if isinstance(axis, integer_types):
4266                # Basically, broadcastable -> length 1, but the
4267                # converse does not hold. So we permit e.g. T/F/T
4268                # joins, and if they fail at runtime they fail, but if
4269                # they don't then it means that the argument where
4270                # that broadcastable flag was False had length 1 along
4271                # this dimension, and therefore this dimension should
4272                # be broadcastable for the output.
4273
4274                if axis < -ndim:
4275                    raise IndexError("Join axis %d out of bounds [0, %d)" %
4276                                     (axis, ndim))
4277                if axis < 0:
4278                    axis += ndim
4279
4280                for x in as_tensor_variable_args:
4281                    for current_axis, bflag in enumerate(x.type.broadcastable):
4282                        # Constant negative axis can no longer be negative at
4283                        # this point. It safe to compare this way.
4284                        if current_axis == axis:
4285                            continue
4286                        if bflag:
4287                            bcastable[current_axis] = True
4288                try:
4289                    bcastable[axis] = False
4290                except IndexError:
4291                    raise ValueError('Join argument "axis" is out of range'
4292                                     ' (given input dimensions)')
4293            else:
4294                # When the axis may vary, no dimension can be guaranteed to be
4295                # broadcastable.
4296                bcastable = [False] * len(
4297                    as_tensor_variable_args[0].type.broadcastable)
4298
4299        if not python_all([x.ndim == len(bcastable)
4300                           for x in as_tensor_variable_args[1:]]):
4301            raise TypeError("Join() can only join tensors with the same "
4302                            "number of dimensions.")
4303
4304        inputs = [as_tensor_variable(axis)] + list(as_tensor_variable_args)
4305        if inputs[0].type not in int_types:
4306            raise TypeError('Axis could not be cast to an integer type',
4307                            axis, inputs[0].type, int_types)
4308
4309        outputs = [output_maker(bcastable)]
4310
4311        node = Apply(self, inputs, outputs)
4312        return node
4313
4314    def perform(self, node, axis_and_tensors, out_):
4315        out, = out_
4316        view = self.view
4317        axis, tensors = axis_and_tensors[0], axis_and_tensors[1:]
4318        # we check these tensors for being empty.
4319        if (view != -1) and np.all(
4320                [tensor.shape[axis] == 0 for tensor in
4321                 tensors[0:view] + tensors[view + 1:]]):
4322            out[0] = tensors[view]
4323
4324        else:
4325            ndim = tensors[0].ndim
4326            if axis < -ndim:
4327                raise IndexError("Join axis %d out of bounds [0, %d)" %
4328                                 (axis, ndim))
4329
4330            out[0] = theano._asarray(np.concatenate(tensors, axis=axis),
4331                                     dtype=node.outputs[0].type.dtype)
4332
4333    def c_code_cache_version(self):
4334        return (5,)
4335
4336    def c_code(self, node, name, inputs, outputs, sub):
4337        axis, tensors = inputs[0], inputs[1:]
4338        view = self.view
4339        non_empty_tensor = tensors[view]
4340        input_1 = tensors[0]
4341        l = len(tensors)
4342        out, = outputs
4343        fail = sub['fail']
4344        adtype = node.inputs[0].type.dtype_specs()[1]
4345        copy_to_list = []
4346
4347        for i, inp in enumerate(tensors):
4348            copy_to_list.append(
4349                """Py_INCREF(%s);
4350                   PyList_SetItem(list, %s, (PyObject*)%s);"""
4351                % (inp, i, inp))
4352
4353        copy_inputs_to_list = '\n'.join(copy_to_list)
4354        n = len(tensors)
4355
4356        code = """
4357        int axis = ((%(adtype)s *)PyArray_DATA(%(axis)s))[0];
4358        PyObject* list = PyList_New(%(l)s);
4359        %(copy_inputs_to_list)s
4360        int tensors_lens_sum;
4361        if(%(view)s != -1) {
4362            tensors_lens_sum = 0;
4363
4364            for(int i=0; i < %(n)s; i++){
4365                tensors_lens_sum += PyArray_DIM((PyArrayObject *)(PyList_GetItem(list, i)), axis);
4366            }
4367            tensors_lens_sum -= PyArray_DIM(%(non_empty_tensor)s, axis);
4368        }
4369        if(%(view)s != -1 && tensors_lens_sum == 0) {
4370            Py_XDECREF(%(out)s);
4371            Py_INCREF(%(non_empty_tensor)s);
4372            %(out)s = %(non_empty_tensor)s;
4373        }else{
4374            //PyObject* PyArray_Concatenate(PyObject* obj, int axis)
4375            int ndim = PyArray_NDIM(%(input_1)s);
4376            if( axis < -ndim ){
4377                PyErr_Format(PyExc_IndexError,
4378                             "Join axis %%d out of bounds [0, %%d)", axis, ndim);
4379                %(fail)s
4380            }
4381            Py_XDECREF(%(out)s);
4382            %(out)s = (PyArrayObject *)PyArray_Concatenate(list, axis);
4383            Py_DECREF(list);
4384            if(!%(out)s){
4385                %(fail)s
4386            }
4387        }
4388        """ % locals()
4389        return code
4390
4391    def R_op(self, inputs, eval_points):
4392        if None in eval_points[1:]:
4393            return [None]
4394        return self.make_node(inputs[0], *eval_points[1:]).outputs
4395
4396    def grad(self, axis_and_tensors, grads):
4397        """ The gradient wrt a join op is a `Split`, used to partition
4398        the gradient along the `axis` which was used for joining.
4399        """
4400        gz, = grads
4401        axis, tensors = axis_and_tensors[0], axis_and_tensors[1:]
4402
4403        rval = [grad_undefined(self, 0, axis)]
4404
4405        dtypes = [as_tensor_variable(x).type.dtype for x in tensors]
4406        out_dtype = scal.upcast(*dtypes)
4407
4408        if 'float' in out_dtype or 'complex' in out_dtype:
4409            # assume that this is differentiable
4410            split = Split(len(tensors))
4411            split_gz = split(gz, axis, stack([shape(x)[axis]
4412                                              for x in tensors]))
4413            # If there is only one split, it might not be in a list.
4414            if not isinstance(split_gz, list):
4415                split_gz = [split_gz]
4416            # Split.make_node isn't always able to infer the right
4417            # broadcast. As the grad need to keep the information,
4418            # read it if needed.
4419            split_gz = [patternbroadcast(g, t.broadcastable)
4420                        for t, g in zip(tensors, split_gz)]
4421            rval = rval + split_gz
4422        else:
4423            # the output has integer type, so the gradient through it
4424            # is 0
4425            rval = rval + [tensor.zeros_like(dtype=config.floatX)
4426                           for tensor in tensors]
4427
4428        return rval
4429
4430    def infer_shape(self, node, ishapes):
4431        # ishapes[0] contains the size of the axis on which we join
4432        # Join op should get at least one input to join
4433        assert len(ishapes) > 1
4434        n_dim = len(ishapes[1])
4435        for shp in ishapes[1:]:
4436            assert shp is not None
4437            assert len(shp) == n_dim
4438
4439        # The joining dimension could be negative, but we need it to be
4440        # in [0, n_dim) in the loop below.
4441        # An axis < -n_dim or >= ndim would be invalid, but this is
4442        # not checked here. An Assert op would be a way of addressing that,
4443        # but it may disrupt optimizations.
4444        join_dim = switch(ge(node.inputs[0], 0),
4445                          node.inputs[0],
4446                          node.inputs[0] + n_dim)
4447        out_shapes = []
4448        for dim in xrange(n_dim):
4449            # we have to deal with 2 possible cases in here :
4450            #   a) we are dealing with the dimension for which we join
4451            #     (called t_side from true side of the if, where the if
4452            #     compares current dimension with the joining dimension)
4453            #   b) a non joining dimension ( in which maybe a symbolic
4454            #      assertion can be used to make sure all tensors have
4455            #      the same number of elements on this non-joined dimension
4456            #      this is f_side
4457            # initialize
4458            t_side = ishapes[1][dim]
4459            f_side = ishapes[1][dim]
4460            # loop over tensors and sum for the joining dimension
4461            for shp in ishapes[2:]:
4462                t_side = t_side + shp[dim]
4463            # return the dimensions found
4464            out_shapes.append(switch(eq(dim, join_dim),
4465                              t_side, f_side))
4466
4467        return [tuple(out_shapes)]
4468
4469
4470join_ = Join()
4471pprint.assign(Join, printing.FunctionPrinter('join'))
4472
4473
4474def join(axis, *tensors_list):
4475    """
4476    Convenience function to concatenate `TensorType`s along the given axis.
4477
4478    This function will not add the op in the graph when it is not useful.
4479    For example, in the case that the list of tensors to be concatenated
4480    is one, it will just return the tensor.
4481
4482    Parameters
4483    ----------
4484    tensors : list of tensors (or list-like)
4485        A list of tensors to be concatenated along the given axis.
4486        The shapes of the tensors to be concatenated must be all
4487        identical, except in the dimension (`axis`) on which they are to
4488        be joined.
4489    axis : int (symbolic or literal)
4490        On which dimension should the tensors be joined?  The `axis`
4491        must be a valid index into the shape of the tensors to be
4492        concatenated.
4493        The `axis` parameter may either be an integer or an object that
4494        can be converted to a scalar using `as_scalar`(`axis`). In the
4495        former case, the axis is fixed at construction, while in the
4496        latter it may vary over time depending on the value of the
4497        `axis` variable.
4498    """
4499    if len(tensors_list) == 1:
4500        return tensors_list[0]
4501    else:
4502        return join_(axis, *tensors_list)
4503
4504
4505def roll(x, shift, axis=None):
4506    """
4507    Convenience function to roll TensorTypes along the given axis.
4508
4509    Syntax copies numpy.roll function.
4510
4511    Parameters
4512    ----------
4513    x : tensor_like
4514        Input tensor.
4515    shift : int (symbolic or literal)
4516        The number of places by which elements are shifted.
4517    axis : int (symbolic or literal), optional
4518        The axis along which elements are shifted. By default, the array
4519        is flattened before shifting, after which the original
4520        shape is restored.
4521
4522    Returns
4523    -------
4524    tensor
4525        Output tensor, with the same shape as ``x``.
4526
4527    """
4528    if axis is None:
4529        if x.ndim > 1:
4530            y = x.flatten()
4531            return roll(y, shift, axis=0).reshape(x.shape)
4532        else:
4533            axis = 0
4534
4535    if axis < 0:
4536        axis += x.ndim
4537
4538    # Shift may be larger than the size of the axis. If so, since the
4539    # roll operation is cyclic, we can take the shift modulo the size
4540    # of the axis
4541    shift = shift % x.shape[axis]
4542
4543    # A slice of all elements in a dimension ':'
4544    allslice = slice(None)
4545    # List of slices describing the front half [:, :, shift:, :]
4546    front_slice = slice(-shift, None)
4547    front_list = ([allslice] * axis + [front_slice] +
4548                  [allslice] * (x.ndim - axis - 1))
4549    # List of slices describing the back half [:, :, :shift, :]
4550    end_slice = slice(0, -shift)
4551    end_list = ([allslice] * axis + [end_slice] +
4552                [allslice] * (x.ndim - axis - 1))
4553    return join(axis,
4554                x.__getitem__(tuple(front_list)),
4555                x.__getitem__(tuple(end_list)))
4556
4557
4558@constructor
4559def shape_padleft(t, n_ones=1):
4560    """Reshape `t` by left-padding the shape with `n_ones` 1s.
4561
4562    See Also
4563    --------
4564    shape_padaxis
4565    shape_padright
4566    Dimshuffle
4567
4568    """
4569    _t = as_tensor_variable(t)
4570
4571    pattern = ['x'] * n_ones + [i for i in xrange(_t.type.ndim)]
4572    return DimShuffle(_t.broadcastable, pattern)(_t)
4573
4574
4575@constructor
4576def shape_padright(t, n_ones=1):
4577    """Reshape `t` by right-padding the shape with `n_ones` 1s.
4578
4579    See Also
4580    --------
4581    shape_padaxis
4582    shape_padleft
4583    Dimshuffle
4584
4585    """
4586    _t = as_tensor_variable(t)
4587
4588    pattern = [i for i in xrange(_t.type.ndim)] + ['x'] * n_ones
4589    return DimShuffle(_t.broadcastable, pattern)(_t)
4590
4591
4592@constructor
4593def shape_padaxis(t, axis):
4594    """Reshape `t` by inserting 1 at the dimension `axis`.
4595
4596    Example
4597    -------
4598    >>> tensor = theano.tensor.tensor3()
4599    >>> theano.tensor.shape_padaxis(tensor, axis=0)
4600    DimShuffle{x,0,1,2}.0
4601    >>> theano.tensor.shape_padaxis(tensor, axis=1)
4602    DimShuffle{0,x,1,2}.0
4603    >>> theano.tensor.shape_padaxis(tensor, axis=3)
4604    DimShuffle{0,1,2,x}.0
4605    >>> theano.tensor.shape_padaxis(tensor, axis=-1)
4606    DimShuffle{0,1,2,x}.0
4607
4608    See Also
4609    --------
4610    shape_padleft
4611    shape_padright
4612    Dimshuffle
4613
4614    """
4615    _t = as_tensor_variable(t)
4616
4617    ndim = _t.ndim + 1
4618    if not -ndim <= axis < ndim:
4619        msg = 'axis {0} is out of bounds [-{1}, {1})'.format(axis, ndim)
4620        raise IndexError(msg)
4621    if axis < 0:
4622        axis += ndim
4623
4624    pattern = [i for i in xrange(_t.type.ndim)]
4625    pattern.insert(axis, 'x')
4626    return DimShuffle(_t.broadcastable, pattern)(_t)
4627
4628
4629@constructor
4630def stack(*tensors, **kwargs):
4631    """Stack tensors in sequence on given axis (default is 0).
4632
4633    Take a sequence of tensors and stack them on given axis to make a single
4634    tensor. The size in dimension `axis` of the result will be equal to the number
4635    of tensors passed.
4636
4637    Note: The interface stack(*tensors) is deprecated, you should use
4638    stack(tensors, axis=0) insted.
4639
4640    Parameters
4641    ----------
4642    tensors : list or tuple of tensors
4643        A list of tensors to be stacked.
4644    axis : int
4645        The index of the new axis. Default value is 0.
4646
4647    Examples
4648    --------
4649    >>> a = theano.tensor.scalar()
4650    >>> b = theano.tensor.scalar()
4651    >>> c = theano.tensor.scalar()
4652    >>> x = theano.tensor.stack([a, b, c])
4653    >>> x.ndim # x is a vector of length 3.
4654    1
4655    >>> a = theano.tensor.tensor4()
4656    >>> b = theano.tensor.tensor4()
4657    >>> c = theano.tensor.tensor4()
4658    >>> x = theano.tensor.stack([a, b, c])
4659    >>> x.ndim # x is a 5d tensor.
4660    5
4661    >>> rval = x.eval(dict((t, np.zeros((2, 2, 2, 2))) for t in [a, b, c]))
4662    >>> rval.shape # 3 tensors are stacked on axis 0
4663    (3, 2, 2, 2, 2)
4664    >>> x = theano.tensor.stack([a, b, c], axis=3)
4665    >>> x.ndim
4666    5
4667    >>> rval = x.eval(dict((t, np.zeros((2, 2, 2, 2))) for t in [a, b, c]))
4668    >>> rval.shape # 3 tensors are stacked on axis 3
4669    (2, 2, 2, 3, 2)
4670    >>> x = theano.tensor.stack([a, b, c], axis=-2)
4671    >>> x.ndim
4672    5
4673    >>> rval = x.eval(dict((t, np.zeros((2, 2, 2, 2))) for t in [a, b, c]))
4674    >>> rval.shape # 3 tensors are stacked on axis -2
4675    (2, 2, 2, 3, 2)
4676    """
4677    # ---> Remove this when moving to the new interface:
4678    if not tensors and not kwargs:
4679        raise Exception('theano.tensor.stack(tensors, axis) must have at least'
4680                        ' one parameter')
4681
4682    if not kwargs and not isinstance(tensors[0], (list, tuple)):
4683        warnings.warn('stack(*tensors) interface is deprecated, use'
4684                      ' stack(tensors, axis=0) instead.', DeprecationWarning,
4685                      stacklevel=3)
4686        axis = 0
4687    elif 'tensors' in kwargs:
4688        tensors = kwargs['tensors']
4689        if 'axis' in kwargs:
4690            axis = kwargs['axis']
4691        else:
4692            axis = 0
4693    else:
4694        if len(tensors) == 2:
4695            axis = tensors[1]
4696        elif 'axis' in kwargs:
4697            axis = kwargs['axis']
4698        else:
4699            axis = 0
4700        tensors = tensors[0]
4701    # <--- Until here.
4702
4703    if len(tensors) == 0:
4704        raise Exception('tensors is empty. You should at least provide one'
4705                        ' tensor to theano.tensor.stack(tensors, axis).')
4706
4707    # If all tensors are scalars of the same type, call make_vector.
4708    # It makes the graph simpler, by not adding DimShuffles and Rebroadcasts
4709
4710    # This should be an optimization!
4711    # Doing it here make the graph less canonicalized
4712    # (more type need to be understood by all optimization)
4713    # And DebugMode can't detect error in this code as it is not in an
4714    # optimization.
4715    # See ticket #660
4716    if np.all(
4717        [  # in case there is direct int in tensors.
4718            isinstance(t, (np.number, float, integer_types,
4719                           python_complex)) or
4720            (isinstance(t, Variable) and
4721             isinstance(t.type, TensorType) and
4722             t.ndim == 0)
4723            for t in tensors]):
4724        # in case there is direct int
4725        tensors = list(map(as_tensor_variable, tensors))
4726        dtype = scal.upcast(*[i.dtype for i in tensors])
4727        return theano.tensor.opt.MakeVector(dtype)(*tensors)
4728    return join(axis, *[shape_padaxis(t, axis) for t in tensors])
4729
4730
4731@constructor
4732def concatenate(tensor_list, axis=0):
4733    """Alias for `join`(axis, *tensor_list).
4734
4735    This function is similar to `join`, but uses the signature of
4736    numpy's concatenate function.
4737
4738    Raises
4739    ------
4740    TypeError
4741        The tensor_list must be a tuple or list.
4742
4743    """
4744    # Check someone did not make the common mistake to do something like:
4745    #   c = concatenate(x, y)
4746    # instead of
4747    #   c = concatenate((x, y))
4748    if not isinstance(tensor_list, (tuple, list)):
4749        raise TypeError(
4750            "The 'tensors' argument must be either a tuple "
4751            "or a list, make sure you did not forget () or [] around "
4752            "arguments of concatenate.", tensor_list)
4753    return join(axis, *tensor_list)
4754
4755
4756def get_vector_length(v):
4757    """Return the run-time length of a symbolic vector.
4758
4759    Parameters
4760    ----------
4761    v
4762        A rank-1 TensorType variable.
4763
4764    Raises
4765    ------
4766    TypeError
4767        `v` hasn't the proper type.
4768    ValueError
4769        No special case applies, the length is not known.
4770        In general this is not possible, but for a number of special cases
4771        the length can be determined at compile / graph-construction time.
4772        This function implements these special cases.
4773
4774    """
4775    v = as_tensor_variable(v)
4776    if v.ndim != 1:
4777        raise TypeError("argument must be symbolic vector, got '%s'" %
4778                        v)
4779    if v.type.broadcastable[0]:
4780        return 1
4781    if isinstance(v, gof.Constant) and v.type.ndim == 1:
4782        return len(v.data)
4783    if v.owner and isinstance(v.owner.op, theano.tensor.opt.MakeVector):
4784        return len(v.owner.inputs)
4785    if v.owner and isinstance(v.owner.op, Shape):
4786        return v.owner.inputs[0].type.ndim
4787    # If we take a slice, we know how many elements it will result in
4788    if ((v.owner and
4789         isinstance(v.owner.op, theano.tensor.subtensor.Subtensor) and
4790         isinstance(v.owner.op.idx_list[0], slice) and
4791         v.owner.inputs[0].owner and
4792         isinstance(v.owner.inputs[0].owner.op, theano.compile.ops.Shape))):
4793        start = extract_constant(theano.tensor.subtensor.get_idx_list(
4794            v.owner.inputs, v.owner.op.idx_list)[0].start)
4795        stop = extract_constant(theano.tensor.subtensor.get_idx_list(
4796            v.owner.inputs, v.owner.op.idx_list)[0].stop)
4797        step = extract_constant(theano.tensor.subtensor.get_idx_list(
4798            v.owner.inputs, v.owner.op.idx_list)[0].step)
4799
4800        ndim = v.owner.inputs[0].owner.inputs[0].ndim
4801        types = (numbers.Integral, np.integer)
4802        if start is None:
4803            start = 0
4804        elif isinstance(start, types) and start < 0:
4805            start += ndim
4806            if start < 0:
4807                start = 0
4808        if stop is None:
4809            stop = ndim
4810        elif isinstance(stop, types):
4811            if stop > ndim:
4812                stop = ndim
4813            elif stop < 0:
4814                stop += ndim
4815        if step is None:
4816            step = 1
4817
4818        if (isinstance(stop, types) and
4819                isinstance(start, types) and
4820                isinstance(step, types) and
4821                start >= 0 and stop >= 0 and
4822                step > 0 and stop >= start):
4823            return (stop - start - 1) // step + 1
4824    if isinstance(v, Variable):
4825        msg = theano.printing.debugprint(v, file='str')
4826    else:
4827        msg = str(v)
4828    raise ValueError("length not known: %s" % msg)
4829
4830
4831@constructor
4832def horizontal_stack(*args):
4833    """
4834    Horizontally stack two L{TensorType}s.
4835
4836    Stack two L{TensorType}s along the second axis (column wise). These
4837    L{TensorType}s must have the same shape along all dimensions but the
4838    second.
4839
4840    """
4841    # Note: 'horizontal_stack' and 'vertical_stack' do not behave exactly like
4842    # Numpy's hstack and vstack functions. This is intended, because Numpy's
4843    # functions have potentially confusing/incoherent behavior (try them on 1D
4844    # arrays). If this is fixed in a future version of Numpy, it may be worth
4845    # trying to get closer to Numpy's way of doing things. In the meantime,
4846    # better keep different names to emphasize the implementation divergences.
4847    assert len(args) >= 2
4848    for arg in args:
4849        assert arg.type.ndim == 2
4850    return concatenate(args, axis=1)
4851
4852
4853@constructor
4854def vertical_stack(*args):
4855    assert len(args) >= 2
4856    for arg in args:
4857        assert arg.type.ndim == 2
4858    return concatenate(args, axis=0)
4859
4860
4861class Reshape(Op):
4862    """Perform a reshape operation of the input x to the new shape shp.
4863    The number of dimensions to which to reshape to (ndim) must be
4864    known at graph build time.
4865    """
4866    view_map = {0: [0]}  # output 0 is potentially aliased to inputs [0]
4867    _f16_ok = True
4868
4869    check_input = False
4870    __props__ = ("ndim",)
4871    params_type = ParamsType(ndim=int32)
4872    # name does not participate because it doesn't affect computations
4873
4874    def __init__(self, ndim, name=None):
4875        self.ndim = int(ndim)
4876        if ndim < 0:
4877            raise ValueError("The output dimensions after reshape must be 0 or greater")
4878        assert name is None, 'name attribute for Reshape has been deprecated'
4879
4880    def __str__(self):
4881        return '%s{%s}' % (self.__class__.__name__, self.ndim)
4882
4883    def make_node(self, x, shp):
4884        x = as_tensor_variable(x)
4885        shp_orig = shp
4886        shp = as_tensor_variable(shp, ndim=1)
4887        if not (shp.dtype in int_dtypes or
4888                (isinstance(shp, TensorConstant) and shp.data.size == 0)):
4889            # It raises an error if shp is not of integer type,
4890            # except when shp is constant and empty
4891            # (in this case, shp.dtype does not matter anymore).
4892            raise TypeError("Shape must be integers", shp, shp.dtype)
4893        assert shp.ndim == 1
4894        if isinstance(shp, TensorConstant):
4895            bcast = [s == 1 for s in shp.data]
4896            return gof.Apply(self, [x, shp], [tensor(x.type.dtype, bcast)])
4897        else:
4898            bcasts = [False] * self.ndim
4899            shp_list = shp_orig
4900            if hasattr(shp_orig, "ndim") and shp_orig.ndim == 0:
4901                shp_list = [shp_orig]
4902            for index in xrange(self.ndim):
4903                y = shp_list[index]
4904                y = as_tensor_variable(y)
4905                # Try to see if we can infer that y has a constant value of 1.
4906                # If so, that dimension should be broadcastable.
4907                try:
4908                    bcasts[index] = (
4909                        hasattr(y, 'get_scalar_constant_value') and
4910                        y.get_scalar_constant_value() == 1)
4911                except NotScalarConstantError:
4912                    pass
4913            return gof.Apply(self, [x, shp], [tensor(x.type.dtype, bcasts)])
4914
4915    def perform(self, node, inp, out_, params):
4916        x, shp = inp
4917        out, = out_
4918        if (len(shp) != self.ndim):
4919            raise ValueError('shape argument to Reshape.perform has incorrect'
4920                             ' length %i'
4921                             ', should be %i' % (len(shp), self.ndim), shp)
4922        try:
4923            out[0] = np.reshape(x, shp)
4924        except Exception:
4925            raise ValueError('Cannot reshape input of shape %s to shape %s' %
4926                             (x.shape, shp))
4927
4928    def connection_pattern(self, node):
4929        return [[True], [False]]
4930
4931    def grad(self, inp, grads):
4932        x, shp = inp
4933        g_out, = grads
4934        return [reshape(g_out, shape(x), ndim=x.ndim),
4935                DisconnectedType()()]
4936
4937    def R_op(self, inputs, eval_points):
4938        if eval_points[0] is None:
4939            return [None]
4940        return self(eval_points[0], *inputs[1:], **dict(return_list=True))
4941
4942    def infer_shape(self, node, ishapes):
4943        # inputs[1] can contain at most one value of '-1', meaning the actual
4944        # shape of the output will be automatically computed by reshape, so
4945        # that the total number of elements stays the same.
4946        # TODO: Maybe put that formula here?
4947        # It's not trivial, because we would have to check if the product of
4948        # all the non-minus-one shapes is a divisor of the product of the
4949        # original shapes.
4950
4951        # The following expression leads to cycles in feature_shape,
4952        # because it tries to replace the Shape_i node by the switch
4953        # statement, which depends on Shape_i.
4954        # return [tuple([switch(eq(node.inputs[1][i], -1),
4955        #                      theano.tensor.opt.Shape_i(i)(node.outputs[0]),
4956        #                      node.inputs[1][i])
4957        #                    for i in xrange(self.ndim)]
4958        #    )]
4959
4960        # Here, we only simplify if the shape (node.inputs[1]) is a constant,
4961        # ideally it would suffice to check that it is always non-negative.
4962
4963        # If current variable is a scalar and its dimensionality should
4964        # change to self.ndim, then use size 1 for all new dimensions.
4965        if len(ishapes[0]) == 0:
4966            return [(1,) * self.ndim]
4967
4968        requ = node.inputs[1]
4969        input_size = mul(*ishapes[0])
4970        if isinstance(requ, theano.tensor.TensorConstant):
4971            requ = list(requ.data)
4972            requ_part = [ele for ele in requ if ele != -1]
4973            crit = len(requ) - len(requ_part)
4974            if crit == 1 and len(requ_part) > 0:
4975                # If there are both 0 and -1 in requ_size, it is impossible
4976                # to determine a right output, but we can at least prevent
4977                # a division by 0. We do not want to keep a negative
4978                # size here as it could lead to further weird errors
4979                # after other optimizations.
4980                requ_size = mul(*requ_part)
4981                missing = input_size // (1 if requ_size == 0 else requ_size)
4982                for i, ele in enumerate(requ):
4983                    if ele == -1:
4984                        requ[i] = missing
4985            elif crit == 1:  # we reshape to -1
4986                requ = [input_size] if ishapes[0] else [1]
4987            elif crit > 1:
4988                raise ValueError('shape argument to Reshape.perform'
4989                                 ' must have at most one entry equal to -1')
4990            return [requ]
4991        else:
4992            requ = [requ[i] for i in xrange(self.ndim)]
4993            # since new_dims can have negative value (-1), the
4994            # multiplication of all values should be negated
4995            # to give a positive value.
4996            # To avoid optimization complexity, we avoid checking
4997            # for the case when there are two or more '-1' values.
4998            if self.ndim:
4999                requ_size = -mul(*requ)
5000                # If there are both 0 and -1 in requ_size, it is impossible
5001                # to determine a right output, but we can at least prevent
5002                # a division by 0. We do not want to keep a negative
5003                # size here as it could lead to further weird errors
5004                # after other optimizations.
5005                rest_size = input_size // maximum(requ_size, 1)
5006            return [tuple([switch(eq(requ[i], -1),
5007                                  rest_size,
5008                                  requ[i])
5009                           for i in xrange(self.ndim)])]
5010
5011    def c_code_cache_version(self):
5012        return (8,)
5013
5014    def c_code(self, node, name, inputs, outputs, sub):
5015        if isinstance(node.inputs[0], TensorVariable):
5016            x, shp = inputs
5017            z, = outputs
5018            sdtype = node.inputs[1].type.dtype_specs()[1]
5019            fail = sub['fail']
5020            params = sub['params']
5021            return """
5022            assert (PyArray_NDIM(%(shp)s) == 1);
5023            npy_intp new_dims[%(params)s->ndim];
5024            PyArray_Dims newshape;
5025            newshape.ptr = new_dims;
5026            newshape.len = %(params)s->ndim;
5027            for (int ii = 0; ii < %(params)s->ndim; ++ii)
5028            {
5029                // -- We do not want an explicit cast here. the shp can be any
5030                // -- int* dtype. The compiler will explicitly upcast it, but
5031                // -- will err if this will downcast. This could happen if the
5032                // -- user pass an int64 dtype, but npy_intp endup being int32.
5033                new_dims[ii] = ((%(sdtype)s*)(
5034                        PyArray_BYTES(%(shp)s) +
5035                        ii * PyArray_STRIDES(%(shp)s)[0]))[0];
5036            }
5037            Py_XDECREF(%(z)s);
5038            %(z)s = (PyArrayObject *) PyArray_Newshape(%(x)s, &newshape, NPY_CORDER);
5039            if (!%(z)s)
5040            {
5041                //The error message should have been set by PyArray_Newshape
5042                %(fail)s;
5043            }
5044            """ % locals()
5045        else:
5046            return Op.c_code(self, node, name, inputs, outputs, sub)
5047
5048
5049def reshape(x, newshape, ndim=None):
5050    if ndim is None:
5051        newshape = as_tensor_variable(newshape)
5052        if newshape.ndim != 1:
5053            raise TypeError(
5054                "New shape in reshape must be a vector or a list/tuple of"
5055                " scalar. Got %s after conversion to a vector." % newshape)
5056        try:
5057            ndim = get_vector_length(newshape)
5058        except ValueError:
5059            raise ValueError(
5060                "The length of the provided shape (%s) cannot "
5061                "be automatically determined, so Theano is not able "
5062                "to know what the number of dimensions of the reshaped "
5063                "variable will be. You can provide the 'ndim' keyword "
5064                "argument to 'reshape' to avoid this problem." % newshape)
5065    op = Reshape(ndim)
5066    rval = op(x, newshape)
5067    return rval
5068
5069
5070class Flatten(Op):
5071    """
5072    Flatten a tensor.
5073
5074    Flattens a tensor to `outdim` dimensions by preserving the leading
5075    outdim - 1 shape components.
5076
5077    .. note:: The interface Flatten(Op) is deprecated, you should use flatten.
5078    """
5079    view_map = {0: [0]}
5080
5081    check_input = False
5082    __props__ = ("outdim",)
5083
5084    def __init__(self, outdim=1):
5085        warnings.warn(
5086            "Flatten class is deprecated, "
5087            "please use flatten method instead.",
5088            DeprecationWarning,
5089            stacklevel=4)
5090        self.outdim = int(outdim)
5091
5092    def __str__(self):
5093        return '%s{%s}' % (self.__class__.__name__, self.outdim)
5094
5095    def make_node(self, x):
5096        t_x = as_tensor_variable(x)
5097        if self.outdim < 1 or (x.ndim and self.outdim > x.ndim):
5098            raise ValueError('invalid output ndimensions (%i) for tensor of '
5099                             'rank %i' % (self.outdim, t_x.ndim))
5100
5101        # Infer the broadcastable pattern of the output. For every dimension
5102        # unaffected by the flatten, the broadcast flag should be unchanged.
5103        # For the dimension resulting from the collapse of other dimensions,
5104        # it should be broadcastable iff all the collapsed dimensions were
5105        # broadcastable.
5106        bcast_kept_dims = x.broadcastable[:self.outdim - 1]
5107        bcast_new_dim = python_all(x.broadcastable[self.outdim - 1:])
5108        broadcastable = bcast_kept_dims + (bcast_new_dim,)
5109
5110        return gof.Apply(self, [t_x], [tensor(x.type.dtype,
5111                                              broadcastable)])
5112
5113    def perform(self, node, inp, out_):
5114        x, = inp
5115        out, = out_
5116        outdim = self.outdim
5117        if outdim == 1:
5118            try:
5119                out[0] = x.reshape(x.size)
5120            except AttributeError:
5121                out[0] = x.reshape((np.prod(x.shape),))
5122        elif outdim == len(x.shape):
5123            out[0] = x
5124        else:
5125            newshape = (x.shape[:outdim - 1] +
5126                        (np.prod(x.shape[outdim - 1:]),))
5127            out[0] = x.reshape(newshape)
5128
5129    def infer_shape(self, node, in_shapes):
5130        in_shp, = in_shapes
5131        part1 = in_shp[:self.outdim - 1]
5132        part2 = in_shp[self.outdim - 1:]
5133
5134        if len(part2) > 1:
5135            part2 = (prod(part2, dtype='int64'),)
5136        elif len(part2) == 1:
5137            # We do not want to force an upcast of part2 if its length is 1
5138            pass
5139        else:
5140            if len(in_shp) == 0 and self.outdim == 1:
5141                part2 = (1,)
5142            else:
5143                raise ValueError('invalid output ndimensions (%i) for tensor '
5144                                 'of rank %i' % (self.outdim, len(in_shp)))
5145
5146        out_shape = (part1 + part2)
5147        return [out_shape]
5148
5149    def grad(self, inp, grads):
5150        x, = inp
5151        g_out, = grads
5152        return [reshape(g_out, shape(x), x.ndim)]
5153
5154    def R_op(self, inputs, eval_points):
5155        if None in eval_points:
5156            return [None]
5157        return self.make_node(*eval_points).outputs
5158
5159    def c_code_cache_version(self):
5160        return (1, 1)
5161
5162    def c_code(self, node, name, inputs, outputs, sub):
5163        x, = inputs
5164        out, = outputs
5165        outdim = self.outdim
5166        fail = sub['fail']
5167        return """
5168        if (%(outdim)s == PyArray_NDIM(%(x)s))
5169        {
5170            Py_XDECREF(%(out)s);
5171            Py_XINCREF(%(x)s);
5172            %(out)s = %(x)s;
5173        }
5174        else
5175        {
5176            Py_XDECREF(%(out)s);
5177
5178            if (%(outdim)s == 1)
5179            {
5180                npy_intp size = PyArray_SIZE(%(x)s);
5181                PyArray_Dims newshape;
5182                newshape.ptr = &size;
5183                newshape.len = 1;
5184                %(out)s = (PyArrayObject*)PyArray_Newshape(%(x)s,
5185                                                           &newshape,
5186                                                           NPY_CORDER);
5187            }
5188            else
5189            {
5190                npy_intp *oldshape = PyArray_DIMS(%(x)s);
5191                npy_intp newshape_dims[%(outdim)s];
5192
5193                int i;
5194                for (i = 0; i < %(outdim)s - 1; ++i)
5195                    newshape_dims[i] = oldshape[i];
5196
5197                newshape_dims[i] = 1;
5198
5199                for (int j = %(outdim)s - 1; j < PyArray_NDIM(%(x)s); ++j)
5200                    newshape_dims[i] *= oldshape[j];
5201
5202                PyArray_Dims newshape;
5203                newshape.ptr = newshape_dims;
5204                newshape.len = %(outdim)s;
5205                %(out)s = (PyArrayObject*)PyArray_Newshape(%(x)s,
5206                                                           &newshape,
5207                                                           NPY_CORDER);
5208            }
5209        }
5210        if (!%(out)s)
5211        {
5212            //The error message should have been set by
5213            // PyArray_Newshape
5214            %(fail)s;
5215        }
5216        """ % locals()
5217
5218
5219def is_flat(var, ndim=None, outdim=None):
5220    """
5221    Verifies the dimensionality of the var is equal to
5222    outdim. This method is usually called after flatten method on a
5223    variable, where the first outdim-1 dimension size(s) of the variable
5224    is kept intact, and the last dimension size of the variable is made
5225    equal to the multiplication of its remaining dimension size(s), such that
5226    the variable would end up with as many dimension as outdim.
5227
5228    Parameters
5229    ----------
5230        var : theano.tensor.var.TensorVariable
5231            the theano var on which the dimensionality is checked.
5232
5233        outdim : int
5234            the expected dimensionality of var.
5235
5236    Returns
5237    -------
5238    bool
5239        the comparison result of var's dim
5240        and the expected outdim.
5241    """
5242    if outdim is None and ndim is None:
5243        ndim = 1
5244    elif outdim is not None and ndim is not None:
5245        raise ValueError("You should only specify ndim")
5246    elif outdim is not None:
5247        warnings.warn(
5248            "flatten outdim parameter is deprecated, use ndim instead.")
5249        ndim = outdim
5250    return var.ndim == ndim
5251
5252
5253def flatten(x, ndim=None, outdim=None):
5254    """
5255    Reshapes the variable x by keeping
5256    the first outdim-1 dimension size(s) of x the same,
5257    and making the last dimension size of x equal to
5258    the multiplication of its remaining dimension size(s).
5259
5260    Parameters
5261    ----------
5262        x : theano.tensor.var.TensorVariable
5263            the variable that should be reshaped.
5264
5265        ndim : int
5266            the number of dimensions of the returned variable
5267            Default 1.
5268        outdim : int
5269            DEPRECATED synonym for ndim
5270    Returns
5271    -------
5272    theano.tensor.var.TensorVariable
5273        the flattend variable with dimensionality of outdim
5274    """
5275    if outdim is None and ndim is None:
5276        ndim = 1
5277    elif outdim is not None and ndim is not None:
5278        raise ValueError("You should only specify ndim")
5279    elif outdim is not None:
5280        warnings.warn(
5281            "flatten outdim parameter is deprecated, use ndim instead.")
5282
5283        ndim = outdim
5284    # Any input variable can be flattened to have ndim of 1,
5285    # even if it's a scalar. Otherwise, ndim must be positive
5286    # and smaller than x.ndim.
5287    if ndim < 1 or (ndim > 1 and ndim > x.ndim):
5288        raise ValueError('ndim %s out of bound [1, %d)'
5289                         % (ndim, x.ndim + 1))
5290
5291    if ndim > 1:
5292        dims = tuple(x.shape[:ndim - 1]) + (-1,)
5293    else:
5294        dims = (-1,)
5295    x_reshaped = x.reshape(dims)
5296    bcast_kept_dims = x.broadcastable[:ndim - 1]
5297    bcast_new_dim = python_all(x.broadcastable[ndim - 1:])
5298    broadcastable = bcast_kept_dims + (bcast_new_dim,)
5299    x_reshaped = theano.tensor.addbroadcast(
5300        x_reshaped, *filter(lambda i: broadcastable[i], range(ndim)))
5301    return x_reshaped
5302
5303
5304# class TileGrad(Op):
5305#     """
5306#     Calculates the gradient of the Tile Op.
5307#     """
5308#     # this is so weird, I can't think of how to make this a general thing.
5309#     def make_node(self, x, reps, g_out):
5310#         return gof.Apply(self, [x, reps, g_out], [x.type()])
5311#
5312#     def perform(self, node, inp, out):
5313#         x, reps, g_out = inp
5314#         gx, = out
5315#         xsh = x.shape
5316#         if len(reps) == 2 and reps[1] == 1 and len(x.shape) == 1:
5317#             gx[0] = numpy.sum(g_out, axis=0)
5318#         else:
5319#             raise NotImplementedError('x.shape, reps combination not '
5320#                                       'supported', (x.shape, reps))
5321#
5322# tilegrad = TileGrad()
5323
5324
5325class Tile(Op):
5326    """
5327    Construct an array by repeating the input x according to reps pattern.
5328
5329    .. note:: Deprecated
5330              Use tile() instead.
5331
5332    Tiles its input according to reps. The length of reps is the number of
5333    dimension of x and contains the number of times to tile x in each
5334    dimension.
5335
5336    See Also
5337    --------
5338    numpy.tile : http://docs.scipy.org/doc/numpy/reference/generated/numpy.tile.html
5339
5340    """
5341    __props__ = ("ndim",)
5342
5343    def __init__(self, ndim):
5344        self.ndim = ndim
5345
5346    def __str__(self):
5347        return self.__class__.__name__ + "{ndim=%d}" % self.ndim
5348
5349    def make_node(self, x, reps):
5350        warnings.warn((
5351            "Tile op is deprecated, use tile function instead."), stacklevel=3)
5352        x = as_tensor_variable(x)
5353        reps = as_tensor_variable(reps)
5354        return gof.Apply(self, [x, reps], [tensor(x.type.dtype, [False] *
5355                                                  self.ndim)])
5356
5357    def perform(self, node, inp, out_):
5358        x, reps = inp
5359        out, = out_
5360        res = np.tile(x, reps)
5361        if res.ndim != self.ndim:
5362            raise ValueError(
5363                'Tile.perform produced incorrect number of dimensions')
5364
5365        if (np.asarray(reps) == 1).all():
5366            # In that case, some NumPy version return a view!  As this
5367            # op isn't declared as inplace, we need to check that and
5368            # copy the data.
5369            if np.may_share_memory(res, x):
5370                res = res.copy()
5371        out[0] = res
5372
5373    def infer_shape(self, node, in_shapes):
5374        # Note: in contrast with numpy, it is assumed that x.shape and reps
5375        # have equal length;  see also tile function below
5376
5377        # Note: if reps were to be allowed not to be a constant and x.shape
5378        # and reps to be unequal, the following block of code could be used:
5379        # prepend 1 to x.shape if needed
5380        # if self.ndim > x.ndim:
5381        # shp = concatenate(ones(self.ndim - x.ndim), shp)
5382        # prepend 1 to reps if needed
5383        # reps = concatenate(ones(self.ndim - reps.shape[0]), reps)
5384
5385        x, reps = node.inputs
5386        shp = in_shapes[0]
5387        tiled_shp = shp * reps
5388        out_shape = []
5389        for i in xrange(self.ndim):
5390            out_shape.append(tiled_shp[i])
5391        return [out_shape]
5392
5393    def grad(self, inp, grads):
5394        x, reps = inp
5395        g_out, = grads
5396        # return [tilegrad(x, reps, g_out), None]
5397        raise NotImplementedError()
5398
5399
5400def tile(x, reps, ndim=None):
5401    """
5402    Tile input array `x` according to `reps`.
5403
5404    See the docstring of `numpy.tile` for details.
5405
5406    'reps' can be constant integer (e.g. 3), constant vector(e.g. [2 3]),
5407    symbolic scalar (e.g. tensor.iscalar()), symbolic vector (e.g. tensor.ivector())
5408    or a list of symbolic scalar (e.g. [tensor.iscalar(), tensor.iscalar()]).
5409
5410    ndim is the number of the dimensions of the output, if it is provided, ndim
5411    should be equal or larger than x.ndim and len(reps), otherwise, we will use
5412    max(x.ndim, len(reps)) as ndim. If reps is symbolic vector, the ndim has to
5413    be provided.
5414
5415    """
5416
5417    if ndim is not None and ndim < x.ndim:
5418        raise ValueError("ndim should be equal or larger than x.ndim")
5419
5420    # if reps is tensor.scalar, integer or tensor.vector, we convert it to a list.
5421    if not isinstance(reps, (list, tuple)):
5422        reps_astensor = as_tensor_variable(reps)
5423        ndim_check = reps_astensor.ndim
5424        if reps_astensor.dtype not in theano.tensor.discrete_dtypes:
5425            raise ValueError("elements of reps must be integer dtype")
5426
5427        # tensor.scalar/integer case
5428        if ndim_check == 0:
5429            reps = [reps]
5430
5431        # tensor.vector case
5432        elif ndim_check == 1:
5433            if ndim is None:
5434                raise ValueError("if reps is tensor.vector, you should specify "
5435                                 "the ndim")
5436            else:
5437                offset = ndim - reps.shape[0]
5438
5439                # assert that reps.shape[0] does not exceed ndim
5440                offset = theano.tensor.opt.assert_(offset, ge(offset, 0))
5441
5442                # if reps.ndim is less than x.ndim, we pad the reps with
5443                # "1" so that reps will have the same ndim as x.
5444                reps_ = [switch(i < offset, 1, reps[i - offset]) for i in range(ndim)]
5445                reps = reps_
5446
5447        # other raise error
5448        else:
5449            raise ValueError("the dimension of reps should not exceed 1")
5450    else:
5451        if ndim is not None and len(reps) > ndim:
5452            raise ValueError("len(reps) should be equal or less than ndim")
5453        if not np.all([isinstance(r, integer_types) or
5454                       (isinstance(r, TensorVariable) and
5455                        r.dtype in theano.tensor.discrete_dtypes) for r in reps]):
5456            raise ValueError("elements of reps must be scalars of integer dtype")
5457
5458    # if reps.ndim is less than x.ndim, we pad the reps with
5459    # "1" so that reps will have the same ndim as x.
5460    reps = list(reps)
5461    if ndim is None:
5462        ndim = builtins.max(len(reps), x.ndim)
5463    if len(reps) < ndim:
5464        reps = [1] * (ndim - len(reps)) + reps
5465
5466    shape = [1] * (ndim - x.ndim) + [x.shape[i] for i in xrange(x.ndim)]
5467    alloc_shape = reps + shape
5468    y = alloc(x, *alloc_shape)
5469    shuffle_ind = np.arange(ndim * 2).reshape(2, ndim)
5470    shuffle_ind = shuffle_ind.transpose().flatten()
5471    y = y.dimshuffle(*shuffle_ind)
5472    new_shapes = [sh * reps[i] for i, sh in enumerate(shape)]
5473    y = y.reshape(new_shapes)
5474
5475    return y
5476
5477
5478class ARange(Op):
5479    """Create an array containing evenly spaced values within a given interval.
5480
5481    Parameters and behaviour are the same as numpy.arange().
5482
5483    """
5484    __props__ = ("dtype",)
5485
5486    def __init__(self, dtype):
5487        self.dtype = dtype
5488
5489    def make_node(self, start, stop, step):
5490        start, stop, step = map(as_tensor_variable, (start, stop, step))
5491        assert start.ndim == 0
5492        assert stop.ndim == 0
5493        assert step.ndim == 0
5494
5495        inputs = [start, stop, step]
5496        outputs = [tensor(self.dtype, (False,))]
5497
5498        return Apply(self, inputs, outputs)
5499
5500    @theano.configparser.change_flags(warn_float64='ignore')
5501    def infer_shape(self, node, i_shapes):
5502        # Note start, stop and step can be float numbers.
5503        start, stop, step = node.inputs
5504
5505        def is_constant_value(var, value):
5506            try:
5507                v = get_scalar_constant_value(var)
5508                return np.all(v == value)
5509            except NotScalarConstantError:
5510                pass
5511            return False
5512
5513        def upcast(var):
5514            if (var.dtype in integer_dtypes and
5515                    # We do not want to cast uint64 to int64 as this can
5516                    # loose information. If we upcast uint64 with int64,
5517                    # this give float64. This is safer then checking for
5518                    # uint64 in case we support [u]int128 or other in the
5519                    # future.
5520                    scal.upcast(var.dtype, 'int64') == 'int64'):
5521                return cast(var, 'int64')
5522            return var
5523
5524        if is_constant_value(step, 1):
5525            if is_constant_value(start, 0):
5526                return [(cast(stop, 'int64'),)]
5527            else:
5528                stop = upcast(stop)
5529                start = upcast(start)
5530                return [(maximum(cast(stop - start, 'int64'), 0),)]
5531        else:
5532            stop = upcast(stop)
5533            start = upcast(start)
5534            return [(maximum(cast(ceil(cast((stop - start), 'float64') / step),
5535                    'int64'), 0),)]
5536
5537    def perform(self, node, inp, out_):
5538        start, stop, step = inp
5539        out, = out_
5540        start = start.item()
5541        stop = stop.item()
5542        step = step.item()
5543        out[0] = np.arange(start, stop, step, dtype=self.dtype)
5544
5545    def connection_pattern(self, node):
5546
5547        return [[True], [False], [True]]
5548
5549    def L_op(self, inputs, outputs, grads):
5550        start, stop, step = inputs
5551        gz, = grads
5552        # `start` and `step` affect the output values
5553        # but the outputs are integers so there's
5554        # no gradient through them.
5555        # When they are not integers, the gradients are
5556        # as expressed below.
5557        # `stop` does not affect the output values,
5558        # just the output shape, so it is disconnected.
5559
5560        if self.dtype in discrete_dtypes:
5561            return [start.zeros_like(dtype=config.floatX),
5562                    DisconnectedType()(),
5563                    step.zeros_like(dtype=config.floatX)]
5564        else:
5565            num_steps_taken = outputs[0].shape[0]
5566            return [gz.sum(),
5567                    DisconnectedType()(),
5568                    (gz * arange(num_steps_taken, dtype=self.dtype)).sum()]
5569
5570    def R_op(self, inputs, eval_points):
5571        return [None]
5572_arange = {}
5573
5574
5575def arange(start, stop=None, step=1, dtype=None):
5576    # If only one argument is provided, it is in fact the "stop" argument,
5577    # and start is 0.
5578    if stop is None:
5579        start, stop = 0, start
5580
5581    start, stop, step = map(as_tensor_variable, (start, stop, step))
5582    # If dtype is not provided, infer it from the other arguments
5583    if dtype is None:
5584        dtype = scal.upcast(start.type.dtype, stop.type.dtype, step.type.dtype)
5585        # don't try to be stingy and byte-optimize, this leads to
5586        # overflow problems.
5587        if dtype in int_dtypes:
5588            dtype = 'int64'
5589        if dtype in uint_dtypes:
5590            dtype = 'uint64'
5591        if config.cast_policy in ('numpy', 'numpy+floatX'):
5592            # We enforce numpy semantics, except in the special case where
5593            # `config.cast_policy` is 'numpy+floatX' and we want to use float32
5594            # rather than float64.
5595            # As an example, if `start`, `stop` and `step` are all int32,
5596            # `numpy.arange` returns an int64 array (on 64-bit platforms),
5597            # while the upcast above returns int32.
5598            numpy_dtype = np.arange(
5599                start=np.array(0, dtype=start.dtype),
5600                stop=np.array(1, dtype=stop.dtype),
5601                step=np.array(1, dtype=step.dtype)).dtype
5602            if numpy_dtype != dtype:
5603                if (config.cast_policy == 'numpy+floatX' and
5604                    config.floatX == 'float32' and
5605                    numpy_dtype == 'float64' and
5606                    # No explicit float64 in the three arguments?
5607                    python_all(
5608                        dt != 'float64'
5609                        for dt in [s.dtype for s in (start, stop, step)])):
5610                    # We use float32 instead.
5611                    assert dtype != 'float64'
5612                    dtype = 'float32'
5613                else:
5614                    # We use the same dtype as numpy instead of the result of
5615                    # the upcast.
5616                    dtype = str(numpy_dtype)
5617
5618    if dtype not in _arange:
5619        _arange[dtype] = ARange(dtype)
5620    return _arange[dtype](start, stop, step)
5621
5622
5623class _nd_grid(object):
5624    """Create a dense n-dimensional 'meshgrid' with equally spaced points.
5625
5626    Used to create the instance ``mgrid`` and ``ogrid`` which act similarly
5627    to their numpy equivalents.
5628
5629    Parameters
5630    ----------
5631    sparse : boolean, optional, default=True
5632        Specifying False leads to the equivalent of numpy's mgrid functionality.
5633        Specifying True leads to the equivalent of ogrid.
5634
5635    Examples
5636    --------
5637    >>> a = T.mgrid[0:5, 0:3]
5638    >>> a[0].eval()
5639    array([[0, 0, 0],
5640           [1, 1, 1],
5641           [2, 2, 2],
5642           [3, 3, 3],
5643           [4, 4, 4]], dtype=int8)
5644    >>> a[1].eval()
5645    array([[0, 1, 2],
5646           [0, 1, 2],
5647           [0, 1, 2],
5648           [0, 1, 2],
5649           [0, 1, 2]], dtype=int8)
5650    >>> b = T.ogrid[0:5, 0:3]
5651    >>> b[0].eval()
5652    array([[0],
5653           [1],
5654           [2],
5655           [3],
5656           [4]], dtype=int8)
5657    >>> b[1].eval()
5658    array([[0, 1, 2, 3]], dtype=int8)
5659
5660    """
5661
5662    def __init__(self, sparse=False):
5663        self.sparse = sparse
5664
5665    def __getitem__(self, *args):
5666
5667        ndim = len(args[0])
5668        for sl in args[0]:
5669            if isinstance(sl.step, python_complex):
5670                raise NotImplementedError("Not implemented for slices "
5671                                          "whose step is complex")
5672        ranges = [arange(sl.start or 0,
5673                         sl.stop,
5674                         sl.step or 1) for sl in args[0]]
5675        shapes = [tuple([1] * j + [r.shape[0]] + [1] * (ndim - 1 - j))
5676                  for j, r in enumerate(ranges)]
5677        ranges = [r.reshape(shape) for r, shape in zip(ranges, shapes)]
5678        if self.sparse:
5679            grids = ranges
5680        else:
5681            grids = []
5682            ones = [ones_like(r) for r in ranges]
5683            for i in range(ndim):
5684                grid = 1
5685                for j in range(ndim):
5686                    if j == i:
5687                        grid = grid * ranges[j]
5688                    else:
5689                        grid = grid * ones[j]
5690                grids.append(grid)
5691        return grids
5692
5693
5694mgrid = _nd_grid()
5695ogrid = _nd_grid(sparse=True)
5696
5697
5698class PermuteRowElements(Op):
5699    """Permute the elements of each row (inner-most dim) of a tensor.
5700
5701    A permutation will be applied to every row (vector) of the input tensor x.
5702    Depending on the dimensionality of x and the permutation tensor y,
5703    different cases are possible.
5704    If y.ndim = 1, y is a single permutation, that will be applied to every
5705    vector of x. For instance, if x is a matrix, the same permutation will be
5706    applied to each row of x.
5707    If x.ndim = y.ndim, each row of x corresponds to a row of y, containing
5708    a permutation that will be applied to that row. For instance, if x and y
5709    are two matrices, a different permutation will be applied to each row of x.
5710    If x.ndim > y.ndim, y will be broadcasted to fit x, then each row (vector)
5711    of x will be reordered according to the corresponding row of y. (This is
5712    a generalization of the first case).
5713    If x.ndim = 1, every permutation in y will be applied to x, and the output
5714    will contain all the results.
5715    If x.ndim < y.ndim, x will be broadcasted to fit y, and different
5716    permutations contained in y will be applied to each vector in x. (This is
5717    a generalization of the previous case).
5718
5719    If the "inverse" argument is True, the Op will perform the inverse
5720    permutation instead.
5721    """
5722    __props__ = ()
5723
5724    def make_node(self, x, y, inverse):
5725        x = as_tensor_variable(x)
5726        y = as_tensor_variable(y)
5727        if inverse:  # as_tensor_variable does not accept booleans
5728            inverse = as_tensor_variable(1)
5729        else:
5730            inverse = as_tensor_variable(0)
5731
5732        # y should contain integers
5733        assert y.type.dtype in integer_dtypes
5734        # Inverse should be an integer scalar
5735        assert (inverse.type.ndim == 0 and inverse.type.dtype in integer_dtypes)
5736
5737        # Match shapes of x and y
5738        x_dim = x.type.ndim
5739        y_dim = y.type.ndim
5740
5741        if x_dim > y_dim:
5742            y = shape_padleft(y, n_ones=(x_dim - y_dim))
5743        elif x_dim < y_dim:
5744            x = shape_padleft(x, n_ones=(y_dim - x_dim))
5745
5746        # Compute the broadcastable pattern of the output
5747        out_broadcastable = [xb and yb for xb, yb in
5748                             izip(x.type.broadcastable, y.type.broadcastable)]
5749        out_type = tensor(dtype=x.type.dtype, broadcastable=out_broadcastable)
5750
5751        inputlist = [x, y, inverse]
5752        outputlist = [out_type]
5753        return Apply(self, inputlist, outputlist)
5754
5755    def _rec_perform(self, node, x, y, inverse, out, curdim):
5756        """Perform the permutation by doing a recursion over the input
5757        dimensions.
5758
5759        For every dimension, starting with the leftmost, the right set of
5760        indices is determined (depending if broadcasting or not), then
5761        the function is recursively called on the appropriate subtensors.
5762
5763        The terminal case is reached when the current tensors are vector,
5764        then the permutation contained in y is applied to x.
5765
5766        Parameters
5767        ----------
5768        x : tensor
5769            The input tensor, on which the permutation is applied.
5770        y : tensor
5771            Tensor containing the permutations to apply.
5772        out : tensor
5773            Tensor storing the output result.
5774        curdim : int
5775            Counter of the current depth of recursion.
5776        inverse
5777            Wether to apply permutations or their inverse.
5778
5779        """
5780        if len(x.shape) == 1:
5781            # Numpy advanced indexing works in this case
5782            if inverse:
5783                out[y] = x[:]
5784            else:
5785                out[:] = x[y]
5786        else:
5787            xs0 = x.shape[0]
5788            ys0 = y.shape[0]
5789            if xs0 == ys0:
5790                for i in xrange(xs0):
5791                    self._rec_perform(node, x[i], y[i], inverse, out[i],
5792                                      curdim + 1)
5793            elif ys0 == 1 and node.inputs[1].type.broadcastable[curdim]:
5794                # Broadcast y
5795                for i in xrange(xs0):
5796                    self._rec_perform(node, x[i], y[0], inverse, out[i],
5797                                      curdim + 1)
5798            elif xs0 == 1 and node.inputs[0].type.broadcastable[curdim]:
5799                # Broadcast x
5800                for i in xrange(ys0):
5801                    self._rec_perform(node, x[0], y[i], inverse, out[i],
5802                                      curdim + 1)
5803            else:
5804                raise ValueError('Dimension mismatch: %s, %s' % (xs0, ys0))
5805
5806    def perform(self, node, inp, out):
5807        x, y, inverse = inp
5808        outs, = out
5809        x_s = x.shape
5810        y_s = y.shape
5811        assert len(x_s) == len(y_s)
5812
5813        # Make sure the output is big enough
5814        out_s = []
5815        for xdim, ydim in izip(x_s, y_s):
5816            if xdim == ydim:
5817                outdim = xdim
5818            elif xdim == 1:
5819                outdim = ydim
5820            elif ydim == 1:
5821                outdim = xdim
5822            else:
5823                raise ValueError('Dimension mismatch: %s, %s' % (xdim, ydim))
5824            out_s.append(outdim)
5825
5826        if outs[0] is None or outs[0].shape != out_s:
5827            outs[0] = np.empty(out_s, dtype=x.dtype)
5828
5829        self._rec_perform(node, x, y, inverse, outs[0], curdim=0)
5830
5831    def infer_shape(self, node, in_shapes):
5832        shp_x = in_shapes[0]
5833        shp_y = in_shapes[1]
5834        assert len(shp_x) == len(shp_y)
5835        out_shape = []
5836        for i in xrange(len(shp_x)):
5837            out_shape.append(maximum(shp_x[i], shp_y[i]))
5838        return [out_shape]
5839
5840    def grad(self, inp, grads):
5841        x, y, inverse = inp
5842        gz, = grads
5843        # First, compute the gradient wrt the broadcasted x.
5844        # If 'inverse' is False (0), apply the inverse of y on gz.
5845        # Else, apply y on gz.
5846        gx = permute_row_elements(gz, y, eq(inverse, 0))
5847
5848        # If x has been broadcasted along some axes, we need to sum
5849        # the gradient over these axes, but keep the dimension (as
5850        # broadcastable)
5851        broadcasted_dims = [dim for dim in xrange(gz.type.ndim)
5852                            if x.type.broadcastable[dim] and
5853                            not gz.type.broadcastable[dim]]
5854        gx = Sum(axis=broadcasted_dims)(gx)
5855
5856        # Sum(...) removed the dimensions in broadcasted_dims,
5857        # so we need to put them back.
5858        newdims = []
5859        i = 0
5860        for dim in xrange(gz.type.ndim):
5861            if dim in broadcasted_dims:
5862                newdims.append('x')
5863            else:
5864                newdims.append(i)
5865                i += 1
5866
5867        gx = DimShuffle(gx.type.broadcastable, newdims)(gx)
5868        assert gx.type.broadcastable == x.type.broadcastable
5869
5870        # if x is an integer type, then so is the output.
5871        # this means f(x+eps) = f(x) so the gradient with respect
5872        # to x is zero
5873        if x.type.dtype in discrete_dtypes:
5874            gx = x.zeros_like()
5875
5876        # The elements of y and of inverse both affect the output,
5877        # so they are connected to the output,
5878        # and the transformation isn't defined if their values
5879        # are non-integer, so the gradient with respect to them is
5880        # undefined
5881
5882        return [gx, grad_undefined(self, 1, y),
5883                grad_undefined(self, 1, inverse)]
5884
5885_permute_row_elements = PermuteRowElements()
5886
5887
5888def permute_row_elements(x, y, inverse=0):
5889    return _permute_row_elements(x, y, inverse)
5890
5891
5892def inverse_permutation(perm):
5893    """Computes the inverse of permutations.
5894
5895    Each row of input should contain a permutation of the first integers.
5896
5897    """
5898    return permute_row_elements(
5899        arange(perm.shape[-1], dtype=perm.dtype),
5900        perm,
5901        inverse=True)
5902
5903
5904#########################
5905# Linalg : Dot
5906#########################
5907#
5908# For BLAS-related ops see blas.py
5909#
5910# TODO: Dotinv should go here, Eigs, Svd, etc.
5911
5912
5913class Dot(Op):
5914    """
5915    Computes the dot product of two variables. For two matrices, this is
5916    equivalent to matrix multiplication. For two vectors, this is the inner
5917    product.
5918
5919    Notes
5920    -----
5921    Matrix-matrix products are sometimes optimized to Dot22 or Gemm ops
5922    (see tensor.blas).
5923    Vector-vector products are sometimes optimized to Ger or CGer (see
5924    tensor.blas).
5925    Matrix-vector products are sometimes optimized to Gemv, CGemv (see
5926    tensor.blas).
5927
5928    """
5929    __props__ = ()
5930
5931    # the rationale for Dot22 is related to getting GEMM Ops into the
5932    # graph.  See Dot22 in tensor.blas for details.
5933
5934    def make_node(self, *inputs):
5935        inputs = list(map(as_tensor_variable, inputs))
5936
5937        if len(inputs) != 2:
5938            raise TypeError(
5939                'theano.tensor.Dot: 2 arguments required, %d given ' %
5940                len(inputs))
5941        if inputs[0].ndim not in (1, 2):
5942            raise TypeError(
5943                'theano.tensor.Dot: input 0 (0-indexed) must have ndim of '
5944                '1 or 2, %d given. Consider calling theano.tensor.dot '
5945                'instead.' % inputs[0].ndim)
5946        if inputs[1].ndim not in (1, 2):
5947            raise TypeError(
5948                'theano.tensor.Dot: input 1 (0-indexed) must have ndim of '
5949                '1 or 2, %d given. Consider calling theano.tensor.dot '
5950                'instead.' % inputs[1].ndim)
5951
5952        i_broadcastables = [input.type.broadcastable for input in inputs]
5953        bx, by = i_broadcastables
5954        if len(by) == 2:  # y is a matrix
5955            bz = bx[:-1] + by[-1:]
5956        elif len(by) == 1:  # y is vector
5957            bz = bx[:-1]
5958
5959        i_dtypes = [input.type.dtype for input in inputs]
5960        outputs = [tensor(scal.upcast(*i_dtypes), bz)]
5961        return Apply(self, inputs, outputs)
5962
5963    def perform(self, node, inp, out):
5964        x, y = inp
5965        z, = out
5966
5967        # the asarray is here because dot between two vectors
5968        # gives a numpy float object but we need to return a 0d
5969        # ndarray
5970        z[0] = np.asarray(np.dot(x, y))
5971
5972    def grad(self, inp, grads):
5973
5974        x, y = inp
5975        gz, = grads
5976        xdim, ydim, gdim = x.type.ndim, y.type.ndim, gz.type.ndim
5977
5978        # grad is scalar, so x is vector and y is vector
5979        if gdim == 0:
5980            xgrad = gz * y
5981            ygrad = gz * x
5982
5983        # x is vector, y is matrix, grad is vector
5984        elif xdim == 1 and ydim == 2:
5985            xgrad = dot(gz, y.T)
5986            ygrad = outer(x.T, gz)
5987
5988        # x is matrix, y is vector, grad is vector
5989        elif xdim == 2 and ydim == 1:
5990            xgrad = outer(gz, y.T)
5991            ygrad = dot(x.T, gz)
5992
5993        # x is matrix, y is matrix, grad is matrix
5994        elif xdim == ydim == 2:
5995            xgrad = dot(gz, y.T)
5996            ygrad = dot(x.T, gz)
5997
5998        # If x or y contain broadcastable dimensions but only one of
5999        # them know that a matching dimensions is broadcastable, the
6000        # above code don't always return the right broadcast pattern.
6001        # This cause problem down the road. See gh-1461.
6002        if xgrad.broadcastable != x.broadcastable:
6003            xgrad = patternbroadcast(xgrad, x.broadcastable)
6004        if ygrad.broadcastable != y.broadcastable:
6005            ygrad = patternbroadcast(ygrad, y.broadcastable)
6006
6007        rval = xgrad, ygrad
6008
6009        for elem in rval:
6010            assert elem.dtype.find('float') != -1
6011
6012        return rval
6013
6014    def R_op(self, inputs, eval_points):
6015        # R_op for a \dot b evaluted at c for a and d for b is
6016        # simply c \dot b + a \dot d
6017
6018        assert len(inputs) == 2
6019        assert len(eval_points) == 2
6020        if eval_points[0] is None and eval_points[1] is None:
6021            return [None]
6022
6023        if eval_points[0]:
6024            t1 = self(eval_points[0], inputs[1])
6025        if eval_points[1]:
6026            t2 = self(inputs[0], eval_points[1])
6027
6028        if eval_points[0] and eval_points[1]:
6029            return [t1 + t2]
6030        elif eval_points[0]:
6031            return [t1]
6032        else:
6033            return [t2]
6034
6035    def infer_shape(self, node, shapes):
6036        xshp, yshp = shapes
6037        x, y = node.inputs
6038
6039        # vector / vector
6040        if x.ndim == 1 and y.ndim == 1:
6041            return [()]
6042        # matrix / vector
6043        if x.ndim == 2 and y.ndim == 1:
6044            return [xshp[:-1]]
6045        # vector / matrix
6046        if x.ndim == 1 and y.ndim == 2:
6047            return [yshp[-1:]]
6048        # matrix / matrix
6049        if x.ndim == 2 and y.ndim == 2:
6050            return [xshp[:-1] + yshp[-1:]]
6051        raise NotImplementedError()
6052
6053    def __str__(self):
6054        return "dot"
6055
6056_dot = Dot()
6057pprint.assign(_dot, printing.OperatorPrinter(printing.special['middle_dot'],
6058                                             -1, 'left'))
6059
6060
6061def dot(a, b):
6062    """
6063    Computes the dot product of two variables.
6064
6065    For two matrices, this is equivalent to matrix multiplication.
6066    For two vectors, this is the inner product.
6067    When one variable is a scalar, this is like elementwise multiplication.
6068    For N dimensions, this is a sum product over the last axis
6069    of the first array and the second-to-last axis of the second array:
6070
6071        dot(a, b)[i,j,k,m] = sum(a[i,j,:] * b[k,:,m])
6072
6073    Note that this dot function does one of three things, in the following
6074    sequence:
6075
6076        1.  If either a or b is scalar, it returns the elementwise product
6077            without calling the Theano Dot op.
6078
6079        2.  If either a or b has more than 2 dimensions, it calls Theano's
6080            tensordot function with appropriate axes. The tensordot function
6081            expresses high-dimensional dot products in terms of 2D matrix
6082            multiplications, so it may be possible to futherize optimize for
6083            performance.
6084
6085        3.  If both a and b have either 1 or 2 dimensions, it calls Theano's
6086            Dot op on a and b.
6087
6088    Notes
6089    -----
6090    Matrix-matrix products are sometimes optimized to Dot22 or Gemm ops
6091    (see tensor.blas).
6092    Vector-vector products are sometimes optimized to Ger or CGer (see
6093    tensor.blas).
6094    Matrix-vector products are sometimes optimized to Gemv, CGemv (see
6095    tensor.blas).
6096
6097    """
6098    a, b = as_tensor_variable(a), as_tensor_variable(b)
6099
6100    if a.ndim == 0 or b.ndim == 0:
6101        return a * b
6102    elif a.ndim > 2 or b.ndim > 2:
6103        return tensordot(a, b, [[a.ndim - 1], [np.maximum(0, b.ndim - 2)]])
6104    else:
6105        return _dot(a, b)
6106
6107
6108#########################
6109# Linalg : TensorDot
6110#########################
6111
6112def _tensordot_as_dot(a, b, axes, dot, batched):
6113    """
6114    Reduces a tensor dot product to a matrix or vector dot product. Based
6115    on code from Tijmen Tieleman's gnumpy
6116    (http://www.cs.toronto.edu/~tijmen/gnumpy.html).
6117
6118    Please see the documentation of tensordot for the meaning of the a, b
6119    and axes arguments.
6120
6121    :param dot: a function that accepts two symbolic variables and computes
6122                the appropriate dot product (e.g. dot, batched_dot)
6123    :type dot: function
6124
6125    :param batched: whether to treat the first axis of a and b as a batch
6126                    axis.  If so, this axis will be preserved in the output,
6127                    allowing this function to be used also for batched
6128                    tensor dot products.
6129    :type batched: boolean
6130
6131    :returns: a tensor with shape equal to the concatenation of a's shape
6132              (less any dimensions that were summed over) and b's shape
6133              (less the first dimension and any dimensions that were summed
6134              over).
6135    :rtype: symbolic tensor
6136    """
6137    a, b = as_tensor_variable(a), as_tensor_variable(b)
6138
6139    if not np.isscalar(axes) and len(axes) != 2:
6140        raise ValueError('Axes should be an integer or a '
6141                         'list/tuple of len 2 (%s was provided)'
6142                         % str(axes))
6143
6144    # if 'axes' is a number of axes to multiply and sum over (trailing axes
6145    # of a, leading axes of b), we can just reshape and use dot.
6146    elif np.isscalar(axes):
6147        axes = int(axes)
6148
6149        for operand_name, operand in (("a", a), ("b", b)):
6150            if axes > operand.ndim:
6151                raise ValueError(
6152                    'axes can not be larger than the dimension of %s '
6153                    '(%s.ndim=%i, axes=%i)'
6154                    % (operand_name, operand_name, operand.ndim, axes))
6155            if batched and axes == operand.ndim:
6156                raise ValueError(
6157                    'axes to sum over must not include the batch axis '
6158                    'of %s (%s.ndim=%i, axes=%i)'
6159                    % (operand_name, operand_name, operand.ndim, axes))
6160
6161        batch_axes = 1 if batched else 0
6162        a_outaxes = slice(0, a.ndim - axes)
6163        b_outaxes = slice(batch_axes + axes, b.ndim)
6164        outshape = concatenate([a.shape[a_outaxes], b.shape[b_outaxes]])
6165        outbcast = a.broadcastable[a_outaxes] + b.broadcastable[b_outaxes]
6166        outndim = len(outbcast)
6167
6168        a_shape = [1] * 2
6169        b_shape = [1] * 2
6170
6171        # compute total size of summed axes
6172        for i in xrange(0, axes):
6173            a_shape[1] *= a.shape[-(i + 1)]
6174            b_shape[0] *= b.shape[batch_axes + i]
6175        # compute total size of other axes
6176        for i in xrange(0, a.ndim - axes - batch_axes):
6177            a_shape[0] *= a.shape[batch_axes + i]
6178        for i in xrange(0, b.ndim - axes - batch_axes):
6179            b_shape[1] *= b.shape[-(i + 1)]
6180
6181        if batched:
6182            a_shape.insert(0, a.shape[0])
6183            b_shape.insert(0, b.shape[0])
6184
6185        a_reshaped = a.reshape(a_shape)
6186        b_reshaped = b.reshape(b_shape)
6187
6188        out_reshaped = dot(a_reshaped, b_reshaped)
6189        out = out_reshaped.reshape(outshape, outndim)
6190        # Make sure the broadcastable pattern of the result is correct,
6191        # since some shape information can be lost in the reshapes.
6192        return patternbroadcast(out, outbcast)
6193
6194    # if 'axes' is a list, transpose a and b such that the summed axes of a
6195    # are last and the summed axes of b are first.
6196    else:
6197        axes = [_pack(axes_) for axes_ in axes]
6198
6199        if len(axes[0]) != len(axes[1]):
6200            raise ValueError('Axes elements must have the same length.')
6201
6202        for i, (operand_name, operand) in enumerate((("a", a),
6203                                                     ("b", b))):
6204            if len(axes[i]) > operand.ndim:
6205                raise ValueError(
6206                    'axes[%i] should be array_like with length less than '
6207                    'the dimensions of %s (%s.ndim=%i, len(axes[0])=%i).' %
6208                    (i, operand_name, operand_name, operand.ndim,
6209                     len(axes[i])))
6210            if len(axes[i]) > 0 and np.max(axes[i]) >= operand.ndim:
6211                raise ValueError(
6212                    'axes[%i] contains dimensions greater than or equal '
6213                    'to %s.ndim (%s.ndim=%i, max(axes[0])=%i).' %
6214                    (i, operand_name, operand_name, operand.ndim,
6215                     np.max(np.array(axes[i]))))
6216            if batched and 0 in axes[i]:
6217                raise ValueError(
6218                    'axes to sum over must not contain the batch axis '
6219                    '(axes[%i]=%s)' %
6220                    (i, axes[i]))
6221
6222        batch_axes = [0] if batched else []
6223        other_axes = [[x for x in xrange(operand.ndim)
6224                       if x not in axes[i] and x not in batch_axes]
6225                      for i, operand in enumerate((a, b))]
6226
6227        a_shuffled = a.dimshuffle(batch_axes + other_axes[0] + axes[0])
6228        b_shuffled = b.dimshuffle(batch_axes + axes[1] + other_axes[1])
6229
6230        # now that a and b are in the right order, recur with integer axes
6231        return _tensordot_as_dot(a_shuffled, b_shuffled, len(axes[0]),
6232                                 dot=dot, batched=batched)
6233
6234
6235def tensordot(a, b, axes=2):
6236    """
6237    Compute a generalized dot product over provided axes.
6238
6239    Given two tensors a and b, tensordot computes a generalized dot product over
6240    the provided axes. Theano's implementation reduces all expressions to
6241    matrix or vector dot products and is based on code from Tijmen Tieleman's
6242    gnumpy (http://www.cs.toronto.edu/~tijmen/gnumpy.html).
6243
6244    Parameters
6245    ----------
6246    a: symbolic tensor
6247        The first tensor variable.
6248    b: symbolic tensor
6249        The second tensor variable
6250    axes: int or array-like of length 2
6251        If an integer, the number of axes to sum over.
6252        If an array, it must have two array elements containing the axes
6253        to sum over in each tensor.
6254
6255        Note that the default value of 2 is not guaranteed to work
6256        for all values of a and b, and an error will be raised if
6257        that is the case. The reason for keeping the default is to
6258        maintain the same signature as numpy's tensordot function
6259        (and np.tensordot raises analogous errors for non-compatible
6260        inputs).
6261
6262        If an integer i, it is converted to an array containing
6263        the last i dimensions of the first tensor and the first
6264        i dimensions of the second tensor:
6265            axes = [list(range(a.ndim - i, b.ndim)), list(range(i))]
6266
6267        If an array, its two elements must contain compatible axes
6268        of the two tensors. For example, [[1, 2], [2, 0]] means sum
6269        over the 2nd and 3rd axes of a and the 3rd and 1st axes of b.
6270        (Remember axes are zero-indexed!) The 2nd axis of a and the
6271        3rd axis of b must have the same shape; the same is true for
6272        the 3rd axis of a and the 1st axis of b.
6273
6274    Returns
6275    -------
6276    symbolic tensor
6277        A tensor with shape equal to the concatenation of a's shape
6278        (less any dimensions that were summed over) and b's shape
6279        (less any dimensions that were summed over).
6280
6281    Examples
6282    --------
6283    It may be helpful to consider an example to see what tensordot does.
6284    Theano's implementation is identical to NumPy's. Here a has shape (2, 3, 4)
6285    and b has shape (5, 6, 4, 3). The axes to sum over are [[1, 2], [3, 2]] --
6286    note that a.shape[1] == b.shape[3] and a.shape[2] == b.shape[2]; these axes
6287    are compatible. The resulting tensor will have shape (2, 5, 6) -- the
6288    dimensions that are not being summed:
6289
6290    >>> a = np.random.random((2,3,4))
6291    >>> b = np.random.random((5,6,4,3))
6292
6293    #tensordot
6294    >>> c = np.tensordot(a, b, [[1,2],[3,2]])
6295
6296    #loop replicating tensordot
6297    >>> a0, a1, a2 = a.shape
6298    >>> b0, b1, _, _ = b.shape
6299    >>> cloop = np.zeros((a0,b0,b1))
6300
6301    #loop over non-summed indices -- these exist
6302    #in the tensor product.
6303    >>> for i in range(a0):
6304    ...     for j in range(b0):
6305    ...         for k in range(b1):
6306    ...             #loop over summed indices -- these don't exist
6307    ...             #in the tensor product.
6308    ...             for l in range(a1):
6309    ...                 for m in range(a2):
6310    ...                     cloop[i,j,k] += a[i,l,m] * b[j,k,m,l]
6311
6312    >>> np.allclose(c, cloop)
6313    true
6314
6315    This specific implementation avoids a loop by transposing a and b such that
6316    the summed axes of a are last and the summed axes of b are first. The
6317    resulting arrays are reshaped to 2 dimensions (or left as vectors, if
6318    appropriate) and a matrix or vector dot product is taken. The result is
6319    reshaped back to the required output dimensions.
6320
6321    In an extreme case, no axes may be specified. The resulting tensor
6322    will have shape equal to the concatenation of the shapes of a and b:
6323
6324    >>> c = np.tensordot(a, b, 0)
6325    >>> print(a.shape)
6326    (2,3,4)
6327    >>> print(b.shape)
6328    (5,6,4,3)
6329    >>> print(c.shape)
6330    (2,3,4,5,6,4,3)
6331
6332    See the documentation of numpy.tensordot for more examples.
6333
6334    """
6335    return _tensordot_as_dot(a, b, axes, dot=dot, batched=False)
6336
6337
6338def outer(x, y):
6339    """Return vector-vector outer product.
6340
6341    If an input isn't a vector, we flatten it first.
6342
6343    """
6344    if x.ndim != 1:
6345        x = x.flatten()
6346    if y.ndim != 1:
6347        y = y.flatten()
6348    return dot(
6349        x.dimshuffle(0, 'x'),
6350        y.dimshuffle('x', 0))
6351
6352
6353def any(x, axis=None, keepdims=False):
6354    out = elemwise.Any(axis)(x)
6355
6356    if keepdims:
6357        out = makeKeepDims(x, out, axis)
6358    return out
6359
6360
6361def all(x, axis=None, keepdims=False):
6362    out = elemwise.All(axis)(x)
6363
6364    if keepdims:
6365        out = makeKeepDims(x, out, axis)
6366    return out
6367
6368
6369# Some NumPy version like 1.9.2 return a view for numpy.diagonal
6370x = np.zeros((4, 4))
6371numpy_diagonal_return_view = np.may_share_memory(np.diagonal(x), x)
6372del x
6373
6374
6375class ExtractDiag(Op):
6376    """
6377    Return specified diagonals.
6378
6379    If x is 2-D, returns the diagonal of x with the given offset,
6380    i.e., the collection of elements of the form x[i, i+offset].
6381    If x has more than two dimensions, then the axes specified by
6382    axis1 and axis2 are used to determine the 2-D sub-array whose
6383    diagonal is returned. The shape of the resulting array can be
6384    determined by removing axis1 and axis2 and appending an index
6385    to the right equal to the size of the resulting diagonals.
6386
6387    Parameters
6388    ----------
6389    x: A tensor variable with x.ndim >= 2.
6390
6391    offset: Offset of the diagonal from the main diagonal.
6392        Can be positive or negative.
6393        Defaults to main diagonal (0).
6394
6395    axis1: Axis to be used as the first axis of the 2-D
6396        sub-arrays from which the diagonals should be taken.
6397        Defaults to first axis (0).
6398
6399    axis2: Axis to be used as the second axis of the 2-D
6400        sub-arrays from which the diagonals should be taken.
6401        Defaults to second axis (1).
6402
6403
6404
6405    Returns
6406    -------
6407    array_of_diagonals:
6408        If x is 2-D, a 1-D array of the same type as a
6409        containing the diagonal is returned.
6410        If the dimension of x is greater than two, then an
6411        array of diagonals is returned, "packed" from left-most
6412        dimension to right-most (e.g., if x is 3-D, then the
6413        diagonals are "packed" along rows).
6414
6415
6416
6417    Raises
6418    ------
6419    ValueError
6420        If the dimension of x is less than 2.
6421
6422
6423    See Also
6424    --------
6425    numpy.diagonal:
6426        https://docs.scipy.org/doc/numpy-dev/reference/generated/numpy.diagonal.html
6427    """
6428    __props__ = ("offset", "axis1", "axis2", "view")
6429
6430    def __init__(self, offset=0, axis1=0, axis2=1, view=False):
6431        self.view = view
6432        if self.view and not numpy_diagonal_return_view:
6433            warnings.warn("View will forced to False. ExtractDiag property view is "
6434                          "set to True but numpy version %s and prior versions of "
6435                          "numpy.diagonal() do not return a view. Update "
6436                          "numpy to use ExtractDiag(view=True)" %
6437                          np.version.version)
6438            self.view = False
6439        if self.view:
6440            self.view_map = {0: [0]}
6441        self.offset = offset
6442        self.axis1 = axis1
6443        self.axis2 = axis2
6444
6445    def make_node(self, x):
6446        x = as_tensor_variable(x)
6447
6448        if x.ndim < 2:
6449            raise ValueError('ExtractDiag needs an input with 2 or more '
6450                             'dimensions', x)
6451        return Apply(self, [x], [x.type.__class__(
6452            dtype=x.dtype,
6453            broadcastable=[False] * (x.ndim - 1))()])
6454
6455    def perform(self, node, inputs, outputs):
6456        (x,) = inputs
6457        (z,) = outputs
6458        z[0] = x.diagonal(self.offset, self.axis1, self.axis2)
6459        if not self.view:
6460            z[0] = z[0].copy()
6461
6462    def grad(self, inputs, gout):
6463        (x,) = inputs
6464        (gz,) = gout
6465
6466        if x.ndim == 2:
6467            x = theano.tensor.zeros_like(x)
6468            xdiag = theano.tensor.AllocDiag(offset=self.offset)(gz)
6469            return [theano.tensor.set_subtensor(
6470                x[:xdiag.shape[0], :xdiag.shape[1]], xdiag)]
6471        else:
6472            warnings.warn("gradient of theano.tensor.basic.ExtractDiag only"
6473                          "works for matrices.")
6474            return [grad_not_implemented(self, 0, x)]
6475
6476    def infer_shape(self, node, shapes):
6477        in_shape, = shapes
6478        dim1 = in_shape[self.axis1]
6479        dim2 = in_shape[self.axis2]
6480        out_shape = [d for i, d in enumerate(in_shape)
6481                     if i not in (self.axis1, self.axis2)]
6482        # The following logic is inspired by C code of PyArray_Diagonal().
6483        offset = self.offset
6484        if offset > 0:
6485            diag_size = clip(dim2 - offset, 0, dim1)
6486        elif offset < 0:
6487            diag_size = clip(dim1 + offset, 0, dim2)
6488        else:
6489            diag_size = minimum(dim1, dim2)
6490        out_shape.append(diag_size)
6491        return [tuple(out_shape)]
6492
6493    def __setstate__(self, state):
6494        self.__dict__.update(state)
6495        if self.view and not numpy_diagonal_return_view:
6496            warnings.warn("View will forced to False. ExtractDiag property view is "
6497                          "set to True but numpy version %s and prior versions of "
6498                          "numpy.diagonal() do not return a view. Update "
6499                          "numpy to use ExtractDiag(view=True)" %
6500                          np.version.version)
6501            self.view = False
6502
6503        if self.view:
6504            self.view_map = {0: [0]}
6505
6506        if "offset" not in state:
6507            self.offset = 0
6508        if "axis1" not in state:
6509            self.axis1 = 0
6510        if "axis2" not in state:
6511            self.axis2 = 1
6512
6513
6514def diagonal(a, offset=0, axis1=0, axis2=1):
6515    """
6516    A helper function for `theano.tensor.ExtractDiag`. It accepts tensor with
6517    `ndim >= 2` as input. The name `diagonal` is just meant to keep it
6518    consistent with numpy.
6519
6520    Parameters
6521    ----------
6522    a : symbolic tensor
6523    offset : int
6524        offset
6525    axis1 : int
6526    axis2 : int
6527
6528    Returns
6529    -------
6530    tensor : symbolic tensor
6531
6532    """
6533    return ExtractDiag(offset, axis1, axis2)(a)
6534
6535
6536class AllocDiag(Op):
6537    """
6538    An op that copies a vector to the diagonal of an empty matrix. It does the
6539    inverse of ExtractDiag.
6540
6541    Usage: T.AllocDiag()(x)
6542
6543    `x` should be a tensor vector. The parenthesis in the front should indicate
6544    which main diagonal the vector value goes into. By default it is set to
6545    `0`, which corresponds to setting the values of x to the main diagonal in
6546    the returned matrix.
6547
6548    Parameters
6549    ----------
6550    axis1: Axis to be used as the first axis of the 2-D
6551        sub-arrays to which the diagonals will be allocated.
6552        Defaults to first axis (0).
6553
6554    axis2: Axis to be used as the second axis of the 2-D
6555        sub-arrays to which the diagonals will be allocated.
6556        Defaults to second axis (1).
6557
6558    offset: Offset of the diagonal from the main diagonal defined by `axis1`
6559        and `axis2`.
6560        Can be positive or negative.
6561        Defaults to main diagonal (0).
6562
6563    x: symbolic vector
6564        A tensor vector consists of diagonal values.
6565
6566    Returns
6567    -------
6568    tensor : symbolic tenstor
6569        A tensor with passed tensor values at their corresponding diagonals.
6570
6571    """
6572
6573    __props__ = ("offset", "axis1", "axis2")
6574
6575    def __init__(self, offset=0, axis1=0, axis2=1):
6576        self.offset = offset
6577        self.axis1 = axis1
6578        self.axis2 = axis2
6579
6580    def make_node(self, diag):
6581        diag = as_tensor_variable(diag)
6582        if diag.type.ndim < 1:
6583            raise ValueError('AllocDiag needs an input with 1 or more '
6584                             'dimensions', diag.type)
6585        return Apply(
6586            self, [diag],
6587            [diag.type.__class__(
6588                dtype=diag.dtype,
6589                broadcastable=[False] * (diag.ndim + 1))()]
6590        )
6591
6592    def perform(self, node, inputs, outputs):
6593        (x,) = inputs
6594        (z,) = outputs
6595
6596        axis1 = np.minimum(self.axis1, self.axis2)
6597        axis2 = np.maximum(self.axis1, self.axis2)
6598        offset = self.offset
6599
6600        # Create array with one extra dimension for resulting matrix
6601        result_shape = x.shape[:-1] + (x.shape[-1] + abs(offset),) * 2
6602        result = np.zeros(result_shape, dtype=x.dtype)
6603
6604        # Create slice for diagonal in final 2 axes
6605        idxs = np.arange(x.shape[-1])
6606        diagonal_slice = ((len(result_shape) - 2) * [slice(None)] +
6607                          [idxs + np.maximum(0, -offset),
6608                           idxs + np.maximum(0, offset)])
6609
6610        # Fill in final 2 axes with x
6611        result[tuple(diagonal_slice)] = x
6612
6613        if len(x.shape) > 1:
6614            # Re-order axes so they correspond to diagonals at axis1, axis2
6615            axes = list(range(len(x.shape[:-1])))
6616            last_idx = axes[-1]
6617            axes = axes[:axis1] + [last_idx + 1] + axes[axis1:]
6618            axes = axes[:axis2] + [last_idx + 2] + axes[axis2:]
6619            result = result.transpose(axes)
6620
6621        z[0] = result
6622
6623    def grad(self, inputs, gout):
6624        (gz,) = gout
6625        return [diagonal(
6626            gz,
6627            offset=self.offset,
6628            axis1=self.axis1,
6629            axis2=self.axis2
6630        )]
6631
6632    def infer_shape(self, nodes, shapes):
6633        (x_shape,) = shapes
6634        axis1 = np.minimum(self.axis1, self.axis2)
6635        axis2 = np.maximum(self.axis1, self.axis2)
6636
6637        result_shape = list(x_shape[:-1])
6638        diag_shape = x_shape[-1] + abs(self.offset)
6639        result_shape = result_shape[:axis1] + [diag_shape] + result_shape[axis1:]
6640        result_shape = result_shape[:axis2] + [diag_shape] + result_shape[axis2:]
6641        return [tuple(result_shape)]
6642
6643    def __setstate__(self, state):
6644        if "view_map" in state:
6645            del state["view_map"]
6646
6647        self.__dict__.update(state)
6648
6649        if "offset" not in state:
6650            self.offset = 0
6651        if "axis1" not in state:
6652            self.axis1 = 0
6653        if "axis2" not in state:
6654            self.axis2 = 1
6655
6656
6657def diag(v, k=0):
6658    """
6659    A helper function for two ops: `theano.tensor.ExtractDiag` and
6660    `theano.tensor.AllocDiag`. The name `diag` is meant to keep it consistent
6661    with numpy. It both accepts tensor vector and tensor matrix.
6662    While the passed tensor variable `v` has `v.ndim>=2`, it builds a
6663    `ExtractDiag` instance, and returns a vector with its entries equal to
6664    `v`'s main diagonal; otherwise if `v.ndim` is `1`, it builds an `AllocDiag`
6665    instance, and returns a matrix with `v` at its k-th diaogonal.
6666
6667    Parameters
6668    ----------
6669    v : symbolic tensor
6670    k : int
6671        offset
6672
6673    Returns
6674    -------
6675    tensor : symbolic tensor
6676
6677    """
6678
6679    if v.ndim == 1:
6680        return AllocDiag(k)(v)
6681    elif v.ndim >= 2:
6682        return diagonal(v, offset=k)
6683    else:
6684        raise ValueError("Input must has v.ndim >= 1.")
6685
6686
6687def stacklists(arg):
6688    """
6689    Recursively stack lists of tensors to maintain similar structure.
6690
6691    This function can create a tensor from a shaped list of scalars:
6692
6693    Examples
6694    --------
6695    >>> from theano.tensor import stacklists, scalars, matrices
6696    >>> from theano import function
6697    >>> a, b, c, d = scalars('abcd')
6698    >>> X = stacklists([[a, b], [c, d]])
6699    >>> f = function([a, b, c, d], X)
6700    >>> f(1, 2, 3, 4)
6701    array([[ 1.,  2.],
6702           [ 3.,  4.]], dtype=float32)
6703
6704    We can also stack arbitrarily shaped tensors. Here we stack matrices into
6705    a 2 by 2 grid:
6706
6707    >>> from numpy import ones
6708    >>> a, b, c, d = matrices('abcd')
6709    >>> X = stacklists([[a, b], [c, d]])
6710    >>> f = function([a, b, c, d], X)
6711    >>> x = ones((4, 4), 'float32')
6712    >>> f(x, x, x, x).shape
6713    (2, 2, 4, 4)
6714
6715    """
6716    if isinstance(arg, (tuple, list)):
6717        return stack(list(map(stacklists, arg)))
6718    else:
6719        return arg
6720
6721
6722def ptp(a, axis=None):
6723    """
6724    Range of values (maximum - minimum) along an axis.
6725
6726    The name of the function comes from the acronym for peak to peak.
6727
6728    Parameters
6729    ----------
6730    a
6731        Input tensor.
6732    axis
6733        Axis along which to find the peaks. By default, flatten the array.
6734
6735    Returns
6736    -------
6737    array
6738        A new array holding the result.
6739
6740    """
6741
6742    a = as_tensor_variable(a)
6743
6744    out = max(a, axis) - min(a, axis)
6745
6746    return out
6747
6748
6749def power(x, y):
6750    return x ** y
6751
6752
6753def swapaxes(y, axis1, axis2):
6754    "swap axes of inputted tensor"
6755    y = as_tensor_variable(y)
6756    ndim = y.ndim
6757    li = list(range(0, ndim))
6758    li[axis1], li[axis2] = li[axis2], li[axis1]
6759    return y.dimshuffle(li)
6760
6761
6762def choose(a, choices, out=None, mode='raise'):
6763    """
6764    Construct an array from an index array and a set of arrays to choose from.
6765
6766    First of all, if confused or uncertain, definitely look at the Examples -
6767    in its full generality, this function is less simple than it might seem
6768    from the following code description (below ndi = numpy.lib.index_tricks):
6769
6770    np.choose(a,c) == np.array([c[a[I]][I] for I in ndi.ndindex(a.shape)]).
6771
6772    But this omits some subtleties. Here is a fully general summary:
6773
6774    Given an ``index`` array (a) of integers and a sequence of n arrays
6775    (choices), a and each choice array are first broadcast, as necessary,
6776    to arrays of a common shape; calling these Ba and
6777    Bchoices[i], i = 0,...,n-1 we have that, necessarily,
6778    Ba.shape == Bchoices[i].shape for each i.
6779    Then, a new array with shape Ba.shape is created as follows:
6780
6781    - if mode=raise (the default), then, first of all, each element of a
6782      (and thus Ba) must be in the range [0, n-1]; now, suppose that
6783      i (in that range) is the value at the (j0, j1, ..., jm) position in Ba -
6784      then the value at the same position in the new array is the value in
6785      Bchoices[i] at that same position;
6786
6787    - if mode=wrap, values in a (and thus Ba) may be any (signed) integer;
6788      modular arithmetic is used to map integers outside the range [0, n-1]
6789      back into that range; and then the new array is constructed as above;
6790
6791    - if mode=clip, values in a (and thus Ba) may be any (signed) integer;
6792      negative integers are mapped to 0; values greater than n-1 are mapped
6793      to n-1; and then the new array is constructed as above.
6794
6795    Parameters
6796    ----------
6797    a : int array
6798        This array must contain integers in [0, n-1], where n is the number of
6799        choices, unless mode=wrap or mode=clip, in which cases any integers
6800        are permissible.
6801    choices : sequence of arrays
6802        Choice arrays. a and all of the choices must be broadcastable to
6803        the same shape. If choices is itself an array (not recommended),
6804        then its outermost dimension (i.e., the one corresponding to
6805        choices.shape[0]) is taken as defining the ``sequence``.
6806    out : array, optional
6807        If provided, the result will be inserted into this array.
6808        It should be of the appropriate shape and dtype.
6809    mode : {``raise`` (default), ``wrap``, ``clip``}, optional
6810        Specifies how indices outside [0, n-1] will be treated:
6811        ``raise`` : an exception is raised
6812        ``wrap`` : value becomes value mod n
6813        ``clip`` : values < 0 are mapped to 0, values > n-1 are mapped to n-1
6814
6815    Returns
6816    -------
6817    merged_array - array
6818        The merged result.
6819
6820    Raises
6821    ------
6822    ValueError - shape mismatch
6823        If a and each choice array are not all broadcastable to the same shape.
6824
6825    """
6826    # This is done to keep the same function signature then NumPy.
6827    assert out is None
6828    return Choose(mode)(a, choices)
6829
6830
6831class Choose(Op):
6832    __props__ = ('mode',)
6833
6834    def __init__(self, mode):
6835        assert mode in ("raise", "wrap", "clip")
6836        self.mode = mode
6837
6838    def infer_shape(self, node, shapes):
6839
6840        if isinstance(node.inputs[1], TensorVariable):
6841            # We have padded node.inputs[0] to the right number of
6842            # dimensions for the output
6843            l = []
6844            for sh1, sh2, b1 in zip(shapes[0],
6845                                    shapes[1][1:],
6846                                    node.inputs[0].broadcastable):
6847                if b1:
6848                    l.append(sh2)
6849                else:
6850                    l.append(sh1)
6851            return [tuple(l)]
6852        else:
6853            import theano.typed_list
6854            assert isinstance(node.inputs[1],
6855                              theano.typed_list.TypedListVariable)
6856            raise ShapeError("Case not implemented")
6857            shape = shapes[0]
6858            for i in xrange(len(shapes[0]) - 1):
6859                shape[i] = shapes[1][i]
6860            return [(shape)]
6861
6862    def make_node(self, a, choices):
6863        # Import here as it isn't imported by default and we can't
6864        # import at the top as it would cause circular import.
6865        import theano.typed_list
6866        a = as_tensor_variable(a)
6867        if a.dtype not in theano.tensor.discrete_dtypes:
6868            raise TypeError(
6869                'choose first argument must have an [u]int* dtype. Got %s.'
6870                % a.dtype)
6871
6872        if isinstance(choices, (tuple, list,
6873                                theano.typed_list.TypedListVariable)):
6874            choice = theano.typed_list.make_list(choices)
6875            choice_ndim = choice.ttype.ndim
6876            choice_bcast = choice.ttype.broadcastable
6877        else:
6878            choice = as_tensor_variable(choices)
6879            choice_ndim = choice.ndim - 1
6880            choice_bcast = choice.broadcastable[1:]
6881        out_ndim = np.max([a.ndim, choice_ndim])
6882
6883        # Make explicit all added broadcastable dimensions.
6884        a = shape_padleft(a, out_ndim - a.ndim)
6885        if len(choice_bcast) != out_ndim:
6886            if isinstance(choice.type, TensorType):
6887                choice = choice.dimshuffle(0,
6888                                           *(('x',) * (out_ndim - choice_ndim) +
6889                                             tuple(range(1, choice.ndim))))
6890                choice_ndim = choice.ndim - 1
6891                choice_bcast = choice.broadcastable[1:]
6892            else:
6893                raise NotImplementedError(
6894                    "We currently didn't implemented that case. "
6895                    "To make it work, explicitly add dimensions "
6896                    "of size one for dimensions that will be broadcasted")
6897
6898        bcast = [False] * out_ndim
6899        for idx, (b1, b2) in enumerate(
6900            zip(a.broadcastable,
6901                (True,) * (out_ndim - choice_ndim) + choice_bcast)):
6902            if b1 and b2:
6903                bcast[idx] = True
6904        o = TensorType(choice.dtype, bcast)
6905        return Apply(self, [a, choice], [o()])
6906
6907    def perform(self, node, inputs, outputs):
6908        (z,) = outputs
6909        a = inputs[0]
6910        choice = inputs[1]
6911        # TODO reuse out?
6912        z[0] = np.choose(a, choice, mode=self.mode)
6913
6914
6915class AllocEmpty(gof.Op):
6916    """Implement Alloc on the cpu, but without initializing memory."""
6917
6918    __props__ = ("dtype", )
6919    params_type = ParamsType(typecode=int32_t)
6920
6921    # specify the type of the data
6922    def __init__(self, dtype):
6923        assert isinstance(dtype, str), dtype
6924        self.dtype = dtype.lower()
6925
6926    @property
6927    def typecode(self):
6928        return np.dtype(self.dtype).num
6929
6930    def make_node(self, *shape):
6931        shape, bcast = alloc_validate_shape(shape)
6932        otype = TensorType(dtype=self.dtype, broadcastable=bcast)
6933        output = otype()
6934
6935        output.tag.values_eq_approx = values_eq_approx_always_true
6936        # The outut can contain nan/inf.  output.type is a new
6937        # instance, so we can do this only for that variable.
6938        output.type.filter_checks_isfinite = False
6939
6940        # We can't reuse filter_checks_isfinite as by default it is
6941        # False and it is set to true only in DebugMode.
6942        # We can't set it in the type as other make_node can reuse the type.
6943        # We can't set it in the variable as it isn't copied when we copy
6944        # the variale. So we set it in the tag.
6945        output.tag.nan_guard_mode_check = False
6946        return Apply(self, shape, [output])
6947
6948    def debug_perform(self, node, inputs, out_, params):
6949        self.perform(node, inputs, out_, params)
6950        out_[0][0].fill(-123456789)
6951
6952    def perform(self, node, inputs, out_, params):
6953        out, = out_
6954        sh = tuple([int(i) for i in inputs])
6955        if out[0] is None or out[0].shape != sh:
6956            out[0] = np.empty(sh, dtype=self.dtype)
6957
6958    def c_code(self, node, name, inputs, out_, sub):
6959        out, = out_
6960        fail = sub['fail']
6961        shps = inputs
6962        nd = len(shps)
6963        params = sub['params']
6964        str = "npy_intp dims[%(nd)s];\n" % locals()
6965        for idx, sh in enumerate(shps):
6966            str += "dims[%(idx)s] =" \
6967                   "((npy_intp)((dtype_%(sh)s*)" \
6968                   " PyArray_DATA(%(sh)s))[0]);\n" % locals()
6969
6970        # Validate that the output storage exists
6971        str += "if(%(out)s==NULL\n" % locals()
6972        for idx, sh in enumerate(shps):
6973            str += "||PyArray_DIMS(%(out)s)[%(idx)s]!=dims[%(idx)s]" % locals()
6974
6975        str += """){
6976            /* Reference received to invalid output variable.
6977            Decrease received reference's ref count and allocate new
6978            output variable */
6979            Py_XDECREF(%(out)s);
6980            %(out)s = (PyArrayObject*)PyArray_EMPTY(%(nd)s,
6981                                                    dims,
6982                                                    %(params)s->typecode,
6983                                                    0);
6984            if (!%(out)s)
6985            {
6986                PyErr_SetString(PyExc_MemoryError, "alloc failed");
6987                %(fail)s;
6988            }
6989        }
6990        """ % locals()
6991        return str
6992
6993    def infer_shape(self, node, input_shapes):
6994        return [node.inputs]
6995
6996    def c_code_cache_version(self):
6997        return (4,)
6998
6999    def do_constant_folding(self, node):
7000        return False
7001
7002    def connection_pattern(self, node):
7003        return [[False] for i in node.inputs]
7004
7005    def grad(self, inputs, grads):
7006        return [DisconnectedType()() for i in inputs]
7007
7008    def R_op(self, inputs, eval_points):
7009        return [zeros(inputs, self.dtype)]
7010