1"""
2Defines the base class for optimizations as well as a certain
3amount of useful generic optimization tools.
4
5"""
6from __future__ import absolute_import, print_function, division
7
8from collections import deque, defaultdict, OrderedDict
9import contextlib
10import copy
11import inspect
12import logging
13import pdb
14import sys
15import time
16import warnings
17import traceback
18
19import numpy as np
20
21import theano
22from theano import config
23from theano.compat import izip
24from six import string_types, iteritems, itervalues, integer_types
25from six.moves import reduce
26from theano.gof import graph, op, utils, unify, toolbox
27from theano.gof.fg import InconsistencyError
28from theano.misc.ordered_set import OrderedSet
29
30from . import destroyhandler as dh
31
32_logger = logging.getLogger('theano.gof.opt')
33_optimizer_idx = [0]
34
35
36def _list_of_nodes(fgraph):
37    return list(graph.io_toposort(fgraph.inputs, fgraph.outputs))
38
39
40class LocalMetaOptimizerSkipAssertionError(AssertionError):
41    """This is an AssertionError, but instead of having the
42    LocalMetaOptimizer print the error, it just skip that
43    compilation.
44
45    """
46    pass
47
48
49class Optimizer(object):
50    """
51
52    An L{Optimizer} can be applied to an L{FunctionGraph} to transform it.
53    It can represent an optimization or in general any kind
54    of transformation you could apply to an L{FunctionGraph}.
55
56    """
57
58    def __hash__(self):
59        if not hasattr(self, '_optimizer_idx'):
60            self._optimizer_idx = _optimizer_idx[0]
61            _optimizer_idx[0] += 1
62        return self._optimizer_idx
63
64    def __eq__(self, other):
65        # added to override the  __eq__ implementation that may be inherited
66        # in subclasses from other bases.
67        return id(self) == id(other)
68
69    def __ne__(self, other):
70        # added to override the  __ne__ implementation that may be inherited
71        # in subclasses from other bases.
72        return id(self) != id(other)
73
74    def apply(self, fgraph):
75        """
76
77        Applies the optimization to the provided L{FunctionGraph}. It may
78        use all the methods defined by the L{FunctionGraph}. If the
79        L{Optimizer} needs to use a certain tool, such as an
80        L{InstanceFinder}, it can do so in its L{add_requirements} method.
81
82        """
83        pass
84
85    def optimize(self, fgraph, *args, **kwargs):
86        """
87
88        This is meant as a shortcut to:
89          opt.add_requirements(fgraph)
90          opt.apply(fgraph)
91
92        """
93        self.add_requirements(fgraph)
94        try:
95            orig = theano.tensor.basic.constant.enable
96            theano.tensor.basic.constant.enable = False
97            ret = self.apply(fgraph, *args, **kwargs)
98        finally:
99            theano.tensor.basic.constant.enable = orig
100        return ret
101
102    def __call__(self, fgraph):
103        """
104
105        Same as self.optimize(fgraph).
106
107        """
108        return self.optimize(fgraph)
109
110    def add_requirements(self, fgraph):
111        """
112
113        Add features to the fgraph that are required to apply the optimization.
114        For example:
115          fgraph.attach_feature(History())
116          fgraph.attach_feature(MyFeature())
117          etc.
118
119        """
120        pass
121
122    def print_summary(self, stream=sys.stdout, level=0, depth=-1):
123        name = getattr(self, 'name', None)
124        print("%s%s %s id=%i" % (
125            (' ' * level), self.__class__.__name__, name, id(self)), file=stream)
126
127    @staticmethod
128    def print_profile(stream, prof, level=0):
129        if prof is not None:
130            raise NotImplementedError(
131                "The function print_profile must be overrided if the"
132                " optimizer return profiling information.")
133
134
135class FromFunctionOptimizer(Optimizer):
136    """
137    WRITEME
138
139    """
140    def __init__(self, fn, requirements=()):
141        self.apply = fn
142        self.requirements = requirements
143
144    def add_requirements(self, fgraph):
145        for req in self.requirements:
146            req(fgraph)
147
148    def print_summary(self, stream=sys.stdout, level=0, depth=-1):
149        print("%s%s id=%i" % (
150            ' ' * level,
151            str(self.apply),
152            id(self)), file=stream)
153
154    def __call__(self, *args, **kwargs):
155        return self.fn(*args, **kwargs)
156
157    def __str__(self):
158        return self.__name__
159
160
161def optimizer(f):
162    """
163    Decorator for FromFunctionOptimizer.
164
165    """
166    rval = FromFunctionOptimizer(f)
167    rval.__name__ = f.__name__
168    return rval
169
170
171def inplace_optimizer(f):
172    """
173    Decorator for FromFunctionOptimizer.
174
175    """
176    dh_handler = dh.DestroyHandler
177    requirements = (lambda fgraph:
178                    fgraph.attach_feature(dh_handler()),)
179    rval = FromFunctionOptimizer(f, requirements)
180    rval.__name__ = f.__name__
181    return rval
182
183
184class SeqOptimizer(Optimizer, list):
185    # inherit from Optimizer first to get Optimizer.__hash__
186    """
187
188    Takes a list of L{Optimizer} instances and applies them
189    sequentially.
190
191    """
192    @staticmethod
193    def warn(exc, self, optimizer):
194        """
195        Default failure_callback for SeqOptimizer.
196
197        """
198        _logger.error("SeqOptimizer apply %s" % str(optimizer))
199        _logger.error("Traceback:")
200        _logger.error(traceback.format_exc())
201        if config.on_opt_error == 'raise':
202            raise exc
203        elif config.on_opt_error == 'pdb':
204            pdb.post_mortem(sys.exc_info()[2])
205
206    def __init__(self, *opts, **kw):
207        """
208        Parameters
209        ----------
210        *opts :
211            The List of optimizers to be applied to a node
212        failure_callback : callable or None
213            Keyword only argument. A callback used when a failure
214            happen during optimization.
215
216        """
217        if len(opts) == 1 and isinstance(opts[0], (list, tuple)):
218            opts = opts[0]
219        self[:] = opts
220        self.failure_callback = kw.pop('failure_callback', None)
221        assert len(kw) == 0
222
223    def apply(self, fgraph):
224        """
225
226        Applies each L{Optimizer} in self in turn.
227
228        """
229        l = []
230        if fgraph.profile:
231            validate_before = fgraph.profile.validate_time
232            sub_validate_time = [validate_before]
233            callbacks_before = fgraph.execute_callbacks_times.copy()
234        else:
235            sub_validate_time = []
236            callbacks_before = []
237        callback_before = fgraph.execute_callbacks_time
238        nb_node_before = len(fgraph.apply_nodes)
239        sub_profs = []
240        nb_nodes = []
241
242        self.pre_profile = (
243            self, l, -1, -1, nb_node_before,
244            -1, sub_profs, sub_validate_time,
245            nb_nodes, {})
246        try:
247            for optimizer in self:
248                try:
249                    nb_nodes_before = len(fgraph.apply_nodes)
250                    t0 = time.time()
251                    sub_prof = optimizer.optimize(fgraph)
252                    l.append(float(time.time() - t0))
253                    sub_profs.append(sub_prof)
254                    nb_nodes.append((nb_nodes_before,
255                                     len(fgraph.apply_nodes)))
256                    if fgraph.profile:
257                        sub_validate_time.append(fgraph.profile.validate_time)
258                except AssertionError:
259                    # do not catch Assertion failures
260                    raise
261                except Exception as e:
262                    if self.failure_callback:
263                        self.failure_callback(e, self, optimizer)
264                        continue
265                    else:
266                        raise
267        finally:
268
269            if fgraph.profile:
270                validate_time = fgraph.profile.validate_time - validate_before
271                callbacks_time = {}
272                for k, v in iteritems(fgraph.execute_callbacks_times):
273                    if k in callbacks_before:
274                        t = v - callbacks_before[k]
275                        if t > 0:
276                            callbacks_time[k] = t
277                    else:
278                        callbacks_time[k] = v
279            else:
280                validate_time = None
281                callbacks_time = {}
282            callback_time = fgraph.execute_callbacks_time - callback_before
283            self.pre_profile = (
284                self, l, validate_time, callback_time, nb_node_before,
285                len(fgraph.apply_nodes), sub_profs, sub_validate_time,
286                nb_nodes, callbacks_time)
287        return self.pre_profile
288
289    def __str__(self):
290        return "SeqOpt(%s)" % list.__str__(self)
291
292    def __repr__(self):
293        return list.__repr__(self)
294
295    def print_summary(self, stream=sys.stdout, level=0, depth=-1):
296        name = getattr(self, 'name', None)
297        print("%s%s %s id=%i" % (
298            (' ' * level), self.__class__.__name__, name, id(self)), file=stream)
299        # This way, -1 will do all depth
300        if depth != 0:
301            depth -= 1
302            for opt in self:
303                opt.print_summary(stream, level=(level + 2), depth=depth)
304
305    @staticmethod
306    def print_profile(stream, prof, level=0):
307        (opts, prof, validate_time, callback_time,
308         nb_node_before, nb_node_after, sub_profs, sub_validate_time,
309         nb_nodes, callbacks_time) = prof
310        blanc = ('    ' * level)
311
312        print(blanc, "SeqOptimizer", end=' ', file=stream)
313        if hasattr(opts, "name"):
314            print(blanc, opts.name, end=' ', file=stream)
315        elif hasattr(opts, "__name__"):
316            print(blanc, opts.__name__, end=' ', file=stream)
317        print((" time %.3fs for %d/%d nodes"
318               " before/after optimization" % (
319                   sum(prof), nb_node_before, nb_node_after)), file=stream)
320        print(blanc, "  %.3fs for callback" % (callback_time), file=stream)
321        print(blanc, "      %.3fs for fgraph.validate()" % (validate_time),
322              file=stream)
323        if callback_time > 1:
324            print(blanc, "  callbacks_time", file=stream)
325            for i in sorted(iteritems(callbacks_time), key=lambda a: -a[1]):
326                if i[1] > 0:
327                    # We want to have the __str__ called, so we can't
328                    # just print i.
329                    print(blanc, "      ", i[0], ',', i[1], file=stream)
330
331        if level == 0:
332            print(blanc,
333                  "  time      - (name, class, index, nodes before, nodes after) - validate time",
334                  file=stream)
335        ll = []
336        for (opt, nb_n) in zip(opts, nb_nodes):
337            if hasattr(opt, "__name__"):
338                name = opt.__name__
339            else:
340                name = opt.name
341            idx = opts.index(opt)
342            ll.append((name, opt.__class__.__name__,
343                       idx) + nb_n)
344        lll = sorted(zip(prof, ll), key=lambda a: a[0])
345
346        for (t, opt) in lll[::-1]:
347            i = opt[2]
348            if sub_validate_time:
349                val_time = sub_validate_time[i + 1] - sub_validate_time[i]
350                print(blanc, '  %.6fs - %s - %.3fs' % (
351                    t, opt, val_time), file=stream)
352            else:
353                print(blanc, '  %.6fs - %s' % (t, opt), file=stream)
354
355            if sub_profs[i]:
356                opts[i].print_profile(stream, sub_profs[i],
357                                      level=level + 1)
358        print(file=stream)
359
360    @staticmethod
361    def merge_profile(prof1, prof2):
362        """
363        Merge 2 profiles returned by this cass apply() fct.
364
365        """
366        new_t = []  # the time for the optimization
367        new_l = []  # the optimization
368        new_sub_profile = []
369        # merge common(same object) opt
370        for l in set(prof1[0]).intersection(set(prof2[0])):
371            idx1 = prof1[0].index(l)
372            idx2 = prof2[0].index(l)
373            new_t.append(prof1[1][idx1] +
374                         prof2[1][idx2])
375            new_l.append(l)
376            if hasattr(l, 'merge_profile'):
377                assert len(prof1[6][idx1]) == len(prof2[6][idx2])
378                new_sub_profile.append(l.merge_profile(prof1[6][idx1],
379                                                       prof2[6][idx2]))
380            else:
381                new_sub_profile.append(None)
382
383        # merge not common opt
384        from six import StringIO
385        for l in set(prof1[0]).symmetric_difference(set(prof2[0])):
386            # The set trick above only work for the same object optimization
387            # It don't work for equivalent optimization.
388            # So we try to merge equivalent optimization here.
389            new_l_names = [o.name for o in new_l]
390            if l.name in new_l_names:
391                idx = new_l_names.index(l.name)
392                io1 = StringIO()
393                io2 = StringIO()
394                l.print_summary(io1)
395                new_l[idx].print_summary(io2)
396                if io1.read() == io2.read():
397                    if l in prof1[0]:
398                        p = prof1
399                    else:
400                        p = prof2
401                    new_t[idx] += p[1][p[0].index(l)]
402                    if hasattr(l, 'merge_profile'):
403                        assert len(p[6][p[0].index(l)]) == \
404                            len(new_sub_profile[idx])
405                        new_sub_profile[idx] = l.merge_profile(
406                            new_sub_profile[idx], p[6][p[0].index(l)])
407                    else:
408                        new_sub_profile[idx] = None
409                continue
410            if l in prof1[0]:
411                p = prof1
412            else:
413                p = prof2
414            new_t.append(p[1][p[0].index(l)])
415            idx = p[0].index(l)
416            new_l.append(l)
417            new_sub_profile.append(p[6][idx])
418
419        new_opt = SeqOptimizer(*new_l)
420        new_nb_nodes = []
421        for p1, p2 in zip(prof1[8], prof2[8]):
422            new_nb_nodes.append((p1[0] + p2[0], p1[1] + p2[1]))
423        new_nb_nodes.extend(prof1[8][len(new_nb_nodes):])
424        new_nb_nodes.extend(prof2[8][len(new_nb_nodes):])
425
426        new_callbacks_times = merge_dict(prof1[9], prof2[9])
427        # We need to assert based on the name as we merge also based on
428        # the name.
429        assert set([l.name for l in prof1[0]]).issubset(
430            set([l.name for l in new_l]))
431        assert set([l.name for l in prof2[0]]).issubset(
432            set([l.name for l in new_l]))
433        assert len(new_t) == len(new_opt) == len(new_sub_profile)
434        return (new_opt, new_t, prof1[2] + prof2[2],
435                prof1[3] + prof2[3],
436                -1, -1, new_sub_profile, [],
437                new_nb_nodes,
438                new_callbacks_times)
439
440
441class _metadict:
442    """
443    WRITEME
444
445    """
446
447    # dict that accepts unhashable keys
448    # uses an associative list
449    # for internal use only
450    def __init__(self):
451        self.d = {}
452        self.l = []
453
454    def __getitem__(self, item):
455        return self.get(item, None)
456
457    def __setitem__(self, item, value):
458        try:
459            self.d[item] = value
460        except Exception:
461            for i, (key, val) in enumerate(self.l):
462                if key == item:
463                    self.l[i] = (item, value)
464                    return
465            self.l.append((item, value))
466
467    def __delitem__(self, item):
468        try:
469            if item in self.d:
470                del self.d[item]
471                return
472        except TypeError as e:
473            assert "unhashable type" in str(e)
474        for i, (key, val) in enumerate(self.l):
475            if key == item:
476                del self.l[i]
477                return
478            raise KeyError(item)
479
480    def discard(self, item):
481        try:
482            if item in self.d:
483                del self.d[item]
484                return
485        except TypeError as e:
486            assert "unhashable type" in str(e)
487        for i, (key, val) in enumerate(self.l):
488            if key == item:
489                del self.l[i]
490                return
491
492    def get(self, item, default):
493        try:
494            return self.d[item]
495        except Exception:
496            for item2, value in self.l:
497                try:
498                    if item == item2:
499                        return value
500                    if item.equals(item2):
501                        return value
502                except Exception:
503                    if item is item2:
504                        return value
505            return default
506
507    def clear(self):
508        self.d = {}
509        self.l = []
510
511    def __str__(self):
512        return "(%s, %s)" % (self.d, self.l)
513
514
515class MergeFeature(object):
516    """
517    Keeps track of variables in fgraph that cannot be merged together.
518
519    That way, the MergeOptimizer can remember the result of the last merge
520    pass on the fgraph.
521
522    """
523    def on_attach(self, fgraph):
524        assert not hasattr(fgraph, 'merge_feature')
525        fgraph.merge_feature = self
526
527        # For constants
528        self.seen_constants = set()
529        # variable -> signature (for constants)
530        self.const_sig = _metadict()
531        # signature -> variable (for constants)
532        self.const_sig_inv = _metadict()
533
534        # For all Apply nodes
535        # Set of distinct (not mergeable) nodes
536        self.nodes_seen = set()
537        # Ordered set of distinct (not mergeable) nodes without any input
538        self.noinput_nodes = OrderedSet()
539
540        # Each element of scheduled is a list of list of (out, new_out) pairs.
541        # Each list of pairs represent the substitution needed to replace all
542        # the outputs of a node with the outputs of a replacement candidate.
543        # Each node can have several candidates. For instance, if "node" has
544        # 2 outputs, and there are 3 replacement candidates, we will have:
545        # shelf.scheduled = [
546        #    [[(node.out1, cand1.out1), (node.out2, cand1.out2)],
547        #     [(node.out1, cand2.out1), (node.out2, cand2.out2)],
548        #     [(node.out1, cand3.out1), (node.out2, cand3.out2)]]]
549        self.scheduled = []
550
551        # List of (node, candidate) pairs, where we tried to replace node by
552        # candidate, but it failed. This is used to avoid infinite loops
553        # during the replacement phase.
554        self.blacklist = []
555
556        for node in fgraph.toposort():
557            self.on_import(fgraph, node, "on_attach")
558
559    def on_change_input(self, fgraph, node, i, r, new_r, reason):
560        # If inputs to node change, it is not guaranteed that it is distinct
561        # from the other nodes in nodes_seen
562        if node in self.nodes_seen:
563            self.nodes_seen.discard(node)
564            self.process_node(fgraph, node)
565
566        # Since we are in on_change_input, node should have inputs.
567        if not isinstance(node, string_types):
568            assert node.inputs
569
570        if isinstance(new_r, graph.Constant):
571            self.process_constant(fgraph, new_r)
572
573    def on_import(self, fgraph, node, reason):
574        for c in node.inputs:
575            if isinstance(c, graph.Constant):
576                self.process_constant(fgraph, c)
577
578        self.process_node(fgraph, node)
579
580    def on_prune(self, fgraph, node, reason):
581        self.nodes_seen.discard(node)
582        if not node.inputs:
583            self.noinput_nodes.discard(node)
584        for c in node.inputs:
585            if isinstance(c, graph.Constant) and (len(c.clients) <= 1):
586                # This was the last node using this constant
587                sig = self.const_sig[c]
588                self.const_sig.discard(c)
589                self.const_sig_inv.discard(sig)
590                self.seen_constants.discard(id(c))
591
592    def process_constant(self, fgraph, c):
593        """
594        Check if a constant can be merged, and queue that replacement.
595
596        """
597        if id(c) in self.seen_constants:
598            return
599        sig = c.merge_signature()
600        other_c = self.const_sig_inv.get(sig, None)
601        if other_c is not None:
602            # multiple names will clobber each other..
603            # we adopt convention to keep the last name
604            if c.name:
605                other_c.name = c.name
606            self.scheduled.append([[(c, other_c, 'merge')]])
607        else:
608            # this is a new constant
609            self.const_sig[c] = sig
610            self.const_sig_inv[sig] = c
611            self.seen_constants.add(id(c))
612
613    def process_node(self, fgraph, node):
614        """
615        Check if a node can be merged, and queue that replacement.
616
617        """
618        if node in self.nodes_seen:
619            return
620
621        node_has_assert = False
622
623        # These asserts ensure that the fgraph has set the clients field
624        # properly.
625        # The clients should at least contain `node` itself!
626        if node.inputs:
627            # Take the smallest clients list. Some ops like elemwise
628            # have optimization that put constant as the first inputs.
629            # As constant have in general more clients than other type of nodes
630            # using always inputs[0] make us look at more nodes.
631            # Always pick the smallest clints list between inputs 0
632            # and -1 speed up optimization.
633
634            if len(node.inputs[0].clients) < len(node.inputs[-1].clients):
635                clients = node.inputs[0].clients
636            else:
637                clients = node.inputs[-1].clients
638            assert len(clients) > 0
639
640            merge_candidates = [c for c, i in clients if c in self.nodes_seen]
641
642            # Put all clients of Assert inputs (if exist) into merge_candidates
643            # TODO: Deactivated for now as this cause cycle in the graph.
644            # (There is a second deactivation part below.)
645            for i in []:  # node.inputs:
646                if i.owner and isinstance(i.owner.op,
647                                          theano.tensor.opt.Assert):
648                    node_has_assert = True
649                    assert_clients = [c for (c, _) in i.owner.inputs[0].clients
650                                      if c in self.nodes_seen]
651
652                    for idx in range(len(assert_clients)):
653                        client = assert_clients[idx]
654                        if isinstance(i.owner.op, theano.tensor.opt.Assert):
655                            for c in client.outputs[0].clients:
656                                if c[0] in self.nodes_seen:
657                                    assert_clients.append(c[0])
658
659                    merge_candidates.extend(assert_clients)
660        else:
661            # If two nodes have no input, but perform the same operation,
662            # they are not always constant-folded, so we want to merge them.
663            # In that case, the candidates are all the nodes without inputs.
664            merge_candidates = self.noinput_nodes
665
666        replacement_candidates = []
667        for candidate in merge_candidates:
668            if candidate is node:
669                continue
670            if len(node.inputs) != len(candidate.inputs):
671                continue
672
673            cand_has_assert = False
674
675            # Get input list of the candidate with assert removed
676            cand_inputs_assert_removed = []
677            # TODO: Deactivated while Assert merging is disabled. (See above and below.)
678            for i in []:  # candidate.inputs:
679                if i.owner and isinstance(i.owner.op,
680                                          theano.tensor.opt.Assert):
681                    cand_has_assert = True
682                    cand_inputs_assert_removed.append(i.owner.inputs[0])
683                else:
684                    cand_inputs_assert_removed.append(i)
685
686            # TODO: Remove this when Assert merging is re-enabled. (See above.)
687            # Without Assert merging we can still look for identical Asserts,
688            # so we should not treat Asserts separately for now.
689            cand_inputs_assert_removed = candidate.inputs
690
691            # Get input list of the node with assert removed
692            if node_has_assert:
693                node_inputs_assert_removed = []
694                for i in node.inputs:
695                    if i.owner and isinstance(i.owner.op,
696                                              theano.tensor.opt.Assert):
697                        node_inputs_assert_removed.append(i.owner.inputs[0])
698                    else:
699                        node_inputs_assert_removed.append(i)
700            else:
701                node_inputs_assert_removed = node.inputs
702
703            inputs_match = all(node_in is cand_in
704                               for node_in, cand_in
705                               in zip(node_inputs_assert_removed,
706                                      cand_inputs_assert_removed))
707
708            if inputs_match and node.op == candidate.op:
709                if (node, candidate) in self.blacklist:
710                    # They were already tried, and there was an error
711                    continue
712
713                # replace node with candidate
714                if not (node_has_assert or cand_has_assert):
715                    # Schedule transfer of clients from node to candidate
716                    pairs = list(zip(node.outputs,
717                                     candidate.outputs,
718                                     ['merge'] * len(node.outputs)))
719
720                # if the current node has assert input, it should not be
721                # replaced with a candidate node which has no assert input
722                elif node_has_assert and not cand_has_assert:
723                    pairs = list(zip(candidate.outputs,
724                                     node.outputs,
725                                     ['merge'] * len(node.outputs)))
726                else:
727                    new_inputs = self.get_merged_assert_input(node, candidate)
728                    new_node = node.op(*new_inputs)
729                    pairs = list(zip(node.outputs,
730                                     new_node.owner.outputs,
731                                     ['new_node'] * len(node.outputs))) +\
732                        list(zip(candidate.outputs,
733                                 new_node.owner.outputs,
734                                 ['new_node'] * len(node.outputs)))
735
736                # transfer names
737                for pair in pairs:
738                    node_output, cand_output = pair[:2]
739                    # clobber old name with new one
740                    # it's arbitrary... one of the names has to go
741                    if node_output.name:
742                        cand_output.name = node_output.name
743
744                replacement_candidates.append(pairs)
745
746        if replacement_candidates:
747            self.scheduled.append(replacement_candidates)
748        else:
749            self.nodes_seen.add(node)
750            if not node.inputs:
751                self.noinput_nodes.add(node)
752
753    def get_merged_assert_input(self, node, candidate):
754        new_inputs = []
755        for node_i, cand_i in zip(node.inputs, candidate.inputs):
756            # if node_i is assert
757            if (node_i.owner and
758                    isinstance(node_i.owner.op,
759                               theano.tensor.opt.Assert)):
760                # node_i is assert, cand_i is assert
761                if (cand_i.owner and
762                        isinstance(cand_i.owner.op,
763                                   theano.tensor.opt.Assert)):
764                    # Here two assert nodes are merged.
765                    # Step 1. Merge conditions of both assert nodes.
766                    # Step 2. Make the new assert node
767                    node_cond = node_i.owner.inputs[1:]
768                    cand_cond = cand_i.owner.inputs[1:]
769                    new_cond = list(set(node_cond + cand_cond))
770                    new_inputs.append(
771                        theano.tensor.opt.assert_op(
772                            node_i.owner.inputs[0],
773                            *new_cond))
774
775                # node_i is assert, cand_i is not assert
776                else:
777                    new_inputs.append(node_i)
778            else:
779                # if node_i is not an assert node, append cand_i
780                new_inputs.append(cand_i)
781
782        return new_inputs
783
784
785class MergeOptimizer(Optimizer):
786    """
787    Merges parts of the graph that are identical and redundant.
788
789    The basic principle is that if two Applies have ops that compare equal, and
790    identical inputs, then they do not both need to be computed. The clients of
791    one are transferred to the other and one of them is removed from the graph.
792    This procedure is carried out in input->output order through the graph.
793
794    The first step of merging is constant-merging, so that all clients of an
795    int(1) for example, are transferred to a particular instance of int(1).
796
797    """
798
799    def add_requirements(self, fgraph):
800        # Added by default
801        # fgraph.attach_feature(toolbox.ReplaceValidate())
802        if not hasattr(fgraph, 'merge_feature'):
803            fgraph.attach_feature(MergeFeature())
804
805    def apply(self, fgraph):
806        # Constant and non-constant are now applied in the same phase.
807        # I am not sure why, but it seems to be faster this way.
808        sched = fgraph.merge_feature.scheduled
809        nb_fail = 0
810        t0 = time.time()
811        if fgraph.profile:
812            validate_before = fgraph.profile.validate_time
813            callback_before = fgraph.execute_callbacks_time
814            callbacks_before = fgraph.execute_callbacks_times.copy()
815
816        nb_merged = 0
817        nb_constant = 0
818        while sched:
819            pairs_list = sched.pop()
820            success = True
821            for pairs_ in pairs_list:
822                # We must check again the equivalence, as the graph
823                # can have changed. If so, doing the replacement can
824                # introduce node that depend on itself.  Doing the
825                # full check of such cycle everytimes is very time
826                # consumming. I think this double check is faster then
827                # doing the full cycle check. The full cycle check is
828                # skipped by validate() if the graph don't contain
829                # destroyers.
830                var, candidate, merge_mode = pairs_[0]
831                if merge_mode == "new_node" and hasattr(var, 'fgraph'):
832                    pass
833                elif (not hasattr(var, 'fgraph') or
834                      not hasattr(candidate, 'fgraph')):
835                    continue
836
837                # Keep len(item) == 2 for item in pairs
838                pairs = [pair[:2] for pair in pairs_]
839
840                if var.owner and candidate.owner:
841                    node = var.owner
842                    candidate = candidate.owner
843
844                    # Get input list of the candidate node with assert
845                    # nodes removed
846                    cand_inputs_assert_removed = []
847                    for i in candidate.inputs:
848                        if i.owner and isinstance(i.owner.op,
849                                                  theano.tensor.opt.Assert):
850                            cand_inputs_assert_removed.append(
851                                i.owner.inputs[0])
852                        else:
853                            cand_inputs_assert_removed.append(i)
854
855                    # Get input list of the node with assert nodes removed
856                    node_inputs_assert_removed = []
857                    for i in node.inputs:
858                        if i.owner and isinstance(i.owner.op,
859                                                  theano.tensor.opt.Assert):
860                            node_inputs_assert_removed.append(
861                                i.owner.inputs[0])
862                        else:
863                            node_inputs_assert_removed.append(i)
864
865                    if merge_mode == "new_node":
866                        inputs_match = True
867                    else:
868                        inputs_match = all(node_in is cand_in
869                                           for node_in, cand_in in
870                                           zip(node_inputs_assert_removed,
871                                               cand_inputs_assert_removed))
872
873                    # No need to compare the op again, as it don't change.
874                    if not inputs_match:
875                        continue
876
877                    if hasattr(pairs[0][0].fgraph, 'destroy_handler'):
878                        # If both nodes have clients that destroy
879                        # them, we can't merge them.
880                        clients = pairs[0][0].clients + pairs[0][1].clients
881                        if sum([i in utils.flatten(c.op.destroy_map.values())
882                                for c, i in clients
883                                if c != 'output' and
884                                hasattr(c.op, 'destroy_map')]) > 1:
885                            continue
886
887                if len(pairs) == 1 and pairs[0][0].type != pairs[0][1].type:
888                    res = pairs[0][0].type.convert_variable(pairs[0][1])
889
890                    # Since the fgraph.replace only checks the convert_variable
891                    # in one way, we change the order in the case that
892                    # convert_variable will not be successful.
893                    if not res:
894                        pairs = [(pairs[0][1], pairs[0][0])]
895
896                try:
897                    # If all Constants, no need to call validate.
898                    # Only need to check one of the var of each pairs.
899                    # If it is a Constant, the other must also be a Constant as we merge them.
900                    if all([isinstance(old, graph.Constant) for old, new in pairs]):
901                        fgraph.replace_all(pairs, 'MergeOptimizer')
902                    else:
903                        fgraph.replace_all_validate(pairs, 'MergeOptimizer')
904                except InconsistencyError:
905                    success = False
906                    nb_fail += 1
907                    fgraph.merge_feature.blacklist.append(
908                        (pairs[0][0].owner, pairs[0][1].owner))
909
910                if success:
911                    nb_merged += len(pairs)
912                    if isinstance(pairs[0][0], graph.Constant):
913                        nb_constant += 1
914                        # print pairs, pairs[0][0].type
915                    break
916
917        if fgraph.profile:
918            validate_time = fgraph.profile.validate_time - validate_before
919            callback_time = fgraph.execute_callbacks_time - callback_before
920            callbacks_time = {}
921            for k, v in iteritems(fgraph.execute_callbacks_times):
922                if k in callbacks_before:
923                    t = v - callbacks_before[k]
924                    if t > 0:
925                        callbacks_time[k] = t
926                else:
927                    callbacks_time[k] = v
928        else:
929            validate_time = None
930            callback_time = None
931            callbacks_time = {}
932        # clear blacklist
933        fgraph.merge_feature.blacklist = []
934        return (nb_fail, time.time() - t0, validate_time,
935                callback_time, callbacks_time, nb_merged, nb_constant)
936
937    def __str__(self):
938        return self.__class__.__name__
939
940    @staticmethod
941    def print_profile(stream, prof, level=0):
942
943        (nb_fail, replace_time, validate_time,
944         callback_time, callbacks_time, nb_merged, nb_constant) = prof
945
946        blanc = ('    ' * level)
947        print(blanc, "MergeOptimizer", file=stream)
948        print(blanc, "  nb fail=%5d merged=%5d constant=%5d" % (
949              nb_fail, nb_merged, nb_constant), file=stream)
950        print(blanc, "  time replace=%2.2f validate=%2.2f callback=%2.2f" % (
951              replace_time, validate_time, callback_time), file=stream)
952        if callback_time > 1:
953            print(blanc, "  callbacks_time", file=stream)
954            for i in sorted(iteritems(callbacks_time), key=lambda a: a[1]):
955                if i[1] > 0:
956                    # We want to have the __str__ called, so we can't
957                    # just print i.
958                    print(blanc, "      ", i[0], ',', i[1], file=stream)
959
960    @staticmethod
961    def merge_profile(prof1, prof2):
962        def merge_none_number(v1, v2):
963            if v1 is None:
964                return v2
965            if v2 is None:
966                return v1
967            return v1 + v2
968        nb_fail = prof1[0] + prof2[0]
969        replace_time = prof1[1] + prof2[1]
970        validate_time = merge_none_number(prof1[2], prof2[2])
971        callback_time = merge_none_number(prof1[3], prof2[3])
972        callbacks_time = merge_dict(prof1[4], prof2[4])
973        nb_merged = prof1[5] + prof2[5]
974        nb_constant = prof1[6] + prof2[6]
975        return (nb_fail, replace_time, validate_time,
976                callback_time, callbacks_time, nb_merged, nb_constant)
977
978
979def is_same_graph_with_merge(var1, var2, givens=None):
980    """
981    Merge-based implementation of `theano.gof.graph.is_same_graph`.
982
983    See help on `theano.gof.graph.is_same_graph` for additional documentation.
984
985    """
986    if givens is None:
987        givens = {}
988    # Copy variables since the MergeOptimizer will modify them.
989    copied = copy.deepcopy([var1, var2, givens])
990    vars = copied[0:2]
991    givens = copied[2]
992    # Create FunctionGraph.
993    inputs = theano.gof.graph.inputs(vars)
994    # The clone isn't needed as we did a deepcopy and we cloning will
995    # break the mapping in givens.
996    fgraph = theano.gof.fg.FunctionGraph(inputs, vars, clone=False)
997    # Perform Variable substitution.
998    for to_replace, replace_by in iteritems(givens):
999        fgraph.replace(to_replace, replace_by)
1000    # Perform merge optimization.
1001    MergeOptimizer().optimize(fgraph)
1002    # When two variables perform the same computations, they will have the same
1003    # owner in the optimized graph.
1004    # We need to be careful with the special case where the owner is None,
1005    # which happens when the graph is made of a single Variable.
1006    # We also need to make sure we replace a Variable if it is present in
1007    # `givens`.
1008    vars_replaced = [givens.get(v, v) for v in vars]
1009    o1, o2 = [v.owner for v in vars_replaced]
1010    if o1 is None and o2 is None:
1011        # Comparing two single-Variable graphs: they are equal if they are
1012        # the same Variable.
1013        return vars_replaced[0] == vars_replaced[1]
1014    else:
1015        return o1 is o2
1016
1017
1018def pre_constant_merge(vars):
1019    """
1020    Merge constants in the subgraph used to compute nodes in `vars`.
1021
1022    `vars` is a list of nodes, and we want to merge together nodes
1023    that are constant inputs used to compute nodes in that list.
1024
1025    Notes
1026    -----
1027    This function will ignore nodes that are in an fgraph.
1028    It is used to pre-merge nodes generated inside an optimization,
1029    before it is inserted in the fgraph.
1030    It is useful if there are many such replacements to make,
1031    so that DebugMode will not check each of them.
1032
1033    """
1034    seen_var = set()
1035    # signature -> variable (for constants)
1036    const_sig_inv = {}
1037    if isinstance(vars, graph.Variable):
1038        vars = [vars]
1039
1040    def recursive_merge(var):
1041        if var in seen_var:
1042            return var
1043        if not hasattr(var, 'owner'):
1044            return var
1045        if var.owner and hasattr(var.owner, "fgraph"):
1046            return var
1047        seen_var.add(var)
1048        if isinstance(var, graph.Constant):
1049            sig = var.signature()
1050            try:
1051                if sig in const_sig_inv:
1052                    return const_sig_inv[sig]
1053                const_sig_inv[sig] = var
1054            except TypeError:  # unhashable type
1055                warnings.warn(
1056                    "We work around a problem, the following variable"
1057                    " signature isn't hashable. Please, report this to"
1058                    " theano-dev so that the better fix is done. %s" % var)
1059                # Some python object like slice aren't hashable. So
1060                # don't merge them here.
1061                pass
1062            return var
1063        if var.owner:
1064            for idx, inp in enumerate(var.owner.inputs):
1065                var.owner.inputs[idx] = recursive_merge(inp)
1066        return var
1067
1068    return list(map(recursive_merge, vars))
1069
1070
1071########################
1072#   Local Optimizers   #
1073########################
1074
1075class LocalOptimizer(object):
1076    """
1077    A class for node-based optimizations.
1078
1079    Instances should implement the transform function,
1080    and be passed to configure a fgraph-based Optimizer instance.
1081
1082    """
1083
1084    def __hash__(self):
1085        if not hasattr(self, '_optimizer_idx'):
1086            self._optimizer_idx = _optimizer_idx[0]
1087            _optimizer_idx[0] += 1
1088        return self._optimizer_idx
1089
1090    def tracks(self):
1091        """
1092        Return the list of op classes that this opt applies to.
1093
1094        Return None to apply to all nodes.
1095
1096        """
1097        return None
1098
1099    def transform(self, node):
1100        """
1101        Transform a subgraph whose output is `node`.
1102
1103        Subclasses should implement this function so that it returns one of two
1104        kinds of things:
1105
1106        - False to indicate that no optimization can be applied to this `node`;
1107          or
1108        - <list of variables> to use in place of `node`'s outputs in the
1109          greater graph.
1110        - dict(old variables -> new variables). A dictionary that map
1111          from old variables to new variables to replace.
1112
1113        Parameters
1114        ----------
1115        node : an Apply instance
1116
1117        """
1118
1119        raise utils.MethodNotDefined("transform",
1120                                     type(self), self.__class__.__name__)
1121
1122    def add_requirements(self, fgraph):
1123        """
1124        If this local optimization wants to add some requirements to the
1125        fgraph, this is the place to do it.
1126
1127        """
1128        # Added by default
1129        # fgraph.attach_feature(toolbox.ReplaceValidate())
1130        pass
1131
1132    def print_summary(self, stream=sys.stdout, level=0, depth=-1):
1133        print("%s%s id=%i" % (
1134            (' ' * level), self.__class__.__name__, id(self)), file=stream)
1135
1136
1137class LocalMetaOptimizer(LocalOptimizer):
1138    """
1139    Base class for meta-optimizers that try a set of LocalOptimizers
1140    to replace a node and choose the one that executes the fastest.
1141
1142    If the error LocalMetaOptimizerSkipAssertionError is raised during
1143    compilation, we will skip that function compilation and not print
1144    the error.
1145
1146    """
1147
1148    def __init__(self):
1149        self.verbose = config.metaopt.verbose
1150        self.track_dict = defaultdict(lambda: [])
1151        self.tag_dict = defaultdict(lambda: [])
1152        self._tracks = []
1153        self.optimizers = []
1154
1155    def register(self, optimizer, tag_list):
1156        self.optimizers.append(optimizer)
1157        for c in optimizer.tracks():
1158            self.track_dict[c].append(optimizer)
1159            self._tracks.append(c)
1160        for tag in tag_list:
1161            self.tag_dict[tag].append(optimizer)
1162
1163    def tracks(self):
1164        return self._tracks
1165
1166    def transform(self, node):
1167        # safety check: depending on registration, tracks may have been ignored
1168        if self._tracks is not None:
1169            if not isinstance(node.op, tuple(self._tracks)):
1170                return
1171        # first, we need to provide dummy values for all inputs
1172        # to the node that are not shared variables anyway
1173        givens = {}
1174        missing = set()
1175        for input in node.inputs:
1176            if isinstance(input, theano.compile.SharedVariable):
1177                pass
1178            elif hasattr(input.tag, 'test_value'):
1179                givens[input] = theano.shared(
1180                    input.type.filter(input.tag.test_value),
1181                    input.name,
1182                    broadcastable=input.broadcastable,
1183                    borrow=True)
1184            else:
1185                missing.add(input)
1186        if missing:
1187            givens.update(self.provide_inputs(node, missing))
1188            missing.difference_update(givens.keys())
1189        # ensure we have data for all input variables that need it
1190        if missing:
1191            if self.verbose > 0:
1192                print(("%s cannot meta-optimize %s, "
1193                       "%d of %d input shapes unknown" %
1194                       (self.__class__.__name__, node, len(missing), node.nin)))
1195            return
1196        # now we can apply the different optimizations in turn,
1197        # compile the resulting subgraphs and time their execution
1198        if self.verbose > 1:
1199            print(("%s meta-optimizing %s (%d choices):" %
1200                   (self.__class__.__name__, node, len(self.get_opts(node)))))
1201        timings = []
1202        for opt in self.get_opts(node):
1203            outputs = opt.transform(node)
1204            if outputs:
1205                try:
1206                    fn = theano.function([], outputs, givens=givens,
1207                                         on_unused_input='ignore')
1208                    fn.trust_input = True
1209                    timing = min(self.time_call(fn) for _ in range(2))
1210                except LocalMetaOptimizerSkipAssertionError:
1211                    continue
1212                except Exception as e:
1213                    if self.verbose > 0:
1214                        print("* %s: exception" % opt, e)
1215                    continue
1216                else:
1217                    if self.verbose > 1:
1218                        print("* %s: %.5g sec" % (opt, timing))
1219                    timings.append((timing, outputs, opt))
1220            else:
1221                if self.verbose > 0:
1222                    print("* %s: not applicable" % opt)
1223        # finally, we choose the fastest one
1224        if timings:
1225            timings.sort()
1226            if self.verbose > 1:
1227                print("= %s" % timings[0][2])
1228            return timings[0][1]
1229        return
1230
1231    def provide_inputs(self, node, inputs):
1232        """
1233        If implemented, returns a dictionary mapping all symbolic variables
1234        in ``inputs`` to SharedVariable instances of suitable dummy values.
1235        The ``node`` can be inspected to infer required input shapes.
1236
1237        """
1238        raise NotImplementedError()
1239
1240    def get_opts(self, node):
1241        """
1242        Can be overrided to change the way opts are selected
1243        """
1244        return self.track_dict[type(node.op)]
1245
1246    def time_call(self, fn):
1247        start = time.time()
1248        fn()
1249        return time.time() - start
1250
1251
1252class FromFunctionLocalOptimizer(LocalOptimizer):
1253    """
1254    WRITEME
1255
1256    """
1257    def __init__(self, fn, tracks=None, requirements=()):
1258        self.transform = fn
1259        self._tracks = tracks
1260        self.requirements = requirements
1261
1262    def add_requirements(self, fgraph):
1263        for req in self.requirements:
1264            req(fgraph)
1265
1266    def tracks(self):
1267        return self._tracks
1268
1269    def __str__(self):
1270        return getattr(self, '__name__',
1271                       '<FromFunctionLocalOptimizer instance>')
1272
1273    def print_summary(self, stream=sys.stdout, level=0, depth=-1):
1274        print("%s%s id=%i" % (
1275            ' ' * level,
1276            str(self.transform),
1277            id(self)), file=stream)
1278
1279
1280def local_optimizer(tracks, inplace=False, requirements=()):
1281    def decorator(f):
1282        """
1283        WRITEME
1284
1285        """
1286        if tracks is not None:
1287            if len(tracks) == 0:
1288                raise ValueError("Use None instead of an empty list to apply to all nodes.", f.__module__, f.__name__)
1289            for t in tracks:
1290                if not (isinstance(t, op.Op) or issubclass(t, op.PureOp)):
1291                    raise ValueError("Tracks are op classes or instances", f.__module__, f.__name__)
1292        req = requirements
1293        if inplace:
1294            dh_handler = dh.DestroyHandler
1295            req = tuple(requirements) + (
1296                lambda fgraph:
1297                fgraph.attach_feature(dh_handler()),)
1298        rval = FromFunctionLocalOptimizer(f, tracks, req)
1299        rval.__name__ = f.__name__
1300        return rval
1301    return decorator
1302
1303
1304class LocalOptGroup(LocalOptimizer):
1305    """Takes a list of LocalOptimizer and applies them to the node.
1306
1307    Parameters
1308    ----------
1309    optimizers :
1310        The List of optimizers to be applied to a node
1311    reentrant : bool (Default True)
1312        Keyword only argument. Reentrant information. Some global
1313        optimizer like NavigatorOptimizer can use this value to
1314        determine if it ignore new nodes during a pass on the
1315        nodes. Sometimes, ignore_newtrees is not reentrant.
1316    apply_all_opts : bool (Default False)
1317        If False, it will return after the new node after the first optimizer
1318        applied. Otherwise, it will start again with the new node until no new
1319        optimization apply.
1320
1321    """
1322
1323    def __init__(self, *optimizers, **kwargs):
1324        if len(optimizers) == 1 and isinstance(optimizers[0], list):
1325            # This happen when created by LocalGroupDB.
1326            optimizers = tuple(optimizers[0])
1327        self.opts = optimizers
1328        assert isinstance(self.opts, tuple)
1329
1330        self.reentrant = any(getattr(opt, 'reentrant', True)
1331                             for opt in optimizers)
1332        self.retains_inputs = all(getattr(opt, 'retains_inputs', False)
1333                                  for opt in optimizers)
1334
1335        self.apply_all_opts = kwargs.pop('apply_all_opts', False)
1336        self.profile = kwargs.pop('profile', False)
1337        self.track_map = defaultdict(lambda: [])
1338        assert len(kwargs) == 0
1339        if self.profile:
1340            self.time_opts = {}
1341            self.process_count = {}
1342            self.applied_true = {}
1343            self.node_created = {}
1344
1345        for o in self.opts:
1346            if self.profile:
1347                self.time_opts.setdefault(o, 0)
1348                self.process_count.setdefault(o, 0)
1349                self.applied_true.setdefault(o, 0)
1350                self.node_created.setdefault(o, 0)
1351            tracks = o.tracks()
1352            if tracks is None:
1353                self.track_map[None].append(o)
1354            else:
1355                for c in tracks:
1356                    self.track_map[c].append(o)
1357
1358    def __str__(self):
1359        return getattr(self, '__name__',
1360                       ('LocalOptGroup(%s)' %
1361                        ','.join([str(o) for o in self.opts])))
1362
1363    def tracks(self):
1364        t = []
1365        for l in self.opts:
1366            tt = l.tracks()
1367            if tt:
1368                t.extend(tt)
1369        return t
1370
1371    def transform(self, node):
1372        if len(self.opts) == 0:
1373            return
1374        fgraph = node.fgraph
1375        repl = None
1376        while True:
1377            opts = self.track_map[type(node.op)] + self.track_map[node.op] + self.track_map[None]
1378            new_repl = None
1379            for opt in opts:
1380                opt_start = time.time()
1381                new_repl = opt.transform(node)
1382                opt_finish = time.time()
1383                if self.profile:
1384                    self.time_opts[opt] += opt_start - opt_finish
1385                    self.process_count[opt] += 1
1386                if not new_repl:
1387                    continue
1388                if isinstance(new_repl, (tuple, list)):
1389                    new_vars = new_repl
1390                else:  # It must be a dict
1391                    new_vars = list(new_repl.values())
1392                if self.profile:
1393                    self.node_created[opt] += len(graph.ops(fgraph.variables, new_vars))
1394                    self.applied_true[opt] += 1
1395                break  # break from the for loop over optimization.
1396            if not new_repl:  # No optimization applied in the last iteration
1397                return repl
1398            # only 1 iteration
1399            if not self.apply_all_opts:
1400                return new_repl
1401            if not new_vars[0].owner:
1402                # We are at the start of the graph.
1403                return new_repl
1404            if len(new_repl) > 1:
1405                s = set([v.owner for v in new_repl])
1406                assert len(s) == 1
1407            repl = new_repl
1408            node = new_vars[0].owner
1409
1410    @staticmethod
1411    def print_profile(stream, prof, level=0):
1412        (time_opts, process_count, applied_true, node_created, profile) = prof
1413
1414        if not profile:
1415            return
1416
1417        blanc = ('    ' * int(level))
1418        print(blanc, "LocalOptGroup", file=stream)
1419        print(blanc, "---------------------", file=stream)
1420        count_opt = []
1421        not_used = []
1422        not_used_time = 0
1423        for o, count in iteritems(process_count):
1424            if count > 0:
1425                count_opt.append((time_opts[o], applied_true[o], count, o, node_created[o]))
1426            else:
1427                not_used.append((time_opts[o], o))
1428                not_used_time += time_opts[o]
1429        if count_opt:
1430            print(blanc,
1431                  '  time taken - times applied - times tried - name - node_created:',
1432                  file=stream)
1433            count_opt.sort()
1434            for (t, a_t, count, o, n_c) in count_opt[::-1]:
1435                print(blanc, '  %.3fs - %d - %d - %s - %d' % (
1436                      t, a_t, count, o, n_c), file=stream)
1437            print(blanc, '  %.3fs - in %d optimization that were not used (display those with runtime greater than 0)' % (
1438                not_used_time, len(not_used)), file=stream)
1439            not_used.sort(key=lambda nu: (nu[0], str(nu[1])))
1440            for (t, o) in not_used[::-1]:
1441                if t > 0:
1442                    # Skip opt that have 0 times, they probably wasn't even tried.
1443                    print(blanc + "  ", '  %.3fs - %s' % (t, o), file=stream)
1444        else:
1445            print(blanc, " The Optimizer wasn't successful ", file=stream)
1446
1447        print(file=stream)
1448
1449    def merge_profile(prof1, prof2):
1450        raise NotImplementedError
1451
1452    def print_summary(self, stream=sys.stdout, level=0, depth=-1):
1453        print("%s%s id=%i" % (
1454            (' ' * level), self.__class__.__name__, id(self)), file=stream)
1455        if depth != 0:
1456            depth -= 1
1457            for lopt in self.opts:
1458                lopt.print_summary(stream, level=(level + 2), depth=depth)
1459
1460    def add_requirements(self, fgraph):
1461        for opt in self.opts:
1462            opt.add_requirements(fgraph)
1463
1464
1465class GraphToGPULocalOptGroup(LocalOptGroup):
1466    """This is the equivalent of LocalOptGroup for GraphToGPU.
1467
1468    The main different is the function signature of the local
1469    optimizer that use the GraphToGPU signature and not the normal
1470    LocalOptimizer signature.
1471
1472    apply_all_opts=True is not supported
1473
1474    """
1475    def __init__(self, *optimizers, **kwargs):
1476        super(GraphToGPULocalOptGroup, self).__init__(*optimizers, **kwargs)
1477        assert self.apply_all_opts is False
1478
1479    def transform(self, op, context_name, inputs, outputs):
1480        if len(self.opts) == 0:
1481            return
1482        fgraph = outputs[0].fgraph
1483        opts = self.track_map[type(op)] + self.track_map[op] + self.track_map[None]
1484        for opt in opts:
1485            opt_start = time.time()
1486            new_repl = opt.transform(op, context_name, inputs, outputs)
1487            opt_finish = time.time()
1488            if self.profile:
1489                self.time_opts[opt] += opt_start - opt_finish
1490                self.process_count[opt] += 1
1491            if not new_repl:
1492                continue
1493            if self.profile:
1494                self.node_created[opt] += len(graph.ops(fgraph.variables, new_repl))
1495                self.applied_true[opt] += 1
1496
1497            return new_repl
1498
1499
1500class OpSub(LocalOptimizer):
1501    """
1502
1503    Replaces the application of a certain op by the application of
1504    another op that takes the same inputs as what they are replacing.
1505
1506    Parameters
1507    ----------
1508    op1, op2
1509        op1.make_node and op2.make_node must take the same number of
1510        inputs and have the same number of outputs.
1511
1512    Examples
1513    --------
1514    OpSub(add, sub) ==>
1515        add(div(x, y), add(y, x)) -> sub(div(x, y), sub(y, x))
1516
1517    """
1518
1519    # an OpSub does not apply to the nodes it produces
1520    reentrant = False
1521    # all the inputs of the original node are transferred to the outputs
1522    retains_inputs = True
1523
1524    def __init__(self, op1, op2, transfer_tags=True):
1525        self.op1 = op1
1526        self.op2 = op2
1527        self.transfer_tags = transfer_tags
1528
1529    def op_key(self):
1530        return self.op1
1531
1532    def tracks(self):
1533        return [self.op1]
1534
1535    def transform(self, node):
1536        if node.op != self.op1:
1537            return False
1538        repl = self.op2.make_node(*node.inputs)
1539        if self.transfer_tags:
1540            repl.tag = copy.copy(node.tag)
1541            for output, new_output in zip(node.outputs, repl.outputs):
1542                new_output.tag = copy.copy(output.tag)
1543        return repl.outputs
1544
1545    def __str__(self):
1546        return "%s -> %s" % (self.op1, self.op2)
1547
1548
1549class OpRemove(LocalOptimizer):
1550    """
1551
1552    Removes all applications of an op by transferring each of its
1553    outputs to the corresponding input.
1554
1555    """
1556
1557    reentrant = False      # no nodes are added at all
1558
1559    def __init__(self, op):
1560        self.op = op
1561
1562    def op_key(self):
1563        return self.op
1564
1565    def tracks(self):
1566        return [self.op]
1567
1568    def transform(self, node):
1569        if node.op != self.op:
1570            return False
1571        return node.inputs
1572
1573    def __str__(self):
1574        return "%s(x) -> x" % (self.op)
1575
1576    def print_summary(self, stream=sys.stdout, level=0, depth=-1):
1577        print("%s%s(%s) id=%i" % (
1578            ' ' * level,
1579            self.__class__.__name__,
1580            str(self.op),
1581            id(self)), file=stream)
1582
1583
1584class PatternSub(LocalOptimizer):
1585    """
1586
1587    @todo update
1588
1589    Replaces all occurrences of the input pattern by the output pattern:
1590
1591    input_pattern ::= (op, <sub_pattern1>, <sub_pattern2>, ...)
1592    input_pattern ::= dict(pattern = <input_pattern>,
1593                            constraint = <constraint>)
1594    sub_pattern ::= input_pattern
1595    sub_pattern ::= string
1596    sub_pattern ::= a Constant instance
1597    sub_pattern ::= int
1598    sub_pattern ::= float
1599    constraint ::= lambda fgraph, expr: additional matching condition
1600
1601    output_pattern ::= (op, <output_pattern1>, <output_pattern2>, ...)
1602    output_pattern ::= string
1603    output_pattern ::= int
1604    output_pattern ::= float
1605
1606    Each string in the input pattern is a variable that will be set to
1607    whatever expression is found in its place. If the same string is
1608    used more than once, the same expression must be found in those
1609    places. If a string used in the input pattern is used in the
1610    output pattern, the matching expression will be inserted in its
1611    place. The input pattern cannot just be a string but the output
1612    pattern can.
1613
1614    If you put a constant variable in the input pattern, there will be a
1615    match iff a constant variable with the same value and the same type
1616    is found in its place.
1617
1618    You can add a constraint to the match by using the dict(...)  form
1619    described above with a 'constraint' key. The constraint must be a
1620    function that takes the fgraph and the current Variable that we are
1621    trying to match and returns True or False according to an
1622    arbitrary criterion.
1623
1624    The constructor creates a PatternSub that replaces occurrences of
1625    in_pattern by occurrences of out_pattern.
1626
1627    Parameters
1628    ----------
1629    in_pattern
1630        The input pattern that we want to replace.
1631    out_pattern
1632        The replacement pattern.
1633    allow_multiple_clients : bool
1634        If False, the pattern matching will fail if one of the subpatterns has
1635        more than one client.
1636    skip_identities_fn : TODO
1637    name
1638        Allows to override this optimizer name.
1639    pdb : bool
1640        If True, we invoke pdb when the first node in the pattern matches.
1641    tracks : optional
1642        The values that self.tracks() will return. Useful to speed up
1643        optimization sometimes.
1644    get_nodes : optional
1645        If you provide `tracks`, you must provide this parameter. It must be a
1646        function that takes the tracked node and returns a list of nodes on
1647        which we will try this optimizer.
1648
1649    Notes
1650    -----
1651    `tracks` and `get_nodes` can be used to make this optimizer track a less
1652    frequent Op, so this will make this optimizer tried less frequently.
1653
1654    Examples
1655    --------
1656    PatternSub((add, 'x', 'y'), (add, 'y', 'x'))
1657    PatternSub((multiply, 'x', 'x'), (square, 'x'))
1658    PatternSub((subtract, (add, 'x', 'y'), 'y'), 'x')
1659    PatternSub((power, 'x', Constant(double, 2.0)), (square, 'x'))
1660    PatternSub((boggle, {'pattern': 'x',
1661                         'constraint': lambda expr: expr.type == scrabble}),
1662               (scrabble, 'x'))
1663    """
1664
1665    def __init__(self, in_pattern, out_pattern,
1666                 allow_multiple_clients=False,
1667                 skip_identities_fn=None, name=None, pdb=False,
1668                 tracks=(), get_nodes=None,
1669                 values_eq_approx=None):
1670        self.in_pattern = in_pattern
1671        self.out_pattern = out_pattern
1672        self.values_eq_approx = values_eq_approx
1673        if isinstance(in_pattern, (list, tuple)):
1674            self.op = self.in_pattern[0]
1675        elif isinstance(in_pattern, dict):
1676            self.op = self.in_pattern['pattern'][0]
1677        else:
1678            raise TypeError("The pattern to search for must start with "
1679                            "a specific Op instance.")
1680        self.__doc__ = (self.__class__.__doc__ +
1681                        "\n\nThis instance does: " +
1682                        str(self) + "\n")
1683        self.allow_multiple_clients = allow_multiple_clients
1684        self.skip_identities_fn = skip_identities_fn
1685        if name:
1686            self.__name__ = name
1687        self.pdb = pdb
1688        self._tracks = tracks
1689        self.get_nodes = get_nodes
1690        if tracks != ():
1691            assert get_nodes
1692
1693    def op_key(self):
1694        return self.op
1695
1696    def tracks(self):
1697        if self._tracks != ():
1698            return self._tracks
1699        return [self.op]
1700
1701    def transform(self, node, get_nodes=True):
1702        """
1703        Checks if the graph from node corresponds to in_pattern. If it does,
1704        constructs out_pattern and performs the replacement.
1705
1706        """
1707        if get_nodes and self.get_nodes is not None:
1708            for real_node in self.get_nodes(node):
1709                if real_node == "output":
1710                    continue
1711                ret = self.transform(real_node, get_nodes=False)
1712                if ret is not False and ret is not None:
1713                    assert len(real_node.outputs) == len(ret)
1714                    if self.values_eq_approx:
1715                        ret.tag.values_eq_approx = self.values_eq_approx
1716                    return dict(izip(real_node.outputs, ret))
1717
1718        if node.op != self.op:
1719            return False
1720        # TODO: if we remove pdb, do this speed things up?
1721
1722        def match(pattern, expr, u, allow_multiple_clients=False, pdb=False):
1723            # TODO move outside match
1724            def retry_with_equiv():
1725                if not self.skip_identities_fn:
1726                    return False
1727                expr_equiv = self.skip_identities_fn(expr)
1728                if expr_equiv is None:
1729                    return False
1730                # TODO: Not sure how to handle multiple_clients flag
1731                # print 'retrying match', pattern, expr_equiv
1732                return match(pattern, expr_equiv, u,
1733                             allow_multiple_clients=allow_multiple_clients)
1734
1735            if isinstance(pattern, (list, tuple)):
1736                if expr.owner is None:
1737                    return False
1738                if (not (expr.owner.op == pattern[0]) or
1739                        (not allow_multiple_clients and len(expr.clients) > 1)):
1740                    return retry_with_equiv()
1741                if len(pattern) - 1 != len(expr.owner.inputs):
1742                    return retry_with_equiv()
1743                for p, v in zip(pattern[1:], expr.owner.inputs):
1744                    u = match(p, v, u, self.allow_multiple_clients)
1745                    if not u:
1746                        return False
1747            elif isinstance(pattern, dict):
1748                try:
1749                    real_pattern = pattern['pattern']
1750                except KeyError:
1751                    raise KeyError(
1752                        "Malformed pattern: %s (expected key 'pattern')"
1753                        % pattern)
1754                constraint = pattern.get('constraint', lambda expr: True)
1755                if constraint(expr):
1756                    return match(real_pattern, expr, u,
1757                                 pattern.get('allow_multiple_clients',
1758                                             allow_multiple_clients))
1759                else:
1760                    return retry_with_equiv()
1761            elif isinstance(pattern, string_types):
1762                v = unify.Var(pattern)
1763                if u[v] is not v and u[v] is not expr:
1764                    return retry_with_equiv()
1765                else:
1766                    u = u.merge(expr, v)
1767            elif (isinstance(pattern, (integer_types, float)) and
1768                    isinstance(expr, graph.Constant)):
1769                if np.all(theano.tensor.constant(pattern).value == expr.value):
1770                    return u
1771                else:
1772                    return retry_with_equiv()
1773            elif (isinstance(pattern, graph.Constant) and
1774                    isinstance(expr, graph.Constant) and
1775                    pattern.equals(expr)):
1776                return u
1777            else:
1778                return retry_with_equiv()
1779            if pdb:
1780                import pdb
1781                pdb.set_trace()
1782            return u
1783
1784        u = match(self.in_pattern, node.out, unify.Unification(), True,
1785                  self.pdb)
1786        if u:
1787            def build(pattern, u):
1788                if isinstance(pattern, (list, tuple)):
1789                    args = [build(p, u) for p in pattern[1:]]
1790                    return pattern[0](*args)
1791                elif isinstance(pattern, string_types):
1792                    return u[unify.Var(pattern)]
1793                elif isinstance(pattern, (integer_types, float)):
1794                    return pattern
1795                else:
1796                    return pattern.clone()
1797            p = self.out_pattern
1798            ret = build(p, u)
1799            if self.values_eq_approx:
1800                ret.tag.values_eq_approx = self.values_eq_approx
1801            return [ret]
1802        else:
1803            return False
1804
1805    def __str__(self):
1806        if getattr(self, '__name__', None):
1807            return self.__name__
1808
1809        def pattern_to_str(pattern):
1810            if isinstance(pattern, (list, tuple)):
1811                return "%s(%s)" % (
1812                    str(pattern[0]),
1813                    ", ".join([pattern_to_str(p) for p in pattern[1:]]))
1814            elif isinstance(pattern, dict):
1815                return "%s subject to %s" % (
1816                    pattern_to_str(pattern['pattern']),
1817                    str(pattern.get('constraint', 'no conditions')))
1818            else:
1819                return str(pattern)
1820        return "%s -> %s" % (
1821            pattern_to_str(self.in_pattern),
1822            pattern_to_str(self.out_pattern))
1823
1824    def __repr__(self):
1825        return str(self)
1826
1827    def print_summary(self, stream=sys.stdout, level=0, depth=-1):
1828        name = getattr(self, '__name__', getattr(self, 'name', None))
1829        print("%s%s %s(%s, %s) id=%i" % (
1830            ' ' * level,
1831            self.__class__.__name__,
1832            name,
1833            str(self.in_pattern),
1834            str(self.out_pattern),
1835            id(self)), file=stream)
1836
1837
1838##################
1839#   Navigators   #
1840##################
1841
1842# Use the following classes to apply LocalOptimizers
1843
1844class Updater:
1845    def __init__(self, importer, pruner, chin, name=None):
1846        self.importer = importer
1847        self.pruner = pruner
1848        self.chin = chin
1849        self.name = name
1850
1851    def __str__(self):
1852        return "Updater{%s}" % str(self.name)
1853
1854    def on_import(self, fgraph, node, reason):
1855        if self.importer:
1856            self.importer(node)
1857
1858    def on_prune(self, fgraph, node, reason):
1859        if self.pruner:
1860            self.pruner(node)
1861
1862    def on_change_input(self, fgraph, node, i, r, new_r, reason):
1863        if self.chin:
1864            self.chin(node, i, r, new_r, reason)
1865
1866    def on_detach(self, fgraph):
1867        # To allow pickling this object
1868        self.importer = None
1869        self.pruner = None
1870        self.chin = None
1871
1872
1873class NavigatorOptimizer(Optimizer):
1874    """
1875    Abstract class.
1876
1877    Parameters
1878    ----------
1879    local_opt
1880        A LocalOptimizer to apply over a FunctionGraph (or None is Ok too).
1881    ignore_newtrees
1882        - True: new subgraphs returned by an optimization is not a
1883          candidate for optimization.
1884        - False: new subgraphs returned by an optimization is a candidate
1885          for optimization.
1886        - 'auto': let the local_opt set this parameter via its 'reentrant'
1887          attribute.
1888    failure_callback
1889            A function that takes (exception, navigator, [(old, new),
1890            (old,new),...]) and we call it if there's an exception.
1891
1892            If the trouble is from local_opt.transform(), the new variables
1893            will be 'None'.
1894
1895            If the trouble is from validation (the new types don't match for
1896            example) then the new variables will be the ones created by
1897            transform().
1898
1899            If this parameter is None, then exceptions are not caught here
1900            (raised normally).
1901
1902    """
1903    @staticmethod
1904    def warn(exc, nav, repl_pairs, local_opt, node):
1905        """
1906        Failure_callback for NavigatorOptimizer: print traceback.
1907
1908        """
1909        if config.on_opt_error != 'ignore':
1910            _logger.error("Optimization failure due to: %s" % str(local_opt))
1911            _logger.error("node: %s" % str(node))
1912            _logger.error("TRACEBACK:")
1913            _logger.error(traceback.format_exc())
1914        if config.on_opt_error == 'pdb':
1915            pdb.post_mortem(sys.exc_info()[2])
1916        elif isinstance(exc, AssertionError) or config.on_opt_error == 'raise':
1917            # We always crash on AssertionError because something may be
1918            # seriously wrong if such an exception is raised.
1919            raise exc
1920
1921    @staticmethod
1922    def warn_inplace(exc, nav, repl_pairs, local_opt, node):
1923        """
1924        Failure_callback for NavigatorOptimizer.
1925
1926        Ignore InconsistencyErrors, print traceback.
1927
1928        If error during replacement repl_pairs is set. Otherwise None.
1929
1930        """
1931        if isinstance(exc, InconsistencyError):
1932            return
1933        return NavigatorOptimizer.warn(exc, nav, repl_pairs, local_opt, node)
1934
1935    @staticmethod
1936    def warn_ignore(exc, nav, repl_pairs, local_opt, node):
1937        """
1938        Failure_callback for NavigatorOptimizer: ignore all errors.
1939
1940        """
1941        pass
1942
1943    def __init__(self, local_opt, ignore_newtrees='auto',
1944                 failure_callback=None):
1945        self.local_opt = local_opt
1946        if ignore_newtrees == 'auto':
1947            self.ignore_newtrees = not getattr(local_opt, 'reentrant', True)
1948        else:
1949            self.ignore_newtrees = ignore_newtrees
1950        self.failure_callback = failure_callback
1951
1952    def attach_updater(self, fgraph, importer, pruner, chin=None, name=None):
1953        """
1954        Install some FunctionGraph listeners to help the navigator deal with
1955        the ignore_trees-related functionality.
1956
1957        Parameters
1958        ----------
1959        importer
1960            Function that will be called whenever optimizations add stuff
1961            to the graph.
1962        pruner
1963            Function to be called when optimizations remove stuff
1964            from the graph.
1965        chin
1966            "on change input" called whenever a node's inputs change.
1967        name
1968            name of the Updater to attach.
1969
1970        Returns
1971        -------
1972        object
1973            The FunctionGraph plugin that handles the three tasks.
1974            Keep this around so that you can detach later!
1975
1976        """
1977        if self.ignore_newtrees:
1978            importer = None
1979
1980        if importer is None and pruner is None:
1981            return None
1982
1983        u = Updater(importer, pruner, chin, name=name)
1984        fgraph.attach_feature(u)
1985        return u
1986
1987    def detach_updater(self, fgraph, u):
1988        """
1989        Undo the work of attach_updater.
1990
1991        Parameters
1992        ----------
1993        u
1994            A return-value of attach_updater.
1995
1996        Returns
1997        -------
1998        None
1999
2000        """
2001        if u is not None:
2002            fgraph.remove_feature(u)
2003
2004    def process_node(self, fgraph, node, lopt=None):
2005        """
2006        This function will use `lopt` to `transform` the `node`. The
2007        `transform` method will return either False or a list of Variables
2008        that are intended to replace `node.outputs`.
2009
2010        If the fgraph accepts the replacement, then the optimization is
2011        successful, and this function returns True.
2012
2013        If there are no replacement candidates or the fgraph rejects the
2014        replacements, this function returns False.
2015
2016        Parameters
2017        ----------
2018        fgraph
2019            A FunctionGraph.
2020        node
2021            An Apply instance in `fgraph`
2022        lopt
2023            A LocalOptimizer instance that may have a better idea for
2024            how to compute node's outputs.
2025
2026        Returns
2027        -------
2028        bool
2029            True iff the `node`'s outputs were replaced in the `fgraph`.
2030
2031        """
2032        lopt = lopt or self.local_opt
2033        try:
2034            replacements = lopt.transform(node)
2035        except Exception as e:
2036            if self.failure_callback is not None:
2037                self.failure_callback(e, self,
2038                                      [(x, None) for x in node.outputs],
2039                                      lopt, node)
2040                return False
2041            else:
2042                raise
2043        if replacements is False or replacements is None:
2044            return False
2045        old_vars = node.outputs
2046        remove = []
2047        if isinstance(replacements, dict):
2048            if "remove" in replacements:
2049                remove = replacements.pop("remove")
2050            old_vars = list(replacements.keys())
2051            replacements = list(replacements.values())
2052        elif not isinstance(replacements, (tuple, list)):
2053            raise TypeError('Optimizer %s gave wrong type of replacement. '
2054                            'Expected list or tuple. Got %s' % (
2055                                lopt, replacements))
2056        if len(old_vars) != len(replacements):
2057            raise ValueError('Optimizer %s gave wrong number of replacements'
2058                             % lopt)
2059        # None in the replacement mean that this variable isn't used
2060        # and we want to remove it
2061        for r, rnew in zip(old_vars, replacements):
2062            if rnew is None and len(r.clients) > 0:
2063                raise ValueError("A local optimizer tried to remove a Variable that is used")
2064        # If an output would be replaced by itself, no need to perform
2065        # the replacement
2066        repl_pairs = [(r, rnew) for r, rnew in zip(old_vars, replacements)
2067                      if rnew is not r and rnew is not None]
2068
2069        if len(repl_pairs) == 0:
2070            return False
2071        try:
2072            fgraph.replace_all_validate_remove(repl_pairs,
2073                                               reason=lopt,
2074                                               remove=remove)
2075            return True
2076        except Exception as e:
2077            # This means the replacements were rejected by the fgraph.
2078            #
2079            # This is not supposed to happen.  The default failure_callback
2080            # will print a traceback as a warning.
2081            if self.failure_callback is not None:
2082                self.failure_callback(e, self, repl_pairs, lopt, node)
2083                return False
2084            else:
2085                raise
2086
2087    def add_requirements(self, fgraph):
2088        super(NavigatorOptimizer, self).add_requirements(fgraph)
2089        # Added by default
2090        # fgraph.attach_feature(toolbox.ReplaceValidate())
2091        if self.local_opt:
2092            self.local_opt.add_requirements(fgraph)
2093
2094    def print_summary(self, stream=sys.stdout, level=0, depth=-1):
2095        print("%s%s (%i)" % (
2096            (' ' * level), self.__class__.__name__, id(self)), file=stream)
2097        if depth != 0:
2098            self.local_opt.print_summary(stream, level=(level + 2),
2099                                         depth=(depth - 1))
2100
2101
2102class TopoOptimizer(NavigatorOptimizer):
2103    """
2104    TopoOptimizer has one local optimizer. It tries to apply to each node, in topological order (or reverse).
2105    Each time the local optimizer applies, the node gets replaced, and the topooptimizer moves on to the next one.
2106
2107    """
2108
2109    def __init__(self, local_opt, order='in_to_out', ignore_newtrees=False,
2110                 failure_callback=None):
2111        if order not in ['out_to_in', 'in_to_out']:
2112            raise ValueError("order must be 'out_to_in' or 'in_to_out'")
2113        self.order = order
2114        NavigatorOptimizer.__init__(self, local_opt, ignore_newtrees,
2115                                    failure_callback)
2116
2117    def apply(self, fgraph, start_from=None):
2118        if start_from is None:
2119            start_from = fgraph.outputs
2120        callback_before = fgraph.execute_callbacks_time
2121        nb_nodes_start = len(fgraph.apply_nodes)
2122        t0 = time.time()
2123        q = deque(graph.io_toposort(fgraph.inputs, start_from))
2124        io_t = time.time() - t0
2125
2126        def importer(node):
2127            if node is not current_node:
2128                q.append(node)
2129
2130        u = self.attach_updater(fgraph, importer, None,
2131                                name=getattr(self, 'name', None))
2132        nb = 0
2133        try:
2134            t0 = time.time()
2135            while q:
2136                if self.order == 'out_to_in':
2137                    node = q.pop()
2138                else:
2139                    node = q.popleft()
2140                if node not in fgraph.apply_nodes:
2141                    continue
2142                current_node = node
2143                nb += self.process_node(fgraph, node)
2144            loop_t = time.time() - t0
2145        finally:
2146            self.detach_updater(fgraph, u)
2147
2148        callback_time = fgraph.execute_callbacks_time - callback_before
2149        nb_nodes_end = len(fgraph.apply_nodes)
2150        return (self, nb, nb_nodes_start, nb_nodes_end,
2151                io_t, loop_t, callback_time, self.local_opt)
2152
2153    @staticmethod
2154    def print_profile(stream, prof, level=0):
2155        blanc = ('    ' * level)
2156        if prof is None:  # Happen as merge_profile() isn't implemented
2157            print(blanc, "TopoOptimizer merge_profile not implemented",
2158                  file=stream)
2159            return
2160
2161        (opt, nb, nb_nodes_start, nb_nodes_end,
2162         io_t, loop_t, callback_time, lopt) = prof
2163
2164        print(blanc, "TopoOptimizer ",
2165              getattr(opt, "name", getattr(opt, "__name__", "")), file=stream)
2166
2167        print(blanc, "  nb_node (start, end, changed)", (
2168            nb_nodes_start, nb_nodes_end, nb), file=stream)
2169        print(blanc, "  init io_toposort", io_t, file=stream)
2170        print(blanc, "  loop time", loop_t, file=stream)
2171        print(blanc, "  callback_time", callback_time, file=stream)
2172        if isinstance(lopt, LocalOptGroup):
2173            if lopt.profile:
2174                lopt.print_profile(stream, (lopt.time_opts,
2175                                            lopt.process_count,
2176                                            lopt.applied_true,
2177                                            lopt.node_created,
2178                                            lopt.profile),
2179                                   level=level + 1)
2180
2181    def __str__(self):
2182        return getattr(self, '__name__',
2183                       '<TopoOptimizer instance>')
2184
2185
2186def out2in(*local_opts, **kwargs):
2187    """
2188    Uses the TopoOptimizer from the output nodes to input nodes of the graph.
2189    """
2190    name = (kwargs and kwargs.pop('name', None))
2191    if len(local_opts) > 1:
2192        # Don't wrap it uselessly if their is only 1 optimization.
2193        local_opts = LocalOptGroup(*local_opts)
2194    else:
2195        local_opts, = local_opts
2196        if not name:
2197            name = local_opts.__name__
2198    ret = TopoOptimizer(local_opts,
2199                        order='out_to_in',
2200                        failure_callback=TopoOptimizer.warn_inplace,
2201                        **kwargs)
2202    if name:
2203        ret.__name__ = name
2204    return ret
2205
2206
2207def in2out(*local_opts, **kwargs):
2208    """
2209    Uses the TopoOptimizer from the input nodes to output nodes of the graph.
2210    """
2211    name = (kwargs and kwargs.pop('name', None))
2212    if len(local_opts) > 1:
2213        # Don't wrap it uselessly if their is only 1 optimization.
2214        local_opts = LocalOptGroup(*local_opts)
2215    else:
2216        local_opts, = local_opts
2217        if not name:
2218            name = local_opts.__name__
2219    ret = TopoOptimizer(local_opts,
2220                        order='in_to_out',
2221                        failure_callback=TopoOptimizer.warn_inplace,
2222                        **kwargs)
2223    if name:
2224        ret.__name__ = name
2225    return ret
2226
2227
2228class OpKeyOptimizer(NavigatorOptimizer):
2229    """
2230    WRITEME
2231
2232    """
2233
2234    def __init__(self, local_opt, ignore_newtrees=False,
2235                 failure_callback=None):
2236        if not hasattr(local_opt, 'op_key'):
2237            raise TypeError("LocalOptimizer for OpKeyOptimizer must have "
2238                            "an 'op_key' method.")
2239        NavigatorOptimizer.__init__(self, local_opt, ignore_newtrees,
2240                                    failure_callback)
2241
2242    def apply(self, fgraph):
2243        op = self.local_opt.op_key()
2244        if isinstance(op, (list, tuple)):
2245            q = reduce(list.__iadd__, map(fgraph.get_nodes, op))
2246        else:
2247            q = list(fgraph.get_nodes(op))
2248
2249        def importer(node):
2250            if node is not current_node:
2251                if node.op == op:
2252                    q.append(node)
2253
2254        u = self.attach_updater(fgraph, importer, None,
2255                                name=getattr(self, 'name', None))
2256        try:
2257            while q:
2258                node = q.pop()
2259                if node not in fgraph.apply_nodes:
2260                    continue
2261                current_node = node
2262                self.process_node(fgraph, node)
2263        finally:
2264            self.detach_updater(fgraph, u)
2265
2266    def add_requirements(self, fgraph):
2267        """
2268        Requires the following features:
2269          - NodeFinder
2270          - ReplaceValidate(Added by default)
2271
2272        """
2273        super(OpKeyOptimizer, self).add_requirements(fgraph)
2274        fgraph.attach_feature(toolbox.NodeFinder())
2275
2276
2277class ChangeTracker:
2278    def __init__(self):
2279        self.changed = False
2280        self.nb_imported = 0
2281
2282    def on_import(self, fgraph, node, reason):
2283        self.nb_imported += 1
2284        self.changed = True
2285
2286    def on_change_input(self, fgraph, node, i, r, new_r, reason):
2287        self.changed = True
2288
2289    def reset(self):
2290        self.changed = False
2291
2292    def on_attach(self, fgraph):
2293        fgraph.change_tracker = self
2294
2295    def on_detach(self, fgraph):
2296        del fgraph.change_tracker
2297
2298
2299def merge_dict(d1, d2):
2300    """
2301    merge 2 dicts by adding the values.
2302    """
2303    d = d1.copy()
2304    for k, v in iteritems(d2):
2305        if k in d:
2306            d[k] += v
2307        else:
2308            d[k] = v
2309    return d
2310
2311
2312class EquilibriumOptimizer(NavigatorOptimizer):
2313    """
2314    Apply optimizations until equilibrium point.
2315
2316    Parameters
2317    ----------
2318    optimizers : list or set
2319        Local or global optimizations to apply until equilibrium.
2320        The global optimizer will be run at the start of each iteration before
2321        the local optimizer.
2322    max_use_ratio : int or float
2323        Each optimizer can be applied at most (size of graph * this number)
2324        times.
2325    ignore_newtrees
2326        See EquilibriumDB ignore_newtrees parameter definition.
2327    final_optimizers
2328        Global optimizers that will be run after each iteration.
2329    cleanup_optimizers
2330        Global optimizers that apply a list of pre determined optimization.
2331        They must not traverse the graph as they are called very frequently.
2332        The MergeOptimizer is one example of optimization that respect this.
2333        They are applied after all global optimizer, then when one local optimizer is applied, then after all final optimizer.
2334
2335    """
2336
2337    def __init__(self,
2338                 optimizers,
2339                 failure_callback=None,
2340                 ignore_newtrees=True,
2341                 tracks_on_change_inputs=False,
2342                 max_use_ratio=None,
2343                 final_optimizers=None,
2344                 cleanup_optimizers=None):
2345        super(EquilibriumOptimizer, self).__init__(
2346            None,
2347            ignore_newtrees=ignore_newtrees,
2348            failure_callback=failure_callback)
2349        self.local_optimizers_map = OrderedDict()
2350        self.local_optimizers_all = []
2351        self.global_optimizers = []
2352        self.final_optimizers = []
2353        self.cleanup_optimizers = []
2354        self.tracks_on_change_inputs = tracks_on_change_inputs
2355        for opt in optimizers:
2356            if isinstance(opt, LocalOptimizer):
2357                if opt.tracks() is None:
2358                    self.local_optimizers_all.append(opt)
2359                else:
2360                    for c in opt.tracks():
2361                        self.local_optimizers_map.setdefault(c, []).append(opt)
2362            else:
2363                self.global_optimizers.append(opt)
2364        if final_optimizers:
2365            self.final_optimizers = final_optimizers
2366        if cleanup_optimizers:
2367            self.cleanup_optimizers = cleanup_optimizers
2368        self.max_use_ratio = max_use_ratio
2369        assert self.max_use_ratio is not None, (
2370            'max_use_ratio has to be a number')
2371
2372    def get_local_optimizers(self):
2373        for opt in self.local_optimizers_all:
2374            yield opt
2375        # if repeat is not a problem we can drop the set
2376        s = set()
2377        for lopt in itervalues(self.local_optimizers_map):
2378            for opt in lopt:
2379                if opt not in s:
2380                    yield opt
2381                    s.add(opt)
2382
2383    def add_requirements(self, fgraph):
2384        super(EquilibriumOptimizer, self).add_requirements(fgraph)
2385        for opt in self.get_local_optimizers():
2386            opt.add_requirements(fgraph)
2387        for opt in self.global_optimizers:
2388            opt.add_requirements(fgraph)
2389        for opt in self.final_optimizers:
2390            opt.add_requirements(fgraph)
2391        for opt in self.cleanup_optimizers:
2392            opt.add_requirements(fgraph)
2393
2394    def apply(self, fgraph, start_from=None):
2395        change_tracker = ChangeTracker()
2396        fgraph.attach_feature(change_tracker)
2397        if start_from is None:
2398            start_from = fgraph.outputs
2399        else:
2400            for node in start_from:
2401                assert node in fgraph.outputs
2402
2403        changed = True
2404        max_use_abort = False
2405        opt_name = None
2406        global_process_count = {}
2407        start_nb_nodes = len(fgraph.apply_nodes)
2408        max_nb_nodes = len(fgraph.apply_nodes)
2409        max_use = max_nb_nodes * self.max_use_ratio
2410
2411        loop_timing = []
2412        loop_process_count = []
2413        global_opt_timing = []
2414        time_opts = {}
2415        io_toposort_timing = []
2416        nb_nodes = []
2417        node_created = {}
2418        global_sub_profs = []
2419        final_sub_profs = []
2420        cleanup_sub_profs = []
2421        for opt in (self.global_optimizers +
2422                    list(self.get_local_optimizers()) +
2423                    self.final_optimizers +
2424                    self.cleanup_optimizers):
2425            global_process_count.setdefault(opt, 0)
2426            time_opts.setdefault(opt, 0)
2427            node_created.setdefault(opt, 0)
2428
2429        def apply_cleanup(profs_dict):
2430            changed = False
2431            for copt in self.cleanup_optimizers:
2432                change_tracker.reset()
2433                nb = change_tracker.nb_imported
2434                t_opt = time.time()
2435                sub_prof = copt.apply(fgraph)
2436                time_opts[copt] += time.time() - t_opt
2437                profs_dict[copt].append(sub_prof)
2438                if change_tracker.changed:
2439                    process_count.setdefault(copt, 0)
2440                    process_count[copt] += 1
2441                    global_process_count[copt] += 1
2442                    changed = True
2443                    node_created[copt] += change_tracker.nb_imported - nb
2444            return changed
2445
2446        while changed and not max_use_abort:
2447            process_count = {}
2448            t0 = time.time()
2449            changed = False
2450            iter_cleanup_sub_profs = {}
2451            for copt in self.cleanup_optimizers:
2452                iter_cleanup_sub_profs[copt] = []
2453
2454            # apply global optimizers
2455            sub_profs = []
2456            for gopt in self.global_optimizers:
2457                change_tracker.reset()
2458                nb = change_tracker.nb_imported
2459                t_opt = time.time()
2460                sub_prof = gopt.apply(fgraph)
2461                time_opts[gopt] += time.time() - t_opt
2462                sub_profs.append(sub_prof)
2463                if change_tracker.changed:
2464                    process_count.setdefault(gopt, 0)
2465                    process_count[gopt] += 1
2466                    global_process_count[gopt] += 1
2467                    changed = True
2468                    node_created[gopt] += change_tracker.nb_imported - nb
2469                    if global_process_count[gopt] > max_use:
2470                        max_use_abort = True
2471                        opt_name = (getattr(gopt, "name", None) or
2472                                    getattr(gopt, "__name__", ""))
2473            global_sub_profs.append(sub_profs)
2474
2475            global_opt_timing.append(float(time.time() - t0))
2476
2477            # apply clean up as global opt can have done changes that
2478            # request that
2479            changed |= apply_cleanup(iter_cleanup_sub_profs)
2480
2481            # apply local optimizer
2482            topo_t0 = time.time()
2483            q = deque(graph.io_toposort(fgraph.inputs, start_from))
2484            io_toposort_timing.append(time.time() - topo_t0)
2485
2486            nb_nodes.append(len(q))
2487            max_nb_nodes = max(max_nb_nodes, len(q))
2488            max_use = max_nb_nodes * self.max_use_ratio
2489
2490            def importer(node):
2491                if node is not current_node:
2492                    q.append(node)
2493
2494            chin = None
2495            if self.tracks_on_change_inputs:
2496                def chin(node, i, r, new_r, reason):
2497                    if node is not current_node and not isinstance(node, str):
2498                        q.append(node)
2499            u = self.attach_updater(fgraph, importer, None,
2500                                    chin=chin,
2501                                    name=getattr(self, 'name', None))
2502            try:
2503                while q:
2504                    node = q.pop()
2505                    if node not in fgraph.apply_nodes:
2506                        continue
2507                    current_node = node
2508                    for lopt in (self.local_optimizers_all +
2509                                 self.local_optimizers_map.get(type(node.op), []) +
2510                                 self.local_optimizers_map.get(node.op, [])):
2511                        nb = change_tracker.nb_imported
2512                        t_opt = time.time()
2513                        lopt_change = self.process_node(fgraph, node, lopt)
2514                        time_opts[lopt] += time.time() - t_opt
2515                        if not lopt_change:
2516                            continue
2517                        process_count.setdefault(lopt, 0)
2518                        process_count[lopt] += 1
2519                        global_process_count[lopt] += 1
2520                        changed = True
2521                        node_created[lopt] += change_tracker.nb_imported - nb
2522                        changed |= apply_cleanup(iter_cleanup_sub_profs)
2523                        if global_process_count[lopt] > max_use:
2524                            max_use_abort = True
2525                            opt_name = (getattr(lopt, "name", None) or
2526                                        getattr(lopt, "__name__", ""))
2527                        if node not in fgraph.apply_nodes:
2528                            # go to next node
2529                            break
2530            finally:
2531                self.detach_updater(fgraph, u)
2532
2533            # Apply final optimizers
2534            sub_profs = []
2535            t_before_final_opt = time.time()
2536            for gopt in self.final_optimizers:
2537                change_tracker.reset()
2538                nb = change_tracker.nb_imported
2539                t_opt = time.time()
2540                sub_prof = gopt.apply(fgraph)
2541                time_opts[gopt] += time.time() - t_opt
2542                sub_profs.append(sub_prof)
2543                if change_tracker.changed:
2544                    process_count.setdefault(gopt, 0)
2545                    process_count[gopt] += 1
2546                    global_process_count[gopt] += 1
2547                    changed = True
2548                    node_created[gopt] += change_tracker.nb_imported - nb
2549                    if global_process_count[gopt] > max_use:
2550                        max_use_abort = True
2551                        opt_name = (getattr(gopt, "name", None) or
2552                                    getattr(gopt, "__name__", ""))
2553            final_sub_profs.append(sub_profs)
2554
2555            global_opt_timing[-1] += time.time() - t_before_final_opt
2556            # apply clean up as final opt can have done changes that
2557            # request that
2558            changed |= apply_cleanup(iter_cleanup_sub_profs)
2559            # merge clean up profiles during that iteration.
2560            c_sub_profs = []
2561            for copt, sub_profs in iteritems(iter_cleanup_sub_profs):
2562                sub_prof = sub_profs[0]
2563                for s_p in sub_profs[1:]:
2564                    sub_prof = copt.merge_profile(sub_prof, s_p)
2565                c_sub_profs.append(sub_prof)
2566            cleanup_sub_profs.append(c_sub_profs)
2567
2568            loop_process_count.append(process_count)
2569            loop_timing.append(float(time.time() - t0))
2570
2571        end_nb_nodes = len(fgraph.apply_nodes)
2572
2573        if max_use_abort:
2574            msg = ("EquilibriumOptimizer max'ed out by '%s'" % opt_name +
2575                   ". You can safely raise the current threshold of " +
2576                   "%f with the theano flag 'optdb.max_use_ratio'." %
2577                   config.optdb.max_use_ratio)
2578            if theano.config.on_opt_error == 'raise':
2579                raise AssertionError(msg)
2580            else:
2581                _logger.error(msg)
2582        fgraph.remove_feature(change_tracker)
2583        assert len(loop_process_count) == len(loop_timing)
2584        assert len(loop_process_count) == len(global_opt_timing)
2585        assert len(loop_process_count) == len(nb_nodes)
2586        assert len(loop_process_count) == len(io_toposort_timing)
2587        assert len(loop_process_count) == len(global_sub_profs)
2588        assert len(loop_process_count) == len(final_sub_profs)
2589        assert len(loop_process_count) == len(cleanup_sub_profs)
2590        return (self, loop_timing, loop_process_count,
2591                (start_nb_nodes, end_nb_nodes, max_nb_nodes),
2592                global_opt_timing, nb_nodes, time_opts, io_toposort_timing,
2593                node_created, global_sub_profs, final_sub_profs,
2594                cleanup_sub_profs)
2595
2596    def print_summary(self, stream=sys.stdout, level=0, depth=-1):
2597        name = getattr(self, 'name', None)
2598        print("%s%s %s id=%i" % (
2599            (' ' * level), self.__class__.__name__, name, id(self)), file=stream)
2600        if depth != 0:
2601            for lopt in self.get_local_optimizers():
2602                lopt.print_summary(stream, level=(level + 2),
2603                                   depth=(depth - 1))
2604
2605    @staticmethod
2606    def print_profile(stream, prof, level=0):
2607        (opt, loop_timing, loop_process_count,
2608         (start_nb_nodes, end_nb_nodes, max_nb_nodes),
2609         global_opt_timing, nb_nodes, time_opts, io_toposort_timing,
2610         node_created, global_sub_profs, final_sub_profs,
2611         cleanup_sub_profs) = prof
2612
2613        blanc = ('    ' * level)
2614        print(blanc, "EquilibriumOptimizer", end=' ', file=stream)
2615        print(blanc, getattr(opt, "name",
2616                             getattr(opt, "__name__", "")), file=stream)
2617        print(blanc, "  time %.3fs for %d passes" % (
2618            sum(loop_timing), len(loop_timing)), file=stream)
2619        print(blanc, "  nb nodes (start, end,  max) %d %d %d" % (
2620            start_nb_nodes, end_nb_nodes, max_nb_nodes), file=stream)
2621        print(blanc, "  time io_toposort %.3fs" % sum(
2622            io_toposort_timing), file=stream)
2623        s = sum([time_opts[o] for o in opt.get_local_optimizers()])
2624        print(blanc, "  time in local optimizers %.3fs" % s, file=stream)
2625        s = sum([time_opts[o] for o in opt.global_optimizers])
2626        print(blanc, "  time in global optimizers %.3fs" % s, file=stream)
2627        s = sum([time_opts[o] for o in opt.final_optimizers])
2628        print(blanc, "  time in final optimizers %.3fs" % s, file=stream)
2629        s = sum([time_opts[o] for o in opt.cleanup_optimizers])
2630        print(blanc, "  time in cleanup optimizers %.3fs" % s, file=stream)
2631        for i in range(len(loop_timing)):
2632            lopt = ""
2633            if loop_process_count[i]:
2634                d = list(reversed(sorted(iteritems(loop_process_count[i]),
2635                                         key=lambda a: a[1])))
2636                lopt = " ".join([str((str(k), v)) for k, v
2637                                 in d[:5]])
2638                if len(d) > 5:
2639                    lopt += " ..."
2640            print(blanc, ('  %2d - %.3fs %d (%.3fs in global opts, '
2641                          '%.3fs io_toposort) - %d nodes - %s' % (
2642                              i, loop_timing[i],
2643                              sum(loop_process_count[i].values()),
2644                              global_opt_timing[i],
2645                              io_toposort_timing[i], nb_nodes[i],
2646                              lopt)), file=stream)
2647
2648        count_opt = []
2649        not_used = []
2650        not_used_time = 0
2651        process_count = {}
2652        for o in (opt.global_optimizers +
2653                  list(opt.get_local_optimizers()) +
2654                  list(opt.final_optimizers) +
2655                  list(opt.cleanup_optimizers)):
2656            process_count.setdefault(o, 0)
2657        for count in loop_process_count:
2658            for o, v in iteritems(count):
2659                process_count[o] += v
2660        for o, count in iteritems(process_count):
2661            if count > 0:
2662                count_opt.append((time_opts[o], count,
2663                                  node_created[o], o))
2664            else:
2665                not_used.append((time_opts[o], o))
2666                not_used_time += time_opts[o]
2667
2668        if count_opt:
2669            print(blanc,
2670                  '  times - times applied - nb node created - name:',
2671                  file=stream)
2672            count_opt.sort()
2673            for (t, count, n_created, o) in count_opt[::-1]:
2674                print(blanc, '  %.3fs - %d - %d - %s' % (
2675                    t, count, n_created, o), file=stream)
2676            print(blanc, '  %.3fs - in %d optimization that were not used (display only those with a runtime > 0)' % (
2677                not_used_time, len(not_used)), file=stream)
2678            not_used.sort(key=lambda nu: (nu[0], str(nu[1])))
2679            for (t, o) in not_used[::-1]:
2680                if t > 0:
2681                    # Skip opt that have 0 times, they probably wasn't even tried.
2682                    print(blanc + "  ", '  %.3fs - %s' % (t, o), file=stream)
2683            print(file=stream)
2684        gf_opts = [o for o in (opt.global_optimizers +
2685                               list(opt.final_optimizers) +
2686                               list(opt.cleanup_optimizers))
2687                   if o.print_profile.__code__ is not
2688                   Optimizer.print_profile.__code__]
2689        if not gf_opts:
2690            return
2691        print(blanc, "Global, final and clean up optimizers", file=stream)
2692        for i in range(len(loop_timing)):
2693            print(blanc, "Iter %d" % i, file=stream)
2694            for o, prof in zip(opt.global_optimizers, global_sub_profs[i]):
2695                try:
2696                    o.print_profile(stream, prof, level + 2)
2697                except NotImplementedError:
2698                    print(blanc, "merge not implemented for ", o)
2699            for o, prof in zip(opt.final_optimizers, final_sub_profs[i]):
2700                try:
2701                    o.print_profile(stream, prof, level + 2)
2702                except NotImplementedError:
2703                    print(blanc, "merge not implemented for ", o)
2704            for o, prof in zip(opt.cleanup_optimizers, cleanup_sub_profs[i]):
2705                try:
2706                    o.print_profile(stream, prof, level + 2)
2707                except NotImplementedError:
2708                    print(blanc, "merge not implemented for ", o)
2709
2710    @staticmethod
2711    def merge_profile(prof1, prof2):
2712        # (opt, loop_timing, loop_process_count, max_nb_nodes,
2713        # global_opt_timing, nb_nodes, time_opts, io_toposort_timing) = prof1
2714        local_optimizers = OrderedSet(prof1[0].get_local_optimizers()).union(
2715            prof2[0].get_local_optimizers())
2716        global_optimizers = OrderedSet(prof1[0].global_optimizers).union(
2717            prof2[0].global_optimizers)
2718        final_optimizers = list(OrderedSet(prof1[0].final_optimizers).union(
2719            prof2[0].final_optimizers))
2720        cleanup_optimizers = list(OrderedSet(prof1[0].cleanup_optimizers).union(
2721            prof2[0].cleanup_optimizers))
2722        new_opt = EquilibriumOptimizer(
2723            local_optimizers.union(global_optimizers),
2724            max_use_ratio=1,
2725            final_optimizers=final_optimizers,
2726            cleanup_optimizers=cleanup_optimizers)
2727
2728        def add_append_list(l1, l2):
2729            l = copy.copy(l1)
2730            for idx, nb in enumerate(l2):
2731                if idx < len(l):
2732                    l[idx] += nb
2733                else:
2734                    l.append(nb)
2735            return l
2736
2737        loop_timing = add_append_list(prof1[1], prof2[1])
2738
2739        loop_process_count = list(prof1[2])
2740        global_sub_profs = []
2741        final_sub_profs = []
2742        cleanup_sub_profs = []
2743
2744        for i in range(min(len(loop_process_count), len(prof2[2]))):
2745            process_count = loop_process_count[i]
2746            for process, count in iteritems(prof2[2][i]):
2747                if process in process_count:
2748                    process_count[process] += count
2749                else:
2750                    process_count[process] = count
2751
2752            def merge(opts, attr, idx):
2753                tmp = []
2754                for opt in opts:
2755                    o1 = getattr(prof1[0], attr)
2756                    o2 = getattr(prof2[0], attr)
2757                    if opt in o1 and opt in o2:
2758                        p1 = prof1[idx][i][o1.index(opt)]
2759                        p2 = prof2[idx][i][o2.index(opt)]
2760                        m = None
2761                        if hasattr(opt, 'merge_profile'):
2762                            m = opt.merge_profile(p1, p2)
2763                    elif opt in o1:
2764                        m = prof1[idx][i][o1.index(opt)]
2765                    else:
2766                        m = prof2[idx][i][o2.index(opt)]
2767                    tmp.append(m)
2768                return tmp
2769            global_sub_profs.append(merge(global_optimizers, 'global_optimizers', 9))
2770            final_sub_profs.append(merge(final_optimizers, 'final_optimizers', 10))
2771            cleanup_sub_profs.append(merge(cleanup_optimizers, 'cleanup_optimizers', 11))
2772
2773        # Add the iteration done by only one of the profile.
2774        loop_process_count.extend(prof1[2][len(loop_process_count):])
2775        global_sub_profs.extend(prof1[9][len(global_sub_profs):])
2776        final_sub_profs.extend(prof1[10][len(final_sub_profs):])
2777        cleanup_sub_profs.extend(prof1[11][len(cleanup_sub_profs):])
2778
2779        global_sub_profs.extend(prof2[9][len(loop_process_count):])
2780        final_sub_profs.extend(prof2[10][len(loop_process_count):])
2781        cleanup_sub_profs.extend(prof2[11][len(loop_process_count):])
2782
2783        max_nb_nodes = max(prof1[3], prof2[3])
2784
2785        global_opt_timing = add_append_list(prof1[4], prof2[4])
2786
2787        nb_nodes = add_append_list(prof1[5], prof2[5])
2788
2789        time_opts = merge_dict(prof1[6], prof2[6])
2790        io_toposort_timing = add_append_list(prof1[7], prof2[7])
2791        assert (len(loop_timing) == len(global_opt_timing) ==
2792                len(global_sub_profs) ==
2793                len(io_toposort_timing) == len(nb_nodes))
2794        assert len(loop_timing) == max(len(prof1[1]), len(prof2[1]))
2795
2796        node_created = merge_dict(prof1[8], prof2[8])
2797        return (new_opt,
2798                loop_timing,
2799                loop_process_count,
2800                max_nb_nodes,
2801                global_opt_timing,
2802                nb_nodes,
2803                time_opts,
2804                io_toposort_timing,
2805                node_created,
2806                global_sub_profs,
2807                final_sub_profs,
2808                cleanup_sub_profs)
2809
2810#################
2811#   Utilities   #
2812#################
2813
2814
2815def _check_chain(r, chain):
2816    """
2817    WRITEME
2818
2819    """
2820    chain = list(reversed(chain))
2821    while chain:
2822        elem = chain.pop()
2823        if elem is None:
2824            if r.owner is not None:
2825                return False
2826        elif r.owner is None:
2827            return False
2828        elif isinstance(elem, op.Op):
2829            if not r.owner.op == elem:
2830                return False
2831        else:
2832            try:
2833                if (issubclass(elem, op.Op) and
2834                        not isinstance(r.owner.op, elem)):
2835                    return False
2836            except TypeError:
2837                return False
2838        if chain:
2839            r = r.owner.inputs[chain.pop()]
2840    # print 'check_chain', _check_chain.n_calls
2841    # _check_chain.n_calls += 1
2842
2843    # The return value will be used as a Boolean, but some Variables cannot
2844    # be used as Booleans (the results of comparisons, for instance)
2845    return (r is not None)
2846# _check_chain.n_calls = 0
2847
2848
2849def check_chain(r, *chain):
2850    """
2851    WRITEME
2852
2853    """
2854    if isinstance(r, graph.Apply):
2855        r = r.outputs[0]
2856    return _check_chain(r, reduce(list.__iadd__, ([x, 0] for x in chain)))
2857
2858
2859def pre_greedy_local_optimizer(list_optimizations, out):
2860    """
2861    This function traverses the computation graph described by all
2862    ``node`` in the graph before the variable out but that are not in the
2863    fgraph. It applies each of the local_optimizations on the traversed graph.
2864
2865    Its main use is to apply locally constant folding when generating
2866    the graph of the indices of a subtensor.
2867
2868    We should not apply optimizations on node that are in fgraph.
2869    So we don't optimize node that have an attribute fgraph.
2870
2871    Notes
2872    -----
2873    This doesn't do an equilibrium... So if there is optimization
2874    like local_upcast_elemwise_constant_inputs in the list, that
2875    adds additional node to the inputs of the node, it can
2876    be needed to call this function multiple times.
2877
2878    """
2879    def local_recursive_function(list_opt, out, optimized_vars, depth):
2880        if not getattr(out, 'owner', None):
2881            return [out], optimized_vars
2882        node = out.owner
2883
2884        if hasattr(node, 'fgraph'):
2885            return node.outputs, optimized_vars
2886        for idx, inp in enumerate(node.inputs):
2887            if inp in optimized_vars:
2888                nw_in = optimized_vars[inp]
2889            else:
2890                if inp.owner:
2891                    outs, optimized_vars = local_recursive_function(
2892                        list_opt,
2893                        inp,
2894                        optimized_vars,
2895                        depth + 1)
2896                    for k, v in zip(inp.owner.outputs, outs):
2897                        optimized_vars[k] = v
2898                    nw_in = outs[inp.owner.outputs.index(inp)]
2899
2900                else:
2901                    nw_in = inp
2902                    optimized_vars[inp] = inp
2903            node.inputs[idx] = nw_in
2904
2905        results = node.outputs
2906        for opt in list_opt:
2907            ret = opt.transform(node)
2908            if ret is not False and ret is not None:
2909                assert len(ret) == len(node.outputs), opt
2910                for k, v in zip(node.outputs, ret):
2911                    optimized_vars[k] = v
2912                results = ret
2913                if ret[0].owner:
2914                    node = out.owner
2915                else:
2916                    break
2917        return results, optimized_vars
2918    if out.owner:
2919        out_index = out.owner.outputs.index(out)
2920    else:
2921        out_index = 0
2922    final_outs, optimized_nodes = local_recursive_function(
2923        list_optimizations, out, {}, 0)
2924    return final_outs[out_index]
2925
2926
2927def copy_stack_trace(from_var, to_var):
2928    """
2929    Copies the stack trace from one or more tensor variables to
2930    one or more tensor variables and returns the destination variables.
2931
2932    Parameters
2933    ----------
2934    from_var
2935        Tensor variable or list of tensor variables to copy stack traces from.
2936    to_var
2937        Tensor variable or list of tensor variables to copy stack traces to.
2938
2939    Notes
2940    -----
2941    The stacktrace is assumed to be of the form of a list of lists
2942    of tuples. Each tuple contains the filename, line number, function name
2943    and so on. Each list of tuples contains the truples belonging to a
2944    particular variable.
2945
2946    """
2947
2948    # Store stack traces from from_var
2949    tr = []
2950    if type(from_var) is list:
2951        # If from_var is a list, store concatenated stack traces
2952        for v in from_var:
2953            tr += getattr(v.tag, 'trace', [])
2954
2955    else:
2956        # If from_var is not a list, it must be a single tensor variable,
2957        # so just store that particular stack trace
2958        tr = getattr(from_var.tag, 'trace', [])
2959
2960    if tr and isinstance(tr[0], tuple):
2961        # There was one single stack trace, we encapsulate it in a list
2962        tr = [tr]
2963
2964    # Copy over stack traces to to_var
2965    if type(to_var) is list:
2966        # Copy over stack traces from from_var to each variable in
2967        # to_var, including the stack_trace of the to_var before
2968        for v in to_var:
2969            v.tag.trace = getattr(v.tag, 'trace', []) + tr
2970    else:
2971        # Copy over stack traces from from_var to each variable to
2972        # to_var, including the stack_trace of the to_var before
2973        to_var.tag.trace = getattr(to_var.tag, 'trace', []) + tr
2974    return to_var
2975
2976
2977@contextlib.contextmanager
2978def inherit_stack_trace(from_var):
2979    """
2980    Contextmanager that copies the stack trace from one or more variable nodes to all
2981    variable nodes constructed in the body. new_nodes is the list of all the newly created
2982    variable nodes inside an optimization that is managed by graph.nodes_constructed().
2983
2984    Parameters
2985    ----------
2986    from_var
2987        Variable node or a list of variable nodes to copy stack traces from.
2988
2989    """
2990    with graph.nodes_constructed() as new_nodes:
2991        yield
2992    copy_stack_trace(from_var, new_nodes)
2993
2994
2995def check_stack_trace(f_or_fgraph, ops_to_check='last', bug_print='raise'):
2996    """
2997    This function checks if the outputs of specific ops of a compiled graph
2998    have a stack.
2999
3000    Parameters
3001    ----------
3002    f_or_fgraph: theano.compile.function_module.Function or
3003          theano.gof.fg.FunctionGraph
3004        The compiled function or the function graph to be analysed.
3005    ops_to_check: it can be of four different types:
3006          - classes or instances inheriting from theano.gof.Op
3007          - tuple/list of classes or instances inheriting from theano.gof.Op
3008          - string
3009          - function returning a boolean and taking as input an instance of
3010            theano.gof.Op.
3011        - if ops_to_check is a string, it should be either 'last' or 'all'.
3012          'last' will check only the last op of the graph while 'all' will
3013          check all the ops of the graph.
3014        - if ops_to_check is an op or a tuple/list of ops, the function will
3015          check that all the outputs of their occurrences in the graph have a
3016          stack trace.
3017        - if ops_to_check is a function, it should take as input a
3018          theano.gof.Op and return a boolean indicating if the input op should
3019          be checked or not.
3020    bug_print: string belonging to {'raise', 'warn', 'ignore'}
3021        You can specify the behaviour of the function when the specified
3022        ops_to_check are not in the graph of f_or_fgraph: it can either raise
3023        an exception, write a warning or simply ignore it.
3024
3025    Returns
3026    -------
3027    boolean
3028        True if the outputs of the specified ops have a stack, False otherwise.
3029
3030    """
3031    if isinstance(f_or_fgraph, theano.compile.function_module.Function):
3032        fgraph = f_or_fgraph.maker.fgraph
3033    elif isinstance(f_or_fgraph, theano.gof.fg.FunctionGraph):
3034        fgraph = f_or_fgraph
3035    else:
3036        raise ValueError('The type of f_or_fgraph is not supported')
3037
3038    if (isinstance(ops_to_check, theano.gof.Op) or
3039            (inspect.isclass(ops_to_check) and
3040                issubclass(ops_to_check, theano.gof.Op))):
3041        ops_to_check = (ops_to_check,)
3042
3043    # if ops_to_check is a string
3044    if isinstance(ops_to_check, string_types):
3045        if ops_to_check == 'last':
3046            apply_nodes_to_check = [fgraph.outputs[i].owner for i in range(
3047                len(fgraph.outputs))]
3048        elif ops_to_check == 'all':
3049            apply_nodes_to_check = fgraph.apply_nodes
3050        else:
3051            raise ValueError('The string ops_to_check is not recognised')
3052
3053    # if ops_to_check is a list/tuple of ops
3054    elif isinstance(ops_to_check, (tuple, list)):
3055        # Separate classes from instances in ops_to_check
3056        op_instances = []
3057        op_classes = []
3058        for obj in ops_to_check:
3059            if isinstance(obj, theano.gof.Op):
3060                op_instances.append(obj)
3061            else:
3062                op_classes.append(obj)
3063        op_classes = tuple(op_classes)
3064
3065        apply_nodes_to_check = (
3066            [node for node in fgraph.apply_nodes if node.op in ops_to_check] +
3067            [node for node in fgraph.apply_nodes
3068             if isinstance(node.op, op_classes) or
3069             (hasattr(node.op, 'scalar_op') and
3070              isinstance(node.op.scalar_op, op_classes))])
3071
3072    # if ops_to_check is a function
3073    elif hasattr(ops_to_check, '__call__'):
3074        apply_nodes_to_check = [node for node in fgraph.apply_nodes
3075                                if ops_to_check(node)]
3076
3077    else:
3078        raise ValueError('ops_to_check does not have the right type')
3079
3080    if not apply_nodes_to_check:
3081        msg = 'Provided op instances/classes are not in the graph or the ' \
3082              'graph is empty'
3083        if bug_print == 'warn':
3084            warnings.warn(msg)
3085        elif bug_print == 'raise':
3086            raise Exception(msg)
3087        elif bug_print == 'ignore':
3088            pass
3089        else:
3090            raise ValueError('The string bug_print is not recognised')
3091
3092    for node in apply_nodes_to_check:
3093        for output in node.outputs:
3094            if (not hasattr(output.tag, 'trace') or not output.tag.trace):
3095                return False
3096
3097    return True
3098
3099
3100class CheckStrackTraceFeature(object):
3101    def on_import(self, fgraph, node, reason):
3102        # In optdb we only register the CheckStackTraceOptimization when
3103        # theano.config.check_stack_trace is not off but we also double check here.
3104        if theano.config.check_stack_trace != 'off' and not check_stack_trace(fgraph, 'all'):
3105            if theano.config.check_stack_trace == 'raise':
3106                    raise AssertionError(
3107                        'Empty stack trace! The optimization that inserted this variable is ' + str(reason))
3108            elif theano.config.check_stack_trace in ['log', 'warn']:
3109                apply_nodes_to_check = fgraph.apply_nodes
3110                for node in apply_nodes_to_check:
3111                    for output in node.outputs:
3112                        if not hasattr(output.tag, 'trace') or not output.tag.trace:
3113                            output.tag.trace = [[('', 0, 'Empty stack trace! The optimization that' +
3114                                                 'inserted this variable is ' + str(reason), '')]]
3115                if theano.config.check_stack_trace == 'warn':
3116                        warnings.warn(
3117                            'Empty stack trace! The optimization that inserted this variable is' + str(reason))
3118
3119
3120class CheckStackTraceOptimization(Optimizer):
3121    """Optimizer that serves to add CheckStackTraceOptimization as an fgraph feature."""
3122
3123    def add_requirements(self, fgraph):
3124        if not hasattr(fgraph, 'CheckStrackTraceFeature'):
3125            fgraph.attach_feature(CheckStrackTraceFeature())
3126
3127    def apply(self, fgraph):
3128        pass
3129