1"""
2IfElse introduces lazy evaluation in Theano (coupled with the CVM/VM
3linkers). It resembles the if clause of any programming language, that
4has a `then` and `else` branch, and executes either one or the other
5according to the condition provided.
6
7This op differs from the already existent `switch` op, that evaluates both
8branches of the clause and afterwards picks (according to the condition)
9which value to report. Note also that `switch` is an elemwise operation (so
10it picks each entry of a matrix according to the condition) while `ifelse`
11is a global operation with a scalar condition.
12"""
13from __future__ import absolute_import, print_function, division
14from copy import deepcopy
15from theano.compat import izip
16import logging
17
18import numpy as np
19
20import theano.tensor
21from theano.tensor import TensorType
22from theano import gof
23from theano.gof import Op, Apply
24
25from six import iteritems
26from six.moves import xrange
27from theano.compile import optdb
28from theano.tensor import opt
29from theano.scan_module.scan_utils import clone
30
31
32__docformat__ = 'restructedtext en'
33__authors__ = ("Razvan Pascanu "
34               "James Bergstra "
35               "Dumitru Erhan "
36               "David Warde-Farley")
37__copyright__ = "(c) 2010, Universite de Montreal"
38__contact__ = "Razvan Pascanu <r.pascanu@gmail>"
39
40_logger = logging.getLogger('theano.ifelse')
41
42
43class IfElse(Op):
44    """
45    Op that provides conditional graph evaluation if used with the CVM/VM
46    linkers. Note that there exist a helpful function `ifelse` that should
47    be used to instantiate the op!
48
49    According to a scalar condition `condition` the op evaluates and then
50    returns all the tensors provided on the `then` branch, otherwise it
51    evaluates and returns the tensors provided on the `else` branch. The op
52    supports multiple tensors on each branch, with the condition that the same
53    number of tensors are on the `then` as on the `else` and there is a one
54    to one correspondence between them (shape and dtype wise).
55
56    The `then` branch is defined as the first N tensors (after the
57    condition), while the `else` branch is defined as the last N tensors.
58
59    Example usage:
60
61        ``rval = ifelse(condition, rval_if_true1, .., rval_if_trueN,
62                        rval_if_false1, rval_if_false2, .., rval_if_falseN)``
63
64    :note:
65        Other Linkers then CVM and VM are INCOMPATIBLE with this Op, and
66        will ignore its lazy characteristic, computing both the True and
67        False branch before picking one.
68
69    """
70    def __init__(self, n_outs, as_view=False, gpu=False, name=None):
71        if as_view:
72            # check destroyhandler and others to ensure that a view_map with
73            # multiple inputs can work
74            view_map = {}
75            for idx in xrange(n_outs):
76                view_map[idx] = [idx + 1]
77            self.view_map = view_map
78        self.as_view = as_view
79        self.gpu = gpu
80        self.n_outs = n_outs
81        self.name = name
82
83    def __eq__(self, other):
84        if not type(self) == type(other):
85            return False
86        if not self.as_view == other.as_view:
87            return False
88        if not self.gpu == other.gpu:
89            return False
90        if not self.n_outs == other.n_outs:
91            return False
92        return True
93
94    def __hash__(self):
95        rval = (hash(type(self)) ^
96                hash(self.as_view) ^
97                hash(self.gpu) ^
98                hash(self.n_outs))
99        return rval
100
101    def __str__(self):
102        args = []
103        if self.name is not None:
104            args.append(self.name)
105        if self.as_view:
106            args.append('inplace')
107        if self.gpu:
108            args.append('gpu')
109        return 'if{%s}' % ','.join(args)
110
111    def infer_shape(self, node, inputs_shapes):
112        # By construction, corresponding then/else pairs have the same number
113        # of dimensions
114
115        ts_shapes = inputs_shapes[1:][:self.n_outs]
116        fs_shapes = inputs_shapes[1:][self.n_outs:]
117        # All elements of all shape tuples for the true and false outputs are
118        # unpacked into the inputs of a separate ifelse, and then the outputs
119        # of that ifelse are packed back into shape tuples.
120        new_ts_inputs = []
121        for ts_shape in ts_shapes:
122            if isinstance(ts_shape, (list, tuple)):
123                new_ts_inputs += list(ts_shape)
124            else:
125                # It can be None for generic objects
126                return [None] * self.n_outs
127
128        new_fs_inputs = []
129        for fs_shape in fs_shapes:
130            if isinstance(fs_shape, (list, tuple)):
131                new_fs_inputs += list(fs_shape)
132            else:
133                # It can be None for generic objects
134                return [None] * self.n_outs
135
136        assert len(new_ts_inputs) == len(new_fs_inputs)
137        if len(new_ts_inputs + new_fs_inputs) > 0:
138            name_tokens = ['shape']
139            if self.name is not None:
140                name_tokens.append(self.name)
141
142            new_ifelse = IfElse(
143                n_outs=len(new_ts_inputs),
144                as_view=False,
145                gpu=False,
146                name='_'.join(name_tokens))
147            new_outs = new_ifelse(node.inputs[0],
148                                  *(new_ts_inputs + new_fs_inputs),
149                                  **dict(return_list=True))
150        else:
151            new_outs = []
152
153        # generate pairs of shapes
154        out_shapes = []
155        for out in node.outputs:
156            out_shapes.append(tuple(new_outs[:out.ndim]))
157            new_outs = new_outs[out.ndim:]
158
159        # new_outs should be an empty list after last iteration
160        assert len(new_outs) == 0
161
162        return out_shapes
163
164    def make_node(self, c, *args):
165        assert len(args) == 2 * self.n_outs, (
166            "Wrong number of arguments to make_node: "
167            "expected %d, got %d" % (2 * self.n_outs, len(args))
168        )
169        c = theano.tensor.as_tensor_variable(c)
170        if not self.gpu:
171            # When gpu is true, we are given only gpuarrays, and we want
172            # to keep them as gpuarrays
173            nw_args = []
174            for x in args:
175                if hasattr(x, '_as_TensorVariable'):
176                    nw_args.append(x._as_TensorVariable())
177                elif isinstance(x, theano.Variable):
178                    nw_args.append(x)
179                else:
180                    nw_args.append(theano.tensor.as_tensor_variable(x))
181            args = nw_args
182        ts = args[:self.n_outs]
183        fs = args[self.n_outs:]
184
185        for t, f in izip(ts, fs):
186            if t.type != f.type:
187                raise TypeError(('IfElse requires same types for true and '
188                                'false return values'), t, f, t.type, f.type)
189        if c.ndim > 0:
190            raise TypeError(('Condition given to the op has to be a scalar '
191                             'with 0 standing for False, anything else '
192                             'for True'))
193        return Apply(self, [c] + list(args), [t.type() for t in ts])
194
195    def R_op(self, inputs, eval_points):
196        return self(inputs[0], *eval_points[1:], **dict(return_list=True))
197
198    def grad(self, ins, grads):
199        ts = ins[1:][:self.n_outs]
200        fs = ins[1:][self.n_outs:]
201        if self.name is not None:
202            nw_name_t = self.name + '_grad_t'
203            nw_name_f = self.name + '_grad_f'
204        else:
205            nw_name_t = None
206            nw_name_f = None
207        if_true_op = IfElse(n_outs=self.n_outs,
208                            as_view=self.as_view,
209                            gpu=self.gpu,
210                            name=nw_name_t)
211
212        if_false_op = IfElse(n_outs=self.n_outs,
213                             as_view=self.as_view,
214                             gpu=self.gpu,
215                             name=nw_name_f)
216
217        # The grads can have a different dtype then the inputs.
218        # As inputs true/false pair must have the same dtype,
219        # we must cast the zeros to the corresponding grad dtype
220        # and not the input dtype.
221        if_true = ([ins[0]] +
222                   grads +
223                   [theano.tensor.zeros_like(t, dtype=grads[i].dtype)
224                    for i, t in enumerate(ts)])
225        if_false = ([ins[0]] +
226                    [theano.tensor.zeros_like(f, dtype=grads[i].dtype)
227                     for i, f in enumerate(fs)] +
228                    grads)
229
230        condition = ins[0]
231        # condition does affect the elements of the output so it is connected.
232        # For the sake of making the gradient convenient we assume that
233        # condition + epsilon always triggers the same branch as condition
234        condition_grad = condition.zeros_like().astype(theano.config.floatX)
235        return ([condition_grad] +
236                if_true_op(*if_true, **dict(return_list=True)) +
237                if_false_op(*if_false, **dict(return_list=True)))
238
239    def make_thunk(self, node, storage_map, compute_map, no_recycling, impl=None):
240        cond = node.inputs[0]
241        ts = node.inputs[1:][:self.n_outs]
242        fs = node.inputs[1:][self.n_outs:]
243        outputs = node.outputs
244
245        def thunk():
246            if not compute_map[cond][0]:
247                return [0]
248            else:
249                truthval = storage_map[cond][0]
250                if truthval != 0:
251                    ls = [idx + 1 for idx in xrange(self.n_outs)
252                          if not compute_map[ts[idx]][0]]
253                    if len(ls) > 0:
254                        return ls
255                    else:
256                        for out, t in izip(outputs, ts):
257                            compute_map[out][0] = 1
258                            val = storage_map[t][0]
259                            if self.as_view:
260                                storage_map[out][0] = val
261                            # Work around broken numpy deepcopy
262                            elif type(val) in (np.ndarray, np.memmap):
263                                storage_map[out][0] = val.copy()
264                            else:
265                                storage_map[out][0] = deepcopy(val)
266                        return []
267                else:
268                    ls = [1 + idx + self.n_outs for idx in xrange(self.n_outs)
269                          if not compute_map[fs[idx]][0]]
270                    if len(ls) > 0:
271                        return ls
272                    else:
273                        for out, f in izip(outputs, fs):
274                            compute_map[out][0] = 1
275                            # can't view both outputs unless destroyhandler
276                            # improves
277                            # Work around broken numpy deepcopy
278                            val = storage_map[f][0]
279                            if type(val) in (np.ndarray, np.memmap):
280                                storage_map[out][0] = val.copy()
281                            else:
282                                storage_map[out][0] = deepcopy(val)
283                        return []
284
285        thunk.lazy = True
286        thunk.inputs = [storage_map[v] for v in node.inputs]
287        thunk.outputs = [storage_map[v] for v in node.outputs]
288        return thunk
289
290
291def ifelse(condition, then_branch, else_branch, name=None):
292    """
293    This function corresponds to an if statement, returning (and evaluating)
294    inputs in the ``then_branch`` if ``condition`` evaluates to True or
295    inputs in the ``else_branch`` if ``condition`` evalutates to False.
296
297    :type condition: scalar like
298    :param condition:
299        ``condition`` should be a tensor scalar representing the condition.
300        If it evaluates to 0 it corresponds to False, anything else stands
301        for True.
302
303    :type then_branch: list of theano expressions/ theano expression
304    :param then_branch:
305        A single theano variable or a list of theano variables that the
306        function should return as the output if ``condition`` evaluates to
307        true. The number of variables should match those in the
308        ``else_branch``, and there should be a one to one correspondance
309        (type wise) with the tensors provided in the else branch
310
311    :type else_branch: list of theano expressions/ theano expressions
312    :param else_branch:
313        A single theano variable or a list of theano variables that the
314        function should return as the output if ``condition`` evaluates to
315        false. The number of variables should match those in the then branch,
316        and there should be a one to one correspondace (type wise) with the
317        tensors provided in the then branch.
318
319    :return:
320        A list of theano variables or a single variable (depending on the
321        nature of the ``then_branch`` and ``else_branch``). More exactly if
322        ``then_branch`` and ``else_branch`` is a tensor, then
323        the return variable will be just a single variable, otherwise a
324        list. The value returns correspond either to the values in the
325        ``then_branch`` or in the ``else_branch`` depending on the value of
326        ``cond``.
327    """
328
329    rval_type = None
330    if type(then_branch) is list:
331        rval_type = list
332    elif type(then_branch) is tuple:
333        rval_type = tuple
334
335    if type(then_branch) not in (list, tuple):
336        then_branch = [then_branch]
337    if type(else_branch) not in (list, tuple):
338        else_branch = [else_branch]
339
340    # Some of the elements might be converted into another type,
341    # we will store them in these new_... lists.
342    new_then_branch = []
343    new_else_branch = []
344    for then_branch_elem, else_branch_elem in izip(then_branch, else_branch):
345        if not isinstance(then_branch_elem, theano.Variable):
346            then_branch_elem = theano.tensor.as_tensor_variable(
347                then_branch_elem)
348        if not isinstance(else_branch_elem, theano.Variable):
349            else_branch_elem = theano.tensor.as_tensor_variable(
350                else_branch_elem)
351
352        if then_branch_elem.type != else_branch_elem.type:
353            # If one of them is a TensorType, and the other one can be
354            # converted into one, then we try to do that.
355            # This case happens when one of the elements has a GPU type,
356            # for instance a shared variable that was silently moved to GPU.
357            if (isinstance(then_branch_elem.type, TensorType) and not
358                    isinstance(else_branch_elem.type, TensorType)):
359                else_branch_elem = then_branch_elem.type.filter_variable(
360                    else_branch_elem)
361
362            elif (isinstance(else_branch_elem.type, TensorType) and not
363                    isinstance(then_branch_elem.type, TensorType)):
364                then_branch_elem = else_branch_elem.type.filter_variable(
365                    then_branch_elem)
366
367            if then_branch_elem.type != else_branch_elem.type:
368                # If the types still don't match, there is a problem.
369                raise TypeError(
370                    'The two branches should have identical types, but '
371                    'they are %s and %s respectively. This error could be '
372                    'raised if for example you provided a one element '
373                    'list on the `then` branch but a tensor on the `else` '
374                    'branch.' %
375                    (then_branch_elem.type, else_branch_elem.type))
376
377        new_then_branch.append(then_branch_elem)
378        new_else_branch.append(else_branch_elem)
379
380    if len(then_branch) != len(else_branch):
381        raise ValueError(('The number of values on the `then` branch'
382                          ' should have the same number of variables as '
383                          'the `else` branch : (variables on `then` '
384                          '%d' % len(then_branch) + ', variables on `else` '
385                          '%d' % len(else_branch) + ')'))
386
387    new_ifelse = IfElse(n_outs=len(then_branch),
388                        as_view=False,
389                        gpu=False,
390                        name=name)
391
392    ins = [condition] + list(new_then_branch) + list(new_else_branch)
393    rval = new_ifelse(*ins, **dict(return_list=True))
394
395    if rval_type is None:
396        return rval[0]
397    elif rval_type is list:
398        return list(rval)
399    else:
400        return tuple(rval)
401
402
403@gof.local_optimizer([IfElse])
404def cond_make_inplace(node):
405    op = node.op
406    if (isinstance(op, IfElse) and
407        not op.as_view and
408        # For big graph, do not make inplace scalar to speed up
409        # optimization.
410        (len(node.fgraph.apply_nodes) < 500 or
411         not all([getattr(o.type, 'ndim', -1) == 0
412                  for o in node.outputs]))):
413        return IfElse(n_outs=op.n_outs,
414                      as_view=True,
415                      gpu=op.gpu,
416                      name=op.name)(*node.inputs, **dict(return_list=True))
417    return False
418
419
420optdb.register('cond_make_inplace', opt.in2out(cond_make_inplace,
421               ignore_newtrees=True), 95, 'fast_run', 'inplace')
422
423# XXX: Optimizations commented pending further debugging (certain optimizations
424# make computation less lazy than it should be currently).
425#
426# ifelse_equilibrium = gof.EquilibriumDB()
427# ifelse_seqopt = gof.SequenceDB()
428# ifelse_equilibrium.register('seq_ifelse', ifelse_seqopt, 'fast_run',
429#                             'ifelse')
430''' Comments:
431I've wrote this comments to explain how the optimization of ifelse function
432(for future developers that need to parse this part of code. Please try to
433keep this comments in sync with whatever changes you add to the code.
434
435ifelse optimization are registered before canonicalize !
436
437The optimizations are called in sequence as follows:
438    * equilibrium shell (runs until no change):
439        * ifelse_lift
440        * ifelse_merge_ifs
441        * ifelse_merge_nodes
442        * ifelse_remove_identical_inside
443        * ifelse_sameCondTrue_inside
444        * ifelse_sameCondFalse_inside
445    * merge_nodes_1
446    * ifelse_sameCondTrue
447    * ifelse_sameCondFalse
448    * ifelse_removeIdentical
449
450where, each of the optimization do the following things:
451    `ifelse_lift` (def cond_lift_single_if):
452
453'''
454# optdb.register('ifelse_equilibriumOpt', ifelse_equilibrium, .5, 'fast_run',
455#                'ifelse')
456
457acceptable_ops = (theano.tensor.basic.Dot,
458                  theano.tensor.basic.Reshape,
459                  theano.tensor.basic.Shape,
460                  theano.tensor.SpecifyShape,
461                  theano.tensor.basic.MaxAndArgmax,
462                  theano.tensor.Subtensor,
463                  theano.tensor.IncSubtensor,
464                  theano.tensor.basic.Rebroadcast,
465                  theano.tensor.basic.Alloc,
466                  theano.tensor.elemwise.Elemwise,
467                  theano.tensor.elemwise.DimShuffle)
468
469
470@gof.local_optimizer(acceptable_ops)
471def ifelse_lift_single_if_through_acceptable_ops(main_node):
472    """This optimization lifts up certain ifelse instances.
473
474        op(ifelse(c, x, y)) -> ifelse(c, op(x), op(y))
475
476    if `op` is in the `acceptable_ops` list, and there is no other if as
477    input to that specific `op`, and the if has no other clients !?
478    """
479    if not (isinstance(main_node.op, acceptable_ops)):
480        return False
481    all_inp_nodes = set()
482    for inp in main_node.inputs:
483        all_inp_nodes.add(inp.owner)
484    ifnodes = [x for x in list(all_inp_nodes)
485               if x and isinstance(x.op, IfElse)]
486    # if we have multiple ifs as inputs .. it all becomes quite complicated
487    # :)
488    if len(ifnodes) != 1:
489        return False
490    node = ifnodes[0]
491    op = node.op
492
493    ts = node.inputs[1:][:op.n_outs]
494    fs = node.inputs[1:][op.n_outs:]
495
496    # outs = main_node.outputs
497    mop = main_node.op
498    true_ins = []
499    false_ins = []
500
501    for x in main_node.inputs:
502        if x in node.outputs:
503            idx = node.outputs.index(x)
504            true_ins.append(ts[idx])
505            false_ins.append(fs[idx])
506        else:
507            true_ins.append(x)
508            false_ins.append(x)
509    true_eval = mop(*true_ins, **dict(return_list=True))
510    false_eval = mop(*false_ins, **dict(return_list=True))
511    # true_eval  = clone(outs, replace = dict(zip(node.outputs, ts)))
512    # false_eval = clone(outs, replace = dict(zip(node.outputs, fs)))
513
514    nw_outs = ifelse(node.inputs[0], true_eval, false_eval, return_list=True)
515    return nw_outs
516
517
518@gof.local_optimizer([IfElse])
519def cond_merge_ifs_true(node):
520    op = node.op
521    if not isinstance(op, IfElse):
522        return False
523    t_ins = node.inputs[1:][:op.n_outs]
524
525    replace = {}
526    for idx, tval in enumerate(t_ins):
527        if (tval.owner and isinstance(tval.owner.op, IfElse) and
528                tval.owner.inputs[0] == node.inputs[0]):
529                ins_op = tval.owner.op
530                ins_t = tval.owner.inputs[1:][:ins_op.n_outs]
531                replace[idx + 1] = ins_t[tval.owner.outputs.index(tval)]
532
533    if len(replace) == 0:
534        return False
535
536    old_ins = list(node.inputs)
537    for pos, var in iteritems(replace):
538        old_ins[pos] = var
539    return op(*old_ins, **dict(return_list=True))
540
541
542@gof.local_optimizer([IfElse])
543def cond_merge_ifs_false(node):
544    op = node.op
545    if not isinstance(op, IfElse):
546        return False
547    f_ins = node.inputs[1:][op.n_outs:]
548
549    replace = {}
550    for idx, fval in enumerate(f_ins):
551        if (fval.owner and isinstance(fval.owner.op, IfElse) and
552                fval.owner.inputs[0] == node.inputs[0]):
553                ins_op = fval.owner.op
554                ins_t = fval.owner.inputs[1:][ins_op.n_outs:]
555                replace[idx + 1 + op.n_outs] = \
556                    ins_t[fval.owner.outputs.index(fval)]
557
558    if len(replace) == 0:
559        return False
560
561    old_ins = list(node.inputs)
562    for pos, var in iteritems(replace):
563        old_ins[pos] = var
564    return op(*old_ins, **dict(return_list=True))
565
566
567class CondMerge(gof.Optimizer):
568    """ Graph Optimizer that merges different cond ops """
569    def add_requirements(self, fgraph):
570        fgraph.add_feature(gof.toolbox.ReplaceValidate())
571
572    def apply(self, fgraph):
573        nodelist = list(fgraph.toposort())
574        cond_nodes = [s for s in nodelist if isinstance(s.op, IfElse)]
575        if len(cond_nodes) < 2:
576            return False
577        merging_node = cond_nodes[0]
578        for proposal in cond_nodes[1:]:
579            if (proposal.inputs[0] == merging_node.inputs[0] and
580                    not gof.graph.is_in_ancestors(proposal, merging_node)):
581                # Create a list of replacements for proposal
582                mn_ts = merging_node.inputs[1:][:merging_node.op.n_outs]
583                mn_fs = merging_node.inputs[1:][merging_node.op.n_outs:]
584                pl_ts = proposal.inputs[1:][:proposal.op.n_outs]
585                pl_fs = proposal.inputs[1:][proposal.op.n_outs:]
586                new_ins = ([merging_node.inputs[0]] +
587                           mn_ts + pl_ts + mn_fs + pl_fs)
588                mn_name = '?'
589                if merging_node.op.name:
590                    mn_name = merging_node.op.name
591                pl_name = '?'
592                # mn_n_ts = len(mn_ts)
593                # mn_n_fs = len(mn_fs)
594                if proposal.op.name:
595                    pl_name = proposal.op.name
596                new_ifelse = IfElse(
597                    n_outs=len(mn_ts + pl_ts),
598                    as_view=False,
599                    gpu=False,
600                    name=mn_name + '&' + pl_name)
601                print('here')
602                new_outs = new_ifelse(*new_ins, **dict(return_list=True))
603                new_outs = [clone(x) for x in new_outs]
604                old_outs = []
605                if type(merging_node.outputs) not in (list, tuple):
606                    old_outs += [merging_node.outputs]
607                else:
608                    old_outs += merging_node.outputs
609                if type(proposal.outputs) not in (list, tuple):
610                    old_outs += [proposal.outputs]
611                else:
612                    old_outs += proposal.outputs
613                pairs = list(zip(old_outs, new_outs))
614                fgraph.replace_all_validate(pairs, reason='cond_merge')
615
616
617@gof.local_optimizer([IfElse])
618def cond_remove_identical(node):
619    op = node.op
620
621    if not isinstance(op, IfElse):
622        return False
623    ts = node.inputs[1:][:op.n_outs]
624    fs = node.inputs[1:][op.n_outs:]
625
626    # sync outs
627    out_map = {}
628    for idx in xrange(len(node.outputs)):
629        if idx not in out_map:
630            for jdx in xrange(idx + 1, len(node.outputs)):
631                if (ts[idx] == ts[jdx] and
632                        fs[idx] == fs[jdx] and
633                        jdx not in out_map):
634                    out_map[jdx] = idx
635
636    if len(out_map) == 0:
637        return False
638
639    nw_ts = []
640    nw_fs = []
641    inv_map = {}
642    pos = 0
643    for idx in xrange(len(node.outputs)):
644        if idx not in out_map:
645            inv_map[idx] = pos
646            pos = pos + 1
647            nw_ts.append(ts[idx])
648            nw_fs.append(fs[idx])
649
650    new_ifelse = IfElse(n_outs=len(nw_ts),
651                        as_view=op.as_view,
652                        gpu=op.gpu,
653                        name=op.name)
654
655    new_ins = [node.inputs[0]] + nw_ts + nw_fs
656    new_outs = new_ifelse(*new_ins, **dict(return_list=True))
657
658    rval = []
659    for idx in xrange(len(node.outputs)):
660        if idx in out_map:
661            rval += [new_outs[inv_map[out_map[idx]]]]
662        else:
663            rval += [new_outs[inv_map[idx]]]
664
665    return rval
666
667
668@gof.local_optimizer([IfElse])
669def cond_merge_random_op(main_node):
670    if isinstance(main_node.op, IfElse):
671        return False
672
673    all_inp_nodes = set()
674    for inp in main_node.inputs:
675        all_inp_nodes.add(inp.owner)
676    cond_nodes = [x for x in list(all_inp_nodes)
677                  if x and isinstance(x.op, IfElse)]
678
679    if len(cond_nodes) < 2:
680        return False
681
682    merging_node = cond_nodes[0]
683    for proposal in cond_nodes[1:]:
684        if (proposal.inputs[0] == merging_node.inputs[0] and
685                not gof.graph.is_in_ancestors(proposal, merging_node) and
686                not gof.graph.is_in_ancestors(merging_node, proposal)):
687            # Create a list of replacements for proposal
688            mn_ts = merging_node.inputs[1:][:merging_node.op.n_outs]
689            mn_fs = merging_node.inputs[1:][merging_node.op.n_outs:]
690            pl_ts = proposal.inputs[1:][:proposal.op.n_outs]
691            pl_fs = proposal.inputs[1:][proposal.op.n_outs:]
692            new_ins = ([merging_node.inputs[0]] +
693                       mn_ts + pl_ts + mn_fs + pl_fs)
694            mn_name = '?'
695            if merging_node.op.name:
696                mn_name = merging_node.op.name
697            pl_name = '?'
698            # mn_n_ts = len(mn_ts)
699            # mn_n_fs = len(mn_fs)
700            if proposal.op.name:
701                pl_name = proposal.op.name
702            new_ifelse = IfElse(
703                n_outs=len(mn_ts + pl_ts),
704                as_view=False,
705                gpu=False,
706                name=mn_name + '&' + pl_name)
707            new_outs = new_ifelse(*new_ins, **dict(return_list=True))
708            old_outs = []
709            if type(merging_node.outputs) not in (list, tuple):
710                old_outs += [merging_node.outputs]
711            else:
712                old_outs += merging_node.outputs
713            if type(proposal.outputs) not in (list, tuple):
714                old_outs += [proposal.outputs]
715            else:
716                old_outs += proposal.outputs
717            pairs = list(zip(old_outs, new_outs))
718            main_outs = clone(main_node.outputs, replace=pairs)
719            return main_outs
720
721
722# XXX: Optimizations commented pending further debugging (certain optimizations
723# make computation less lazy than it should be currently).
724#
725# pushout_equilibrium = gof.EquilibriumDB()
726#
727# XXX: This optimization doesn't seem to exist anymore?
728# pushout_equilibrium.register("cond_lift_single_if",
729#                              opt.in2out(cond_lift_single_if,
730#                                         ignore_newtrees=True),
731#                              'fast_run', 'ifelse')
732#
733# pushout_equilibrium.register("cond_merge_random_op",
734#                              opt.in2out(cond_merge_random_op,
735#                                         ignore_newtrees=True),
736#                              'fast_run', 'ifelse')
737#
738#
739# pushout_equilibrium.register("ifelse_merge",
740#                              gof.MergeOptimizer(skip_const_merge=False),
741#                              'fast_run', 'ifelse')
742#
743# pushout_equilibrium.register("ifelse_remove_identical_inside",
744#                              opt.in2out(cond_remove_identical,
745#                                         ignore_newtrees=True),
746#                              'fast_run', 'ifelse')
747#
748# pushout_equilibrium.register('ifelse_sameCondTrue_inside',
749#                              opt.in2out(cond_merge_ifs_true,
750#                                         ignore_newtrees=True),
751#                              'fast_run', 'ifelse')
752#
753# pushout_equilibrium.register('ifelse_sameCondFalse_inside',
754#                              opt.in2out(cond_merge_ifs_false,
755#                                         ignore_newtrees=True),
756#                              'fast_run', 'ifelse')
757#
758# ifelse_seqopt.register('ifelse_condPushOut_equilibrium',
759#                        pushout_equilibrium,
760#                        1, 'fast_run', 'ifelse')
761#
762# ifelse_seqopt.register('merge_nodes_1',
763#                        gof.MergeOptimizer(skip_const_merge=False),
764#                        2, 'fast_run', 'ifelse')
765#
766#
767# ifelse_seqopt.register('ifelse_sameCondTrue',
768#                        opt.in2out(cond_merge_ifs_true,
769#                                   ignore_newtrees=True),
770#                        3, 'fast_run', 'ifelse')
771#
772#
773# ifelse_seqopt.register('ifelse_sameCondFalse',
774#                        opt.in2out(cond_merge_ifs_false,
775#                                   ignore_newtrees=True),
776#                        4, 'fast_run', 'ifelse')
777#
778#
779# ifelse_seqopt.register('ifelse_removeIdenetical',
780#                        opt.in2out(cond_remove_identical,
781#                                   ignore_newtrees=True),
782#                        7, 'fast_run', 'ifelse')
783