1from __future__ import absolute_import, print_function, division
2""" Tensor optimizations addressing the ops in basic.py.
3"""
4# TODO: intelligent merge for mul/add
5# TODO: 0*x -> 0
6
7from collections import defaultdict
8import logging
9import itertools
10import operator
11import sys
12import time
13import traceback
14import warnings
15
16import numpy as np
17from six import integer_types, iteritems
18from six.moves import reduce, xrange
19
20import theano
21from theano import gof
22from theano.compat import izip
23from theano.gof import opt, InconsistencyError, TopoOptimizer, graph
24from theano.gof import Variable, Constant
25from theano.gof.opt import copy_stack_trace, in2out
26from theano.gof.utils import MethodNotDefined
27from theano.gradient import DisconnectedType
28from theano import config
29from theano.tensor.elemwise import Elemwise, DimShuffle
30from theano.tensor.subtensor import (get_idx_list, get_canonical_form_slice,
31                                     Subtensor, IncSubtensor, make_constant,
32                                     AdvancedIncSubtensor1,
33                                     AdvancedIncSubtensor,
34                                     AdvancedSubtensor1,
35                                     advanced_subtensor,
36                                     advanced_subtensor1,
37                                     advanced_inc_subtensor1)
38from theano.tensor.sort import TopKOp
39from theano import scalar
40from theano.scalar import basic
41from theano.tensor import basic as T
42from theano import compile  # to register the optimizer built by this file
43from theano.compile.ops import Shape, Shape_i
44from theano.tensor.type import (values_eq_approx_remove_inf,
45                                values_eq_approx_remove_nan,
46                                values_eq_approx_remove_inf_nan)
47
48from theano.gof.opt import (Optimizer, pre_constant_merge,
49                            pre_greedy_local_optimizer)
50from theano.gof import toolbox
51from theano.tensor.basic import (Alloc, get_scalar_constant_value, ShapeError,
52                                 extract_constant, NotScalarConstantError,
53                                 Reshape)
54from six import StringIO
55
56_logger = logging.getLogger('theano.tensor.opt')
57
58# Utilities
59
60
61def _fill_chain(new_out, orig_inputs):
62    for i in orig_inputs:
63        new_out = T.fill(i, new_out)
64    return [new_out]
65
66
67def encompasses_broadcastable(b1, b2):
68    """
69
70    Parameters
71    ----------
72    b1
73        The broadcastable attribute of a tensor type.
74    b2
75        The broadcastable attribute of a tensor type.
76
77    Returns
78    -------
79    bool
80        True if the broadcastable patterns b1 and b2 are such that b2 is
81        broadcasted to b1's shape and not the opposite.
82
83    """
84    if len(b1) < len(b2):
85        return False
86    b1 = b1[-len(b2):]
87    return not any(v1 and not v2 for v1, v2 in zip(b1, b2))
88
89
90def merge_broadcastables(broadcastables):
91    return [all(bcast) for bcast in zip(*broadcastables)]
92
93
94def scalarconsts_rest(inputs, elemwise=True, only_process_constants=False):
95    """Partition a list of variables into two kinds:
96    scalar constants, and the rest."""
97    consts = []
98    origconsts = []
99    nonconsts = []
100    for i in inputs:
101        try:
102            v = get_scalar_constant_value(i, elemwise=elemwise,
103                                          only_process_constants=only_process_constants)
104            consts.append(v)
105            origconsts.append(i)
106        except NotScalarConstantError:
107            nonconsts.append(i)
108    return consts, origconsts, nonconsts
109
110
111def broadcast_like(value, template, fgraph, dtype=None):
112    """
113    Return a Variable with the same shape and dtype as the template,
114    filled by broadcasting value through it. `value` will be cast as
115    necessary.
116
117    """
118    value = T.as_tensor_variable(value)
119    if value.type == template.type:
120        return value
121    if template not in fgraph.variables:
122        raise NotImplementedError('broadcast_like currently requires the '
123                                  'template Variable to be in the fgraph already')
124    if dtype is None:
125        dtype = template.dtype
126    value = T.cast(value, dtype)
127    if value.type == template.type:
128        return value
129    if hasattr(fgraph, 'shape_feature'):
130        new_shape = fgraph.shape_feature.shape_of[template]
131    else:
132        new_shape = template.shape
133    rval = T.alloc(value, *new_shape)
134    # the template may have 1s in its shape without being broadcastable
135    if rval.broadcastable != template.broadcastable:
136        rval = T.unbroadcast(rval, *[i for i in xrange(rval.ndim)
137                                     if rval.broadcastable[i] and
138                                     not template.broadcastable[i]])
139    assert rval.type.dtype == dtype
140
141    if rval.type.broadcastable != template.broadcastable:
142        raise AssertionError("rval.type.broadcastable is " +
143                             str(rval.type.broadcastable) +
144                             " but template.broadcastable is" +
145                             str(template.broadcastable))
146
147    return rval
148
149
150class InplaceElemwiseOptimizer(Optimizer):
151    """
152    We parametrise it to make it work for Elemwise and GpuElemwise op.
153    """
154    def __init__(self, OP):
155        self.op = OP
156
157    def add_requirements(self, fgraph):
158        fgraph.attach_feature(theano.gof.destroyhandler.DestroyHandler())
159
160    @staticmethod
161    def print_profile(stream, prof, level=0):
162        blanc = ('    ' * level)
163        print(blanc, "InplaceElemwiseOptimizer ", prof['opt'].op, file=stream)
164        for k in ['node_before',
165                  'nb_call_replace',
166                  'nb_call_validate',
167                  'nb_inconsistent']:
168            print(blanc, k, prof[k], file=stream)
169        ndim = prof['ndim']
170        if ndim:
171            print(blanc, "ndim", "nb", file=stream)
172            for n in sorted(ndim.keys()):
173                print(blanc, n, ndim[n], file=stream)
174
175    def apply(self, fgraph):
176        """
177        Usage: InplaceElemwiseOptimizer(op).optimize(fgraph)
178
179        Attempts to replace all Broadcast ops by versions of them
180        that operate inplace. It operates greedily: for each Broadcast
181        Op that is encountered, for each output, tries each input to
182        see if it can operate inplace on that input. If so, makes the
183        change and go to the next output or Broadcast Op.
184
185        Examples
186        --------
187
188            `x + y + z -> x += y += z`
189
190            `(x + y) * (x * y) -> (x += y) *= (x * y) or (x + y) *= (x *= y)`
191
192        """
193        # We should not validate too often as this takes too much time to
194        # execute!
195        # It is the _dfs_toposort() fct in theano/gof/destroyhandler.py
196        # that takes so much time.
197        # Should we try to use another lib that does toposort?
198        #   igraph: http://igraph.sourceforge.net/
199        #   networkx: https://networkx.lanl.gov/
200        # Should we try to use cython?
201        #   Compiling only that fct is not enough, should we try to add the
202        #   deque class too?
203        #   And init the deque and other list to an upper bound number of
204        #   elements?
205        # Maybe Theano should do online toposort as in
206        #   http://code.google.com/p/acyclic
207        #
208        # The next longest optimizer is the canonizer phase.
209        # Then I think it is the [io_?]toposort (need to validate) so check if
210        # the solution is also applicable there.
211
212        # We execute `validate` after this number of change.
213        prof = {'opt': self,
214                'node_before': len(fgraph.apply_nodes),
215                'nb_call_replace': 0,
216                'nb_call_validate': 0,
217                'nb_inconsistent': 0,
218                'ndim': defaultdict(lambda: 0)}
219
220        check_each_change = config.tensor.insert_inplace_optimizer_validate_nb
221        if check_each_change == -1:
222            if len(fgraph.apply_nodes) > 500:
223                check_each_change = 10
224            else:
225                check_each_change = 1
226
227        nb_change_no_validate = 0
228        chk = fgraph.checkpoint()
229
230        if fgraph.update_mapping:
231            update_outs = [fgraph.outputs[i] for i in fgraph.update_mapping]
232        else:
233            update_outs = []
234
235        protected_inputs = [
236            f.protected for f in fgraph._features if
237            isinstance(f, theano.compile.function_module.Supervisor)]
238        protected_inputs = sum(protected_inputs, [])  # flatten the list
239        protected_inputs.extend(fgraph.outputs)
240        for node in list(graph.io_toposort(fgraph.inputs, fgraph.outputs)):
241            op = node.op
242            # gpuarray GpuElemwise inherit from Elemwise
243            if not type(op) == self.op:
244                continue
245            # If big graph and the outputs are scalar, do not make it
246            # inplace.
247            if (check_each_change != 1 and
248                # If multiple outputs, they must all have the same size,
249                # so only check the first.
250                    getattr(node.outputs[0].type, 'ndim', -1) == 0):
251                continue
252
253            if op.inplace_pattern:
254                # Maybe this isn't needed anymore, but I don't want to
255                # rish regression now. This case only happen if the
256                # original node add already some inplace patter and we
257                # still try to add more pattern.
258
259                baseline = op.inplace_pattern
260                candidate_outputs = [i for i in xrange(len(node.outputs))
261                                     if i not in baseline]
262                # node inputs that are Constant, already destroyed,
263                # or fgraph protected inputs and fgraph outputs can't be used as
264                # inplace target.
265                # Remove here as faster.
266                candidate_inputs = [i for i in xrange(len(node.inputs))
267                                    if i not in baseline.values() and
268                                    not isinstance(node.inputs[i], Constant) and
269                                    # the next line should not be costly most of the time.
270                                    not fgraph.has_destroyers([node.inputs[i]]) and
271                                    node.inputs[i] not in protected_inputs]
272            else:
273                baseline = []
274                candidate_outputs = list(range(len(node.outputs)))
275                # node inputs that are Constant, already destroyed,
276                # fgraph protected inputs and fgraph outputs can't be used as inplace
277                # target.
278                # Remove here as faster.
279                candidate_inputs = [i for i in xrange(len(node.inputs))
280                                    if not isinstance(node.inputs[i], Constant) and
281                                    not fgraph.has_destroyers([node.inputs[i]]) and
282                                    node.inputs[i] not in protected_inputs]
283
284            verbose = False
285
286            raised_warning = not verbose
287
288            for candidate_output in candidate_outputs:
289
290                # If the output of the node can be established as an update
291                # output of the fgraph, visit the candidate_inputs in an order
292                # that will improve the chances of making the node operate
293                # inplace on the input it's meant to update
294                candidate_out_var = node.outputs[candidate_output]
295                sorted_candidate_inputs = candidate_inputs
296
297                if candidate_out_var in update_outs:
298
299                    # The candidate output is an update. Sort the
300                    # variables in candidate_inputs in the following order:
301                    # - Vars corresponding to the actual updated input
302                    #   (best case scenario is for the node that procudes
303                    #   an update to operate inplace on the variable to
304                    #   update)
305                    # - Vars computed inplace on the updates input (second
306                    #   best scenario if for the node to work inplace on
307                    #   a variable obtained by a chain of inplace on the
308                    #   variable to update. In some cases, this will be
309                    #   equivalent to operating inplace on the variable to
310                    #   update)
311                    # - Remaining variables
312                    updated_inputs = []
313                    for i, f_out in enumerate(fgraph.outputs):
314                        if (f_out is candidate_out_var and i in fgraph.update_mapping):
315                            updated_inp_idx = fgraph.update_mapping[i]
316                            updated_inputs.append(fgraph.inputs[updated_inp_idx])
317
318                    updated_vars = []
319                    vars_from_inplace = []
320                    other_vars = []
321                    for inp_idx in candidate_inputs:
322                        inp = node.inputs[inp_idx]
323                        if inp in updated_inputs:
324                            # the candidate input is the actual updated input
325                            updated_vars.append(inp_idx)
326                        elif (hasattr(fgraph, 'destroy_handler') and
327                              inp.owner and
328                              any([fgraph.destroy_handler.root_destroyer.get(up_inp, None) is inp.owner
329                                   for up_inp in updated_inputs])):
330
331                            # the candidate input is a variable computed
332                            # inplace on the updated input via a sequence of
333                            # one or more inplace operations
334                            vars_from_inplace.append(inp_idx)
335                        else:
336                            other_vars.append(inp_idx)
337
338                    sorted_candidate_inputs = (updated_vars +
339                                               vars_from_inplace + other_vars)
340
341                for candidate_input in sorted_candidate_inputs:
342                    # remove inputs that don't have the same dtype as the output
343                    if node.inputs[candidate_input].type != node.outputs[
344                            candidate_output].type:
345                        continue
346
347                    inplace_pattern = dict(baseline)
348                    inplace_pattern[candidate_output] = candidate_input
349                    try:
350                        if hasattr(op.scalar_op, "make_new_inplace"):
351                            new_scal = op.scalar_op.make_new_inplace(
352                                scalar.transfer_type(
353                                    *[inplace_pattern.get(i, o.dtype)
354                                      for i, o in enumerate(node.outputs)]))
355                        else:
356                            new_scal = op.scalar_op.__class__(
357                                scalar.transfer_type(
358                                    *[inplace_pattern.get(i, None)
359                                      for i in xrange(len(node.outputs))]))
360                        new_outputs = self.op(new_scal, inplace_pattern)(
361                            *node.inputs, **dict(return_list=True))
362                        new_node = new_outputs[0].owner
363
364                        for r, new_r in zip(node.outputs, new_outputs):
365                            prof['nb_call_replace'] += 1
366                            fgraph.replace(r, new_r,
367                                           reason="inplace_elemwise_optimizer")
368                        nb_change_no_validate += 1
369                        prof['ndim'][candidate_out_var.ndim] += 1
370                        if nb_change_no_validate >= check_each_change:
371                            prof['nb_call_validate'] += 1
372                            fgraph.validate()
373                            chk = fgraph.checkpoint()
374                            nb_change_no_validate = 0
375                    except (ValueError, InconsistencyError) as e:
376                        prof['nb_inconsistent'] += 1
377                        if check_each_change != 1 and not raised_warning:
378                            print(("Some inplace optimization was not "
379                                   "performed due to unexpected error:"),
380                                  file=sys.stderr)
381                            print(e, file=sys.stderr)
382                            raised_warning = True
383                        fgraph.revert(chk)
384                        continue
385                    candidate_inputs.remove(candidate_input)
386                    node = new_node
387                    baseline = inplace_pattern
388                    break
389
390        if nb_change_no_validate > 0:
391            try:
392                fgraph.validate()
393            except Exception:
394                if not raised_warning:
395                    print(("Some inplace optimization was not "
396                           "performed due to unexpected error"),
397                          file=sys.stderr)
398                fgraph.revert(chk)
399        return prof
400
401    def print_summary(self, stream=sys.stdout, level=0, depth=-1):
402        print("%s%s (%s)" % (
403            (' ' * level), self.__class__.__name__, self.op), file=stream)
404        return inplace_elemwise_optimizer
405
406inplace_elemwise_optimizer = InplaceElemwiseOptimizer(T.Elemwise)
407compile.optdb.register('inplace_elemwise_opt', inplace_elemwise_optimizer, 75,
408                       'inplace_opt',  # for historic reason
409                       'inplace_elemwise_optimizer',
410                       'fast_run', 'inplace')
411
412
413def register_useless(lopt, *tags, **kwargs):
414    if type(lopt) == str:
415        def register(inner_lopt):
416            return register_useless(inner_lopt, lopt, *tags, **kwargs)
417        return register
418    else:
419        name = kwargs.pop('name', None) or lopt.__name__
420
421        compile.mode.local_useless.register(name, lopt, 'last', 'fast_run',
422                                            *tags, **kwargs)
423        return lopt
424
425
426def register_canonicalize(lopt, *tags, **kwargs):
427    if type(lopt) == str:
428        def register(inner_lopt):
429            return register_canonicalize(inner_lopt, lopt, *tags, **kwargs)
430        return register
431    else:
432        name = kwargs.pop('name', None) or lopt.__name__
433        compile.optdb['canonicalize'].register(name, lopt, 'fast_run',
434                                               *tags, **kwargs)
435        return lopt
436
437
438def register_stabilize(lopt, *tags, **kwargs):
439    if type(lopt) == str:
440        def register(inner_lopt):
441            return register_stabilize(inner_lopt, lopt, *tags, **kwargs)
442        return register
443    else:
444        name = kwargs.pop('name', None) or lopt.__name__
445        compile.optdb['stabilize'].register(name, lopt, 'fast_run',
446                                            *tags, **kwargs)
447        return lopt
448
449
450def register_specialize(lopt, *tags, **kwargs):
451    if type(lopt) == str:
452        def register(inner_lopt):
453            return register_specialize(inner_lopt, lopt, *tags, **kwargs)
454        return register
455    else:
456        name = kwargs.pop('name', None) or lopt.__name__
457        compile.optdb['specialize'].register(name, lopt, 'fast_run',
458                                             *tags, **kwargs)
459        return lopt
460
461
462def register_uncanonicalize(lopt, *tags, **kwargs):
463    if type(lopt) == str:
464        def register(inner_lopt):
465            return register_uncanonicalize(inner_lopt, lopt, *tags, **kwargs)
466        return register
467    else:
468        name = (kwargs and kwargs.pop('name', None)) or lopt.__name__
469        compile.optdb['uncanonicalize'].register(name, lopt, 'fast_run', *tags,
470                                                 **kwargs)
471        return lopt
472
473
474def register_specialize_device(lopt, *tags, **kwargs):
475    if type(lopt) == str:
476        def register(inner_lopt):
477            return register_specialize_device(inner_lopt, lopt, *tags, **kwargs)
478        return register
479    else:
480        name = (kwargs and kwargs.pop('name', None)) or lopt.__name__
481        compile.optdb['specialize_device'].register(name, lopt, 'fast_run', *tags,
482                                                    **kwargs)
483        return lopt
484
485
486#####################
487# Dot optimizations #
488#####################
489
490@register_canonicalize
491@register_stabilize
492@gof.local_optimizer([T.Dot])
493def local_0_dot_x(node):
494    if not isinstance(node.op, T.Dot):
495        return False
496
497    x = node.inputs[0]
498    y = node.inputs[1]
499    replace = False
500    try:
501        if get_scalar_constant_value(x, only_process_constants=True) == 0:
502            replace = True
503    except NotScalarConstantError:
504        pass
505
506    try:
507        if get_scalar_constant_value(y, only_process_constants=True) == 0:
508            replace = True
509    except NotScalarConstantError:
510        pass
511
512    if replace:
513        constant_zero = T.constant(0, dtype=node.outputs[0].type.dtype)
514        if x.ndim == 2 and y.ndim == 2:
515            constant_zero = assert_(constant_zero,
516                                    T.eq(x.shape[1], y.shape[0]))
517            return [T.alloc(constant_zero, x.shape[0], y.shape[1])]
518        elif x.ndim == 1 and y.ndim == 2:
519            constant_zero = assert_(constant_zero,
520                                    T.eq(x.shape[0], y.shape[0]))
521            return [T.alloc(constant_zero, y.shape[1])]
522        elif x.ndim == 2 and y.ndim == 1:
523            constant_zero = assert_(constant_zero,
524                                    T.eq(x.shape[1], y.shape[0]))
525            return [T.alloc(constant_zero, x.shape[0])]
526        elif x.ndim == 1 and y.ndim == 1:
527            constant_zero = assert_(constant_zero,
528                                    T.eq(x.shape[0], y.shape[0]))
529            return [constant_zero]
530        else:
531            _logger.warning("Optimization Warning: "
532                            "Optimization theano/opt.py:local_0_dot_x Found "
533                            "that it could apply, but was not implemented "
534                            "for dot product with these input types:\n"
535                            "(%s, %s)",
536                            x.type, y.type)
537
538######################
539# DimShuffle lifters #
540######################
541
542
543def apply_local_dimshuffle_lift(var):
544    # return var
545    # lift recursively
546    if not var.owner:
547        return var
548    new = local_dimshuffle_lift.transform(var.owner)
549    if new:
550        return new[0]
551    return var
552
553
554# Checks for two types of useless dimshuffles:
555#   1 - dimshuffle all dimensions in order.
556#   2 - dimshuffle a broadcastable dimension.
557def is_dimshuffle_useless(new_order, input):
558    is_useless = True
559    if len(new_order) == input.type.ndim:
560        all_broadcastable_dims = [i for (i, is_broadcastable)
561                                  in enumerate(input.type.broadcastable)
562                                  if is_broadcastable] + ['x']
563        for i in range(input.type.ndim):
564            if (new_order[i] == i or
565                    (i in all_broadcastable_dims and
566                     new_order[i] in all_broadcastable_dims)):
567                is_useless = True
568            else:
569                is_useless = False
570                break
571    else:
572        is_useless = False
573    return is_useless
574
575
576@gof.local_optimizer([DimShuffle])
577def local_dimshuffle_lift(node):
578    """
579    "Lifts" DimShuffle through Elemwise operations and merges
580    consecutive DimShuffles. Basically, applies the following
581    transformations on the whole graph:
582
583    DimShuffle(Elemwise(x, y)) => Elemwise(DimShuffle(x), DimShuffle(y))
584    DimShuffle(DimShuffle(x)) => DimShuffle(x)
585    DimShuffle{0,1,...}(x) => x (when the dimshuffle do nothing)
586
587    After this transform, clusters of Elemwise operations are
588    void of DimShuffle operations.
589
590    """
591    op = node.op
592    if not isinstance(op, DimShuffle):
593        return False
594
595    input = node.inputs[0]
596    inode = input.owner
597    new_order = op.new_order
598    if inode and isinstance(inode.op, Elemwise) and (len(input.clients) == 1):
599        # Don't use make_node to have tag.test_value set.
600        new_inputs = []
601        for inp in inode.inputs:
602            new_inp = op.__class__(inp.type.broadcastable,
603                                   op.new_order)(inp)
604            new_inputs.append(apply_local_dimshuffle_lift(new_inp))
605        copy_stack_trace(node.outputs[0], new_inputs)
606        ret = inode.op(*new_inputs, **dict(return_list=True))
607        return ret
608    if inode and isinstance(inode.op, DimShuffle):
609        new_order = [x == 'x' and 'x' or inode.op.new_order[x] for x in
610                     new_order]
611        input = inode.inputs[0]
612
613    if is_dimshuffle_useless(new_order, input):
614        return [input]
615    elif inode and isinstance(inode.op, DimShuffle):
616        ret = op.__class__(input.type.broadcastable, new_order)(input)
617        ret = apply_local_dimshuffle_lift(ret)
618        copy_stack_trace(node.outputs[0], ret)
619        return [ret]
620
621
622@register_canonicalize
623@gof.local_optimizer([Reshape])
624def local_useless_dimshuffle_in_reshape(node):
625    """
626    Removes useless DimShuffle operation inside Reshape:
627
628      reshape(vector.dimshuffle('x', 0), shp) => reshape(vector, shp)
629      reshape(matrix.dimshuffle('x', 0, 'x', 1), shp) => reshape(matrix, shp)
630      reshape(row.dimshuffle(1, 'x'), shp) => reshape(row, shp)
631      reshape(col.dimshuffle(0), shp) => reshape(col, shp)
632
633    """
634    op = node.op
635    if not isinstance(op, Reshape):
636        return False
637    if not (node.inputs[0].owner is not None and
638            isinstance(node.inputs[0].owner.op, DimShuffle)):
639        return False
640
641    new_order = node.inputs[0].owner.op.new_order
642    input = node.inputs[0].owner.inputs[0]
643    broadcastables = node.inputs[0].broadcastable
644    new_order_of_nonbroadcast = []
645    for i, bd in zip(new_order, broadcastables):
646        if not bd:
647            new_order_of_nonbroadcast.append(i)
648    no_change_in_order = all(
649        new_order_of_nonbroadcast[i] <= new_order_of_nonbroadcast[i + 1]
650        for i in xrange(len(new_order_of_nonbroadcast) - 1))
651    if no_change_in_order:
652        shape = node.inputs[1]
653        ret = op.__class__(node.outputs[0].ndim)(input, shape)
654        copy_stack_trace(node.outputs[0], ret)
655        return [ret]
656
657
658@register_canonicalize
659@gof.local_optimizer([DimShuffle])
660def local_lift_transpose_through_dot(node):
661    """
662    dot(x,y).T -> dot(y.T, x.T)
663
664    These optimizations "lift" (propagate towards the inputs) DimShuffle
665    through dot product.  It allows to put the graph in a more standard shape,
666    and to later merge consecutive DimShuffles.
667
668    The transformation should be apply whether or not the transpose is
669    inplace.  The newly-introduced transpositions are not inplace, this will
670    be taken care of in a later optimization phase.
671
672    """
673    if not (isinstance(node.op, T.DimShuffle) and node.op.new_order == (1, 0)):
674        return False
675    if not (node.inputs[0].owner and
676            isinstance(node.inputs[0].owner.op, T.Dot)):
677        return False
678    x, y = node.inputs[0].owner.inputs
679
680    if x.ndim == y.ndim == 2:
681        # Output is dot product of transposed inputs in reverse order
682        ret = [T.dot(y.T, x.T)]
683
684        # Copy over stack trace to output from result of dot-product
685        copy_stack_trace(node.inputs[0], ret)
686        return ret
687
688register_canonicalize(local_dimshuffle_lift)
689register_specialize(local_dimshuffle_lift)
690
691######################
692# Casting operations #
693######################
694
695
696@register_canonicalize
697@register_specialize
698@gof.local_optimizer([T.TensorFromScalar])
699def local_tensor_scalar_tensor(node):
700    '''tensor_from_scalar(scalar_from_tensor(x)) -> x'''
701    if isinstance(node.op, T.TensorFromScalar):
702        s = node.inputs[0]
703        if s.owner and isinstance(s.owner.op, T.ScalarFromTensor):
704            t = s.owner.inputs[0]
705
706            # We don't need to copy over any stack traces here
707            return [t]
708
709
710@register_canonicalize
711@register_specialize
712@gof.local_optimizer([T.ScalarFromTensor])
713def local_scalar_tensor_scalar(node):
714    '''scalar_from_tensor(tensor_from_scalar(x)) -> x'''
715    if isinstance(node.op, T.ScalarFromTensor):
716        t = node.inputs[0]
717        if t.owner and isinstance(t.owner.op, T.TensorFromScalar):
718            s = t.owner.inputs[0]
719
720            # We don't need to copy over any stack traces here
721            return [s]
722
723#####################################
724# ShapeFeature, Shape optimizations
725#####################################
726
727
728class MakeVector(T.Op):
729    """Concatenate a number of scalars together into a vector.
730
731    This is a simple version of stack() that introduces far less cruft
732    into the graph. Should work with 0 inputs. The constant_folding
733    optimization will remove it.
734
735    """
736
737    __props__ = ("dtype",)
738
739    def __init__(self, dtype='int64'):
740        self.dtype = dtype
741
742    def make_node(self, *inputs):
743        inputs = list(map(T.as_tensor_variable, inputs))
744        if (not all(a.type == inputs[0].type for a in inputs) or
745                (len(inputs) > 0 and inputs[0].dtype != self.dtype)):
746            dtype = theano.scalar.upcast(self.dtype, *[i.dtype for i in inputs])
747            # upcast the input to the determined dtype,
748            # but don't downcast anything
749            assert dtype == self.dtype, (
750                "The upcast of the inputs to MakeVector should match the "
751                "dtype given in __init__.")
752            if not all(self.dtype == T.cast(i, dtype=dtype).dtype
753                       for i in inputs):
754                raise TypeError("MakeVector.make_node expected inputs"
755                                " upcastable to %s. got %s" %
756                                (self.dtype, str([i.dtype for i in inputs])))
757            inputs = [T.cast(i, dtype=dtype) for i in inputs]
758        assert all(self.dtype == a.dtype for a in inputs)
759        assert all(a.ndim == 0 for a in inputs)
760
761        if inputs:
762            dtype = inputs[0].type.dtype
763        else:
764            dtype = self.dtype
765        # bcastable = (len(inputs) == 1)
766        bcastable = False
767        otype = T.TensorType(broadcastable=(bcastable,), dtype=dtype)
768        return T.Apply(self, inputs, [otype()])
769
770    def perform(self, node, inputs, out_):
771        out, = out_
772        # not calling theano._asarray as optimization
773        if (out[0] is None) or (out[0].size != len(inputs)):
774            out[0] = theano._asarray(inputs, dtype=node.outputs[0].dtype)
775        else:
776            # assume that out has correct dtype. there is no cheap way to check
777            out[0][...] = inputs
778
779    def c_code_cache_version(self):
780        return (2,)
781
782    def c_code(self, node, name, inp, out_, sub):
783        out, = out_
784        # Shouldn't use PyArray_TYPE(inp[0]) for the dtype
785        # when len(inp) == 0 (we need to support this case.
786        # So there will be (1 * nb_dtype) + ((nb len(inp) - 1 ))
787        # different c code with the following algo
788        out_shape = len(inp)
789        out_num = np.dtype(node.outputs[0].dtype).num
790        # don't use dtype_%(out)s as when check_input=False, it isn't defined.
791        out_dtype = node.outputs[0].type.dtype_specs()[1]
792        if len(inp) > 0:
793            assert self.dtype == node.inputs[0].dtype
794            out_num = 'PyArray_TYPE(%s)' % inp[0]
795
796        ret = """
797        npy_intp dims[1];
798        dims[0] = %(out_shape)s;
799        if(!%(out)s || PyArray_DIMS(%(out)s)[0] != %(out_shape)s){
800            Py_XDECREF(%(out)s);
801            %(out)s = (PyArrayObject*)PyArray_EMPTY(1, dims, %(out_num)s, 0);
802        }
803        """ % locals()
804        for idx, i in enumerate(inp):
805            ret += """
806            *((%(out_dtype)s *)PyArray_GETPTR1(%(out)s, %(idx)s)) = *((%(out_dtype)s *) PyArray_DATA(%(i)s));
807            """ % locals()
808        return ret
809
810    def infer_shape(self, node, ishapes):
811        return [(len(ishapes),)]
812
813    def grad(self, inputs, output_gradients):
814        # If the output is of an integer dtype, no gradient shall pass
815        if self.dtype in theano.tensor.discrete_dtypes:
816            return [ipt.zeros_like().astype(theano.config.floatX)
817                    for ipt in inputs]
818
819        grads = []
820        for i, inp in enumerate(inputs):
821            grads.append(output_gradients[0][i])
822        return grads
823
824    def R_op(self, inputs, eval_points):
825        if None in eval_points:
826            return [None]
827        return self.make_node(*eval_points).outputs
828
829make_vector = MakeVector()
830
831
832class MakeVectorPrinter:
833    def process(self, r, pstate):
834        if r.owner is None:
835            raise TypeError("Can only print make_vector.")
836        elif isinstance(r.owner.op, MakeVector):
837            old_precedence = getattr(pstate, 'precedence', None)
838            try:
839                pstate.precedence = 1000
840                s = [pstate.pprinter.process(input)
841                     for input in r.owner.inputs]
842            finally:
843                pstate.precedence = old_precedence
844            return "[%s]" % ", ".join(s)
845        else:
846            raise TypeError("Can only print make_vector.")
847
848T.pprint.assign(MakeVector, MakeVectorPrinter())
849
850
851class ShapeFeature(object):
852    """Graph optimizer for removing all calls to shape().
853
854    This optimizer replaces all Shapes and Subtensors of Shapes with
855    Shape_i and MakeVector Ops.
856
857    This optimizer has several goals:
858
859    1. to 'lift' Shapes to as close to the inputs as possible.
860
861    2. to infer the shape of every node in the graph in terms of the
862       input shapes.
863
864    3. remove all fills (T.second, T.fill) from the graph
865
866    Lifting shapes as close to the inputs as possible is important for
867    canonicalization because it is very bad form to have to compute
868    something just to know how big it will be.  Firstly, it is a waste
869    of time to compute such outputs.  But it is important to get rid
870    of these outputs as early as possible in the compilation process
871    because the extra computations make it appear as if many internal
872    graph nodes have multiple clients.  Many optimizations refuse to
873    work on nodes with multiple clients.
874
875    Lifting is done by using an `<Op>.infer_shape` function if one is
876    present, or else using a conservative default.  An Op that
877    supports shape-lifting should define a infer_shape(self, node,
878    input_shapes) function.  The argument input_shapes is a tuple of
879    tuples... there is an interior tuple for each input to the node.
880    The tuple has as many elements as dimensions.  The element in
881    position i of tuple j represents the i'th shape component of the
882    j'th input.  The function should return a tuple of tuples.  One
883    output tuple for each node.output.  Again, the i'th element of the
884    j'th output tuple represents the output[j].shape[i] of the
885    function.  If an output is not a TensorType, then None should be
886    returned instead of a tuple for that output.
887
888    For example the infer_shape for a matrix-matrix product would accept
889    input_shapes=((x0,x1), (y0,y1)) and return ((x0, y1),).
890
891    Inferring the shape of internal nodes in the graph is important
892    for doing size-driven optimizations.  If we know how big various
893    intermediate results will be, we can estimate the cost of many Ops
894    accurately, and generate c-code that is specific [e.g. unrolled]
895    to particular sizes.
896
897    In cases where you cannot figure out the shape, raise a ShapeError.
898
899    Notes
900    -----
901    Right now there is only the ConvOp that could really take
902    advantage of this shape inference, but it is worth it even
903    just for the ConvOp.  All that's necessary to do shape
904    inference is 1) to mark shared inputs as having a particular
905    shape, either via a .tag or some similar hacking; and 2) to
906    add an optional In() argument to promise that inputs will
907    have a certain shape (or even to have certain shapes in
908    certain dimensions). We can't automatically infer the shape of
909    shared variables as they can change of shape during the
910    execution by default.  (NOT IMPLEMENTED YET, BUT IS IN TRAC)
911
912
913    **Using Shape information in Optimizations**
914
915    To use this shape information in OPTIMIZATIONS, use the
916    ``shape_of`` dictionary.
917
918    For example:
919
920    .. code-block:: python
921
922        try:
923            shape_of = node.fgraph.shape_feature.shape_of
924        except AttributeError:
925            # This can happen when the mode doesn't include the ShapeFeature.
926            return
927
928        shape_of_output_zero = shape_of[node.output[0]]
929
930    The ``shape_of_output_zero`` symbol will contain a tuple, whose
931    elements are either integers or symbolic integers.
932
933    TODO: check to see if the symbols are necessarily
934    non-constant... or are integer literals sometimes Theano
935    constants?? That would be confusing.
936
937    """
938    def get_node_infer_shape(self, node):
939        try:
940            shape_infer = node.op.infer_shape
941        except AttributeError:
942            shape_infer = self.default_infer_shape
943
944        try:
945            o_shapes = shape_infer(node,
946                                   [self.shape_of[r] for r in node.inputs])
947        except ShapeError:
948            o_shapes = self.default_infer_shape(node, [self.shape_of[r] for
949                                                       r in node.inputs])
950        except NotImplementedError as e:
951            raise NotImplementedError(
952                'Code called by infer_shape failed raising a '
953                'NotImplementedError. Raising NotImplementedError to '
954                'indicate that a shape cannot be computed is no longer '
955                'supported, and one should now use tensor.ShapeError '
956                'instead. The original exception message is: %s' % e)
957        except Exception as e:
958            msg = ('Failed to infer_shape from Op %s.\nInput shapes: '
959                   '%s\nException encountered during infer_shape: '
960                   '%s\nException message: %s\nTraceback: %s') % (
961                node.op, [self.shape_of[r] for r in node.inputs],
962                type(e), str(e), traceback.format_exc())
963            if config.on_shape_error == "raise":
964                raise Exception(msg)
965            else:
966                _logger.warning(msg)
967            o_shapes = self.default_infer_shape(
968                node, [self.shape_of[r] for r in node.inputs])
969
970        return o_shapes
971
972    def get_shape(self, var, idx):
973        """ Optimization can call this to get the current shape_i
974
975        It is better to call this then use directly shape_of[var][idx]
976        as this method should update shape_of if needed.
977
978        TODO: Up to now, we don't update it in all cases. Update in all cases.
979        """
980        r = self.shape_of[var][idx]
981        if (r.owner and
982                isinstance(r.owner.op, Shape_i) and
983                r.owner.inputs[0] not in var.fgraph.variables):
984            assert var.owner
985            node = var.owner
986            # recur on inputs
987            for i in node.inputs:
988                if getattr(i, 'ndim', None) > 0:
989                    self.get_shape(i, 0)
990            o_shapes = self.get_node_infer_shape(node)
991            assert len(o_shapes) == len(node.outputs)
992
993            # Only change the variables and dimensions that would introduce
994            # extra computation
995            for new_shps, out in zip(o_shapes, node.outputs):
996                if not hasattr(out, 'ndim'):
997                    continue
998
999                merged_shps = list(self.shape_of[out])
1000                changed = False
1001                for i in range(out.ndim):
1002                    n_r = merged_shps[i]
1003                    if (n_r.owner and
1004                            isinstance(n_r.owner.op, Shape_i) and
1005                            n_r.owner.inputs[0] not in var.fgraph.variables):
1006                        changed = True
1007                        merged_shps[i] = new_shps[i]
1008                if changed:
1009                    self.set_shape(out, merged_shps, override=True)
1010            r = self.shape_of[var][idx]
1011        return r
1012
1013    def shape_ir(self, i, r):
1014        """Return symbolic r.shape[i] for tensor variable r, int i."""
1015        if hasattr(r.type, "broadcastable") and r.type.broadcastable[i]:
1016            return self.lscalar_one
1017        else:
1018            # Do not call make_node for test_value
1019            s = Shape_i(i)(r)
1020            try:
1021                s = get_scalar_constant_value(s)
1022            except NotScalarConstantError:
1023                pass
1024            return s
1025
1026    def shape_tuple(self, r):
1027        """Return a tuple of symbolic shape vars for tensor variable r."""
1028        if not hasattr(r, 'ndim'):
1029            # This happen for NoneConst.
1030            return None
1031        return tuple([self.shape_ir(i, r) for i in xrange(r.ndim)])
1032
1033    def default_infer_shape(self, node, i_shapes):
1034        """Return a list of shape tuple or None for the outputs of node.
1035
1036        This function is used for Ops that don't implement infer_shape.
1037        Ops that do implement infer_shape should use the i_shapes parameter,
1038        but this default implementation ignores it.
1039
1040        """
1041        rval = []
1042        for r in node.outputs:
1043            try:
1044                rval.append(self.shape_tuple(r))
1045            except AttributeError:
1046                rval.append(None)
1047        return rval
1048
1049    def unpack(self, s_i, var):
1050        """Return a symbolic integer scalar for the shape element s_i.
1051
1052        The s_i argument was produced by the infer_shape() of an Op subclass.
1053
1054        var: the variable that correspond to s_i. This is just for
1055        error reporting.
1056
1057        """
1058        # unpack the s_i that the Op returned
1059        assert s_i is not None
1060        if s_i == 1:
1061            # don't make the optimizer merge a zillion ones together
1062            # by always returning the same object to represent 1
1063            return self.lscalar_one
1064        if type(s_i) is float and int(s_i) == s_i:
1065            s_i = int(s_i)
1066        if (type(s_i) in integer_types or
1067                isinstance(s_i, np.integer) or
1068                (isinstance(s_i, np.ndarray) and s_i.ndim == 0)):
1069            # this shape is a constant
1070            if s_i < 0:
1071                msg = "There is a negative shape in the graph!"
1072                msg += gof.utils.get_variable_trace_string(var)
1073                # The rest of the pipeline don't handle correctly this
1074                # case.  So we have 2 choices, stop compilation or
1075                # consider the shape as unknow.  As we have more
1076                # chance to give the stack trace here then later, I
1077                # choose that options as it would give better error
1078                # message.
1079                raise AssertionError(msg)
1080            return T.constant(s_i, dtype='int64')
1081        if type(s_i) in (tuple, list):
1082            # this dimension is the same as many of the inputs
1083            # which tells us that if one of the inputs is known,
1084            # the others all become known.
1085            # TODO: should be implemented in Elemwise, and Dot
1086            #
1087            # worst case, we loop over shape_of and replace things
1088            raise NotImplementedError(s_i)
1089
1090        # s_i is x.shape[i] for some x, we change it to shape_of[x][i]
1091        if (s_i.owner and
1092                isinstance(s_i.owner.op, Subtensor) and
1093                s_i.owner.inputs[0].owner and
1094                isinstance(s_i.owner.inputs[0].owner.op, T.Shape)):
1095            assert s_i.ndim == 0
1096            assert len(s_i.owner.op.idx_list) == 1
1097
1098            # The current Subtensor always put constant index in the graph.
1099            # This was not True in the past. So call the Subtensor function
1100            # that will return the right index.
1101            idx = get_idx_list(s_i.owner.inputs, s_i.owner.op.idx_list)
1102            assert len(idx) == 1
1103            idx = idx[0]
1104            try:
1105                i = get_scalar_constant_value(idx)
1106            except NotScalarConstantError:
1107                pass
1108            else:
1109                # Executed only if no exception was raised
1110                x = s_i.owner.inputs[0].owner.inputs[0]
1111                # x should already have been imported, and should be in shape_of.
1112                s_i = self.shape_of[x][i]
1113
1114        if s_i.type.dtype in theano.tensor.integer_dtypes:
1115            if getattr(s_i.type, 'ndim', 0):
1116                raise TypeError('Shape element must be scalar', s_i)
1117            return s_i
1118        else:
1119            raise TypeError('Unsupported shape element',
1120                            s_i, type(s_i), getattr(s_i, 'type', None))
1121
1122    def set_shape(self, r, s, override=False):
1123        """Assign the shape `s` to previously un-shaped variable `r`.
1124
1125        Parameters
1126        ----------
1127        r : a variable
1128        s : None or a tuple of symbolic integers
1129        override : If False, it mean r is a new object in the fgraph.
1130            If True, it mean r is already in the fgraph and we want to
1131            override its shape.
1132
1133        """
1134        if not override:
1135            assert r not in self.shape_of, 'r already in shape_of'
1136        if s is None:
1137            self.shape_of[r] = s
1138        else:
1139            if not isinstance(s, (tuple, list)):
1140                raise TypeError('shapes must be tuple/list', (r, s))
1141
1142            if r.ndim != len(s):
1143                sio = StringIO()
1144                theano.printing.debugprint(r, file=sio, print_type=True)
1145                raise AssertionError(
1146                    "Something inferred a shape with %d dimensions "
1147                    "for a variable with %d dimensions"
1148                    " for the variable:\n%s" % (
1149                        len(s), r.ndim, sio.getvalue()))
1150
1151            shape_vars = []
1152            for i in xrange(r.ndim):
1153                if (hasattr(r.type, 'broadcastable') and
1154                        r.type.broadcastable[i]):
1155                    shape_vars.append(self.lscalar_one)
1156                else:
1157                    shape_vars.append(self.unpack(s[i], r))
1158            assert all([not hasattr(r.type, "broadcastable") or
1159                        not r.type.broadcastable[i] or
1160                        # The two following comparison are a speed optimization
1161                        # But we never timed this speed optimization!
1162                        self.lscalar_one.equals(shape_vars[i]) or
1163                        self.lscalar_one.equals(
1164                            T.extract_constant(shape_vars[i]))
1165                        for i in xrange(r.ndim)])
1166            self.shape_of[r] = tuple(shape_vars)
1167            for sv in shape_vars:
1168                self.shape_of_reverse_index.setdefault(sv, set()).add(r)
1169
1170    def update_shape(self, r, other_r):
1171        """Replace shape of r by shape of other_r.
1172
1173        If, on some dimensions, the shape of other_r is not informative,
1174        keep the shape of r on those dimensions.
1175
1176        """
1177        # other_r should already have a shape
1178        assert other_r in self.shape_of, ('other_r not in shape_of', other_r)
1179        other_shape = self.shape_of[other_r]
1180
1181        # If other_shape has no information, call is pointless.
1182        if other_shape is None:
1183            return
1184
1185        if r in self.shape_of:
1186            r_shape = self.shape_of[r]
1187        else:
1188            # If no info is known on r's shape, use other_shape
1189            self.set_shape(r, other_shape)
1190            return
1191        if (other_r.owner and r.owner and
1192                other_r.owner.inputs == r.owner.inputs and
1193                other_r.owner.op == r.owner.op):
1194            # We are doing a merge. So the 2 shapes graph will be the
1195            # same.  This is only a speed optimization to call
1196            # ancestors() less frequently.
1197            return
1198
1199        # Merge other_shape with r_shape, giving the priority to other_shape
1200        merged_shape = []
1201        for i, ps in enumerate(other_shape):
1202            if r_shape is None and other_shape:
1203                merged_shape.append(other_shape[i])
1204            elif (ps.owner and
1205                    isinstance(getattr(ps.owner, 'op', None), Shape_i) and
1206                    ps.owner.op.i == i and
1207                    ps.owner.inputs[0] in (r, other_r)):
1208                # If other_shape[i] is uninformative, use r_shape[i].
1209                # For now, we consider 2 cases of uninformative other_shape[i]:
1210                #  - Shape_i(i)(other_r);
1211                #  - Shape_i(i)(r).
1212                merged_shape.append(r_shape[i])
1213            elif isinstance(r_shape[i], (Constant, integer_types)):
1214                # We do this to call less often ancestors and make
1215                # sure we have the simplest shape possible.
1216                merged_shape.append(r_shape[i])
1217            elif isinstance(other_shape[i], (Constant, integer_types)):
1218                # We do this to call less often ancestors and make
1219                # sure we have the simplest shape possible.
1220                merged_shape.append(other_shape[i])
1221            elif other_shape[i] == r_shape[i]:
1222                # This mean the shape is equivalent
1223                # We do not want to do the ancestor check in those cases
1224                merged_shape.append(r_shape[i])
1225            elif r_shape[i] in theano.gof.graph.ancestors([other_shape[i]]):
1226                # Another case where we want to use r_shape[i] is when
1227                # other_shape[i] actually depends on r_shape[i]. In that case,
1228                # we do not want to substitute an expression with another that
1229                # is strictly more complex. Such a substitution could also lead
1230                # to cycles: if (in the future) r_shape[i] gets replaced by an
1231                # expression of other_shape[i], other_shape[i] may end up
1232                # depending on itself.
1233                merged_shape.append(r_shape[i])
1234            else:
1235                merged_shape.append(other_shape[i])
1236        assert all([(not hasattr(r.type, "broadcastable") or
1237                     not r.type.broadcastable[i] and
1238                     not other_r.type.broadcastable[i]) or
1239                    # The two following comparison are a speed optimization
1240                    # But we never timed this speed optimization!
1241                    self.lscalar_one.equals(merged_shape[i]) or
1242                    self.lscalar_one.equals(
1243                        T.extract_constant(merged_shape[i], only_process_constants=True))
1244                    for i in xrange(r.ndim)])
1245        self.shape_of[r] = tuple(merged_shape)
1246        for sv in self.shape_of[r]:
1247            self.shape_of_reverse_index.setdefault(sv, set()).add(r)
1248
1249    def set_shape_i(self, r, i, s_i):
1250        '''Replace element i of shape_of[r] by s_i'''
1251        assert r in self.shape_of
1252        prev_shape = self.shape_of[r]
1253        # prev_shape is a tuple, so we cannot change it inplace,
1254        # so we build another one.
1255        new_shape = []
1256        for j, s_j in enumerate(prev_shape):
1257            if j == i:
1258                new_shape.append(self.unpack(s_i, r))
1259            else:
1260                new_shape.append(s_j)
1261        assert all([not hasattr(r.type, "broadcastable") or
1262                    not r.type.broadcastable[idx] or
1263                    # The two following comparison are a speed optimization
1264                    # But we never timed this speed optimization!
1265                    self.lscalar_one.equals(new_shape[idx]) or
1266                    self.lscalar_one.equals(T.extract_constant(new_shape[idx]))
1267                    for idx in xrange(r.ndim)])
1268        self.shape_of[r] = tuple(new_shape)
1269        for sv in self.shape_of[r]:
1270            self.shape_of_reverse_index.setdefault(sv, set()).add(r)
1271
1272    def init_r(self, r):
1273        '''Register r's shape in the shape_of dictionary.'''
1274        if r not in self.shape_of:
1275            try:
1276                self.set_shape(r, self.shape_tuple(r))
1277            except AttributeError:  # XXX: where would this come from?
1278                self.set_shape(r, None)
1279
1280    def make_vector_shape(self, r):
1281        return make_vector(*self.shape_of[r])
1282
1283    #
1284    # Feature interface
1285    #
1286    #
1287    def on_attach(self, fgraph):
1288        assert not hasattr(fgraph, 'shape_feature')
1289        fgraph.shape_feature = self
1290        # Must be local to the object as otherwise we reuse the same
1291        # variable for multiple fgraph!
1292        self.lscalar_one = T.constant(1, dtype='int64')
1293        assert self.lscalar_one.type == T.lscalar
1294
1295        self.shape_of = {}
1296        # Variable -> tuple(scalars) or None  (All tensor vars map to tuple)
1297
1298        self.scheduled = {}
1299        # Variable ->
1300
1301        self.shape_of_reverse_index = {}
1302        # shape var -> graph v
1303
1304        for node in fgraph.toposort():
1305            self.on_import(fgraph, node, reason='on_attach')
1306
1307    def on_detach(self, fgraph):
1308        self.shape_of = {}
1309        self.scheduled = {}
1310        self.shape_of_reverse_index = {}
1311        del fgraph.shape_feature
1312
1313    def on_import(self, fgraph, node, reason):
1314        if node.outputs[0] in self.shape_of:
1315            # this is a revert, not really an import
1316            for r in node.outputs + node.inputs:
1317                assert r in self.shape_of
1318            return
1319
1320        for i, r in enumerate(node.inputs):
1321            # make sure we have shapes for the inputs
1322            self.init_r(r)
1323
1324        o_shapes = self.get_node_infer_shape(node)
1325
1326        # this is packed information
1327        # an element of o_shapes is either None or a tuple
1328        #   elements of the tuple can be either strings, or ints
1329        if len(o_shapes) != len(node.outputs):
1330            raise Exception(
1331                ('The infer_shape method for the Op "%s" returned a list ' +
1332                 'with the wrong number of element: len(o_shapes) = %d ' +
1333                 ' != len(node.outputs) = %d') % (str(node.op),
1334                                                  len(o_shapes),
1335                                                  len(node.outputs)))
1336
1337        # Ensure shapes are in 'int64'. This is to make sure the assert
1338        # found in the `local_useless_subtensor` optimization does not fail.
1339        for sh_idx, sh in enumerate(o_shapes):
1340            if sh is None:
1341                continue
1342            if not isinstance(sh, (list, tuple)):
1343                raise ValueError("infer_shape of %s didn't return a list of"
1344                                 " list. It returned '%s'" % (str(node), str(o_shapes)))
1345            new_shape = []
1346            for i, d in enumerate(sh):
1347                # Note: we ignore any shape element that is not typed (i.e.,
1348                # does not have a 'dtype' attribute). This means there may
1349                # still remain int elements that are int32 on 32-bit platforms,
1350                # but this works with `local_useless_subtensor`, so for now we
1351                # keep it this way. See #266 for a better long-term fix.
1352                if getattr(d, 'dtype', 'int64') != 'int64':
1353                    assert d.dtype in theano.tensor.discrete_dtypes, (node, d.dtype)
1354                    assert str(d.dtype) != 'uint64', node
1355                    new_shape += sh[len(new_shape):i + 1]
1356                    if isinstance(d, T.Constant):
1357                        casted_d = T.constant(d.data, dtype='int64')
1358                    else:
1359                        casted_d = theano.tensor.cast(d, 'int64')
1360                    new_shape[i] = casted_d
1361            if new_shape:
1362                # We replace the shape with wrong dtype by the one with
1363                # 'int64'.
1364                new_shape += sh[len(new_shape):]
1365                o_shapes[sh_idx] = tuple(new_shape)
1366
1367        for r, s in izip(node.outputs, o_shapes):
1368            self.set_shape(r, s)
1369
1370    def on_change_input(self, fgraph, node, i, r, new_r, reason):
1371        if new_r not in self.shape_of:
1372            # It happen that the fgraph didn't called on_import for some
1373            # new_r.  This happen when new_r don't have an
1374            # owner(i.e. it is a constant or an input of the graph)
1375            # update_shape suppose that r and new_r are in shape_of.
1376            self.init_r(new_r)
1377
1378        # This tells us that r and new_r must have the same shape if
1379        # we didn't know that the shapes are related, now we do.
1380        self.update_shape(new_r, r)
1381
1382        # change_input happens in two cases:
1383        # 1) we are trying to get rid of r, or
1384        # 2) we are putting things back after a failed transaction.
1385
1386        # In case 1, if r has a shape_i client, we will want to
1387        # replace the shape_i of r with the shape of new_r.  Say that
1388        # r is *scheduled*.
1389        # At that point, node is no longer a client of r, but of new_r
1390        for (shpnode, idx) in (r.clients + [(node, i)]):
1391            if isinstance(getattr(shpnode, 'op', None), Shape_i):
1392                idx = shpnode.op.i
1393                repl = self.shape_of[new_r][idx]
1394                if repl.owner is shpnode:
1395                    # This mean the replacement shape object is
1396                    # exactly the same as the current shape object. So
1397                    # no need for replacement. This happen for example
1398                    # with the InputToGpuOptimizer optimizer.
1399                    continue
1400                if (repl.owner and
1401                        repl.owner.inputs[0] is shpnode.inputs[0] and
1402                        isinstance(repl.owner.op, Shape_i) and
1403                        repl.owner.op.i == shpnode.op.i):
1404                    # The replacement is a shape_i of the same
1405                    # input. So no need to do this equivalent
1406                    # replacement.
1407                    continue
1408
1409                if shpnode.outputs[0] in theano.gof.graph.ancestors([repl]):
1410                    raise InconsistencyError(
1411                        "This substitution would insert a cycle in the graph:"
1412                        "node: %s, i: %i, r: %s, new_r: %s"
1413                        % (node, i, r, new_r))
1414
1415                self.scheduled[shpnode] = new_r
1416        # In case 2, if r is a variable that we've scheduled for shape update,
1417        # then we should cancel it.
1418        unscheduled = [k for k, v in self.scheduled.items() if v == r]
1419        for k in unscheduled:
1420            del self.scheduled[k]
1421
1422        # In either case, r could be in shape_of.values(), that is, r itself
1423        # is the shape of  something. In that case, we want to update
1424        # the value in shape_of, to keep it up-to-date.
1425        for v in self.shape_of_reverse_index.get(r, []):
1426            # The reverse index is only approximate. It is not updated on
1427            # deletion of variables, or on change_input so it might be the
1428            # case that there are a few extra `v`'s in it that no longer have
1429            # a shape of r or possibly have been deleted from shape_of
1430            # entirely. The important thing is that it permits to recall
1431            # all variables with r in their shape.
1432            for ii, svi in enumerate(self.shape_of.get(v, [])):
1433                if svi == r:
1434                    self.set_shape_i(v, ii, new_r)
1435        self.shape_of_reverse_index[r] = set()
1436
1437    def same_shape(self, x, y, dim_x=None, dim_y=None):
1438        """Return True if we are able to assert that x and y have the
1439        same shape.
1440
1441        dim_x and dim_y are optional. If used, they should be an index
1442        to compare only 1 dimension of x and y.
1443
1444        """
1445        sx = self.shape_of[x]
1446        sy = self.shape_of[y]
1447        if sx is None or sy is None:
1448            return False
1449        if dim_x is not None:
1450            sx = [sx[dim_x]]
1451        if dim_y is not None:
1452            sy = [sy[dim_y]]
1453        assert len(sx) == len(sy)
1454
1455        # We look on each dimensions we want to compare.
1456        # If any of them can't be asserted to be equal, return False.
1457        # Otherwise, we return True at the end.
1458        for dx, dy in zip(sx, sy):
1459            if dx is dy:
1460                continue
1461            # Need to try to find that they are the same shape. We
1462            # need to compare the full graph. It could be slow. So I
1463            # just implement for now the case of Shape_i.
1464            if not dx.owner or not dy.owner:
1465                return False
1466            if (not isinstance(dx.owner.op, Shape_i) or
1467                    not isinstance(dy.owner.op, Shape_i)):
1468                return False
1469            opx = dx.owner.op
1470            opy = dy.owner.op
1471            if not (opx.i == opy.i):
1472                return False
1473            # FB I'm not sure if this handle correctly constants.
1474            if dx.owner.inputs[0] == dy.owner.inputs[0]:
1475                continue
1476            # To be sure to cover all case, call equal_computation.
1477            # Can't use theano.gof.graph.is_same_graph(dx, dy)
1478            # As it currently expect that dx and dy aren't in a FunctionGraph
1479            from theano.scan_module.scan_utils import equal_computations
1480            if not equal_computations([dx], [dy]):
1481                return False
1482        return True
1483
1484
1485class ShapeOptimizer(Optimizer):
1486    """Optimizer that serves to add ShapeFeature as an fgraph feature."""
1487    def add_requirements(self, fgraph):
1488        fgraph.attach_feature(ShapeFeature())
1489
1490    def apply(self, fgraph):
1491        pass
1492
1493
1494class UnShapeOptimizer(Optimizer):
1495    """Optimizer remove ShapeFeature as an fgraph feature."""
1496    def apply(self, fgraph):
1497        for feature in fgraph._features:
1498            if isinstance(feature, ShapeFeature):
1499                fgraph.remove_feature(feature)
1500
1501# Register it after merge1 optimization at 0. We don't want to track
1502# the shape of merged node.
1503theano.compile.mode.optdb.register('ShapeOpt', ShapeOptimizer(),
1504                                   0.1, 'fast_run', 'fast_compile')
1505# Not enabled by default for now. Some crossentropy opt use the
1506# shape_feature.  They are at step 2.01. uncanonicalize is at step
1507# 3. After it goes to 48.5 that move to the gpu. So 10 seem resonable.
1508theano.compile.mode.optdb.register('UnShapeOpt', UnShapeOptimizer(),
1509                                   10)
1510
1511
1512def local_elemwise_alloc_op(ElemwiseOP, AllocOP, DimShuffleOP):
1513    def local_elemwise_alloc(node):
1514        """
1515        elemwise(alloc(x, shp), ..., y.TensorType(BROADCAST CONDITION))
1516          -> elemwise(x, y.TensorType(BROADCAST CONDITION))
1517
1518        elemwise(dimshuffle(alloc(x, shp)),... ,y.TensorType(BROADCAST CONDITION))
1519          -> elemwise(x.dimshuffle(...), y.TensorType(BROADCAST CONDITION))
1520
1521        BROADCAST CONDITION: the condition is that the one input that are
1522        not to be optimized to have the same broadcast pattern as the
1523        output.
1524
1525        We can change the alloc by a dimshuffle as the elemwise
1526        already have the shape info.  The dimshuffle will be faster
1527        to exec.
1528
1529        """
1530        if not isinstance(node.op, ElemwiseOP):
1531            return False
1532
1533        if len(node.outputs) > 1:
1534            # Ensure all outputs have the same broadcast pattern
1535            # This is a supposition that I'm not sure is always true.
1536            assert all([o.type.broadcastable ==
1537                        node.outputs[0].type.broadcastable for o in
1538                        node.outputs[1:]])
1539
1540        # The broadcast pattern of the ouptut must match the broadcast
1541        # pattern of at least one of the inputs.
1542        if not any([i.type.broadcastable ==
1543                    node.outputs[0].type.broadcastable for i in node.inputs]):
1544            return False
1545
1546        def dimshuffled_alloc(i):
1547            return (isinstance(i.owner.op, DimShuffleOP) and
1548                    i.owner.inputs[0].owner and
1549                    isinstance(i.owner.inputs[0].owner.op, AllocOP))
1550
1551        # At least one input must have an owner that is either a AllocOP or a
1552        # DimShuffleOP with an owner that is a AllocOP -- otherwise there is
1553        # nothing to optimize.
1554        if not any([i.owner and (isinstance(i.owner.op, AllocOP) or
1555                                 dimshuffled_alloc(i)) for i in node.inputs]):
1556            return False
1557
1558        # Search for input that we can use as a baseline for the dimensions.
1559        assert_op_idx = -1
1560        for idx, i in enumerate(node.inputs):
1561            if i.type.broadcastable == node.outputs[0].type.broadcastable:
1562                # Prefer an input that is not a AllocOP nor a DimShuffleOP of a
1563                # AllocOP so that all allocs can be optimized.
1564                if not (i.owner and (isinstance(i.owner.op, AllocOP) or
1565                        dimshuffled_alloc(i))):
1566                    assert_op_idx = idx
1567                    break
1568
1569        # It may be the case that only AllocOP and DimShuffleOP of AllocOP exist.
1570        if assert_op_idx < 0:
1571            # We want to optimize as many allocs as possible. When
1572            # there is more than one then do all but one.  number of
1573            # inputs with alloc or dimshuffle alloc
1574            l2 = [i for i in node.inputs
1575                  if (i.owner and (isinstance(i.owner.op, AllocOP) or
1576                      dimshuffled_alloc(i)))]
1577            # If only 1 alloc or dimshuffle alloc, it is the one we
1578            # will use for the shape. So no alloc would be removed.
1579            if len(l2) > 1:
1580                # l containt inputs with alloc or dimshuffle alloc
1581                # only.  Its length will always be at least one, as we
1582                # checked that before
1583                l = [idx for idx, i in enumerate(node.inputs)
1584                     if i.broadcastable == node.outputs[0].broadcastable]
1585                assert_op_idx = l[0]  # The first one is as good as any to use.
1586            else:
1587                # Nothing would be optimized!
1588                return False
1589
1590        assert_op = node.inputs[assert_op_idx]
1591        cmp_op = assert_op
1592        new_i = []
1593        same_shape = node.fgraph.shape_feature.same_shape
1594        for i in node.inputs:
1595            # Remove alloc
1596            if (i.owner and isinstance(i.owner.op, AllocOP) and
1597                    i.owner.inputs[0].type != i.owner.outputs[0].type):
1598                # when i.owner.inputs[0].type == i.owner.outputs[0].type we
1599                # will remove that alloc later
1600                assert i.type.ndim == cmp_op.ndim
1601                if theano.config.experimental.local_alloc_elemwise_assert:
1602                    get_shape = node.fgraph.shape_feature.get_shape
1603                    cond = []
1604                    for idx in xrange(i.type.ndim):
1605                        if (not i.type.broadcastable[idx] and
1606                                not same_shape(i, cmp_op, idx, idx)):
1607                            i_shp = get_shape(i, idx)
1608                            cmp_shp = get_shape(cmp_op, idx)
1609                            cond.append(T.eq(i_shp, cmp_shp))
1610                    if cond:
1611                        assert_op = assert_(assert_op, *cond)
1612                new_i.append(i.owner.inputs[0])
1613
1614            # Remove Alloc in DimShuffle
1615            elif i.owner and dimshuffled_alloc(i):
1616                assert i.type.ndim == cmp_op.type.ndim
1617                if theano.config.experimental.local_alloc_elemwise_assert:
1618                    assert_cond = [T.eq(i.shape[idx], cmp_op.shape[idx])
1619                                   for idx in xrange(i.type.ndim)
1620                                   if not i.type.broadcastable[idx] and
1621                                   not same_shape(i, cmp_op, idx, idx)]
1622                    if assert_cond:
1623                        assert_op = assert_(assert_op, *assert_cond)
1624                alloc_input = i.owner.inputs[0].owner.inputs[0]
1625                if alloc_input.ndim != i.owner.inputs[0].ndim:
1626                    # The alloc can add dimension to the value
1627                    # We add a dimshuffle to add them.
1628                    # We let later optimization merge the multiple dimshuffle
1629                    nb_dim_to_add = i.owner.inputs[0].ndim - alloc_input.ndim
1630                    alloc_input = alloc_input.dimshuffle(
1631                        ['x'] * nb_dim_to_add +
1632                        list(range(alloc_input.ndim)))
1633
1634                # We need to keep the dimshuffle. It could swap axes or
1635                # add dimensions anywhere.
1636                r_i = i.owner.op(alloc_input)
1637
1638                # Copy stack trace from i to new_i
1639                copy_stack_trace(i, r_i)
1640                new_i.append(r_i)
1641            else:
1642                new_i.append(i)
1643        new_i[assert_op_idx] = assert_op
1644
1645        ret = node.op(*new_i, return_list=True)
1646
1647        # Copy over stack trace from previous outputs to new outputs.
1648        copy_stack_trace(node.outputs, ret)
1649        return ret
1650
1651    return local_elemwise_alloc
1652
1653# TODO, global optimizer that lift the assert to the beginning of the graph.
1654# TODO, optimize all inputs when possible -- currently when all inputs have
1655# an alloc all but one is optimized.
1656
1657local_elemwise_alloc = register_specialize(
1658    gof.local_optimizer([T.Elemwise])(
1659        local_elemwise_alloc_op(T.Elemwise, T.Alloc, T.DimShuffle)),
1660    'local_alloc_elemwise')
1661
1662
1663@gof.local_optimizer([T.Elemwise])
1664def local_fill_sink(node):
1665    """
1666    f(fill(a, b), fill(c, d), e) -> fill(c, fill(a, f(b, d, e)))
1667    f need to be an elemwise that isn't a fill.
1668    """
1669    if (not hasattr(node, 'op') or
1670            not isinstance(node.op, T.Elemwise) or
1671            node.op == T.fill):
1672        return False
1673    models = []
1674    inputs = []
1675    for input in node.inputs:
1676        if input.owner and input.owner.op == T.fill:
1677            models.append(input.owner.inputs[0])
1678            inputs.append(input.owner.inputs[1])
1679        else:
1680            inputs.append(input)
1681    if not models:
1682        return False
1683    c = node.op(*inputs)
1684    for model in models:
1685        if model.type != c.type:
1686            c = T.fill(model, c)
1687
1688    # The newly created node c doesn't has 'clients',
1689    # so this iteration is took place with node.outputs[0]
1690    replacements = {node.outputs[0]: c}
1691    for client, cl_idx in node.outputs[0].clients:
1692        if (hasattr(client, 'op') and
1693                isinstance(client.op, T.Elemwise) and
1694                not client.op == T.fill):
1695            client_inputs = client.inputs[:]
1696            client_inputs[cl_idx] = c
1697            new_client = client.op(*client_inputs)
1698
1699            # Add clients to new_client
1700            new_client.owner.outputs[0].clients = client.outputs[0].clients
1701            r = local_fill_sink.transform(new_client.owner)
1702            if not r:
1703                continue
1704            replacements.update(r)
1705    return replacements
1706
1707register_canonicalize(local_fill_sink)
1708
1709
1710@register_specialize
1711@register_stabilize
1712# @register_canonicalize  # We make full pass after the canonizer phase.
1713@gof.local_optimizer([T.fill])
1714def local_fill_to_alloc(node):
1715    """fill(s,v) -> alloc(v, shape(s))
1716
1717    This is an important optimization because with the shape_to_shape_i
1718    optimization, the dependency on 's' is often removed.
1719
1720    """
1721    if node.op == T.fill:
1722        r, v = node.inputs
1723        if v.type == node.outputs[0].type:
1724            # this is a useless fill, erase it.
1725            rval = [v]
1726        elif v.type.broadcastable == node.outputs[0].type.broadcastable:
1727            # this is a cast
1728            rval = [T.cast(v, node.outputs[0].type.dtype)]
1729        elif r.type.broadcastable == node.outputs[0].type.broadcastable:
1730            # we are broadcasting v somehow, but not r
1731            o = broadcast_like(v, r, node.fgraph, dtype=v.dtype)
1732            copy_stack_trace(node.outputs[0], o)
1733            rval = [o]
1734        else:
1735            # we are broadcasting both v and r,
1736            # the output shape must be computed
1737            #
1738            # TODO: implement this case (including a test!)
1739            #
1740            #  I think the strategy should be to extend the shorter
1741            #  shape vector with 1s (how?) and then take the
1742            #  elementwise max of the two.  - how to flag an error of
1743            #  shape mismatch where broadcasting should be illegal?
1744            return
1745            # TODO: cut out un-necessary dimshuffles of v
1746
1747        assert rval[0].type == node.outputs[0].type, (
1748            'rval', rval[0].type, 'orig', node.outputs[0].type, 'node',
1749            node,)  # theano.printing.debugprint(node.outputs[0], file='str'))
1750        return rval
1751
1752# Register this after stabilize at 1.5 to make sure stabilize don't
1753# get affected by less canonicalized graph due to alloc.
1754compile.optdb.register('local_fill_to_alloc',
1755                       in2out(local_fill_to_alloc),
1756                       1.51, 'fast_run')
1757# Needed to clean some extra alloc added by local_fill_to_alloc
1758compile.optdb.register('local_elemwise_alloc',
1759                       in2out(local_elemwise_alloc),
1760                       1.52, 'fast_run')
1761
1762
1763@register_canonicalize("fast_compile")
1764@register_useless
1765@gof.local_optimizer([T.fill])
1766def local_useless_fill(node):
1767    """fill(s,v) -> v
1768
1769    This optimization is only needed in FAST_COMPILE to make the code
1770    more readable. Normally, it is done by the local_fill_to_alloc
1771    opt.
1772
1773    """
1774    if node.op == T.fill:
1775        r, v = node.inputs
1776        if v.type == node.outputs[0].type:
1777            # this is a useless fill, erase it.
1778            # also, we don't need to copy over any stack traces here
1779            return [v]
1780
1781
1782@register_specialize
1783@register_stabilize
1784@register_canonicalize
1785@register_useless
1786@gof.local_optimizer([T.alloc])
1787def local_useless_alloc(node):
1788    """
1789    If the input type is the same as the output type (dtype and broadcast)
1790    there is no change in the shape of the input. So this is just a simple copy
1791    of the input. This is not needed.
1792
1793    """
1794    op = node.op
1795    if not isinstance(op, Alloc):
1796        return False
1797
1798    input = node.inputs[0]
1799    output = node.outputs[0]
1800
1801    # Check if dtype and broadcast remain the same.
1802    if input.type == output.type:
1803        # We don't need to copy over any stack traces here
1804        return [input]
1805
1806
1807@register_specialize
1808@register_stabilize
1809@register_canonicalize
1810@gof.local_optimizer([T.alloc])
1811def local_canonicalize_alloc(node):
1812    """If the input type is the same as the output type (dtype and broadcast)
1813    there is no change in the shape of the input. So this is just a simple copy
1814    of the input. This is not needed. (as local_useless_alloc)
1815
1816    Also, it will canonicalize alloc by creating Dimshuffle after the
1817    alloc to introduce the dimensions of constant size 1.
1818
1819    See https://github.com/Theano/Theano/issues/4072 to know why this
1820    is needed.
1821
1822    """
1823    op = node.op
1824    if not isinstance(op, Alloc):
1825        return False
1826
1827    input = node.inputs[0]
1828    output = node.outputs[0]
1829
1830    # Check if dtype and broadcast remain the same.
1831    if input.type == output.type:
1832        # We don't need to copy over any stack traces here
1833        return [input]
1834
1835    # Allow local_merge_alloc to do its work first
1836    clients = getattr(output, 'clients', [])
1837    for client, i in clients:
1838        if client != "output" and isinstance(client.op, Alloc):
1839            return
1840
1841    # Check if alloc adds a broadcastable dimension with shape 1.
1842
1843    output_shape = node.inputs[1:]
1844    num_dims_with_size_1_added_to_left = 0
1845    for i in range(len(output_shape) - input.ndim):
1846        if extract_constant(output_shape[i], only_process_constants=True) == 1:
1847            num_dims_with_size_1_added_to_left += 1
1848        else:
1849            break
1850    new_output_shape = output_shape[num_dims_with_size_1_added_to_left:]
1851    if num_dims_with_size_1_added_to_left > 0 and len(new_output_shape) >= input.ndim:
1852        if output.broadcastable[num_dims_with_size_1_added_to_left:] == input.broadcastable:
1853            inner = input
1854        else:
1855            inner = op(*([input] + new_output_shape))
1856        dimshuffle_new_order = (['x'] * num_dims_with_size_1_added_to_left +
1857                                list(xrange(len(new_output_shape))))
1858        return [DimShuffle(inner.type.broadcastable, dimshuffle_new_order)(inner)]
1859
1860
1861# Don't register by default.
1862@gof.local_optimizer([T.AllocEmpty])
1863def local_alloc_empty_to_zeros(node):
1864    """This convert AllocEmpty to Alloc of 0.
1865
1866    This help investigate NaN with NanGuardMode.  Not registered by
1867    default. To activate it, use the Theano flag
1868    optimizer_including=alloc_empty_to_zeros. This also enable
1869    the GPU version of this optimizations.
1870
1871    """
1872    if isinstance(node.op, T.AllocEmpty):
1873        return [T.zeros(node.inputs, dtype=node.outputs[0].dtype)]
1874compile.optdb.register('local_alloc_empty_to_zeros',
1875                       in2out(local_alloc_empty_to_zeros),
1876                       # After move to gpu and merge2, before inplace.
1877                       49.3,
1878                       'alloc_empty_to_zeros',)
1879
1880
1881@register_specialize
1882@register_canonicalize
1883@gof.local_optimizer([T.Shape])
1884def local_shape_to_shape_i(node):
1885    if node.op == T.shape:
1886        # This optimization needs ShapeOpt and fgraph.shape_feature
1887        if not hasattr(node.fgraph, 'shape_feature'):
1888            return
1889        shape_feature = node.fgraph.shape_feature
1890        ret = shape_feature.make_vector_shape(node.inputs[0])
1891
1892        # We need to copy over stack trace from input to output
1893        copy_stack_trace(node.outputs[0], ret)
1894        return [ret]
1895
1896
1897# TODO: Not sure what type of node we are expecting here
1898@register_specialize
1899@register_canonicalize
1900@gof.local_optimizer(None)
1901def local_track_shape_i(node):
1902    try:
1903        shape_feature = node.fgraph.shape_feature
1904    except AttributeError:
1905        return
1906    if node in shape_feature.scheduled:
1907        # Don't unschedule node as it could be reinserted in the
1908        # fgraph as we don't change it in the shapefeature internal
1909        # structure.
1910        assert isinstance(node.op, Shape_i)
1911        replacement = shape_feature.scheduled[node]
1912        return [shape_feature.shape_of[replacement][node.op.i]]
1913
1914
1915@register_specialize
1916@register_canonicalize
1917@gof.local_optimizer([Subtensor])
1918def local_subtensor_inc_subtensor(node):
1919    """
1920    Subtensor(SetSubtensor(x, y, idx), idx) -> y
1921
1922    """
1923    if isinstance(node.op, Subtensor):
1924        x = node.inputs[0]
1925        if not x.owner or not isinstance(x.owner.op, IncSubtensor):
1926            return
1927        if not x.owner.op.set_instead_of_inc:
1928            return
1929
1930        if (x.owner.inputs[2:] == node.inputs[1:] and
1931                tuple(x.owner.op.idx_list) == tuple(node.op.idx_list)):
1932            out = node.outputs[0]
1933            y = x.owner.inputs[1]
1934            # If the dtypes differ, cast y into x.dtype
1935            if x.dtype != y.dtype:
1936                y = y.astype(x.dtype)
1937            if out.type == y.type:
1938                # if x[idx] and y have the same type, directly return y
1939                return [y]
1940            else:
1941                # The difference is related to broadcasting pattern
1942                assert out.broadcastable != y.broadcastable
1943                # We have to alloc y to the shape of x[idx]
1944                x_subtensor = node.op(x.owner.inputs[0], *x.owner.inputs[2:])
1945                return [T.alloc(y, *x_subtensor.shape)]
1946        else:
1947            return
1948
1949
1950@register_specialize
1951@register_canonicalize
1952@gof.local_optimizer([Subtensor])
1953def local_subtensor_remove_broadcastable_index(node):
1954    """
1955    Remove broadcastable dimension with index 0 or -1
1956    a[:,:,:,0] -> a.dimshuffle(0,1,2), when
1957        a.broadcastable = (False, False, False, True)
1958    a[0,:,-1,:] -> a.dimshuffle(1,3), when
1959        a.broadcastable = (True, False, True, False)
1960
1961    """
1962    if isinstance(node.op, Subtensor):
1963        idx = node.op.idx_list
1964    else:
1965        return
1966
1967    remove_dim = []
1968    node_inputs_idx = 1
1969    for dim, elem in enumerate(idx):
1970        if isinstance(elem, (scalar.Scalar)):
1971            # The idx is a Scalar, ie a Type. This means the actual index
1972            # is contained in node.inputs[1]
1973            dim_index = node.inputs[node_inputs_idx]
1974            if type(dim_index) == theano.scalar.basic.ScalarConstant:
1975                dim_index = dim_index.value
1976            if dim_index in [0, -1] and node.inputs[0].broadcastable[dim]:
1977                remove_dim.append(dim)
1978                node_inputs_idx += 1
1979            else:
1980                return
1981        elif isinstance(elem, slice):
1982            if elem != slice(None):
1983                return
1984        elif isinstance(elem, (integer_types, np.integer)):
1985            if elem in [0, -1] and node.inputs[0].broadcastable[dim]:
1986                remove_dim.append(dim)
1987        else:
1988            raise TypeError('case not expected')
1989
1990    if len(remove_dim) == 0:
1991        return
1992    else:
1993        all_dim = range(node.inputs[0].ndim)
1994        remain_dim = [x for x in all_dim if x not in remove_dim]
1995        return [node.inputs[0].dimshuffle(tuple(remain_dim))]
1996
1997
1998@register_specialize
1999@register_canonicalize('fast_compile_gpu')
2000@register_useless
2001@gof.local_optimizer([Subtensor, AdvancedSubtensor1])
2002def local_subtensor_make_vector(node):
2003    """
2004    Replace all subtensor(make_vector) like:
2005    [a,b,c][0] -> a
2006    [a,b,c][0:2] -> [a,b]
2007
2008    Replace all AdvancedSubtensor1(make_vector) like:
2009    [a,b,c][[0,2]] -> [a,c]
2010
2011    We can do this for constant indexes.
2012
2013    """
2014    x = node.inputs[0]
2015    if not x.owner or x.owner.op != make_vector:
2016        return
2017
2018    if isinstance(node.op, Subtensor):
2019        # This optimization needs ShapeOpt and fgraph.shape_feature
2020        try:
2021            idx, = node.op.idx_list
2022        except Exception:
2023            # 'how can you have multiple indexes into a shape?'
2024            raise
2025
2026        if isinstance(idx, (scalar.Scalar, T.TensorType)):
2027            # The idx is a Scalar, ie a Type. This means the actual index
2028            # is contained in node.inputs[1]
2029            old_idx, idx = idx, node.inputs[1]
2030            assert idx.type == old_idx
2031    elif isinstance(node.op, AdvancedSubtensor1):
2032        idx = node.inputs[1]
2033    else:
2034        return
2035
2036    if isinstance(idx, (integer_types, np.integer)):
2037        # We don't need to copy over any stack traces here
2038        return [x.owner.inputs[idx]]
2039    elif isinstance(idx, Variable):
2040        if idx.ndim == 0:
2041            # if it is a constant we can do something with it
2042            try:
2043                v = get_scalar_constant_value(idx, only_process_constants=True)
2044                if isinstance(v, np.integer):
2045                    # Python 2.4 wants to index only with Python integers
2046                    v = int(v)
2047                # We don't need to copy over any stack traces here
2048                try:
2049                    ret = [x.owner.inputs[v]]
2050                except IndexError:
2051                    raise NotScalarConstantError("Bad user graph!")
2052                return ret
2053            except NotScalarConstantError:
2054                pass
2055        elif idx.ndim == 1 and isinstance(idx, T.Constant):
2056            values = list(map(int, list(idx.value)))
2057            ret = make_vector(*[x.owner.inputs[v] for v in values])
2058
2059            # Copy over stack trace from previous output to new output
2060            copy_stack_trace(node.outputs[0], ret)
2061            ret = T.patternbroadcast(ret, node.outputs[0].broadcastable)
2062            return [ret]
2063        else:
2064            raise TypeError('case not expected')
2065    elif isinstance(idx, slice):
2066        # it is a slice of ints and/or Variables
2067        # check subtensor to see if it can contain constant variables, and if
2068        # it can, then try to unpack them.
2069        try:
2070            const_slice = node.op.get_constant_idx(node.inputs,
2071                                                   allow_partial=False)[0]
2072            ret = make_vector(*x.owner.inputs[const_slice])
2073            # Copy over stack trace from previous outputs to new output
2074            copy_stack_trace(node.outputs, ret)
2075            ret = T.patternbroadcast(ret, node.outputs[0].broadcastable)
2076            return [ret]
2077        except NotScalarConstantError:
2078            pass
2079    else:
2080        raise TypeError('case not expected')
2081
2082
2083# TODO: the other optimization for and, or, xor, le and ge see ticket #496.
2084
2085@register_useless
2086@register_canonicalize('fast_compile')
2087@register_specialize
2088@gof.local_optimizer([T.Elemwise])
2089def local_useless_elemwise(node):
2090    """
2091    eq(x, x) -> 1
2092    neq(x, x) -> 0
2093    mul(x) -> x
2094    add(x) -> x
2095    identity(x) -> x
2096    and(x, 1) -> x  (if x.dtype == 'bool')
2097    and(x, 0) -> zeros_like(x)
2098    or(x, 0) -> x
2099    or(x, 1) -> ones_like(x)  (if x.dtype == 'bool')
2100    xor(x, x) -> zeros_like(x)
2101
2102    """
2103    if isinstance(node.op, T.Elemwise):
2104        # We call zeros_like and one_like with opt=True to generate a
2105        # cleaner graph.
2106        dtype = node.outputs[0].dtype
2107
2108        if node.op.scalar_op == theano.scalar.eq and len(node.inputs) == 2:
2109            if node.inputs[0] == node.inputs[1]:
2110                # it is the same var in the graph. That will always be true
2111                ret = T.ones_like(node.inputs[0], dtype=dtype, opt=True)
2112
2113                # Copy stack trace from input to constant output
2114                copy_stack_trace(node.outputs[0], ret)
2115                return [ret]
2116        elif node.op.scalar_op == theano.scalar.neq and len(node.inputs) == 2:
2117            if node.inputs[0] == node.inputs[1]:
2118                # it is the same var in the graph. That will always be false
2119                ret = T.zeros_like(node.inputs[0], dtype=dtype, opt=True)
2120
2121                # Copy stack trace from input to constant output
2122                copy_stack_trace(node.outputs[0], ret)
2123                return [ret]
2124
2125        elif node.op.scalar_op == theano.scalar.mul and len(node.inputs) == 1:
2126            # No need to copy over any stack trace
2127            return [node.inputs[0]]
2128
2129        elif node.op.scalar_op == theano.scalar.add and len(node.inputs) == 1:
2130            # No need to copy over any stack trace
2131            return [node.inputs[0]]
2132        elif (node.op.scalar_op == theano.scalar.identity and
2133              len(node.inputs) == 1):
2134            return [node.inputs[0]]
2135
2136        elif (isinstance(node.op.scalar_op, scalar.AND) and
2137              len(node.inputs) == 2):
2138
2139            if isinstance(node.inputs[0], T.TensorConstant):
2140                const_val = T.extract_constant(node.inputs[0], only_process_constants=True)
2141                if not isinstance(const_val, Variable):
2142                    if const_val == 0:
2143                        return [T.zeros_like(node.inputs[1], dtype=dtype,
2144                                             opt=True)]
2145                    elif node.outputs[0].dtype == 'bool':
2146                        # If the output is not Boolean, it is the bitwise AND,
2147                        # and this optimization would be wrong
2148                        return [node.inputs[1].astype(node.outputs[0].dtype)]
2149
2150            if isinstance(node.inputs[1], T.TensorConstant):
2151                const_val = T.extract_constant(node.inputs[1], only_process_constants=True)
2152                if not isinstance(const_val, Variable):
2153                    if const_val == 0:
2154                        return [T.zeros_like(node.inputs[0], dtype=dtype,
2155                                             opt=True)]
2156                    elif node.outputs[0].dtype == 'bool':
2157                        # If the output is not Boolean, it is the bitwise AND,
2158                        # and this optimization would be wrong
2159                        return [node.inputs[0].astype(node.outputs[0].dtype)]
2160
2161        elif (isinstance(node.op.scalar_op, scalar.OR) and
2162              len(node.inputs) == 2):
2163
2164            if isinstance(node.inputs[0], T.TensorConstant):
2165                const_val = T.extract_constant(node.inputs[0], only_process_constants=True)
2166                if not isinstance(const_val, Variable):
2167                    if const_val == 0:
2168                        return [node.inputs[1].astype(node.outputs[0].dtype)]
2169                    elif node.outputs[0].dtype == 'bool':
2170                        # If the output is not Boolean, it is the bitwise OR,
2171                        # and this optimization would be wrong
2172                        return [T.ones_like(node.inputs[1], dtype=dtype,
2173                                            opt=True)]
2174
2175            if isinstance(node.inputs[1], T.TensorConstant):
2176                const_val = T.extract_constant(node.inputs[1], only_process_constants=True)
2177                if not isinstance(const_val, Variable):
2178                    if const_val == 0:
2179                        return [node.inputs[0].astype(node.outputs[0].dtype)]
2180                    elif node.outputs[0].dtype == 'bool':
2181                        # If the output is not Boolean, it is the bitwise OR,
2182                        # and this optimization would be wrong
2183                        return [T.ones_like(node.inputs[0], dtype=dtype,
2184                                            opt=True)]
2185
2186        elif (isinstance(node.op.scalar_op, scalar.XOR) and
2187              len(node.inputs) == 2):
2188            if node.inputs[0] is node.inputs[1]:
2189                return [T.zeros_like(node.inputs[0], dtype=dtype, opt=True)]
2190
2191
2192@register_specialize
2193@gof.local_optimizer([T.Elemwise])
2194def local_alloc_unary(node):
2195    """unary(alloc(x, shp)) -> alloc(unary(x), shp)"""
2196    if isinstance(node.op, T.Elemwise) and len(node.inputs) == 1:
2197        a = node.inputs[0]
2198        if a.owner and isinstance(a.owner.op, T.Alloc):
2199            x = a.owner.inputs[0]
2200            shp = a.owner.inputs[1:]
2201            v = node.op(x)
2202            # T.alloc does not preserve the stacktrace of v,
2203            # so we need to copy it over from x.
2204            copy_stack_trace(node.outputs[0], v)
2205            ret = T.alloc(T.cast(v, node.outputs[0].dtype), *shp)
2206
2207            # T.cast does not preserve the stacktrace of x,
2208            # so we need to copy it over to the output.
2209            copy_stack_trace([node.outputs[0], a], ret)
2210            return [ret]
2211
2212
2213@register_canonicalize
2214@register_specialize
2215@gof.local_optimizer([T.Elemwise])
2216def local_cast_cast(node):
2217    """cast(cast(x, dtype1), dtype2)
2218
2219    when those contrain:
2220    dtype1 == dtype2
2221    OR the base dtype is the same (int, uint, float, complex)
2222          and the first cast cause an upcast.
2223
2224    """
2225    if (not isinstance(node.op, T.Elemwise) or
2226            not isinstance(node.op.scalar_op, scalar.Cast)):
2227        return
2228    x = node.inputs[0]
2229    if (not x.owner or
2230            not isinstance(x.owner.op, T.Elemwise) or
2231            not isinstance(x.owner.op.scalar_op, scalar.Cast)):
2232        return
2233
2234    type1 = x.owner.op.scalar_op.o_type
2235    type2 = node.op.scalar_op.o_type
2236    base = x.owner.inputs[0]
2237
2238    if type1 == type2:
2239        # We don't need to copy over any stack traces here
2240        return [x]
2241
2242    if(is_an_upcast(base.dtype, type1.dtype)):
2243        # Checking for further redundancy. Eg: int8 -> int32 -> int8
2244        if(type2.dtype == base.dtype):
2245            return x.owner.inputs
2246        else:
2247            # Apply the second cast only
2248            v = node.op(base)
2249            # Copy stack trace from the output of the original cast
2250            copy_stack_trace(node.outputs[0], v)
2251            return [v]
2252
2253
2254def is_an_upcast(type1, type2):
2255    """Given two data types (as strings), check if converting to
2256    type2 from type1 constitutes an upcast.
2257    Differs from theano.scalar.upcast
2258
2259    """
2260    category = {
2261        # The first number in the pair is the dtype (bool, uint, int, float,
2262        # complex). Conversion from higher to lower is never an upcast.
2263        # The second number roughly indicates the precision. Again, conversion
2264        # from higher to lower is never an upcast.
2265
2266        'bool': (0, 0),
2267        'uint8': (1, 1), 'uint16': (1, 2), 'uint32': (1, 3), 'uint64': (1, 4),
2268        'int8': (2, 1), 'int16': (2, 2), 'int32': (2, 3), 'int64': (2, 4),
2269        'float16': (3, 1.5), 'float32': (3, 2.5), 'float64': (3, 3.5),
2270        'complex64': (4, 3), 'complex128': (4, 4)
2271    }
2272
2273    cat1 = category[type1]
2274    cat2 = category[type2]
2275
2276    if(cat2[0] >= cat1[0] and cat2[1] > cat1[1]):
2277        return True
2278    else:
2279        return False
2280
2281
2282@register_canonicalize
2283@register_specialize
2284@gof.local_optimizer([T.Elemwise])
2285def local_func_inv(node):
2286    """
2287    Check for two consecutive operations that are functional inverses
2288    and remove them from the function graph.
2289
2290    """
2291    inv_pairs = (
2292        (basic.Deg2Rad, basic.Rad2Deg),
2293        (basic.Cosh, basic.ArcCosh),
2294        (basic.Tanh, basic.ArcTanh),
2295        (basic.Sinh, basic.ArcSinh),
2296        (basic.Conj, basic.Conj),
2297        (basic.Neg, basic.Neg),
2298        (basic.Inv, basic.Inv),
2299    )
2300    x = node.inputs[0]
2301
2302    if not isinstance(node.op, T.Elemwise):
2303        return
2304    if (not x.owner or not isinstance(x.owner.op, T.Elemwise)):
2305        return
2306
2307    prev_op = x.owner.op.scalar_op
2308    node_op = node.op.scalar_op
2309
2310    for inv_pair in inv_pairs:
2311        if is_inverse_pair(node_op, prev_op, inv_pair):
2312            # We don't need to copy stack trace, because the optimization
2313            # is trivial and maintains the earlier stack trace
2314            return x.owner.inputs
2315
2316    return
2317
2318
2319def is_inverse_pair(node_op, prev_op, inv_pair):
2320    """
2321    Given two consecutive operations, check if they are the
2322    provided pair of inverse functions.
2323
2324    """
2325    node_is_op0 = isinstance(node_op, inv_pair[0])
2326    node_is_op1 = isinstance(node_op, inv_pair[1])
2327    prev_is_op0 = isinstance(prev_op, inv_pair[0])
2328    prev_is_op1 = isinstance(prev_op, inv_pair[1])
2329
2330    return (node_is_op0 and prev_is_op1) or (node_is_op1 and prev_is_op0)
2331
2332
2333class Assert(T.Op):
2334    """
2335    Implements assertion in a computational graph.
2336
2337    Returns the first parameter if the condition is true, otherwise, triggers
2338    AssertionError.
2339
2340    Notes
2341    -----
2342    This Op is a debugging feature. It can be removed from the graph
2343    because of optimizations, and can hide some possible optimizations to
2344    the optimizer. Specifically, removing happens if it can be determined
2345    that condition will always be true. Also, the output of the Op must be
2346    used in the function computing the graph, but it doesn't have to be
2347    returned.
2348
2349    Examples
2350    --------
2351    >>> import theano
2352    >>> T = theano.tensor
2353    >>> x = T.vector('x')
2354    >>> assert_op = T.opt.Assert()
2355    >>> func = theano.function([x], assert_op(x, x.size<2))
2356
2357    """
2358    _f16_ok = True
2359    __props__ = ('msg',)
2360    view_map = {0: [0]}
2361
2362    check_input = False
2363
2364    def __init__(self, msg="Theano Assert failed!"):
2365        self.msg = msg
2366
2367    def __setstate__(self, attrs):
2368        self.__dict__.update(attrs)
2369        if not hasattr(self, 'msg'):
2370            self.msg = "Theano Assert failed!"
2371
2372    def make_node(self, value, *conds):
2373        if not isinstance(value, Variable):
2374            value = T.as_tensor_variable(value)
2375        cond = [T.as_tensor_variable(c) for c in conds]
2376        assert np.all([c.type.ndim == 0 for c in cond])
2377        return gof.Apply(self, [value] + cond, [value.type()])
2378
2379    def perform(self, node, inputs, out_):
2380        out, = out_
2381        v = inputs[0]
2382        out[0] = v
2383        assert np.all(inputs[1:]), self.msg
2384
2385    def grad(self, input, output_gradients):
2386        return output_gradients + [DisconnectedType()()] * (len(input) - 1)
2387
2388    def connection_pattern(self, node):
2389        return [[1]] + [[0]] * (len(node.inputs) - 1)
2390
2391    def c_code(self, node, name, inames, onames, sub):
2392        value = inames[0]
2393        out = onames[0]
2394        check = []
2395        fail = sub['fail']
2396        msg = self.msg.replace('"', '\\"').replace('\n', '\\n')
2397        for idx in xrange(len(inames) - 1):
2398            i = inames[idx + 1]
2399            dtype = node.inputs[idx + 1].dtype
2400            check.append('if(!((npy_%(dtype)s*)PyArray_DATA(%(i)s))[0])'
2401                         '{PyErr_SetString(PyExc_AssertionError,"%(msg)s");'
2402                         '%(fail)s}' % locals())
2403        check = "\n".join(check)
2404        return """
2405        %(check)s
2406        Py_XDECREF(%(out)s);
2407        %(out)s = %(value)s;
2408        Py_INCREF(%(value)s);
2409        """ % locals()
2410
2411    def c_code_cache_version(self):
2412        return (3, 0)
2413
2414    def infer_shape(self, node, input_shapes):
2415        return [input_shapes[0]]
2416
2417assert_ = Assert()
2418# Unittest.assert_ is a deprecated name for assertTrue.
2419# 2to3 convert theano.tensor.opt.assert_ to theano.tensor.opt.assertTrue
2420# So I define a new name as a work around.
2421assert_op = assert_
2422
2423
2424@register_specialize
2425@gof.local_optimizer([Assert])
2426def local_remove_useless_assert(node):
2427    if isinstance(node.op, Assert):
2428        cond = []
2429        for c in node.inputs[1:]:
2430            try:
2431                const = get_scalar_constant_value(c)
2432
2433                if 0 != const.ndim or const == 0:
2434                    # Should we raise an error here? How to be sure it
2435                    # is not catched?
2436                    cond.append(c)
2437            except NotScalarConstantError:
2438                cond.append(c)
2439
2440        if len(cond) == 0:
2441            # We don't need to copy over any stack traces here
2442            return [node.inputs[0]]
2443        if len(cond) != len(node.inputs) - 1:
2444            ret = assert_(node.inputs[0], *cond)
2445
2446            # We copy over stack trace from the output of the original assert
2447            copy_stack_trace(node.outputs[0], ret)
2448            return [ret]
2449
2450
2451@gof.local_optimizer([Assert])
2452def local_remove_all_assert(node):
2453    """An optimization disabled by default that removes all asserts from
2454    the graph.
2455
2456    Notes
2457    -----
2458    See the :ref:`unsafe` section to know how to enable it.
2459
2460    """
2461    if not isinstance(node.op, Assert):
2462        return
2463
2464    # We don't need to copy over any stack traces here
2465    return [node.inputs[0]]
2466# Disabled by default
2467compile.optdb['canonicalize'].register('local_remove_all_assert',
2468                                       local_remove_all_assert,
2469                                       'unsafe',
2470                                       use_db_name_as_tag=False)
2471compile.optdb['stabilize'].register('local_remove_all_assert',
2472                                    local_remove_all_assert,
2473                                    'unsafe',
2474                                    use_db_name_as_tag=False)
2475compile.optdb['specialize'].register('local_remove_all_assert',
2476                                     local_remove_all_assert,
2477                                     'unsafe',
2478                                     use_db_name_as_tag=False)
2479compile.optdb['useless'].register('local_remove_all_assert',
2480                                  local_remove_all_assert,
2481                                  'unsafe',
2482                                  use_db_name_as_tag=False)
2483
2484#######################
2485# Constant Canonicalization
2486############################
2487
2488
2489@register_canonicalize
2490@gof.local_optimizer([T.Elemwise])
2491def local_upcast_elemwise_constant_inputs(node):
2492    """This explicitly upcasts constant inputs to elemwise Ops, when
2493    those Ops do implicit upcasting anyway.
2494
2495    Rationale: it helps merge things like (1-x) and (1.0 - x).
2496
2497    """
2498    if len(node.outputs) > 1:
2499        return
2500    try:
2501        shape_i = node.fgraph.shape_feature.shape_i
2502    except AttributeError:
2503        shape_i = None
2504    if isinstance(node.op, T.Elemwise):
2505        scalar_op = node.op.scalar_op
2506        # print "aa", scalar_op.output_types_preference
2507        if (getattr(scalar_op, 'output_types_preference', None)
2508                in (T.scal.upgrade_to_float, T.scal.upcast_out)):
2509            # this is the kind of op that we can screw with the input
2510            # dtypes by upcasting explicitly
2511            output_dtype = node.outputs[0].type.dtype
2512            new_inputs = []
2513            for i in node.inputs:
2514                if i.type.dtype == output_dtype:
2515                    new_inputs.append(i)
2516                else:
2517                    try:
2518                        # works only for scalars
2519                        cval_i = get_scalar_constant_value(i,
2520                                                           only_process_constants=True)
2521                        if all(i.broadcastable):
2522                            new_inputs.append(T.shape_padleft(
2523                                T.cast(cval_i, output_dtype),
2524                                i.ndim))
2525                        else:
2526                            if shape_i is None:
2527                                return
2528                            new_inputs.append(
2529                                T.alloc(T.cast(cval_i, output_dtype),
2530                                        *[shape_i(d)(i)
2531                                          for d in xrange(i.ndim)]))
2532                            # print >> sys.stderr, "AAA",
2533                            # *[Shape_i(d)(i) for d in xrange(i.ndim)]
2534                    except NotScalarConstantError:
2535                        # for the case of a non-scalar
2536                        if isinstance(i, T.TensorConstant):
2537                            new_inputs.append(T.cast(i, output_dtype))
2538                        else:
2539                            new_inputs.append(i)
2540
2541            if new_inputs != node.inputs:
2542                rval = [node.op(*new_inputs)]
2543                if rval[0].type != node.outputs[0].type:
2544                    # This can happen for example when floatX=float32
2545                    # and we do the true division between and int64
2546                    # and a constant that will get typed as int8.
2547
2548                    # As this is just to allow merging more case, if
2549                    # the upcast don't work, we can just skip it.
2550                    return
2551
2552                # Copy over output stacktrace from before upcasting
2553                copy_stack_trace(node.outputs[0], rval)
2554                return rval
2555
2556##################
2557# Subtensor opts #
2558##################
2559
2560
2561@register_useless
2562@register_canonicalize
2563@register_specialize
2564@gof.local_optimizer([IncSubtensor])
2565def local_useless_inc_subtensor(node):
2566    """
2567    Remove IncSubtensor, when we overwrite the full inputs with the
2568    new value.
2569
2570    """
2571    if not isinstance(node.op, IncSubtensor):
2572        return
2573    if node.op.set_instead_of_inc is False:
2574        # This is an IncSubtensor, so the init value must be zeros
2575        try:
2576            c = get_scalar_constant_value(node.inputs[0],
2577                                          only_process_constants=True)
2578            if c != 0:
2579                return
2580        except NotScalarConstantError:
2581            return
2582    if (node.inputs[0].ndim != node.inputs[1].ndim or
2583            node.inputs[0].broadcastable != node.inputs[1].broadcastable):
2584        # FB: I didn't check if this case can happen, but this opt
2585        # don't support it.
2586        return
2587    # We have a SetSubtensor or an IncSubtensor on zeros
2588    # If is this IncSubtensor useful?
2589
2590    # Check that we keep all the original data.
2591    # Put the constant inputs in the slice.
2592    idx_cst = get_idx_list(node.inputs[1:], node.op.idx_list)
2593    if all(isinstance(e, slice) and e.start is None and
2594           e.stop is None and (e.step is None or T.extract_constant(e.step,
2595                               only_process_constants=True) == -1)
2596           for e in idx_cst):
2597        # IncSubtensor broadcast node.inputs[1] on node.inputs[0]
2598        # based on run time shapes, so we must check they are the same.
2599        if not hasattr(node.fgraph, 'shape_feature'):
2600            return
2601        if not node.fgraph.shape_feature.same_shape(node.inputs[0],
2602                                                    node.inputs[1]):
2603            return
2604        # There is no reverse, so we don't need a replacement.
2605        if all(e.step is None
2606               for e in node.op.idx_list):
2607            # They are the same shape, so we can remore this IncSubtensor
2608            return [node.inputs[1]]
2609        ret = Subtensor(node.op.idx_list)(*node.inputs[1:])
2610        # Copy over previous output stacktrace
2611        copy_stack_trace(node.outputs, ret)
2612        return [ret]
2613
2614
2615@register_canonicalize
2616@gof.local_optimizer([AdvancedIncSubtensor1])
2617def local_set_to_inc_subtensor(node):
2618    """
2619    AdvancedIncSubtensor1(x, x[ilist]+other, ilist, set_instead_of_inc=True) ->
2620    AdvancedIncSubtensor1(x, other, ilist, set_instead_of_inc=False)
2621
2622    """
2623    if (isinstance(node.op, AdvancedIncSubtensor1) and
2624            node.op.set_instead_of_inc and
2625            node.inputs[1].owner and
2626            isinstance(node.inputs[1].owner.op, Elemwise) and
2627            isinstance(node.inputs[1].owner.op.scalar_op, scalar.Add)):
2628        addn = node.inputs[1].owner
2629        subn = None
2630        other = None
2631
2632        if (addn.inputs[0].owner and
2633                isinstance(addn.inputs[0].owner.op, AdvancedSubtensor1)):
2634            subn = addn.inputs[0].owner
2635            other = addn.inputs[1]
2636        elif (addn.inputs[1].owner and
2637              isinstance(addn.inputs[1].owner.op, AdvancedSubtensor1)):
2638            subn = addn.inputs[1].owner
2639            other = addn.inputs[0]
2640        else:
2641            return
2642        if (subn.inputs[1] != node.inputs[2] or
2643                subn.inputs[0] != node.inputs[0]):
2644            return
2645        ret = advanced_inc_subtensor1(node.inputs[0], other, node.inputs[2])
2646        # Copy over previous output stacktrace
2647        # Julian: I'm not sure about this at all...
2648        copy_stack_trace(node.outputs, ret)
2649        return [ret]
2650
2651
2652@register_useless
2653@register_canonicalize
2654@register_specialize
2655@gof.local_optimizer([Subtensor])
2656def local_useless_slice(node):
2657    """
2658    Remove Subtensor of the form X[0, :] -> X[0]
2659    """
2660    if isinstance(node.op, Subtensor):
2661        slices = get_idx_list(node.inputs, node.op.idx_list)
2662        last_slice = len(slices)
2663        for s in slices[::-1]:
2664            # check if slice and then check slice indices
2665            if (isinstance(s, slice) and s.start is None and s.stop is None and
2666                    (s.step is None or T.extract_constant(s.step,
2667                                                          only_process_constants=True) == 1)):
2668                last_slice -= 1
2669            else:
2670                break
2671        # check if we removed something
2672        if last_slice < len(slices):
2673            subtens = Subtensor(slices[:last_slice])
2674            sl_ins = Subtensor.collapse(slices[:last_slice],
2675                                        lambda x: isinstance(x, T.Variable))
2676            out = subtens(node.inputs[0], *sl_ins)
2677            # Copy over previous output stacktrace
2678            copy_stack_trace(node.outputs, out)
2679            return [out]
2680
2681
2682@register_canonicalize
2683@register_specialize
2684@gof.local_optimizer([Subtensor, AdvancedSubtensor1])
2685def local_useless_subtensor(node):
2686    """
2687    Remove Subtensor/AdvancedSubtensor1 if it takes the full input. In the
2688    AdvancedSubtensor1 case, the full input is taken when the indices are
2689    equivalent to `arange(0, input.shape[0], 1)` using either an explicit
2690    list/vector or the ARange op.
2691
2692    """
2693
2694    # If the optimization is tried over a node that is not a part of graph before
2695    if not hasattr(node, 'fgraph'):
2696        return
2697
2698    # This optimization needs ShapeOpt and fgraph.shape_feature
2699    if not hasattr(node.fgraph, 'shape_feature'):
2700        return
2701
2702    shape_of = node.fgraph.shape_feature.shape_of
2703
2704    if isinstance(node.op, Subtensor):
2705        cdata = node.op.get_constant_idx(node.inputs, allow_partial=True,
2706                                         only_process_constants=True)
2707        for pos, idx in enumerate(cdata):
2708            if not isinstance(idx, slice):
2709                # If idx is not a slice, this means we remove this dimension
2710                # from the output, so the subtensor is not useless
2711                return False
2712            if idx.start is not None and idx.start != 0:
2713                # If the start of the slice is different from 0, or is a
2714                # variable, then we assume the subtensor is not useless
2715                return False
2716            if idx.step is not None and idx.step != 1:
2717                # If we are going backwards, or skipping elements, then this
2718                # is not a useless subtensor
2719                return False
2720
2721        for pos, idx in enumerate(cdata):
2722
2723            length_pos = shape_of[node.inputs[0]][pos]
2724
2725            if isinstance(idx.stop, (integer_types, np.integer)):
2726                length_pos_data = sys.maxsize
2727                try:
2728                    length_pos_data = get_scalar_constant_value(length_pos,
2729                                                                only_process_constants=True)
2730                except NotScalarConstantError:
2731                    pass
2732
2733                if idx.stop < length_pos_data:
2734                    return False
2735            elif isinstance(idx.stop, gof.Variable):
2736                length_pos_shape_i = idx.stop
2737                # length_pos is a tensor variable, but length_pos_shape_i
2738                # is a scalar variable. We try to see if they represent
2739                # the same underlying variable.
2740                if (length_pos_shape_i.owner and
2741                        isinstance(length_pos_shape_i.owner.op,
2742                                   T.ScalarFromTensor)):
2743                    length_pos_shape_i = length_pos_shape_i.owner.inputs[0]
2744                elif (length_pos.owner and
2745                      isinstance(length_pos.owner.op, T.TensorFromScalar)):
2746                    length_pos = length_pos.owner.inputs[0]
2747                else:
2748                    # We did not find underlying variables of the same type
2749                    return False
2750
2751                # The type can be different: int32 vs int64. length_pos
2752                # should always be int64 as that is what the shape
2753                # tracker keep. Subtensor accept any scalar int{8,16,32,64}
2754                # as index type.
2755                assert str(length_pos.type.dtype) == "int64"
2756                assert str(length_pos_shape_i.type.dtype) in ["int8", "int16",
2757                                                              "int32", "int64"]
2758
2759                # length_pos_shape_i cannot be None
2760                if length_pos_shape_i != length_pos:
2761                    return False
2762            elif idx.stop is None:
2763                pass
2764            else:
2765                return False
2766    elif isinstance(node.op, AdvancedSubtensor1):
2767        # get length of the indexed tensor along the first axis
2768        try:
2769            length = get_scalar_constant_value(shape_of[node.inputs[0]][0],
2770                                               only_process_constants=True)
2771        except NotScalarConstantError:
2772            return False
2773
2774        # get index (which must be a vector by definition)
2775        idx = node.inputs[1]
2776
2777        # `idx` must be equivalent to [0,1,...,shape[0] - 1] to qualify for
2778        # this optimization
2779        if isinstance(idx, T.Constant):
2780            idx = idx.value
2781            if len(idx) != length:
2782                return False
2783            if np.any(idx != np.arange(length)):
2784                return False
2785        elif idx.owner is not None and isinstance(idx.owner.op, T.ARange):
2786            try:
2787                start, stop, step = map(lambda x: get_scalar_constant_value(x,
2788                                                                            only_process_constants=True),
2789                                        idx.owner.inputs)
2790            except NotScalarConstantError:
2791                return False
2792
2793            if start != 0:
2794                return False
2795            if stop != length:
2796                return False
2797            if step != 1:
2798                return False
2799        else:
2800            return False
2801    else:
2802        return False
2803
2804    # We don't need to copy over any stacktrace here,
2805    # because previous stacktrace should suffice.
2806    return [node.inputs[0]]
2807
2808
2809# fast_compile to allow opt subtensor(cast{float32}(make_vector))
2810@register_canonicalize('fast_compile')
2811@gof.local_optimizer([Subtensor])
2812def local_subtensor_lift(node):
2813    """
2814    unary(x)[idx] -> unary(x[idx])#any broadcast pattern.
2815
2816    Handles the following unary ops:
2817    elemwise(x,...)[idx] -> elemwise(x[idx],...)
2818      when x,... are broadcasted scalar or not broadcasted at all
2819    rebroadcast(x)[idx] => rebroadcast(x[idx])
2820
2821    """
2822    if isinstance(node.op, Subtensor):
2823        u = node.inputs[0]
2824        if not u.owner or len(u.clients) > 1:
2825            return False
2826
2827        if isinstance(u.owner.op, T.Elemwise) and len(u.owner.inputs) == 1:
2828            idx = node.inputs[1:]
2829            x_idx = node.op(u.owner.inputs[0], *idx)
2830            # Copy over previous output stacktrace
2831            copy_stack_trace(node.outputs, x_idx)
2832            ret = u.owner.op(x_idx)
2833            # Copy over previous output stacktrace
2834            # and stacktrace from previous unary operation
2835            copy_stack_trace([node.outputs[0], node.inputs[0]], ret)
2836            return [ret]
2837
2838        if isinstance(u.owner.op, T.Elemwise):
2839            new_inputs = []
2840            if all([sum(i.type.broadcastable) == 0 for i in u.owner.inputs]):
2841                # There is no broadcastable in the inputs
2842                idx = node.inputs[1:]
2843                new_inputs = [node.op(i, *idx) for i in u.owner.inputs]
2844                # Copy over previous output stacktrace
2845                copy_stack_trace(node.outputs[0], new_inputs)
2846
2847                ret = u.owner.op(*new_inputs)
2848                # Copy over previous output stacktrace
2849                # and stacktrace from previous unary operation
2850                copy_stack_trace([node.outputs[0], node.inputs[0]], ret)
2851                return [ret]
2852            elif all([sum(i.type.broadcastable) in [i.ndim, 0]
2853                      for i in u.owner.inputs]):
2854                # There is no broadcastable in the inputs or it is scalar
2855                idx = node.inputs[1:]
2856                new_inputs = []
2857                for i in u.owner.inputs:
2858                    if sum(i.type.broadcastable) == 0:
2859                        new_inputs.append(node.op(i, *idx))
2860                    else:
2861                        # If the subtensor remove some dims, we must
2862                        # lower the number of dimensions of this scalar.
2863                        if node.outputs[0].ndim == i.ndim:
2864                            new_inputs.append(i)
2865                        else:
2866                            new_inputs.append(
2867                                i.dimshuffle(['x'] * node.outputs[0].ndim))
2868
2869                # Copy over previous output stacktrace
2870                copy_stack_trace(node.outputs[0], new_inputs)
2871
2872                ret = u.owner.op(*new_inputs)
2873                # Copy over previous output stacktrace
2874                # and stacktrace from previous unary operation
2875                copy_stack_trace([node.outputs[0], node.inputs[0]], ret)
2876                return [ret]
2877
2878        if isinstance(u.owner.op, T.Rebroadcast):
2879            # make sure that Rebroadcast has only 1 input
2880            assert len(u.owner.inputs) == 1
2881
2882            # Subtensor might reduce dim., adapt broadcast pattern accordingly
2883            new_axis = []
2884
2885            # loop through indices being subtensor-ed
2886            # i indexes broadcastable pattern before subtensor
2887            # j indexes broadcastable pattern after subtensor
2888            j = 0
2889            for (i, x) in enumerate(node.op.idx_list):
2890                # if its not a slice, it will reduce the dimension, should
2891                # not appear in the broascastable dimensions
2892                if isinstance(x, slice):
2893                    new_axis += [(j, u.broadcastable[i])]
2894                    j += 1
2895            # now keep the broadcastable pattern of all
2896            # items not appearing in subtensor list
2897            for i in xrange(len(node.op.idx_list), len(u.broadcastable)):
2898                new_axis += [(j, u.broadcastable[i])]
2899                j += 1
2900
2901            subt_x = node.op(u.owner.inputs[0], *node.inputs[1:])
2902            # Copy over previous output stacktrace
2903            copy_stack_trace(node.outputs[0], subt_x)
2904
2905            rbcast_subt_x = T.Rebroadcast(*new_axis)(subt_x)
2906            # Copy over previous output stacktrace
2907            # and stacktrace from previous unary operation
2908            copy_stack_trace([node.outputs[0], node.inputs[0]], rbcast_subt_x)
2909
2910            return [rbcast_subt_x]
2911
2912
2913def merge_two_slices(slice1, len1, slice2, len2):
2914    """
2915     This function merges two slices into a single slice. The code works on
2916     the assumption that:
2917
2918     a) slice1 is actually a slice and not an index, while slice2
2919        can be just an index.
2920
2921     b) the two slices **have been applied consecutively** on the same
2922        tensor
2923
2924    The output slice is **not** in canonical form, but actually just a slice
2925    that can be applied to a tensor to produce the same output as applying
2926    the two consecutive slices.
2927    ``len1`` is the length of the tensor **before** applying the first slice,
2928    while ``len2`` is the length **after** applying the first slice.
2929    """
2930    list_opt = [local_abs_merge, local_mul_switch_sink,
2931                local_upcast_elemwise_constant_inputs,
2932                local_useless_switch, constant_folding]
2933
2934    if type(slice1) is not slice:
2935        raise ValueError(('First provided slice should actually be of type'
2936                         'slice and not an index !'), slice1)
2937    sl1, reverse1 = get_canonical_form_slice(slice1, len1)
2938    sl2, reverse2 = get_canonical_form_slice(slice2, len2)
2939
2940    if type(sl2) is not slice:
2941        if reverse1 is None:
2942            # The first slice is not in reverse, which makes things a lot
2943            # more clear.
2944            # In this case we need to take care only of the special cases:
2945            # len2 <=0    -> throw index error regardless of sl2
2946            # sl2 > len2  -> throw index error
2947            # sl2 < -len2 -> throw index error
2948            # To get a index error we simply use len1+1 to indicate we are
2949            # out of bounds, because passing this index through the formula
2950            # of getting the mixed slice is not guaranteed to result in an
2951            # index error. The **issue though** if that the error will
2952            # complain about accessing element len1+1 which is probably not
2953            # too intuitive for the user
2954            val = sl1.start + sl2 * sl1.step
2955            val = T.switch(T.le(len2, 0), len1 + 1, val)
2956            val = T.switch(T.ge(sl2, len2), len1 + 1, val)
2957            val = T.switch(T.lt(sl2, 0), - len1 - 1, val)
2958            if sl1.step:
2959                val = T.switch(T.eq(sl1.step, 0), len1 + 1, val)
2960            val = pre_greedy_local_optimizer(list_opt, val)
2961            return val
2962        else:
2963            # We are in the more complex case when we do not actually know
2964            # if the first slice was in reverse or not.
2965            # in case it was not in reverse:
2966            p_val = sl1.start + sl2 * sl1.step
2967            # case it was in reverse we need to realize that we do not want
2968            # the k-th element from sl.start but the k-th element from
2969            # sl.stop backwards
2970            n_val = sl1.stop - 1 - sl2 * sl1.step
2971            if config.warn.subtensor_merge_bug:
2972                warnings.warn((
2973                    'Your current code is fine, but Theano versions '
2974                    'prior to 0.5rc2 might have given an incorrect result. '
2975                    'To disable this warning, set the Theano flag '
2976                    'warn.subtensor_merge_bug to False.'))
2977            # we need to pick either n_val or p_val and then follow same
2978            # steps as above for covering the index error cases
2979            val = T.switch(T.lt(reverse1, 0), n_val, p_val)
2980            val = T.switch(T.le(len2, 0), len1 + 1, val)
2981            val = T.switch(T.ge(sl2, len2), len1 + 1, val)
2982            val = T.switch(T.lt(sl2, 0), - len1 - 1, val)
2983            if sl1.step:
2984                val = T.switch(T.eq(sl1.step, 0), len1 + 1, val)
2985            val = pre_greedy_local_optimizer(list_opt, val)
2986            return val
2987    else:
2988        # We are deleaing with two slices that need to be put together
2989        # according to the two steps we have 4 different combinations of
2990        # positive/negative. I will denote the case I'm looking at by
2991        # suffixes to the variables (nn,np,pn,pp):
2992        flen = sl2.stop - sl2.start
2993        p_step = sl1.step * sl2.step
2994        n_step = sl1.step * sl2.step * -1
2995
2996        pp_start = T.minimum(sl1.start + sl2.start * sl1.step, sl1.stop)
2997        pp_stop = T.minimum(sl1.start + sl2.stop * sl1.step, sl1.stop)
2998
2999        pn_stop = sl1.start + (sl2.start - 1) * sl1.step
3000        pn_stop = T.switch(T.and_(T.lt(pn_stop, 0),
3001                                  T.gt(flen, 0)),
3002                           -len1 - 1,
3003                           T.minimum(pn_stop, sl1.stop))
3004        pn_start = sl1.start + (sl2.stop - 1) * sl1.step
3005        pn_start = T.minimum(pn_start, sl1.stop)
3006        pn_start = T.maximum(pn_start, 0)
3007
3008        np_stop = sl1.stop - sl2.stop * sl1.step - 1
3009        np_stop = T.switch(T.and_(T.lt(np_stop, 0),
3010                                  T.gt(flen, 0)),
3011                           -len1 - 1,
3012                           T.maximum(sl1.start - 1, np_stop))
3013        np_start = T.maximum(sl1.start, sl1.stop - sl2.start * sl1.step - 1)
3014
3015        nn_start = T.maximum(sl1.start,
3016                             (sl1.stop - 1) - (sl2.stop - 1) * sl1.step)
3017        nn_stop = T.maximum(sl1.start, sl1.stop - sl2.start * sl1.step)
3018
3019        start = T.switch(T.lt(reverse2 * reverse1, 0),
3020                         T.switch(T.lt(reverse1, 0), np_start, pn_start),
3021                         T.switch(T.lt(reverse1, 0), nn_start,
3022                                  pp_start))
3023
3024        stop = T.switch(T.lt(reverse2 * reverse1, 0),
3025                        T.switch(T.lt(reverse1, 0), np_stop, pn_stop),
3026                        T.switch(T.lt(reverse1, 0), nn_stop, pp_stop))
3027
3028        step = T.switch(T.lt(reverse2 * reverse1, 0), n_step, p_step)
3029        start = T.switch(T.le(flen, 0), 0, start)
3030        stop = T.switch(T.le(flen, 0), 0, stop)
3031
3032        # The canonical form of the slice is pretty complicated
3033        # and is not simplified. We simplify it in advance here
3034        # as otherwise this create too many useless optimization that
3035        # DebugMode must check.
3036        start = pre_greedy_local_optimizer(list_opt, start)
3037        stop = pre_greedy_local_optimizer(list_opt, stop)
3038        step = pre_greedy_local_optimizer(list_opt, step)
3039        start = pre_greedy_local_optimizer(list_opt, start)
3040        stop = pre_greedy_local_optimizer(list_opt, stop)
3041        step = pre_greedy_local_optimizer(list_opt, step)
3042
3043        # Pre merge constant for the same reason.
3044        start, stop, step = pre_constant_merge([start, stop, step])
3045
3046        return slice(start, stop, step)
3047
3048
3049@register_canonicalize
3050@register_specialize
3051@gof.local_optimizer([Subtensor])
3052def local_subtensor_merge(node):
3053    """
3054    Refactored optimization to deal with all cases of tensor merging.
3055    Given a subgraph of the form Subtensor(Subtensor(u)), the optimization
3056    expresses all slices in a canonical form, and then merges them together.
3057
3058    """
3059
3060    if isinstance(node.op, Subtensor):
3061        u = node.inputs[0]
3062        if u.owner and isinstance(u.owner.op, Subtensor):
3063            # We can merge :)
3064            # x actual tensor on which we are picking slices
3065            x = u.owner.inputs[0]
3066            # slices of the first applied subtensor
3067            slices1 = get_idx_list(u.owner.inputs, u.owner.op.idx_list)
3068            slices2 = get_idx_list(node.inputs, node.op.idx_list)
3069            # Get the shapes of the vectors !
3070            try:
3071                # try not to introduce new shape into the graph
3072                xshape = node.fgraph.shape_feature.shape_of[x]
3073                ushape = node.fgraph.shape_feature.shape_of[u]
3074            except AttributeError:
3075                # Following the suggested use of shape_feature which should
3076                # consider the case when the compilation mode doesn't
3077                # include the ShapeFeature
3078                xshape = x.shape
3079                ushape = u.shape
3080
3081            merged_slices = []
3082            pos_2 = 0
3083            pos_1 = 0
3084            while (pos_1 < len(slices1)) and (pos_2 < len(slices2)):
3085                slice1 = slices1[pos_1]
3086                if type(slice1) is slice:
3087                    merged_slices.append(
3088                        merge_two_slices(slice1,
3089                                         xshape[pos_1],
3090                                         slices2[pos_2],
3091                                         ushape[pos_2]))
3092                    pos_2 += 1
3093                else:
3094                    merged_slices.append(slice1)
3095                pos_1 += 1
3096
3097            if pos_2 < len(slices2):
3098                merged_slices += slices2[pos_2:]
3099            else:
3100                merged_slices += slices1[pos_1:]
3101
3102            merged_slices = make_constant(merged_slices)
3103            subtens = Subtensor(merged_slices)
3104
3105            sl_ins = Subtensor.collapse(
3106                merged_slices,
3107                lambda x: isinstance(x, T.Variable))
3108            # Do not call make_node for test_value
3109            out = subtens(x, *sl_ins)
3110
3111            # Copy over previous output stacktrace
3112            # and stacktrace from previous slicing operation.
3113            # Why? Because, the merged slicing operation could have failed
3114            # because of either of the two original slicing operations
3115            orig_out = node.outputs[0]
3116            copy_stack_trace([orig_out, node.inputs[0]], out)
3117
3118            # Restore original broadcastable dimensions that `subtens()` may
3119            # have been unable to infer again
3120            if out.type != orig_out.type:
3121                assert out.dtype == orig_out.dtype
3122                assert out.ndim == orig_out.ndim
3123                out = T.patternbroadcast(out, orig_out.broadcastable)
3124                copy_stack_trace([orig_out, node.inputs[0]], out)
3125            return [out]
3126
3127
3128@register_useless
3129@register_canonicalize
3130@register_specialize
3131@gof.local_optimizer([Subtensor])
3132def local_subtensor_of_alloc(node):
3133    """
3134
3135    alloc(val)[x:y] -> alloc(val[...])
3136    alloc(val)[x:y] -> alloc(val)
3137    This can be seen as a lift, but it also reduce the number of computation/memory.
3138
3139    """
3140    if not isinstance(node.op, Subtensor):
3141        return False
3142    u = node.inputs[0]
3143    if u.owner is None:
3144        return False
3145    if not isinstance(u.owner.op, T.Alloc):
3146        return False
3147    slices = get_idx_list(node.inputs, node.op.idx_list)
3148    val = u.owner.inputs[0]
3149    dims = u.owner.inputs[1:]
3150    assert len(slices) <= len(dims)
3151
3152    # Number of dimensions added to val
3153    n_added_dims = u.ndim - val.ndim
3154    # Dimensions of the returned alloc
3155    nw_dims = []
3156    # Slices to take from val
3157    val_slices = []
3158
3159    for i, (sl, dim) in enumerate(zip(slices, dims)):
3160        # If val was not copied over that dim,
3161        # we need to take the appropriate subtensor on it.
3162        if i >= n_added_dims:
3163            # We check that the corresponding val dimensions was
3164            # not a broadcasted dimensions.
3165            if (val.type.ndim > (i - n_added_dims) and
3166                    val.type.broadcastable[i - n_added_dims]):
3167                val_slices.append(slice(None))
3168            else:
3169                val_slices.append(sl)
3170
3171        csl, _ = get_canonical_form_slice(sl, dim)
3172        if type(csl) is not slice:
3173            # That dimension is removed.
3174            pass
3175        else:
3176            nw_dim = csl.stop - csl.start
3177
3178            if csl.step != 1:
3179                # Do not add the ceil_intdiv() graphs in the graphs
3180                # when this is not needed as it prevent detecting the
3181                # correct broadcast pattern.
3182                nw_dim = T.ceil_intdiv(nw_dim, csl.step)
3183            nw_dims += [nw_dim]
3184
3185    nw_val = val[tuple(val_slices)]
3186    nw_dims += dims[len(slices):]
3187    if nw_val.ndim > len(nw_dims):
3188        return False
3189    rval = T.alloc(nw_val, *nw_dims)
3190    if type(rval) not in (list, tuple):
3191        rval = [rval]
3192    if rval[0].type != node.outputs[0].type:
3193        # It happen that the make_node() isn't able to infer the same pattern.
3194        # We know it is safe, so fix that.
3195        rval[0] = T.patternbroadcast(rval[0], node.outputs[0].broadcastable)
3196
3197    return rval
3198
3199
3200@register_canonicalize
3201@register_stabilize
3202@register_specialize
3203@gof.local_optimizer([Subtensor])
3204def local_subtensor_of_dot(node):
3205    """
3206    This optimization translates T.dot(A, B)[idxs] into T.dot(A[idxs_a], B[idxs_b]),
3207    where idxs_a and idxs_b are defined appropriately.
3208
3209    idxs_a is the first A.ndim-1 entries of idxs,
3210    and idxs_b is the remaining entries of idxs (if any),
3211    modified to skip the second-to-last dimension of B
3212    (because dot sums over this dimension).
3213
3214    """
3215    if not isinstance(node.op, Subtensor):
3216        return
3217    if (not node.inputs[0].owner or
3218            not isinstance(node.inputs[0].owner.op, T.Dot)):
3219        return
3220    # If there is other node that use the outputs of the dot
3221    # We don't want to compute twice the sub part.
3222    if len(node.inputs[0].clients) > 1:
3223        return
3224
3225    a = node.inputs[0].owner.inputs[0]
3226    b = node.inputs[0].owner.inputs[1]
3227
3228    idx_list = get_idx_list(node.inputs, node.op.idx_list)
3229
3230    num_a_indices = min(a.ndim - 1, len(idx_list))
3231    a_indices = idx_list[:num_a_indices]
3232    b_indices = idx_list[num_a_indices:]
3233
3234    # This is necessary because np.dot sums the last index of a with the second to last of b
3235    # so we want to skip the second-to-last index into b.
3236    # This wasn't necessary for a, because we just omitted the last index.
3237    # We skip this if b.ndim = 1, since then we just want b_sub = b, not b_sub = b[:]
3238    # (dot also handles b.ndim < 2 as a special case)
3239    if b.ndim > 1 and len(b_indices) >= b.ndim - 1:
3240        b_indices = (b_indices[:b.ndim - 2] +
3241                     (slice(None, None, None),) + b_indices[b.ndim - 2:])
3242
3243    a_sub = a.__getitem__(tuple(a_indices))
3244    b_sub = b.__getitem__(tuple(b_indices)) if b_indices else b
3245
3246    # Copy over previous output stacktrace to a_sub and b_sub,
3247    # because an error in the subtensor operation (e.g. an index error)
3248    # on either a or b must correspond to an error in the
3249    # subtensor operation on their dot product.
3250    copy_stack_trace(node.outputs[0], [a_sub, b_sub])
3251
3252    # Copy over previous output stacktrace and previous dot product stacktrace,
3253    # because an error here may correspond to an either in either the original
3254    # dot product, or in the dot product after the subtensor operation.
3255    r = T.dot(a_sub, b_sub)
3256    copy_stack_trace([node.outputs[0], node.inputs[0]], r)
3257
3258    return [r]
3259
3260
3261@register_canonicalize
3262@gof.local_optimizer([T.add])
3263def local_IncSubtensor_serialize(node):
3264    """
3265    When using Subtensor, gradient graphs can be ugly.
3266
3267    If we ask for grad(f(a[0]), a), we are going to get something like
3268
3269        IncSubtensor(Elemwise{second}(a, 0), g(f(a[0])), [0])
3270
3271    This might be ugly, but at least it's as fast as you could want.
3272    If we ask for grad(f(a[0], a[1], a[2]), a), it's much worse...
3273
3274        Elemwise{Add}
3275            IncSubtensor(Elemwise{second}(a, 0), g(f(a[0])), [0])
3276            IncSubtensor(Elemwise{second}(a, 0), g(f(a[1])), [1])
3277            IncSubtensor(Elemwise{second}(a, 0), g(f(a[2])), [2])
3278
3279    This is much worse because this time we have to produce 3 matrices
3280    the size of 'a', just so we can add them together.
3281
3282    This Op rearranges IncSubtensor's that all work on the same
3283    initial argument (here, Elemwise{second}(a,0)) into a chain.  The
3284    advantage of the chain structure is that each one can be optimized
3285    later in the pipeline to operate inplace.
3286
3287    Ideally, the op will do something like this:
3288
3289    #
3290    #  add(x, incsubtensor(b, c), incsubtensor(b, d))
3291    #  -> incsubtensor(incsubtensor(add(x,b,b), c), d)
3292
3293    """
3294    def movable(i):
3295        # Return True iff this is a incsubtensor that we can move
3296        return (i.owner and
3297                isinstance(i.owner.op, (IncSubtensor,
3298                                        AdvancedIncSubtensor1,
3299                                        AdvancedIncSubtensor,)) and
3300                i.type == o_type and
3301                len(i.clients) == 1 and
3302                not i.owner.op.set_instead_of_inc)
3303
3304    if node.op == T.add:
3305        o_type = node.outputs[0].type
3306
3307        movable_inputs = [i for i in node.inputs if movable(i)]
3308
3309        if movable_inputs:
3310            new_inputs = ([i for i in node.inputs if not movable(i)] +
3311                          [mi.owner.inputs[0] for mi in movable_inputs])
3312            if len(new_inputs) == 0:
3313                new_add = new_inputs[0]
3314            else:
3315                new_add = T.add(*new_inputs)
3316
3317                # Copy over stacktrace from original output, as an error
3318                # (e.g. an index error) in this add operation should
3319                # correspond to an error in the original add operation.
3320                copy_stack_trace(node.outputs[0], new_add)
3321
3322            # stack up the new incsubtensors
3323            tip = new_add
3324            for mi in movable_inputs:
3325                assert tip.type == o_type
3326                assert tip.type == mi.owner.inputs[0].type
3327                tip = mi.owner.op(tip, *mi.owner.inputs[1:])
3328                # Copy over stacktrace from outputs of the original
3329                # "movable" operation to the new operation.
3330                copy_stack_trace(node.outputs + mi.owner.outputs, tip)
3331
3332            return [tip]
3333
3334        # print incsub_inputs, [id(i.owner.inputs[0]) for i in incsub_inputs]
3335
3336# We register it in a TopoOptimizer inside the canonizer EQ optimizer.
3337# Otherwise in some cases it was making the EQ optimizer use 45. In
3338# the TopoOptimizer, the EQ only use 5 passes.
3339compile.optdb.register('pre_local_IncSubtensor_serialize',
3340                       in2out(local_IncSubtensor_serialize),
3341                       # Just before canonizer
3342                       0.99, 'fast_run')
3343
3344
3345# after priority 50 Destructive inplace operations
3346# gemm is the first one now, at priority 70
3347
3348@gof.local_optimizer([IncSubtensor], inplace=True)
3349def local_inplace_setsubtensor(node):
3350    """
3351    Also work for GpuIncSubtensor.
3352
3353    """
3354    if isinstance(node.op, IncSubtensor) and not node.op.inplace:
3355        dta = node.op.destroyhandler_tolerate_aliased
3356        new_op = node.op.__class__(
3357            node.op.idx_list, inplace=True,
3358            set_instead_of_inc=node.op.set_instead_of_inc,
3359            destroyhandler_tolerate_aliased=dta)
3360        new_node = new_op(*node.inputs)
3361        val = getattr(node.outputs[0].tag, 'nan_guard_mode_check', True)
3362        new_node.tag.nan_guard_mode_check = val
3363
3364        # Copy stacktrace from original outputs to new outputs.
3365        # This is sensible, because the new operation is the
3366        # same as the old one, but now with different attributes.
3367        copy_stack_trace(node.outputs, new_node)
3368        return [new_node]
3369    return False
3370compile.optdb.register('local_inplace_setsubtensor',
3371                       TopoOptimizer(
3372                           local_inplace_setsubtensor,
3373                           failure_callback=TopoOptimizer.warn_inplace),
3374                       60, 'fast_run', 'inplace')  # DEBUG
3375
3376
3377@gof.local_optimizer([AdvancedIncSubtensor1], inplace=True)
3378def local_inplace_incsubtensor1(node):
3379    """
3380    Also work for GpuAdvancedIncSubtensor1.
3381
3382    """
3383    if isinstance(node.op, AdvancedIncSubtensor1) and not node.op.inplace:
3384        new_op = node.op.clone_inplace()
3385        new_node = new_op(*node.inputs)
3386
3387        # Copy stacktrace from original outputs to new outputs.
3388        # This is sensible, because the new operation is the
3389        # same as the old one, but now with different attributes.
3390        copy_stack_trace(node.outputs, new_node)
3391        return [new_node]
3392    return False
3393compile.optdb.register('local_inplace_incsubtensor1',
3394                       TopoOptimizer(
3395                           local_inplace_incsubtensor1,
3396                           failure_callback=TopoOptimizer.warn_inplace),
3397                       60, 'fast_run', 'inplace')  # DEBUG
3398
3399
3400# Register old name
3401@register_canonicalize("local_incsubtensor_of_allocs")
3402@register_stabilize("local_incsubtensor_of_allocs")
3403@gof.local_optimizer([IncSubtensor,
3404                      AdvancedIncSubtensor,
3405                      AdvancedIncSubtensor1])
3406def local_incsubtensor_of_zeros(node):
3407    """
3408    IncSubtensor(x, zeros, idx) -> x
3409
3410    """
3411    if (isinstance(node.op, (IncSubtensor,
3412                             AdvancedIncSubtensor,
3413                             AdvancedIncSubtensor1)) and
3414            not node.op.set_instead_of_inc):
3415        x = node.inputs[0]
3416        y = node.inputs[1]
3417        try:
3418            # Don't use only_process_constants=True. We need to
3419            # investigate Alloc of 0s but with non constant shape.
3420            if get_scalar_constant_value(y, elemwise=False) == 0:
3421                # No need to copy over the stacktrace,
3422                # because x should already have a stacktrace
3423                return [x]
3424        except NotScalarConstantError:
3425            return
3426
3427
3428@register_canonicalize
3429@register_specialize
3430@gof.local_optimizer([IncSubtensor])
3431def local_incsubtensor_of_zeros_to_setsubtensor(node):
3432    """
3433    IncSubtensor(zeros, x, ...) -> SetSubtensor(zeros, x, ...)
3434    """
3435    if (isinstance(node.op, (IncSubtensor)) and not node.op.set_instead_of_inc):
3436        x = node.inputs[0]
3437
3438        if isinstance(x, T.Constant) and not np.any(x.data):
3439            return [IncSubtensor(node.op.idx_list,
3440                                 node.op.inplace,
3441                                 set_instead_of_inc=True,
3442                                 destroyhandler_tolerate_aliased=node.op.destroyhandler_tolerate_aliased,
3443                                 )(*node.inputs)]
3444
3445
3446@register_canonicalize('local_setsubtensor_of_allocs')
3447@register_stabilize('local_setsubtensor_of_allocs')
3448@gof.local_optimizer([IncSubtensor])
3449def local_setsubtensor_of_constants(node):
3450    """
3451    SetSubtensor(x, x[idx], idx) -> x
3452
3453    when x is constant or alloc.
3454
3455    """
3456    if isinstance(node.op, IncSubtensor) and node.op.set_instead_of_inc:
3457        x = node.inputs[0]
3458        y = node.inputs[1]
3459
3460        # Don't use only_process_constants=True. We need to
3461        # investigate Alloc of 0s but with non constant shape.
3462        try:
3463            replace_x = get_scalar_constant_value(x, elemwise=False)
3464        except NotScalarConstantError:
3465            return
3466
3467        try:
3468            replace_y = get_scalar_constant_value(y, elemwise=False)
3469        except NotScalarConstantError:
3470            return
3471
3472        if replace_x == replace_y:
3473
3474            # No need to copy over the stacktrace,
3475            # because x should already have a stacktrace
3476            return [x]
3477        else:
3478            return False
3479
3480
3481@register_canonicalize
3482@register_stabilize
3483@gof.local_optimizer([AdvancedSubtensor1])
3484def local_adv_sub1_adv_inc_sub1(node):
3485    """Optimize the possible AdvSub1(AdvSetSub1(...), ...).
3486
3487    AdvancedSubtensor1(AdvancedSetSubtensor1(x, y, idx), idx) -> y
3488
3489    Notes
3490    -----
3491    This opt add AssertOp. Otherwise, it would remove shape and
3492    index error. If you want to get rid of them, see the
3493    :ref:`unsafe_optimization` section.
3494
3495    WARNING:
3496    A previous version of this optimization also matched
3497    AdvancedSubtensor1(AdvancedIncSubtensor1(0s, y, idx), idx) -> y
3498    This is incorrect when there are duplicate indices.
3499    The current version warns the user about potential past issues.
3500
3501    """
3502    if not isinstance(node.op, AdvancedSubtensor1):
3503        return
3504    inp = node.inputs[0]
3505    if (not inp.owner or
3506            not isinstance(inp.owner.op, AdvancedIncSubtensor1)):
3507        return
3508    idx = node.inputs[1]
3509    idx2 = inp.owner.inputs[2]
3510    x = inp.owner.inputs[0]
3511    y = inp.owner.inputs[1]
3512    if idx is not idx2:
3513        return
3514    if (not inp.owner.op.set_instead_of_inc and
3515            # Don't use only_process_constants=True. We need to
3516            # investigate Alloc of 0s but with non constant shape.
3517            T.extract_constant(x, elemwise=False) != 0):
3518        return
3519
3520    if not inp.owner.op.set_instead_of_inc:
3521        if config.warn.inc_subtensor1_opt:
3522            warnings.warn(
3523                'Your current code is fine, but Theano versions '
3524                'between 0.7rc1 and 0.10 (or development versions '
3525                'between Nov. 2014 and May 2017) '
3526                'might have given incorrect results. This graph has '
3527                'following pattern: inc_subtensor(zeros[idx], x)[idx], '
3528                'where idx is an array of integers. This used to be '
3529                'optimized to "x", which is incorrect if there are '
3530                'duplicated indices in idx. '
3531                'To disable this warning, set the Theano flag '
3532                'warn.inc_subtensor1_opt to False.')
3533        return
3534
3535    cond = [T.all(T.and_(T.lt(idx, x.shape[0]), T.ge(idx, -x.shape[0])))]
3536    if not node.fgraph.shape_feature.same_shape(idx, y, 0, 0):
3537        cond.append(T.eq(idx.shape[0], y.shape[0]))
3538    r = Assert("Bad indexing or shapes in a AdvancedIncSubtensor1 "
3539               "that was optimized away")(y, *cond)
3540    copy_stack_trace(y, r)
3541
3542    if r.dtype == node.outputs[0].dtype:
3543        return [r]
3544    # It is possible that y is upcast or downcast to x.dtype.
3545    # In all case, as we set or add with 0, we can just cast y.
3546    r2 = T.cast(r, node.outputs[0].dtype)
3547
3548    # Copy over stacktrace from before casting, since
3549    # we don't expect problems in the casting operation,
3550    # and any problems in the indexing would have been spotted above.
3551    copy_stack_trace(r, r2)
3552    return [r2]
3553
3554
3555@register_specialize
3556@register_stabilize
3557@register_canonicalize
3558@register_useless
3559@gof.local_optimizer([IncSubtensor,
3560                      AdvancedIncSubtensor,
3561                      AdvancedIncSubtensor1])
3562def local_useless_inc_subtensor_alloc(node):
3563    """
3564    Replaces an [Advanced]IncSubtensor[1], whose increment is an `alloc` of
3565    a fully or partially broadcastable variable, by one that skips the
3566    intermediate `alloc` where possible.
3567
3568    """
3569    if isinstance(node.op, (IncSubtensor,
3570                            AdvancedIncSubtensor,
3571                            AdvancedIncSubtensor1)):
3572        x = node.inputs[0]
3573        y = node.inputs[1]
3574        i = node.inputs[2:]
3575
3576        if y.owner is not None and isinstance(y.owner.op, T.Alloc):
3577            # `z` is the input of the Alloc op, i.e. T.alloc(z, <shape>)
3578            z = y.owner.inputs[0]
3579
3580            try:
3581                shape_feature = node.fgraph.shape_feature
3582            except AttributeError:
3583                # The shape feature may not be available in some mode, but we
3584                # need it for this optimization, so don't continue.
3585                return False
3586
3587            shape_of = shape_feature.shape_of
3588            same_shape = shape_feature.same_shape
3589
3590            # Get the subtensor of `x` indexed by `i` in order to compare
3591            # shapes later.
3592            if isinstance(node.op, IncSubtensor):
3593                xi = Subtensor(node.op.idx_list)(x, *i)
3594            elif isinstance(node.op, AdvancedIncSubtensor):
3595                xi = advanced_subtensor(x, *i)
3596            elif isinstance(node.op, AdvancedIncSubtensor1):
3597                xi = advanced_subtensor1(x, *i)
3598            else:
3599                raise Exception('Should never happen!')
3600
3601            reason = 'local_useless_incsubtensor_alloc'
3602
3603            # Add `xi` to the shape feature `fgraph`. This is important for
3604            # shape inference later because the variable must be part of the
3605            # function graph in order to call `same_shape` on it.
3606            if xi not in shape_of:
3607                shape_feature.on_import(node.fgraph, xi.owner,
3608                                        '%s: add `xi`' % reason)
3609
3610            # `xi` may have more dimensions than `y` since the subtensor ops
3611            # do automatic broadcasting of the increment internally. Thus, we
3612            # need to make the leading implicitly broadcasted dimensions
3613            # explicit for shape comparison later.
3614            if xi.ndim > y.ndim:
3615                y = T.shape_padleft(y, xi.ndim - y.ndim)
3616                if y not in shape_of:
3617                    shape_feature.on_import(node.fgraph, y.owner,
3618                                            '%s: add `y`' % reason)
3619
3620            # Build `z_broad` explicitly to include extra implicit dimensions.
3621            z_broad = ((True,) * (xi.ndim - z.ndim) + z.broadcastable)
3622
3623            cond = [
3624                # The shapes of `y` and `xi` must either agree or `y` may
3625                # also have shape equal to 1 which may be treated as a
3626                # broadcastable dimension by the subtensor op.
3627                T.or_(T.eq(y.shape[k], 1), T.eq(y.shape[k], xi.shape[k]))
3628                # Loop over all dimensions.
3629                for k in xrange(xi.ndim)
3630                # We need to check the above shapes, if
3631                # * the pre-alloc increment `z` is broadcastable in
3632                # dimension `k` (if it isn't, then the shapes of `z` and
3633                # `y` are the same by the definition of the `Alloc` op in
3634                # this dimension and replacing `y` by `z` will not hide a
3635                # shape error), and
3636                # * `xi` and `y` do not have the same shape in dimension
3637                # `k` or we cannot infer the shape statically (if the
3638                # shapes of `xi` and `y` are not the same, then replacing
3639                # `y` by `z` will hide the shape error of `y`), and
3640                # * the shape of `y` is not equal to 1 or we cannot infer
3641                # the shape statically (if the shape of `y` is equal to
3642                # 1, then `y` is broadcasted by the inc_subtensor op
3643                # internally, so the shapes of `xi` and `y` do not need
3644                # to match in dimension `k`; else we need to check at
3645                # runtime that the shape of `y` is either 1 or the same
3646                # as `xi` or otherwise replacing `y` by `z` will hide a
3647                # shape error).
3648                if (z_broad[k] and
3649                    not same_shape(xi, y, dim_x=k, dim_y=k) and
3650                    shape_of[y][k] != 1)]
3651
3652            if len(cond) > 0:
3653                msg = '`x[i]` and `y` do not have the same shape.'
3654                z = Assert(msg)(z, *cond)
3655
3656            r = node.op(x, z, *i)
3657            # Copy over stacktrace from previous output, since
3658            # we don't expect problems when removing the intermediate
3659            # alloc operation and so we still want to point at the line
3660            # of the inc_subtensor operation.
3661            copy_stack_trace(node.outputs, r)
3662
3663            return [r]
3664
3665
3666####################
3667# Rebroadcast opts #
3668####################
3669
3670@register_useless
3671@register_canonicalize
3672@register_specialize
3673@gof.local_optimizer([T.Rebroadcast])
3674def local_useless_rebroadcast(node):
3675    """
3676    Remove Rebroadcast if id does not actually change the broadcasting pattern.
3677
3678    """
3679    if isinstance(node.op, T.Rebroadcast):
3680        x = node.inputs[0]
3681        if np.all(x.broadcastable == node.outputs[0].broadcastable):
3682            # No broadcastable flag was modified
3683            # No need to copy over stack trace,
3684            # because x should already have a stack trace.
3685            return [x]
3686        else:
3687            # Keep the flags that modify something
3688            new_axis = {}
3689            for dim, bc in list(node.op.axis.items()):
3690                if x.broadcastable[dim] != bc:
3691                    new_axis[dim] = bc
3692            if new_axis == node.op.axis:
3693                # All flags are useful
3694                return
3695            else:
3696                r = T.Rebroadcast(*list(new_axis.items()))(x)
3697                # Copy over stacktrace from previous output
3698                copy_stack_trace(node.outputs, r)
3699                return [r]
3700
3701
3702@register_canonicalize
3703@register_specialize
3704@gof.local_optimizer([T.Rebroadcast])
3705def local_rebroadcast_lift(node):
3706    """
3707    Lifts Rebroadcast through unary Elemwise operations,
3708    and merges consecutive Rebroadcasts.
3709
3710    Rebroadcast(Elemwise(x)) => Elemwise(Rebroadcast(x))
3711    Rebroadcast(Rebroadcast(x)) => Rebroadcast(x)
3712
3713    """
3714    op = node.op
3715    if not isinstance(op, T.Rebroadcast):
3716        return False
3717
3718    input = node.inputs[0]
3719    inode = input.owner
3720    if inode and isinstance(inode.op, Elemwise) and len(inode.inputs) == 1:
3721        # It may happen that `input` has no client because this optimization
3722        # is called from `apply_rebroadcast_opt`, which in particular is used
3723        # by the `unbroadcast` function before we are in the actual function
3724        # compilation phase.
3725        if hasattr(input, 'clients') and len(input.clients) == 1:
3726            rebroadcasted = T.Rebroadcast(*list(op.axis.items()))(
3727                inode.inputs[0])
3728            # Copy over stacktrace from previous output (after rebroadcasting)
3729            # to new output, because an error in the new graph right after
3730            # rebroadcasting must have been caused by the previous rebroadcasting.
3731            copy_stack_trace(node.outputs, rebroadcasted)
3732
3733            rval = inode.op.make_node(rebroadcasted).outputs
3734
3735            # Copy over stacktrace from previous output (after rebroadcasting)
3736            # and input (after elemwise operation) to new output, because an
3737            # error in the new graph could have been caused by either of the
3738            # two ops.
3739            copy_stack_trace(node.outputs + node.inputs, rval)
3740
3741            return rval
3742    if inode and isinstance(inode.op, T.Rebroadcast):
3743        # the "axis" specification in the outer Rebroadcast overrides
3744        # the axis of the inner one
3745        axis = inode.op.axis.copy()
3746        axis.update(op.axis)
3747        iinput = inode.inputs[0]
3748
3749        rval = [T.Rebroadcast(*list(axis.items()))(iinput)]
3750
3751        # Copy over stacktrace from previous output (after second rebroadcast)
3752        # and from previous input (after first rebroadcast op) because an error in
3753        # the new graph could have been caused by either of the two
3754        # rebroadcast ops.
3755        copy_stack_trace(node.outputs + node.inputs, rval)
3756        return rval
3757
3758
3759def apply_rebroadcast_opt(rval):
3760    """
3761    Apply as many times as required the optimization local_useless_rebroadcast
3762    and local_rebroadcast_lift.
3763
3764    Parameters
3765    ----------
3766    rval: a Variable
3767
3768    Returns
3769    -------
3770    A Variable (the same if no optimization can be applied)
3771
3772    """
3773
3774    changed = True
3775    while changed and rval.owner:
3776        changed = False
3777        rval2 = theano.tensor.opt.local_useless_rebroadcast.transform(
3778            rval.owner)
3779        if rval2:
3780            assert len(rval2) == 1
3781            rval = rval2[0]
3782            changed = True
3783        if rval.owner:
3784            rval2 = theano.tensor.opt.local_rebroadcast_lift.transform(
3785                rval.owner)
3786            if rval2:
3787                assert len(rval2) == 1
3788                rval = rval2[0]
3789                changed = True
3790    return rval
3791
3792
3793#############
3794# Join opts #
3795#############
3796@register_specialize
3797@register_canonicalize
3798@register_useless
3799@gof.local_optimizer([T.Join])
3800def local_join_1(node):
3801    """Join(i, x) => x
3802
3803    Remove Join() when only one element is joined.
3804
3805    """
3806    if not isinstance(node.op, T.Join):
3807        return
3808    tensors = node.inputs[1:]
3809    if len(tensors) == 1:
3810        # We don't need to copy over any stacktrace here, because the
3811        # input variable should already have its own stacktrace.
3812        return [tensors[0]]
3813
3814
3815# TODO: merge in local_useless_join
3816@register_useless
3817@register_specialize
3818@register_canonicalize
3819@gof.local_optimizer([T.Join])
3820def local_join_empty(node):
3821    """Join(i, x, y, empty) => Join(i, x, y)
3822
3823    Remove empty inputs to joins. The empty inputs can be anywhere.
3824
3825    """
3826    if not isinstance(node.op, T.Join):
3827        return
3828    new_inputs = []
3829    try:
3830        join_idx = get_scalar_constant_value(node.inputs[0],
3831                                             only_process_constants=True)
3832    except NotScalarConstantError:
3833        return
3834    for idx in xrange(1, len(node.inputs)):
3835        inp = node.inputs[idx]
3836        # We can not use size == 0,, as this can change shape from 3,0
3837        # to 2,0.  This trigger DebugMode error. This happen with
3838        # stack(...,[]) as this add a dimshuffle on [], that add a
3839        # dimensions with shape 1.
3840        if isinstance(inp, theano.Constant) and inp.data.shape[join_idx] == 0:
3841            continue
3842        new_inputs.append(inp)
3843    if len(new_inputs) < len(node.inputs) - 1:
3844        if len(new_inputs) == 0:
3845            # T.join do not work in that case.
3846            # constant folding will take care of this case.
3847            return
3848        ret = T.join(node.inputs[0], *new_inputs)
3849        o = node.outputs[0]
3850        if ret.dtype != o.dtype:
3851            # Join can upcast some inputs
3852            return
3853
3854        # Copy over stacktrace from previous output (after join op)
3855        # to new output, because an error in the new op must be caused
3856        # by an error in the old join op.
3857        copy_stack_trace(node.outputs, ret)
3858
3859        if ret.type != o.type:
3860            assert ret.dtype == o.dtype
3861            assert ret.ndim == o.ndim
3862            ret = T.patternbroadcast(ret, node.outputs[0].broadcastable)
3863
3864        # Copy over stacktrace from previous output
3865        # (after patternbroadcast op) for same reasons as before.
3866        copy_stack_trace(node.outputs, ret)
3867
3868        return [ret]
3869
3870
3871@register_specialize
3872@register_canonicalize
3873@register_useless
3874@gof.local_optimizer([T.Join])
3875def local_join_make_vector(node):
3876    """Join(0, make_vector1, make_vector2, ...) => Join(0, make_vector12, ...)
3877
3878    Merge MakeVector inputs to Join. This can make the join completly
3879    disapear with the local_join_1 opt.
3880
3881    """
3882    if not isinstance(node.op, T.Join) or node.outputs[0].ndim != 1:
3883        return
3884    new_inputs = [node.inputs[1]]
3885    for idx in xrange(2, len(node.inputs)):
3886        inp = node.inputs[idx]
3887        if (inp.owner and
3888                isinstance(inp.owner.op, MakeVector) and
3889                new_inputs[-1].owner and
3890                isinstance(new_inputs[-1].owner.op, MakeVector) and
3891                # MakeVector have a dtype parameter
3892                inp.owner.op == new_inputs[-1].owner.op):
3893            inps = new_inputs[-1].owner.inputs + inp.owner.inputs
3894            new_inputs[-1] = inp.owner.op(*inps)
3895
3896            # Copy over stacktrace from previous output (after join op)
3897            # to new intermediate output, because an error in the intermediate
3898            # op must be caused by an error in the old join op.
3899            copy_stack_trace(node.outputs, new_inputs[-1])
3900        else:
3901            new_inputs.append(inp)
3902    if len(new_inputs) < len(node.inputs) - 1:
3903        ret = T.join(node.inputs[0], *new_inputs)
3904
3905        # Copy over stacktrace from previous output (after join op)
3906        # to new output, because an error in the new op must be caused
3907        # by an error in the old join op.
3908        copy_stack_trace(node.outputs, ret)
3909        return [ret]
3910
3911
3912#################
3913#  speed/memory #
3914#################
3915@register_canonicalize
3916@register_specialize
3917@gof.local_optimizer([T.elemwise.Sum])
3918def local_sumsqr2dot(node):
3919    """
3920    This optimization detects T.sqr( W.dimshuffle('x',0,1) * G.dimshuffle(0,'x',1) ).sum(axis=(1,2))
3921     and converts this to T.dot(T.sqr(G), T.sqr(W).sum(axis=0)).
3922    """
3923    if (isinstance(node.op, T.elemwise.Sum) and
3924            isinstance(node.op.scalar_op, theano.scalar.basic.Add) and node.op.axis == (1, 2)):
3925        in1 = node.inputs[0]
3926        out = node.outputs[0]
3927
3928        if (in1.owner and isinstance(in1.owner.op, T.Elemwise) and isinstance(in1.owner.op.scalar_op, theano.scalar.basic.Sqr)):
3929            in_sqr = in1.owner.inputs[0]
3930            if (in_sqr.owner and isinstance(in_sqr.owner.op, T.Elemwise) and
3931                    isinstance(in_sqr.owner.op.scalar_op, theano.scalar.basic.Mul) and len(in_sqr.owner.inputs) == 2):
3932                in_mul1, in_mul2 = in_sqr.owner.inputs
3933
3934                if (isinstance(in_mul1.owner.op, T.elemwise.DimShuffle) and in_mul1.owner.op.new_order == ('x', 0, 1) and
3935                        isinstance(in_mul2.owner.op, T.elemwise.DimShuffle) and in_mul2.owner.op.new_order == (0, 'x', 1)):
3936                    W = in_mul1.owner.inputs[0]
3937                    G = in_mul2.owner.inputs[0]
3938
3939                    new_out = T.dot(T.sqr(G), T.sqr(W).sum(axis=0))
3940                    if new_out.dtype != out.dtype:
3941                        new_out = T.cast(new_out, dtype=out.dtype)
3942                    return [new_out]
3943
3944
3945#################
3946# Exp stability #
3947#################
3948@register_stabilize
3949@register_specialize
3950@register_canonicalize
3951@gof.local_optimizer([T.Elemwise])
3952def local_expm1(node):
3953    """
3954    This optimization detects exp(a)-1 and converts this to expm1(a).
3955    """
3956    if (isinstance(node.op, T.Elemwise) and
3957            isinstance(node.op.scalar_op, theano.scalar.basic.Sub)):
3958        in1, in2 = node.inputs
3959        out = node.outputs[0]
3960
3961        if (in1.owner and isinstance(in1.owner.op, T.Elemwise) and isinstance(in1.owner.op.scalar_op, theano.scalar.basic.Exp) and
3962                T.extract_constant(in2, only_process_constants=False) == 1):
3963            in11 = in1.owner.inputs[0]
3964            new_out = T.expm1(in11)
3965
3966            if new_out.dtype != out.dtype:
3967                new_out = T.cast(new_out, dtype=out.dtype)
3968            if new_out.type != out.type:
3969                return
3970            return [new_out]
3971
3972
3973###############
3974# Switch opts #
3975###############
3976@register_useless('local_remove_switch_const_cond')
3977@register_canonicalize('fast_compile', 'local_remove_switch_const_cond')
3978@register_specialize
3979@gof.local_optimizer([T.Elemwise])
3980def local_useless_switch(node):
3981    """
3982    This optimization makes the following changes in the graph:
3983        T.switch(cond,left,right) -->
3984               if cond is constant and cond == 0: right
3985               if cond is constant and cond != 0: left
3986               if left is right -> left
3987
3988        T.switch(le(shape_i{id}(X), 0), 0, shape_i{id}(X)) -> shape_i{id}(X)
3989    """
3990    if (isinstance(node.op, T.Elemwise) and
3991            isinstance(node.op.scalar_op, scalar.basic.Switch)):
3992        cond = T.extract_constant(node.inputs[0],
3993                                  only_process_constants=True)
3994        if ((type(cond) is np.ndarray and cond.ndim == 0) or
3995                isinstance(cond, np.number)):
3996            if cond == 0:
3997                correct_out = node.inputs[2]
3998            else:
3999                correct_out = node.inputs[1]
4000
4001            if correct_out.ndim != node.outputs[0].ndim:
4002                # TODO: broadcast?
4003                return False
4004            if correct_out.dtype != node.outputs[0].dtype:
4005                out = T.cast(correct_out, node.outputs[0].dtype)
4006            else:
4007                out = correct_out
4008
4009            if out.type.broadcastable != node.outputs[0].type.broadcastable:
4010                # We need to copy data to the new dimensions during execution
4011
4012                # We should not depend on node.outputs as this would
4013                # make the new node depend on the old one that will
4014                # get optimized again. So this create a cycle.
4015                shps = []
4016                for idx, (b1, b2), in enumerate(zip(out.type.broadcastable,
4017                                                    node.outputs[0].type.broadcastable)):
4018                    if b1 == b2:
4019                        shps.append(out.shape[idx])
4020                    elif not node.inputs[1].type.broadcastable[idx]:
4021                        shps.append(node.inputs[1].shape[idx])
4022                    else:
4023                        shps.append(node.inputs[2].shape[idx])
4024                out = T.alloc(out, *shps)
4025            else:
4026                out = out
4027
4028            # Copy over stacktrace from selected output to new output
4029            copy_stack_trace(node.outputs + correct_out, out)
4030            return [out]
4031        # if left is right -> left
4032        if node.inputs[1] is node.inputs[2]:
4033            # Note: No need to copy over stacktrace, because the input node
4034            # already has its own stacktrace
4035            if cond.type == node.inputs[1].type:
4036                return [node.inputs[1]]
4037
4038            ret = T.fill(cond, node.inputs[1])
4039
4040            # Copy over stacktrace from switch output and correct branch
4041            copy_stack_trace(node.outputs + node.inputs[1], ret)
4042            return [ret]
4043
4044        # This case happens with scan.
4045        # Elemwise{switch}(le(shape_i{id}(X), 0), 0, shape_i{id}(X)) -> shape_i{id}(X)
4046        left = node.inputs[1]
4047        right = node.inputs[2]
4048        cond_var = node.inputs[0]
4049        if cond_var.owner and \
4050           isinstance(cond_var.owner.op, T.Elemwise) and \
4051           isinstance(cond_var.owner.op.scalar_op, scalar.LE) and \
4052           cond_var.owner.inputs[0].owner and \
4053           isinstance(cond_var.owner.inputs[0].owner.op, Shape_i) and \
4054           T.extract_constant(cond_var.owner.inputs[1], only_process_constants=True) == 0 and \
4055           T.extract_constant(left, only_process_constants=True) == 0 and \
4056           right is cond_var.owner.inputs[0]:
4057            assert right.type == node.outputs[0].type
4058            # No need to copy over stacktrace, because the right input node
4059            # already has its own stacktrace
4060            return [right]
4061        return False
4062    return False
4063
4064
4065@register_specialize
4066@register_canonicalize
4067@gof.local_optimizer([T.mul])
4068def local_mul_switch_sink(node):
4069    """
4070    This optimization makes the following changes in the graph:
4071    T.mul(A,T.switch(cond,0,iff),B) -->  T.switch(cond,0,T.mul(A,B,iff))
4072    T.mul(A,T.switch(cond,ift,0),B) -->  T.switch(cond,T.mul(A,B,ift),0)
4073    A and B being several (or none) symbolic variables.
4074    This is useful because A and B may not be numerically stable and give
4075    NaN or inf values for cases where the switch returns 0.
4076    With this optimization T.grad(T.switch(...)) has the right behavior.
4077
4078    Examples
4079    --------
4080      x -> f(x)
4081      x -> g(x)
4082      y = T.switch(cond,f(x),g(x))
4083      **without the optimization
4084      T.grad(y,x) -> grad(f(x),x) * grad(y,f(x)) +  grad(g(x),x) * grad(y,g(x))
4085      **with the optimization
4086      T.grad(y,x) -> switch(cond,grad(f(x),x), 0) + switch(cond,0,grad(g(x),x))
4087    This will be particularly useful for the lazyif because we skip
4088    an entire part of the graph.
4089
4090    """
4091    if node.op != T.mul:
4092        return False
4093    for idx, i in enumerate(node.inputs):
4094        if i.owner and i.owner.op == T.switch:
4095            switch = i.owner
4096            try:
4097                if (get_scalar_constant_value(
4098                        switch.inputs[1], only_process_constants=True) == 0.):
4099                    listmul = node.inputs[:idx] + node.inputs[idx + 1:]
4100                    fmul = T.mul(*(listmul + [switch.inputs[2]]))
4101
4102                    # Copy over stacktrace for elementwise multiplication op
4103                    # from previous elementwise multiplication op.
4104                    # An error in the multiplication (e.g. errors due to
4105                    # inconsistent shapes), will point to the
4106                    # multiplication op.
4107                    copy_stack_trace(node.outputs, fmul)
4108
4109                    fct = [T.switch(switch.inputs[0], 0,
4110                                    fmul)]
4111                    fct[0].tag.values_eq_approx = values_eq_approx_remove_nan
4112
4113                    # Copy over stacktrace for switch op from both previous
4114                    #  elementwise multiplication op and previous switch op,
4115                    # because an error in this part can be caused by either
4116                    # of the two previous ops.
4117                    copy_stack_trace(node.outputs + switch.outputs, fct)
4118                    return fct
4119            except NotScalarConstantError:
4120                pass
4121            try:
4122                if (get_scalar_constant_value(
4123                        switch.inputs[2], only_process_constants=True) == 0.):
4124                    listmul = node.inputs[:idx] + node.inputs[idx + 1:]
4125                    fmul = T.mul(*(listmul + [switch.inputs[1]]))
4126                    # Copy over stacktrace for elementwise multiplication op
4127                    # from previous elementwise multiplication op.
4128                    # An error in the multiplication (e.g. errors due to
4129                    # inconsistent shapes), will point to the
4130                    # multiplication op.
4131                    copy_stack_trace(node.outputs, fmul)
4132
4133                    fct = [T.switch(switch.inputs[0],
4134                                    fmul, 0)]
4135                    fct[0].tag.values_eq_approx = values_eq_approx_remove_nan
4136
4137                    # Copy over stacktrace for switch op from both previous
4138                    # elementwise multiplication op and previous switch op,
4139                    # because an error in this part can be caused by either
4140                    # of the two previous ops.
4141                    copy_stack_trace(node.outputs + switch.outputs, fct)
4142                    return fct
4143            except NotScalarConstantError:
4144                pass
4145    return False
4146
4147
4148@register_canonicalize
4149@gof.local_optimizer([T.true_div, T.int_div])
4150def local_div_switch_sink(node):
4151    """
4152    This optimization makes the following changes in the graph:
4153    T.div(T.switch(cond,0,iff),A) -->  T.switch(cond,0,T.div(iff,A))
4154    T.div(T.switch(cond,ift,0),A) -->  T.switch(cond,T.div(ift,A),0)
4155
4156    A being a symbolic variable.
4157    This is useful because A may not be numerically stable and give
4158    NaN or inf values for cases where the switch returns 0.
4159    See local_mul_switch_sink for more details.
4160
4161    """
4162    if (node.op != T.true_div and node.op != T.int_div):
4163        return False
4164    op = node.op
4165    if node.inputs[0].owner and node.inputs[0].owner.op == T.switch:
4166        switch = node.inputs[0].owner
4167        try:
4168            if get_scalar_constant_value(switch.inputs[1],
4169                                         only_process_constants=True) == 0.:
4170                fdiv = op(switch.inputs[2], node.inputs[1])
4171                # Copy over stacktrace for elementwise division op
4172                # from previous elementwise multiplication op.
4173                # An error in the division (e.g. errors due to
4174                # inconsistent shapes or division by zero),
4175                # will point to the new division op.
4176                copy_stack_trace(node.outputs, fdiv)
4177
4178                fct = [T.switch(switch.inputs[0], 0,
4179                                fdiv)]
4180                fct[0].tag.values_eq_approx = values_eq_approx_remove_nan
4181
4182                # Copy over stacktrace for switch op from both previous
4183                # elementwise division op and previous switch op,
4184                # because an error in this part can be caused by either
4185                # of the two previous ops.
4186                copy_stack_trace(node.outputs + switch.outputs, fct)
4187                return fct
4188        except NotScalarConstantError:
4189            pass
4190        try:
4191            if get_scalar_constant_value(switch.inputs[2],
4192                                         only_process_constants=True) == 0.:
4193                fdiv = op(switch.inputs[1], node.inputs[1])
4194                # Copy over stacktrace for elementwise division op
4195                # from previous elementwise multiplication op.
4196                # An error in the division (e.g. errors due to
4197                # inconsistent shapes or division by zero),
4198                # will point to the new division op.
4199                copy_stack_trace(node.outputs, fdiv)
4200
4201                fct = [T.switch(switch.inputs[0],
4202                                fdiv, 0)]
4203                fct[0].tag.values_eq_approx = values_eq_approx_remove_nan
4204
4205                # Copy over stacktrace for switch op from both previous
4206                # elementwise division op and previous switch op,
4207                # because an error in this part can be caused by either
4208                # of the two previous ops.
4209                copy_stack_trace(node.outputs + switch.outputs, fct)
4210                return fct
4211        except NotScalarConstantError:
4212            pass
4213    return False
4214
4215
4216# Merge add/sub/mul/div/minimum/maximum/... of switches sharing the same
4217# condition, to enable further simplification of their branches
4218# Example: switch(c, a, b) + switch(c, x, y) -> switch(c, a+x, b+y)
4219@register_canonicalize
4220@gof.local_optimizer([T.Elemwise])
4221def local_merge_switch_same_cond(node):
4222    scal = theano.scalar
4223    # node must be binary elemwise or add or mul
4224    if not isinstance(node.op, T.Elemwise) or not isinstance(
4225            node.op.scalar_op, (scal.BinaryScalarOp, scal.Add, scal.Mul)):
4226        return
4227    # all inputs must be switch
4228    if not all(s.owner and isinstance(s.owner.op, T.Elemwise) and
4229               isinstance(s.owner.op.scalar_op, scal.Switch)
4230               for s in node.inputs):
4231        return
4232    # all switch conditions must be the same
4233    cond = node.inputs[0].owner.inputs[0]
4234    if not all(s.owner.inputs[0] is cond for s in node.inputs[1:]):
4235        return
4236    # pull out switch
4237    return [T.switch(cond,
4238                     node.op(*[s.owner.inputs[1] for s in node.inputs]),
4239                     node.op(*[s.owner.inputs[2] for s in node.inputs]))]
4240
4241
4242#############
4243# Tile Opts #
4244#############
4245@register_useless
4246@register_canonicalize
4247@register_stabilize
4248@gof.local_optimizer([T.Tile])
4249def local_useless_tile(node):
4250    """Tile(x, (1,)*N) -> x
4251
4252    This is useless tile. (1,)*N, just mean a vector with all element
4253    being 1.
4254
4255    """
4256    if isinstance(node.op, T.Tile):
4257        try:
4258            a = T.get_scalar_constant_value(node.inputs[1],
4259                                            only_process_constants=True)
4260            if a == 1:
4261                try:
4262                    l = T.get_vector_length(node.inputs[1])
4263                    if l == node.inputs[0].ndim:
4264                        # No need to copy over any stacktrace as previous
4265                        # input variable already has a stacktrace
4266                        return [node.inputs[0]]
4267                    elif l < node.inputs[0].ndim:
4268                        # The Op don't support that case, so we can't
4269                        # implement the opt and test it.
4270                        return
4271                        return [node.inputs[0]]
4272                    else:
4273                        # The Op don't support that case, so we can't
4274                        # implement the opt and test it.
4275                        return
4276                        x_nd = node.inputs[0].ndim
4277                        broad = ['x'] * (l - x_nd) + xrange(x_nd)
4278                        ret = node.inputs[0].dimshuffle(broad)
4279                        # Copy over stacktrace from previous output node,
4280                        # and from node before tiling operation.
4281                        copy_stack_trace(node.outputs + node.inputs[0], ret)
4282                        return [ret]
4283                except ValueError:
4284                    return
4285        except NotScalarConstantError:
4286            return
4287
4288
4289##############
4290# Split Opts #
4291##############
4292@register_useless
4293@register_canonicalize
4294@register_specialize
4295@gof.local_optimizer([T.Split])
4296def local_useless_split(node):
4297    """ Split{n_splits=1}(x, y) -> x
4298
4299    Remove Split with only 1 split.
4300
4301    """
4302    if isinstance(node.op, T.Split):
4303        if node.op.len_splits == 1:
4304            x, axis, splits = node.inputs
4305            out = assert_op(x, T.eq(splits.shape[0], 1))
4306            # Copy over stacktrace from previous output node.
4307            copy_stack_trace(node.outputs, out)
4308            out2 = assert_op(out, T.eq(x.shape[axis], splits[0]))
4309            # Copy over stacktrace from previous output node.
4310            copy_stack_trace(out, out2)
4311
4312            return [out2]
4313
4314
4315################
4316# Flatten Opts #
4317################
4318@register_canonicalize
4319@register_stabilize
4320@gof.local_optimizer([T.Flatten])
4321def local_flatten_lift(node):
4322    """
4323    Flatten(UnaryElemwise(x)) -> UnaryElemwise(Flatten(x))
4324
4325    This optimization is needed by optimization
4326    nnet/sigm.py:log1msigm_to_softplus to get applied when there is a flatten.
4327
4328    """
4329    if (isinstance(node.op, T.Flatten) and
4330            node.inputs[0].owner and
4331            isinstance(node.inputs[0].owner.op, T.Elemwise) and
4332            len(node.inputs[0].owner.inputs) == 1):
4333        f = node.op(node.inputs[0].owner.inputs[0])
4334
4335        # Copy over stacktrace from previous output node (flatten op),
4336        # since this is the op which may cause an error for f.
4337        copy_stack_trace(node.outputs, f)
4338
4339        e = node.inputs[0].owner.op(f)
4340
4341        # Copy over stacktrace from previous output node and from unary
4342        # elementwise output node since if there was an error, it would
4343        # probably have come from that operation.
4344        copy_stack_trace(node.outputs + [node.inputs[0]], e)
4345
4346        return [e]
4347
4348##################
4349# Reshape opts   #
4350##################
4351
4352
4353def local_reshape_chain(op):
4354    @gof.local_optimizer([op])
4355    def f(node):
4356        """
4357        Reshape(Reshape(shape1),shape2) -> Reshape(shape2)
4358
4359        """
4360        if not opt.check_chain(node, op, op):
4361            return False
4362
4363        # TODO: this can permit a failing program to run by eliminating
4364        #       the lower reshape
4365        rval = node.op(node.inputs[0].owner.inputs[0], node.inputs[1])
4366
4367        # Copy over stacktrace from previous output node, as any error
4368        # in new computational graph would have been caused by last op
4369        # in the old computational graph.
4370        copy_stack_trace(node.outputs, rval)
4371
4372        # It might happen that the desired output of this node has a
4373        # broadcastable pattern that does not match that of 'rval'. This is
4374        # when originally, we were able to figure out that one of the
4375        # dimensions of the reshape is one, but some other transformation
4376        # replaced the shape by one for which this cannot be guessed.
4377        # We should try to figure out why we lost the information about this
4378        # constant value... but in the meantime, better not apply this
4379        # optimization.
4380        if rval.broadcastable == node.outputs[0].broadcastable:
4381            return [rval]
4382        else:
4383            return False
4384
4385    return f
4386register_canonicalize(local_reshape_chain(T.Reshape),
4387                      name='local_reshape_chain')
4388
4389
4390@register_useless
4391@register_canonicalize
4392@register_stabilize
4393@gof.local_optimizer([T.Reshape])
4394def local_useless_reshape(node):
4395    """
4396    Remove two kinds of useless reshape.
4397
4398    Remove Reshape when both the input and output have a single dimension.
4399    Remove Reshape when reshaping to the shape of the input.
4400
4401    """
4402    op = node.op
4403    if not isinstance(op, Reshape):
4404        return False
4405
4406    input = node.inputs[0]
4407    output = node.outputs[0]
4408    output_shape = node.inputs[1]
4409
4410    if input.ndim != output.ndim:
4411        return False
4412
4413    # Simple case: both input and output have a single dimension.
4414    # This could hide errors if the user provides inconsistent shapes.
4415    if (input.ndim == 1 and output.ndim == 1 and
4416            input.broadcastable == output.broadcastable):
4417        return [input]
4418
4419    # Second case: all the shapes match the input shape
4420    # Match Reshape(x, x.shape)
4421    if output_shape.owner and isinstance(output_shape.owner.op, Shape):
4422        shape_input = output_shape.owner.inputs[0]
4423        if shape_input == input:
4424            return [input]
4425
4426    # Match Reshape(x, [x.shape[0], ..., x.shape[-1]]), accounting for
4427    # broadcastable and constant dimensions
4428    if output_shape.owner and isinstance(output_shape.owner.op, MakeVector):
4429        output_shape_is = output_shape.owner.inputs
4430
4431        if not hasattr(node, 'fgraph'):
4432            shape_feature = None
4433        else:
4434            shape_feature = getattr(node.fgraph, 'shape_feature', None)
4435
4436        nb_m1 = 0
4437        shape_match = [False] * input.ndim
4438        for dim in xrange(input.ndim):
4439            outshp_i = output_shape_is[dim]
4440            # Match Shape_i{dim}(input)
4441            if (outshp_i.owner and isinstance(outshp_i.owner.op, Shape_i) and
4442                    outshp_i.owner.op.i == dim and
4443                    outshp_i.owner.inputs[0] == input):
4444                shape_match[dim] = True
4445                continue
4446
4447            # Match Shape(input)[dim]
4448            if (outshp_i.owner and isinstance(outshp_i.owner.op, Subtensor) and
4449                    len(outshp_i.owner.inputs) == 2 and
4450                    extract_constant(outshp_i.owner.inputs[1]) == dim):
4451                subtensor_inp = outshp_i.owner.inputs[0]
4452                if (subtensor_inp.owner and
4453                        isinstance(subtensor_inp.owner.op, Shape)):
4454                    shape_input_i = subtensor_inp.owner.inputs[0]
4455                    if shape_input_i == input:
4456                        shape_match[dim] = True
4457                        continue
4458
4459            # Match 1 if input.broadcastable[dim] is True
4460            cst_outshp_i = extract_constant(outshp_i, only_process_constants=1)
4461            if input.broadcastable[dim] and cst_outshp_i == 1:
4462                shape_match[dim] = True
4463                continue
4464
4465            # Match -1
4466            if cst_outshp_i == -1:
4467                shape_match[dim] = True
4468                nb_m1 += 1
4469                continue
4470
4471            # Match shape_of[input][dim] or its constant equivalent
4472            if shape_feature:
4473                inpshp_i = shape_feature.get_shape(input, dim)
4474                if (inpshp_i == outshp_i or
4475                    (extract_constant(inpshp_i, only_process_constants=1) ==
4476                     extract_constant(outshp_i, only_process_constants=1))):
4477                    shape_match[dim] = True
4478                    continue
4479
4480        if all(shape_match) and nb_m1 <= 1:
4481            return [input]
4482
4483        # TODO later: if all the shapes except one match, we may want to
4484        # consider it useless as well, like we do in the 1-dim case.
4485
4486
4487@register_canonicalize
4488@gof.local_optimizer([T.Reshape])
4489def local_reshape_to_dimshuffle(node):
4490    """
4491    Broadcastable dimensions in Reshape are replaced with dimshuffle.
4492
4493    The goal is to avoid using reshape to add or remove broadcastable
4494    dimensions, but use dimshuffle instead, so dimshuffles can cancel out
4495    or be removed later on.
4496
4497    For example:
4498        - reshape(x, (1, n)) --> dimshuffle{x,0}(reshape(x, (n,))
4499        - reshape(x, (1, m, 1, n, 1, 1))
4500          --> dimshuffle{x,0,x,1,x,x}(reshape(x, (m, n)))
4501    """
4502    op = node.op
4503    if not isinstance(op, Reshape):
4504        return False
4505
4506    input = node.inputs[0]
4507    output = node.outputs[0]
4508    output_shape = node.inputs[1]
4509
4510    dimshuffle_new_order = []
4511    new_output_shape = []
4512    index = 0  # index over the output of the new reshape
4513    for i in xrange(output.ndim):
4514        # Since output_shape is a symbolic vector, we trust extract_constant
4515        # to go through however it is formed to see if its i-th element is 1.
4516        # We need only_process_constants=False for that.
4517        dim = extract_constant(output_shape[i], only_process_constants=False,
4518                               elemwise=False)
4519        if dim == 1:
4520            dimshuffle_new_order.append('x')
4521        else:
4522            dimshuffle_new_order.append(index)
4523            new_output_shape.append(dim)
4524            index = index + 1
4525    if index != output.ndim:
4526        inner = op.__class__(len(new_output_shape))(input, new_output_shape)
4527        copy_stack_trace(output, inner)
4528        new_node = [DimShuffle(inner.type.broadcastable, dimshuffle_new_order)(inner)]
4529        copy_stack_trace(output, new_node)
4530        return new_node
4531
4532
4533@register_canonicalize
4534@register_stabilize
4535@gof.local_optimizer([T.Reshape])
4536def local_reshape_lift(node):
4537    """
4538    Reshape(UnaryElemwise(x)) -> UnaryElemwise(Reshape(x))
4539
4540    This optimization is needed by optimization
4541    nnet/sigm.py:log1msigm_to_softplus to get applied when there is a reshape.
4542
4543    """
4544    if (isinstance(node.op, T.Reshape) and
4545            node.inputs[0].owner and
4546            isinstance(node.inputs[0].owner.op, T.Elemwise) and
4547            len(node.inputs[0].owner.inputs) == 1):
4548        r = node.op(node.inputs[0].owner.inputs[0], node.inputs[1])
4549        # Copy stacktrace from previous Reshape op, as an error in new
4550        # Reshape op could only have been caused by old one.
4551        copy_stack_trace(node.outputs, r)
4552
4553        e = node.inputs[0].owner.op(r)
4554        # Copy stacktrace from both previous Reshape and UnaryElemwise op
4555        # because an error in new cg could have been caused by either ops.
4556        copy_stack_trace(node.outputs + node.inputs, e)
4557
4558        # In rare case the original broadcast was (False, True), but
4559        # the new one is (False, False). So don't crash in that case.
4560        if e.type != node.outputs[0].type:
4561            re = T.patternbroadcast(e, node.outputs[0].broadcastable)
4562
4563            # Copy over stack trace.
4564            # If the graph fails it is usually due to the fact that a dimension
4565            # that should be broadcastable does not actually have length 1,
4566            copy_stack_trace(e, re)
4567        else:
4568            re = e
4569
4570        return [re]
4571
4572
4573##################
4574# Middleman cuts #
4575##################
4576
4577register_canonicalize(gof.OpRemove(T.tensor_copy), name='remove_tensor_copy')
4578
4579################
4580# Canonization #
4581################
4582
4583
4584class Canonizer(gof.LocalOptimizer):
4585    r"""
4586    Simplification tool. The variable is a local_optimizer. It is best used
4587    with a TopoOptimizer in in_to_out order.
4588
4589    Usage: Canonizer(main, inverse, reciprocal, calculate)
4590
4591    Parameters
4592    ----------
4593    main
4594        A suitable Op class that is commutative, associative and
4595        takes one to an arbitrary number of inputs, e.g. add or
4596        mul
4597    inverse
4598        An Op class such that inverse(main(x, y), y) == x
4599        e.g. sub or true_div
4600    reciprocal
4601        A function such that main(x, reciprocal(y)) == inverse(x, y)
4602        e.g. neg or inv
4603    calculate
4604        Function that takes a list of numpy.ndarray instances
4605        for the numerator, another list for the denumerator,
4606        and calculates inverse(main(\*num), main(\*denum)). It
4607        takes a keyword argument, aslist. If True, the value
4608        should be returned as a list of one element, unless
4609        the value is such that value = main(). In that case,
4610        the return value should be an empty list.
4611
4612    Examples
4613    --------
4614    >>> import theano.tensor as T
4615    >>> from theano.tensor.opt import Canonizer
4616    >>> add_canonizer = Canonizer(T.add, T.sub, T.neg, \
4617    ...                           lambda n, d: sum(n) - sum(d))
4618    >>> mul_canonizer = Canonizer(T.mul, T.true_div, T.inv, \
4619    ...                           lambda n, d: prod(n) / prod(d))
4620
4621    Examples of optimizations mul_canonizer can perform:
4622
4623    | x / x -> 1
4624    | (x * y) / x -> y
4625    | x / y / x -> 1 / y
4626    | x / y / z -> x / (y * z)
4627    | x / (y / z) -> (x * z) / y
4628    | (a / b) * (b / c) * (c / d) -> a / d
4629    | (2.0 * x) / (4.0 * y) -> (0.5 * x) / y
4630    | 2 * x / 2 -> x
4631    | x * y * z -> Elemwise(T.mul){x,y,z} #only one pass over the memory.
4632    |           !-> Elemwise(T.mul){x,Elemwise(T.mul){y,z}}
4633
4634    """
4635
4636    def __init__(self, main, inverse, reciprocal, calculate,
4637                 use_reciprocal=True):
4638        self.main = main
4639        self.inverse = inverse
4640        self.reciprocal = reciprocal
4641        self.calculate = calculate
4642        self.use_reciprocal = use_reciprocal
4643
4644        self.external_simplifiers = []
4645
4646    def add_simplifier(self, simplifier, reason):
4647        self.external_simplifiers.append((reason, simplifier))
4648
4649    def tracks(self):
4650        return [self.main, self.inverse, self.reciprocal]
4651
4652    def get_num_denum(self, input):
4653        r"""
4654        This extract two lists, num and denum, such that the input is:
4655        self.inverse(self.main(\*num), self.main(\*denum)). It returns
4656        the two lists in a (num, denum) pair.
4657
4658        For example, for main, inverse and reciprocal = \*, / and inv(),
4659
4660        | input -> returned value (num, denum)
4661
4662        | x*y -> ([x, y], [])
4663        | inv(x) -> ([], [x])
4664        | inv(x) * inv(y) -> ([], [x, y])
4665        | x*y/z -> ([x, y], [z])
4666        | log(x) / y * (z + x) / y -> ([log(x), z + x], [y, y])
4667        | (((a / b) * c) / d) -> ([a, c], [b, d])
4668        | a / (b / c) -> ([a, c], [b])
4669        | log(x) -> ([log(x)], [])
4670        | x**y -> ([x**y], [])
4671        | x * y * z -> ([x, y, z], [])
4672
4673        """
4674        # This function is recursive.  The idea is that there is a
4675        # get_num_denum recursion in which the internal ops are all
4676        # one of (main, inverse, reciprocal, DimShuffle) and the
4677        # internal data nodes all have the dtype of the 'input'
4678        # argument. The leaf-Variables of the graph covered by the
4679        # recursion may be of any Variable type.
4680
4681        if input.owner is None or input.owner.op not in [
4682                self.main, self.inverse, self.reciprocal]:
4683            if input.owner and isinstance(input.owner.op, T.DimShuffle):
4684                # If input is a DimShuffle of some input which does
4685                # something like this:
4686
4687                # * change a vector of length N into a 1xN row matrix
4688                # * change a scalar into a 1x1x1 tensor
4689                # * in general, complete the shape of a tensor
4690                #   with broadcastable 1s to the *left*
4691                # Then we will simply discard the DimShuffle and return
4692                # the num/denum of its input
4693                dsn = input.owner    # dimshuffle node
4694                dsop = dsn.op        # dimshuffle op
4695
4696                # the first input of the dimshuffle i.e. the ndarray to redim
4697                dsi0 = dsn.inputs[0]
4698
4699                # The compatible order is a DimShuffle "new_order" of the form:
4700                # ('x', ..., 'x', 0, 1, 2, ..., dimshuffle_input.type.ndim)
4701
4702                # That kind of DimShuffle only adds broadcastable
4703                # dimensions on the left, without discarding any
4704                # existing broadcastable dimension and is inserted
4705                # automatically by Elemwise when the inputs have
4706                # different numbers of dimensions (hence why we can
4707                # discard its information - we know we can retrieve it
4708                # later on).
4709                compatible_order = (('x',) *
4710                                    (input.type.ndim - dsi0.type.ndim) +
4711                                    tuple(range(dsi0.type.ndim)))
4712                if dsop.new_order == compatible_order:
4713                    # If the "new_order" is the one we recognize,
4714                    # we return the num_denum of the dimshuffled input.
4715                    return self.get_num_denum(input.owner.inputs[0])
4716                else:
4717                    # This is when the input isn't produced by main,
4718                    # inverse or reciprocal.
4719                    return [input], []
4720            else:
4721                return [input], []
4722        num = []
4723        denum = []
4724        parent = input.owner
4725
4726        # We get the (num, denum) pairs for each input
4727        # pairs = [self.get_num_denum(input2) if input2.type.dtype ==
4728        # input.type.dtype else ([input2], []) for input2 in
4729        # parent.inputs]
4730        pairs = [self.get_num_denum(input2) for input2 in parent.inputs]
4731
4732        if parent.op == self.main:
4733            # If we have main(x, y, ...), numx, denumx, numy, denumy, ...
4734            # then num is concat(numx, numy, num...) and denum is
4735            # concat(denumx, denumy, denum...) note that main() can have any
4736            # number of arguments >= 0 concat is list concatenation
4737            num = reduce(list.__iadd__, map(operator.itemgetter(0), pairs))
4738            denum = reduce(list.__iadd__, map(operator.itemgetter(1), pairs))
4739        elif parent.op == self.inverse:
4740            # If we have inverse(x, y), numx, denumx, numy and denumy
4741            # then num is concat(numx, denumy) and denum is
4742            # concat(denumx, numy) note that inverse() is binary
4743            num = pairs[0][0] + pairs[1][1]
4744            denum = pairs[0][1] + pairs[1][0]
4745        elif parent.op == self.reciprocal:
4746            # If we have reciprocal(x), numx, denumx
4747            # then num is denumx and denum is numx
4748            # note that reciprocal() is unary
4749            num = pairs[0][1]
4750            denum = pairs[0][0]
4751        return num, denum
4752
4753    def merge_num_denum(self, num, denum):
4754        r"""
4755        Utility function which takes two lists, num and denum, and
4756        returns something which is equivalent to inverse(main(\*num),
4757        main(\*denum)), but depends on the length of num and the length
4758        of denum (in order to minimize the number of operations).
4759
4760        Let n = len(num) and d = len(denum):
4761
4762        | n=0, d=0: neutral element (given by self.calculate([], []))
4763        |           (for example, this would be 0 if main is addition
4764        |           and 1 if main is multiplication)
4765        | n=1, d=0: num[0]
4766        | n=0, d=1: reciprocal(denum[0])
4767        | n=1, d=1: inverse(num[0], denum[0])
4768        | n=0, d>1: reciprocal(main(\*denum))
4769        | n>1, d=0: main(\*num)
4770        | n=1, d>1: inverse(num[0], main(\*denum))
4771        | n>1, d=1: inverse(main(\*num), denum[0])
4772        | n>1, d>1: inverse(main(\*num), main(\*denum))
4773
4774        Given the values of n and d to which they are associated, all
4775        of the above are equivalent to:
4776        inverse(main(\*num), main(\*denum))
4777
4778        """
4779
4780        ln, ld = len(num), len(denum)
4781        if not ln and not ld:
4782            return T.as_tensor_variable(self.calculate([], []))
4783        if not ln:
4784            if self.use_reciprocal:
4785                return self.reciprocal(self.merge_num_denum(denum, []))
4786            else:
4787                ln = [self.calculate([], [], aslist=False)]
4788        if not ld:
4789            if ln == 1:
4790                # num[0] should always be a variable
4791                assert isinstance(num[0], gof.Variable)
4792                return num[0]
4793            else:
4794                return self.main(*num)
4795        return self.inverse(self.merge_num_denum(num, []),
4796                            self.merge_num_denum(denum, []))
4797
4798    @staticmethod
4799    def get_constant(v):
4800        """
4801
4802        Returns
4803        -------
4804        object
4805            A numeric constant if v is a Constant or, well, a
4806            numeric constant. If v is a plain Variable, returns None.
4807
4808        """
4809        if isinstance(v, Constant):
4810            if getattr(v.tag, 'unique_value', None) is not None:
4811                data = v.tag.unique_value
4812            else:
4813                data = v.data
4814            if data.ndim == 0:
4815                return data
4816            else:
4817                return None
4818        elif isinstance(v, Variable):
4819            return None
4820        else:
4821            return v
4822
4823    def simplify(self, num, denum, out_type):
4824        """
4825        Shorthand for:
4826
4827        .. code-block:: python
4828
4829            self.simplify_constants(*self.simplify_factors(num, denum))
4830
4831        """
4832        rval = self.simplify_constants(*self.simplify_factors(num, denum),
4833                                       out_type=out_type)
4834        for reason, simplifier in self.external_simplifiers:
4835            # TODO: document that 'reason' is associated with this
4836            #       simplification to help auditing when things go
4837            #       wrong
4838            rval = simplifier(*rval)
4839        return rval
4840
4841    def simplify_factors(self, num, denum):
4842        """
4843        For any Variable r which is both in num and denum, removes it
4844        from both lists. Modifies the lists inplace. Returns the
4845        modified lists. For example:
4846
4847        | [x], [x] -> [], []
4848        | [x, y], [x] -> [y], []
4849        | [a, b], [c, d] -> [a, b], [c, d]
4850
4851        """
4852        ln = len(num)
4853        ld = len(denum)
4854        if (ld > 2 and ln > 2):
4855            # Faster version for "big" inputs.
4856            while True:
4857                s = set(num)
4858                # Inputs can appear multiple times
4859                redo = len(s) != len(num)
4860                inter = s.intersection(denum)
4861                for v in inter:
4862                    num.remove(v)
4863                    denum.remove(v)
4864                if not redo or not inter:
4865                    break
4866        else:
4867            for v in list(num):
4868                if v in denum:
4869                    num.remove(v)
4870                    denum.remove(v)
4871        return num, denum
4872
4873    def simplify_constants(self, orig_num, orig_denum, out_type=None):
4874        """
4875        Find all constants and put them together into a single constant.
4876
4877        Finds all constants in orig_num and orig_denum (using
4878        get_constant) and puts them together into a single
4879        constant. The constant is inserted as the first element of the
4880        numerator. If the constant is the neutral element, it is
4881        removed from the numerator.
4882
4883        Examples
4884        --------
4885        Let main be multiplication:
4886
4887        | [2, 3, x], [] -> [6, x], []
4888        | [x, y, 2], [4, z] -> [0.5, x, y], [z]
4889        | [x, 2, y], [z, 2] -> [x, y], [z]
4890
4891        """
4892        # Lists representing the numerator and denumerator
4893        num, denum = [], []
4894
4895        # Lists representing the *constant* elements of num and denum
4896        numct, denumct = [], []
4897
4898        for v in orig_num:
4899            ct = self.get_constant(v)
4900            if ct is not None:
4901                # We found a constant in the numerator!
4902                # We add it to numct
4903                numct.append(ct)
4904            else:
4905                num.append(v)
4906        for v in orig_denum:
4907            ct = self.get_constant(v)
4908            if ct is not None:
4909                denumct.append(ct)
4910            else:
4911                denum.append(v)
4912
4913        if self.use_reciprocal or num:
4914            # This will calculate either:
4915            # [inverse(main(*numct), main(*denumct))]
4916            # [] - if inverse(main(*numct), main(*denumct)) is the
4917            # neutral element
4918            ct = self.calculate(numct, denumct, aslist=True,
4919                                out_type=out_type)
4920        else:
4921            # This happens if we don't allow the reciprocal and the
4922            # numerator is empty. That means we will need to represent
4923            # reciprocal(x) like inverse(neutral_element, x) so
4924            # we can't allow ct == []
4925            # TODO: why is this branch needed when merge_num_denum
4926            # does it for us?
4927            ct = [self.calculate(numct, denumct, aslist=False,
4928                                 out_type=out_type)]
4929
4930        # Wrapping ct in a Constant with the right dtype
4931        ct = [T.constant(c, dtype=out_type.dtype) for c in ct]
4932
4933        if orig_num and len(numct) == 1 and len(denumct) == 0 and ct:
4934            # In that case we should only have one constant in `ct`.
4935            assert len(ct) == 1
4936            first_num_ct = self.get_constant(orig_num[0])
4937            if first_num_ct is not None and ct[0].type.values_eq(ct[0].data,
4938                                                                 first_num_ct):
4939                # This is an important trick :( if it so happens that:
4940                # * there's exactly one constant on the numerator and none on
4941                #   the denominator
4942                # * it's not the neutral element (ct is an empty list in that
4943                #   case)
4944                # * the constant is the same as the first argument in the
4945                #   numerator (we only check the first argument because the
4946                #   canonizer puts the computed constants first)
4947                # -> then we return very exactly the original num/denum.
4948                # If we don't do that the optimizer will just loop
4949                # infinitely because it will not catch on that there are
4950                # no changes to be made and every time it will want to
4951                # replace something by the same thing...
4952                # Note that it is important to use `values_eq` instead of
4953                # the == operator, to handle NaN values correctly.
4954                return orig_num, orig_denum
4955
4956        return ct + num, denum
4957
4958    def transform(self, node):
4959        op = node.op
4960        if op not in [self.main, self.inverse, self.reciprocal]:
4961            return False
4962
4963        assert len(node.outputs) == 1
4964        out = node.outputs[0]
4965
4966        # out won't have a clients field when we didn't commit a
4967        # started change in the graph.  We can't do the check if we
4968        # want to skip it, so we force the skip it. It should be
4969        # reapplied later.
4970        if not hasattr(out, 'clients'):
4971            return
4972
4973        # check if any of the clients of this node would be part of
4974        # this canonized graph...  if so, we do nothing and wait for
4975        # them to be transformed.
4976        for c, c_idx in out.clients:
4977            if c == 'output':
4978                continue
4979            while (isinstance(getattr(c, 'op', None), DimShuffle) and
4980                   len(c.outputs[0].clients) <= 1):
4981                c = c.outputs[0].clients[0][0]
4982            if getattr(c, 'op', '') in [self.main, self.inverse,
4983                                        self.reciprocal]:
4984                return False
4985
4986        # Here we make the canonical version of the graph around this node
4987        # See the documentation of get_num_denum and simplify
4988        orig_num, orig_denum = self.get_num_denum(node.outputs[0])
4989        num, denum = self.simplify(list(orig_num), list(orig_denum), out.type)
4990
4991        def same(x, y):
4992            return len(x) == len(y) and all(np.all(xe == ye) for xe, ye in
4993                                            zip(x, y))
4994
4995        if same(orig_num, num) and same(orig_denum, denum):
4996            # We return False if there are no changes
4997            return False
4998
4999        new = self.merge_num_denum(num, denum)
5000        if new.type.dtype != out.type.dtype:
5001            new = T.cast(new, out.type.dtype)
5002
5003        assert (new.type == out.type) == (not (new.type != out.type))
5004
5005        if not (new.type == out.type):
5006            new = _fill_chain(new, node.inputs)[0]
5007
5008        if new.type == out.type:
5009            # This happen with test
5010            # theano/tensor/tests/test_opt.py:T_local_switch_sink
5011            new.tag.values_eq_approx = values_eq_approx_remove_inf_nan
5012
5013            # We need to implement the copy over of the stacktrace.
5014            # See issue #5104.
5015            return [new]
5016        else:
5017            _logger.warning(' '.join(('CANONIZE FAILED: new, out = ',
5018                                      new, ',', out, 'types',
5019                                      new.type, ',', out.type)))
5020            return False
5021
5022    def __str__(self):
5023        return getattr(self, 'name', 'Canonizer(%s, %s, %s)' % (
5024            self.main, self.inverse, self.reciprocal))
5025
5026
5027def mul_calculate(num, denum, aslist=False, out_type=None):
5028    if not num and not denum:
5029        # Smallest 1 possible.
5030        if aslist:
5031            return []
5032        else:
5033            return np.int8(1)
5034
5035    # Make sure we do not accidentally upcast data types.
5036    if out_type is None:
5037        out_dtype = scalar.upcast(*[v.dtype for v in (num + denum)])
5038    else:
5039        out_dtype = out_type.dtype
5040    one = theano._asarray(1, dtype=out_dtype)
5041
5042    v = reduce(np.multiply, num, one) / reduce(np.multiply, denum, one)
5043    if aslist:
5044        if np.all(v == 1):
5045            return []
5046        else:
5047            return [v]
5048    return v
5049
5050local_mul_canonizer = Canonizer(T.mul, T.true_div, T.inv, mul_calculate, False)
5051register_canonicalize(local_mul_canonizer, name='local_mul_canonizer')
5052
5053
5054@gof.local_optimizer([T.neg])
5055def local_neg_to_mul(node):
5056    if node.op == T.neg:
5057        return [T.mul(np.array(-1, dtype=node.inputs[0].dtype),
5058                node.inputs[0])]
5059register_canonicalize(local_neg_to_mul)
5060
5061
5062@register_specialize
5063@gof.local_optimizer([T.Sum, T.elemwise.Prod])
5064def local_sum_prod_mul_by_scalar(node):
5065    """
5066    sum(scalar * smth) -> scalar * sum(smth)
5067    sum(-smth) -> -sum(smth)
5068
5069    or
5070
5071    prod(scalar * smth) -> scalar ** size(smth) * prod(smth)
5072    prod(-smth) -> -1 ** size(smth) * prod(smth)
5073
5074    """
5075    # TODO: if the the thing inside the Sum is a division,
5076    # we should get at the numerator....
5077    if isinstance(node.op, (T.Sum, T.elemwise.Prod)):
5078        node_inps, = node.inputs
5079        if node_inps.owner and node_inps.owner.op == T.mul:
5080            terms = node_inps.owner.inputs
5081            scalars = [t.dimshuffle() for t in terms if
5082                       np.all(t.type.broadcastable)]
5083
5084            if len(scalars) == 0:
5085                # Nothing to optimize here
5086                return
5087
5088            non_scalars = [t for t in terms if not np.all(t.broadcastable)]
5089
5090            # Perform the op only on the non-scalar inputs, if applicable
5091            if len(non_scalars) == 0:
5092                new_op_input_nb_elements = 1
5093                new_op_output = 1
5094            elif len(non_scalars) == 1:
5095                new_op_input_nb_elements = non_scalars[0].size
5096                new_op_output = node.op(non_scalars[0])
5097            else:
5098                new_op_input = T.mul(*non_scalars)
5099                # We assume that errors always come from the prod/mul op in the
5100                # original computational graph, and therefore need to only
5101                # copy over its output stacktrace.
5102                copy_stack_trace(node.outputs, new_op_input)
5103
5104                new_op_input_nb_elements = new_op_input.size
5105                new_op_output = node.op(new_op_input)
5106
5107            if not len(non_scalars) == 0:
5108                # Copy over stacktrace from previous output to new mul op,
5109                # for same reason as above.
5110                copy_stack_trace(node.outputs, new_op_output)
5111
5112            # If node.op is a T.elemwise.Prod, then the scalars need to be
5113            # raised to the power of the number of elements in the input
5114            # to the Prod
5115            if (isinstance(node.op, T.elemwise.Prod) and
5116                    new_op_input_nb_elements != 1):
5117
5118                scalars = [s ** new_op_input_nb_elements for s in scalars]
5119
5120            # Scale the output of the op by the scalars and return as
5121            # replacement for the original output
5122            mul_inputs = scalars
5123            if new_op_input_nb_elements != 1:
5124                mul_inputs.append(new_op_output)
5125
5126            if len(mul_inputs) == 1:
5127                # Copy over stacktrace from previous output to new mul op,
5128                # for same reason as above.
5129                copy_stack_trace(node.outputs, mul_inputs)
5130
5131                return mul_inputs
5132            else:
5133                ret = T.mul(*mul_inputs)
5134                # Copy over stacktrace from previous output to new mul op,
5135                # for same reason as above.
5136                copy_stack_trace(node.outputs, [ret] + mul_inputs)
5137
5138                return [ret]
5139
5140        if isinstance(node.op, T.Sum) and node_inps.owner and node_inps.owner.op == T.neg:
5141            s = node.op(node_inps.owner.inputs[0])
5142            ret = T.neg(s)
5143            # There are never errors in the negative op, thus
5144            # we need only to copy over stacktrace from previous output node to
5145            # the two new ops.
5146            copy_stack_trace(node.outputs, [s, ret])
5147
5148            return [ret]
5149
5150
5151@register_specialize
5152@gof.local_optimizer([T.Elemwise])
5153def local_elemwise_sub_zeros(node):
5154    """
5155    Elemwise{sub}(X,X) -> zeros_like(X)
5156    """
5157    if (isinstance(node.op, T.Elemwise) and
5158            node.op.scalar_op.nin == 2 and
5159            node.op.scalar_op == scalar.sub and
5160            node.inputs[0] == node.inputs[1]):
5161        res = T.zeros_like(node.inputs[0])
5162        # Copy over stacktrace from previous output.
5163        # This could help for failures due to out-of-memory.
5164        copy_stack_trace(node.outputs, res)
5165        return [res]
5166
5167
5168@register_useless
5169@register_specialize
5170@register_stabilize
5171@register_canonicalize
5172@gof.local_optimizer([T.Elemwise])
5173def local_useless_elemwise_comparison(node):
5174    """...
5175
5176    :note: These cases appear in the graph generated by scan.
5177           These optimizations will make the graph easier to read.
5178    # Comparing to itself is constant
5179    Elemwise[{LT,GT}](X, X) -> Elemwise[zeros](X)
5180    Elemwise[{LE,GE}](X, X) -> Elemwise[ones](X)
5181    Elemwise[{minimum,maximum}](X, X) -> X
5182
5183    # Comparing shape to 0 can be constant
5184    Elemwise[LT](X.shape[i], 0) -> Elemwise[zeros](X)
5185    Elemwise[GE](X.shape[i], 0) -> Elemwise[ones](X)
5186    Elemwise[maximum](X.shape[i], 0) -> X.shape[i]
5187    Elemwise[maximum](0, X.shape[i]) -> X.shape[i]
5188    Elemwise[minimum](X.shape[i], 0) -> 0
5189    Elemwise[minimum](0, X.shape[i]) -> 0
5190
5191    # The shape can be replaced with sum of shapes
5192    Elemwise[LT](add([anything that is shapes]), 0) -> Elemwise[zeros](X)
5193    Elemwise[GE](add([anything that is shapes]), 0) -> Elemwise[ones](X)
5194
5195    # Shapes are never negative
5196    # Needed by Reshape.infer_shape
5197    Elemwise[EQ](Subtensor(Shape(x)), -N) -> Elemwise[zeros](X)
5198
5199    """
5200    if not isinstance(node.op, T.Elemwise):
5201        return
5202    if node.op.scalar_op.nin != 2:
5203        return
5204
5205    # We call zeros_like and one_like with opt=True to generate a
5206    # cleaner graph.
5207    dtype = node.outputs[0].dtype
5208
5209    # Elemwise[{LT,GT}](X, X) -> Elemwise[zeros](X)
5210    if isinstance(node.op.scalar_op, (scalar.LT, scalar.GT)) and \
5211       node.inputs[0] is node.inputs[1]:
5212        res = T.zeros_like(node.inputs[0], dtype=dtype, opt=True)
5213        # Copy over stacktrace from previous output.
5214        copy_stack_trace(node.outputs, res)
5215        return [res]
5216    # Elemwise[{LE,GE}](X, X) -> Elemwise[ones](X)
5217    if isinstance(node.op.scalar_op, (scalar.LE, scalar.GE)) and \
5218       node.inputs[0] is node.inputs[1]:
5219        res = T.ones_like(node.inputs[0], dtype=dtype, opt=True)
5220
5221        # Copy over stacktrace from previous output.
5222        copy_stack_trace(node.outputs, res)
5223        return [res]
5224    # Elemwise[{minimum,maximum}](X, X) -> X
5225    if isinstance(node.op.scalar_op, (scalar.Minimum, scalar.Maximum)) and \
5226       node.inputs[0] is node.inputs[1]:
5227        res = node.inputs[0]
5228        # Copy over stacktrace from previous output.
5229        copy_stack_trace(node.outputs, res)
5230        return [res]
5231
5232    # Elemwise[LT](X.shape[i], 0) -> Elemwise[zeros](X)
5233    if isinstance(node.op.scalar_op, scalar.LT) and \
5234       node.inputs[0].owner and \
5235       isinstance(node.inputs[0].owner.op, Shape_i) and \
5236       T.extract_constant(node.inputs[1], only_process_constants=True) == 0:
5237        res = T.zeros_like(node.inputs[0], dtype=dtype, opt=True)
5238        # Copy over stacktrace from previous output.
5239        copy_stack_trace(node.outputs, res)
5240        return [res]
5241    # Elemwise[GE](X.shape[i], 0) -> Elemwise[ones](X)
5242    if isinstance(node.op.scalar_op, scalar.GE) and \
5243       node.inputs[0].owner and \
5244       isinstance(node.inputs[0].owner.op, Shape_i) and \
5245       T.extract_constant(node.inputs[1], only_process_constants=True) == 0:
5246        res = T.ones_like(node.inputs[0], dtype=dtype, opt=True)
5247        # Copy over stacktrace from previous output.
5248        copy_stack_trace(node.outputs, res)
5249        return [res]
5250    # Elemwise[maximum](X.shape[i], 0) -> X.shape[i]
5251    if isinstance(node.op.scalar_op, scalar.Maximum) and \
5252       node.inputs[0].owner and \
5253       isinstance(node.inputs[0].owner.op, Shape_i) and \
5254       T.extract_constant(node.inputs[1], only_process_constants=True) == 0:
5255        # No need to copy over stacktrace.
5256        return [node.inputs[0]]
5257    # Elemwise[maximum](0, X.shape[i]) -> X.shape[i]
5258    if isinstance(node.op.scalar_op, scalar.Maximum) and \
5259       T.extract_constant(node.inputs[0], only_process_constants=True) == 0 and \
5260       node.inputs[1].owner and \
5261       isinstance(node.inputs[1].owner.op, Shape_i):
5262        # No need to copy over stacktrace.
5263        return [node.inputs[1]]
5264    # Elemwise[minimum](X.shape[i], 0) -> 0
5265    if isinstance(node.op.scalar_op, scalar.Minimum) and \
5266       node.inputs[0].owner and \
5267       isinstance(node.inputs[0].owner.op, Shape_i) and \
5268       T.extract_constant(node.inputs[1], only_process_constants=True) == 0:
5269        res = T.zeros_like(node.inputs[0], dtype=dtype, opt=True)
5270        # Copy over stacktrace from previous output.
5271        copy_stack_trace(node.outputs, res)
5272        return [res]
5273
5274    # Elemwise[minimum](0, X.shape[i]) -> 0
5275    if isinstance(node.op.scalar_op, scalar.Minimum) and \
5276       T.extract_constant(node.inputs[0], only_process_constants=True) == 0 and \
5277       node.inputs[1].owner and \
5278       isinstance(node.inputs[1].owner.op, Shape_i):
5279        res = T.zeros_like(node.inputs[1], dtype=dtype, opt=True)
5280        # Copy over stacktrace from previous output.
5281        copy_stack_trace(node.outputs, res)
5282        return [res]
5283
5284    # Elemwise[LT](add([anything that is shapes]), 0) -> Elemwise[zeros](X)
5285    if isinstance(node.op.scalar_op, scalar.LT) and \
5286       node.inputs[0].owner and \
5287       isinstance(node.inputs[0].owner.op, Elemwise) and \
5288       isinstance(node.inputs[0].owner.op.scalar_op, scalar.Add) and \
5289       all([isinstance(var.owner and var.owner.op, Shape_i)
5290            for var in node.inputs[0].owner.inputs]) and \
5291       T.extract_constant(node.inputs[1], only_process_constants=True) == 0:
5292        res = T.zeros_like(node.inputs[0], dtype=dtype, opt=True)
5293        # Copy over stacktrace from previous output.
5294        copy_stack_trace(node.outputs, res)
5295        return [res]
5296    # Elemwise[GE](add([anything that is shapes]), 0) -> Elemwise[ones](X)
5297    if isinstance(node.op.scalar_op, scalar.GE) and \
5298       node.inputs[0].owner and \
5299       isinstance(node.inputs[0].owner.op, Elemwise) and \
5300       isinstance(node.inputs[0].owner.op.scalar_op, scalar.Add) and \
5301       all([isinstance(var.owner and var.owner.op, Shape_i)
5302            for var in node.inputs[0].owner.inputs]) and \
5303       T.extract_constant(node.inputs[1], only_process_constants=True) == 0:
5304        res = T.ones_like(node.inputs[0], dtype=dtype, opt=True)
5305
5306        # Copy over stacktrace from previous output.
5307        copy_stack_trace(node.outputs, res)
5308        return [res]
5309
5310    # Elemwise[EQ](Subtensor(Shape(x)), -N)
5311    # Elemwise[EQ](somegraph that only depend of shape, -N)
5312    # TODO: handle the case where the -N is on either side
5313        """
5314 |Elemwise{eq,no_inplace} [id B] ''
5315 | |Subtensor{int64} [id C] ''
5316 | | |Join [id D] ''
5317 | | | |TensorConstant{0} [id E]
5318 | | | |Subtensor{int64:int64:} [id F] ''
5319 | | | | |Shape [id G] ''
5320        """
5321    def investigate(node):
5322        " Return True if values will be shapes, so >= 0"
5323        if isinstance(node.op, (T.Shape, Shape_i)):
5324            return True
5325        elif isinstance(node.op, Subtensor) and node.inputs[0].owner:
5326            return investigate(node.inputs[0].owner)
5327        elif isinstance(node.op, T.Join):
5328            return all(v.owner and
5329                       investigate(v.owner) for v in node.inputs[1:])
5330        elif isinstance(node.op, MakeVector):
5331            return all(v.owner and
5332                       investigate(v.owner) for v in node.inputs)
5333
5334    if (isinstance(node.op.scalar_op, scalar.EQ) and
5335            node.inputs[0].owner and
5336            investigate(node.inputs[0].owner)):
5337        try:
5338            cst = get_scalar_constant_value(node.inputs[1],
5339                                            only_process_constants=True)
5340
5341            res = T.zeros_like(node.inputs[0], dtype=dtype, opt=True)
5342
5343            if cst < 0:
5344                # Copy over stacktrace from previous output.
5345                copy_stack_trace(node.outputs, res)
5346
5347                return [res]
5348
5349        except NotScalarConstantError:
5350            pass
5351    return
5352
5353
5354@register_canonicalize
5355@register_specialize
5356@gof.local_optimizer([T.Sum, T.elemwise.Prod])
5357def local_sum_prod_div_dimshuffle(node):
5358    """
5359    sum(a / dimshuffle{...}(b), axis=l) -> sum(a, axis={...}) / b,
5360    if dimension l of the DimShuffle is 'x'
5361
5362    or
5363
5364    prod(a / dimshuffle{...}(b), axis=l) ->
5365    prod(a, axis={...}) / b ** a.shape[l],
5366    if dimension l of the DimShuffle is 'x'
5367    """
5368
5369    # It does not make much sense now to extend it to the case where the
5370    # dimshuffle is in the numerator, since elemwise inversion of the
5371    # denominator would still be needed before the summation or production.
5372
5373    if isinstance(node.op, (T.Sum, T.elemwise.Prod)):
5374        axis = node.op.axis
5375        if axis is None:
5376            axis = list(range(node.inputs[0].ndim))
5377        node_input = node.inputs[0]
5378        if node_input.owner and node_input.owner.op == T.true_div:
5379            numerator, denominator = node_input.owner.inputs
5380
5381            # Old, bugged logic, reproduced here only to warn users
5382            if (config.warn.sum_div_dimshuffle_bug and
5383                    isinstance(node.op, T.Sum) and
5384                    numerator.owner and
5385                    isinstance(numerator.owner.op, T.DimShuffle)):
5386                # Check compatibility
5387                new_order = numerator.owner.op.new_order
5388                compatible_dims = True
5389                for ax in axis:
5390                    if len(new_order) <= ax or new_order[ax] != 'x':
5391                        compatible_dims = False
5392                        break
5393
5394                if compatible_dims:
5395                    _logger.warn('WARNING: Your current code is fine, but'
5396                                 ' Theano versions between '
5397                                 'rev. 3bd9b789f5e8 (2010-06-16) and'
5398                                 ' cfc6322e5ad4 (2010-08-03) would '
5399                                 'have given an incorrect result. '
5400                                 'To disable this warning, set the Theano'
5401                                 ' flag warn.sum_div_dimshuffle_bug to'
5402                                 ' False.')
5403
5404            if denominator.owner and isinstance(denominator.owner.op,
5405                                                T.DimShuffle):
5406                dimshuffle_input = denominator.owner.inputs[0]
5407                dimshuffle_order = denominator.owner.op.new_order
5408
5409                compatible_dims = []
5410                incompatible_dims = []
5411                for ax in axis:
5412                    if (ax < len(dimshuffle_order) and
5413                            dimshuffle_order[ax] == 'x'):
5414                        compatible_dims.append(ax)
5415                    else:
5416                        incompatible_dims.append(ax)
5417                reordered_incompatible_dims = []
5418                for ic_ax in incompatible_dims:
5419                    reordered_incompatible_dims.append(
5420                        ic_ax - sum(
5421                            [1 for c_ax in compatible_dims if c_ax < ic_ax]))
5422
5423                if len(compatible_dims) > 0:
5424                    optimized_dimshuffle_order = list(
5425                        ax for i, ax in enumerate(dimshuffle_order)
5426                        if (i not in axis) or (ax != 'x'))
5427
5428                    # Removing leading 'x' (since it will be done automatically)
5429                    while (len(optimized_dimshuffle_order) > 0 and
5430                           optimized_dimshuffle_order[0] == 'x'):
5431                        del optimized_dimshuffle_order[0]
5432
5433                    # if optimized_dimshuffle_order is sorted with
5434                    # not 'x', then dimshuffle is useless.
5435                    if all(i == e for i, e in
5436                           enumerate(optimized_dimshuffle_order)):
5437                        optimized_dimshuffle = dimshuffle_input
5438                    else:
5439                        optimized_dimshuffle = T.DimShuffle(
5440                            dimshuffle_input.type.broadcastable,
5441                            optimized_dimshuffle_order)(dimshuffle_input)
5442
5443                        if (config.warn.sum_div_dimshuffle_bug and
5444                                isinstance(node.op, T.Sum)):
5445                            _logger.warn('WARNING: Your current code is fine,'
5446                                         ' but Theano versions between '
5447                                         'rev. 3bd9b789f5e8 (2010-06-16) and'
5448                                         ' cfc6322e5ad4 (2010-08-03) would '
5449                                         'have given an incorrect result. '
5450                                         'To disable this warning, set the'
5451                                         ' Theano flag '
5452                                         'warn.sum_div_dimshuffle_bug'
5453                                         ' to False.')
5454
5455                    if isinstance(node.op, T.Sum):
5456                        op_on_compatible_dims = T.sum(
5457                            numerator, axis=compatible_dims)
5458                        rval = T.true_div(
5459                            op_on_compatible_dims,
5460                            optimized_dimshuffle)
5461                        if len(reordered_incompatible_dims) > 0:
5462                            rval = T.sum(rval,
5463                                         axis=reordered_incompatible_dims)
5464                    elif isinstance(node.op, T.elemwise.Prod):
5465                        op_on_compatible_dims = T.prod(
5466                            numerator, axis=compatible_dims)
5467                        dtype = numerator.dtype
5468                        rval = T.true_div(
5469                            op_on_compatible_dims,
5470                            (optimized_dimshuffle **
5471                                T.prod([numerator.shape[ax].astype(dtype)
5472                                        for ax in compatible_dims])))
5473                        if len(reordered_incompatible_dims) > 0:
5474                            rval = T.prod(rval,
5475                                          axis=reordered_incompatible_dims)
5476                    return [rval]
5477
5478
5479@register_canonicalize
5480@gof.local_optimizer([T.Sum, T.elemwise.Prod])
5481def local_sum_prod_all_to_none(node):
5482    """
5483    Sum{0,1,...N} -> Sum{} or
5484    Prod{0,1,...N} -> Prod{}
5485
5486    """
5487    if isinstance(node.op, T.Sum) or isinstance(node.op, T.elemwise.Prod):
5488        opt_type = T.Sum if isinstance(node.op, T.Sum) else T.elemwise.Prod
5489        # if all the axes are named, then use None as a shorthand
5490        # this permits more merging
5491        if node.op.axis is None:
5492            return
5493        if set(node.op.axis) == set(range(node.inputs[0].type.ndim)):
5494            return [opt_type(axis=None, dtype=node.op.dtype)(node.inputs[0])]
5495
5496
5497@register_canonicalize
5498@gof.local_optimizer([T.Sum, T.elemwise.Prod])
5499def local_op_of_op(node):
5500    """
5501    Prod(Prod()) -> single Prod()
5502    or
5503    Sum(Sum()) -> single Sum()
5504
5505    """
5506    if isinstance(node.op, T.elemwise.Prod) or isinstance(node.op, T.Sum):
5507        opt_type = T.Sum if isinstance(node.op, T.Sum) else T.elemwise.Prod
5508        node_inps, = node.inputs
5509        out_dtype = node.op.dtype
5510        # We manipulate the graph so this is done to make sure the opt
5511        # doesn't affect other computations.
5512        if len(node_inps.clients) == 1:
5513            if (node_inps.owner and
5514                    (isinstance(node_inps.owner.op, node.op.__class__))):
5515
5516                # check to see either the inner or outer prod is doing a
5517                # product over all axis, in which case we can remove it
5518                if node_inps.owner.op.axis is None or node.op.axis is None:
5519                    return [opt_type(None, dtype=out_dtype)(
5520                        node_inps.owner.inputs[0])]
5521
5522                # figure out which axes were in the original sum
5523                newaxis = list(tuple(node_inps.owner.op.axis))
5524                for i in node.op.axis:
5525                    new_i = i
5526                    for ii in node_inps.owner.op.axis:
5527                        if new_i >= ii:
5528                            new_i += 1
5529                    assert new_i not in newaxis
5530                    newaxis.append(new_i)
5531
5532                assert len(newaxis) == len(list(node_inps.owner.op.axis) +
5533                                           list(node.op.axis))
5534
5535                # The old bugged logic. We keep it there to generate a warning
5536                # when we generated bad code.
5537                alldims = list(range(node_inps.owner.inputs[0].type.ndim))
5538                alldims = [d for i, d in enumerate(alldims) if i
5539                           in node_inps.owner.op.axis]
5540                alldims = [d for i, d in enumerate(alldims)
5541                           if i in node.op.axis]
5542                newaxis_old = [i for i in
5543                               xrange(node_inps.owner.inputs[0].type.ndim)
5544                               if i not in alldims]
5545
5546                if (theano.config.warn.sum_sum_bug and
5547                        newaxis != newaxis_old and
5548                        len(newaxis) == len(newaxis_old)):
5549                    _logger.warn(
5550                        "WARNING (YOUR CURRENT CODE IS FINE): Theano "
5551                        "versions between version 9923a40c7b7a and August "
5552                        "2nd, 2010 generated bugged code in this case. "
5553                        "This happens when there are two consecutive sums "
5554                        "in the graph and the intermediate sum is not "
5555                        "used elsewhere in the code. Some safeguard "
5556                        "removed some bad code, but not in all cases. You "
5557                        "are in one such case. To disable this warning "
5558                        "(that you can safely ignore since this bug has "
5559                        "been fixed) set the theano flag "
5560                        "`warn.sum_sum_bug` to False.")
5561
5562                combined = opt_type(newaxis, dtype=out_dtype)
5563                return [combined(node_inps.owner.inputs[0])]
5564
5565
5566ALL_REDUCE = [T.elemwise.CAReduce, T.elemwise.All, T.elemwise.Any,
5567              T.elemwise.Sum, T.elemwise.Prod,
5568              T.elemwise.ProdWithoutZeros]
5569
5570
5571@register_canonicalize
5572@register_uncanonicalize  # Needed for MaxAndArgmax -> CAReduce
5573@gof.local_optimizer(ALL_REDUCE)
5574def local_reduce_join(node):
5575    """
5576    Reduce{scalar.op}(Join(axis=0, a, b), axis=0) -> Elemwise{scalar.op}(a, b)
5577
5578    Notes
5579    -----
5580    Supported scalar.op are Maximum, Mimimum in some cases and Add and Mul in
5581    all cases.
5582
5583    Currently we must reduce on axis 0. It is probably extensible to the case
5584    where we join and reduce on the same set of axis.
5585
5586    """
5587    if (isinstance(node.op, T.CAReduce) and
5588            node.inputs[0].owner and
5589            isinstance(node.inputs[0].owner.op, T.Join)):
5590        join = node.inputs[0].owner
5591        if T.extract_constant(join.inputs[0], only_process_constants=True) != 0:
5592            return
5593
5594        if isinstance(node.op.scalar_op, (scalar.Maximum, scalar.Minimum)):
5595            # Support only 2 inputs for now
5596            if len(join.inputs) != 3:
5597                return
5598        elif not isinstance(node.op.scalar_op, (scalar.Add, scalar.Mul)):
5599            return
5600        elif len(join.inputs) <= 2:
5601            # This is a useless join, that will get removed by another opt.
5602            return
5603
5604        new_inp = []
5605        for inp in join.inputs[1:]:
5606            inp = inp.owner
5607            if not inp:
5608                return
5609            if (not isinstance(inp.op, DimShuffle) or
5610                    inp.op.new_order != ('x',) +
5611                    tuple(range(inp.inputs[0].ndim))):
5612                return
5613            new_inp.append(inp.inputs[0])
5614        ret = Elemwise(node.op.scalar_op)(*new_inp)
5615
5616        if ret.dtype != node.outputs[0].dtype:
5617            # The reduction do something about the dtype.
5618            return
5619
5620        reduce_axis = node.op.axis
5621        if reduce_axis is None:
5622            reduce_axis = tuple(xrange(node.inputs[0].ndim))
5623
5624        # I put this warning late to don't add extra warning.
5625        if len(reduce_axis) != 1 or 0 not in reduce_axis:
5626            if theano.config.warn.reduce_join:
5627                warnings.warn((
5628                    'Your current code is fine, but Theano versions '
5629                    'prior to 0.7 (or this development version Sept 2014) '
5630                    'might have given an incorrect result for this code. '
5631                    'To disable this warning, set the Theano flag '
5632                    'warn.reduce_join to False. The problem was an '
5633                    'optimization, that modified the pattern '
5634                    '"Reduce{scalar.op}(Join(axis=0, a, b), axis=0)", '
5635                    'did not check the reduction axis. So if the '
5636                    'reduction axis was not 0, you got a wrong answer.'))
5637            return
5638
5639        # We add the new check late to don't add extra warning.
5640        try:
5641            join_axis = get_scalar_constant_value(join.inputs[0],
5642                                                  only_process_constants=True)
5643
5644            if join_axis != reduce_axis[0]:
5645                return
5646        except NotScalarConstantError:
5647            return
5648
5649        return [ret]
5650
5651
5652@register_canonicalize('fast_compile', 'local_cut_useless_reduce')
5653@register_useless('local_cut_useless_reduce')
5654@gof.local_optimizer(ALL_REDUCE)
5655def local_useless_reduce(node):
5656    """Sum(a, axis=[]) -> a  """
5657    if isinstance(node.op, T.CAReduce):
5658        summed, = node.inputs
5659        # if reduce were doing anything, the output ndim would be reduced
5660        if summed.type == node.outputs[0].type:
5661            return [summed]
5662
5663
5664@register_canonicalize
5665@register_uncanonicalize
5666@register_specialize
5667@gof.local_optimizer(ALL_REDUCE)
5668def local_reduce_broadcastable(node):
5669    """Remove reduction over broadcastable dimensions."""
5670    if isinstance(node.op, T.CAReduce):
5671        reduced, = node.inputs
5672        odtype = node.outputs[0].dtype
5673        if node.op.axis is None:
5674            if all(reduced.broadcastable):
5675                return [reduced.dimshuffle().astype(odtype)]
5676        else:
5677            axis = list(node.op.axis)
5678            cuttable = [a for a in axis if reduced.broadcastable[a]]
5679            if cuttable:
5680                # -- we can remove some axes of summation,
5681                #    which simplifies the codegen for sum, especially on GPU
5682                new_axis = []
5683                pattern = []
5684                ii = 0
5685                for p in xrange(reduced.ndim):
5686                    if p not in cuttable:
5687                        if p in axis:
5688                            new_axis.append(ii)
5689                        pattern.append(p)
5690                        ii += 1
5691                new_reduced = reduced.dimshuffle(*pattern)
5692                if new_axis:
5693                    if type(node.op) == theano.tensor.elemwise.CAReduce:
5694                        # This happen for tensor.max(), tensor.min()
5695                        new_op = node.op.__class__(node.op.scalar_op,
5696                                                   axis=new_axis)
5697                    else:
5698                        new_op = node.op.__class__(axis=new_axis)
5699                    return [new_op(new_reduced)]
5700                else:
5701                    # -- in this case we can remove the reduction completely
5702                    return [new_reduced.astype(odtype)]
5703
5704
5705@register_specialize
5706@gof.local_optimizer([T.Sum, T.elemwise.Prod])
5707def local_opt_alloc(node):
5708    """
5709    sum(alloc(constant,shapes...)) => constant*prod(shapes)
5710    or
5711    prod(alloc(constant,shapes...)) => constant**prod(shapes)
5712
5713    """
5714    if isinstance(node.op, T.Sum) or isinstance(node.op, T.elemwise.Prod):
5715        node_inps, = node.inputs
5716        if node_inps.owner and isinstance(node_inps.owner.op, T.Alloc):
5717            input = node_inps.owner.inputs[0]
5718            shapes = node_inps.owner.inputs[1:]
5719            try:
5720                val = get_scalar_constant_value(input,
5721                                                only_process_constants=True)
5722                assert val.size == 1
5723                val = val.reshape(1)[0]
5724                # check which type of op
5725                size = T.mul(*shapes)
5726                if input.dtype in ["float16", "float32"]:
5727                    # shapes are ints and normally int64.
5728                    # We don't want to have a float64 upcast
5729                    # We don't want to downcast to float16
5730                    # as we fear it could loose too much precision
5731                    # that will be amplified by the mul/pow below.
5732                    size = size.astype('float32')
5733                if (node.op.axis is None or
5734                        node.op.axis == tuple(range(input.ndim))):
5735                    if isinstance(node.op, T.Sum):
5736                        val = val * size
5737                    else:
5738                        val = val ** size
5739                    # Sum can change the input dtype (upcast or bool
5740                    # -> float32) by default or by user request.
5741                    # We can ignore the acc_dtype, as there is only 1
5742                    # elemwise we will do and not a sequence, so there is no
5743                    # accumulation of errors.
5744                    # So mostly, we just need to cast the output to the old
5745                    # dtype.
5746                    val = val.astype(node.outputs[0].dtype)
5747                    return [val]
5748                to_prod = [shapes[i] for i in xrange(len(shapes))
5749                           if i in node.op.axis]
5750                if to_prod:
5751                    size = T.mul(*to_prod)
5752                    if isinstance(node.op, T.Sum):
5753                        val *= size
5754                    else:
5755                        val = val ** size
5756                # See comments above.
5757                val = val.astype(node.outputs[0].dtype)
5758                return [T.alloc(val,
5759                                *[shapes[i] for i in xrange(len(shapes))
5760                                  if i not in node.op.axis])]
5761            except NotScalarConstantError:
5762                pass
5763
5764
5765@register_specialize
5766@gof.local_optimizer([T.neg])
5767def local_neg_neg(node):
5768    # other specializations shouldn't put this in,
5769    # but sometimes they do
5770    if node.op == T.neg:
5771        if node.inputs[0].owner and node.inputs[0].owner.op == T.neg:
5772            return [node.inputs[0].owner.inputs[0]]
5773
5774
5775@register_specialize
5776@gof.local_optimizer([T.neg])
5777def local_neg_div_neg(node):
5778    """
5779    - (-a / b) -> a / b
5780
5781    Also performs - (c / b) -> ((-c) / b) when c is a scalar constant.
5782
5783    """
5784    if node.op == T.neg:
5785        if node.inputs[0].owner and node.inputs[0].owner.op == T.true_div:
5786            frac = node.inputs[0]
5787            num, denom = frac.owner.inputs
5788            if num.owner and num.owner.op == T.neg:
5789                if len(frac.clients) == 1:
5790                    # No other clients of the original division
5791                    new_num = num.owner.inputs[0]
5792                    return [T.true_div(new_num, denom)]
5793            elif np.all(num.broadcastable) and isinstance(num, Constant):
5794                if len(frac.clients) == 1:
5795                    new_num = -num.data
5796                    return [T.true_div(new_num, denom)]
5797
5798
5799@gof.local_optimizer([T.mul])
5800def local_mul_zero(node):
5801    """
5802    As part of canonicalization, we replace multiplication by zero
5803    with zero.
5804
5805    """
5806    if node.op == T.mul:
5807        otype = node.outputs[0].type
5808
5809        for i in node.inputs:
5810            try:
5811                value = get_scalar_constant_value(i)
5812            except NotScalarConstantError:
5813                continue
5814            # print 'MUL by value', value, node.inputs
5815            if value == 0:
5816                # print '... returning zeros'
5817                return _fill_chain(theano._asarray(0, dtype=otype.dtype),
5818                                   node.inputs)
5819register_canonicalize(local_mul_zero)
5820
5821
5822@gof.local_optimizer([T.true_div])
5823def local_div_to_inv(node):
5824    if node.op == T.true_div and np.all(
5825            local_mul_canonizer.get_constant(node.inputs[0]) == 1.0):
5826        out = node.outputs[0]
5827        new_out = T.inv(local_mul_canonizer.merge_num_denum(node.inputs[1:],
5828                                                            []))
5829        # The ones could have forced upcasting
5830        if new_out.dtype != out.dtype:
5831            new_out = T.cast(new_out, dtype=out.dtype)
5832        # The ones could have forced a specific length
5833        if new_out.type != out.type:
5834            new_out = broadcast_like(new_out, out, node.fgraph)
5835        return [new_out]
5836    else:
5837        return False
5838register_specialize(local_div_to_inv)
5839
5840
5841@gof.local_optimizer([T.inv])
5842def local_inv_canon(node):
5843    if node.op == T.inv:
5844        return [T.pow(node.inputs[0], -1.0)]
5845    else:
5846        return False
5847register_canonicalize(local_inv_canon)
5848
5849
5850@gof.local_optimizer([T.pow])
5851def local_pow_canonicalize(node):
5852    if node.op == T.pow:
5853        cst = local_mul_canonizer.get_constant(node.inputs[1])
5854        if cst == 0:
5855            return [broadcast_like(1, node.outputs[0], node.fgraph)]
5856        if cst == 1:
5857            return [broadcast_like(node.inputs[0], node.outputs[0], node.fgraph)]
5858    else:
5859        return False
5860register_canonicalize(local_pow_canonicalize)
5861
5862
5863@register_specialize
5864@gof.local_optimizer([T.mul])
5865def local_mul_to_sqr(node):
5866    """
5867    x*x -> sqr(x)
5868
5869    This is faster on the GPU when memory fetching is a big part of
5870    the computation time.
5871
5872    """
5873    if node.op == T.mul:
5874        if len(node.inputs) == 2:
5875            if node.inputs[0] is node.inputs[1]:
5876                return [T.sqr(node.inputs[0])]
5877
5878
5879@register_canonicalize
5880@gof.local_optimizer([T.int_div])
5881def local_intdiv_by_one(node):
5882    """x // 1 -> x
5883    """
5884    if node.op in [T.int_div]:
5885        if isinstance(node.inputs[1], T.TensorConstant) and \
5886           np.all(node.inputs[1].value == 1):
5887            return [node.inputs[0].astype(node.outputs[0].dtype)]
5888
5889
5890@register_canonicalize
5891@register_specialize
5892@gof.local_optimizer([T.int_div, T.true_div])
5893def local_zero_div(node):
5894    """0 / x -> 0
5895    """
5896    if isinstance(node.op, T.Elemwise) and isinstance(
5897            node.op.scalar_op, (theano.scalar.IntDiv, theano.scalar.TrueDiv)):
5898        if local_mul_canonizer.get_constant(node.inputs[0]) == 0:
5899            ret = broadcast_like(0, node.outputs[0], node.fgraph)
5900            ret.tag.values_eq_approx = values_eq_approx_remove_nan
5901            return [ret]
5902
5903
5904@gof.local_optimizer([T.pow])
5905def local_pow_specialize(node):
5906    # here, we are past the point of canonicalization, so we don't want
5907    # to put in un-necessary fills.
5908    if node.op == T.pow:
5909        # the idea here is that we have pow(x, y)
5910        odtype = node.outputs[0].dtype
5911        xsym = node.inputs[0]
5912        ysym = node.inputs[1]
5913        y = local_mul_canonizer.get_constant(ysym)
5914        if (y is not None) \
5915                and encompasses_broadcastable(xsym.type.broadcastable,
5916                                              ysym.type.broadcastable):
5917            rval = None
5918
5919            if np.all(y == 2):
5920                rval = [T.sqr(xsym)]
5921            if np.all(y == 1):
5922                rval = [xsym]
5923            if np.all(y == 0):
5924                rval = [T.fill(xsym, np.asarray(1, dtype=odtype))]
5925            if np.all(y == 0.5):
5926                rval = [T.sqrt(xsym)]
5927            if np.all(y == -0.5):
5928                rval = [T.inv(T.sqrt(xsym))]
5929            if np.all(y == -1):
5930                rval = [T.inv(xsym)]
5931            if np.all(y == -2):
5932                rval = [T.inv(T.sqr(xsym))]
5933            if rval:
5934                rval[0] = T.cast(rval[0], odtype)
5935                assert rval[0].type == node.outputs[0].type, (
5936                    rval, node.outputs)
5937                return rval
5938    else:
5939        return False
5940register_specialize(local_pow_specialize)
5941
5942
5943@register_specialize_device
5944@gof.local_optimizer([T.pow])
5945def local_pow_specialize_device(node):
5946    """
5947    This optimization is not the same on all device. We do it only on cpu here.
5948    """
5949    if node.op == T.pow:
5950        # the idea here is that we have pow(x, y)
5951        odtype = node.outputs[0].dtype
5952        xsym = node.inputs[0]
5953        ysym = node.inputs[1]
5954        y = local_mul_canonizer.get_constant(ysym)
5955
5956        # the next line is needed to fix a strange case that I don't
5957        # know how to make a separate test.
5958        # That happen in the test_opt.py:test_log_erfc test.
5959        # y is a ndarray with dtype int8 and value 2,4 or 6. This make
5960        # the abs(y) <= 512 fail!
5961        # taking the value outside ndarray solve the problem.
5962        # it could be that in that case, numpy make the comparaison
5963        # into the wrong type(do in int8 that overflow.)
5964        if isinstance(y, np.ndarray):
5965            assert y.size == 1
5966            try:
5967                y = y[0]
5968            except IndexError:
5969                pass
5970        if (y is not None) \
5971                and encompasses_broadcastable(xsym.type.broadcastable,
5972                                              ysym.type.broadcastable):
5973            rval = None
5974            # 512 is too small for the cpu and too big for some gpu!
5975            if abs(y) == int(abs(y)) and abs(y) <= 512:
5976                pow2 = [xsym]
5977                pow2_scal = [theano.scalar.get_scalar_type(xsym.dtype)()]
5978                y_to_do = abs(y)
5979                for i in xrange(int(np.log2(y_to_do))):
5980                    pow2.append(T.sqr(pow2[i]))
5981                    pow2_scal.append(theano.scalar.sqr(pow2_scal[i]))
5982                rval1 = None
5983                rval1_scal = None
5984                while y_to_do > 0:
5985                    log_to_do = int(np.log2(y_to_do))
5986                    if rval1:
5987                        rval1 *= pow2[log_to_do]
5988                        rval1_scal *= pow2_scal[log_to_do]
5989                    else:
5990                        rval1 = pow2[log_to_do]
5991                        rval1_scal = pow2_scal[log_to_do]
5992                    y_to_do -= 2 ** log_to_do
5993
5994                if abs(y) > 2:
5995                    # We fuse all the pow together here to make
5996                    # compilation faster
5997                    rval1 = Elemwise(
5998                        theano.scalar.Composite(
5999                            [pow2_scal[0]], [rval1_scal])).make_node(xsym)
6000                if y < 0:
6001                    rval = [T.inv(rval1)]
6002                else:
6003                    rval = [rval1]
6004            if rval:
6005                rval[0] = T.cast(rval[0], odtype)
6006                assert rval[0].type == node.outputs[0].type, (
6007                    rval, node.outputs)
6008                return rval
6009
6010
6011@gof.local_optimizer([T.mul])
6012def local_mul_specialize(node):
6013    """
6014    Remove special-case constants from mul arguments and useless neg in inputs.
6015
6016    mul(-1, x) -> neg(x)
6017    mul(1, x, y) -> mul(x, y)
6018    mul(0, ...) -> alloc(0, shapes...)
6019
6020    This is not done if we would add more nodes in the graph, like with:
6021
6022    mul(-1, x, y) -/-> neg(mul(x, y))
6023
6024    """
6025    # here, we are past the point of canonicalization, so we don't
6026    # want to put in un-necessary fills.
6027    #
6028    # at this point [post canonicalize], mul() may have many inputs.
6029    if node.op == T.mul:
6030        # the idea here is that we have pow(x, y)
6031        neg = False
6032        new_inputs = []
6033        nb_neg_node = 0
6034        nb_cst = 0
6035        for input in node.inputs:
6036            # remove any neg arguments
6037            while input.owner and input.owner.op == T.neg:
6038                neg ^= True
6039                input = input.owner.inputs[0]
6040                nb_neg_node += 1
6041
6042            # remove special case arguments of 1, -1 or 0
6043            y = local_mul_canonizer.get_constant(input)
6044            if y == 1.0:
6045                nb_cst += 1
6046            elif y == -1.0:
6047                nb_cst += 1
6048                neg ^= True  # toggles
6049            elif y == 0.0:
6050                # if we find any zero, we just return right away
6051                return [broadcast_like(0, node.outputs[0], node.fgraph)]
6052            else:
6053                new_inputs.append(input)
6054
6055        if new_inputs != node.inputs:
6056            if new_inputs:
6057                if len(new_inputs) == 1:
6058                    if neg:
6059                        if new_inputs[0].dtype in (T.uint_dtypes + ['bool']):
6060                            return
6061                        else:
6062                            rval = -new_inputs[0]
6063                    else:
6064                        rval = new_inputs[0]
6065                else:
6066                    # The next case would cause a replace by an equivalent case.
6067                    if (neg and
6068                            nb_neg_node == 0 and
6069                            nb_cst == 1):
6070                        return
6071                    elif neg:
6072                        # Don't add an extra neg node as we can't
6073                        # fully replace this mul by a neg.
6074                        m1 = np.asarray(-1, dtype=node.outputs[0].dtype)
6075                        new_inputs = [m1] + new_inputs
6076                    rval = T.mul(*new_inputs)
6077
6078                return [broadcast_like(rval, node.outputs[0], node.fgraph)]
6079            else:
6080                # there are no variable inputs to mul
6081                # N.B. this could have been constant-folded...
6082                if neg:
6083                    return [broadcast_like(-1, node.outputs[0], node.fgraph)]
6084                else:
6085                    return [broadcast_like(1, node.outputs[0], node.fgraph)]
6086
6087register_specialize(local_mul_specialize)
6088
6089
6090@gof.local_optimizer([T.add])
6091def local_add_specialize(node):
6092    def fill_chain(v):
6093        out = _fill_chain(v, node.inputs)
6094        return out
6095
6096    # here, we are past the point of canonicalization, so we don't want
6097    # to put in un-necessary fills.
6098    if node.op == T.add:
6099        new_inputs = []
6100        for input in node.inputs:
6101            try:
6102                y = get_scalar_constant_value(input)
6103            except NotScalarConstantError:
6104                y = input
6105            if np.all(y == 0.0):
6106                continue
6107            new_inputs.append(input)
6108
6109        if len(new_inputs) < len(node.inputs):
6110            dtype = node.outputs[0].type.dtype
6111            if len(new_inputs) == 0:
6112                # we got rid of the entire expression!
6113                ndim = node.outputs[0].type.ndim
6114                # Reuse call to constant for cache()
6115                cst = T.constant(np.zeros((1,) * ndim, dtype=dtype))
6116                assert cst.type.broadcastable == (True,) * ndim
6117                return fill_chain(cst)
6118
6119            if len(new_inputs) == 1:
6120                ret = fill_chain(new_inputs[0])
6121            else:
6122                ret = fill_chain(T.add(*new_inputs))
6123            # The dtype should not be changed. It can happen if the input
6124            # that was forcing upcasting was equal to 0.
6125            if ret[0].dtype != dtype:
6126                ret = [T.cast(ret[0], dtype)]
6127            return ret
6128    else:
6129        return False
6130register_specialize(local_add_specialize)
6131
6132mul_canonizer = in2out(gof.LocalOptGroup(local_mul_canonizer,
6133                                         local_fill_sink, apply_all_opts=True),
6134                       name='mul_canonizer_groups')
6135
6136
6137def check_for_x_over_absX(numerators, denominators):
6138    """Convert x/abs(x) into sign(x). """
6139    # TODO: this function should dig/search through dimshuffles
6140    # This won't catch a dimshuffled absolute value
6141    for den in list(denominators):
6142        if (den.owner and den.owner.op == T.abs_ and
6143                den.owner.inputs[0] in numerators):
6144            if den.owner.inputs[0].type.dtype.startswith('complex'):
6145                # TODO: Make an Op that projects a complex number to
6146                #      have unit length but projects 0 to 0.  That
6147                #      would be a weird Op, but consistent with the
6148                #      special case below.  I heard there's some
6149                #      convention in Matlab that is similar to
6150                #      this... but not sure.
6151                pass
6152            else:
6153                denominators.remove(den)
6154                numerators.remove(den.owner.inputs[0])
6155                numerators.append(T.sgn(den.owner.inputs[0]))
6156    return numerators, denominators
6157local_mul_canonizer.add_simplifier(check_for_x_over_absX, 'X_over_absX')
6158
6159
6160@register_canonicalize
6161@gof.local_optimizer([T.abs_])
6162def local_abs_lift(node):
6163    """
6164    Move the abs toward the input.
6165
6166    This is needed for check_for_x_over_absX to apply in more case.
6167
6168    """
6169    if node.op == T.abs_ and node.inputs[0].owner:
6170        assert node.nin == 1
6171        if node.inputs[0].owner.op == T.mul:
6172            return [T.mul(*[T.abs_(i) for i in node.inputs[0].owner.inputs])]
6173        if node.inputs[0].owner.op == T.true_div:
6174            i = node.inputs[0].owner.inputs
6175            return [T.true_div(T.abs_(i[0]), T.abs_(i[1]))]
6176
6177
6178@register_specialize
6179@gof.local_optimizer([T.mul, T.true_div])
6180def local_abs_merge(node):
6181    """
6182    Merge abs generated by local_abs_lift when the canonizer don't
6183    need it anymore
6184
6185    """
6186    if node.op == T.mul and sum([i.owner.op == T.abs_ for i in node.inputs
6187                                 if i.owner]) > 1:
6188        inputs = []
6189        for i in node.inputs:
6190            if i.owner and i.owner.op == T.abs_:
6191                inputs.append(i.owner.inputs[0])
6192            elif isinstance(i, Constant):
6193                try:
6194                    const = get_scalar_constant_value(i,
6195                                                      only_process_constants=True)
6196                except NotScalarConstantError:
6197                    return False
6198                if not (const >= 0).all():
6199                    return False
6200                inputs.append(i)
6201            else:
6202                return False
6203        return [T.abs_(T.mul(*inputs))]
6204    if node.op == T.true_div and sum([i.owner.op == T.abs_ for i in
6205                                      node.inputs if i.owner]) == 2:
6206        return [T.abs_(T.true_div(node.inputs[0].owner.inputs[0],
6207                                  node.inputs[1].owner.inputs[0]))]
6208
6209
6210@register_stabilize
6211@register_specialize
6212@gof.local_optimizer([T.log])
6213def local_log1p(node):
6214    # log(1+x) -> log1p(x)
6215    # log(1-x) -> log1p(-x)
6216    if node.op == T.log:
6217        log_arg, = node.inputs
6218        if log_arg.owner and log_arg.owner.op == T.add:
6219            scalars, scalar_inputs, nonconsts = scalarconsts_rest(
6220                log_arg.owner.inputs, only_process_constants=True)
6221            # scalar_inputs are potentially dimshuffled and fill'd scalars
6222            if scalars and np.allclose(np.sum(scalars), 1):
6223                if nonconsts:
6224                    if len(nonconsts) > 1:
6225                        ninp = T.add(*nonconsts)
6226                    else:
6227                        ninp = nonconsts[0]
6228                    if ninp.dtype != log_arg.type.dtype:
6229                        ninp = ninp.astype(node.outputs[0].dtype)
6230                    return _fill_chain(T.log1p(ninp), scalar_inputs)
6231
6232        elif log_arg.owner and log_arg.owner.op == T.sub:
6233            one = T.extract_constant(log_arg.owner.inputs[0],
6234                                     only_process_constants=True)
6235            if one != 1:
6236                return
6237            other = log_arg.owner.inputs[1]
6238            if other.dtype != log_arg.dtype:
6239                other = other.astype(log_arg.dtype)
6240            return [T.log1p(T.neg(other))]
6241
6242
6243# TODO: in canonicalize, change log10 and log2 -> log
6244@register_stabilize
6245@register_specialize
6246@gof.local_optimizer([T.log])
6247def local_log_add(node):
6248    # log(exp(x)+exp(y))
6249    #
6250    # Suppose x >= y
6251    # log(exp(x) + exp(y))
6252    # log(exp(x) * (1 + exp(y)/exp(x)))
6253    # x + log(1 + exp(y)/exp(x))
6254    # x + log1p(exp(y)/exp(x))
6255    # x + log1p(exp(y-x))
6256    if node.op == T.log:
6257        z = node.inputs[0]
6258        if z.owner and z.owner.op == T.add:
6259            zi = z.owner.inputs
6260            if len(zi) != 2:
6261                # -- upgrading Maximum to handle multiple inputs wasn't trivial
6262                #    TODO
6263                # raise NotImplementedError()
6264                return
6265            pre_exp = [x.owner.inputs[0] for x in zi
6266                       if x.owner and x.owner.op == T.exp]
6267            if len(pre_exp) == len(zi):
6268                # all arguments to add are exp(<something>)
6269                max_pre = T.maximum(*pre_exp)
6270
6271                ret = max_pre + T.log1p(T.exp(T.add(*[p - max_pre
6272                                                      for p in pre_exp])))
6273                ret.tag.values_eq_approx = values_eq_approx_remove_inf
6274                return [ret]
6275
6276
6277@gof.local_optimizer([T.log])
6278def local_log_sum_exp(node):
6279    # log(sum_i(exp(x_i))) = x_max + log(sum_i(exp(x_i - x_max)))
6280
6281    if node.op != T.log:
6282        return
6283
6284    sum_node = node.inputs[0].owner
6285    # If the sum has keepdims=True, there might be a dimshuffle
6286    if sum_node and isinstance(sum_node.op, T.DimShuffle):
6287        dimshuffle_op = sum_node.op
6288        sum_node = sum_node.inputs[0].owner
6289    else:
6290        dimshuffle_op = None
6291
6292    if not sum_node or not isinstance(sum_node.op, T.Sum):
6293        return
6294
6295    exp_node, axis = sum_node.inputs[0].owner, sum_node.op.axis
6296    if not exp_node or not (
6297            isinstance(exp_node.op, Elemwise) and
6298            isinstance(exp_node.op.scalar_op, scalar.Exp)):
6299        return
6300
6301    pre_exp = exp_node.inputs[0]
6302    max_pre_exp = T.max(pre_exp, axis=axis)
6303    max_pre_exp_keepdims = T.makeKeepDims(pre_exp, max_pre_exp, axis)
6304
6305    ret = (max_pre_exp +
6306           T.log(T.sum(T.exp(pre_exp - max_pre_exp_keepdims), axis=axis)))
6307
6308    # Restore the dimshuffle op, if any.
6309    if dimshuffle_op:
6310        ret = dimshuffle_op(ret)
6311
6312    return [ret]
6313
6314
6315compile.optdb.register('local_log_sum_exp',
6316                       in2out(local_log_sum_exp, ignore_newtrees=True),
6317                       1.6, 'fast_run')
6318
6319
6320def add_calculate(num, denum, aslist=False, out_type=None):
6321    # TODO: make sure that this function and mul_calculate are similar
6322    if out_type is None:
6323        zero = 0.0
6324    else:
6325        zero = theano._asarray(0, dtype=out_type.dtype)
6326    # zero = 0.0 if out_type is None else theano._asarray(0,
6327    # dtype=out_type.dtype)
6328    if out_type and out_type.dtype == 'bool':
6329        if len(denum) == 0:
6330            # NumPy 1.14 do not accept to do "bool - bool"
6331            v = reduce(np.add, num, zero)
6332        else:
6333            raise Exception(
6334                "bool subtraction not supported. This should not happen as"
6335                " an earlier error should have been raised")
6336    else:
6337        v = reduce(np.add, num, zero) - reduce(np.add, denum, zero)
6338    if aslist:
6339        if np.all(v == 0):
6340            return []
6341        else:
6342            return [v]
6343    return v
6344
6345
6346local_add_canonizer = Canonizer(T.add, T.sub, T.neg, add_calculate)
6347add_canonizer = in2out(gof.LocalOptGroup(local_add_canonizer,
6348                                         local_fill_sink, apply_all_opts=True),
6349                       name='add_canonizer_group')
6350
6351
6352register_canonicalize(local_add_canonizer, name='local_add_canonizer')
6353
6354
6355##################
6356# Distributivity #
6357##################
6358
6359
6360def distribute_greedy(pos_pairs, neg_pairs, num, denum,
6361                      out_type, minscore=0):
6362    # each pair in pos_pairs and neg_pairs is a num/denum pair. this
6363    # function attempts to add num and denum to the corresponding parts
6364    # of each pair, and counts how many multiplications/divisions can
6365    # be saved in that way.
6366
6367    # each division is counted like div_cost multiplications
6368    # (typically, division costs more so we are willing to multiply more
6369    # in order to divide less)
6370    # 1.5 was obtained through an informal test and may very well be
6371    # platform dependent
6372    div_cost = 1.5
6373
6374    # score is number of operations saved, higher is better
6375    score = len(num) + div_cost * len(denum)
6376    new_pos_pairs = list(itertools.starmap(local_mul_canonizer.simplify,
6377                                           [(n + num, d + denum, out_type) for (n, d)
6378                                            in pos_pairs]))
6379    new_neg_pairs = list(itertools.starmap(local_mul_canonizer.simplify,
6380                                           [(n + num, d + denum, out_type) for (n, d)
6381                                            in neg_pairs]))
6382    for (n, d), (nn, dd) in zip(pos_pairs + neg_pairs, new_pos_pairs +
6383                                new_neg_pairs):
6384        # We calculate how many operations we are saving with the new
6385        # num and denum
6386        score += len(n) + div_cost * len(d) - len(nn) - div_cost * len(dd)
6387    if score <= minscore:
6388        # the change is not applied because it adds too many operations
6389        return False, pos_pairs, neg_pairs
6390    return True, new_pos_pairs, new_neg_pairs
6391
6392
6393def attempt_distribution(factor, num, denum, out_type):
6394    # we try to insert each num and each denum in the factor
6395    # returns: changes?, new_factor, new_num, new_denum
6396    # if there are changes, new_num and new_denum contain all the numerators
6397    # and denumerators that could not be distributed in the factor
6398    pos, neg = local_add_canonizer.get_num_denum(factor)
6399    if len(pos) == 1 and not neg:
6400        return False, factor, num, denum
6401    pos_pairs = list(map(local_mul_canonizer.get_num_denum, pos))
6402    neg_pairs = list(map(local_mul_canonizer.get_num_denum, neg))
6403    change = False
6404    for n in list(num):
6405        success, pos_pairs, neg_pairs = distribute_greedy(pos_pairs,
6406                                                          neg_pairs, [n], [], out_type)
6407        if success:
6408            change = True
6409            num.remove(n)
6410    for d in list(denum):
6411        success, pos_pairs, neg_pairs = distribute_greedy(pos_pairs,
6412                                                          neg_pairs, [], [d], out_type)
6413        if success:
6414            change = True
6415            denum.remove(d)
6416    if not change:
6417        return change, factor, num, denum
6418    else:
6419        return change, local_add_canonizer.merge_num_denum(
6420            list(itertools.starmap(local_mul_canonizer.merge_num_denum,
6421                                   pos_pairs)),
6422            list(itertools.starmap(local_mul_canonizer.merge_num_denum,
6423                                   neg_pairs))), num, denum
6424
6425
6426@register_canonicalize
6427@register_stabilize
6428@gof.local_optimizer([T.mul, T.true_div, T.inv])
6429def local_greedy_distributor(node):
6430    """
6431    Optimize by reducing the number of multiplications and/or divisions.
6432
6433    This optimization tries to apply distributivity of multiplication
6434    to addition in order to reduce the number of multiplications
6435    and/or divisions that must be done. The algorithm weighs division
6436    more than multiplication to account for the former's slightly
6437    greater computational cost.
6438
6439    The following expressions are simplified:
6440    1. ((a/x + b/y) * x * y) --> a*y + b*x
6441    2. ((a/x + b) * x) --> a + b*x
6442    3. There are other forms too where node is a true_div.
6443
6444    The following expressions are not simplified:
6445    4. ((a + b) * x) -/-> a*x + b*x
6446
6447    This optimization aims to reduce computational cost. It may also
6448    increase numerical stability, e.g. when x and/or y tend to 0 in
6449    example 1.
6450
6451    """
6452
6453    out = node.outputs[0]
6454    num, denum = local_mul_canonizer.get_num_denum(out)
6455    if len(num) == 1 and not denum:
6456        return False
6457
6458    new_num, new_denum = [], []
6459
6460    change = False
6461
6462    out_type = out.type
6463    for candidate in list(num):
6464        if candidate not in num:
6465            continue
6466        num.remove(candidate)
6467        _change, candidate, num, denum = attempt_distribution(
6468            candidate, num, denum, out_type,)
6469
6470        change |= _change
6471        new_num.append(candidate)
6472
6473    for candidate in list(denum):
6474        if candidate not in denum:
6475            continue
6476        denum.remove(candidate)
6477        _change, candidate, denum, num = attempt_distribution(
6478            candidate, denum, num, out_type)
6479        change |= _change
6480        new_denum.append(candidate)
6481    if not change:
6482        return False
6483
6484    new_num += num
6485    new_denum += denum
6486
6487    rval = local_mul_canonizer.merge_num_denum(new_num, new_denum)
6488
6489    if not (rval.type == out.type):
6490        # WHY DOES THIS HAPPEN?
6491        return False
6492
6493    return [rval]
6494
6495
6496@gof.local_optimizer(None)
6497def constant_folding(node):
6498    for input in node.inputs:
6499        if not isinstance(input, Constant):
6500            return False
6501    # condition:  all inputs are constant
6502    if not node.op.do_constant_folding(node):
6503        # The op asks not to be constant folded.
6504        return False
6505
6506    storage_map = dict([(i, [i.data]) for i in node.inputs])
6507    compute_map = dict([(i, [True]) for i in node.inputs])
6508    for o in node.outputs:
6509        storage_map[o] = [None]
6510        compute_map[o] = [False]
6511    impl = None
6512    if (hasattr(node.op, 'python_constant_folding') and
6513            node.op.python_constant_folding(node)):
6514        impl = 'py'
6515    thunk = node.op.make_thunk(node, storage_map, compute_map,
6516                               no_recycling=[], impl=impl)
6517
6518    required = thunk()
6519    assert not required  # a node whose inputs are all provided should always
6520    # return successfully
6521    rval = []
6522    for output in node.outputs:
6523        assert compute_map[output][0], (output, storage_map[output][0])
6524        try:
6525            constant = output.type.Constant
6526        except AttributeError:
6527            constant = Constant
6528
6529        v = constant(output.type, storage_map[output][0])
6530        copy_stack_trace(output, v)
6531
6532        rval.append(v)
6533    return rval
6534
6535
6536topo_constant_folding = in2out(constant_folding, ignore_newtrees=True,
6537                               name="topo_constant_folding")
6538register_canonicalize(topo_constant_folding, 'fast_compile', final_opt=True)
6539register_uncanonicalize(topo_constant_folding, 'fast_compile', final_opt=True)
6540register_stabilize(topo_constant_folding, 'fast_compile', final_opt=True)
6541register_specialize(topo_constant_folding, 'fast_compile', final_opt=True)
6542
6543
6544def get_clients(node):
6545    """
6546    Used by erf/erfc opt to track less frequent op.
6547
6548    """
6549    return [c for c, i in node.outputs[0].clients
6550            if c != "output"]
6551
6552
6553def get_clients2(node):
6554    """
6555    Used by erf/erfc opt to track less frequent op.
6556
6557    """
6558    l = []
6559    for c, i in node.outputs[0].clients:
6560        if c != "output":
6561            for var in c.outputs:
6562                l.extend([cc for cc, ii in var.clients if cc != "output"])
6563    return l
6564
6565# 1+erf(x)=>erfc(-x)
6566local_one_plus_erf = gof.PatternSub((T.add,
6567                                     1,
6568                                     (T.erf, 'x')),
6569                                    (T.erfc, (T.neg, 'x')),
6570                                    allow_multiple_clients=True,
6571                                    name='local_one_plus_erf',
6572                                    tracks=[T.erf],
6573                                    get_nodes=get_clients)
6574register_canonicalize(local_one_plus_erf)
6575register_stabilize(local_one_plus_erf)
6576register_specialize(local_one_plus_erf)
6577
6578# 1-erf(x)=>erfc(x)
6579local_one_minus_erf = gof.PatternSub((T.sub,
6580                                      1,
6581                                      (T.erf, 'x')),
6582                                     (T.erfc, 'x'),
6583                                     allow_multiple_clients=True,
6584                                     name='local_one_minus_erf',)
6585register_canonicalize(local_one_minus_erf)
6586register_stabilize(local_one_minus_erf)
6587register_specialize(local_one_minus_erf)
6588
6589local_one_minus_erf2 = gof.PatternSub((T.add,
6590                                      1,
6591                                      (T.mul, -1, (T.erf, 'x'))),
6592                                      (T.erfc, 'x'),
6593                                      allow_multiple_clients=True,
6594                                      name='local_one_minus_erf2')
6595register_canonicalize(local_one_minus_erf2)
6596register_stabilize(local_one_minus_erf2)
6597register_specialize(local_one_minus_erf2)
6598
6599# 1+(-erf(x))=>erfc(x) This is a different graph then the previous as
6600# the canonicalize don't work completly
6601local_one_plus_neg_erf = gof.PatternSub((T.add,
6602                                         1,
6603                                         (T.neg, (T.erf, 'x'))),
6604                                        (T.erfc, 'x'),
6605                                        allow_multiple_clients=True,
6606                                        name='local_one_plus_neg_erf',
6607                                        tracks=[T.erf],
6608                                        get_nodes=get_clients2)
6609register_canonicalize(local_one_plus_neg_erf)
6610register_stabilize(local_one_plus_neg_erf)
6611register_specialize(local_one_plus_neg_erf)
6612
6613# (-1)+erf(x) => -erfc(x) don't need erf(x)+(-1) as the canonicalize
6614# will put the -1 as the first argument.
6615local_erf_minus_one = gof.PatternSub((T.add,
6616                                      -1,
6617                                      (T.erf, 'x')),
6618                                     (T.neg, (T.erfc, 'x')),
6619                                     allow_multiple_clients=True,
6620                                     name='local_erf_minus_one',
6621                                     tracks=[T.erf],
6622                                     get_nodes=get_clients)
6623register_canonicalize(local_erf_minus_one)
6624register_stabilize(local_erf_minus_one)
6625register_specialize(local_erf_minus_one)
6626
6627# 1-erfc(x) => erf(x)
6628local_one_minus_erfc = gof.PatternSub((T.sub,
6629                                       1,
6630                                       (T.erfc, 'x')),
6631                                      (T.erf, 'x'),
6632                                      allow_multiple_clients=True,
6633                                      name='local_one_minus_erfc',
6634                                      tracks=[T.erfc],
6635                                      get_nodes=get_clients)
6636register_canonicalize(local_one_minus_erfc)
6637register_stabilize(local_one_minus_erfc)
6638register_specialize(local_one_minus_erfc)
6639
6640local_one_minus_erfc2 = gof.PatternSub((T.add,
6641                                        1,
6642                                        (T.neg, (T.erfc, 'x'))),
6643                                       (T.erf, 'x'),
6644                                       allow_multiple_clients=True,
6645                                       name='local_one_minus_erfc2',
6646                                       tracks=[T.erfc],
6647                                       get_nodes=get_clients2)
6648register_canonicalize(local_one_minus_erfc2)
6649register_stabilize(local_one_minus_erfc2)
6650register_specialize(local_one_minus_erfc2)
6651
6652local_one_minus_erfc3 = gof.PatternSub((T.add,
6653                                        1,
6654                                        (T.mul, -1, (T.erfc, 'x'))),
6655                                       (T.erf, 'x'),
6656                                       allow_multiple_clients=True,
6657                                       name='local_one_minus_erfc3',
6658                                       tracks=[T.erfc],
6659                                       get_nodes=get_clients2)
6660register_canonicalize(local_one_minus_erfc3)
6661register_stabilize(local_one_minus_erfc3)
6662register_specialize(local_one_minus_erfc3)
6663
6664# 1+(-erfc(x)) => erf(x) This is a different graph then the previous as
6665# the canonicalize don't work completly
6666local_one_add_neg_erfc = gof.PatternSub((T.add,
6667                                         1,
6668                                         (T.neg, (T.erfc, 'x'))),
6669                                        (T.erf, 'x'),
6670                                        allow_multiple_clients=True,
6671                                        name='local_one_add_neg_erfc',
6672                                        tracks=[T.erfc],
6673                                        get_nodes=get_clients2)
6674
6675register_canonicalize(local_one_add_neg_erfc)
6676register_stabilize(local_one_add_neg_erfc)
6677register_specialize(local_one_add_neg_erfc)
6678
6679# (-1)+erfc(-x)=>erf(x)
6680local_erf_neg_minus_one = gof.PatternSub((T.add,
6681                                          -1,
6682                                          (T.erfc, (T.neg, 'x'))),
6683                                         (T.erf, 'x'),
6684                                         allow_multiple_clients=True,
6685                                         name='local_erf_neg_minus_one',
6686                                         tracks=[T.erfc],
6687                                         get_nodes=get_clients)
6688register_canonicalize(local_erf_neg_minus_one)
6689register_stabilize(local_erf_neg_minus_one)
6690register_specialize(local_erf_neg_minus_one)
6691
6692# (-1)+erfc(-1*x)=>erf(x)
6693local_erf_neg_minus_one2 = gof.PatternSub((T.add,
6694                                           -1,
6695                                           (T.erfc, (T.mul, -1, 'x'))),
6696                                          (T.erf, 'x'),
6697                                          allow_multiple_clients=True,
6698                                          name='local_erf_neg_minus_one2',
6699                                          tracks=[T.erfc],
6700                                          get_nodes=get_clients)
6701register_canonicalize(local_erf_neg_minus_one2)
6702register_stabilize(local_erf_neg_minus_one2)
6703register_specialize(local_erf_neg_minus_one2)
6704
6705
6706# Stability optimization
6707# log(erfc(x)) => when x>threashold,
6708#              -x**2-log(x)-.5*log(pi)+log(1-1/(2*x**2)+3/(4*x**4)-15/(8*x**6))
6709# for float64: threshold=26.641747557 was choosed with:
6710#  [(i,numpy.log(scipy.special.erfc(numpy.asarray([i],dtype='float64'))))
6711#   for i in numpy.arange(26.641747557,26.6417475571,.00000000001)]
6712# for float32: threshold=10.0541949, [(i,numpy.log(scipy.special.erfc(
6713#        numpy.asarray([i],dtype='float32')))) for i in numpy.arange(
6714#        10.0541948,10.0541951,.0000001)]
6715@register_stabilize
6716@register_specialize
6717@gof.local_optimizer([T.log])
6718def local_log_erfc(node):
6719    if node.op != T.log:
6720        return False
6721    if not node.inputs[0].owner or node.inputs[0].owner.op != T.erfc:
6722        return False
6723
6724    if hasattr(node.tag, 'local_log_erfc_applied'):
6725        # We use that flag to don't apply the optimization recursively
6726        return False
6727    node.tag.local_log_erfc_applied = True
6728
6729    x = node.inputs[0].owner.inputs[0]
6730    stab_value = (-x ** 2 - T.log(x) - .5 * T.log(np.pi) +
6731                  T.log(1 - 1 / (2 * x ** 2) + 3 / (4 * x ** 4) -
6732                  15 / (8 * x ** 6)))
6733
6734    if (node.outputs[0].dtype == 'float32' or
6735            node.outputs[0].dtype == 'float16'):
6736        threshold = 10.0541949
6737    elif node.outputs[0].dtype == 'float64':
6738        threshold = 26.641747557
6739
6740    ret = T.switch(x < threshold, node.outputs[0], stab_value)
6741    ret.tag.values_eq_approx = values_eq_approx_remove_inf
6742    return [ret]
6743
6744
6745# Stability optimization of the grad of log(erfc(x))
6746# ([y*]exp(-(x**2)))/erfc(x) # The y* is optional
6747# ([y*]exp(x**2))/erfc(-x) => [y*](when x>threashold,
6748#                            sqrt(pi)*-x/(1-1/(2*x**2)+3/(4*x**4)-15/(8*x**6)))
6749# for float64: threshold=26.63 see at the end of the fct for the explanation
6750# for float32: threshold=9.3 see at the end of the fct for the explanation
6751# TODO: remove the contraint that there are only 2 inputs to exp(x**2)
6752#      is the second.
6753# TODO: at the test point 10 in float32, there is instability in the original
6754#      value. The original gives -30.0, the stab -20.1 and in float64 -18.1.
6755#      Make it so that the test does not generate an error in that case!
6756@register_stabilize
6757@register_specialize
6758@gof.local_optimizer([T.true_div])
6759def local_grad_log_erfc_neg(node):
6760    if node.op != T.true_div:
6761        return False
6762    if not node.inputs[1].owner or node.inputs[1].owner.op != T.erfc:
6763        return False
6764    erfc = node.inputs[1]
6765    erfc_x = erfc.owner.inputs[0]
6766    if not node.inputs[0].owner:
6767        return False
6768
6769    # The mul is optional.
6770    if node.inputs[0].owner.op != T.mul:
6771        mul = None
6772        y = []
6773        if not node.inputs[0].owner or node.inputs[0].owner.op != T.exp:
6774            return False
6775        exp = node.inputs[0]
6776    else:
6777        mul = node.inputs[0]
6778        exp = None
6779        for idx, inp in enumerate(mul.owner.inputs):
6780            if inp.owner and inp.owner.op == T.exp:
6781                exp = inp
6782                break
6783        if len(mul.owner.inputs) == 2:
6784            y = [mul.owner.inputs[1 - idx]]
6785        else:
6786            y = mul.owner.inputs[:]
6787            del y[idx]
6788    del mul
6789    if not exp.owner.inputs[0].owner:
6790        return False
6791
6792    if exp.owner.inputs[0].owner.op == T.neg:
6793        neg = exp.owner.inputs[0]
6794        if (not neg.owner.inputs[0].owner or
6795                neg.owner.inputs[0].owner.op != T.sqr):
6796            return False
6797        sqr = neg.owner.inputs[0]
6798        x = sqr.owner.inputs[0]
6799    elif exp.owner.inputs[0].owner.op == T.mul:
6800        # We should compare that -(erfc_x**2) is equivalent to mul_neg.
6801        # There is currently no easy way to do this in the general case,
6802        # so we implement some common case for now.
6803
6804        # In many cases the neg are replaced by mul in the graph.
6805        # This also allows to stabilize log(erfc(cst*x)).
6806        mul_neg = exp.owner.inputs[0]
6807
6808        # In case that multiple mul are not fused together, we do it here.
6809        def check_input(inputs):
6810            new_inputs = []
6811            for i in inputs:
6812                if i.owner and i.owner.op == T.mul:
6813                    new_inputs.extend(check_input(i.owner.inputs))
6814                else:
6815                    new_inputs.append(i)
6816            return new_inputs
6817        mul_inputs = check_input(mul_neg.owner.inputs)
6818
6819        # Put the constant first.
6820        for i in xrange(len(mul_inputs)):
6821            if isinstance(i, Constant):
6822                if i == 0:
6823                    break
6824                else:
6825                    tmp = mul_inputs[0]
6826                    mul_inputs[0] = mul_inputs[i]
6827                    mul_inputs[i] = tmp
6828                    break
6829        mul_neg = T.mul(*mul_inputs)
6830
6831        try:
6832            cst2 = get_scalar_constant_value(mul_neg.owner.inputs[0],
6833                                             only_process_constants=True)
6834        except NotScalarConstantError:
6835            return False
6836
6837        if len(mul_neg.owner.inputs) == 2:
6838            if (not mul_neg.owner.inputs[1].owner or
6839                    mul_neg.owner.inputs[1].owner.op != T.sqr):
6840                return False
6841            sqr = mul_neg.owner.inputs[1]
6842            x = sqr.owner.inputs[0]
6843        elif len(mul_neg.owner.inputs) == 3:
6844            if mul_neg.owner.inputs[1] is not mul_neg.owner.inputs[2]:
6845                return False
6846            x = mul_neg.owner.inputs[1]
6847        else:
6848            return False
6849
6850        if cst2 != -1:
6851            if (not erfc_x.owner or erfc_x.owner.op != T.mul or
6852                    len(erfc_x.owner.inputs) != 2):
6853                # todo implement that case
6854                return False
6855            if erfc_x.owner.inputs[1] is not mul_neg.owner.inputs[1]:
6856                return False
6857
6858            x = erfc_x
6859            try:
6860                cst = get_scalar_constant_value(erfc_x.owner.inputs[0],
6861                                                only_process_constants=True)
6862            except NotScalarConstantError:
6863                return False
6864            if cst2 != -cst * 2:
6865                return False
6866
6867            # The constant is valid. Must check that the
6868        elif erfc_x is not x:
6869            return False
6870
6871    else:
6872        return False
6873
6874    if hasattr(node.tag, 'local_grad_log_erfc_neg'):
6875        # We use that flag to don't apply the optimization recursively
6876        return False
6877
6878    # we move the y outside the div.
6879    true_div_no_mul = T.true_div(exp, erfc)
6880    true_div_no_mul.owner.tag.local_grad_log_erfc_neg = True
6881
6882    # aaron value
6883    stab_value = (x * T.pow(1 - 1 / (2 * (x ** 2)) +
6884                  3 / (4 * (x ** 4)) - 15 / (8 * (x ** 6)), -1) *
6885                  T.cast(T.sqrt(np.pi), dtype=x.dtype))
6886
6887    if x.dtype == 'float32' or x.dtype == 'float16':
6888        threshold = 9.3
6889        # threshold = 10.1
6890    elif x.dtype == 'float64':
6891        threshold = 26.641747557
6892    ret = T.switch(x < threshold, true_div_no_mul, stab_value)
6893    if y:
6894        ret = T.mul(ret, *y)
6895    ret.tag.values_eq_approx = values_eq_approx_remove_inf_nan
6896    return [ret]
6897    """
6898The libm used for the test is amdlibm
6899    #([y*]exp(-(x**2)))/erfc(x) # The mul is optional
6900#exp(x**2)/erfc(-x) => when x>threashold,
6901#-x*(1-1/(2*x**2)+3/(4*x**4)-15/(8*x**6))*sqrt(pi) for float64:
6902#threshold=26.63 see below for float32: threshold=9.3 see below TODO
6903#remove the contraint that there are only 2 inputs to mul TODO: should
6904#we cast numpy.pi to x.dtype?
6905
6906#float32 threshold 9.3 as the approximation is more precise at that
6907#point and more stable.
6908import numpy, scipy.special
6909r = numpy.arange(9,10.06,.01)
6910
6911p64=[(numpy.exp(-(x**2)))/scipy.special.erfc(x) for x in r]
6912p32=[(numpy.exp(-(x**2)))/scipy.special.erfc(x) for x in
6913numpy.asarray(r,dtype='float32')]
6914a64=[x*((1-1/(2*x**2)+3/(4*x**4)-15/(8*x**6))**(-1))*numpy.sqrt(numpy.pi)
6915for x in r]
6916a32=[x*((1-1/(2*x**2)+3/(4*x**4)-15/(8*x**6))**(-1))
6917     * numpy.float32(numpy.sqrt(numpy.pi))
6918for x in numpy.asarray(r,dtype='float32')] for idx,(a,b,c,d,e) in
6919enumerate(zip(r,p64,p32,a64,a32)):print
6920a,b,c,d,e,c-b,e-b,numpy.absolute(c-b)<numpy.absolute(e-b)
6921
6922#, show that the value don't look stable at some point before inf.
6923for i in xrange(1,len(p32)): print r[i], p32[i]-p32[i-1]
6924
6925#float64 threshold is 26.63 the approx seam more precise at that
6926point.  r = numpy.arange(26.2,26.7,.001)
6927#scipy.special.erfc(numpy.float128(x)) don't work
6928#p128=[(numpy.exp(-(x**2)))/scipy.special.erfc(x)for x in
6929numpy.float128(r)] #those value have been computed with g++
6930theano/misc/erfc_stability_threshold.c && ./a.out
6931p128=numpy.float128(['46.47206725', '46.47383842', '46.47560959',
6932'46.47738076', '46.47915193', '46.48092309', '46.48269426',
6933'46.48446543', '46.48623660', '46.48800777', '46.48977894',
6934'46.49155011', '46.49332128', '46.49509245', '46.49686362',
6935'46.49863479', '46.50040596', '46.50217713', '46.50394830',
6936'46.50571947', '46.50749064', '46.50926181', '46.51103298',
6937'46.51280415', '46.51457532', '46.51634649', '46.51811766',
6938'46.51988883', '46.52166000', '46.52343118', '46.52520235',
6939'46.52697352', '46.52874469', '46.53051586', '46.53228703',
6940'46.53405820', '46.53582938', '46.53760055', '46.53937172',
6941'46.54114289', '46.54291407', '46.54468524', '46.54645641',
6942'46.54822758', '46.54999876', '46.55176993', '46.55354110',
6943'46.55531227', '46.55708345', '46.55885462', '46.56062579',
6944'46.56239697', '46.56416814', '46.56593931', '46.56771049',
6945'46.56948166', '46.57125283', '46.57302401', '46.57479518',
6946'46.57656636', '46.57833753', '46.58010871', '46.58187988',
6947'46.58365105', '46.58542223', '46.58719340', '46.58896458',
6948'46.59073575', '46.59250693', '46.59427810', '46.59604928',
6949'46.59782045', '46.59959163', '46.60136280', '46.60313398',
6950'46.60490516', '46.60667633', '46.60844751', '46.61021868',
6951'46.61198986', '46.61376104', '46.61553221', '46.61730339',
6952'46.61907456', '46.62084574', '46.62261692', '46.62438809',
6953'46.62615927', '46.62793045', '46.62970163', '46.63147280',
6954'46.63324398', '46.63501516', '46.63678633', '46.63855751',
6955'46.64032869', '46.64209987', '46.64387104', '46.64564222',
6956'46.64741340', '46.64918458', '46.65095576', '46.65272693',
6957'46.65449811', '46.65626929', '46.65804047', '46.65981165',
6958'46.66158283', '46.66335401', '46.66512519', '46.66689636',
6959'46.66866754', '46.67043872', '46.67220990', '46.67398108',
6960'46.67575226', '46.67752344', '46.67929462', '46.68106580',
6961'46.68283698', '46.68460816', '46.68637934', '46.68815052',
6962'46.68992170', '46.69169288', '46.69346406', '46.69523524',
6963'46.69700642', '46.69877760', '46.70054878', '46.70231997',
6964'46.70409115', '46.70586233', '46.70763351', '46.70940469',
6965'46.71117587', '46.71294705', '46.71471824', '46.71648942',
6966'46.71826060', '46.72003178', '46.72180296', '46.72357414',
6967'46.72534533', '46.72711651', '46.72888769', '46.73065887',
6968'46.73243006', '46.73420124', '46.73597242', '46.73774361',
6969'46.73951479', '46.74128597', '46.74305715', '46.74482834',
6970'46.74659952', '46.74837070', '46.75014189', '46.75191307',
6971'46.75368426', '46.75545544', '46.75722662', '46.75899781',
6972'46.76076899', '46.76254018', '46.76431136', '46.76608254',
6973'46.76785373', '46.76962491', '46.77139610', '46.77316728',
6974'46.77493847', '46.77670965', '46.77848084', '46.78025202',
6975'46.78202321', '46.78379439', '46.78556558', '46.78733677',
6976'46.78910795', '46.79087914', '46.79265032', '46.79442151',
6977'46.79619269', '46.79796388', '46.79973507', '46.80150625',
6978'46.80327744', '46.80504863', '46.80681981', '46.80859100',
6979'46.81036219', '46.81213337', '46.81390456', '46.81567575',
6980'46.81744693', '46.81921812', '46.82098931', '46.82276050',
6981'46.82453168', '46.82630287', '46.82807406', '46.82984525',
6982'46.83161644', '46.83338762', '46.83515881', '46.83693000',
6983'46.83870119', '46.84047238', '46.84224357', '46.84401475',
6984'46.84578594', '46.84755713', '46.84932832', '46.85109951',
6985'46.85287070', '46.85464189', '46.85641308', '46.85818427',
6986'46.85995546', '46.86172665', '46.86349784', '46.86526903',
6987'46.86704022', '46.86881141', '46.87058260', '46.87235379',
6988'46.87412498', '46.87589617', '46.87766736', '46.87943855',
6989'46.88120974', '46.88298093', '46.88475212', '46.88652331',
6990'46.88829450', '46.89006569', '46.89183688', '46.89360807',
6991'46.89537927', '46.89715046', '46.89892165', '46.90069284',
6992'46.90246403', '46.90423522', '46.90600642', '46.90777761',
6993'46.90954880', '46.91131999', '46.91309119', '46.91486238',
6994'46.91663357', '46.91840476', '46.92017596', '46.92194715',
6995'46.92371834', '46.92548953', '46.92726073', '46.92903192',
6996'46.93080311', '46.93257431', '46.93434550', '46.93611669',
6997'46.93788789', '46.93965908', '46.94143028', '46.94320147',
6998'46.94497266', '46.94674386', '46.94851505', '46.95028625',
6999'46.95205744', '46.95382864', '46.95559983', '46.95737103',
7000'46.95914222', '46.96091341', '46.96268461', '46.96445581',
7001'46.96622700', '46.96799820', '46.96976939', '46.97154059',
7002'46.97331178', '46.97508298', '46.97685417', '46.97862537',
7003'46.98039657', '46.98216776', '46.98393896', '46.98571015',
7004'46.98748135', '46.98925255', '46.99102374', '46.99279494',
7005'46.99456614', '46.99633733', '46.99810853', '46.99987973',
7006'47.00165092', '47.00342212', '47.00519332', '47.00696452',
7007'47.00873571', '47.01050691', '47.01227811', '47.01404931',
7008'47.01582050', '47.01759170', '47.01936290', '47.02113410',
7009'47.02290530', '47.02467649', '47.02644769', '47.02821889',
7010'47.02999009', '47.03176129', '47.03353249', '47.03530369',
7011'47.03707489', '47.03884608', '47.04061728', '47.04238848',
7012'47.04415968', '47.04593088', '47.04770208', '47.04947328',
7013'47.05124448', '47.05301568', '47.05478688', '47.05655808',
7014'47.05832928', '47.06010048', '47.06187168', '47.06364288',
7015'47.06541408', '47.06718528', '47.06895648', '47.07072768',
7016'47.07249888', '47.07427009', '47.07604129', '47.', '47.07958369',
7017'47.08135489', '47.08312609', '47.08489729', '47.08666850',
7018'47.08843970', '47.09021090', '47.09198210', '47.09375330',
7019'47.09552450', '47.09729571', '47.09906691', '47.10083811',
7020'47.10260931', '47.10438052', '47.10615172', '47.10792292',
7021'47.10969412', '47.11146533', '47.11323653', '47.11500773',
7022'47.11677894', '47.11855014', '47.12032134', '47.12209255',
7023'47.12386375', '47.12563495', '47.12740616', '47.12917736',
7024'47.13094857', '47.13271977', '47.13449097', '47.13626218',
7025'47.13803338', '47.13980459', '47.14157579', '47.14334700',
7026'47.14511820', '47.14688941', '47.14866061', '47.15043182',
7027'47.15220302', '47.15397423', '47.15574543', '47.15751664',
7028'47.15928784', '47.16105905', '47.16283025', '47.16460146',
7029'47.16637266', '47.16814387', '47.16991508', '47.17168628',
7030'47.17345749', '47.17522869', '47.17699990', '47.17877111',
7031'47.18054231', '47.18231352', '47.18408473', '47.18585593',
7032'47.18762714', '47.18939835', '47.19116956', '47.19294076',
7033'47.19471197', '47.19648318', '47.19825439', '47.20002559',
7034'47.20179680', '47.20356801', '47.20533922', '47.20711042',
7035'47.20888163', '47.21065284', '47.21242405', '47.21419526',
7036'47.21596647', '47.21773767', '47.21950888', '47.22128009',
7037'47.22305130', '47.22482251', '47.22659372', '47.22836493',
7038'47.23013614', '47.23190735', '47.23367855', '47.23544976',
7039'47.23722097', '47.23899218', '47.24076339', '47.24253460',
7040'47.24430581', '47.24607702', '47.24784823', '47.24961944',
7041'47.25139065', '47.25316186', '47.25493307', '47.25670429',
7042'47.25847550', '47.26024671', '47.26201792', '47.26378913',
7043'47.26556034', '47.26733155', '47.26910276', '47.27087397',
7044'47.27264518', '47.27441640', '47.27618761', '47.27795882',
7045'47.27973003', '47.28150124', '47.28327246', '47.28504367',
7046'47.28681488', '47.28858609', '47.29035730', '47.29212852',
7047'47.29389973', '47.29567094', '47.29744215', '47.29921337',
7048'47.30098458', '47.30275579', '47.30452701', '47.30629822',
7049'47.30806943', '47.30984065', '47.31161186', '47.31338307',
7050'47.31515429', '47.31692550', '47.31869671', '47.32046793',
7051'47.32223914', '47.32401036', '47.32578157', '47.32755278',
7052'47.32932400', '47.33109521', '47.33286643', '47.33463764',
7053'47.33640886', '47.33818007', '47.33995129', '47.34172250',
7054'47.34349372', '47.34526493', '47.34703615', '47.34880736',
7055'47.35057858', '47.35234979', '47.35412101', '47.35589223'])
7056p64=[(numpy.exp(-(x**2)))/scipy.special.erfc(x)for x in r]
7057a128=[x*((1-1/(2*x**2)+3/(4*x**4)-15/(8*x**6))**(-1))
7058      *numpy.float128(numpy.sqrt(numpy.pi))
7059      for x in numpy.asarray(r,dtype='float128')]
7060a64=[x*((1-1/(2*x**2)+3/(4*x**4)-15/(8*x**6)+63/(7*x**8))**(-1))
7061     *numpy.sqrt(numpy.pi)
7062     for x in r] for a,b,c,d in zip(r,p128,p64,a64):print a,b,c,d,c-b,d-b
7063
7064for i in xrange(1,len(p64)): print i, 64[i]-p64[i-1]
7065   """
7066
7067
7068# ###############
7069# # Loop fusion #
7070# ###############
7071def local_elemwise_fusion_op(OP, max_input_fct=lambda node: 32,
7072                             maker=None):
7073    """
7074    We parametrize it to make it work for Elemwise and GpuElemwise op.
7075
7076    Parameters
7077    ----------
7078    OP
7079        GpuElemwise or Elemwise class (the one that we want to fuse)
7080    max_input_fct
7081        A function that returns the maximum number of inputs
7082        that this elemwise can take (useful for GpuElemwise).
7083        GPU kernel currently has a limit of 256 bytes for
7084        the size of all parameters passed to it. As currently
7085        we pass many information only by parameter, we must
7086        limit how many ops we fuse together to avoid busting
7087        that 256 limit.
7088
7089        On the CPU we limit to 32 input variables
7090        since that is the maximum numpy support.
7091
7092    """
7093    if maker is None:
7094        def maker(node, scalar_op):
7095            return OP(scalar_op)
7096
7097    def local_fuse(node):
7098        """
7099        As part of specialization, we fuse two consecutive elemwise Ops of the
7100        same shape.
7101
7102        For mixed dtype, we let the Composite op do the cast. It lets the C
7103        compiler do the cast.
7104        The number of dimensions is validated at call time by theano itself.
7105
7106        """
7107        # META TODO:  PUT THESE THINGS IN TRAC, NOT TODO NOTES!!
7108        # TODO: use broadcast flag?
7109
7110        # TODO: don't do this optimization as a localOptimizer.
7111        # Analyze the graph in terms of elemwise subgraphs, and then
7112        # replace each subgraph with a Composite version.
7113
7114        # TODO: use malloc and copy to transfer arguments that don't
7115        # fit within the parameter space of 256 bytes
7116        #
7117        # TODO: Merge with multiple output to merge when an inputs
7118        # have multiple clients. This can't be done with a local
7119        # optimiser.
7120
7121        # TODO: Related: Support composites with multiple outputs
7122
7123        # TODO: Use Composite to combine Elemwise and Reduce
7124        # operations.  We have to loop over the data anyway... might
7125        # as well sum it up while we're at it (this can be trickier
7126        # than i'm making it seound here. The data-traversal should be
7127        # done contiguously, and the summing-up might not be easy or
7128        # worthwhile if the summation axis doesn't line up with a
7129        # contiguous dimension)
7130
7131        if type(node.op) is not OP:
7132            return False
7133
7134        if len(node.outputs) > 1:
7135            # We don't support the fusion for node with multiple outputs.
7136            return
7137        inputs = []  # inputs of the new Elemwise op.
7138        s_inputs = []  # inputs of the new scalar op used by the Composite.
7139        # Inputs of the new scalar op that represents the current node.
7140        s_g = []
7141
7142        # There is a hard limit of 256 bytes for the formal argument list to a
7143        # GPU kernel function.
7144        max_nb_input = max_input_fct(node)
7145        # The number of inputs to the new fused op if we do not fuse more
7146        # inputs.
7147        new_nb_input = len(node.inputs)
7148        # Did we fuse something?
7149        # Needed as we can fuse unary op that don't change the number of
7150        # inputs.
7151        # And there is a case where the inputs are the same as the current
7152        # node. That won't change the number of inputs of the new op.
7153        fused = False
7154
7155        for i in node.inputs:
7156            do_fusion = False
7157            catch = False
7158            # Will store inputs of the fused node that are not currently inputs
7159            # of the node we want to create (to avoid duplicating inputs).
7160            tmp_input = []
7161            # Same as tmp_input, but for scalars.
7162            tmp_scalar = []
7163
7164            # We should not check the number of inputs here
7165            # As fusing op don't always change the number of input.
7166            # If a variable is used as multiple into to the same node,
7167            # we still want to fusion. So we take the set.
7168            if (i.owner and
7169                    isinstance(i.owner.op, OP) and
7170                    len(set([n for n, idx in i.clients])) == 1 and
7171                    # Do not merge elemwise that don't have the same
7172                    # broadcastable pattern to don't redo duplicate
7173                    # computation due to broadcast.
7174                    i.owner.outputs[0].broadcastable ==
7175                    node.outputs[0].broadcastable):
7176                do_fusion = True
7177                try:
7178                    tmp_s_input = []
7179                    # we should not put duplicate input into s_inputs and inputs
7180                    for ii in i.owner.inputs:
7181                        if ii in inputs:
7182                            tmp_s_input.append(s_inputs[inputs.index(ii)])
7183                        elif ii in tmp_input:
7184                            tmp_s_input.append(tmp_scalar[tmp_input.index(ii)])
7185                        else:
7186                            tmp = scalar.get_scalar_type(ii.dtype).make_variable()
7187                            try:
7188                                tv = gof.op.get_test_value(ii)
7189                                if tv.size > 0:
7190                                    tmp.tag.test_value = tv.flatten()[0]
7191                                else:
7192                                    tmp.tag.test_value = tv
7193                            except AttributeError:
7194                                pass
7195                            tmp_s_input.append(tmp)
7196                            tmp_input.append(ii)
7197                            tmp_scalar.append(tmp_s_input[-1])
7198                    s_op = i.owner.op.scalar_op(*tmp_s_input,
7199                                                return_list=True)
7200
7201                    # if the scalar_op don't have a c implementation,
7202                    # we skip its fusion to allow the fusion of the
7203                    # other ops.
7204                    i.owner.op.scalar_op.c_code(s_op[0].owner,
7205                                                "test_presence_of_c_code",
7206                                                ["x" for x in i.owner.inputs],
7207                                                ["z" for z in i.owner.outputs],
7208                                                {"fail": "%(fail)s"})
7209                except MethodNotDefined:
7210                    catch = True
7211                except NotImplementedError:
7212                    catch = True
7213                if catch:
7214                    _logger.info(("%s does not implement the c_code function."
7215                                  " As well as being potentially slow, this"
7216                                  " disables loop fusion of this op.") %
7217                                 str(i.owner.op.scalar_op))
7218                    do_fusion = False
7219
7220            # Compute the number of inputs in case we fuse this input.
7221            # We subtract 1 because we replace the existing input with the new
7222            # inputs from `tmp_input`.
7223            new_nb_input_ = new_nb_input + len(tmp_input) - 1
7224
7225            # If the new input is already an input of the current node, it was
7226            # already counted when `new_nb_input` was initialized to
7227            # len(node.inputs).
7228            # This can happen when a variable is used both by the Elemwise to
7229            # fuse and the current node.
7230            for x in tmp_input:
7231                if x in node.inputs:
7232                    new_nb_input_ -= 1
7233
7234            if do_fusion and (new_nb_input_ <= max_nb_input):
7235                fused = True
7236                new_nb_input = new_nb_input_
7237                inputs.extend(tmp_input)
7238                s_inputs.extend(tmp_scalar)
7239                s_g.extend(s_op)
7240            else:
7241                # We must support the case where the same variable appear many
7242                # time in the inputs
7243                if inputs.count(i) == node.inputs.count(i):
7244                    s = s_inputs[inputs.index(i)]
7245                else:
7246                    s = scalar.get_scalar_type(i.dtype).make_variable()
7247                    try:
7248                        if theano.config.compute_test_value != 'off':
7249                            v = gof.op.get_test_value(i)
7250                            if v.size > 0:
7251                                s.tag.test_value = v.flatten()[0]
7252                    except AttributeError:
7253                        pass
7254
7255                    inputs.append(i)
7256                    s_inputs.append(s)
7257                s_g.append(s)
7258
7259        if not fused:
7260            return False
7261
7262        if new_nb_input != len(inputs) or len(s_inputs) != len(inputs):
7263            raise Exception("""Something has gone wrong with the elemwise
7264fusion optimization. We skip this optimization. You can ignore this message,
7265your code will run correctly, but may be slower.""")
7266
7267        s_new_out = node.op.scalar_op(*s_g, return_list=True)
7268        try:
7269            s_new_out[0].owner.op.c_code(s_new_out[0].owner,
7270                                         "test_presence_of_c_code",
7271                                         ["x" for x in s_g],
7272                                         ["z" for x in s_new_out],
7273                                         {"fail": "%(fail)s"})
7274        except MethodNotDefined:
7275            _logger.info(("%s does not implement the c_code function."
7276                          " As well as being potentially slow, this disables "
7277                          "loop fusion of this op.") % str(
7278                              s_new_out[0].owner.op))
7279            return False
7280        except NotImplementedError:
7281            _logger.info(("%s does not implement the c_code function. As well"
7282                          " as being potentially slow, this disables loop"
7283                          " fusion of this op.") % str(s_new_out[0].owner.op))
7284            return False
7285
7286        # create the composite op.
7287        C = scalar.Composite(s_inputs, s_new_out)
7288
7289        # create the new node.
7290        # Do not call make_node to have test_value
7291        n = maker(node, C)(*inputs).owner
7292        assert len(n.outputs) == 1
7293        assert node.outputs[0].dtype == n.outputs[0].dtype
7294
7295        if len(n.inputs) > max_nb_input:
7296            _logger.info('loop fusion failed because Op would exceed'
7297                         ' kernel argument limit.')
7298            return False
7299
7300        # we fuse as many that we can at the same time to make debug mode faster
7301        # debug mode will be faster as it won't test all intermediate step.
7302        while True:
7303            ret = local_fuse(n)
7304            if ret is not False and ret is not None:
7305                # print n,ret
7306                assert len(ret) == len(n.outputs)
7307                assert len(ret) == 1
7308                n = ret[0].owner
7309            else:
7310                break
7311
7312        return n.outputs
7313    return local_fuse
7314
7315
7316def elemwise_max_input_fct(node):
7317    # The Elemwise.perform use numpy ufunc and they are limited to 31
7318    # inputs.
7319    if not theano.config.cxx:
7320        return 31
7321    return 1024
7322
7323
7324local_elemwise_fusion = local_elemwise_fusion_op(T.Elemwise,
7325                                                 elemwise_max_input_fct)
7326
7327
7328class FusionOptimizer(Optimizer):
7329    """Graph optimizer for Fusion of elemwise operations."""
7330    def __init__(self, local_optimizer):
7331        Optimizer.__init__(self)
7332        self.optimizer = local_optimizer
7333
7334    def add_requirements(self, fgraph):
7335        fgraph.attach_feature(toolbox.ReplaceValidate())
7336
7337    def apply(self, fgraph):
7338        did_something = True
7339        nb_iter = 0
7340        nb_replacement = 0
7341        nb_inconsistency_replace = 0
7342        time_toposort = 0
7343        if fgraph.profile:
7344            validate_before = fgraph.profile.validate_time
7345            callbacks_before = fgraph.execute_callbacks_times.copy()
7346            callback_before = fgraph.execute_callbacks_time
7347        while did_something:
7348            t0 = time.time()
7349            nodelist = list(fgraph.toposort())
7350            time_toposort += time.time() - t0
7351            nodelist.reverse()
7352            did_something = False
7353            for node in nodelist:
7354                # Don't try to fuse node that have already been fused.
7355                if node in fgraph.apply_nodes:
7356                    new_outputs = self.optimizer(node)
7357                    if new_outputs:
7358                        assert len(new_outputs) == len(node.outputs)
7359                        try:
7360                            fgraph.replace_all_validate(
7361                                list(zip(node.outputs, new_outputs)),
7362                                reason=self.__class__.__name__)
7363                            did_something = True
7364                            nb_replacement += 1
7365                        except InconsistencyError:
7366                            nb_inconsistency_replace += 1
7367                            pass
7368            nb_iter += 1
7369
7370        if fgraph.profile:
7371            validate_time = fgraph.profile.validate_time - validate_before
7372            callback_time = fgraph.execute_callbacks_time - callback_before
7373            callbacks_time = {}
7374            for k, v in iteritems(fgraph.execute_callbacks_times):
7375                if k in callbacks_before:
7376                    callbacks_time[k] = v - callbacks_before[k]
7377                else:
7378                    callbacks_time[k] = v
7379        else:
7380            validate_time = None
7381            callback_time = None
7382            callbacks_time = {}
7383        return (self, nb_iter, nb_replacement,
7384                nb_inconsistency_replace,
7385                validate_time, callback_time, callbacks_time,
7386                time_toposort)
7387
7388    @staticmethod
7389    def print_profile(stream, prof, level=0):
7390        blanc = ('    ' * level)
7391        print(blanc, "FusionOptimizer", file=stream)
7392        print(blanc, " nb_iter", prof[1], file=stream)
7393        print(blanc, " nb_replacement", prof[2], file=stream)
7394        print(blanc, " nb_inconsistency_replace", prof[3], file=stream)
7395        print(blanc, " validate_time", prof[4], file=stream)
7396        print(blanc, " callback_time", prof[5], file=stream)
7397        if prof[5] > 1:
7398            print(blanc, " callbacks_time", file=stream)
7399            for i in sorted(iteritems(prof[6]), key=lambda a: a[1])[::-1]:
7400                if i[1] > 0:
7401                    print(blanc, "     ", i)
7402        print(blanc, " time_toposort", prof[7], file=stream)
7403
7404
7405def local_add_mul_fusion(node):
7406    """Fuse consecutive add or mul in one such node with more inputs.
7407
7408    It is better to fuse add/mul that way then in a Composite node as
7409    this make the inner graph of the Composite smaller. This allow to
7410    put more computation in a Composite before hitting the max
7411    recusion limit when pickling Composite.
7412
7413    """
7414    if (not isinstance(node.op, Elemwise) or
7415            not isinstance(node.op.scalar_op, (scalar.Add, scalar.Mul))):
7416        return False
7417
7418    s_op = node.op.scalar_op.__class__
7419    new_inp = []
7420    fused = False
7421    nb_inputs = len(node.inputs)
7422    max_inputs = float('inf')
7423    if hasattr(node.op, 'max_inputs'):
7424        max_inputs = node.op.max_inputs(node)
7425    for inp in node.inputs:
7426        if (inp.owner and
7427                isinstance(inp.owner.op, Elemwise) and
7428                isinstance(inp.owner.op.scalar_op, s_op) and
7429                # Do not duplicate the operation.
7430                len(inp.clients) == 1 and
7431                (nb_inputs + len(inp.owner.inputs) - 1) <= max_inputs):
7432            new_inp.extend(inp.owner.inputs)
7433            fused = True
7434        else:
7435            new_inp.append(inp)
7436
7437    # We can not compare the number of inputs as Mul and Add could have
7438    # 0 or 1 inputs in some corner cases.
7439    if fused:
7440        output = node.op(*new_inp)
7441        copy_stack_trace(node.outputs[0], output)
7442
7443        # Do the recursion here to help lower the number of
7444        # FusionOptimizer iteration.
7445        if output.owner:
7446            output2 = local_add_mul_fusion(output.owner)
7447            if output2:
7448                return output2
7449        return [output]
7450
7451if config.tensor.local_elemwise_fusion:
7452    _logger.debug("enabling optimization fusion elemwise in fast_run")
7453    # Must be after gpu(48.5) and before AddDestroyHandler(49.5)
7454    fuse_seqopt = gof.SequenceDB()
7455    fuse_seqopt.register('local_add_mul_fusion',
7456                         FusionOptimizer(local_add_mul_fusion),
7457                         0, 'fast_run', 'fusion')
7458    fuse_seqopt.register('composite_elemwise_fusion',
7459                         FusionOptimizer(local_elemwise_fusion),
7460                         1, 'fast_run', 'fusion')
7461    compile.optdb.register('elemwise_fusion',
7462                           fuse_seqopt, 49,
7463                           'fast_run', 'fusion', 'local_elemwise_fusion',
7464                           'FusionOptimizer')
7465else:
7466    _logger.debug("not enabling optimization fusion elemwise in fast_run")
7467    compile.optdb.register('elemwise_fusion',
7468                           FusionOptimizer(local_elemwise_fusion), 49,
7469                           'fusion', 'local_elemwise_fusion',
7470                           'FusionOptimizer')
7471
7472
7473@register_canonicalize
7474@gof.local_optimizer([Elemwise])
7475def local_useless_composite(node):
7476    """For elemwise Composite that have multiple outputs, remove the
7477    outputs that are not used.
7478
7479    """
7480    if (not isinstance(node.op, Elemwise) or
7481            not isinstance(node.op.scalar_op, scalar.Composite)):
7482        return
7483    comp = node.op.scalar_op
7484    idx = [i for i, o_extern in enumerate(node.outputs)
7485           if o_extern.clients]
7486    if len(idx) < len(node.outputs):
7487        new_outputs = [comp.outputs[i] for i in idx]
7488        c = scalar.Composite(inputs=comp.inputs,
7489                             outputs=new_outputs)
7490        e = Elemwise(scalar_op=c)(*node.inputs, return_list=True)
7491        return dict(zip([node.outputs[i] for i in idx], e))
7492
7493# ############################
7494# # Remove consider_constant #
7495# ############################
7496
7497
7498# Although the ops ConsiderConstant, ZeroGrad and DisconnectedGrad
7499# just returns the input, it should be removed from the graph to
7500@register_canonicalize('fast_compile')
7501@register_useless('fast_compile')
7502@gof.local_optimizer(None)
7503def local_view_op(node):
7504    if isinstance(node.op, theano.compile.ops.ViewOp):
7505        return node.inputs
7506
7507
7508@register_useless
7509@register_canonicalize
7510@register_stabilize
7511@register_specialize
7512@gof.local_optimizer([T.Alloc])
7513def local_merge_alloc(node):
7514    # This opt takes care of several cases:
7515    # Alloc(Alloc(m, x, 1, 1, 1), x, y, z, w) -> Alloc(m, x, y, z, w)
7516    # Alloc(Alloc(m, y, 1, 1), x, y, z, w) -> Alloc(m, x, y, z, w)
7517    # Alloc(Alloc(m, y1, 1, 1), x, y2, z, w) -> Alloc(m, x, assert(y1, y1==y2), z, w)
7518    if not isinstance(node.op, T.Alloc):
7519        return False
7520    if not node.inputs[0].owner or not isinstance(
7521            node.inputs[0].owner.op, T.Alloc):
7522        return False
7523    inputs_outer = node.inputs
7524    inputs_inner = node.inputs[0].owner.inputs
7525    dims_outer = inputs_outer[1:]
7526    dims_inner = inputs_inner[1:]
7527    dims_outer_rev = dims_outer[::-1]
7528    dims_inner_rev = dims_inner[::-1]
7529    # check if the pattern of broadcasting is matched, in the reversed ordering.
7530    # The reverse ordering is needed when an Alloc add an implicit new
7531    # broadcasted dimensions to its inputs[0]. Eg:
7532    # Alloc(Alloc(m, y, 1, 1), x, y, z, w) -> Alloc(m, x, y, z, w)
7533    i = 0
7534    for dim_inner, dim_outer in zip(dims_inner_rev, dims_outer_rev):
7535        if dim_inner != dim_outer:
7536            if isinstance(dim_inner, Constant) and dim_inner.data == 1:
7537                pass
7538            else:
7539                dims_outer[-1 - i] = Assert(
7540                    "You have a shape error in your graph. To see a better"
7541                    " error message and a stack trace of where in your code"
7542                    " the error is created, use the Theano flags"
7543                    " optimizer=None or optimizer=fast_compile.")(
7544                    dim_outer, T.eq(dim_outer, dim_inner))
7545        i += 1
7546    return [T.alloc(inputs_inner[0], *dims_outer)]
7547
7548
7549@register_useless('fast_compile')
7550@gof.local_optimizer([TopKOp])
7551def local_useless_topk(node):
7552    """
7553    TopKOp generates two outputs by default
7554    This opt removes the useless ones
7555
7556    """
7557    op = node.op
7558    if not isinstance(op, TopKOp):
7559        return
7560    if not (op.return_values and op.return_indices):
7561        return False
7562
7563    x, k = node.inputs
7564    ret_val = bool(node.outputs[0].clients)
7565    ret_idx = bool(node.outputs[1].clients)
7566
7567    if not (ret_val ^ ret_idx):
7568        # both true -> nothing to remove
7569        # both false -> let pruner handle
7570        return False
7571
7572    old_output = node.outputs[ret_idx]
7573    new_output = TopKOp(
7574        axis=op.axis,
7575        sorted=op.sorted,
7576        idx_dtype=op.idx_dtype,
7577        return_values=ret_val,
7578        return_indices=ret_idx)(x, k)
7579    copy_stack_trace(node.outputs[0], new_output)
7580    return {old_output: new_output}
7581