1from __future__ import absolute_import, print_function, division
2import sys
3from textwrap import dedent
4import warnings
5import logging
6
7import numpy as np
8from six import integer_types
9from six.moves import xrange
10
11import theano
12from theano.compat import izip
13from theano.gradient import DisconnectedType
14from theano import gof
15from theano.gof import Apply, hashtype, Op, Type, MethodNotDefined, ParamsType
16from theano.printing import pprint
17from theano import scalar as scal
18from theano.tensor.basic import alloc
19from theano.tensor.basic import (addbroadcast, clip, get_scalar_constant_value,
20                                 TensorType, NotScalarConstantError)
21from theano.tensor.elemwise import DimShuffle
22from theano.tensor.type_other import NoneConst, SliceType, NoneTypeT, make_slice
23from theano import config
24from theano.compat import Iterable
25
26from .inc_code import inc_code
27
28_logger = logging.getLogger("theano.tensor.subtensor")
29
30# Do a lazy import of the sparse module
31sparse_module_ref = None
32
33
34class AdvancedIndexingError(TypeError):
35    """
36    Raised when Subtensor is asked to perform advanced indexing.
37
38    """
39    pass
40
41
42class AdvancedBooleanIndexingError(TypeError):
43    """
44    Raised when Subtensor is asked to perform advanced indexing with boolean masks.
45
46    """
47    pass
48
49
50##########
51# Helpful functions to deal with Subtensor and IncSubtensor
52##########
53
54def make_constant(args):
55    """
56    Convert python litterals to theano constants in subtensor arguments.
57
58    """
59    def conv(a):
60            if a is None:
61                return a
62            elif isinstance(a, slice):
63                return slice(conv(a.start),
64                             conv(a.stop),
65                             conv(a.step))
66            elif isinstance(a, (integer_types, np.integer)):
67                return scal.ScalarConstant(scal.int64, a)
68            else:
69                return a
70    return tuple(map(conv, args))
71
72
73def get_idx_list(inputs, idx_list, get_count=False):
74    """
75    Given a list of inputs to the subtensor and its idx_list reorders
76    the inputs according to the idx list to get the right values.
77
78    If get_counts=True, instead returns the number of inputs consumed
79    during this process.
80
81    """
82
83    # The number of indices
84    n = len(inputs) - 1
85
86    # The subtensor (or idx_list) does not depend on the inputs.
87    if n == 0:
88        return tuple(idx_list)
89    indices = list(reversed(list(inputs[1:])))
90
91    # General case
92    def convert(entry):
93        if isinstance(entry, gof.Type):
94            return indices.pop()
95        elif isinstance(entry, slice):
96            return slice(convert(entry.start),
97                         convert(entry.stop),
98                         convert(entry.step))
99        else:
100            return entry
101    cdata = tuple(map(convert, idx_list))
102    if get_count:
103        return n - len(indices)
104    else:
105        return cdata
106
107
108def get_canonical_form_slice(theslice, length):
109    """
110    Given a slice [start:stop:step] transform it into a canonical form
111    that respects the conventions imposed by python and numpy.
112
113    In a canonical form a slice is represented by a canonical form slice,
114    in which 0 <= start <= stop <= length and step > 0, and a flag which says
115    if the resulting set of numbers needs to be reversed or not.
116
117    """
118    from theano.tensor import switch, lt, ge, sgn
119    if isinstance(theslice, slice):
120
121        def analyze(x):
122            try:
123                x_constant = get_scalar_constant_value(x)
124                is_constant = True
125            except theano.tensor.NotScalarConstantError:
126                x_constant = theano.tensor.extract_constant(x)
127                is_constant = False
128            return x_constant, is_constant
129
130        start, is_start_constant = analyze(theslice.start)
131        stop, is_stop_constant = analyze(theslice.stop)
132        step, is_step_constant = analyze(theslice.step)
133        length, is_length_constant = analyze(length)
134
135        if step is None:
136            step = 1
137            is_step_constant = True
138
139        # First handle the easier and common case where `step` is 1 and
140        # either `start` or `stop` is a range boundary. More specializations
141        # could be added later. This makes the resulting graph smaller than
142        # in the generic case below.
143        if step == 1:
144            is_start_0 = (
145                start is None or start == 0 or
146                (is_start_constant and is_length_constant and
147                 start < 0 and start + length <= 0))
148            is_stop_length = (
149                stop is None or stop in [length, sys.maxsize] or
150                (is_stop_constant and is_length_constant and
151                 stop >= length))
152            if is_start_0:
153                # 0:stop:1
154                if is_stop_length:
155                    # Full slice.
156                    return slice(0, length, 1), 1
157                if is_stop_constant and stop >= 0:
158                    return (slice(0, switch(lt(stop, length), stop, length),
159                                  1), 1)
160                stop_plus_len = stop + length
161                stop = switch(
162                    lt(stop, 0),
163                    # stop < 0
164                    switch(
165                        lt(stop_plus_len, 0),
166                        # stop + len < 0
167                        0,
168                        # stop + len >= 0
169                        stop_plus_len),
170                    # stop >= 0: use min(stop, length)
171                    switch(lt(stop, length), stop, length))
172                return slice(0, stop, 1), 1
173            elif is_stop_length:
174                # start:length:1
175                if is_start_constant and start >= 0:
176                    return slice(switch(lt(start, length), start, length),
177                                 length, 1), 1
178                start_plus_len = start + length
179                start = switch(
180                    lt(start, 0),
181                    # start < 0
182                    switch(
183                        lt(start_plus_len, 0),
184                        # start + len < 0
185                        0,
186                        # start + len >= 0
187                        start_plus_len),
188                    # start >= 0: use min(start, length)
189                    switch(lt(start, length), start, length))
190                return slice(start, length, 1), 1
191
192        # This is the generic case.
193
194        if is_step_constant:
195            # When we know the sign of `step`, the graph can be made simpler.
196            assert step != 0
197            if step > 0:
198                def switch_neg_step(a, b):
199                    return b
200                abs_step = step
201                sgn_step = 1
202            else:
203                def switch_neg_step(a, b):
204                    return a
205                abs_step = -step
206                sgn_step = -1
207        else:
208            is_step_neg = lt(step, 0)
209
210            def switch_neg_step(a, b):
211                return switch(is_step_neg, a, b)
212            abs_step = abs(step)
213            sgn_step = sgn(step)
214
215        defstart = switch_neg_step(length - 1, 0)
216        defstop = switch_neg_step(-1, length)
217        if start is None:
218            start = defstart
219        else:
220            start = switch(lt(start, 0), start + length, start)
221            start = switch(lt(start, 0), switch_neg_step(-1, 0), start)
222            start = switch(ge(start, length),
223                           switch_neg_step(length - 1, length),
224                           start)
225        if stop is None or stop == sys.maxsize:
226            # The special "maxsize" case is probably not needed here,
227            # as slices containing maxsize are not generated by
228            # __getslice__ anymore.
229            stop = defstop
230        else:
231            stop = switch(lt(stop, 0), stop + length, stop)
232            stop = switch(lt(stop, 0), -1, stop)
233            stop = switch(ge(stop, length), length, stop)
234
235        nw_stop = switch_neg_step(start + 1, stop)
236        slice_len = (start - stop - 1) // abs_step + 1
237        slice_len = switch(lt(slice_len, 0), 0, slice_len)
238        neg_start = nw_stop - (slice_len - 1) * abs_step - 1
239        neg_start = switch(lt(neg_start, 0), (nw_stop - 1), neg_start)
240        nw_start = switch_neg_step(neg_start, start)
241        nw_start = switch(lt(nw_start, 0), 0, nw_start)
242        nw_stop = switch(lt(nw_stop, 0), 0, nw_stop)
243        # Ensure start <= stop.
244        nw_start = switch(lt(nw_start, nw_stop), nw_start, nw_stop)
245
246        nw_step = abs_step
247        if step != 1:
248            reverse = sgn_step
249            return slice(nw_start, nw_stop, nw_step), reverse
250        else:
251            return slice(nw_start, nw_stop, nw_step), 1
252    else:
253        value = theano.tensor.extract_constant(theslice)
254        value = switch(lt(value, 0), (value + length), value)
255
256        return value, 1
257
258
259class Subtensor(Op):
260    """
261    Return a subtensor view.
262
263    The inputs array is the tensor x, followed by scalar integer types.
264    TODO: WRITEME: how are the scalar integer variables formatted?
265
266    This class uses a relatively complex internal representation of the inputs
267    to remember how the input tensor x should be sliced.
268
269    idx_list: instance variable TODO: WRITEME: is this a list or a tuple?
270                                        (old docstring gives two conflicting
271                                        descriptions)
272              elements are either integers, theano scalar types, or slices.
273              one element per "explicitly named dimension"
274                TODO: WRITEME: what is an "explicitly named dimension" ?
275
276              if integer:
277                  indexes into the inputs array
278              if slice:
279                  start/stop/step members of each slice are integer indices
280                  into the inputs array or None
281                  integer indices be actual integers or theano scalar types
282
283    Note that the idx_list defines the Op, so two Subtensor instances are
284    considered to be different Ops if they have different idx_list fields.
285    This means that the entries in it are theano Types, not theano Variables.
286
287    @todo: add support for advanced tensor indexing (in Subtensor_dx too).
288
289    """
290    e_subslice = 'nested slicing is not supported'
291    e_indextype = "Invalid index type or slice for Subtensor"
292    debug = 0
293    check_input = False
294    view_map = {0: [0]}
295    _f16_ok = True
296    __props__ = ("idx_list",)
297
298    @staticmethod
299    def collapse(idxs, cond):
300        """
301        Parameters
302        ----------
303        idxs : a list of indices or slices.
304        cond : a callable that returns a bool
305
306        Returns
307        -------
308        list
309            idxs, with the slices flattened out into a list.
310            If cond is true for an entry, does not flatten it.
311
312        """
313        ret = []
314
315        def helper(entry):
316            if cond(entry):
317                ret.append(entry)
318            elif isinstance(entry, slice):
319                helper(entry.start)
320                helper(entry.stop)
321                helper(entry.step)
322
323        for idx in idxs:
324            helper(idx)
325
326        return ret
327
328    @staticmethod
329    def convert(entry, slice_ok=True):
330        """
331        Change references to Variables into references to Types.
332
333        The "idx_list" field is unique to each Subtensor instance.
334        It is not unique to each Apply node, so it should not refer to
335        specific Variables.
336        TODO: WRITEME: This method also accepts "entry" already being a Type;
337            when would that happen?
338
339        """
340        invalid_scal_types = [scal.float64, scal.float32, scal.float16]
341        scal_types = [scal.int64, scal.int32, scal.int16, scal.int8]
342        tensor_types = [theano.tensor.lscalar, theano.tensor.iscalar,
343                        theano.tensor.wscalar, theano.tensor.bscalar]
344        invalid_tensor_types = [theano.tensor.fscalar, theano.tensor.dscalar,
345                                theano.tensor.cscalar, theano.tensor.zscalar]
346
347        if (isinstance(entry, (np.ndarray, theano.tensor.Variable)) and
348                hasattr(entry, 'dtype') and entry.dtype == 'bool'):
349            raise AdvancedBooleanIndexingError(Subtensor.e_indextype, entry)
350
351        if (isinstance(entry, gof.Variable) and
352            (entry.type in invalid_scal_types or
353             entry.type in invalid_tensor_types)):
354            raise TypeError("Expected an integer")
355
356        if isinstance(entry, gof.Variable) and entry.type in scal_types:
357            return entry.type
358        elif isinstance(entry, gof.Type) and entry in scal_types:
359            return entry
360
361        if (isinstance(entry, gof.Variable) and
362                entry.type in tensor_types and
363                np.all(entry.type.broadcastable)):
364            return scal.get_scalar_type(entry.type.dtype)
365        elif (isinstance(entry, gof.Type) and
366              entry in tensor_types and
367              np.all(entry.broadcastable)):
368            return scal.get_scalar_type(entry.dtype)
369        elif slice_ok and isinstance(entry, slice):
370            a = entry.start
371            b = entry.stop
372            c = entry.step
373
374            if a is not None:
375                slice_a = Subtensor.convert(a, False)
376            else:
377                slice_a = None
378
379            if b is not None and b != sys.maxsize:
380                # The special "maxsize" case is probably not needed here,
381                # as slices containing maxsize are not generated by
382                # __getslice__ anymore.
383                slice_b = Subtensor.convert(b, False)
384            else:
385                slice_b = None
386
387            if c is not None:
388                slice_c = Subtensor.convert(c, False)
389            else:
390                slice_c = None
391
392            return slice(slice_a, slice_b, slice_c)
393        elif isinstance(entry, (integer_types, np.integer)):
394            # Disallow the use of python scalars in idx_list
395            raise TypeError("Python scalar in idx_list."
396                            "Please report this error to theano-dev.")
397        else:
398            raise AdvancedIndexingError(Subtensor.e_indextype, entry)
399
400    def get_constant_idx(self, inputs, allow_partial=False,
401                         only_process_constants=False, elemwise=True):
402        """
403        Return the idx_list with constant inputs replaced by their
404        python scalar equivalent.
405        May raise `theano.tensor.NotScalarConstantError` if the idx contains
406        non-constant entries.
407
408        If allow_partial is True, then entries that are not constant will
409        stay as their input variable rather than raising an exception.
410
411        None entries are always left as-is.
412
413        Parameters
414        ----------
415        only_process_constants
416            If True, we only attempt to obtain the value of an index/slice if
417            it's directly constant and don't try to dig through dimshuffles,
418            fills, allocs, and other to figure out its value.
419
420        Examples
421        --------
422        Example usage where v, a are appropriately typed theano variables :
423        >>> b = a[v, 1:3]
424        >>> b.owner.op.idx_list
425        (Scalar(int64), slice(Scalar(int64), Scalar(int64), None))
426        >>> b.owner.op.get_constant_idx(b.owner.inputs, allow_partial=True)
427        [v, slice(1, 3, None)]
428        >>> b.owner.op.get_constant_idx(b.owner.inputs)
429        NotScalarConstantError: v
430
431        """
432        real_idx = get_idx_list(inputs, self.idx_list)
433
434        def conv(val):
435            if val is None:
436                return None
437            elif isinstance(val, slice):
438                return slice(conv(val.start),
439                             conv(val.stop),
440                             conv(val.step))
441            else:
442                try:
443                    return get_scalar_constant_value(
444                        val,
445                        only_process_constants=only_process_constants,
446                        elemwise=elemwise)
447                except theano.tensor.NotScalarConstantError:
448                    if allow_partial:
449                        return val
450                    else:
451                        raise
452
453        return list(map(conv, real_idx))
454
455    def __init__(self, idx_list):
456        self.idx_list = tuple(map(self.convert, idx_list))
457
458    @staticmethod
459    def my_as_scalar(a):
460        # Since scal.as_scalar does not know about tensor types (it would
461        # create a circular import) , this method converts either a
462        # TensorVariable or a ScalarVariable to a scalar.
463        if isinstance(a, gof.Variable) and isinstance(a.type, TensorType):
464            return theano.tensor.scalar_from_tensor(a)
465        else:
466            return scal.as_scalar(a)
467
468    def make_node(self, x, *inputs):
469        """
470        Parameters
471        ----------
472        x
473            The tensor to take a subtensor of.
474        inputs
475            A list of theano Scalars.
476
477        """
478        x = theano.tensor.as_tensor_variable(x)
479        inputs = tuple(self.my_as_scalar(a) for a in inputs)
480
481        idx_list = list(self.idx_list)
482        if len(idx_list) > x.type.ndim:
483            raise IndexError('too many indices for array')
484
485        input_types = Subtensor.collapse(idx_list,
486                                         lambda entry: isinstance(entry,
487                                                                  gof.Type))
488        if len(inputs) != len(input_types):
489            raise IndexError(
490                "Not enough inputs to fill in the Subtensor template.",
491                inputs, idx_list)
492        for input, expected_type in izip(inputs, input_types):
493            if input.type != expected_type:
494                raise TypeError(
495                    "Wrong type for Subtensor template. Expected %s, got %s."
496                    % (input.type, expected_type))
497
498        # infer the broadcasting pattern
499        padded = (self.get_constant_idx((None,) + inputs, allow_partial=True) +
500                  [slice(None, None, None)] * (x.type.ndim - len(idx_list)))
501        broadcastable = []
502        for i, (p, bc) in enumerate(izip(padded, x.type.broadcastable)):
503            if isinstance(p, slice):
504                if bc:
505                    start = p.start
506                    try:
507                        start = get_scalar_constant_value(start)
508                    except NotScalarConstantError:
509                        pass
510                    if start is None or start == 0:
511                        start = p.start
512                        if start is None:
513                            start = 0
514                        if (p.stop is None or
515                            (isinstance(p.stop, (integer_types, np.integer,
516                                                 np.ndarray)) and
517                             p.stop > start)):
518                            broadcastable.append(True)
519                            continue
520
521                broadcastable.append(False)
522
523        return gof.Apply(self,
524                         (x, ) + inputs,
525                         [theano.tensor.tensor(dtype=x.type.dtype,
526                                               broadcastable=broadcastable)])
527
528    def perform(self, node, inputs, out_):
529        out, = out_
530        x = inputs[0]
531
532        cdata = get_idx_list(inputs, self.idx_list)
533        if len(cdata) == 1:
534            cdata = cdata[0]
535
536        out[0] = np.asarray(x.__getitem__(cdata))
537
538    def infer_shape(self, node, shapes):
539        xshp = shapes[0]
540        assert len(xshp) == node.inputs[0].ndim
541        outshp = []
542        actual_idx_list = list(get_idx_list(node.inputs, self.idx_list))
543        padded = (actual_idx_list +
544                  [slice(None, None, None)] * (len(xshp) - len(self.idx_list)))
545        i = 0
546        for idx, xl in izip(padded, xshp):
547            if isinstance(idx, slice):
548                # If it is the default (None, None, None) slice, or a variant,
549                # the shape will be xl
550                if ((idx.start in [None, 0]) and
551                        (idx.stop in [None, sys.maxsize]) and
552                        (idx.step is None or idx.step == 1)):
553                    outshp.append(xl)
554                else:
555                    cnf = get_canonical_form_slice(idx, xl)[0]
556                    if cnf.step == 1:
557                        length = cnf.stop - cnf.start
558                    else:
559                        length = (cnf.stop - cnf.start - 1) // cnf.step + 1
560                    outshp.append(length)
561                i += 1
562            else:
563                # That dimension is dropped
564                pass
565        assert i == node.outputs[0].ndim
566        assert len(outshp) == node.outputs[0].ndim
567        return [outshp]
568
569    def grad(self, inputs, grads):
570        gz, = grads
571        x = inputs[0]
572        rest = inputs[1:]
573        if x.dtype in theano.tensor.discrete_dtypes:
574            first = x.zeros_like().astype(theano.config.floatX)
575        else:
576            # For best optimization, we let this as an inc.
577            # This allow the opt local_IncSubtensor_serialize to apply first.
578            # We have an optimization that will convert this to a
579            # set subtensor here at:
580            # theano/tensor/opt.py:local_incsubtensor_of_zeros_to_setsubtensor()
581            first = IncSubtensor(self.idx_list)(x.zeros_like(),
582                                                gz, *rest)
583        return ([first] + [DisconnectedType()()] * len(rest))
584
585    def connection_pattern(self, node):
586
587        rval = [[True]]
588
589        for ipt in node.inputs[1:]:
590            rval.append([False])
591
592        return rval
593
594    def __hash__(self):
595        # TODO: optimize by cache this hash value
596        msg = []
597        for entry in self.idx_list:
598            if isinstance(entry, slice):
599                msg += [(entry.start, entry.stop, entry.step)]
600            else:
601                msg += [entry]
602
603        idx_list = tuple(msg)
604        # backport
605        # idx_list = tuple((entry.start, entry.stop, entry.step)
606        #                 if isinstance(entry, slice)
607        #                 else entry
608        #                 for entry in self.idx_list)
609        return hash(idx_list)
610
611    @staticmethod
612    def str_from_slice(entry):
613        msg = []
614        for x in [entry.start, entry.stop, entry.step]:
615            if x is None:
616                msg.append("")
617            else:
618                msg.append(str(x))
619        return ":".join(msg)
620
621    def __str__(self):
622        indices = []
623        for entry in self.idx_list:
624            if isinstance(entry, slice):
625                indices.append(self.str_from_slice(entry))
626            else:
627                indices.append(str(entry))
628        return "%s{%s}" % (self.__class__.__name__, ", ".join(indices))
629
630    @staticmethod
631    def default_helper_c_code_args():
632        """
633        Returns a dictionary of default arguments to helper_c_code.
634
635        """
636
637        return {"c_prefix": "PyArray",
638                "strides_mul": 1}
639
640    @staticmethod
641    def helper_c_code(node, name, inputs, outputs, sub, idx_list, view_ndim,
642                      c_prefix=None,
643                      strides_mul=None):
644        """
645        The parameters c_prefix are there to allow reusing this
646        function on PyArray and GpuArray object.
647
648        This fct take as input the x.
649
650        """
651
652        default_args = Subtensor.default_helper_c_code_args()
653
654        if strides_mul is None:
655            strides_mul = default_args['strides_mul']
656
657        if c_prefix is None:
658            c_prefix = default_args['c_prefix']
659
660        #
661        # two arrays are created in C code:
662        # is_slice: len == ndim, 0 means int, 1 means slice
663        # subtensor_spec: len = n_ints + 3 * n_slices
664        #
665        fail = sub['fail']
666        init_cmds = []  # initialization for subtensor_spec
667        is_slice = []
668        # TODO: change that, it might lead to unexpected results,
669        # see assembla-#767
670        NONE_CODE = sys.maxsize - 1
671
672        pos = [0, 1]  # annoying version of global variable for init_entry
673
674        def inc_spec_pos(amt):
675            pos[0] += amt
676
677        def inc_input_pos(amt):
678            pos[1] += amt
679
680        def spec_pos():
681            return pos[0]
682
683        def input_pos():
684            return pos[1]
685
686        def init_entry(entry, depth=0):
687            if isinstance(entry, (np.integer, integer_types)):
688                init_cmds.append(
689                    "subtensor_spec[%i] = %i;" % (spec_pos(),
690                                                  entry))
691                inc_spec_pos(1)
692                if depth == 0:
693                    is_slice.append(0)
694            elif isinstance(entry, Type):
695                init_cmds.append(
696                    "subtensor_spec[%i] = %s;" % (spec_pos(),
697                                                  inputs[input_pos()]))
698                inc_spec_pos(1)
699                inc_input_pos(1)
700                if depth == 0:
701                    is_slice.append(0)
702            elif entry is None:
703                init_cmds.append(
704                    "subtensor_spec[%i] = %i;" % (spec_pos(),
705                                                  NONE_CODE))
706                inc_spec_pos(1)
707                if depth == 0:
708                    is_slice.append(0)
709            elif depth == 0 and isinstance(entry, slice):
710                init_entry(entry.start, depth + 1)
711                init_entry(entry.stop, depth + 1)
712                init_entry(entry.step, depth + 1)
713                is_slice.append(1)
714            else:
715                assert 0, entry
716
717        for entry in idx_list:
718            init_entry(entry)
719        # make sure we used all inputs
720        assert input_pos() == len(inputs), input_pos()
721        assert len(is_slice) <= node.inputs[0].ndim, node.inputs[0].ndim
722
723        len_is_slice = len(is_slice)
724
725        len_subtensor_spec = spec_pos()
726        subensor_spec = "npy_intp subtensor_spec[%(len_subtensor_spec)s];" % locals()
727        if len_subtensor_spec == 0:
728            subensor_spec = "npy_intp * subtensor_spec = NULL;"
729
730        if is_slice:
731            is_slice_init = "int is_slice[] = {" + ",".join([str(s) for s in
732                                                             is_slice]) + "};"
733        else:
734            is_slice_init = "int* is_slice = NULL;"
735        subtensor_init = "\n".join(init_cmds)
736
737        x, = inputs[:1]
738        z, = outputs
739
740        if view_ndim:
741            rval = """
742        // Argument of the view
743        npy_intp xview_dims[%(view_ndim)s];
744        npy_intp xview_strides[%(view_ndim)s];
745
746        """ % locals()
747        else:
748            rval = """
749        // Argument of the view
750        npy_intp* xview_dims = NULL;
751        npy_intp* xview_strides = NULL;
752
753        """
754
755        rval += """
756        // One more argument of the view
757        npy_intp xview_offset = 0;
758
759        // The subtensor is created by iterating over the dimensions
760        // and updating stride, shape, and data pointers
761
762        %(is_slice_init)s
763        %(subensor_spec)s
764        %(subtensor_init)s;
765        int spec_pos = 0; //position in subtensor_spec
766        int inner_ii = 0; // the current dimension of zview
767        int outer_ii = 0; // current dimension of z
768
769
770        for (; outer_ii < %(len_is_slice)s; ++outer_ii)
771        {
772            if (is_slice[outer_ii])
773            {
774                npy_intp length = %(c_prefix)s_DIMS(%(x)s)[outer_ii];
775                npy_intp slicelength;
776                npy_intp start = subtensor_spec[spec_pos+0];
777                npy_intp stop  = subtensor_spec[spec_pos+1];
778                npy_intp step  = subtensor_spec[spec_pos+2];
779                if (step == %(NONE_CODE)s) step = 1;
780
781                npy_intp defstart = step < 0 ? length-1 : 0;
782                npy_intp defstop = step < 0 ? -1 : length;
783
784                // logic adapted from
785                // PySlice_GetIndicesEx in python source
786                if (!step)
787                {
788                    PyErr_Format(PyExc_ValueError,
789                                 "slice step cannot be zero");
790                    %(fail)s;
791                }
792
793                if (start == %(NONE_CODE)s)
794                {
795                    start = defstart;
796                }
797                else
798                {
799                    if (start < 0) start += length;
800                    if (start < 0) start = (step < 0) ? -1 : 0;
801                    if (start >= length)
802                        start = (step < 0) ? length - 1 : length;
803                }
804
805                if (stop == %(NONE_CODE)s)
806                {
807                    stop = defstop;
808                }
809                else
810                {
811                    if (stop < 0) stop += length;
812                    if (stop < 0) stop = (step < 0) ? -1 : 0;
813                    if (stop >= length)
814                        stop = (step < 0) ? length - 1 : length;
815                }
816
817                if ((step < 0 && stop >= start)
818                    || (step > 0 && start >= stop)) {
819                    slicelength = 0;
820                }
821                else if (step < 0) {
822                    slicelength = (stop-start+1)/step+1;
823                }
824                else {
825                    slicelength = (stop-start-1)/step+1;
826                }
827
828                if (0){
829                    fprintf(stdout, "start %%zi\\n", start);
830                    fprintf(stdout, "stop %%zi\\n", stop);
831                    fprintf(stdout, "step %%zi\\n", step);
832                    fprintf(stdout, "length %%zi\\n", length);
833                    fprintf(stdout, "slicelength %%zi\\n", slicelength);
834                }
835
836                assert (slicelength <= length);
837
838                xview_offset += (npy_intp)%(c_prefix)s_STRIDES(%(x)s)[outer_ii]
839                    * start * %(strides_mul)s;
840                xview_dims[inner_ii] = slicelength;
841                xview_strides[inner_ii] = (npy_intp)%(c_prefix)s_STRIDES(%(x)s)[outer_ii] * step;
842
843                inner_ii += 1;
844                spec_pos += 3;
845            }
846            else // tuple coord `outer_ii` is an int
847            {
848                int idx = subtensor_spec[spec_pos];
849                if (idx < 0) idx += %(c_prefix)s_DIMS(%(x)s)[outer_ii];
850                if (idx >= 0)
851                {
852                    if (idx < %(c_prefix)s_DIMS(%(x)s)[outer_ii])
853                    {
854                        xview_offset += (npy_intp)%(c_prefix)s_STRIDES(%(x)s)[outer_ii] * idx *
855                               %(strides_mul)s;
856                    }
857                    else
858                    {
859                        PyErr_Format(PyExc_IndexError,"index out of bounds");
860                        %(fail)s;
861                    }
862                }
863                else
864                {
865                    PyErr_Format(PyExc_IndexError,"index out of bounds");
866                    %(fail)s;
867                }
868
869                spec_pos += 1;
870            }
871        }
872        assert (inner_ii <= %(view_ndim)s);
873        while (inner_ii < %(view_ndim)s)
874        {
875            assert (outer_ii < %(c_prefix)s_NDIM(%(x)s));
876            xview_dims[inner_ii] = %(c_prefix)s_DIMS(%(x)s)[outer_ii];
877            xview_strides[inner_ii] = %(c_prefix)s_STRIDES(%(x)s)[outer_ii];
878
879            inner_ii += 1;
880            outer_ii += 1;
881        }
882        """ % locals()
883        # print rval
884        return rval
885
886    @staticmethod
887    def helper_c_code_cache_version():
888        return (9,)
889
890    def c_code(self, node, name, inputs, outputs, sub):  # DEBUG
891        if not isinstance(node.inputs[0].type, theano.tensor.TensorType):
892            raise NotImplementedError()
893
894        x = inputs[0]
895        z, = outputs
896        ndim = node.inputs[0].ndim
897        view_ndim = node.outputs[0].ndim
898        fail = sub['fail']
899
900        decl = "PyArrayObject * xview = NULL;"
901
902        checkNDim = """
903        if (PyArray_NDIM(%(x)s) != %(ndim)s){
904            PyErr_SetString(PyExc_ValueError,
905                                     "Expected %(ndim)s dimensions input"
906                                        );
907            %(fail)s
908        }
909        """ % locals()
910
911        get_xview = self.helper_c_code(node, name, inputs, outputs, sub,
912                                       self.idx_list, view_ndim)
913        build_view = """
914        //TODO: give this Op a second output so that this view can be cached
915        //TODO: alternatively, fix the memory leak on failure
916        Py_INCREF(PyArray_DESCR(%(x)s));
917        xview = (PyArrayObject*)PyArray_NewFromDescr(
918                &PyArray_Type,
919                PyArray_DESCR(%(x)s),
920                %(view_ndim)s,
921                xview_dims,
922                xview_strides,
923                PyArray_BYTES(%(x)s) + xview_offset,
924                PyArray_FLAGS(%(x)s),
925                NULL);
926        assert (PyArray_NDIM(xview) == %(view_ndim)s);
927        if (!xview)
928        {
929            %(fail)s;
930        }
931        """ % locals()
932
933        finish_view = """
934        Py_XDECREF(%(z)s);
935        Py_INCREF(py_%(x)s);
936        PyArray_SetBaseObject(xview, py_%(x)s);
937        assert(py_%(x)s == (PyObject*)%(x)s);
938        %(z)s = xview;
939        """ % locals()
940
941        return (decl + checkNDim +
942                "{" + get_xview + build_view + finish_view + "}")
943
944    def c_code_cache_version(self):
945        hv = self.helper_c_code_cache_version()
946        # If `helper_c_code_cache_version` is not versioned we do not want to
947        # have a versioned version of this op's C code.
948        if len(hv) == 0:
949            return ()
950        return (4, hv)
951
952    def R_op(self, inputs, eval_points):
953        # Subtensor is not differentiable wrt to its indices, therefore we
954        # do not even need to consider the eval_points provided for those
955        # (they should be defaulted to zeros_like by the global R_op)
956        if eval_points[0] is None:
957            return [None]
958        return self(eval_points[0], *inputs[1:], **dict(return_list=True))
959
960
961class SubtensorPrinter:
962
963    def process(self, r, pstate):
964        if r.owner is None:
965            raise TypeError("Can only print Subtensor.")
966        elif isinstance(r.owner.op, Subtensor):
967            idxs = r.owner.op.idx_list
968            inputs = list(r.owner.inputs)
969            input = inputs.pop(0)
970            sidxs = []
971            old_precedence = getattr(pstate, 'precedence', None)
972            try:
973                pstate.precedence = -1000
974
975                for entry in idxs:
976                    if isinstance(entry, integer_types):
977                        sidxs.append(str(entry))
978                    elif isinstance(entry, scal.Scalar):
979                        sidxs.append(pstate.pprinter.process(inputs.pop()))
980                    elif isinstance(entry, slice):
981                        if entry.start is None or entry.start == 0:
982                            msg1 = ""
983                        else:
984                            msg1 = entry.start
985
986                        if entry.stop is None or entry.stop == sys.maxsize:
987                            msg2 = ""
988                        else:
989                            msg2 = entry.stop
990
991                        if entry.step is None:
992                            msg3 = ""
993                        else:
994                            msg3 = ":%s" % entry.step
995
996                        sidxs.append("%s:%s%s" % (msg1, msg2, msg3))
997            finally:
998                pstate.precedence = old_precedence
999
1000            try:
1001                pstate.precedence = 1000
1002                sub = pstate.pprinter.process(input, pstate)
1003            finally:
1004                pstate.precedence = old_precedence
1005            return "%s[%s]" % (sub, ", ".join(sidxs))
1006        else:
1007            raise TypeError("Can only print Subtensor.")
1008
1009pprint.assign(Subtensor, SubtensorPrinter())
1010
1011
1012def set_subtensor(x, y, inplace=False,
1013                  tolerate_inplace_aliasing=False):
1014    """
1015    Return x with the given subtensor overwritten by y.
1016
1017    Parameters
1018    ----------
1019    x
1020        Symbolic variable for the lvalue of = operation.
1021    y
1022        Symbolic variable for the rvalue of = operation.
1023    tolerate_inplace_aliasing
1024        See inc_subtensor for documentation.
1025
1026    Examples
1027    --------
1028    To replicate the numpy expression "r[10:] = 5", type
1029
1030    >>> r = ivector()
1031    >>> new_r = set_subtensor(r[10:], 5)
1032
1033    """
1034    return inc_subtensor(x, y, inplace, set_instead_of_inc=True,
1035                         tolerate_inplace_aliasing=tolerate_inplace_aliasing)
1036
1037
1038def inc_subtensor(x, y, inplace=False, set_instead_of_inc=False,
1039                  tolerate_inplace_aliasing=False):
1040    """
1041    Return x with the given subtensor incremented by y.
1042
1043    Parameters
1044    ----------
1045    x
1046        The symbolic result of a Subtensor operation.
1047    y
1048        The amount by which to increment the subtensor in question.
1049    inplace
1050        Don't use. Theano will do it when possible.
1051    set_instead_of_inc
1052        If True, do a set_subtensor instead.
1053    tolerate_inplace_aliasing:
1054        Allow x and y to be views of a single underlying array even while
1055        working inplace. For correct results, x and y must not be overlapping
1056        views; if they overlap, the result of this Op will generally be
1057        incorrect. This value has no effect if inplace=False.
1058
1059    Examples
1060    --------
1061    To replicate the numpy expression "r[10:] += 5", type
1062
1063    >>> r = ivector()
1064    >>> new_r = inc_subtensor(r[10:], 5)
1065
1066    """
1067    # First of all, y cannot have a higher dimension than x,
1068    # nor have non-broadcastable dimensions where x is broadcastable.
1069
1070    x = theano.tensor.as_tensor_variable(x)
1071    y = theano.tensor.as_tensor_variable(y)
1072
1073    if y.ndim > x.ndim:
1074        raise TypeError(("Trying to increment a %d-dimensional "
1075                         "subtensor with a %d-dimensional value.") % (x.ndim,
1076                                                                      y.ndim))
1077
1078    dim_offset = x.ndim - y.ndim
1079    for dim in xrange(y.ndim):
1080        if (x.broadcastable[dim + dim_offset] and not y.broadcastable[dim]):
1081            # It is acceptable to try to increment a subtensor with a
1082            # broadcastable dim with a tensor that is not broadcastable
1083            # on that dimension. However, its length must then be 1.
1084            # We insert a Rebroadcast Op to make sure it is the case.
1085            y = addbroadcast(y, dim)
1086
1087    if not x.owner:
1088        raise TypeError('x must be the result of a subtensor operation')
1089
1090    # retrieve idx_list from x.owner
1091    if isinstance(x.owner.op, Subtensor):
1092        if tolerate_inplace_aliasing:
1093            destroyhandler_tolerate_aliased = [[0, 1]]
1094        else:
1095            destroyhandler_tolerate_aliased = []
1096        the_op = IncSubtensor(
1097            x.owner.op.idx_list, inplace, set_instead_of_inc,
1098            destroyhandler_tolerate_aliased=destroyhandler_tolerate_aliased)
1099        real_x = x.owner.inputs[0]
1100        real_idxargs = x.owner.inputs[1:]
1101        return the_op(real_x, y, *real_idxargs)
1102    elif isinstance(x.owner.op, AdvancedSubtensor1):
1103        real_x = x.owner.inputs[0]
1104        ilist = x.owner.inputs[1]
1105        the_op = AdvancedIncSubtensor1(inplace,
1106                                       set_instead_of_inc=set_instead_of_inc)
1107        return the_op(real_x, y, ilist)
1108    elif isinstance(x.owner.op, AdvancedSubtensor):
1109        real_x = x.owner.inputs[0]
1110        ilist = x.owner.inputs[1:]
1111
1112        the_op = AdvancedIncSubtensor(inplace,
1113                                      set_instead_of_inc=set_instead_of_inc)
1114        return the_op(real_x, y, *ilist)
1115    elif isinstance(x.owner.op, AdvancedBooleanSubtensor):
1116        real_x = x.owner.inputs[0]
1117        ilist = x.owner.inputs[1:]
1118
1119        the_op = AdvancedBooleanIncSubtensor(inplace,
1120                                             set_instead_of_inc=set_instead_of_inc)
1121        return the_op(real_x, y, *ilist)
1122    elif isinstance(x.owner.op, DimShuffle):
1123        inner_x = x.owner.inputs[0]
1124        # In the dimshuffle case, there are in fact two dimshuffles:
1125        # one to make the indexed dimension the last one,
1126        # and one to put it back where it was. So, in the case where we have
1127        # inc_subtensor(x[:,i], y), the graph is actually
1128        # inc_subtensor((x.T)[i].T, y).
1129        # We could get all the way to x, and then get rid of the dimshuffles
1130        # completely, but the problem is that advanced_inc_subtensor1 can only
1131        # work on the first (outer-most, left-most) dimension of x,
1132        # just like advanced_subtensor1.
1133        # So we call advanced_inc_subtensor1(x.T, i, y.T) (as we also need to
1134        # transpose y if it is not a scalar or a vector), but then we need to
1135        # return something that has the same shape as x, not as x.T (inner_x).
1136        # So re-apply the outer dimshuffle on the new inc_subtensor,
1137        # and return advanced_inc_subtensor1(x.T, i, y.T).T.
1138
1139        # Get the dimshuffle pattern to apply to y.
1140        x_order = x.owner.op.new_order
1141        y_order = ['x'] * x.ndim
1142        for i, v in enumerate(x_order):
1143            if v != 'x' and (v - dim_offset) >= 0:
1144                y_order[v - dim_offset] = i
1145
1146        # Warn if this code path would have produced wrong results in the past
1147        if config.warn.inc_set_subtensor1:
1148            # Dimshuffle pattern for y that would be equivalent to past code
1149            prev_y_order = ['x'] * (dim_offset) + list(range(y.ndim))
1150            if y_order != prev_y_order:
1151                warnings.warn(
1152                    'Although your current code is fine, please note that '
1153                    'earlier versions prior to 0.7 (or this development '
1154                    'version) may have yielded an incorrect result in '
1155                    'this `inc_subtensor` or `set_subtensor` operation. '
1156                    'To remove this warning, you can either set the '
1157                    '`warn.inc_set_subtensor1` config option to `False`, '
1158                    'or `warn.ignore_bug_before` to at least "0.7".',
1159                    stacklevel=2)
1160
1161        inner_incsubtensor = inc_subtensor(
1162            inner_x,
1163            y.dimshuffle(y_order),
1164            inplace=inplace,
1165            set_instead_of_inc=set_instead_of_inc,
1166            tolerate_inplace_aliasing=tolerate_inplace_aliasing)
1167        # The broadcastable pattern of inner_x may not be the same as
1168        # the one of x, so we have to build a new dimshuffle here,
1169        # instead of reusing x.owner.op().
1170        return inner_incsubtensor.dimshuffle(x.owner.op.new_order)
1171
1172    elif isinstance(x.owner.op, theano.tensor.Reshape):
1173        # This case happens when the indices are not arranged as a vector, but
1174        # as a higher-dimensional array. This is handled by the subtensor
1175        # by flattening this list, taking the subtensor, then reshaping the
1176        # result.
1177        inner_x = x.owner.inputs[0]
1178        # Try to apply inc_subtensor on inner_x.
1179        # If it works, there is no need to reshape, as the inc_subtensor
1180        # will have the same shape as inner_x, which is what we want.
1181        # We also explicitly duplicate y to its broadcasted shape
1182        # before we partially flatten it to inner_x dimension. This is
1183        # not strictly needed in all cases, but it is easier this way.
1184        if y.ndim > 0:
1185            # This if is needed to prevent some useless warning about
1186            # old code bug.
1187            expanded_y = alloc(y, *[x.shape[i] for i in xrange(x.ndim)])
1188            flattened_y = expanded_y.reshape(inner_x.shape)
1189        else:
1190            flattened_y = y
1191
1192        # Warn if this code path would have produced wrong results in the past
1193        if config.warn.inc_set_subtensor1:
1194            if inner_x.ndim > 1 and sum(y.broadcastable) > 0:
1195                warnings.warn(
1196                    'Although your current code is fine, please note that '
1197                    'earlier versions prior to 0.7 (or this development '
1198                    'version) may have yielded an incorrect result in '
1199                    'this `inc_subtensor` or `set_subtensor` operation. '
1200                    'To remove this warning, you can either set the '
1201                    '`warn.inc_set_subtensor1` config option to `False`, '
1202                    'or `warn.ignore_bug_before` to at least "0.7".',
1203                    stacklevel=2)
1204
1205        inner_incsubtensor = inc_subtensor(
1206            inner_x,
1207            flattened_y,
1208            inplace=inplace,
1209            set_instead_of_inc=set_instead_of_inc,
1210            tolerate_inplace_aliasing=tolerate_inplace_aliasing)
1211        return inner_incsubtensor
1212    else:
1213        raise TypeError('x must be the result of a subtensor operation')
1214
1215
1216class IncSubtensor(Op):
1217    """
1218    Increment a subtensor.
1219
1220    This is like numpy's
1221
1222        x[i,j,k] += y
1223
1224    It is used internally to implement the gradient on SubTensor.
1225
1226    Parameters
1227    ----------
1228    set_instead_of_inc
1229        If True set the subtensor to the value instead of incrementing it by
1230        that value.
1231
1232    """
1233
1234    check_input = False
1235    __props__ = ("idx_list", "inplace", "set_instead_of_inc")
1236
1237    def __init__(self, idx_list, inplace=False, set_instead_of_inc=False,
1238                 destroyhandler_tolerate_aliased=None):
1239        if destroyhandler_tolerate_aliased is None:
1240            destroyhandler_tolerate_aliased = []
1241        self.idx_list = list(map(Subtensor.convert, idx_list))
1242        self.inplace = inplace
1243        if inplace:
1244            self.destroy_map = {0: [0]}
1245        self.destroyhandler_tolerate_aliased = list(
1246            destroyhandler_tolerate_aliased)
1247        self.set_instead_of_inc = set_instead_of_inc
1248
1249    def __hash__(self):
1250        msg = []
1251        for entry in self.idx_list:
1252            if isinstance(entry, slice):
1253                msg += [(entry.start, entry.stop, entry.step)]
1254            else:
1255                msg += [entry]
1256
1257        idx_list = tuple(msg)
1258        # backport
1259        # idx_list = tuple((entry.start, entry.stop, entry.step)
1260        #                 if isinstance(entry, slice)
1261        #                 else entry
1262        #                 for entry in self.idx_list)
1263        return (hashtype(self) ^ hash(idx_list) ^ hash(self.inplace) ^
1264                hash(self.set_instead_of_inc))
1265
1266    def __str__(self):
1267        indices = []
1268        for entry in self.idx_list:
1269            if isinstance(entry, slice):
1270                indices.append(Subtensor.str_from_slice(entry))
1271            else:
1272                indices.append(str(entry))
1273        if self.inplace:
1274            msg = 'Inplace'
1275        else:
1276            msg = ''
1277        if not self.set_instead_of_inc:
1278            msg += 'Inc'
1279        else:
1280            msg += 'Set'
1281        return "%s{%s;%s}" % (
1282            self.__class__.__name__,
1283            msg,
1284            ", ".join(indices))
1285
1286    def make_node(self, x, y, *inputs):
1287        """
1288        Parameters
1289        ----------
1290        x
1291            The tensor to increment.
1292        y
1293            The value to increment by.
1294        inputs: TODO WRITEME
1295
1296        """
1297        x, y = map(theano.tensor.as_tensor_variable, [x, y])
1298        if y.ndim > x.ndim:
1299            raise ValueError(("Trying to increment a %d-dimensional "
1300                              "subtensor with a %d-dimensional value.") % (
1301                                  x.ndim, y.ndim))
1302        inputs = tuple(map(Subtensor.my_as_scalar, inputs))
1303
1304        idx_list = list(self.idx_list)
1305        if len(idx_list) > x.type.ndim:
1306            raise IndexError('too many indices for array')
1307
1308        input_types = Subtensor.collapse(
1309            idx_list,
1310            lambda entry: isinstance(entry, gof.Type))
1311        if len(inputs) != len(input_types):
1312            raise IndexError(
1313                "Not enough inputs to fill in the Subtensor template.",
1314                inputs, idx_list)
1315        for input, expected_type in izip(inputs, input_types):
1316            if input.type != expected_type:
1317                raise TypeError(
1318                    "Wrong type for Subtensor template. Expected %s, got %s."
1319                    % (input.type, expected_type))
1320
1321        return gof.Apply(self,
1322                         (x, y) + inputs,
1323                         [x.type()])
1324
1325    def decl_view(self):
1326        return "PyArrayObject * zview = NULL;"
1327
1328    def perform(self, node, inputs, out_):
1329        out, = out_
1330        x, y = inputs[:2]
1331        indices = list(reversed(inputs[2:]))
1332
1333        def convert(entry):
1334            if isinstance(entry, gof.Type):
1335                rval = indices.pop()
1336                if sys.version_info < (2, 5):
1337                    # Before Python 2.5, PySlice_GetIndicesEx requires
1338                    # Python int to be passed.
1339                    rval_ = int(rval)
1340                    if rval_ != rval:
1341                        raise IndexError((
1342                            "Invalid value for indexing: %s. "
1343                            "That value may be too big.") % rval)
1344                    return rval_
1345                return rval
1346            elif isinstance(entry, slice):
1347                return slice(convert(entry.start),
1348                             convert(entry.stop),
1349                             convert(entry.step))
1350            else:
1351                return entry
1352
1353        cdata = tuple(map(convert, self.idx_list))
1354        if len(cdata) == 1:
1355            cdata = cdata[0]
1356        if not self.inplace:
1357            x = x.copy()
1358        sub_x = x.__getitem__(cdata)
1359        if sub_x.shape:
1360            # we've sliced out an N-D tensor with N > 0
1361            if not self.set_instead_of_inc:
1362                sub_x += y
1363            else:
1364                # sub_x += -sub_x + y
1365                x.__setitem__(cdata, y)
1366        else:
1367            # scalar case
1368            if not self.set_instead_of_inc:
1369                x.__setitem__(cdata, sub_x + y)
1370            else:
1371                x.__setitem__(cdata, y)
1372        out[0] = x
1373
1374    def c_code(self, node, name, inputs, outputs, sub):
1375
1376        # This method delegates much of the work to helper
1377        # methods. This method implements the main logic
1378        # but subclasses may override the helper methods
1379        # to change the particulars, e.g. GpuIncSubtensor
1380        # turns the view/copy operations on numpy arrays
1381        # into the same operations on gpu arrays.
1382
1383        self.do_type_checking(node)
1384
1385        if self.inplace:  # convert bool to int
1386            inplace = 1
1387        else:
1388            inplace = 0
1389        x = inputs[0]
1390        y = inputs[1]
1391        z, = outputs
1392        if self.set_instead_of_inc:  # convert bool to int
1393            op_is_set = 1
1394        else:
1395            op_is_set = 0
1396        fail = sub['fail']
1397        view_ndim = (node.inputs[0].ndim -
1398                     np.sum([not isinstance(idx, slice)
1399                             for idx in self.idx_list]))
1400
1401        copy_of_x = self.copy_of_x(x)
1402
1403        copy_input_if_necessary = """
1404        if (%(inplace)s)
1405        {
1406            if (%(x)s != %(z)s)
1407            {
1408                Py_XDECREF(%(z)s);
1409                Py_INCREF(%(x)s);
1410                %(z)s = %(x)s;
1411            }
1412        }
1413        else
1414        {
1415            Py_XDECREF(%(z)s);
1416            %(z)s = %(copy_of_x)s;
1417            if (!%(z)s) {
1418                // Exception already set
1419                %(fail)s
1420            }
1421        }
1422        """ % locals()
1423
1424        # get info needed to make zview: a view of %(z)s
1425        helper_args = self.get_helper_c_code_args()
1426
1427        get_zview = Subtensor.helper_c_code(
1428            node=node,
1429            name=name,
1430            inputs=outputs[:1] + inputs[2:],
1431            outputs=outputs,
1432            sub=sub,
1433            idx_list=self.idx_list,
1434            view_ndim=view_ndim,
1435            ** helper_args
1436        )
1437
1438        # Make a view on the output, as we will write into it.
1439        alloc_zview = self.make_view_array(z, view_ndim)
1440
1441        build_view = """
1442        //TODO: give this Op a second output so that this view can be cached
1443        //TODO: alternatively, fix the memory leak on failure
1444        %(alloc_zview)s;
1445        if (!zview)
1446        {
1447            %(fail)s;
1448        }
1449        """ % locals()
1450
1451        copy_into = self.copy_into("zview", y)
1452
1453        add_to_zview = self.add_to_zview(name, y, fail)
1454
1455        make_modification = """
1456        if (%(op_is_set)s)
1457        {
1458            if (%(copy_into)s) // does broadcasting
1459            {
1460                Py_DECREF(zview);
1461                %(fail)s;
1462            }
1463        }
1464        else
1465        {
1466            %(add_to_zview)s
1467        }
1468        """ % locals()
1469        return (self.decl_view() +
1470                copy_input_if_necessary +
1471                "{" +
1472                get_zview +
1473                build_view +
1474                make_modification +
1475                "Py_DECREF(zview);" +
1476                "}"
1477                )
1478
1479    def do_type_checking(self, node):
1480        """
1481        Should raise NotImplementedError if c_code does not support
1482        the types involved in this node.
1483
1484        """
1485
1486        if not isinstance(node.inputs[0].type, theano.tensor.TensorType):
1487            raise NotImplementedError()
1488
1489    def c_code_cache_version(self):
1490        hv = Subtensor.helper_c_code_cache_version()
1491        if hv:
1492            return (3, hv)
1493        else:
1494            return ()
1495
1496    def copy_of_x(self, x):
1497        """
1498        Parameters
1499        ----------
1500        x
1501            A string giving the name of a C variable pointing to an array.
1502
1503        Returns
1504        -------
1505        object
1506            C code expression to make a copy of x.
1507
1508        Base class uses PyArrayObject *, subclasses may override for
1509        different types of arrays.
1510
1511        """
1512        # Parameters of PyArrary_FromAny are:
1513        # array
1514        # dtype: we pass NULL to say any dtype is acceptable, so the existing
1515        #        dtype will be copied
1516        # min_depth: we pass 0 to have this parameter ignored
1517        # max_depth: we pass 0 to have this parameter ignored
1518        # requirements: here we pass NPY_ARRAY_ENSURECOPY to force a copy
1519        # context: this is almost always NULL, I'm not sure what it's used for
1520        return """(PyArrayObject*)PyArray_FromAny(py_%(x)s, NULL, 0, 0,
1521                NPY_ARRAY_ENSURECOPY, NULL)""" % locals()
1522
1523    def make_view_array(self, x, view_ndim):
1524        """
1525        Parameters
1526        ----------
1527        x
1528            A string identifying an array to be viewed.
1529        view_ndim
1530            A string specifying the number of dimensions to have in the view.
1531
1532        This doesn't need to actually set up the view with the right indexing;
1533        we'll do that manually later.
1534
1535        """
1536
1537        return """Py_INCREF(PyArray_DESCR(%(x)s));
1538        zview = (PyArrayObject*)PyArray_NewFromDescr(
1539                &PyArray_Type,
1540                PyArray_DESCR(%(x)s),
1541                %(view_ndim)s,
1542                xview_dims, //PyArray_DIMS(%(x)s),
1543                xview_strides, //PyArray_STRIDES(%(x)s),
1544                PyArray_BYTES(%(x)s) + xview_offset, //PyArray_DATA(%(x)s),
1545                PyArray_FLAGS(%(x)s),
1546                NULL);
1547        """ % locals()
1548
1549    def get_helper_c_code_args(self):
1550        """
1551        Return a dictionary of arguments to pass to helper_c_code.
1552
1553        """
1554        return Subtensor.default_helper_c_code_args()
1555
1556    def copy_into(self, view, source):
1557        """
1558        Parameters
1559        ----------
1560        view : string
1561            C code expression for an array.
1562        source : string
1563            C code expression for an array.
1564
1565        Returns
1566        -------
1567        object
1568            C code expression to copy source into view, and 0 on success.
1569
1570        """
1571        return """PyArray_CopyInto(%(view)s, %(source)s)""" % locals()
1572
1573    def add_to_zview(self, name, x, fail):
1574        """
1575        Return C code to add x to zview. Should DECREF zview if the
1576        add fails.
1577
1578        """
1579
1580        return """
1581            PyArrayObject * add_rval = (PyArrayObject*)PyNumber_InPlaceAdd(
1582                    (PyObject*)zview, py_%(x)s);
1583            if (add_rval)
1584            {
1585                assert (PyArray_Check((PyObject*)add_rval));
1586                assert (PyArray_DATA(add_rval) == PyArray_DATA(zview));
1587                Py_DECREF(add_rval);
1588            }
1589            else
1590            {
1591                Py_DECREF(zview);
1592                %(fail)s;
1593            }""" % locals()
1594
1595    def infer_shape(self, node, shapes):
1596        return [shapes[0]]
1597
1598    def R_op(self, inputs, eval_points):
1599        if eval_points[0] is None or eval_points[1] is None:
1600            return [None]
1601        # Again we ignore eval points for indices because incsubtensor is
1602        # not differentiable wrt to those
1603        return self(eval_points[0], eval_points[1], *inputs[2:],
1604                    **dict(return_list=True))
1605
1606    def connection_pattern(self, node):
1607
1608        rval = [[True], [True]]
1609
1610        for ipt in node.inputs[2:]:
1611            rval.append([False])
1612
1613        return rval
1614
1615    def grad(self, inputs, grads):
1616        g_output, = grads
1617        x, y = inputs[:2]
1618        idx_list = inputs[2:]
1619
1620        if x.dtype in theano.tensor.discrete_dtypes:
1621            # The output dtype is the same as x
1622            gx = x.zeros_like(dtype=theano.config.floatX)
1623            if y.dtype in theano.tensor.discrete_dtypes:
1624                gy = y.zeros_like(dtype=theano.config.floatX)
1625            else:
1626                gy = y.zeros_like()
1627        elif x.dtype in theano.tensor.complex_dtypes:
1628            raise NotImplementedError("No support for complex grad yet")
1629        else:
1630            if self.set_instead_of_inc:
1631                gx = set_subtensor(
1632                    Subtensor(idx_list=self.idx_list)(g_output, *idx_list),
1633                    theano.tensor.zeros_like(y))
1634            else:
1635                gx = g_output
1636            gy = Subtensor(idx_list=self.idx_list)(g_output, *idx_list)
1637            gy = _sum_grad_over_bcasted_dims(y, gy)
1638
1639        return [gx, gy] + [DisconnectedType()()] * len(idx_list)
1640
1641
1642def _sum_grad_over_bcasted_dims(x, gx):
1643    """
1644    Sum of gx over dimensions to reproduce x.broadcastable.
1645
1646    This is useful to sum gradients over certain dimensions when
1647    x has been broadcasted, and we need to sum the gradient contributions
1648    over all duplications.
1649
1650    """
1651    if gx.broadcastable != x.broadcastable:
1652        x_dim_added = gx.ndim - x.ndim
1653        x_broad = (True,) * x_dim_added + x.broadcastable
1654        assert sum(gx.broadcastable) < sum(x_broad)
1655        axis_to_sum = []
1656        for i in xrange(gx.ndim):
1657            if gx.broadcastable[i] is False and x_broad[i] is True:
1658                axis_to_sum.append(i)
1659            elif (gx.broadcastable[i] is True and
1660                  x_broad[i] is False):
1661                # This means that Theano was able to infer that
1662                # gx.shape[i] is 1, so x.shape[i] is 1, but we
1663                # didn't know it. It is fine.
1664                pass
1665            else:
1666                assert gx.broadcastable[i] == x_broad[i]
1667        gx = gx.sum(axis=axis_to_sum, keepdims=True)
1668        if gx.ndim != x.ndim:
1669            assert gx.ndim > x.ndim
1670            for i in xrange(x_dim_added):
1671                assert gx.broadcastable[i]
1672            gx = gx.dimshuffle(*list(range(x_dim_added, gx.ndim)))
1673        assert gx.broadcastable == x.broadcastable
1674    return gx
1675
1676
1677#########################
1678# Advanced indexing
1679#########################
1680#
1681# Should reproduce numpy's behaviour, see url:
1682# docs.scipy.org/doc/numpy/reference/arrays.indexing.html#advanced-indexing
1683
1684
1685class AdvancedSubtensor1(Op):
1686    """
1687    Implement x[ilist] where ilist is a vector of integers.
1688
1689    """
1690    # sparse_grad doesn't go in here since it only affects the output
1691    # of the grad() method.
1692    __props__ = ()
1693    _f16_ok = True
1694    check_input = False
1695
1696    def __init__(self, sparse_grad=False):
1697        self.sparse_grad = sparse_grad
1698
1699    def make_node(self, x, ilist):
1700        x_ = theano.tensor.as_tensor_variable(x)
1701        ilist_ = theano.tensor.as_tensor_variable(ilist)
1702        if ilist_.type.dtype not in theano.tensor.integer_dtypes:
1703            raise TypeError('index must be integers')
1704        if ilist_.type.ndim != 1:
1705            raise TypeError('index must be vector')
1706        if x_.type.ndim == 0:
1707            raise TypeError('cannot index into a scalar')
1708        bcast = (ilist_.broadcastable[0],) + x_.broadcastable[1:]
1709        return Apply(self, [x_, ilist_], [TensorType(dtype=x.dtype,
1710                                                     broadcastable=bcast)()])
1711
1712    def perform(self, node, inp, out_):
1713        x, i = inp
1714        out, = out_
1715        # Copy always implied by numpy advanced indexing semantic.
1716        if out[0] is not None and out[0].shape == (len(i),) + x.shape[1:]:
1717            o = out[0]
1718        else:
1719            o = None
1720
1721        # If i.dtype is more precise than numpy.intp (int32 on 32-bit machines,
1722        # int64 on 64-bit machines), numpy may raise the following error:
1723        # TypeError: array cannot be safely cast to required type.
1724        # We need to check if values in i can fit in numpy.intp, because
1725        # if they don't, that should be an error (no array can have that
1726        # many elements on a 32-bit arch).
1727        if i.dtype != np.intp:
1728            i_ = theano._asarray(i, dtype=np.intp)
1729            if not np.can_cast(i.dtype, np.intp):
1730                # Check if there was actually an incorrect conversion
1731                if np.any(i != i_):
1732                    raise IndexError(
1733                        'index contains values that are bigger '
1734                        'than the maximum array size on this system.', i)
1735            i = i_
1736
1737        out[0] = x.take(i, axis=0, out=o)
1738
1739    def connection_pattern(self, node):
1740        rval = [[True]]
1741
1742        for ipt in node.inputs[1:]:
1743            rval.append([False])
1744
1745        return rval
1746
1747    def grad(self, inputs, grads):
1748        global sparse_module_ref
1749        x, ilist = inputs
1750        gz, = grads
1751        assert len(inputs) == 2
1752        if self.sparse_grad:
1753            if x.type.ndim != 2:
1754                raise TypeError(
1755                    "AdvancedSubtensor1: you can't take the sparse grad"
1756                    " from a tensor with ndim != 2. ndim is " +
1757                    str(x.type.ndim))
1758            if sparse_module_ref is None:
1759                import theano.sparse as sparse_module_ref
1760
1761            rval1 = [sparse_module_ref.construct_sparse_from_list(x, gz,
1762                                                                  ilist)]
1763        else:
1764            if x.dtype in theano.tensor.discrete_dtypes:
1765                # The output dtype is the same as x
1766                gx = x.zeros_like(dtype=theano.config.floatX)
1767            elif x.dtype in theano.tensor.complex_dtypes:
1768                raise NotImplementedError("No support for complex grad yet")
1769            else:
1770                gx = x.zeros_like()
1771            rval1 = [advanced_inc_subtensor1(gx, gz, ilist)]
1772        return rval1 + [DisconnectedType()()] * (len(inputs) - 1)
1773
1774    def R_op(self, inputs, eval_points):
1775        if eval_points[0] is None:
1776            return [None]
1777        return self.make_node(eval_points[0], *inputs[1:]).outputs
1778
1779    def infer_shape(self, node, ishapes):
1780        x, ilist = ishapes
1781        return [ilist + x[1:]]
1782
1783    def c_support_code(self):
1784        # In some versions of numpy, NPY_MIN_INTP is defined as MIN_LONG,
1785        # which is not defined. It should be NPY_MIN_LONG instead in that case.
1786        return dedent("""\
1787                #ifndef MIN_LONG
1788                #define MIN_LONG NPY_MIN_LONG
1789                #endif""")
1790
1791    def c_code(self, node, name, input_names, output_names, sub):
1792        if self.__class__ is not AdvancedSubtensor1:
1793            raise MethodNotDefined(
1794                "c_code defined for AdvancedSubtensor1,"
1795                " not for child class", type(self))
1796        a_name, i_name = input_names[0], input_names[1]
1797        output_name = output_names[0]
1798        fail = sub['fail']
1799        return """
1800            PyArrayObject *indices;
1801            int i_type = PyArray_TYPE(%(i_name)s);
1802            if (i_type != NPY_INTP) {
1803                // Cast %(i_name)s to NPY_INTP (expected by PyArray_TakeFrom),
1804                // if all values fit.
1805                if (!PyArray_CanCastSafely(i_type, NPY_INTP) &&
1806                    PyArray_SIZE(%(i_name)s) > 0) {
1807                    npy_int64 min_val, max_val;
1808                    PyObject* py_min_val = PyArray_Min(%(i_name)s, NPY_MAXDIMS,
1809                                                       NULL);
1810                    if (py_min_val == NULL) {
1811                        %(fail)s;
1812                    }
1813                    min_val = PyLong_AsLongLong(py_min_val);
1814                    Py_DECREF(py_min_val);
1815                    if (min_val == -1 && PyErr_Occurred()) {
1816                        %(fail)s;
1817                    }
1818                    PyObject* py_max_val = PyArray_Max(%(i_name)s, NPY_MAXDIMS,
1819                                                       NULL);
1820                    if (py_max_val == NULL) {
1821                        %(fail)s;
1822                    }
1823                    max_val = PyLong_AsLongLong(py_max_val);
1824                    Py_DECREF(py_max_val);
1825                    if (max_val == -1 && PyErr_Occurred()) {
1826                        %(fail)s;
1827                    }
1828                    if (min_val < NPY_MIN_INTP || max_val > NPY_MAX_INTP) {
1829                        PyErr_SetString(PyExc_IndexError,
1830                                     "Index contains values "
1831                                     "that are bigger than the maximum array "
1832                                     "size on this system.");
1833                        %(fail)s;
1834                    }
1835                }
1836                indices = (PyArrayObject*) PyArray_Cast(%(i_name)s, NPY_INTP);
1837                if (indices == NULL) {
1838                    %(fail)s;
1839                }
1840            }
1841            else {
1842                 indices = %(i_name)s;
1843                 Py_INCREF(indices);
1844            }
1845            if (%(output_name)s != NULL) {
1846                npy_intp nd, i, *shape;
1847                nd = PyArray_NDIM(%(a_name)s) + PyArray_NDIM(indices) - 1;
1848                if (PyArray_NDIM(%(output_name)s) != nd) {
1849                    Py_CLEAR(%(output_name)s);
1850                }
1851                else {
1852                    shape = PyArray_DIMS(%(output_name)s);
1853                    for (i = 0; i < PyArray_NDIM(indices); i++) {
1854                        if (shape[i] != PyArray_DIMS(indices)[i]) {
1855                            Py_CLEAR(%(output_name)s);
1856                            break;
1857                        }
1858                    }
1859                    if (%(output_name)s != NULL) {
1860                        for (; i < nd; i++) {
1861                            if (shape[i] != PyArray_DIMS(%(a_name)s)[
1862                                                i-PyArray_NDIM(indices)+1]) {
1863                                Py_CLEAR(%(output_name)s);
1864                                break;
1865                            }
1866                        }
1867                    }
1868                }
1869            }
1870            %(output_name)s = (PyArrayObject*)PyArray_TakeFrom(
1871                        %(a_name)s, (PyObject*)indices, 0, %(output_name)s, NPY_RAISE);
1872            Py_DECREF(indices);
1873            if (%(output_name)s == NULL) %(fail)s;
1874        """ % locals()
1875
1876    def c_code_cache_version(self):
1877        return (0, 1, 2)
1878
1879advanced_subtensor1 = AdvancedSubtensor1()
1880
1881
1882class AdvancedIncSubtensor1(Op):
1883    """
1884    Increments a subtensor using advanced slicing (list of index).
1885
1886    """
1887
1888    __props__ = ('inplace', 'set_instead_of_inc')
1889    check_input = False
1890    params_type = ParamsType(inplace=scal.bool,
1891                             set_instead_of_inc=scal.bool)
1892
1893    def __init__(self, inplace=False, set_instead_of_inc=False):
1894        self.inplace = bool(inplace)
1895        self.set_instead_of_inc = bool(set_instead_of_inc)
1896        if inplace:
1897            self.destroy_map = {0: [0]}
1898
1899    def clone_inplace(self):
1900        return self.__class__(
1901            inplace=True,
1902            set_instead_of_inc=self.set_instead_of_inc)
1903
1904    def __str__(self):
1905        if self.inplace:
1906            msg = "inplace"
1907        else:
1908            msg = "no_inplace"
1909        if self.set_instead_of_inc:
1910            msg += ",set"
1911        else:
1912            msg += ",inc"
1913
1914        return self.__class__.__name__ + "{%s}" % msg
1915
1916    def make_node(self, x, y, ilist):
1917        x_ = theano.tensor.as_tensor_variable(x)
1918        y_ = theano.tensor.as_tensor_variable(y)
1919        ilist_ = theano.tensor.as_tensor_variable(ilist)
1920
1921        if ilist_.type.dtype not in theano.tensor.integer_dtypes:
1922            raise TypeError('index must be integers')
1923        if ilist_.type.ndim != 1:
1924            raise TypeError('index must be vector')
1925        if x_.type.ndim == 0:
1926            raise TypeError('cannot index into a scalar')
1927        if y_.type.ndim > x_.type.ndim:
1928            if self.set_instead_of_inc:
1929                opname = 'set'
1930            else:
1931                opname = 'increment'
1932            raise TypeError(
1933                'cannot %s x subtensor with ndim=%s'
1934                ' by y with ndim=%s to x subtensor with ndim=%s ' % (
1935                    opname, x_.type.ndim, y_.type.ndim, '?'))  # FIXME: too few args for format string
1936
1937        return Apply(self, [x_, y_, ilist_], [x_.type()])
1938
1939    def copy_of_x(self, x):
1940        """
1941        Parameters
1942        ----------
1943        x : string
1944            Gives the name of a C variable pointing to an array.
1945
1946        Returns
1947        -------
1948        object
1949            C code expression to make a copy of x.
1950
1951        Base class uses PyArrayObject *, subclasses may override for
1952        different types of arrays.
1953
1954        """
1955        # Parameters of PyArrary_FromAny are:
1956        # array
1957        # dtype: we pass NULL to say any dtype is acceptable, so the existing
1958        #        dtype will be copied
1959        # min_depth: we pass 0 to have this parameter ignored
1960        # max_depth: we pass 0 to have this parameter ignored
1961        # requirements: here we pass NPY_ARRAY_ENSURECOPY to force a copy
1962        # context: this is almost always NULL, I'm not sure what it's used for
1963        return """(PyArrayObject*)PyArray_FromAny(py_%(x)s, NULL, 0, 0,
1964                NPY_ARRAY_ENSURECOPY, NULL)""" % locals()
1965
1966    def c_support_code(self):
1967        return inc_code()
1968
1969    def c_code(self, node, name, input_names, output_names, sub):
1970        numpy_ver = [int(n) for n in np.__version__.split('.')[:2]]
1971        if bool(numpy_ver < [1, 8]):
1972            raise NotImplementedError
1973        x, y, idx = input_names
1974        out = output_names[0]
1975        copy_of_x = self.copy_of_x(x)
1976
1977        return """
1978        PyObject* rval = NULL;
1979        if (%(params)s->inplace)
1980        {
1981            if (%(x)s != %(out)s)
1982            {
1983                Py_XDECREF(%(out)s);
1984                Py_INCREF(%(x)s);
1985                %(out)s = %(x)s;
1986            }
1987        }
1988        else
1989        {
1990            Py_XDECREF(%(out)s);
1991            %(out)s = %(copy_of_x)s;
1992            if (!%(out)s) {
1993                // Exception already set
1994                %(fail)s
1995            }
1996        }
1997        if (inplace_increment(%(out)s, (PyObject *)%(idx)s, %(y)s, (1 - %(params)s->set_instead_of_inc))) {
1998            %(fail)s;
1999        }
2000        Py_XDECREF(rval);
2001        """ % dict(x=x, y=y, idx=idx, out=out, copy_of_x=copy_of_x,
2002                   params=sub['params'], fail=sub['fail'])
2003
2004    def c_code_cache_version(self):
2005        return (8,)
2006
2007    def perform(self, node, inp, out_, params):
2008        # TODO opt to make this inplace
2009        x, y, idx = inp
2010        out, = out_
2011        if not self.inplace:
2012            x = x.copy()
2013        # In Numpy, x[idx] += y doesn't work if the same index is present
2014        # many times: it does it only once. Is it a bug? In any case, for
2015        # this reason we implement our own 'inc' iteration.
2016
2017        if self.set_instead_of_inc:
2018            x[idx] = y
2019        else:
2020            np.add.at(x, idx, y)
2021
2022        out[0] = x
2023
2024    def infer_shape(self, node, ishapes):
2025        x, y, ilist = ishapes
2026        return [x]
2027
2028    def R_op(self, inputs, eval_points):
2029        if None in eval_points[:2]:
2030            return [None]
2031        return self.make_node(eval_points[0], eval_points[1],
2032                              *inputs[2:]).outputs
2033
2034    def connection_pattern(self, node):
2035
2036        rval = [[True], [True], [False]]
2037        return rval
2038
2039    def grad(self, inputs, grads):
2040        g_output, = grads
2041        x, y, idx_list = inputs
2042        if x.dtype in theano.tensor.discrete_dtypes:
2043            # The output dtype is the same as x
2044            gx = x.zeros_like(dtype=theano.config.floatX)
2045            if y.dtype in theano.tensor.discrete_dtypes:
2046                gy = y.zeros_like(dtype=theano.config.floatX)
2047            else:
2048                gy = y.zeros_like()
2049        elif x.dtype in theano.tensor.complex_dtypes:
2050            raise NotImplementedError("No support for complex grad yet")
2051        else:
2052            if self.set_instead_of_inc:
2053                gx = advanced_set_subtensor1(
2054                    g_output,
2055                    y.zeros_like(),
2056                    idx_list)
2057            else:
2058                gx = g_output
2059            gy = advanced_subtensor1(g_output, idx_list)
2060            gy = _sum_grad_over_bcasted_dims(y, gy)
2061
2062        return [gx, gy] + [DisconnectedType()()]
2063
2064advanced_inc_subtensor1 = AdvancedIncSubtensor1()
2065advanced_set_subtensor1 = AdvancedIncSubtensor1(set_instead_of_inc=True)
2066
2067
2068def as_index_variable(idx):
2069    if idx is None:
2070        return NoneConst.clone()
2071    if isinstance(idx, slice):
2072        return make_slice(idx)
2073    if isinstance(idx, gof.Variable) and isinstance(idx.type, SliceType):
2074        return idx
2075    if isinstance(idx, gof.Variable) and isinstance(idx.type, NoneTypeT):
2076        return idx
2077    idx = theano.tensor.as_tensor_variable(idx)
2078    if idx.type.dtype not in theano.tensor.discrete_dtypes:
2079        raise TypeError('index must be integers or a boolean mask')
2080    return idx
2081
2082
2083def adv_index_broadcastable_pattern(a, idx):
2084    """
2085    This function is only used to determine the broadcast pattern for
2086    AdvancedSubtensor output variable.
2087
2088    For this, we make a fake ndarray and a fake idx and call use ask numpy
2089    the output. From this, we find the output broadcast pattern.
2090
2091    """
2092
2093    def replace_slice(v):
2094        if isinstance(v, gof.Apply):
2095            if len(v.outputs) != 1:
2096                raise ValueError(
2097                    "It is ambiguous which output of a multi-output Op has"
2098                    " to be fetched.", v)
2099            else:
2100                v = v.outputs[0]
2101
2102        if NoneConst.equals(v):
2103            return None
2104        if isinstance(v.type, SliceType):
2105            return slice(None, None)
2106
2107        if v.dtype == 'bool':
2108            return np.ones((2,) * v.ndim, v.dtype)
2109        else:
2110            return np.zeros((2,) * v.ndim, int)
2111
2112    newidx = tuple(map(replace_slice, idx))
2113
2114    # 2 - True = 1; 2 - False = 2
2115    fakeshape = [2 - bc for bc in a.broadcastable]
2116    retshape = np.empty(fakeshape)[newidx].shape
2117    return tuple([dim == 1 for dim in retshape])
2118
2119
2120def check_advanced_indexing_dimensions(input, idx_list):
2121    """
2122    This function checks if the index list in idx_list is correct.
2123    If there are any boolean masks, we check if the mask has the
2124    same shape as the input. This is enforced in NumPy 0.13.0 and
2125    newer, but not by earlier versions. If the size is not the same,
2126    this method raises an IndexError.
2127    """
2128    dim_seen = 0
2129    for index in idx_list:
2130        if index is np.newaxis:
2131            # skip, does not count as an input dimension
2132            pass
2133        elif isinstance(index, np.ndarray) and index.dtype == 'bool':
2134            for i in xrange(index.ndim):
2135                if index.shape[i] != input.shape[dim_seen + i]:
2136                    raise IndexError('boolean index did not match indexed array '
2137                                     'along dimension %d; dimension is %d but '
2138                                     'corresponding boolean dimension is %d' %
2139                                     (dim_seen + i, input.shape[dim_seen + i],
2140                                      index.shape[i]))
2141            dim_seen += index.ndim
2142        else:
2143            dim_seen += 1
2144
2145
2146def check_and_reject_bool(args_el):
2147    try:
2148        if (isinstance(args_el, (np.bool_, bool)) or
2149                args_el.dtype == 'bool'):
2150            raise TypeError('AdvancedSubtensor does not support boolean '
2151                            'masks for indexing. Use AdvancedBooleanSubtensor '
2152                            'instead. ')
2153    except AttributeError:
2154        pass
2155
2156    if (not isinstance(args_el, theano.tensor.Variable) and
2157            isinstance(args_el, Iterable)):
2158        for el in args_el:
2159            check_and_reject_bool(el)
2160
2161
2162class BaseAdvancedSubtensor(Op):
2163    """
2164    Abstract base class for AdvancedSubtensor and AdvancedBooleanSubtensor.
2165    Implements advanced indexing with boolean masks.
2166
2167    """
2168
2169    # Should be used by __getitem__ and __getslice__, as follows:
2170    # AdvancedSubtensor()(self, *args) or
2171    # AdvancedBooleanSubtensor()(self, *args),
2172    # if args contains and advanced indexing pattern
2173    __props__ = ()
2174
2175    def make_node(self, x, *index):
2176        x = theano.tensor.as_tensor_variable(x)
2177
2178        index = tuple(map(as_index_variable, index))
2179        bcast = adv_index_broadcastable_pattern(x, index)
2180        return gof.Apply(self,
2181                         (x,) + index,
2182                         [theano.tensor.tensor(dtype=x.type.dtype,
2183                                               broadcastable=bcast)])
2184
2185    def R_op(self, inputs, eval_points):
2186        if eval_points[0] is None:
2187            return [None]
2188        return self.make_node(eval_points[0], *inputs[1:]).outputs
2189
2190    def infer_shape(self, node, ishapes):
2191        # Default case, we don't know
2192        raise theano.tensor.basic.ShapeError("case not implemented")
2193
2194    def perform(self, node, inputs, out_):
2195        out, = out_
2196        check_advanced_indexing_dimensions(inputs[0], inputs[1:])
2197        rval = inputs[0].__getitem__(tuple(inputs[1:]))
2198        # When there are no arrays, we are not actually doing advanced
2199        # indexing, so __getitem__ will not return a copy.
2200        # Since no view_map is set, we need to copy the returned value
2201        if not any(isinstance(v.type, TensorType) and v.ndim > 0
2202                   for v in node.inputs[1:]):
2203            rval = rval.copy()
2204        out[0] = rval
2205
2206    def connection_pattern(self, node):
2207        rval = [[True]]
2208
2209        for ipt in node.inputs[1:]:
2210            rval.append([False])
2211
2212        return rval
2213
2214
2215class AdvancedSubtensor(BaseAdvancedSubtensor):
2216    """
2217    Return a subtensor copy, using advanced indexing.
2218
2219    """
2220
2221    # Should be used by __getitem__ and __getslice__, as follows:
2222    # AdvancedSubtensor()(self, *args),
2223    # if args contains and advanced indexing pattern
2224
2225    def make_node(self, x, *index):
2226        check_and_reject_bool(index)
2227        return super(AdvancedSubtensor, self).make_node(x, *index)
2228
2229    def infer_shape(self, node, ishapes):
2230        # Really special case
2231        if len(ishapes) == 3:
2232            xshp, ind1shp, ind2shp = ishapes
2233            if (len(xshp) == 2 and
2234                    ind1shp is not None and len(ind1shp) == 1 and
2235                    ind2shp is not None and len(ind2shp) == 1):
2236                # if the graph is correct, we can assume ind1shp[0] and
2237                # ind2shp[0] will have the same value.
2238                # Try to return the one closest to the graph input.
2239                if node.inputs[2].owner is None:
2240                    return [ind2shp]
2241                else:
2242                    return [ind1shp]
2243        return super(AdvancedSubtensor, self).infer_shape(node, ishapes)
2244
2245    def grad(self, inputs, grads):
2246        gz, = grads
2247        x = inputs[0]
2248        if x.dtype in theano.tensor.discrete_dtypes:
2249            # The output dtype is the same as x
2250            gx = x.zeros_like(dtype=theano.config.floatX)
2251        elif x.dtype in theano.tensor.complex_dtypes:
2252            raise NotImplementedError("No support for complex grad yet")
2253        else:
2254            gx = x.zeros_like()
2255        rest = inputs[1:]
2256        return [advanced_inc_subtensor(gx, gz, *rest)] + \
2257            [DisconnectedType()()] * len(rest)
2258advanced_subtensor = AdvancedSubtensor()
2259
2260
2261class AdvancedBooleanSubtensor(BaseAdvancedSubtensor):
2262    """
2263    Return a subtensor copy, using advanced indexing with boolean masks.
2264
2265    """
2266
2267    # Should be used by __getitem__ and __getslice__, as follows:
2268    # AdvancedBooleanSubtensor()(self, *args),
2269    # if args contains and advanced indexing pattern with boolean masks
2270
2271    def grad(self, inputs, grads):
2272        gz, = grads
2273        x = inputs[0]
2274        if x.dtype in theano.tensor.discrete_dtypes:
2275            # The output dtype is the same as x
2276            gx = x.zeros_like(dtype=theano.config.floatX)
2277        elif x.dtype in theano.tensor.complex_dtypes:
2278            raise NotImplementedError("No support for complex grad yet")
2279        else:
2280            gx = x.zeros_like()
2281        rest = inputs[1:]
2282        return [advanced_boolean_inc_subtensor(gx, gz, *rest)] + \
2283            [DisconnectedType()()] * len(rest)
2284advanced_boolean_subtensor = AdvancedBooleanSubtensor()
2285
2286
2287class BaseAdvancedIncSubtensor(Op):
2288    """
2289    Base class for AdvancedIncSubtensor and AdvancedBooleanIncSubtensor.
2290    Increments a subtensor using advanced indexing.
2291    """
2292
2293    __props__ = ("inplace", "set_instead_of_inc")
2294
2295    def __init__(self, inplace=False, set_instead_of_inc=False):
2296        self.inplace = inplace
2297        self.set_instead_of_inc = set_instead_of_inc
2298        # The assert is needed as in the pass the first argument was
2299        # something else that was not used.
2300        assert isinstance(inplace, bool)
2301        if self.inplace:
2302            raise NotImplementedError('In place computation is not'
2303                                      ' implemented')
2304
2305    def __str__(self):
2306        return "%s{%s, %s}" % (self.__class__.__name__,
2307                               "inplace=" + str(self.inplace),
2308                               " set_instead_of_inc=" +
2309                               str(self. set_instead_of_inc))
2310
2311    def make_node(self, x, y, *inputs):
2312        x = theano.tensor.as_tensor_variable(x)
2313        y = theano.tensor.as_tensor_variable(y)
2314
2315        new_inputs = []
2316        for inp in inputs:
2317            if isinstance(inp, (list, tuple)):
2318                inp = theano.tensor.as_tensor_variable(inp)
2319            new_inputs.append(inp)
2320        return gof.Apply(self,
2321                         (x, y) + tuple(new_inputs),
2322                         [theano.tensor.tensor(
2323                             dtype=x.type.dtype,
2324                             broadcastable=x.type.broadcastable)])
2325
2326    def perform(self, node, inputs, out_):
2327        # TODO: 1. opt to make this in place 2. generalize as described in
2328        # AdvancedSubtensor's perform TODO
2329
2330        check_advanced_indexing_dimensions(inputs[0], inputs[2:])
2331
2332        out, = out_
2333        if not self.inplace:
2334            out[0] = inputs[0].copy()
2335        else:
2336            out[0] = inputs[0]
2337
2338        if self.set_instead_of_inc:
2339            out[0][tuple(inputs[2:])] = inputs[1]
2340        else:
2341            np.add.at(out[0], tuple(inputs[2:]), inputs[1])
2342
2343    def infer_shape(self, node, ishapes):
2344        return [ishapes[0]]
2345
2346    def connection_pattern(self, node):
2347
2348        rval = [[True], [True]]
2349
2350        for ipt in node.inputs[2:]:
2351            rval.append([False])
2352
2353        return rval
2354
2355    def R_op(self, inputs, eval_points):
2356        if None in eval_points[:2]:
2357            return [None]
2358        return self.make_node(eval_points[0], eval_points[1],
2359                              *inputs[2:]).outputs
2360
2361
2362class AdvancedIncSubtensor(BaseAdvancedIncSubtensor):
2363    """
2364    Increments a subtensor using advanced indexing.
2365    """
2366
2367    def make_node(self, x, y, *inputs):
2368        check_and_reject_bool(inputs)
2369        return super(AdvancedIncSubtensor, self).make_node(x, y, *inputs)
2370
2371    def grad(self, inpt, output_gradients):
2372        x, y = inpt[:2]
2373        idxs = inpt[2:]
2374        outgrad, = output_gradients
2375        if x.dtype in theano.tensor.discrete_dtypes:
2376            # The output dtype is the same as x
2377            gx = x.zeros_like(dtype=theano.config.floatX)
2378            if y.dtype in theano.tensor.discrete_dtypes:
2379                gy = y.zeros_like(dtype=theano.config.floatX)
2380            else:
2381                gy = y.zeros_like()
2382        elif x.dtype in theano.tensor.complex_dtypes:
2383            raise NotImplementedError("No support for complex grad yet")
2384        else:
2385            if self.set_instead_of_inc:
2386                gx = advanced_set_subtensor(
2387                    outgrad,
2388                    y.zeros_like(),
2389                    *idxs)
2390            else:
2391                gx = outgrad
2392            gy = advanced_subtensor(outgrad, *idxs)
2393            # Make sure to sum gy over the dimensions of y that have been
2394            # added or broadcasted
2395            gy = _sum_grad_over_bcasted_dims(y, gy)
2396        return [gx, gy] + \
2397            [DisconnectedType()() for _ in idxs]
2398advanced_inc_subtensor = AdvancedIncSubtensor()
2399advanced_set_subtensor = AdvancedIncSubtensor(set_instead_of_inc=True)
2400
2401
2402class AdvancedBooleanIncSubtensor(BaseAdvancedIncSubtensor):
2403    """
2404    Increments a subtensor using advanced indexing with boolean masks.
2405    """
2406
2407    def grad(self, inpt, output_gradients):
2408        x, y = inpt[:2]
2409        idxs = inpt[2:]
2410        outgrad, = output_gradients
2411        if x.dtype in theano.tensor.discrete_dtypes:
2412            # The output dtype is the same as x
2413            gx = x.zeros_like(dtype=theano.config.floatX)
2414            if y.dtype in theano.tensor.discrete_dtypes:
2415                gy = y.zeros_like(dtype=theano.config.floatX)
2416            else:
2417                gy = y.zeros_like()
2418        elif x.dtype in theano.tensor.complex_dtypes:
2419            raise NotImplementedError("No support for complex grad yet")
2420        else:
2421            if self.set_instead_of_inc:
2422                gx = advanced_set_subtensor(
2423                    outgrad,
2424                    y.zeros_like(),
2425                    *idxs)
2426            else:
2427                gx = outgrad
2428            gy = advanced_boolean_subtensor(outgrad, *idxs)
2429            # Make sure to sum gy over the dimensions of y that have been
2430            # added or broadcasted
2431            gy = _sum_grad_over_bcasted_dims(y, gy)
2432        return [gx, gy] + \
2433            [DisconnectedType()() for _ in idxs]
2434advanced_boolean_inc_subtensor = AdvancedBooleanIncSubtensor()
2435advanced_boolean_set_subtensor = AdvancedBooleanIncSubtensor(set_instead_of_inc=True)
2436
2437
2438def take(a, indices, axis=None, mode='raise'):
2439    a = theano.tensor.as_tensor_variable(a)
2440    indices = theano.tensor.as_tensor_variable(indices)
2441    # Reuse advanced_subtensor1 if indices is a vector
2442    if indices.ndim == 1:
2443        if mode == 'clip':
2444            indices = clip(indices, 0, a.shape[axis] - 1)
2445        elif mode == 'wrap':
2446            indices = indices % a.shape[axis]
2447        if axis is None:
2448            return advanced_subtensor1(a.flatten(), indices)
2449        elif axis == 0:
2450            return advanced_subtensor1(a, indices)
2451        else:
2452            if axis < 0:
2453                axis += a.ndim
2454            assert axis >= 0
2455            shuffle = list(range(a.ndim))
2456            shuffle[0] = axis
2457            shuffle[axis] = 0
2458            return advanced_subtensor1(
2459                a.dimshuffle(shuffle), indices).dimshuffle(shuffle)
2460    if axis is None:
2461        shape = indices.shape
2462        ndim = indices.ndim
2463    else:
2464        # If axis is 0, don't generate a useless concatenation.
2465        if axis == 0:
2466            shape = theano.tensor.concatenate(
2467                [indices.shape, a.shape[axis + 1:]])
2468        else:
2469            if axis < 0:
2470                axis += a.ndim
2471            shape = theano.tensor.concatenate(
2472                [a.shape[:axis], indices.shape, a.shape[axis + 1:]])
2473        ndim = a.ndim + indices.ndim - 1
2474    return take(a, indices.flatten(), axis, mode).reshape(shape, ndim)
2475