1""" 2IfElse introduces lazy evaluation in Theano (coupled with the CVM/VM 3linkers). It resembles the if clause of any programming language, that 4has a `then` and `else` branch, and executes either one or the other 5according to the condition provided. 6 7This op differs from the already existent `switch` op, that evaluates both 8branches of the clause and afterwards picks (according to the condition) 9which value to report. Note also that `switch` is an elemwise operation (so 10it picks each entry of a matrix according to the condition) while `ifelse` 11is a global operation with a scalar condition. 12""" 13from __future__ import absolute_import, print_function, division 14from copy import deepcopy 15from theano.compat import izip 16import logging 17 18import numpy as np 19 20import theano.tensor 21from theano.tensor import TensorType 22from theano import gof 23from theano.gof import Op, Apply 24 25from six import iteritems 26from six.moves import xrange 27from theano.compile import optdb 28from theano.tensor import opt 29from theano.scan_module.scan_utils import clone 30 31 32__docformat__ = 'restructedtext en' 33__authors__ = ("Razvan Pascanu " 34 "James Bergstra " 35 "Dumitru Erhan " 36 "David Warde-Farley") 37__copyright__ = "(c) 2010, Universite de Montreal" 38__contact__ = "Razvan Pascanu <r.pascanu@gmail>" 39 40_logger = logging.getLogger('theano.ifelse') 41 42 43class IfElse(Op): 44 """ 45 Op that provides conditional graph evaluation if used with the CVM/VM 46 linkers. Note that there exist a helpful function `ifelse` that should 47 be used to instantiate the op! 48 49 According to a scalar condition `condition` the op evaluates and then 50 returns all the tensors provided on the `then` branch, otherwise it 51 evaluates and returns the tensors provided on the `else` branch. The op 52 supports multiple tensors on each branch, with the condition that the same 53 number of tensors are on the `then` as on the `else` and there is a one 54 to one correspondence between them (shape and dtype wise). 55 56 The `then` branch is defined as the first N tensors (after the 57 condition), while the `else` branch is defined as the last N tensors. 58 59 Example usage: 60 61 ``rval = ifelse(condition, rval_if_true1, .., rval_if_trueN, 62 rval_if_false1, rval_if_false2, .., rval_if_falseN)`` 63 64 :note: 65 Other Linkers then CVM and VM are INCOMPATIBLE with this Op, and 66 will ignore its lazy characteristic, computing both the True and 67 False branch before picking one. 68 69 """ 70 def __init__(self, n_outs, as_view=False, gpu=False, name=None): 71 if as_view: 72 # check destroyhandler and others to ensure that a view_map with 73 # multiple inputs can work 74 view_map = {} 75 for idx in xrange(n_outs): 76 view_map[idx] = [idx + 1] 77 self.view_map = view_map 78 self.as_view = as_view 79 self.gpu = gpu 80 self.n_outs = n_outs 81 self.name = name 82 83 def __eq__(self, other): 84 if not type(self) == type(other): 85 return False 86 if not self.as_view == other.as_view: 87 return False 88 if not self.gpu == other.gpu: 89 return False 90 if not self.n_outs == other.n_outs: 91 return False 92 return True 93 94 def __hash__(self): 95 rval = (hash(type(self)) ^ 96 hash(self.as_view) ^ 97 hash(self.gpu) ^ 98 hash(self.n_outs)) 99 return rval 100 101 def __str__(self): 102 args = [] 103 if self.name is not None: 104 args.append(self.name) 105 if self.as_view: 106 args.append('inplace') 107 if self.gpu: 108 args.append('gpu') 109 return 'if{%s}' % ','.join(args) 110 111 def infer_shape(self, node, inputs_shapes): 112 # By construction, corresponding then/else pairs have the same number 113 # of dimensions 114 115 ts_shapes = inputs_shapes[1:][:self.n_outs] 116 fs_shapes = inputs_shapes[1:][self.n_outs:] 117 # All elements of all shape tuples for the true and false outputs are 118 # unpacked into the inputs of a separate ifelse, and then the outputs 119 # of that ifelse are packed back into shape tuples. 120 new_ts_inputs = [] 121 for ts_shape in ts_shapes: 122 if isinstance(ts_shape, (list, tuple)): 123 new_ts_inputs += list(ts_shape) 124 else: 125 # It can be None for generic objects 126 return [None] * self.n_outs 127 128 new_fs_inputs = [] 129 for fs_shape in fs_shapes: 130 if isinstance(fs_shape, (list, tuple)): 131 new_fs_inputs += list(fs_shape) 132 else: 133 # It can be None for generic objects 134 return [None] * self.n_outs 135 136 assert len(new_ts_inputs) == len(new_fs_inputs) 137 if len(new_ts_inputs + new_fs_inputs) > 0: 138 name_tokens = ['shape'] 139 if self.name is not None: 140 name_tokens.append(self.name) 141 142 new_ifelse = IfElse( 143 n_outs=len(new_ts_inputs), 144 as_view=False, 145 gpu=False, 146 name='_'.join(name_tokens)) 147 new_outs = new_ifelse(node.inputs[0], 148 *(new_ts_inputs + new_fs_inputs), 149 **dict(return_list=True)) 150 else: 151 new_outs = [] 152 153 # generate pairs of shapes 154 out_shapes = [] 155 for out in node.outputs: 156 out_shapes.append(tuple(new_outs[:out.ndim])) 157 new_outs = new_outs[out.ndim:] 158 159 # new_outs should be an empty list after last iteration 160 assert len(new_outs) == 0 161 162 return out_shapes 163 164 def make_node(self, c, *args): 165 assert len(args) == 2 * self.n_outs, ( 166 "Wrong number of arguments to make_node: " 167 "expected %d, got %d" % (2 * self.n_outs, len(args)) 168 ) 169 c = theano.tensor.as_tensor_variable(c) 170 if not self.gpu: 171 # When gpu is true, we are given only gpuarrays, and we want 172 # to keep them as gpuarrays 173 nw_args = [] 174 for x in args: 175 if hasattr(x, '_as_TensorVariable'): 176 nw_args.append(x._as_TensorVariable()) 177 elif isinstance(x, theano.Variable): 178 nw_args.append(x) 179 else: 180 nw_args.append(theano.tensor.as_tensor_variable(x)) 181 args = nw_args 182 ts = args[:self.n_outs] 183 fs = args[self.n_outs:] 184 185 for t, f in izip(ts, fs): 186 if t.type != f.type: 187 raise TypeError(('IfElse requires same types for true and ' 188 'false return values'), t, f, t.type, f.type) 189 if c.ndim > 0: 190 raise TypeError(('Condition given to the op has to be a scalar ' 191 'with 0 standing for False, anything else ' 192 'for True')) 193 return Apply(self, [c] + list(args), [t.type() for t in ts]) 194 195 def R_op(self, inputs, eval_points): 196 return self(inputs[0], *eval_points[1:], **dict(return_list=True)) 197 198 def grad(self, ins, grads): 199 ts = ins[1:][:self.n_outs] 200 fs = ins[1:][self.n_outs:] 201 if self.name is not None: 202 nw_name_t = self.name + '_grad_t' 203 nw_name_f = self.name + '_grad_f' 204 else: 205 nw_name_t = None 206 nw_name_f = None 207 if_true_op = IfElse(n_outs=self.n_outs, 208 as_view=self.as_view, 209 gpu=self.gpu, 210 name=nw_name_t) 211 212 if_false_op = IfElse(n_outs=self.n_outs, 213 as_view=self.as_view, 214 gpu=self.gpu, 215 name=nw_name_f) 216 217 # The grads can have a different dtype then the inputs. 218 # As inputs true/false pair must have the same dtype, 219 # we must cast the zeros to the corresponding grad dtype 220 # and not the input dtype. 221 if_true = ([ins[0]] + 222 grads + 223 [theano.tensor.zeros_like(t, dtype=grads[i].dtype) 224 for i, t in enumerate(ts)]) 225 if_false = ([ins[0]] + 226 [theano.tensor.zeros_like(f, dtype=grads[i].dtype) 227 for i, f in enumerate(fs)] + 228 grads) 229 230 condition = ins[0] 231 # condition does affect the elements of the output so it is connected. 232 # For the sake of making the gradient convenient we assume that 233 # condition + epsilon always triggers the same branch as condition 234 condition_grad = condition.zeros_like().astype(theano.config.floatX) 235 return ([condition_grad] + 236 if_true_op(*if_true, **dict(return_list=True)) + 237 if_false_op(*if_false, **dict(return_list=True))) 238 239 def make_thunk(self, node, storage_map, compute_map, no_recycling, impl=None): 240 cond = node.inputs[0] 241 ts = node.inputs[1:][:self.n_outs] 242 fs = node.inputs[1:][self.n_outs:] 243 outputs = node.outputs 244 245 def thunk(): 246 if not compute_map[cond][0]: 247 return [0] 248 else: 249 truthval = storage_map[cond][0] 250 if truthval != 0: 251 ls = [idx + 1 for idx in xrange(self.n_outs) 252 if not compute_map[ts[idx]][0]] 253 if len(ls) > 0: 254 return ls 255 else: 256 for out, t in izip(outputs, ts): 257 compute_map[out][0] = 1 258 val = storage_map[t][0] 259 if self.as_view: 260 storage_map[out][0] = val 261 # Work around broken numpy deepcopy 262 elif type(val) in (np.ndarray, np.memmap): 263 storage_map[out][0] = val.copy() 264 else: 265 storage_map[out][0] = deepcopy(val) 266 return [] 267 else: 268 ls = [1 + idx + self.n_outs for idx in xrange(self.n_outs) 269 if not compute_map[fs[idx]][0]] 270 if len(ls) > 0: 271 return ls 272 else: 273 for out, f in izip(outputs, fs): 274 compute_map[out][0] = 1 275 # can't view both outputs unless destroyhandler 276 # improves 277 # Work around broken numpy deepcopy 278 val = storage_map[f][0] 279 if type(val) in (np.ndarray, np.memmap): 280 storage_map[out][0] = val.copy() 281 else: 282 storage_map[out][0] = deepcopy(val) 283 return [] 284 285 thunk.lazy = True 286 thunk.inputs = [storage_map[v] for v in node.inputs] 287 thunk.outputs = [storage_map[v] for v in node.outputs] 288 return thunk 289 290 291def ifelse(condition, then_branch, else_branch, name=None): 292 """ 293 This function corresponds to an if statement, returning (and evaluating) 294 inputs in the ``then_branch`` if ``condition`` evaluates to True or 295 inputs in the ``else_branch`` if ``condition`` evalutates to False. 296 297 :type condition: scalar like 298 :param condition: 299 ``condition`` should be a tensor scalar representing the condition. 300 If it evaluates to 0 it corresponds to False, anything else stands 301 for True. 302 303 :type then_branch: list of theano expressions/ theano expression 304 :param then_branch: 305 A single theano variable or a list of theano variables that the 306 function should return as the output if ``condition`` evaluates to 307 true. The number of variables should match those in the 308 ``else_branch``, and there should be a one to one correspondance 309 (type wise) with the tensors provided in the else branch 310 311 :type else_branch: list of theano expressions/ theano expressions 312 :param else_branch: 313 A single theano variable or a list of theano variables that the 314 function should return as the output if ``condition`` evaluates to 315 false. The number of variables should match those in the then branch, 316 and there should be a one to one correspondace (type wise) with the 317 tensors provided in the then branch. 318 319 :return: 320 A list of theano variables or a single variable (depending on the 321 nature of the ``then_branch`` and ``else_branch``). More exactly if 322 ``then_branch`` and ``else_branch`` is a tensor, then 323 the return variable will be just a single variable, otherwise a 324 list. The value returns correspond either to the values in the 325 ``then_branch`` or in the ``else_branch`` depending on the value of 326 ``cond``. 327 """ 328 329 rval_type = None 330 if type(then_branch) is list: 331 rval_type = list 332 elif type(then_branch) is tuple: 333 rval_type = tuple 334 335 if type(then_branch) not in (list, tuple): 336 then_branch = [then_branch] 337 if type(else_branch) not in (list, tuple): 338 else_branch = [else_branch] 339 340 # Some of the elements might be converted into another type, 341 # we will store them in these new_... lists. 342 new_then_branch = [] 343 new_else_branch = [] 344 for then_branch_elem, else_branch_elem in izip(then_branch, else_branch): 345 if not isinstance(then_branch_elem, theano.Variable): 346 then_branch_elem = theano.tensor.as_tensor_variable( 347 then_branch_elem) 348 if not isinstance(else_branch_elem, theano.Variable): 349 else_branch_elem = theano.tensor.as_tensor_variable( 350 else_branch_elem) 351 352 if then_branch_elem.type != else_branch_elem.type: 353 # If one of them is a TensorType, and the other one can be 354 # converted into one, then we try to do that. 355 # This case happens when one of the elements has a GPU type, 356 # for instance a shared variable that was silently moved to GPU. 357 if (isinstance(then_branch_elem.type, TensorType) and not 358 isinstance(else_branch_elem.type, TensorType)): 359 else_branch_elem = then_branch_elem.type.filter_variable( 360 else_branch_elem) 361 362 elif (isinstance(else_branch_elem.type, TensorType) and not 363 isinstance(then_branch_elem.type, TensorType)): 364 then_branch_elem = else_branch_elem.type.filter_variable( 365 then_branch_elem) 366 367 if then_branch_elem.type != else_branch_elem.type: 368 # If the types still don't match, there is a problem. 369 raise TypeError( 370 'The two branches should have identical types, but ' 371 'they are %s and %s respectively. This error could be ' 372 'raised if for example you provided a one element ' 373 'list on the `then` branch but a tensor on the `else` ' 374 'branch.' % 375 (then_branch_elem.type, else_branch_elem.type)) 376 377 new_then_branch.append(then_branch_elem) 378 new_else_branch.append(else_branch_elem) 379 380 if len(then_branch) != len(else_branch): 381 raise ValueError(('The number of values on the `then` branch' 382 ' should have the same number of variables as ' 383 'the `else` branch : (variables on `then` ' 384 '%d' % len(then_branch) + ', variables on `else` ' 385 '%d' % len(else_branch) + ')')) 386 387 new_ifelse = IfElse(n_outs=len(then_branch), 388 as_view=False, 389 gpu=False, 390 name=name) 391 392 ins = [condition] + list(new_then_branch) + list(new_else_branch) 393 rval = new_ifelse(*ins, **dict(return_list=True)) 394 395 if rval_type is None: 396 return rval[0] 397 elif rval_type is list: 398 return list(rval) 399 else: 400 return tuple(rval) 401 402 403@gof.local_optimizer([IfElse]) 404def cond_make_inplace(node): 405 op = node.op 406 if (isinstance(op, IfElse) and 407 not op.as_view and 408 # For big graph, do not make inplace scalar to speed up 409 # optimization. 410 (len(node.fgraph.apply_nodes) < 500 or 411 not all([getattr(o.type, 'ndim', -1) == 0 412 for o in node.outputs]))): 413 return IfElse(n_outs=op.n_outs, 414 as_view=True, 415 gpu=op.gpu, 416 name=op.name)(*node.inputs, **dict(return_list=True)) 417 return False 418 419 420optdb.register('cond_make_inplace', opt.in2out(cond_make_inplace, 421 ignore_newtrees=True), 95, 'fast_run', 'inplace') 422 423# XXX: Optimizations commented pending further debugging (certain optimizations 424# make computation less lazy than it should be currently). 425# 426# ifelse_equilibrium = gof.EquilibriumDB() 427# ifelse_seqopt = gof.SequenceDB() 428# ifelse_equilibrium.register('seq_ifelse', ifelse_seqopt, 'fast_run', 429# 'ifelse') 430''' Comments: 431I've wrote this comments to explain how the optimization of ifelse function 432(for future developers that need to parse this part of code. Please try to 433keep this comments in sync with whatever changes you add to the code. 434 435ifelse optimization are registered before canonicalize ! 436 437The optimizations are called in sequence as follows: 438 * equilibrium shell (runs until no change): 439 * ifelse_lift 440 * ifelse_merge_ifs 441 * ifelse_merge_nodes 442 * ifelse_remove_identical_inside 443 * ifelse_sameCondTrue_inside 444 * ifelse_sameCondFalse_inside 445 * merge_nodes_1 446 * ifelse_sameCondTrue 447 * ifelse_sameCondFalse 448 * ifelse_removeIdentical 449 450where, each of the optimization do the following things: 451 `ifelse_lift` (def cond_lift_single_if): 452 453''' 454# optdb.register('ifelse_equilibriumOpt', ifelse_equilibrium, .5, 'fast_run', 455# 'ifelse') 456 457acceptable_ops = (theano.tensor.basic.Dot, 458 theano.tensor.basic.Reshape, 459 theano.tensor.basic.Shape, 460 theano.tensor.SpecifyShape, 461 theano.tensor.basic.MaxAndArgmax, 462 theano.tensor.Subtensor, 463 theano.tensor.IncSubtensor, 464 theano.tensor.basic.Rebroadcast, 465 theano.tensor.basic.Alloc, 466 theano.tensor.elemwise.Elemwise, 467 theano.tensor.elemwise.DimShuffle) 468 469 470@gof.local_optimizer(acceptable_ops) 471def ifelse_lift_single_if_through_acceptable_ops(main_node): 472 """This optimization lifts up certain ifelse instances. 473 474 op(ifelse(c, x, y)) -> ifelse(c, op(x), op(y)) 475 476 if `op` is in the `acceptable_ops` list, and there is no other if as 477 input to that specific `op`, and the if has no other clients !? 478 """ 479 if not (isinstance(main_node.op, acceptable_ops)): 480 return False 481 all_inp_nodes = set() 482 for inp in main_node.inputs: 483 all_inp_nodes.add(inp.owner) 484 ifnodes = [x for x in list(all_inp_nodes) 485 if x and isinstance(x.op, IfElse)] 486 # if we have multiple ifs as inputs .. it all becomes quite complicated 487 # :) 488 if len(ifnodes) != 1: 489 return False 490 node = ifnodes[0] 491 op = node.op 492 493 ts = node.inputs[1:][:op.n_outs] 494 fs = node.inputs[1:][op.n_outs:] 495 496 # outs = main_node.outputs 497 mop = main_node.op 498 true_ins = [] 499 false_ins = [] 500 501 for x in main_node.inputs: 502 if x in node.outputs: 503 idx = node.outputs.index(x) 504 true_ins.append(ts[idx]) 505 false_ins.append(fs[idx]) 506 else: 507 true_ins.append(x) 508 false_ins.append(x) 509 true_eval = mop(*true_ins, **dict(return_list=True)) 510 false_eval = mop(*false_ins, **dict(return_list=True)) 511 # true_eval = clone(outs, replace = dict(zip(node.outputs, ts))) 512 # false_eval = clone(outs, replace = dict(zip(node.outputs, fs))) 513 514 nw_outs = ifelse(node.inputs[0], true_eval, false_eval, return_list=True) 515 return nw_outs 516 517 518@gof.local_optimizer([IfElse]) 519def cond_merge_ifs_true(node): 520 op = node.op 521 if not isinstance(op, IfElse): 522 return False 523 t_ins = node.inputs[1:][:op.n_outs] 524 525 replace = {} 526 for idx, tval in enumerate(t_ins): 527 if (tval.owner and isinstance(tval.owner.op, IfElse) and 528 tval.owner.inputs[0] == node.inputs[0]): 529 ins_op = tval.owner.op 530 ins_t = tval.owner.inputs[1:][:ins_op.n_outs] 531 replace[idx + 1] = ins_t[tval.owner.outputs.index(tval)] 532 533 if len(replace) == 0: 534 return False 535 536 old_ins = list(node.inputs) 537 for pos, var in iteritems(replace): 538 old_ins[pos] = var 539 return op(*old_ins, **dict(return_list=True)) 540 541 542@gof.local_optimizer([IfElse]) 543def cond_merge_ifs_false(node): 544 op = node.op 545 if not isinstance(op, IfElse): 546 return False 547 f_ins = node.inputs[1:][op.n_outs:] 548 549 replace = {} 550 for idx, fval in enumerate(f_ins): 551 if (fval.owner and isinstance(fval.owner.op, IfElse) and 552 fval.owner.inputs[0] == node.inputs[0]): 553 ins_op = fval.owner.op 554 ins_t = fval.owner.inputs[1:][ins_op.n_outs:] 555 replace[idx + 1 + op.n_outs] = \ 556 ins_t[fval.owner.outputs.index(fval)] 557 558 if len(replace) == 0: 559 return False 560 561 old_ins = list(node.inputs) 562 for pos, var in iteritems(replace): 563 old_ins[pos] = var 564 return op(*old_ins, **dict(return_list=True)) 565 566 567class CondMerge(gof.Optimizer): 568 """ Graph Optimizer that merges different cond ops """ 569 def add_requirements(self, fgraph): 570 fgraph.add_feature(gof.toolbox.ReplaceValidate()) 571 572 def apply(self, fgraph): 573 nodelist = list(fgraph.toposort()) 574 cond_nodes = [s for s in nodelist if isinstance(s.op, IfElse)] 575 if len(cond_nodes) < 2: 576 return False 577 merging_node = cond_nodes[0] 578 for proposal in cond_nodes[1:]: 579 if (proposal.inputs[0] == merging_node.inputs[0] and 580 not gof.graph.is_in_ancestors(proposal, merging_node)): 581 # Create a list of replacements for proposal 582 mn_ts = merging_node.inputs[1:][:merging_node.op.n_outs] 583 mn_fs = merging_node.inputs[1:][merging_node.op.n_outs:] 584 pl_ts = proposal.inputs[1:][:proposal.op.n_outs] 585 pl_fs = proposal.inputs[1:][proposal.op.n_outs:] 586 new_ins = ([merging_node.inputs[0]] + 587 mn_ts + pl_ts + mn_fs + pl_fs) 588 mn_name = '?' 589 if merging_node.op.name: 590 mn_name = merging_node.op.name 591 pl_name = '?' 592 # mn_n_ts = len(mn_ts) 593 # mn_n_fs = len(mn_fs) 594 if proposal.op.name: 595 pl_name = proposal.op.name 596 new_ifelse = IfElse( 597 n_outs=len(mn_ts + pl_ts), 598 as_view=False, 599 gpu=False, 600 name=mn_name + '&' + pl_name) 601 print('here') 602 new_outs = new_ifelse(*new_ins, **dict(return_list=True)) 603 new_outs = [clone(x) for x in new_outs] 604 old_outs = [] 605 if type(merging_node.outputs) not in (list, tuple): 606 old_outs += [merging_node.outputs] 607 else: 608 old_outs += merging_node.outputs 609 if type(proposal.outputs) not in (list, tuple): 610 old_outs += [proposal.outputs] 611 else: 612 old_outs += proposal.outputs 613 pairs = list(zip(old_outs, new_outs)) 614 fgraph.replace_all_validate(pairs, reason='cond_merge') 615 616 617@gof.local_optimizer([IfElse]) 618def cond_remove_identical(node): 619 op = node.op 620 621 if not isinstance(op, IfElse): 622 return False 623 ts = node.inputs[1:][:op.n_outs] 624 fs = node.inputs[1:][op.n_outs:] 625 626 # sync outs 627 out_map = {} 628 for idx in xrange(len(node.outputs)): 629 if idx not in out_map: 630 for jdx in xrange(idx + 1, len(node.outputs)): 631 if (ts[idx] == ts[jdx] and 632 fs[idx] == fs[jdx] and 633 jdx not in out_map): 634 out_map[jdx] = idx 635 636 if len(out_map) == 0: 637 return False 638 639 nw_ts = [] 640 nw_fs = [] 641 inv_map = {} 642 pos = 0 643 for idx in xrange(len(node.outputs)): 644 if idx not in out_map: 645 inv_map[idx] = pos 646 pos = pos + 1 647 nw_ts.append(ts[idx]) 648 nw_fs.append(fs[idx]) 649 650 new_ifelse = IfElse(n_outs=len(nw_ts), 651 as_view=op.as_view, 652 gpu=op.gpu, 653 name=op.name) 654 655 new_ins = [node.inputs[0]] + nw_ts + nw_fs 656 new_outs = new_ifelse(*new_ins, **dict(return_list=True)) 657 658 rval = [] 659 for idx in xrange(len(node.outputs)): 660 if idx in out_map: 661 rval += [new_outs[inv_map[out_map[idx]]]] 662 else: 663 rval += [new_outs[inv_map[idx]]] 664 665 return rval 666 667 668@gof.local_optimizer([IfElse]) 669def cond_merge_random_op(main_node): 670 if isinstance(main_node.op, IfElse): 671 return False 672 673 all_inp_nodes = set() 674 for inp in main_node.inputs: 675 all_inp_nodes.add(inp.owner) 676 cond_nodes = [x for x in list(all_inp_nodes) 677 if x and isinstance(x.op, IfElse)] 678 679 if len(cond_nodes) < 2: 680 return False 681 682 merging_node = cond_nodes[0] 683 for proposal in cond_nodes[1:]: 684 if (proposal.inputs[0] == merging_node.inputs[0] and 685 not gof.graph.is_in_ancestors(proposal, merging_node) and 686 not gof.graph.is_in_ancestors(merging_node, proposal)): 687 # Create a list of replacements for proposal 688 mn_ts = merging_node.inputs[1:][:merging_node.op.n_outs] 689 mn_fs = merging_node.inputs[1:][merging_node.op.n_outs:] 690 pl_ts = proposal.inputs[1:][:proposal.op.n_outs] 691 pl_fs = proposal.inputs[1:][proposal.op.n_outs:] 692 new_ins = ([merging_node.inputs[0]] + 693 mn_ts + pl_ts + mn_fs + pl_fs) 694 mn_name = '?' 695 if merging_node.op.name: 696 mn_name = merging_node.op.name 697 pl_name = '?' 698 # mn_n_ts = len(mn_ts) 699 # mn_n_fs = len(mn_fs) 700 if proposal.op.name: 701 pl_name = proposal.op.name 702 new_ifelse = IfElse( 703 n_outs=len(mn_ts + pl_ts), 704 as_view=False, 705 gpu=False, 706 name=mn_name + '&' + pl_name) 707 new_outs = new_ifelse(*new_ins, **dict(return_list=True)) 708 old_outs = [] 709 if type(merging_node.outputs) not in (list, tuple): 710 old_outs += [merging_node.outputs] 711 else: 712 old_outs += merging_node.outputs 713 if type(proposal.outputs) not in (list, tuple): 714 old_outs += [proposal.outputs] 715 else: 716 old_outs += proposal.outputs 717 pairs = list(zip(old_outs, new_outs)) 718 main_outs = clone(main_node.outputs, replace=pairs) 719 return main_outs 720 721 722# XXX: Optimizations commented pending further debugging (certain optimizations 723# make computation less lazy than it should be currently). 724# 725# pushout_equilibrium = gof.EquilibriumDB() 726# 727# XXX: This optimization doesn't seem to exist anymore? 728# pushout_equilibrium.register("cond_lift_single_if", 729# opt.in2out(cond_lift_single_if, 730# ignore_newtrees=True), 731# 'fast_run', 'ifelse') 732# 733# pushout_equilibrium.register("cond_merge_random_op", 734# opt.in2out(cond_merge_random_op, 735# ignore_newtrees=True), 736# 'fast_run', 'ifelse') 737# 738# 739# pushout_equilibrium.register("ifelse_merge", 740# gof.MergeOptimizer(skip_const_merge=False), 741# 'fast_run', 'ifelse') 742# 743# pushout_equilibrium.register("ifelse_remove_identical_inside", 744# opt.in2out(cond_remove_identical, 745# ignore_newtrees=True), 746# 'fast_run', 'ifelse') 747# 748# pushout_equilibrium.register('ifelse_sameCondTrue_inside', 749# opt.in2out(cond_merge_ifs_true, 750# ignore_newtrees=True), 751# 'fast_run', 'ifelse') 752# 753# pushout_equilibrium.register('ifelse_sameCondFalse_inside', 754# opt.in2out(cond_merge_ifs_false, 755# ignore_newtrees=True), 756# 'fast_run', 'ifelse') 757# 758# ifelse_seqopt.register('ifelse_condPushOut_equilibrium', 759# pushout_equilibrium, 760# 1, 'fast_run', 'ifelse') 761# 762# ifelse_seqopt.register('merge_nodes_1', 763# gof.MergeOptimizer(skip_const_merge=False), 764# 2, 'fast_run', 'ifelse') 765# 766# 767# ifelse_seqopt.register('ifelse_sameCondTrue', 768# opt.in2out(cond_merge_ifs_true, 769# ignore_newtrees=True), 770# 3, 'fast_run', 'ifelse') 771# 772# 773# ifelse_seqopt.register('ifelse_sameCondFalse', 774# opt.in2out(cond_merge_ifs_false, 775# ignore_newtrees=True), 776# 4, 'fast_run', 'ifelse') 777# 778# 779# ifelse_seqopt.register('ifelse_removeIdenetical', 780# opt.in2out(cond_remove_identical, 781# ignore_newtrees=True), 782# 7, 'fast_run', 'ifelse') 783