1"""
2This module provides optimizations for scan.
3The Optimization provided in this file:
4
5local opt: remove_constants_and_unused_inputs_scan,
6           constant_folding_for_scan2,
7           scan_merge_inouts
8           They are wrapped in in2out to create global opt.
9global opt: ScanInplaceOptimizer,
10            PushOutNonSeqScan,
11            PushOutSeqScan,
12            PushOutDot1,
13            ScanMerge,
14            ScanSaveMem
15
16How the are registered:
17
18optdb: scan_eqopt1 (.1), scan_eqopt2(1.6), scan_inplace(75)
19scan_eqopt1 -> scan_seqopt1
20scan_seqopt1 -> in2out(remove_constants_and_unused_inputs_scan)(1),
21                PushOutNonSeqScan(2),
22                PushOutSeqScan(3), PushOutDot1(4)
23scan_eqopt2 -> They are all global optimizer. (in2out convert local to global).
24               This is important, as the order is important and all global
25               optimizer run before local optimizer in the order they where
26               registered. (So don't change the order we register them!)
27               If we convert to local optimizer, we must convert all of them
28               to local optimizer. But:
29               1) can ScanMerge be made local? Can we keep only this one
30               global?
31               2) ScanSaveMem assert that we remove all nodes outputs,
32                  we need to keep this.
33               3) It is ScanSaveMem suppose the the others ran before.
34                  I added an assert at one place, but didn't looked for
35                  other place.
36               4) Moving this to local opt could speed up significant this opt,
37                  as we pass frequently on all nodes in the graph for no
38                  good reason.
39               5) We register remove_constant_*  many places, as some
40                  opt create them and let this one clean up the mess.
41                  Doing it that way, make things simpler for those already
42                  complex opt.
43
44               in2out(constant_folding),
45               in2out(remove_constants_and_unused_inputs_scan1),
46               ScanMerge,
47               in2out(remove_constants_and_unused_inputs_scan2),
48               in2out(scan_merge_inouts),
49               ScanSaveMem,
50               in2out(remove_constants_and_unused_inputs_scan3)
51"""
52from __future__ import absolute_import, print_function, division
53import logging
54import copy
55from sys import maxsize
56from collections import OrderedDict
57import numpy as np
58
59import theano
60from theano import tensor, scalar
61from theano.tensor import opt, get_scalar_constant_value, Alloc, AllocEmpty
62from theano import gof
63from six import integer_types, iteritems
64from six.moves import xrange
65from theano.compile import optdb
66from theano.compile.function_module import deep_copy_op
67from theano.gof import toolbox, DestroyHandler, InconsistencyError
68from theano.gof.opt import Optimizer
69from theano.gof.opt import pre_constant_merge, pre_greedy_local_optimizer
70
71from theano.scan_module import scan_op
72from theano.scan_module import scan_utils
73from theano.scan_module.scan_utils import equal_computations, scan_args
74
75__docformat__ = 'restructedtext en'
76__authors__ = ("Razvan Pascanu "
77               "Frederic Bastien "
78               "James Bergstra "
79               "Pascal Lamblin "
80               "Arnaud Bergeron ")
81__copyright__ = "(c) 2010, Universite de Montreal"
82__contact__ = "Razvan Pascanu <r.pascanu@gmail>"
83
84
85# Logging function for sending warning or info
86_logger = logging.getLogger('theano.scan_module.scan_opt')
87
88list_opt_slice = [tensor.opt.local_abs_merge,
89                  tensor.opt.local_mul_switch_sink,
90                  tensor.opt.local_upcast_elemwise_constant_inputs,
91                  tensor.opt.local_useless_switch,
92                  tensor.opt.constant_folding]
93
94
95def warning(*msg):
96    _logger.warning('WARNING theano.scan: ' + ' '.join(msg))
97
98
99def info(*msg):
100    _logger.info('INFO theano.scan: ' + ' '.join(msg))
101
102
103@gof.local_optimizer([scan_op.Scan])
104def remove_constants_and_unused_inputs_scan(node):
105    """
106    Move constants into the inner graph, and remove unused inputs.
107
108    Constants that are in the outer graph are represented by a free symbolic
109    variable in the inner graph. If we move them into the inner graph,
110    constant-folding can happen in the inner graph.
111    This is applied only on sequences and non-sequences,
112    not on initial states.
113
114    """
115    if not isinstance(node.op, scan_op.Scan):
116        return False
117    op = node.op
118    # We only need to take care of sequences and other arguments
119    st = op.n_seqs
120    st += int(sum([len(x) for x in
121                   op.tap_array[:(op.n_mit_mot + op.n_mit_sot)]]))
122    st += op.n_sit_sot
123    st += op.n_shared_outs
124
125    op_ins = op.inputs
126    op_outs = op.outputs
127
128    # Corresponds to the initial states, which should stay untouched.
129    # We put those variables aside, and put them back at the end.
130    out_stuff_inner = op_ins[op.n_seqs:st]
131
132    non_seqs = op_ins[st:]
133    st = (op.n_seqs +
134          op.n_mit_mot +
135          op.n_mit_sot +
136          op.n_sit_sot +
137          op.n_nit_sot +
138          op.n_shared_outs + 1)
139    outer_non_seqs = node.inputs[st:]
140    out_stuff_outer = node.inputs[1 + op.n_seqs:st]
141
142    # To replace constants in the outer graph by clones in the inner graph
143    givens = OrderedDict()
144    # All the inputs of the inner graph of the new scan
145    nw_inner = []
146    # Same for the outer graph, initialized w/ number of steps
147    nw_outer = [node.inputs[0]]
148
149    all_ins = gof.graph.inputs(op_outs)
150    for idx in xrange(op.n_seqs):
151        node_inp = node.inputs[idx + 1]
152        if (isinstance(node_inp, tensor.TensorConstant) and
153                node_inp.tag.unique_value is not None):
154            try:
155                # This works if input is a constant that has all entries
156                # equal
157                givens[op_ins[idx]] = node_inp.clone()[0]
158            except TypeError:
159                pass
160        elif op_ins[idx] in all_ins:
161            # Check for identical other sequence
162            identical_seqs = [x for x in nw_outer
163                              if scan_utils.equal_computations(
164                                  [x], [node_inp])]
165            if identical_seqs:
166                index = node.inputs.index(identical_seqs[0]) - 1
167                givens[op_ins[idx]] = op_ins[index]
168            else:
169                nw_inner.append(op_ins[idx])
170                nw_outer.append(node_inp)
171
172    nw_n_seqs = len(nw_inner)
173    # Add outputs stuff
174    nw_inner += out_stuff_inner
175    nw_outer += out_stuff_outer
176
177    # Look through non sequences
178    nw_inner_nonseq = []
179    nw_outer_nonseq = []
180    for idx, (nw_in, nw_out) in enumerate(zip(non_seqs, outer_non_seqs)):
181        if isinstance(nw_out, tensor.Constant):
182            givens[nw_in] = nw_out.clone()
183        elif nw_in in all_ins:
184            # Indices of elements of nw_outer_nonseq that are equivalent
185            # to nw_out.
186            identical_nonseq_idx = [
187                i for (i, x) in enumerate(nw_outer_nonseq)
188                if scan_utils.equal_computations([x], [nw_out])]
189            if identical_nonseq_idx:
190                givens[nw_in] = nw_inner_nonseq[identical_nonseq_idx[0]]
191            else:
192                nw_inner_nonseq.append(nw_in)
193                nw_outer_nonseq.append(nw_out)
194
195    nw_inner.extend(nw_inner_nonseq)
196    nw_outer.extend(nw_outer_nonseq)
197
198    if len(nw_inner) != len(op_ins):
199        op_outs = scan_utils.clone(op_outs, replace=givens)
200        nw_info = copy.deepcopy(op.info)
201        nw_info['n_seqs'] = nw_n_seqs
202        # DEBUG CHECK
203        nwScan = scan_op.Scan(nw_inner, op_outs, nw_info)
204        nw_outs = nwScan(*nw_outer, **dict(return_list=True))
205        return OrderedDict([("remove", [node])] + list(zip(node.outputs, nw_outs)))
206    else:
207        return False
208
209
210# This is a global opt for historical reason
211# It should be possible to change it to a local opt.
212class PushOutNonSeqScan(gof.Optimizer):
213    """
214    A global optimizer for pushing out the variables inside the scan that depend
215    only on non-sequences.
216    """
217
218    def __init__(self):
219        gof.Optimizer.__init__(self)
220
221    def add_requirements(self, fgraph):
222        fgraph.attach_feature(gof.toolbox.ReplaceValidate())
223
224    def apply(self, fgraph):
225        nodelist = [x for x in fgraph.toposort() if isinstance(x.op,
226                                                               scan_op.Scan)]
227        for node in nodelist:
228            self.process_node(fgraph, node)
229
230    def process_node(self, fgraph, node):
231        """
232        IMPORTANT NOTE: This function uses set and dictionary data structures.
233        By default they are not ordered for efficiency reasons. Take care
234        and make sure of changing them with their Ordered counterparts if you
235        need to iterate over these variables.
236
237        """
238        # this flag tells if there was any change during the last iterations
239        clean_inputs, clean_outputs = scan_utils.reconstruct_graph(
240            node.op.inputs, node.op.outputs)
241
242        local_fgraph_topo = theano.gof.graph.io_toposort(clean_inputs,
243                                                         clean_outputs)
244        local_fgraph_outs_set = set(clean_outputs)
245        local_fgraph_outs_map = dict([(v, k) for k, v in
246                                      enumerate(clean_outputs)])
247
248        to_remove_set = set()
249        to_replace_set = set()
250        to_replace_map = OrderedDict()
251
252        def add_to_replace(y):
253            to_replace_set.add(y)
254            to_replace_map[y] = add_to_replace.n
255            add_to_replace.n += 1
256        add_to_replace.n = 0
257
258        replace_with_in = []
259        replace_with_out = []
260
261        op = node.op
262        # Construct the list of non_sequences to simplify a few things
263        inner_non_seqs = op.inner_non_seqs(clean_inputs)
264        inner_non_seqs_set = set(inner_non_seqs)
265        inner_non_seqs_map = dict([(v, k) for k, v in
266                                   enumerate(inner_non_seqs)])
267
268        outer_non_seqs = op.outer_non_seqs(node.inputs)
269
270        inner_seqs = op.inner_seqs(clean_inputs)
271        outer_seqs = op.outer_seqs(node.inputs)
272
273        assert len(inner_non_seqs) == len(outer_non_seqs)
274        assert len(inner_seqs) == len(outer_seqs)
275
276        for nd in local_fgraph_topo:
277            if (  # we haven't already looked at this node
278                    nd not in to_remove_set and
279                    all([((x in inner_non_seqs_set) or
280                        (x.owner in to_remove_set) or
281                        isinstance(x, tensor.Constant))
282                        for x in nd.inputs]) and
283                    # we can do this because the assumption is that a
284                    # viewOp or deepCopyOp will be just at the end of the
285                    # function and not somewhere in the middle ..
286                    not isinstance(nd.op, theano.compile.ViewOp) and
287                    not isinstance(nd.op, theano.compile.DeepCopyOp)):
288
289                # We have a candidate node to removable
290                # Step 1. Reconstruct it on outside
291                to_remove_set.add(nd)
292                outside_ins = []
293                for x in nd.inputs:
294                    if x in inner_non_seqs_set:
295                        _idx = inner_non_seqs_map[x]
296                        outside_ins.append(outer_non_seqs[_idx])
297                    elif x in to_replace_set:
298                        outside_ins.append(replace_with_out[to_replace_map[x]])
299                    elif isinstance(x, theano.Constant):
300                        outside_ins.append(x.clone())
301                    else:
302                        raise Exception(
303                            ('Error in the `scan_pushout_non_seq_'
304                             'operations`. The optimization tries '
305                             'to move some computation fron scan '
306                             'which is not allowed to move. Report '
307                             'this on theano-users list'), x)
308                outside_ins = [x.type.filter_variable(y) for x, y in
309                               zip(nd.inputs, outside_ins)]
310
311                # Do not call make_node for test_value
312                nw_outer_node = nd.op(*outside_ins,
313                                      **dict(return_list=True))[0].owner
314
315                # Step 2. Create variables for replacements
316                for idx, y in enumerate(nd.outputs):
317                    y_place_holder = scan_utils.safe_new(y, '_replace')
318                    add_to_replace(y)
319                    replace_with_in.append(y_place_holder)
320                    assert isinstance(y, type(nw_outer_node.outputs[idx]))
321                    replace_with_out.append(nw_outer_node.outputs[idx])
322
323        # We need to check all candidate replacements and choose those that
324        # make sense for us
325        # Step 1. which elements of `to_replace` are used by remaining
326        # components of the inner function
327        clean_to_replace = []
328        clean_replace_with_in = []
329        clean_replace_with_out = []
330        existent_nodes = [nd for nd in local_fgraph_topo
331                          if nd not in to_remove_set]
332        existent_nodes_set = set(existent_nodes)
333
334        to_keep_set = set([])
335        for nd in existent_nodes:
336            to_keep_set.update(nd.inputs)
337
338        for out, idx in to_replace_map.items():
339            if (  # If types are different, conversion Op will be inserted,
340                    # and it may trigger an infinite loop.
341                    replace_with_in[idx].type == out.type and
342                    out in to_keep_set and
343                    out.owner not in existent_nodes_set):
344                clean_to_replace.append(out)
345                clean_replace_with_in.append(replace_with_in[idx])
346                clean_replace_with_out.append(replace_with_out[idx])
347
348        if len(clean_to_replace) > 0:
349            # We can finally put an end to all this madness
350            givens = OrderedDict()
351            nw_outer = []
352            nw_inner = []
353            for to_repl, repl_in, repl_out in zip(clean_to_replace,
354                                                  clean_replace_with_in,
355                                                  clean_replace_with_out):
356                if isinstance(repl_out, theano.Constant):
357                    repl_in = repl_out.clone()
358                else:
359                    nw_inner.append(repl_in)
360                    nw_outer.append(repl_out)
361                givens[to_repl] = repl_in
362
363            op_outs = scan_utils.clone(clean_outputs, replace=givens)
364            op_ins = clean_inputs + nw_inner
365
366            # Reconstruct node
367            nwScan = scan_op.Scan(op_ins, op_outs, op.info)
368
369            # Do not call make_node for test_value
370            nw_node = nwScan(*(node.inputs + nw_outer),
371                             **dict(return_list=True))[0].owner
372
373            fgraph.replace_all_validate_remove(
374                list(zip(node.outputs, nw_node.outputs)),
375                remove=[node],
376                reason='scanOp_pushout_nonseqs_ops')
377            return True
378        elif not to_keep_set:
379            # Nothing in the inner graph should be kept
380            replace_with = OrderedDict()
381            for out, idx in to_replace_map.items():
382                if out in local_fgraph_outs_set:
383                    x = node.outputs[local_fgraph_outs_map[out]]
384                    y = replace_with_out[idx]
385                    shape = [shp for shp in y.shape]
386                    replace_with[x] = tensor.alloc(y,
387                                                   node.inputs[0],
388                                                   *shape)
389
390            # We need to add one extra dimension to the outputs
391            # because the scan op expects for a tensor3, to which an
392            # subtensor is applied that takes only the last element
393            if replace_with:
394                if len(node.outputs) == len(replace_with):
395                    # Every output of the node has a replacement, the Scan
396                    # node can be removed from the graph
397                    fgraph.replace_all_validate_remove(
398                        replace_with.items(),
399                        remove=[node],
400                        reason='scanOp_pushout_nonseqs_ops')
401                else:
402                    # The node has some outputs for which no replacement has
403                    # been established. This can occur for outputs that are
404                    # not produced by apply nodes (since the optimizations
405                    # only visits apply nodes) such as constants or inputs
406                    # passed directly as outputs. The replacements can be
407                    # performed but the Scan node can't be removed at this
408                    # point.
409                    fgraph.replace_all_validate(
410                        replace_with.items(),
411                        reason='scanOp_pushout_nonseqs_ops')
412
413        else:
414            return False
415
416
417# This is a global opt for historical reason
418# It should be possible to change it to a local opt.
419class PushOutSeqScan(gof.Optimizer):
420    """
421    A global optimizer for pushing out the variables inside the
422    scan that depend only on constants and sequences.
423    """
424
425    def __init__(self):
426        gof.Optimizer.__init__(self)
427
428    def add_requirements(self, fgraph):
429        fgraph.attach_feature(gof.toolbox.ReplaceValidate())
430
431    def apply(self, fgraph):
432        nodelist = [x for x in fgraph.toposort()
433                    if isinstance(x.op, scan_op.Scan)]
434        for node in nodelist:
435            self.process_node(fgraph, node)
436
437    def process_node(self, fgraph, node):
438        """
439        IMPORTANT NOTE: This function uses set and dictionary data structure.
440        By default they are not ordered for efficiency reasons. Take care
441        and make sure of changing them to Ordered versions if you need to
442        iterate over those variables.
443
444        """
445        # this flag tells if there was any change during the last iterations
446        clean_inputs, clean_outputs = scan_utils.reconstruct_graph(
447            node.op.inputs, node.op.outputs)
448
449        local_fgraph_topo = theano.gof.graph.io_toposort(clean_inputs,
450                                                         clean_outputs)
451        local_fgraph_outs_set = set(clean_outputs)
452        local_fgraph_outs_map = dict([(v, k) for k, v in
453                                      enumerate(clean_outputs)])
454
455        to_remove_set = set()
456        to_replace_set = set()
457        to_replace_map = OrderedDict()
458
459        def add_to_replace(y):
460            to_replace_set.add(y)
461            to_replace_map[y] = add_to_replace.n
462            add_to_replace.n += 1
463        add_to_replace.n = 0
464
465        replace_with_in = []
466        replace_with_out = []
467
468        op = node.op
469        # Construct the list of non_sequences to simplify a few things
470        inner_non_seqs = op.inner_non_seqs(clean_inputs)
471        inner_non_seqs_set = set(inner_non_seqs)
472        inner_non_seqs_map = dict([(v, k) for k, v in
473                                   enumerate(inner_non_seqs)])
474
475        outer_non_seqs = op.outer_non_seqs(node.inputs)
476        inner_seqs = op.inner_seqs(clean_inputs)
477        inner_seqs_set = set(inner_seqs)
478        inner_seqs_map = dict([(v, k) for k, v in
479                               enumerate(inner_seqs)])
480
481        outer_seqs = op.outer_seqs(node.inputs)
482        assert len(inner_non_seqs) == len(outer_non_seqs)
483        assert len(inner_seqs) == len(outer_seqs)
484
485        for nd in local_fgraph_topo:
486            if (nd not in to_remove_set and
487                all([(x in inner_non_seqs_set) or
488                     (x.owner in to_remove_set) or
489                     isinstance(x, tensor.Constant) or
490                     (x in inner_seqs_set) for x in nd.inputs]) and
491                    isinstance(nd.op, theano.tensor.Elemwise)):
492
493                outside_ins = []
494                depends_on_seqs = False
495
496                for x in nd.inputs:
497                    if x in inner_non_seqs_set:
498                        _idx = inner_non_seqs_map[x]
499                        outside_ins.append(outer_non_seqs[_idx])
500                    elif x in inner_seqs_set:
501                        outside_ins.append(outer_seqs[inner_seqs_map[x]])
502                        depends_on_seqs = True
503                    elif x in to_replace_set:
504                        outside_ins.append(replace_with_out[
505                            to_replace_map[x]])
506                        depends_on_seqs = True
507                    elif isinstance(x, theano.Constant):
508                        outside_ins.append(x.clone())
509                    else:
510                        raise Exception(
511                            ('Error in the `scan_pushout_seq_'
512                             'operations`. The optimization tries '
513                             'to move some computation fron scan '
514                             'which is not allowed to move. Report '
515                             'this on theano-users list'), x)
516
517                if not depends_on_seqs:
518                    # Removing this node from the inner graph of scan
519                    # should be handled by the PushOutNonSeqScan
520                    # optimization. The current optimization only tries
521                    # to pull sequence-dependant computation out of
522                    # scan.
523                    continue
524
525                to_remove_set.add(nd)
526
527                # Do not call make_node for test_value
528                nw_outer_node = nd.op(*outside_ins,
529                                      **dict(return_list=True))[0].owner
530
531                # Step 2. Create variables for replacements
532                for idx, y in enumerate(nd.outputs):
533                    y_place_holder = scan_utils.safe_new(y, '_replace')
534                    add_to_replace(y)
535                    replace_with_in.append(y_place_holder)
536                    replace_with_out.append(nw_outer_node.outputs[idx])
537
538            elif (nd not in to_remove_set and
539                  isinstance(nd.op, theano.tensor.DimShuffle) and
540                  (nd.inputs[0] in inner_seqs_set or
541                   nd.inputs[0].owner in to_remove_set)):
542
543                to_remove_set.add(nd)
544                x = nd.inputs[0]
545                if x in inner_seqs_set:
546                    outside_ins = outer_seqs[inner_seqs_map[x]]
547                elif x in to_replace_set:
548                    outside_ins = replace_with_out[to_replace_map[x]]
549                new_ord = (0,)
550                for old_ord in nd.op.new_order:
551                    if (old_ord == 'x'):
552                        new_ord += (old_ord,)
553                    else:
554                        new_ord += (old_ord + 1,)
555                new_outer = outside_ins.dimshuffle(new_ord)
556                y = nd.outputs[0]
557                y_place_holder = scan_utils.safe_new(y, '_replace')
558                add_to_replace(y)
559                replace_with_in.append(y_place_holder)
560                replace_with_out.append(new_outer)
561
562                if hasattr(new_outer.tag, "test_value"):
563                    new_sh = new_outer.tag.test_value.shape
564                    ref_sh = (outside_ins.tag.test_value.shape[0],)
565                    ref_sh += nd.outputs[0].tag.test_value.shape
566                    assert new_sh == ref_sh
567
568        # We need to check all candidate replacements and choose those that
569        # make sense for us
570        # Step 1. which elements of `to_replace` are used by remaining
571        # components of the inner function
572        clean_to_replace = []
573        clean_replace_with_in = []
574        clean_replace_with_out = []
575
576        existent_nodes = [nd for nd in local_fgraph_topo
577                          if nd not in to_remove_set]
578        existent_nodes_set = set(existent_nodes)
579
580        to_keep_set = set([])
581        for nd in existent_nodes:
582            to_keep_set.update(nd.inputs)
583
584        for out, idx in to_replace_map.items():
585            if (out in to_keep_set and out.owner not in existent_nodes_set and
586                # If types are different, conversion Op will be inserted,
587                # and it may trigger an infinite loop.
588                    replace_with_in[idx].type == out.type):
589
590                clean_to_replace.append(out)
591                clean_replace_with_in.append(replace_with_in[idx])
592                clean_replace_with_out.append(replace_with_out[idx])
593
594        if len(clean_to_replace) > 0:
595            # We can finally put an end to all this madness
596            givens = OrderedDict()
597            nw_outer = []
598            nw_inner = []
599            for to_repl, repl_in, repl_out in zip(clean_to_replace,
600                                                  clean_replace_with_in,
601                                                  clean_replace_with_out):
602                if isinstance(repl_out, theano.Constant):
603                    repl_in = repl_out.clone()
604                else:
605                    nw_inner.append(repl_in)
606                    nw_outer.append(repl_out)
607
608                givens[to_repl] = repl_in
609
610            op_outs = scan_utils.clone(clean_outputs, replace=givens)
611            op_ins = nw_inner + clean_inputs
612
613            # Reconstruct node
614            nw_info = op.info.copy()
615            nw_info['n_seqs'] += len(nw_inner)
616            nwScan = scan_op.Scan(op_ins, op_outs, nw_info)
617            # Do not call make_node for test_value
618            nw_node = nwScan(*(node.inputs[:1] + nw_outer + node.inputs[1:]),
619                             **dict(return_list=True))[0].owner
620
621            fgraph.replace_all_validate_remove(
622                list(zip(node.outputs, nw_node.outputs)),
623                remove=[node],
624                reason='scanOp_pushout_seqs_ops')
625            return True
626        elif (not to_keep_set and
627              not op.as_while and
628              not op.outer_mitmot(node)):
629            # Nothing in the inner graph should be kept
630            replace_with = OrderedDict()
631            for out, idx in to_replace_map.items():
632                if out in local_fgraph_outs_set:
633                    x = node.outputs[local_fgraph_outs_map[out]]
634                    _y = replace_with_out[idx]
635                    ls = clean_outputs
636                    if out in op.inner_mitsot_outs(ls):
637                        odx = op.inner_mitsot_outs(ls).index(out)
638                        inp = op.outer_mitsot(node)[odx]
639                        st = abs(np.min(op.mitsot_taps()))
640                        y = tensor.set_subtensor(inp[st:], _y)
641                    elif out in op.inner_sitsot_outs(ls):
642                        odx = op.inner_sitsot_outs(ls).index(out)
643                        inp = op.outer_sitsot(node)[odx]
644                        y = tensor.set_subtensor(inp[1:], _y)
645                    elif out in op.inner_nitsot_outs(ls):
646                        y = _y
647                    else:
648                        y = _y[-1]
649                    replace_with[x] = y
650
651            # We need to add one extra dimension to the outputs
652            if replace_with and len(replace_with) == len(node.outputs):
653                fgraph.replace_all_validate_remove(
654                    list(replace_with.items()),
655                    remove=[node],
656                    reason='scanOp_pushout_seqs_ops')
657                return True
658        else:
659            return False
660
661
662class PushOutScanOutput(gof.Optimizer):
663    """
664    This is an optimization that can push operations performed
665    at the end of the inner graph of scan to outside of scan.
666    """
667
668    def __init__(self):
669        gof.Optimizer.__init__(self)
670
671    def add_requirements(self, fgraph):
672        fgraph.attach_feature(gof.toolbox.ReplaceValidate())
673
674    def apply(self, fgraph):
675        # Don't perform the optimization on as_while scans. Because these scans
676        # don't run for a predetermined number of steps, handling them is
677        # more complicated and this optimization doesn't support it at the
678        # moment.
679        nodelist = [x for x in fgraph.toposort()
680                    if (isinstance(x.op, scan_op.Scan) and
681                        not x.op.as_while)]
682        for node in nodelist:
683            # Process the node as long as something gets optimized
684            while node is not None:
685                node = self.process_node(fgraph, node)
686
687    def process_node(self, fgraph, node):
688
689        op = node.op
690
691        # Use scan_args to parse the inputs and outputs of scan for ease of
692        # use
693        args = scan_args(node.inputs, node.outputs,
694                         op.inputs, op.outputs, op.info)
695
696        new_scan_node = None
697        clients = {}
698        local_fgraph_topo = theano.gof.graph.io_toposort(args.inner_inputs,
699                                                         args.inner_outputs,
700                                                         clients=clients)
701
702        for nd in local_fgraph_topo:
703            if (isinstance(nd.op, theano.tensor.elemwise.Elemwise) and
704                    isinstance(nd.op.scalar_op, scalar.Add) and
705                    nd.out in args.inner_out_sit_sot and
706                    self.inner_sitsot_only_last_step_used(nd.out, args)):
707
708                # Ensure that one of the input to the add is the output of
709                # the add from a previous iteration of the inner function
710                sitsot_idx = args.inner_out_sit_sot.index(nd.out)
711                if args.inner_in_sit_sot[sitsot_idx] in nd.inputs:
712
713                    # Ensure that the other input to the add is a dot product
714                    # between 2 matrices which will become a tensor3 and a
715                    # matrix if pushed outside of the scan. Also make sure
716                    # that the output of the Dot is ONLY used by the 'add'
717                    # otherwise doing a Dot in the outer graph will only
718                    # duplicate computation.
719
720                    sitsot_in_idx = nd.inputs.index(args.inner_in_sit_sot[
721                                                    sitsot_idx])
722
723                    # 0 if sitsot_in_idx==1, 1 if sitsot_in_idx==0
724                    dot_in_idx = 1 - sitsot_in_idx
725
726                    dot_input = nd.inputs[dot_in_idx]
727
728                    if (dot_input.owner is not None and
729                        isinstance(dot_input.owner.op, theano.tensor.Dot) and
730                        len(clients[dot_input]) == 1 and
731                        dot_input.owner.inputs[0].ndim == 2 and
732                        dot_input.owner.inputs[1].ndim == 2 and
733                        self.get_outer_ndim(dot_input.owner.inputs[0], args) == 3 and
734                            self.get_outer_ndim(dot_input.owner.inputs[1], args) == 3):
735
736                        # The optimization can be be applied in this case.
737
738                        # Move out of scan the two inputs to the Dot and
739                        # perform a dot outside of scan on these two inputs
740                        inner_dot_inputs = nd.inputs[dot_in_idx].owner.inputs
741                        (outer_dot_inputs,
742                         new_scan_node,
743                         new_scan_args) = \
744                            self.push_out_inner_vars(fgraph, inner_dot_inputs,
745                                                     node, args)
746
747                        # Collapse some of the dimensions of the tensors
748                        # so that they become matrices. This is because a
749                        # dot is usually faster on two large matrices than
750                        # a bunch of small ones
751                        outer_dot_inputs[0] = theano.tensor.flatten(
752                            outer_dot_inputs[0].dimshuffle(1, 0, 2), ndim=2)
753
754                        shape_input1 = theano.tensor.shape(outer_dot_inputs[1])
755                        outer_dot_inputs[1] =\
756                            outer_dot_inputs[1].reshape((shape_input1[0] *
757                                                         shape_input1[1],
758                                                         shape_input1[2]))
759
760                        # Perform the dot on the newly obtained matrices and
761                        # add the initial value
762                        outer_dot_output = theano.tensor.dot(*outer_dot_inputs)
763                        init_value = new_scan_args.outer_in_sit_sot[sitsot_idx][0]
764                        replacement = outer_dot_output + init_value
765
766                        # Alter the outer graph to use the output of the
767                        # external Dot instead of the output of scan
768                        # Modify the outer graph to add the outer Dot
769                        outer_sitsot = new_scan_args.outer_out_sit_sot[sitsot_idx]
770                        subtensor_node = outer_sitsot.clients[0][0]
771                        outer_sitsot_last_step = subtensor_node.outputs[0]
772
773                        fgraph.replace_all([
774                            (outer_sitsot_last_step, replacement)],
775                            reason="scanOp_pushout_output")
776
777                        break
778        return new_scan_node
779
780    def inner_sitsot_only_last_step_used(self, var, scan_args):
781        """
782        Given a inner nit_sot output of scan, return True iff the outer
783        nit_sot output has only one client and that client is a Subtensor
784        instance that takes only the last step (last element along the first
785        axis).
786
787        """
788        idx = scan_args.inner_out_sit_sot.index(var)
789        outer_var = scan_args.outer_out_sit_sot[idx]
790
791        if len(outer_var.clients) == 1:
792            client = outer_var.clients[0][0]
793            if (client != 'output' and isinstance(client.op,
794                                                  theano.tensor.Subtensor)):
795                lst = theano.tensor.subtensor.get_idx_list(
796                    client.inputs, client.op.idx_list)
797                if (len(lst) == 1 and
798                        theano.tensor.extract_constant(lst[0]) == -1):
799                    return True
800
801        return False
802
803    def get_outer_ndim(self, var, scan_args):
804
805        # Given a variable, determine the number of dimension it would have if
806        # it was pushed out of scan
807        if (var in scan_args.inner_in_non_seqs or
808                isinstance(var, theano.Constant)):
809
810            outer_ndim = var.ndim
811        else:
812            outer_ndim = var.ndim + 1
813
814        return outer_ndim
815
816    def push_out_inner_vars(self, fgraph, inner_vars, old_scan_node,
817                            old_scan_args):
818
819        outer_vars = [None] * len(inner_vars)
820        new_scan_node = old_scan_node
821        new_scan_args = old_scan_args
822
823        # For the inner_vars that already exist in the outer graph,
824        # simply obtain a reference to them
825        for idx in range(len(inner_vars)):
826
827            var = inner_vars[idx]
828
829            if var in old_scan_args.inner_in_seqs:
830                idx_seq = old_scan_args.inner_in_seqs.index(var)
831                outer_vars[idx] = old_scan_args.outer_in_seqs[idx_seq]
832
833            elif var in old_scan_args.inner_in_non_seqs:
834                idx_non_seq = old_scan_args.inner_in_non_seqs.index(var)
835                outer_vars[idx] = old_scan_args.outer_in_non_seqs[idx_non_seq]
836
837            elif isinstance(var, theano.Constant):
838                outer_vars[idx] = var.clone()
839
840            elif var in old_scan_args.inner_out_nit_sot:
841                idx_nitsot = old_scan_args.inner_out_nit_sot.index(var)
842                outer_vars[idx] = old_scan_args.outer_out_nit_sot[idx_nitsot]
843
844        # For the inner_vars that don't already exist in the outer graph, add
845        # them as new nitsot outputs to the scan node.
846        idx_add_as_nitsots = [i for i in range(len(outer_vars))
847                              if outer_vars[i] is None]
848        add_as_nitsots = [inner_vars[idx] for idx in idx_add_as_nitsots]
849
850        if len(add_as_nitsots) > 0:
851
852            new_scan_node = self.add_nitsot_outputs(fgraph, old_scan_node,
853                                                    old_scan_args,
854                                                    add_as_nitsots)
855
856            new_scan_args = scan_args(new_scan_node.inputs,
857                                      new_scan_node.outputs,
858                                      new_scan_node.op.inputs,
859                                      new_scan_node.op.outputs,
860                                      new_scan_node.op.info)
861
862            new_outs = new_scan_args.outer_out_nit_sot[-len(add_as_nitsots):]
863            for i in range(len(new_outs)):
864                outer_vars[idx_add_as_nitsots[i]] = new_outs[i]
865
866        return outer_vars, new_scan_node, new_scan_args
867
868    def add_nitsot_outputs(self, fgraph, old_scan_node,
869                           old_scan_args, new_outputs_inner):
870
871        nb_new_outs = len(new_outputs_inner)
872
873        # Create the initial values for the new nitsot outputs
874        # (the initial value is the nb of steps to store. For a nistot,
875        # it should be the number of steps performed by scan)
876        new_nitsots_initial_value = [old_scan_node.inputs[0]
877                                     for i in range(nb_new_outs)]
878
879        # Create the scan_args corresponding to the new scan op to
880        # create
881        new_scan_args = copy.copy(old_scan_args)
882        new_scan_args.inner_out_nit_sot.extend(new_outputs_inner)
883        new_scan_args.outer_in_nit_sot.extend(new_nitsots_initial_value)
884
885        # Create the scan op from the scan_args
886        new_scan_op = scan_op.Scan(new_scan_args.inner_inputs,
887                                   new_scan_args.inner_outputs,
888                                   new_scan_args.info)
889
890        # Create the Apply node for the scan op
891        new_scan_node = new_scan_op(*new_scan_args.outer_inputs,
892                                    **dict(return_list=True))[0].owner
893
894        # Modify the outer graph to make sure the outputs of the new scan are
895        # used instead of the outputs of the old scan
896        new_node_new_outputs_idx = (len(old_scan_args.outer_outputs) -
897                                    len(old_scan_args.outer_out_shared))
898
899        new_node_old_outputs = (
900            new_scan_node.outputs[:new_node_new_outputs_idx] +
901            new_scan_node.outputs[new_node_new_outputs_idx + nb_new_outs:])
902
903        fgraph.replace_all_validate_remove(
904            list(zip(old_scan_node.outputs, new_node_old_outputs)),
905            remove=[old_scan_node],
906            reason='scanOp_pushout_output')
907
908        return new_scan_node
909
910
911class ScanInplaceOptimizer(Optimizer):
912    """
913    Graph optimizer for Scan (makes it run inplace).
914
915    """
916
917    def __init__(self, typeInfer=None, gpua_flag=False):
918        Optimizer.__init__(self)
919        self.typeInfer = typeInfer
920        self.gpua_flag = gpua_flag
921
922    def add_requirements(self, fgraph):
923        fgraph.attach_feature(toolbox.ReplaceValidate())
924        fgraph.attach_feature(DestroyHandler())
925
926    def attempt_scan_inplace(self, fgraph, node, output_indices, alloc_ops):
927        """Attempts to replace a Scan node by one which computes the specified
928        outputs inplace.
929
930        Parameters
931        ----------
932        fgraph : FunctionGraph
933            Function graph in which to attempt the replacement
934        node : Apply node
935            Scan node to replace by an inplace version
936        output_indices : list of integers
937            Indices of the outputs to attempt to compute inplace
938        alloc_ops : list of Op classes
939            Classes that represent operation that allocate new memory and
940            that the optimization should duplicate so it can operate inplace
941            on them.
942        """
943
944        op = node.op
945
946        info = copy.deepcopy(op.info)
947        if 'destroy_map' not in info:
948            info['destroy_map'] = OrderedDict()
949
950        for out_idx in output_indices:
951            info['destroy_map'][out_idx] = [out_idx + 1 + op.info['n_seqs']]
952
953        # inputs corresponding to sequences and n_steps
954        ls_begin = node.inputs[:1 + op.n_seqs]
955        ls = op.outer_mitmot(node.inputs)
956        ls += op.outer_mitsot(node.inputs)
957        ls += op.outer_sitsot(node.inputs)
958        ls_end = op.outer_shared(node.inputs)
959        ls_end += op.outer_nitsot(node.inputs)
960        ls_end += op.outer_non_seqs(node.inputs)
961
962        # In `ls`, duplicate any input which has more then one client and is
963        # the output of an eligible allocation op
964        for i in range(len(ls)):
965            inp = ls[i]
966            if (len(inp.clients) > 1 and inp.owner and
967                    isinstance(inp.owner.op, alloc_ops)):
968                ls[i] = inp.owner.op(*inp.owner.inputs)
969
970        n_outs = len(ls)
971        for idx in xrange(n_outs):
972            if ls[idx] in ls[:idx]:
973                ls[idx] = deep_copy_op(ls[idx])
974
975        inputs = ls_begin + ls + ls_end
976        if self.typeInfer is None:
977            typeConstructor = None
978        else:
979            typeConstructor = self.typeInfer(node)
980
981        new_op = scan_op.Scan(op.inputs,
982                              op.outputs,
983                              info,
984                              typeConstructor=typeConstructor)
985
986        # Do not call make_node for test_value
987        new_outs = new_op(*inputs, **dict(return_list=True))
988        try:
989            fgraph.replace_all_validate_remove(
990                list(zip(node.outputs, new_outs)),
991                remove=[node],
992                reason='scanOp_make_inplace')
993            return new_outs[0].owner
994        except InconsistencyError:
995            # Failed moving output to be computed inplace
996            return node
997
998    def apply(self, fgraph):
999
1000        # Depending on the value of gpua_flag, get the list of memory
1001        # allocation ops that the optimization should be able to
1002        # handle
1003        alloc_ops = (Alloc, AllocEmpty)
1004        if self.gpua_flag:
1005            # gpuarray might be imported but not its GpuAlloc and
1006            # GpuAllopEmpty ops.
1007            try:
1008                alloc_ops += (theano.gpuarray.GpuAlloc,
1009                              theano.gpuarray.GpuAllocEmpty)
1010            except Exception:
1011                pass
1012
1013        nodes = fgraph.toposort()[::-1]
1014        scan_nodes = [x for x in nodes
1015                      if (isinstance(x.op, scan_op.Scan) and
1016                          x.op.info['gpua'] == self.gpua_flag)]
1017        for scan_idx in xrange(len(scan_nodes)):
1018
1019            # First attempt to make the Scan compute inplace every recurrent
1020            # output that seems like it could be computed inplace. If that
1021            # fails, go through these outputs individually, trying each of
1022            # them.
1023            original_node = scan_nodes[scan_idx]
1024            op = original_node.op
1025            n_outs = (op.info['n_mit_mot'] +
1026                      op.info['n_mit_sot'] +
1027                      op.info['n_sit_sot'])
1028
1029            # Generate a list of outputs on which the node could potentially
1030            # operate inplace.
1031            out_indices = []
1032            for out_idx in range(n_outs):
1033                inp_idx = 1 + op.n_seqs + out_idx
1034                inp = original_node.inputs[inp_idx]
1035
1036                # If the input is from an eligible allocation node, attempt to
1037                # be inplace on it, even if other nodes are modifying it
1038                # inplace.
1039                if inp.owner and isinstance(inp.owner.op, alloc_ops):
1040                    out_indices.append(out_idx)
1041                    continue
1042
1043                # If the input is not from an eligible allocation node, only
1044                # attempt to be inplace on it if nothing else is currently
1045                # inplace on it.
1046                input_used_inplace = False
1047                for c in original_node.inputs[inp_idx].clients:
1048                    client = c[0]
1049
1050                    # Get the indices of this client's inputs on which it
1051                    # operates inplace
1052                    if hasattr(client.op, 'destroy_map'):
1053                        # This flattens the content of destroy_map.values()
1054                        # which is a list of lists
1055                        inplace_inp_indices = sum(client.op.destroy_map.values(), [])
1056
1057                        inplace_inps = [client.inputs[i] for i in inplace_inp_indices]
1058                        if original_node.inputs[inp_idx] in inplace_inps:
1059                            input_used_inplace = True
1060                            break
1061
1062                if not input_used_inplace:
1063                    out_indices.append(out_idx)
1064
1065            node = self.attempt_scan_inplace(fgraph, scan_nodes[scan_idx],
1066                                             out_indices, alloc_ops)
1067
1068            if node is original_node:
1069                # Making the scan compute all plausible recurrent outputs
1070                # inplace has failed. Attempt all plausible recurrent output
1071                # individually.
1072                for pos in out_indices:
1073                    node = self.attempt_scan_inplace(fgraph, node, [pos],
1074                                                     alloc_ops)
1075
1076
1077class ScanSaveMem(gof.Optimizer):
1078    """
1079    Graph Optimizer that reduces scan memory consumption.
1080
1081    """
1082
1083    def __init__(self):
1084        gof.Optimizer.__init__(self)
1085
1086    def add_requirements(self, fgraph):
1087        fgraph.attach_feature(gof.toolbox.ReplaceValidate())
1088
1089    def process_node(self, fgraph, node):
1090
1091        # helpful functions
1092        def select_min(x, y):
1093            if x is None:
1094                return y
1095            if y is None:
1096                return x
1097            return tensor.minimum(x, y)
1098
1099        def select_max(x, y):
1100            if x is None:
1101                return y
1102            if y is None:
1103                return x
1104            return tensor.maximum(x, y)
1105
1106        def sanitize(x):
1107            if x is None:
1108                return None
1109            else:
1110                return tensor.as_tensor_variable(x)
1111
1112        if hasattr(fgraph, 'shape_feature'):
1113            shape_of = node.fgraph.shape_feature.shape_of
1114        else:
1115            # Each access to shape_of is in a try..except block in order to
1116            # use a default version when the variable is not in the shape_of
1117            # dictionary.
1118            shape_of = OrderedDict()
1119        # 1. Initialization of variables
1120        # Note 1) We do not actually care about outputs representing shared
1121        # variables (those have no intermediate values) so it is safer to
1122        # ignore them and not change them in any way. To simplify the
1123        # optimizations I construct the variable ``c_outs`` ( that counts
1124        # outputs up to those we care) and the list ``init_l`` which for any
1125        # output we care says the length of its initial state. Note that
1126        # defining ``init_l`` for mit_mot sequences is a bit trickier but
1127        # it is safe to set it to 0
1128        op = node.op
1129        c_outs = op.n_mit_mot + op.n_mit_sot + op.n_sit_sot + op.n_nit_sot
1130
1131        init_l = [0 for x in xrange(op.n_mit_mot)]
1132        init_l += [abs(min(v)) for v in op.tap_array[op.n_mit_mot:]]
1133        init_l += [0 for x in xrange(op.n_nit_sot)]
1134        # 2. Check the clients of each output and see for how many steps
1135        # does scan need to run
1136
1137        # This comparison checks if there is any uncounted output, which
1138        # can only be an output corresponding to a shared variable
1139
1140        # 2.1 Initialize
1141        # global_nsteps is a dictionary having two fields ( 'real' deals
1142        # with int values, 'sym' with symbolic ones) or None
1143        # given that a scan op has k outputs o_1, .. o_k and each
1144        # output has n_j clients c_1^1, c_1^2, .. c_1^{n_1}, c_2^1, ..,
1145        # global_nsteps is None if any of the clients is different
1146        # from a subtensor or its real and sym field equal to
1147        # max(c_i_j.idx_list[0].stop), meaning store up to which maximal
1148        # index(step) for any output scan actually needs to compute
1149        # In other words n_steps should be equal to this maximal !
1150        # Note: if we have a shared variable that gets updated at every step
1151        # of the loop, reducing the number of steps will affect the the
1152        # value of the shared variable after the loop so we need not to
1153        # change the number of steps in that case. To do this we set
1154        # global_nsteps to None which is seen as a flag that nothing needs
1155        # to be done
1156        assert len(node.outputs) >= c_outs
1157        if len(node.outputs) == c_outs:
1158            global_nsteps = {'real': -1, 'sym': []}
1159        else:
1160            global_nsteps = None
1161
1162        # Keeps track of the original slices that each client represent
1163        slices = [None for o in node.outputs]
1164
1165        # A list for each output indicating how many intermediate values
1166        # should be stored. If negative it means none of the intermediate
1167        # values (i.e. the output can be removed since it is not used
1168        # afterwards in the computations), if 0 it means that all
1169        # intermediate values are required, otherwise is up to that number
1170        # of intermediate values
1171        # Note that for mit_mot outputs and shared outputs we can not change
1172        # the number of intermediate steps stored without affecting the
1173        # result of the op
1174        store_steps = [0 for o in xrange(op.n_mit_mot)]
1175        store_steps += [-1 for o in node.outputs[op.n_mit_mot:c_outs]]
1176        # Flag that says if an input has changed and we need to do something
1177        # or not
1178        flag_store = False
1179
1180        # 2.2 Loop over the clients
1181        for i, out in enumerate(node.outputs[:c_outs]):
1182            # look at all its clients
1183            slices[i] = []
1184            for cl, _ in out.clients:
1185
1186                # 2.1 outputs of the function
1187                # => output needs all its intermediate values
1188                if type(cl) == str:
1189                    # if the node is actually an output, then
1190                    # we need to store the entire thing
1191                    global_nsteps = None
1192                    slices[i] = None
1193                    break
1194                # 2.2 non-subtensor nodes
1195                # => output needs all its intermediate values
1196                elif not isinstance(cl.op, tensor.Subtensor):
1197                    global_nsteps = None
1198                    slices[i] = None
1199                    break
1200                # 2.3 subtensor nodes
1201                # => output might need to store just a subset of its values
1202                else:
1203                    # 2.3.1 extract idx list of subtensor
1204                    this_slice = tensor.get_idx_list(cl.inputs,
1205                                                     cl.op.idx_list)
1206                    if this_slice is None:
1207                        # if unable to extract idx_list
1208                        # => outputs needs all its intermediate values
1209                        global_nsteps = None
1210                        slices[i] = None
1211                        break
1212
1213                    # 2.3.2 extract the begin/end of the first dimension
1214                    if i >= op.n_mit_mot:
1215                        try:
1216                            length = shape_of[out][0]
1217                        except KeyError:
1218                            length = node.inputs[0] + init_l[i]
1219                    else:
1220                        try:
1221                            length = shape_of[out][0]
1222                        except KeyError:
1223                            length = out.shape[0]
1224                    cf_slice = tensor.get_canonical_form_slice(
1225                        this_slice[0], length)
1226                    slices[i] += [(cf_slice, this_slice)]
1227
1228                    if (isinstance(this_slice[0], slice) and
1229                            this_slice[0].stop is None):
1230                        global_nsteps = None
1231                    if isinstance(cf_slice[0], slice):
1232                        stop = tensor.basic.extract_constant(cf_slice[0].stop)
1233                    else:
1234                        stop = tensor.basic.extract_constant(cf_slice[0]) + 1
1235                    if stop == maxsize or stop == length:
1236                        stop = None
1237                    else:
1238                        # there is a **gotcha** here ! Namely, scan returns an
1239                        # array that contains the initial state of the output
1240                        # as well. Which means that if have a initial state of
1241                        # length 3, and you look for 5 steps you get an output
1242                        # y of length 8. If you only use y[:5], this does not
1243                        # mean that you only need to loop for 5 steps but
1244                        # actually only for 2 steps ( the first 3 are the
1245                        # initial state)
1246                        stop = stop - init_l[i]
1247
1248                    # 2.3.3 we might get away with less number of steps
1249                    if stop is not None and global_nsteps is not None:
1250                        # yes if it is a tensor
1251                        if isinstance(stop, tensor.Variable):
1252                            global_nsteps['sym'] += [stop]
1253                        # not if it is maxsize
1254                        elif (type(stop) in integer_types and
1255                              stop == maxsize):
1256                            global_nsteps = None
1257                        # yes if it is a int k, 0 < k < maxsize
1258                        elif (type(stop) in integer_types and
1259                              global_nsteps['real'] < stop):
1260                            global_nsteps['real'] = stop
1261                        # yes if it is a int k, 0 < k < maxsize
1262                        elif (type(stop) in integer_types and stop > 0):
1263                            pass
1264                        # not otherwise
1265                        else:
1266                            global_nsteps = None
1267
1268        # 2.3. Analyze global_nsteps to figure out for how many steps scan
1269        # needs to iterate
1270        if global_nsteps is not None:
1271            nw_steps = node.inputs[0]
1272
1273            # there are some symbolic tensors that limit the number of
1274            # steps
1275            if len(global_nsteps['sym']) == 0:
1276                sym_steps = None
1277            else:
1278                sym_steps = global_nsteps['sym'][0]
1279                for c in global_nsteps['sym'][1:]:
1280                    sym_steps = tensor.maximum(sym_steps, c)
1281
1282            if global_nsteps['real'] >= 0:
1283                real_steps = global_nsteps['real']
1284            else:
1285                real_steps = None
1286            nw_steps = select_min(select_max(sym_steps, real_steps),
1287                                  node.inputs[0])
1288
1289            # Make sure the ScanSaveMem optimization never makes the new
1290            # number of steps to be 0 (this could happen, for instance, if
1291            # the optimization detects that the outputs of the Scan go through
1292            # subtensor nodes that end up taking no elements) because Scan with
1293            # 0 iterations are not supported. Make sure the new number of steps
1294            # is at least 1.
1295            nw_steps = select_max(nw_steps, 1)
1296        else:
1297            nw_steps = node.inputs[0]
1298            global_nsteps = None
1299
1300        # 2.4 Loop over the clients again now looking just to see how many
1301        # intermediate steps to store
1302        for i, out in enumerate(node.outputs[:c_outs]):
1303            # look at all its clients
1304            for cl, _ in out.clients:
1305                if type(cl) == str:
1306                    store_steps[i] = 0
1307                    break
1308                elif not isinstance(cl.op, tensor.Subtensor):
1309                    store_steps[i] = 0
1310                    break
1311                else:
1312                    this_slice = tensor.get_idx_list(cl.inputs,
1313                                                     cl.op.idx_list)
1314                    if this_slice is None:
1315                        store_steps[i] = 0
1316                        break
1317
1318                    if (isinstance(this_slice[0], slice) and
1319                            this_slice[0].start is None):
1320                        store_steps[i] = 0
1321                        break
1322
1323                    if i > op.n_mit_mot:
1324                        length = node.inputs[0] + init_l[i]
1325                    else:
1326                        try:
1327                            length = shape_of[out][0]
1328                        except KeyError:
1329                            length = out.shape[0]
1330                    cf_slice = tensor.get_canonical_form_slice(
1331                        this_slice[0], length)
1332
1333                    if isinstance(cf_slice[0], slice):
1334                        start = tensor.basic.extract_constant(
1335                            cf_slice[0].start)
1336                    else:
1337                        start = tensor.basic.extract_constant(cf_slice[0])
1338                    if start == 0 or store_steps[i] == 0:
1339                        store_steps[i] = 0
1340                    else:
1341                        # The "+ 1" is because of the memory pre-allocation
1342                        # mechanism used to in the Scan op to reduce overhead.
1343                        # To prevent aliasing between the inputs and outputs
1344                        # of recurrent states, it requires that the buffer be
1345                        # large enough to that, the new state and the oldest
1346                        # tap needed don't occupy the sample place in the
1347                        # circular buffer. For now, this only needs to be done
1348                        # for mitsots and sitsots (because mitmots are not
1349                        # currently supported by the mechanism) and only if
1350                        # the pre-allocation mechanism is activated.
1351                        prealloc_outs = theano.config.scan.allow_output_prealloc
1352
1353                        first_mitsot_idx = node.op.n_mit_mot
1354                        last_sitsot_idx = (node.op.n_mit_mot +
1355                                           node.op.n_mit_sot +
1356                                           node.op.n_sit_sot - 1)
1357                        preallocable_output = (first_mitsot_idx <= i <= last_sitsot_idx)
1358
1359                        if (prealloc_outs and preallocable_output):
1360                            pval = select_max(nw_steps - start + init_l[i],
1361                                              init_l[i] + 1)
1362                        else:
1363                            pval = select_max(nw_steps - start + init_l[i],
1364                                              init_l[i])
1365
1366                        if store_steps[i] != -1:
1367                            pval = select_max(pval, store_steps[i])
1368
1369                        # TODO: Simplify the number of steps needed.
1370                        # FB: This need good testing, left to later.
1371                        #     call get_scalar_constant_value()? it can
1372                        # return python/numpy scalar or np.ndarray
1373                        # currently.
1374                        # pval = pre_greedy_local_optimizer(list_opt_slice,
1375                        #                                  pval)
1376                        # pval = pre_constant_merge([pval])[0]
1377                        # if (isinstance(pval, theano.tensor.TensorConstant)
1378                        # and
1379                        #    pval.dtype.startswith('int')):
1380                        #    try:
1381                        #        pval = int(pval.data)
1382                        #    except Exception:
1383                        #        pass
1384
1385                        store_steps[i] = pval
1386                        flag_store = True
1387
1388        orphane_outs = [i for i, x in enumerate(store_steps)
1389                        if (type(x) is int) and (x < 0)]
1390        flag_store = flag_store or (len(orphane_outs) > 0)
1391        # 3. is there anything to change ?
1392        if (flag_store or global_nsteps is not None):
1393            # 3.1 initialize inputs for the new scan
1394            old_outputs = []
1395            nw_inputs = list(node.inputs)
1396            nw_inputs[0] = nw_steps
1397
1398            # 3.2 check orphane outputs to see if we can eliminate any
1399            required, not_required = scan_utils.scan_can_remove_outs(
1400                node.op, orphane_outs)
1401            # 3.3. compose replace pairs for those nodes that need not
1402            # to store everything in memory ( or ar orphane and required
1403            # by the inner function .. )
1404            replaced_outs = []
1405            offset = 1 + op.n_seqs + op.n_mit_mot
1406            for idx, _val in enumerate(store_steps[op.n_mit_mot:]):
1407                i = idx + op.n_mit_mot
1408                if not(type(_val) is int and _val <= 0 and i not in required):
1409
1410                    if idx + op.n_mit_mot in required:
1411                        val = 1
1412                    else:
1413                        val = _val
1414                    # If the memory for this output has been pre-allocated
1415                    # before going into the scan op (by an alloc node)
1416                    if idx < op.n_mit_sot + op.n_sit_sot:
1417                        # In case the input is still an alloc node, we
1418                        # actually have two options:
1419                        #   a) the input is a set_subtensor, in that case we
1420                        #      can replace the initial tensor by a slice,
1421                        #   b) it is not, and we simply take a slice of it.
1422                        # TODO: commit change below with Razvan
1423                        if (nw_inputs[offset + idx].owner and
1424                            isinstance(nw_inputs[offset + idx].owner.op,
1425                                       tensor.IncSubtensor) and
1426                            isinstance(
1427                                nw_inputs[offset + idx].owner.op.idx_list[0],
1428                                slice)):
1429
1430                            assert isinstance(nw_inputs[offset + idx].owner.op,
1431                                              tensor.IncSubtensor)
1432                            _nw_input = nw_inputs[offset + idx].owner.inputs[1]
1433                            cval = tensor.as_tensor_variable(val)
1434                            initl = tensor.as_tensor_variable(init_l[i])
1435                            tmp_idx = tensor.switch(cval < initl,
1436                                                    cval + initl,
1437                                                    cval - initl)
1438                            tmp = pre_greedy_local_optimizer(list_opt_slice,
1439                                                             tmp_idx)
1440                            tmp = pre_constant_merge([tmp])[0]
1441
1442                            nw_input = scan_utils.expand_empty(_nw_input, tmp)
1443                        else:
1444                            tmp = tensor.as_tensor_variable(val)
1445                            initl = tensor.as_tensor_variable(init_l[i])
1446                            tmp = tensor.maximum(tmp, initl)
1447                            tmp = pre_greedy_local_optimizer(list_opt_slice,
1448                                                             tmp)
1449                            tmp = pre_constant_merge([tmp])[0]
1450                            nw_input = nw_inputs[offset + idx][:tmp]
1451
1452                        nw_inputs[offset + idx] = nw_input
1453                        replaced_outs.append(op.n_mit_mot + idx)
1454                        odx = op.n_mit_mot + idx
1455                        old_outputs += [(odx, [x[0].outputs[0] for x in
1456                                        node.outputs[odx].clients])]
1457                    # If there is no memory pre-allocated for this output
1458                    elif idx < op.n_mit_sot + op.n_sit_sot + op.n_nit_sot:
1459
1460                        pos = (op.n_mit_mot + idx + op.n_seqs +
1461                               1 + op.n_shared_outs)
1462                        if nw_inputs[pos] == node.inputs[0]:
1463                            nw_inputs[pos] = val
1464                        odx = op.n_mit_mot + idx
1465                        replaced_outs.append(odx)
1466                        old_outputs += [(odx, [x[0].outputs[0] for x in
1467                                        node.outputs[odx].clients])]
1468            # 3.4. Recompute inputs for everything else based on the new
1469            # number of steps
1470            if global_nsteps is not None:
1471                for idx, val in enumerate(store_steps[op.n_mit_mot:]):
1472                    if val == 0:
1473                        # val == 0 means that we want to keep all intermediate
1474                        # results for that state, including the initial values.
1475                        if idx < op.n_mit_sot + op.n_sit_sot:
1476                            in_idx = offset + idx
1477                            # Number of steps in the initial state
1478                            initl = init_l[op.n_mit_mot + idx]
1479
1480                            # If the initial buffer has the form
1481                            # inc_subtensor(zeros(...)[...], _nw_input)
1482                            # we want to make the zeros tensor as small as
1483                            # possible (nw_steps + initl), and call
1484                            # inc_subtensor on that instead.
1485                            # Otherwise, simply take 0:(nw_steps+initl).
1486                            if ((nw_inputs[in_idx].owner and
1487                                 isinstance(nw_inputs[in_idx].owner.op,
1488                                            tensor.IncSubtensor) and
1489                                 isinstance(
1490                                     nw_inputs[in_idx].owner.op.idx_list[0],
1491                                     slice))):
1492                                _nw_input = nw_inputs[in_idx].owner.inputs[1]
1493                                nw_input = scan_utils.expand_empty(_nw_input,
1494                                                                   nw_steps)
1495                                nw_inputs[in_idx] = nw_input
1496                            else:
1497                                nw_input = nw_inputs[in_idx][:(initl + nw_steps)]
1498
1499                        elif idx < op.n_mit_sot + op.n_sit_sot + op.n_nit_sot:
1500                            in_idx = offset + idx + op.n_shared_outs
1501                            if nw_inputs[in_idx] == node.inputs[0]:
1502                                nw_inputs[in_idx] = nw_steps
1503
1504            # 3.5 Remove unwanted orphane outputs
1505            (inps, outs, info, node_ins, compress_map) = \
1506                scan_utils.compress_outs(op, not_required, nw_inputs)
1507            inv_compress_map = OrderedDict()
1508            for k, v in iteritems(compress_map):
1509                inv_compress_map[v] = k
1510
1511            node_ins = [pre_greedy_local_optimizer(list_opt_slice, x) for x in
1512                        node_ins]
1513            node_ins = pre_constant_merge(node_ins)
1514            # 3.6 Compose the new scan
1515            # TODO: currently we don't support scan with 0 step. So
1516            # don't create one.
1517            # For test, mark that savemem have optimized this node
1518            info['_scan_savemem_visited'] = True
1519            if theano.tensor.extract_constant(node_ins[0]) == 0:
1520                return
1521
1522            # Do not call make_node for test_value
1523            new_outs = scan_op.Scan(inps, outs, info)(*node_ins,
1524                                                      **dict(return_list=True))
1525
1526            old_new = []
1527            # 3.7 Get replace pairs for those outputs that do not change
1528            # the number of intermediate steps stored
1529            for idx, sl in enumerate(slices):
1530                if global_nsteps and sl is not None and store_steps[idx] == 0:
1531                    for hdx, cl in enumerate(node.outputs[idx].clients):
1532                        cnf_slice, old_slices = sl[hdx]
1533                        # Sanitize the nw_slice by converting ints back into
1534                        # constants :) I only need to do this for the first
1535                        # slice since that is the only slice
1536
1537                        if isinstance(cnf_slice[0], slice):
1538                            fslice = slice(
1539                                sanitize(cnf_slice[0].start),
1540                                sanitize(cnf_slice[0].stop),
1541                                sanitize(cnf_slice[0].step))
1542                        else:
1543                            fslice = sanitize(cnf_slice[0])
1544
1545                        nw_slice = (fslice,) + tuple(old_slices[1:])
1546                        nw_pos = inv_compress_map[idx]
1547
1548                        subtens = tensor.Subtensor(nw_slice)
1549                        # slice inputs
1550                        sl_ins = tensor.Subtensor.collapse(
1551                            nw_slice,
1552                            lambda entry: isinstance(entry,
1553                                                     tensor.Variable))
1554                        new_o = subtens(new_outs[nw_pos], *sl_ins)
1555                        if new_o.ndim > 0:
1556                            new_o = new_o[::cnf_slice[1]]
1557                        replaced_outs.append(idx)
1558                        old_new += [(cl[0].outputs[0], new_o)]
1559            # 3.8. Get replace pairs for those outputs that change
1560            # the number of stored intermediate steps
1561            for pos, old_outs in old_outputs:
1562                if len(old_outs) > 0:
1563                    nw_pos = compress_map[pos]
1564                    for k, old in enumerate(old_outs):
1565                        # Get the correct slice
1566                        cnf_slice, old_slices = slices[pos][k]
1567                        if type(cnf_slice[0]) is slice:
1568                            start = (cnf_slice[0].start - nw_steps -
1569                                     init_l[pos] + store_steps[pos])
1570                            if (cnf_slice[0].stop is not None and
1571                                    cnf_slice[0].stop != maxsize):
1572                                stop = (cnf_slice[0].stop - nw_steps -
1573                                        init_l[pos] + store_steps[pos])
1574                            else:
1575                                stop = None
1576                            nw_slice = ((slice(sanitize(start),
1577                                               sanitize(stop),
1578                                               sanitize(cnf_slice[0].step)),) +
1579                                        tuple(old_slices[1:]))
1580
1581                        else:
1582                            position = (cnf_slice[0] - nw_steps -
1583                                        init_l[pos] + store_steps[pos])
1584
1585                            nw_slice = (sanitize(position),) + tuple(
1586                                old_slices[1:])
1587                        subtens = tensor.Subtensor(nw_slice)
1588                        sl_ins = tensor.Subtensor.collapse(
1589                            nw_slice,
1590                            lambda entry: isinstance(entry,
1591                                                     tensor.Variable))
1592                        new_o = subtens(new_outs[nw_pos], *sl_ins)
1593                        if new_o.ndim > 0:
1594                            new_o = new_o[::cnf_slice[1]]
1595                        old_new += [(old, new_o)]
1596
1597            # 3.9. Get replace pairs for all other nodes
1598            if flag_store or global_nsteps is not None:
1599                for idx, o in enumerate(node.outputs):
1600                    if not (idx in replaced_outs) and idx not in not_required:
1601                        nw_pos = compress_map[idx]
1602                        old_new += [(o, new_outs[nw_pos])]
1603                # Check if the new outputs depend on the old scan node
1604                old_scan_is_used = [gof.graph.is_in_ancestors(new.owner, node)
1605                                    for old, new in old_new]
1606                if any(old_scan_is_used):
1607                    return False
1608                remove = [old.owner for (old, new) in old_new]
1609                # As Fred suggested assert that also the old node is not in
1610                # the Graph as that will make things suboptimal
1611                remove.append(node)
1612                fgraph.replace_all_validate_remove(old_new,
1613                                                   remove,
1614                                                   reason='scanOp_save_mem')
1615
1616    def apply(self, fgraph):
1617
1618        nodelist = [x for x in fgraph.toposort() if isinstance(x.op,
1619                                                               scan_op.Scan)]
1620        for node in nodelist:
1621            self.process_node(fgraph, node)
1622
1623
1624class ScanMerge(gof.Optimizer):
1625    """
1626    Graph Optimizer that merges different scan ops.
1627
1628    """
1629
1630    def add_requirements(self, fgraph):
1631        fgraph.attach_feature(gof.toolbox.ReplaceValidate())
1632
1633    def merge(self, nodes):
1634
1635        if nodes[0].op.as_while:
1636            as_while = True
1637            condition = nodes[0].op.outputs[-1]
1638        else:
1639            as_while = False
1640
1641        info = OrderedDict()
1642        info['tap_array'] = []
1643        info['n_seqs'] = sum([nd.op.n_seqs for nd in nodes])
1644        info['n_mit_mot'] = sum([nd.op.n_mit_mot for nd in nodes])
1645        info['n_mit_mot_outs'] = sum([nd.op.n_mit_mot_outs for nd in nodes])
1646        info['mit_mot_out_slices'] = []
1647        info['n_mit_sot'] = sum([nd.op.n_mit_sot for nd in nodes])
1648        info['n_sit_sot'] = sum([nd.op.n_sit_sot for nd in nodes])
1649        info['n_shared_outs'] = sum([nd.op.n_shared_outs for nd in nodes])
1650        info['n_nit_sot'] = sum([nd.op.n_nit_sot for nd in nodes])
1651        info['truncate_gradient'] = nodes[0].op.truncate_gradient
1652        info['name'] = '&'.join([nd.op.name for nd in nodes])
1653        info['mode'] = nodes[0].op.mode
1654        info['gpua'] = False
1655        info['as_while'] = as_while
1656        info['profile'] = nodes[0].op.profile
1657        info['allow_gc'] = nodes[0].op.allow_gc
1658
1659        # We keep the inner_ins and inner_outs of each original node separated.
1660        # To be able to recombine them in the right order after the clone,
1661        # we also need to split them by types (seq, mitmot, ...).
1662        # On the other hand, outer_ins, outer_outs and info are held together.
1663        inner_ins = [[] for nd in nodes]
1664        outer_ins = []
1665        inner_outs = [[] for nd in nodes]
1666        outer_outs = []
1667
1668        def rename(ls, suffix):
1669            for k in ls:
1670                if k.name:
1671                    k.name += str(suffix)
1672            return ls
1673
1674        for idx, nd in enumerate(nodes):
1675            # Seq
1676            inner_ins[idx].append(rename(nd.op.inner_seqs(nd.op.inputs), idx))
1677            outer_ins += rename(nd.op.outer_seqs(nd.inputs), idx)
1678
1679        for idx, nd in enumerate(nodes):
1680            # MitMot
1681            inner_ins[idx].append(
1682                rename(nd.op.inner_mitmot(nd.op.inputs), idx))
1683            inner_outs[idx].append(nd.op.inner_mitmot_outs(nd.op.outputs))
1684            info['tap_array'] += nd.op.mitmot_taps()
1685            info['mit_mot_out_slices'] += nd.op.mitmot_out_taps()
1686            outer_ins += rename(nd.op.outer_mitmot(nd.inputs), idx)
1687            outer_outs += nd.op.outer_mitmot_outs(nd.outputs)
1688
1689        for idx, nd in enumerate(nodes):
1690            # MitSot
1691            inner_ins[idx].append(
1692                rename(nd.op.inner_mitsot(nd.op.inputs), idx))
1693            inner_outs[idx].append(nd.op.inner_mitsot_outs(nd.op.outputs))
1694            info['tap_array'] += nd.op.mitsot_taps()
1695            outer_ins += rename(nd.op.outer_mitsot(nd.inputs), idx)
1696            outer_outs += nd.op.outer_mitsot_outs(nd.outputs)
1697
1698        for idx, nd in enumerate(nodes):
1699            # SitSot
1700            inner_ins[idx].append(
1701                rename(nd.op.inner_sitsot(nd.op.inputs), idx))
1702            info['tap_array'] += [[-1] for x in xrange(nd.op.n_sit_sot)]
1703            inner_outs[idx].append(nd.op.inner_sitsot_outs(nd.op.outputs))
1704            outer_ins += rename(nd.op.outer_sitsot(nd.inputs), idx)
1705            outer_outs += nd.op.outer_sitsot_outs(nd.outputs)
1706
1707        for idx, nd in enumerate(nodes):
1708            # Shared
1709            inner_ins[idx].append(
1710                rename(nd.op.inner_shared(nd.op.inputs), idx))
1711            outer_ins += rename(nd.op.outer_shared(nd.inputs), idx)
1712
1713        for idx, nd in enumerate(nodes):
1714            # NitSot
1715            inner_outs[idx].append(nd.op.inner_nitsot_outs(nd.op.outputs))
1716            outer_ins += rename(nd.op.outer_nitsot(nd.inputs), idx)
1717            outer_outs += nd.op.outer_nitsot_outs(nd.outputs)
1718
1719        for idx, nd in enumerate(nodes):
1720            # Shared
1721            outer_outs += nd.op.outer_shared_outs(nd.outputs)
1722            inner_outs[idx].append(nd.op.inner_shared_outs(nd.op.outputs))
1723
1724        for idx, nd in enumerate(nodes):
1725            # Non Seqs
1726            inner_ins[idx].append(
1727                rename(nd.op.inner_non_seqs(nd.op.inputs), idx))
1728            outer_ins += rename(nd.op.outer_non_seqs(nd.inputs), idx)
1729
1730        # Add back the number of steps
1731        outer_ins = [nodes[0].inputs[0]] + outer_ins
1732
1733        if as_while:
1734            # add the condition, which was the one of nodes[0]
1735            inner_outs[0].append([condition])
1736
1737        # Clone the inner graph of each node independently
1738        for idx, nd in enumerate(nodes):
1739            # concatenate all inner_ins and inner_outs of nd
1740            flat_inner_ins = sum(inner_ins[idx], [])
1741            flat_inner_outs = sum(inner_outs[idx], [])
1742            # clone
1743            flat_inner_ins, flat_inner_outs = scan_utils.reconstruct_graph(
1744                flat_inner_ins, flat_inner_outs)
1745            # split the new inner variables again in seq, mitmot, etc.
1746            new_inner_ins = []
1747            count = 0
1748            for nl in inner_ins[idx]:
1749                seq_len = len(nl)
1750                new_inner_ins.append(flat_inner_ins[count:(count + seq_len)])
1751                count += seq_len
1752
1753            new_inner_outs = []
1754            count = 0
1755            for nl in inner_outs[idx]:
1756                seq_len = len(nl)
1757                new_inner_outs.append(flat_inner_outs[count:(count + seq_len)])
1758                count += seq_len
1759
1760            inner_ins[idx] = new_inner_ins
1761            inner_outs[idx] = new_inner_outs
1762
1763        # Flatten inner_ins and inner_outs so that all seqs are first,
1764        # then mitmot, etc.
1765        new_inner_ins = []
1766        new_inner_outs = []
1767        nb_ins_groups = len(inner_ins[0])
1768        nb_outs_groups = len(inner_outs[0])
1769        for idx, nd in enumerate(nodes):
1770            # All inner_ins should have the same length
1771            assert len(inner_ins[idx]) == nb_ins_groups
1772
1773            # All inner_outs should have the same length, except if as_while,
1774            # in which case the first one should have one more element
1775            if as_while and idx > 0:
1776                assert len(inner_outs[idx]) == nb_outs_groups - 1
1777            else:
1778                assert len(inner_outs[idx]) == nb_outs_groups
1779
1780        for gr_idx in range(nb_ins_groups):
1781            for idx, nd in enumerate(nodes):
1782                new_inner_ins += inner_ins[idx][gr_idx]
1783
1784        for gr_idx in range(nb_outs_groups):
1785            for idx, nd in enumerate(nodes):
1786                if as_while and idx > 0 and gr_idx == (nb_outs_groups - 1):
1787                    # There is no condition on that node, skip it
1788                    pass
1789                else:
1790                    new_inner_outs += inner_outs[idx][gr_idx]
1791
1792        new_op = scan_op.Scan(new_inner_ins, new_inner_outs, info)
1793        new_outs = new_op(*outer_ins)
1794
1795        if not isinstance(new_outs, (list, tuple)):
1796            new_outs = [new_outs]
1797
1798        return list(zip(outer_outs, new_outs))
1799
1800    def belongs_to_set(self, node, set_nodes):
1801        """
1802        This function checks if node `node` belongs to `set_nodes`, in the
1803        sense that it can be merged together with every other node in
1804        `set_nodes`. In order for two nodes to be mergeable, they have to go
1805        over the same number of steps, have the same condition (if any),
1806        have the same value for truncate_gradient, and have the same mode.
1807        Questionable, we should also consider profile ?
1808
1809        """
1810        rep = set_nodes[0]
1811        if (rep.op.as_while != node.op.as_while or
1812                node.op.truncate_gradient != rep.op.truncate_gradient or
1813                node.op.mode != rep.op.mode):
1814            return False
1815
1816        nsteps = node.inputs[0]
1817        try:
1818            nsteps = int(get_scalar_constant_value(nsteps))
1819        except tensor.NotScalarConstantError:
1820            pass
1821
1822        rep_nsteps = rep.inputs[0]
1823        try:
1824            rep_nsteps = int(get_scalar_constant_value(rep_nsteps))
1825        except tensor.NotScalarConstantError:
1826            pass
1827
1828        if nsteps != rep_nsteps:
1829            return False
1830
1831        # Check to see if it is an input of a different node
1832        for nd in set_nodes:
1833            if gof.graph.is_in_ancestors(node, nd) or gof.graph.is_in_ancestors(nd, node):
1834                return False
1835
1836        if not node.op.as_while:
1837            return True
1838        cond = node.op.outputs[-1]
1839        rep_cond = rep.op.outputs[-1]
1840        return scan_utils.equal_computations([cond], [rep_cond],
1841                                             node.op.inputs,
1842                                             rep.op.inputs)
1843
1844    def apply(self, fgraph):
1845        # Collect all scan nodes ordered according to toposort
1846        scan_nodes = [nd for nd in fgraph.toposort()
1847                      if isinstance(nd.op, scan_op.Scan)]
1848
1849        # All sets of possibly mergeable nodes
1850        all_sets = []
1851
1852        for nd in scan_nodes:
1853            belongs_to_set_idx = -1
1854            for pos, subset in enumerate(all_sets):
1855                if self.belongs_to_set(nd, subset):
1856                    belongs_to_set_idx = pos
1857                    # It is possible that nd belongs to more than one subset.
1858                    # For instance, if we have 3 Scan nodes X, Y and Z, if Z
1859                    # depends on the output of X, then X and Z are incompatible
1860                    # and would create different subsets, but Y could be
1861                    # compatible with both X and Z. We choose the first one.
1862                    break
1863
1864            if belongs_to_set_idx == -1:
1865                all_sets.append([nd])
1866            else:
1867                all_sets[belongs_to_set_idx].append(nd)
1868
1869        for subset in all_sets:
1870            if len(subset) > 1:
1871                proposal = self.merge(subset)
1872                fgraph.replace_all_validate_remove(proposal,
1873                                                   remove=subset,
1874                                                   reason='scanOp_merge')
1875
1876
1877def has_duplicates(l):
1878    """
1879    Returns true if l has any duplicates (according to __eq__).
1880
1881    """
1882    return len(set(l)) < len(l)
1883
1884
1885def make_equiv(lo, li):
1886    """
1887    Builds a dictionary of equivalences between inner inputs based on
1888    the equivalence of their corresponding outer inputs.
1889
1890    """
1891    seeno = OrderedDict()
1892    left = []
1893    right = []
1894    for o, i in zip(lo, li):
1895        if o in seeno:
1896            left += [i]
1897            right += [o]
1898        else:
1899            seeno[o] = i
1900    return left, right
1901
1902
1903@gof.local_optimizer([scan_op.Scan])
1904def scan_merge_inouts(node):
1905    if not isinstance(node.op, scan_op.Scan):
1906        return False
1907
1908    # Do a first pass to merge identical external inputs.
1909    # Equivalent inputs will be stored in inp_equiv, then a new
1910    # scan node created without duplicates.
1911    a = scan_args(node.inputs, node.outputs,
1912                  node.op.inputs, node.op.outputs, node.op.info)
1913
1914    inp_equiv = OrderedDict()
1915
1916    if has_duplicates(a.outer_in_seqs):
1917        new_outer_seqs = []
1918        new_inner_seqs = []
1919        for out_seq, in_seq in zip(a.outer_in_seqs, a.inner_in_seqs):
1920            if out_seq in new_outer_seqs:
1921                i = new_outer_seqs.index(out_seq)
1922                inp_equiv[in_seq] = new_inner_seqs[i]
1923            else:
1924                new_outer_seqs.append(out_seq)
1925                new_inner_seqs.append(in_seq)
1926        a.outer_in_seqs = new_outer_seqs
1927        a.inner_in_seqs = new_inner_seqs
1928
1929    if has_duplicates(a.outer_in_non_seqs):
1930        new_outer_nseqs = []
1931        new_inner_nseqs = []
1932        for out_nseq, in_nseq in zip(a.outer_in_non_seqs, a.inner_in_non_seqs):
1933            if out_nseq in new_outer_nseqs:
1934                i = new_outer_nseqs.index(out_nseq)
1935                inp_equiv[in_nseq] = new_inner_nseqs[i]
1936            else:
1937                new_outer_nseqs.append(out_nseq)
1938                new_inner_nseqs.append(in_nseq)
1939        a.outer_in_non_seqs = new_outer_nseqs
1940        a.inner_in_non_seqs = new_inner_nseqs
1941
1942    if len(inp_equiv) > 0:
1943        # do the replacement now. The rest will be left to ScanSaveMem
1944        inner_inputs = a.inner_inputs
1945        outer_inputs = a.outer_inputs
1946        info = a.info
1947        a_inner_outs = a.inner_outputs
1948        inner_outputs = scan_utils.clone(a_inner_outs, replace=inp_equiv)
1949
1950        op = scan_op.Scan(inner_inputs, inner_outputs, info)
1951        outputs = op(*outer_inputs)
1952
1953        if not isinstance(outputs, (list, tuple)):
1954            outputs = [outputs]
1955
1956        na = scan_args(outer_inputs, outputs, op.inputs, op.outputs, op.info)
1957        remove = [node]
1958    else:
1959        na = a
1960        remove = []
1961
1962    # Now that the identical external inputs have been merged, we do a new
1963    # loop in order to merge external outputs that compute the same things
1964    # from the same inputs.
1965    left = []
1966    right = []
1967
1968    if has_duplicates(na.outer_in_shared):
1969        _left, _right = make_equiv(na.outer_in_shared, na.inner_in_shared)
1970        left += _left
1971        right += _right
1972    if has_duplicates(na.outer_in_sit_sot):
1973        _left, _right = make_equiv(na.outer_in_sit_sot, na.inner_in_sit_sot)
1974        left += _left
1975        right += _right
1976    if has_duplicates(na.outer_in_mit_mot):
1977        seen = OrderedDict()
1978        for omm, imm, _sl in zip(na.outer_in_mit_mot,
1979                                 na.inner_in_mit_mot, na.mit_mot_in_slices):
1980            sl = tuple(_sl)
1981            if (omm, sl) in seen:
1982                simm = seen[(omm, sl)]
1983                left += imm
1984                right += simm
1985            else:
1986                seen[(omm, sl)] = imm
1987
1988    if has_duplicates(na.outer_in_mit_sot):
1989        seen = OrderedDict()
1990        for oms, ims, _sl in zip(na.outer_in_mit_sot,
1991                                 na.inner_in_mit_sot,
1992                                 na.mit_sot_in_slices):
1993            sl = tuple(_sl)
1994            if (oms, sl) in seen:
1995                sims = seen[(oms, sl)]
1996                left += ims
1997                right += sims
1998            else:
1999                seen[(oms, sl)] = ims
2000
2001    def map_out(outer_i, inner_o, outer_o, seen):
2002        # Return the outer input corresponding to an
2003        # (outer input, inner output) pair. If we see that pair for the first
2004        # time, return the provided outer output. If an equivalent pair had
2005        # already been seen, return that one instead.
2006        # Note that we need to check that the outer input match as well,
2007        # because they could have different sizes, and the corresponding
2008        # outer outputs cannot be merged in that case.
2009        for s_outer_i, s_inner_o, s_outer_o in seen:
2010            if (equal_computations([inner_o], [s_inner_o], left, right) and
2011                    outer_i == s_outer_i):
2012                return s_outer_o
2013        seen.append((outer_i, inner_o, outer_o))
2014        return outer_o
2015
2016    seen = []
2017
2018    assert len(na.outer_in_nit_sot) == len(na.inner_out_nit_sot)
2019    assert len(na.inner_out_nit_sot) == len(na.outer_out_nit_sot)
2020    na.outer_out_nit_sot = [
2021        map_out(outer_i, inner_o, outer_o, seen)
2022        for outer_i, inner_o, outer_o in zip(na.outer_in_nit_sot,
2023                                             na.inner_out_nit_sot,
2024                                             na.outer_out_nit_sot)]
2025
2026    seen = []
2027    assert len(na.outer_in_sit_sot) == len(na.inner_out_sit_sot)
2028    assert len(na.inner_out_sit_sot) == len(na.outer_out_sit_sot)
2029    na.outer_out_sit_sot = [
2030        map_out(outer_i, inner_o, outer_o, seen)
2031        for outer_i, inner_o, outer_o in zip(na.outer_in_sit_sot,
2032                                             na.inner_out_sit_sot,
2033                                             na.outer_out_sit_sot)]
2034
2035    seen = []
2036    assert len(na.outer_in_mit_sot) == len(na.inner_out_mit_sot)
2037    assert len(na.inner_out_mit_sot) == len(na.outer_out_mit_sot)
2038    na.outer_out_mit_sot = [
2039        map_out(outer_i, inner_o, outer_o, seen)
2040        for outer_i, inner_o, outer_o in zip(na.outer_in_mit_sot,
2041                                             na.inner_out_mit_sot,
2042                                             na.outer_out_mit_sot)]
2043
2044    seen = []
2045    new_outer_out_mit_mot = []
2046    assert len(na.outer_in_mit_mot) == len(na.inner_out_mit_mot)
2047    assert len(na.inner_out_mit_mot) == len(na.outer_out_mit_mot)
2048    assert len(na.outer_out_mit_mot) == len(na.mit_mot_out_slices)
2049    for outer_imm, inner_omm, outer_omm, osl in zip(na.outer_in_mit_mot,
2050                                                    na.inner_out_mit_mot,
2051                                                    na.outer_out_mit_mot,
2052                                                    na.mit_mot_out_slices):
2053        for s_outer_imm, s_inner_omm, s_outer_omm, sosl in seen:
2054            if (osl == sosl and
2055                equal_computations(inner_omm, s_inner_omm, left, right) and
2056                    outer_imm == s_outer_imm):
2057
2058                new_outer_out_mit_mot.append(s_outer_omm)
2059                break
2060        else:
2061            seen.append((outer_imm, inner_omm, outer_omm, osl))
2062            new_outer_out_mit_mot.append(outer_omm)
2063    na.outer_out_mit_mot = new_outer_out_mit_mot
2064    if remove:
2065        return OrderedDict([("remove", remove)] +
2066                           list(zip(node.outputs, na.outer_outputs)))
2067    return na.outer_outputs
2068
2069
2070class PushOutDot1(gof.Optimizer):
2071    """
2072    Graph optimizer for Scan(makes it run inplace).
2073
2074    """
2075
2076    def __init__(self):
2077        Optimizer.__init__(self)
2078
2079    def add_requirements(self, fgraph):
2080        fgraph.attach_feature(toolbox.ReplaceValidate())
2081
2082    def apply(self, fgraph):
2083
2084        nodes = fgraph.toposort()
2085        scan_nodes = [x for x in nodes if (isinstance(x.op, scan_op.Scan))]
2086        for node in scan_nodes:
2087            self.apply_opt(fgraph, node)
2088
2089    def apply_opt(self, fgraph, node):
2090        # Replace pattern of the form
2091        # x[t] = x[t-1] + dot(seq[t], value)
2092        # with Sequence.reshape((-1, seq.shape[2])) \dot Value
2093        # When seq[t] is a vector/matrix  and `value` is a matrix
2094        # Note that this works when only you need X[-1] in the end
2095        # and assumes dimshuffle are applied to vectors before calling dot
2096        op = node.op
2097        sitsot_ins = op.inner_sitsot(op.inputs)
2098        sitsot_outs = op.inner_sitsot_outs(op.outputs)
2099        outer_sitsot = op.outer_sitsot_outs(node)
2100        seqs = op.inner_seqs(op.inputs)
2101        for inp, out, outer_out in zip(sitsot_ins, sitsot_outs, outer_sitsot):
2102
2103            if (out.owner and
2104                isinstance(out.owner.op, theano.tensor.Elemwise) and
2105                isinstance(out.owner.op.scalar_op, theano.scalar.Add) and
2106                inp in out.owner.inputs and
2107                len(outer_out.clients) == 1 and
2108                not isinstance(outer_out.clients[0][0], str) and
2109                isinstance(outer_out.clients[0][0].op, theano.tensor.Subtensor) and
2110                    outer_out.clients[0][0].op.idx_list == (-1,)):
2111
2112                x = out.owner.inputs[0]
2113                if x == inp:
2114                    x = out.owner.inputs[1]
2115                # We need to check if x is the result of an outer product
2116                if (x.owner and isinstance(x.owner.op, theano.tensor.Dot) and
2117                        x.owner.inputs[0].ndim == 2 and x.owner.inputs[1].ndim == 2):
2118
2119                    # We need to check if any of the inputs are a sequence
2120                    inp1 = x.owner.inputs[0]
2121                    inp2 = x.owner.inputs[1]
2122
2123                    if inp1 in seqs or inp2 in seqs:
2124                        new_scan_out = inp1
2125
2126                        if inp1 in seqs:
2127                            new_scan_out = inp2
2128                        idx = sitsot_outs.index(out)
2129                        # We've found our pattern and need to construct a new
2130                        # scan node to replace this one. For this we need to
2131                        # replace the sit_sot output with a nit_sot output
2132
2133                        # First let us split all arguments according to their
2134                        # corresponding categories
2135
2136                        inner_seqs = op.inner_seqs(op.inputs)
2137                        outer_seqs = op.outer_seqs(node)
2138                        inner_mitmot = op.inner_mitmot(op.inputs)
2139                        outer_mitmot = op.outer_mitmot(node)
2140                        inner_mitmot_outs = op.inner_mitmot_outs(op.outputs)
2141                        inner_mitsot = op.inner_mitsot(op.inputs)
2142                        outer_mitsot = op.outer_mitsot(node)
2143                        inner_mitsot_outs = op.inner_mitsot_outs(op.outputs)
2144                        inner_sitsot = op.inner_sitsot(op.inputs)
2145                        outer_sitsot = op.outer_sitsot(node)
2146                        inner_sitsot_outs = op.inner_sitsot_outs(op.outputs)
2147                        outer_nitsot = op.outer_nitsot(node)
2148                        inner_nitsot_outs = op.inner_nitsot_outs(op.outputs)
2149                        inner_shared = op.inner_shared(op.inputs)
2150                        outer_shared = op.outer_shared(node)
2151                        inner_shared_outs = op.inner_shared_outs(op.outputs)
2152                        inner_non_seqs = op.inner_non_seqs(op.inputs)
2153                        outer_non_seqs = op.outer_non_seqs(node)
2154
2155                        new_info = op.info.copy()
2156                        st = len(op.mitmot_taps()) + len(op.mitsot_taps())
2157
2158                        new_info['tap_array'] = (
2159                            new_info['tap_array'][:st + idx] +
2160                            new_info['tap_array'][st + idx + 1:])
2161                        new_info['n_sit_sot'] -= 1
2162                        new_info['n_nit_sot'] += 1
2163                        inner_sitsot = (inner_sitsot[:idx] +
2164                                        inner_sitsot[idx + 1:])
2165                        outer_sitsot = (outer_sitsot[:idx] +
2166                                        outer_sitsot[idx + 1:])
2167                        inner_sitsot_outs = (inner_sitsot_outs[:idx] +
2168                                             inner_sitsot_outs[idx + 1:])
2169                        # add n_steps as the length
2170                        inner_nitsot_outs.append(new_scan_out)
2171
2172                        _new_inner_inps = (inner_seqs +
2173                                           inner_mitmot +
2174                                           inner_mitsot +
2175                                           inner_sitsot +
2176                                           inner_shared +
2177                                           inner_non_seqs)
2178                        _new_inner_outs = (inner_mitmot_outs +
2179                                           inner_mitsot_outs +
2180                                           inner_sitsot_outs +
2181                                           inner_nitsot_outs +
2182                                           inner_shared_outs)
2183                        new_inner_inps, new_inner_outs =\
2184                            scan_utils.reconstruct_graph(_new_inner_inps,
2185                                                         _new_inner_outs)
2186                        new_op = scan_op.Scan(new_inner_inps, new_inner_outs,
2187                                              new_info)
2188                        _scan_inputs = ([node.inputs[0]] +
2189                                        outer_seqs +
2190                                        outer_mitmot +
2191                                        outer_mitsot +
2192                                        outer_sitsot +
2193                                        outer_shared +
2194                                        outer_nitsot +
2195                                        [node.inputs[0]] +
2196                                        outer_non_seqs)
2197
2198                        new_outs = new_op(*_scan_inputs)
2199                        if type(new_outs) not in (list, tuple):
2200                            new_outs = [new_outs]
2201
2202                        # We need now to pair correctly the new outputs
2203                        # with the old ones
2204
2205                        outer_nitsot_outs = new_op.outer_nitsot_outs(new_outs)
2206
2207                        _val = outer_nitsot_outs[-1]
2208                        outer_nitsot_outs = outer_nitsot_outs[:-1]
2209                        if inp1 in seqs:
2210                            _out_seq = op.outer_seqs(node)[seqs.index(inp1)]
2211                            # We need to clip the seq to the number of steps
2212                            _out_seq = _out_seq[:node.inputs[0]]
2213                            sh0 = _out_seq.shape[0]
2214                            sh1 = _out_seq.shape[1]
2215                            sh2 = _out_seq.shape[2]
2216                            out_seq = _out_seq.dimshuffle(1, 0, 2)
2217                            out_seq = out_seq.reshape((sh1, sh0 * sh2))
2218                            sh0 = _val.shape[0]
2219                            sh1 = _val.shape[1]
2220                            sh2 = _val.shape[2]
2221
2222                            val = _val.reshape((sh0 * sh1, sh2))
2223                            new_out = tensor.dot(out_seq, val)
2224                        else:
2225                            _out_seq = op.outer_seqs(node)[seqs.index(inp2)]
2226                            out_seq = _out_seq.reshape(
2227                                (_out_seq.shape[0] * _out_seq.shape[1],
2228                                 _out_seq.shape[2]))
2229
2230                            val = _val.dimshuffle(1, 0, 2).reshape(
2231                                (_val.shape[1],
2232                                 _val.shape[0] * _val.shape[2]))
2233                            new_out = tensor.dot(val, out_seq)
2234
2235                        pos = node.outputs.index(outer_out)
2236                        old_new = list(zip(node.outputs[:pos], new_outs[:pos]))
2237                        old = node.outputs[pos].clients[0][0].outputs[0]
2238                        old_new.append((old, new_out))
2239                        old_new += list(zip(node.outputs[pos + 1:],
2240                                            new_outs[pos:]))
2241                        fgraph.replace_all_validate_remove(
2242                            old_new, remove=[node], reason='scan_pushout_dot1')
2243
2244
2245# I've added an equilibrium because later scan optimization in the sequence
2246# can make it such that earlier optimizations should apply. However, in
2247# general I do not expect the sequence to run more then once
2248scan_eqopt1 = theano.gof.EquilibriumDB()
2249scan_seqopt1 = theano.gof.SequenceDB()
2250scan_eqopt2 = theano.gof.EquilibriumDB()
2251
2252# scan_eqopt1 before ShapeOpt at 0.1
2253# This is needed to don't have ShapeFeature trac old Scan that we
2254# don't want to reintroduce.
2255optdb.register('scan_eqopt1', scan_eqopt1, .05, 'fast_run', 'scan')
2256# We run before blas opt at 1.7 and specialize 2.0
2257# but after stabilize at 1.5. Should we put it before stabilize?
2258optdb.register('scan_eqopt2', scan_eqopt2, 1.6, 'fast_run', 'scan')
2259# ScanSaveMem should execute only once per node.
2260optdb.register('scanOp_save_mem', ScanSaveMem(), 1.61, 'fast_run', 'scan')
2261optdb.register('scanOp_make_inplace',
2262               ScanInplaceOptimizer(typeInfer=None),
2263               75,
2264               'fast_run',
2265               'inplace',
2266               'scan')
2267
2268scan_eqopt1.register(
2269    'all_pushout_opt', scan_seqopt1, 1, 'fast_run', 'scan')
2270
2271
2272scan_seqopt1.register('scanOp_remove_constants_and_unused_inputs0',
2273                      opt.in2out(remove_constants_and_unused_inputs_scan,
2274                                 ignore_newtrees=True),
2275                      1,
2276                      'remove_constants_and_unused_inputs_scan',
2277                      'fast_run',
2278                      'scan')
2279
2280
2281scan_seqopt1.register('scanOp_pushout_nonseqs_ops',
2282                      PushOutNonSeqScan(),
2283                      2,
2284                      'fast_run',
2285                      'scan')
2286
2287
2288scan_seqopt1.register('scanOp_pushout_seqs_ops',
2289                      PushOutSeqScan(),
2290                      3,
2291                      'fast_run',
2292                      'scan')
2293
2294
2295scan_seqopt1.register('scan_pushout_dot1',
2296                      PushOutDot1(),
2297                      4,
2298                      'fast_run',
2299                      'more_mem',
2300                      'scan')
2301
2302
2303scan_seqopt1.register('scanOp_pushout_output',
2304                      PushOutScanOutput(),
2305                      5,
2306                      'fast_run',
2307                      'more_mem',
2308                      'scan')
2309
2310
2311scan_eqopt2.register('constant_folding_for_scan2',
2312                     opt.in2out(tensor.opt.constant_folding,
2313                                ignore_newtrees=True),
2314                     1,
2315                     'fast_run',
2316                     'scan')
2317
2318
2319scan_eqopt2.register('scanOp_remove_constants_and_unused_inputs1',
2320                     opt.in2out(remove_constants_and_unused_inputs_scan,
2321                                ignore_newtrees=True),
2322                     2,
2323                     'remove_constants_and_unused_inputs_scan',
2324                     'fast_run',
2325                     'scan')
2326
2327
2328# after const merge but before stabilize so that we can have identity
2329# for equivalent nodes but we still have the chance to hoist stuff out
2330# of the scan later.
2331scan_eqopt2.register('scanOp_merge',
2332                     ScanMerge(),
2333                     4,
2334                     'fast_run',
2335                     'scan')
2336
2337# After Merge optimization
2338scan_eqopt2.register('scanop_remove_constants_and_unused_inputs2',
2339                     opt.in2out(remove_constants_and_unused_inputs_scan,
2340                                ignore_newtrees=True),
2341                     5,
2342                     'remove_constants_and_unused_inputs_scan',
2343                     'fast_run',
2344                     'scan')
2345
2346scan_eqopt2.register('scanOp_merge_inouts',
2347                     opt.in2out(scan_merge_inouts, ignore_newtrees=True),
2348                     6,
2349                     'scan_merge_inouts',
2350                     'fast_run',
2351                     'scan')
2352
2353# After everything else
2354scan_eqopt2.register('scanOp_remove_constants_and_unused_inputs3',
2355                     opt.in2out(remove_constants_and_unused_inputs_scan,
2356                                ignore_newtrees=True),
2357                     8,
2358                     'remove_constants_and_unused_inputs_scan',
2359                     'fast_run',
2360                     'scan')
2361