1""" 2This module provides optimizations for scan. 3The Optimization provided in this file: 4 5local opt: remove_constants_and_unused_inputs_scan, 6 constant_folding_for_scan2, 7 scan_merge_inouts 8 They are wrapped in in2out to create global opt. 9global opt: ScanInplaceOptimizer, 10 PushOutNonSeqScan, 11 PushOutSeqScan, 12 PushOutDot1, 13 ScanMerge, 14 ScanSaveMem 15 16How the are registered: 17 18optdb: scan_eqopt1 (.1), scan_eqopt2(1.6), scan_inplace(75) 19scan_eqopt1 -> scan_seqopt1 20scan_seqopt1 -> in2out(remove_constants_and_unused_inputs_scan)(1), 21 PushOutNonSeqScan(2), 22 PushOutSeqScan(3), PushOutDot1(4) 23scan_eqopt2 -> They are all global optimizer. (in2out convert local to global). 24 This is important, as the order is important and all global 25 optimizer run before local optimizer in the order they where 26 registered. (So don't change the order we register them!) 27 If we convert to local optimizer, we must convert all of them 28 to local optimizer. But: 29 1) can ScanMerge be made local? Can we keep only this one 30 global? 31 2) ScanSaveMem assert that we remove all nodes outputs, 32 we need to keep this. 33 3) It is ScanSaveMem suppose the the others ran before. 34 I added an assert at one place, but didn't looked for 35 other place. 36 4) Moving this to local opt could speed up significant this opt, 37 as we pass frequently on all nodes in the graph for no 38 good reason. 39 5) We register remove_constant_* many places, as some 40 opt create them and let this one clean up the mess. 41 Doing it that way, make things simpler for those already 42 complex opt. 43 44 in2out(constant_folding), 45 in2out(remove_constants_and_unused_inputs_scan1), 46 ScanMerge, 47 in2out(remove_constants_and_unused_inputs_scan2), 48 in2out(scan_merge_inouts), 49 ScanSaveMem, 50 in2out(remove_constants_and_unused_inputs_scan3) 51""" 52from __future__ import absolute_import, print_function, division 53import logging 54import copy 55from sys import maxsize 56from collections import OrderedDict 57import numpy as np 58 59import theano 60from theano import tensor, scalar 61from theano.tensor import opt, get_scalar_constant_value, Alloc, AllocEmpty 62from theano import gof 63from six import integer_types, iteritems 64from six.moves import xrange 65from theano.compile import optdb 66from theano.compile.function_module import deep_copy_op 67from theano.gof import toolbox, DestroyHandler, InconsistencyError 68from theano.gof.opt import Optimizer 69from theano.gof.opt import pre_constant_merge, pre_greedy_local_optimizer 70 71from theano.scan_module import scan_op 72from theano.scan_module import scan_utils 73from theano.scan_module.scan_utils import equal_computations, scan_args 74 75__docformat__ = 'restructedtext en' 76__authors__ = ("Razvan Pascanu " 77 "Frederic Bastien " 78 "James Bergstra " 79 "Pascal Lamblin " 80 "Arnaud Bergeron ") 81__copyright__ = "(c) 2010, Universite de Montreal" 82__contact__ = "Razvan Pascanu <r.pascanu@gmail>" 83 84 85# Logging function for sending warning or info 86_logger = logging.getLogger('theano.scan_module.scan_opt') 87 88list_opt_slice = [tensor.opt.local_abs_merge, 89 tensor.opt.local_mul_switch_sink, 90 tensor.opt.local_upcast_elemwise_constant_inputs, 91 tensor.opt.local_useless_switch, 92 tensor.opt.constant_folding] 93 94 95def warning(*msg): 96 _logger.warning('WARNING theano.scan: ' + ' '.join(msg)) 97 98 99def info(*msg): 100 _logger.info('INFO theano.scan: ' + ' '.join(msg)) 101 102 103@gof.local_optimizer([scan_op.Scan]) 104def remove_constants_and_unused_inputs_scan(node): 105 """ 106 Move constants into the inner graph, and remove unused inputs. 107 108 Constants that are in the outer graph are represented by a free symbolic 109 variable in the inner graph. If we move them into the inner graph, 110 constant-folding can happen in the inner graph. 111 This is applied only on sequences and non-sequences, 112 not on initial states. 113 114 """ 115 if not isinstance(node.op, scan_op.Scan): 116 return False 117 op = node.op 118 # We only need to take care of sequences and other arguments 119 st = op.n_seqs 120 st += int(sum([len(x) for x in 121 op.tap_array[:(op.n_mit_mot + op.n_mit_sot)]])) 122 st += op.n_sit_sot 123 st += op.n_shared_outs 124 125 op_ins = op.inputs 126 op_outs = op.outputs 127 128 # Corresponds to the initial states, which should stay untouched. 129 # We put those variables aside, and put them back at the end. 130 out_stuff_inner = op_ins[op.n_seqs:st] 131 132 non_seqs = op_ins[st:] 133 st = (op.n_seqs + 134 op.n_mit_mot + 135 op.n_mit_sot + 136 op.n_sit_sot + 137 op.n_nit_sot + 138 op.n_shared_outs + 1) 139 outer_non_seqs = node.inputs[st:] 140 out_stuff_outer = node.inputs[1 + op.n_seqs:st] 141 142 # To replace constants in the outer graph by clones in the inner graph 143 givens = OrderedDict() 144 # All the inputs of the inner graph of the new scan 145 nw_inner = [] 146 # Same for the outer graph, initialized w/ number of steps 147 nw_outer = [node.inputs[0]] 148 149 all_ins = gof.graph.inputs(op_outs) 150 for idx in xrange(op.n_seqs): 151 node_inp = node.inputs[idx + 1] 152 if (isinstance(node_inp, tensor.TensorConstant) and 153 node_inp.tag.unique_value is not None): 154 try: 155 # This works if input is a constant that has all entries 156 # equal 157 givens[op_ins[idx]] = node_inp.clone()[0] 158 except TypeError: 159 pass 160 elif op_ins[idx] in all_ins: 161 # Check for identical other sequence 162 identical_seqs = [x for x in nw_outer 163 if scan_utils.equal_computations( 164 [x], [node_inp])] 165 if identical_seqs: 166 index = node.inputs.index(identical_seqs[0]) - 1 167 givens[op_ins[idx]] = op_ins[index] 168 else: 169 nw_inner.append(op_ins[idx]) 170 nw_outer.append(node_inp) 171 172 nw_n_seqs = len(nw_inner) 173 # Add outputs stuff 174 nw_inner += out_stuff_inner 175 nw_outer += out_stuff_outer 176 177 # Look through non sequences 178 nw_inner_nonseq = [] 179 nw_outer_nonseq = [] 180 for idx, (nw_in, nw_out) in enumerate(zip(non_seqs, outer_non_seqs)): 181 if isinstance(nw_out, tensor.Constant): 182 givens[nw_in] = nw_out.clone() 183 elif nw_in in all_ins: 184 # Indices of elements of nw_outer_nonseq that are equivalent 185 # to nw_out. 186 identical_nonseq_idx = [ 187 i for (i, x) in enumerate(nw_outer_nonseq) 188 if scan_utils.equal_computations([x], [nw_out])] 189 if identical_nonseq_idx: 190 givens[nw_in] = nw_inner_nonseq[identical_nonseq_idx[0]] 191 else: 192 nw_inner_nonseq.append(nw_in) 193 nw_outer_nonseq.append(nw_out) 194 195 nw_inner.extend(nw_inner_nonseq) 196 nw_outer.extend(nw_outer_nonseq) 197 198 if len(nw_inner) != len(op_ins): 199 op_outs = scan_utils.clone(op_outs, replace=givens) 200 nw_info = copy.deepcopy(op.info) 201 nw_info['n_seqs'] = nw_n_seqs 202 # DEBUG CHECK 203 nwScan = scan_op.Scan(nw_inner, op_outs, nw_info) 204 nw_outs = nwScan(*nw_outer, **dict(return_list=True)) 205 return OrderedDict([("remove", [node])] + list(zip(node.outputs, nw_outs))) 206 else: 207 return False 208 209 210# This is a global opt for historical reason 211# It should be possible to change it to a local opt. 212class PushOutNonSeqScan(gof.Optimizer): 213 """ 214 A global optimizer for pushing out the variables inside the scan that depend 215 only on non-sequences. 216 """ 217 218 def __init__(self): 219 gof.Optimizer.__init__(self) 220 221 def add_requirements(self, fgraph): 222 fgraph.attach_feature(gof.toolbox.ReplaceValidate()) 223 224 def apply(self, fgraph): 225 nodelist = [x for x in fgraph.toposort() if isinstance(x.op, 226 scan_op.Scan)] 227 for node in nodelist: 228 self.process_node(fgraph, node) 229 230 def process_node(self, fgraph, node): 231 """ 232 IMPORTANT NOTE: This function uses set and dictionary data structures. 233 By default they are not ordered for efficiency reasons. Take care 234 and make sure of changing them with their Ordered counterparts if you 235 need to iterate over these variables. 236 237 """ 238 # this flag tells if there was any change during the last iterations 239 clean_inputs, clean_outputs = scan_utils.reconstruct_graph( 240 node.op.inputs, node.op.outputs) 241 242 local_fgraph_topo = theano.gof.graph.io_toposort(clean_inputs, 243 clean_outputs) 244 local_fgraph_outs_set = set(clean_outputs) 245 local_fgraph_outs_map = dict([(v, k) for k, v in 246 enumerate(clean_outputs)]) 247 248 to_remove_set = set() 249 to_replace_set = set() 250 to_replace_map = OrderedDict() 251 252 def add_to_replace(y): 253 to_replace_set.add(y) 254 to_replace_map[y] = add_to_replace.n 255 add_to_replace.n += 1 256 add_to_replace.n = 0 257 258 replace_with_in = [] 259 replace_with_out = [] 260 261 op = node.op 262 # Construct the list of non_sequences to simplify a few things 263 inner_non_seqs = op.inner_non_seqs(clean_inputs) 264 inner_non_seqs_set = set(inner_non_seqs) 265 inner_non_seqs_map = dict([(v, k) for k, v in 266 enumerate(inner_non_seqs)]) 267 268 outer_non_seqs = op.outer_non_seqs(node.inputs) 269 270 inner_seqs = op.inner_seqs(clean_inputs) 271 outer_seqs = op.outer_seqs(node.inputs) 272 273 assert len(inner_non_seqs) == len(outer_non_seqs) 274 assert len(inner_seqs) == len(outer_seqs) 275 276 for nd in local_fgraph_topo: 277 if ( # we haven't already looked at this node 278 nd not in to_remove_set and 279 all([((x in inner_non_seqs_set) or 280 (x.owner in to_remove_set) or 281 isinstance(x, tensor.Constant)) 282 for x in nd.inputs]) and 283 # we can do this because the assumption is that a 284 # viewOp or deepCopyOp will be just at the end of the 285 # function and not somewhere in the middle .. 286 not isinstance(nd.op, theano.compile.ViewOp) and 287 not isinstance(nd.op, theano.compile.DeepCopyOp)): 288 289 # We have a candidate node to removable 290 # Step 1. Reconstruct it on outside 291 to_remove_set.add(nd) 292 outside_ins = [] 293 for x in nd.inputs: 294 if x in inner_non_seqs_set: 295 _idx = inner_non_seqs_map[x] 296 outside_ins.append(outer_non_seqs[_idx]) 297 elif x in to_replace_set: 298 outside_ins.append(replace_with_out[to_replace_map[x]]) 299 elif isinstance(x, theano.Constant): 300 outside_ins.append(x.clone()) 301 else: 302 raise Exception( 303 ('Error in the `scan_pushout_non_seq_' 304 'operations`. The optimization tries ' 305 'to move some computation fron scan ' 306 'which is not allowed to move. Report ' 307 'this on theano-users list'), x) 308 outside_ins = [x.type.filter_variable(y) for x, y in 309 zip(nd.inputs, outside_ins)] 310 311 # Do not call make_node for test_value 312 nw_outer_node = nd.op(*outside_ins, 313 **dict(return_list=True))[0].owner 314 315 # Step 2. Create variables for replacements 316 for idx, y in enumerate(nd.outputs): 317 y_place_holder = scan_utils.safe_new(y, '_replace') 318 add_to_replace(y) 319 replace_with_in.append(y_place_holder) 320 assert isinstance(y, type(nw_outer_node.outputs[idx])) 321 replace_with_out.append(nw_outer_node.outputs[idx]) 322 323 # We need to check all candidate replacements and choose those that 324 # make sense for us 325 # Step 1. which elements of `to_replace` are used by remaining 326 # components of the inner function 327 clean_to_replace = [] 328 clean_replace_with_in = [] 329 clean_replace_with_out = [] 330 existent_nodes = [nd for nd in local_fgraph_topo 331 if nd not in to_remove_set] 332 existent_nodes_set = set(existent_nodes) 333 334 to_keep_set = set([]) 335 for nd in existent_nodes: 336 to_keep_set.update(nd.inputs) 337 338 for out, idx in to_replace_map.items(): 339 if ( # If types are different, conversion Op will be inserted, 340 # and it may trigger an infinite loop. 341 replace_with_in[idx].type == out.type and 342 out in to_keep_set and 343 out.owner not in existent_nodes_set): 344 clean_to_replace.append(out) 345 clean_replace_with_in.append(replace_with_in[idx]) 346 clean_replace_with_out.append(replace_with_out[idx]) 347 348 if len(clean_to_replace) > 0: 349 # We can finally put an end to all this madness 350 givens = OrderedDict() 351 nw_outer = [] 352 nw_inner = [] 353 for to_repl, repl_in, repl_out in zip(clean_to_replace, 354 clean_replace_with_in, 355 clean_replace_with_out): 356 if isinstance(repl_out, theano.Constant): 357 repl_in = repl_out.clone() 358 else: 359 nw_inner.append(repl_in) 360 nw_outer.append(repl_out) 361 givens[to_repl] = repl_in 362 363 op_outs = scan_utils.clone(clean_outputs, replace=givens) 364 op_ins = clean_inputs + nw_inner 365 366 # Reconstruct node 367 nwScan = scan_op.Scan(op_ins, op_outs, op.info) 368 369 # Do not call make_node for test_value 370 nw_node = nwScan(*(node.inputs + nw_outer), 371 **dict(return_list=True))[0].owner 372 373 fgraph.replace_all_validate_remove( 374 list(zip(node.outputs, nw_node.outputs)), 375 remove=[node], 376 reason='scanOp_pushout_nonseqs_ops') 377 return True 378 elif not to_keep_set: 379 # Nothing in the inner graph should be kept 380 replace_with = OrderedDict() 381 for out, idx in to_replace_map.items(): 382 if out in local_fgraph_outs_set: 383 x = node.outputs[local_fgraph_outs_map[out]] 384 y = replace_with_out[idx] 385 shape = [shp for shp in y.shape] 386 replace_with[x] = tensor.alloc(y, 387 node.inputs[0], 388 *shape) 389 390 # We need to add one extra dimension to the outputs 391 # because the scan op expects for a tensor3, to which an 392 # subtensor is applied that takes only the last element 393 if replace_with: 394 if len(node.outputs) == len(replace_with): 395 # Every output of the node has a replacement, the Scan 396 # node can be removed from the graph 397 fgraph.replace_all_validate_remove( 398 replace_with.items(), 399 remove=[node], 400 reason='scanOp_pushout_nonseqs_ops') 401 else: 402 # The node has some outputs for which no replacement has 403 # been established. This can occur for outputs that are 404 # not produced by apply nodes (since the optimizations 405 # only visits apply nodes) such as constants or inputs 406 # passed directly as outputs. The replacements can be 407 # performed but the Scan node can't be removed at this 408 # point. 409 fgraph.replace_all_validate( 410 replace_with.items(), 411 reason='scanOp_pushout_nonseqs_ops') 412 413 else: 414 return False 415 416 417# This is a global opt for historical reason 418# It should be possible to change it to a local opt. 419class PushOutSeqScan(gof.Optimizer): 420 """ 421 A global optimizer for pushing out the variables inside the 422 scan that depend only on constants and sequences. 423 """ 424 425 def __init__(self): 426 gof.Optimizer.__init__(self) 427 428 def add_requirements(self, fgraph): 429 fgraph.attach_feature(gof.toolbox.ReplaceValidate()) 430 431 def apply(self, fgraph): 432 nodelist = [x for x in fgraph.toposort() 433 if isinstance(x.op, scan_op.Scan)] 434 for node in nodelist: 435 self.process_node(fgraph, node) 436 437 def process_node(self, fgraph, node): 438 """ 439 IMPORTANT NOTE: This function uses set and dictionary data structure. 440 By default they are not ordered for efficiency reasons. Take care 441 and make sure of changing them to Ordered versions if you need to 442 iterate over those variables. 443 444 """ 445 # this flag tells if there was any change during the last iterations 446 clean_inputs, clean_outputs = scan_utils.reconstruct_graph( 447 node.op.inputs, node.op.outputs) 448 449 local_fgraph_topo = theano.gof.graph.io_toposort(clean_inputs, 450 clean_outputs) 451 local_fgraph_outs_set = set(clean_outputs) 452 local_fgraph_outs_map = dict([(v, k) for k, v in 453 enumerate(clean_outputs)]) 454 455 to_remove_set = set() 456 to_replace_set = set() 457 to_replace_map = OrderedDict() 458 459 def add_to_replace(y): 460 to_replace_set.add(y) 461 to_replace_map[y] = add_to_replace.n 462 add_to_replace.n += 1 463 add_to_replace.n = 0 464 465 replace_with_in = [] 466 replace_with_out = [] 467 468 op = node.op 469 # Construct the list of non_sequences to simplify a few things 470 inner_non_seqs = op.inner_non_seqs(clean_inputs) 471 inner_non_seqs_set = set(inner_non_seqs) 472 inner_non_seqs_map = dict([(v, k) for k, v in 473 enumerate(inner_non_seqs)]) 474 475 outer_non_seqs = op.outer_non_seqs(node.inputs) 476 inner_seqs = op.inner_seqs(clean_inputs) 477 inner_seqs_set = set(inner_seqs) 478 inner_seqs_map = dict([(v, k) for k, v in 479 enumerate(inner_seqs)]) 480 481 outer_seqs = op.outer_seqs(node.inputs) 482 assert len(inner_non_seqs) == len(outer_non_seqs) 483 assert len(inner_seqs) == len(outer_seqs) 484 485 for nd in local_fgraph_topo: 486 if (nd not in to_remove_set and 487 all([(x in inner_non_seqs_set) or 488 (x.owner in to_remove_set) or 489 isinstance(x, tensor.Constant) or 490 (x in inner_seqs_set) for x in nd.inputs]) and 491 isinstance(nd.op, theano.tensor.Elemwise)): 492 493 outside_ins = [] 494 depends_on_seqs = False 495 496 for x in nd.inputs: 497 if x in inner_non_seqs_set: 498 _idx = inner_non_seqs_map[x] 499 outside_ins.append(outer_non_seqs[_idx]) 500 elif x in inner_seqs_set: 501 outside_ins.append(outer_seqs[inner_seqs_map[x]]) 502 depends_on_seqs = True 503 elif x in to_replace_set: 504 outside_ins.append(replace_with_out[ 505 to_replace_map[x]]) 506 depends_on_seqs = True 507 elif isinstance(x, theano.Constant): 508 outside_ins.append(x.clone()) 509 else: 510 raise Exception( 511 ('Error in the `scan_pushout_seq_' 512 'operations`. The optimization tries ' 513 'to move some computation fron scan ' 514 'which is not allowed to move. Report ' 515 'this on theano-users list'), x) 516 517 if not depends_on_seqs: 518 # Removing this node from the inner graph of scan 519 # should be handled by the PushOutNonSeqScan 520 # optimization. The current optimization only tries 521 # to pull sequence-dependant computation out of 522 # scan. 523 continue 524 525 to_remove_set.add(nd) 526 527 # Do not call make_node for test_value 528 nw_outer_node = nd.op(*outside_ins, 529 **dict(return_list=True))[0].owner 530 531 # Step 2. Create variables for replacements 532 for idx, y in enumerate(nd.outputs): 533 y_place_holder = scan_utils.safe_new(y, '_replace') 534 add_to_replace(y) 535 replace_with_in.append(y_place_holder) 536 replace_with_out.append(nw_outer_node.outputs[idx]) 537 538 elif (nd not in to_remove_set and 539 isinstance(nd.op, theano.tensor.DimShuffle) and 540 (nd.inputs[0] in inner_seqs_set or 541 nd.inputs[0].owner in to_remove_set)): 542 543 to_remove_set.add(nd) 544 x = nd.inputs[0] 545 if x in inner_seqs_set: 546 outside_ins = outer_seqs[inner_seqs_map[x]] 547 elif x in to_replace_set: 548 outside_ins = replace_with_out[to_replace_map[x]] 549 new_ord = (0,) 550 for old_ord in nd.op.new_order: 551 if (old_ord == 'x'): 552 new_ord += (old_ord,) 553 else: 554 new_ord += (old_ord + 1,) 555 new_outer = outside_ins.dimshuffle(new_ord) 556 y = nd.outputs[0] 557 y_place_holder = scan_utils.safe_new(y, '_replace') 558 add_to_replace(y) 559 replace_with_in.append(y_place_holder) 560 replace_with_out.append(new_outer) 561 562 if hasattr(new_outer.tag, "test_value"): 563 new_sh = new_outer.tag.test_value.shape 564 ref_sh = (outside_ins.tag.test_value.shape[0],) 565 ref_sh += nd.outputs[0].tag.test_value.shape 566 assert new_sh == ref_sh 567 568 # We need to check all candidate replacements and choose those that 569 # make sense for us 570 # Step 1. which elements of `to_replace` are used by remaining 571 # components of the inner function 572 clean_to_replace = [] 573 clean_replace_with_in = [] 574 clean_replace_with_out = [] 575 576 existent_nodes = [nd for nd in local_fgraph_topo 577 if nd not in to_remove_set] 578 existent_nodes_set = set(existent_nodes) 579 580 to_keep_set = set([]) 581 for nd in existent_nodes: 582 to_keep_set.update(nd.inputs) 583 584 for out, idx in to_replace_map.items(): 585 if (out in to_keep_set and out.owner not in existent_nodes_set and 586 # If types are different, conversion Op will be inserted, 587 # and it may trigger an infinite loop. 588 replace_with_in[idx].type == out.type): 589 590 clean_to_replace.append(out) 591 clean_replace_with_in.append(replace_with_in[idx]) 592 clean_replace_with_out.append(replace_with_out[idx]) 593 594 if len(clean_to_replace) > 0: 595 # We can finally put an end to all this madness 596 givens = OrderedDict() 597 nw_outer = [] 598 nw_inner = [] 599 for to_repl, repl_in, repl_out in zip(clean_to_replace, 600 clean_replace_with_in, 601 clean_replace_with_out): 602 if isinstance(repl_out, theano.Constant): 603 repl_in = repl_out.clone() 604 else: 605 nw_inner.append(repl_in) 606 nw_outer.append(repl_out) 607 608 givens[to_repl] = repl_in 609 610 op_outs = scan_utils.clone(clean_outputs, replace=givens) 611 op_ins = nw_inner + clean_inputs 612 613 # Reconstruct node 614 nw_info = op.info.copy() 615 nw_info['n_seqs'] += len(nw_inner) 616 nwScan = scan_op.Scan(op_ins, op_outs, nw_info) 617 # Do not call make_node for test_value 618 nw_node = nwScan(*(node.inputs[:1] + nw_outer + node.inputs[1:]), 619 **dict(return_list=True))[0].owner 620 621 fgraph.replace_all_validate_remove( 622 list(zip(node.outputs, nw_node.outputs)), 623 remove=[node], 624 reason='scanOp_pushout_seqs_ops') 625 return True 626 elif (not to_keep_set and 627 not op.as_while and 628 not op.outer_mitmot(node)): 629 # Nothing in the inner graph should be kept 630 replace_with = OrderedDict() 631 for out, idx in to_replace_map.items(): 632 if out in local_fgraph_outs_set: 633 x = node.outputs[local_fgraph_outs_map[out]] 634 _y = replace_with_out[idx] 635 ls = clean_outputs 636 if out in op.inner_mitsot_outs(ls): 637 odx = op.inner_mitsot_outs(ls).index(out) 638 inp = op.outer_mitsot(node)[odx] 639 st = abs(np.min(op.mitsot_taps())) 640 y = tensor.set_subtensor(inp[st:], _y) 641 elif out in op.inner_sitsot_outs(ls): 642 odx = op.inner_sitsot_outs(ls).index(out) 643 inp = op.outer_sitsot(node)[odx] 644 y = tensor.set_subtensor(inp[1:], _y) 645 elif out in op.inner_nitsot_outs(ls): 646 y = _y 647 else: 648 y = _y[-1] 649 replace_with[x] = y 650 651 # We need to add one extra dimension to the outputs 652 if replace_with and len(replace_with) == len(node.outputs): 653 fgraph.replace_all_validate_remove( 654 list(replace_with.items()), 655 remove=[node], 656 reason='scanOp_pushout_seqs_ops') 657 return True 658 else: 659 return False 660 661 662class PushOutScanOutput(gof.Optimizer): 663 """ 664 This is an optimization that can push operations performed 665 at the end of the inner graph of scan to outside of scan. 666 """ 667 668 def __init__(self): 669 gof.Optimizer.__init__(self) 670 671 def add_requirements(self, fgraph): 672 fgraph.attach_feature(gof.toolbox.ReplaceValidate()) 673 674 def apply(self, fgraph): 675 # Don't perform the optimization on as_while scans. Because these scans 676 # don't run for a predetermined number of steps, handling them is 677 # more complicated and this optimization doesn't support it at the 678 # moment. 679 nodelist = [x for x in fgraph.toposort() 680 if (isinstance(x.op, scan_op.Scan) and 681 not x.op.as_while)] 682 for node in nodelist: 683 # Process the node as long as something gets optimized 684 while node is not None: 685 node = self.process_node(fgraph, node) 686 687 def process_node(self, fgraph, node): 688 689 op = node.op 690 691 # Use scan_args to parse the inputs and outputs of scan for ease of 692 # use 693 args = scan_args(node.inputs, node.outputs, 694 op.inputs, op.outputs, op.info) 695 696 new_scan_node = None 697 clients = {} 698 local_fgraph_topo = theano.gof.graph.io_toposort(args.inner_inputs, 699 args.inner_outputs, 700 clients=clients) 701 702 for nd in local_fgraph_topo: 703 if (isinstance(nd.op, theano.tensor.elemwise.Elemwise) and 704 isinstance(nd.op.scalar_op, scalar.Add) and 705 nd.out in args.inner_out_sit_sot and 706 self.inner_sitsot_only_last_step_used(nd.out, args)): 707 708 # Ensure that one of the input to the add is the output of 709 # the add from a previous iteration of the inner function 710 sitsot_idx = args.inner_out_sit_sot.index(nd.out) 711 if args.inner_in_sit_sot[sitsot_idx] in nd.inputs: 712 713 # Ensure that the other input to the add is a dot product 714 # between 2 matrices which will become a tensor3 and a 715 # matrix if pushed outside of the scan. Also make sure 716 # that the output of the Dot is ONLY used by the 'add' 717 # otherwise doing a Dot in the outer graph will only 718 # duplicate computation. 719 720 sitsot_in_idx = nd.inputs.index(args.inner_in_sit_sot[ 721 sitsot_idx]) 722 723 # 0 if sitsot_in_idx==1, 1 if sitsot_in_idx==0 724 dot_in_idx = 1 - sitsot_in_idx 725 726 dot_input = nd.inputs[dot_in_idx] 727 728 if (dot_input.owner is not None and 729 isinstance(dot_input.owner.op, theano.tensor.Dot) and 730 len(clients[dot_input]) == 1 and 731 dot_input.owner.inputs[0].ndim == 2 and 732 dot_input.owner.inputs[1].ndim == 2 and 733 self.get_outer_ndim(dot_input.owner.inputs[0], args) == 3 and 734 self.get_outer_ndim(dot_input.owner.inputs[1], args) == 3): 735 736 # The optimization can be be applied in this case. 737 738 # Move out of scan the two inputs to the Dot and 739 # perform a dot outside of scan on these two inputs 740 inner_dot_inputs = nd.inputs[dot_in_idx].owner.inputs 741 (outer_dot_inputs, 742 new_scan_node, 743 new_scan_args) = \ 744 self.push_out_inner_vars(fgraph, inner_dot_inputs, 745 node, args) 746 747 # Collapse some of the dimensions of the tensors 748 # so that they become matrices. This is because a 749 # dot is usually faster on two large matrices than 750 # a bunch of small ones 751 outer_dot_inputs[0] = theano.tensor.flatten( 752 outer_dot_inputs[0].dimshuffle(1, 0, 2), ndim=2) 753 754 shape_input1 = theano.tensor.shape(outer_dot_inputs[1]) 755 outer_dot_inputs[1] =\ 756 outer_dot_inputs[1].reshape((shape_input1[0] * 757 shape_input1[1], 758 shape_input1[2])) 759 760 # Perform the dot on the newly obtained matrices and 761 # add the initial value 762 outer_dot_output = theano.tensor.dot(*outer_dot_inputs) 763 init_value = new_scan_args.outer_in_sit_sot[sitsot_idx][0] 764 replacement = outer_dot_output + init_value 765 766 # Alter the outer graph to use the output of the 767 # external Dot instead of the output of scan 768 # Modify the outer graph to add the outer Dot 769 outer_sitsot = new_scan_args.outer_out_sit_sot[sitsot_idx] 770 subtensor_node = outer_sitsot.clients[0][0] 771 outer_sitsot_last_step = subtensor_node.outputs[0] 772 773 fgraph.replace_all([ 774 (outer_sitsot_last_step, replacement)], 775 reason="scanOp_pushout_output") 776 777 break 778 return new_scan_node 779 780 def inner_sitsot_only_last_step_used(self, var, scan_args): 781 """ 782 Given a inner nit_sot output of scan, return True iff the outer 783 nit_sot output has only one client and that client is a Subtensor 784 instance that takes only the last step (last element along the first 785 axis). 786 787 """ 788 idx = scan_args.inner_out_sit_sot.index(var) 789 outer_var = scan_args.outer_out_sit_sot[idx] 790 791 if len(outer_var.clients) == 1: 792 client = outer_var.clients[0][0] 793 if (client != 'output' and isinstance(client.op, 794 theano.tensor.Subtensor)): 795 lst = theano.tensor.subtensor.get_idx_list( 796 client.inputs, client.op.idx_list) 797 if (len(lst) == 1 and 798 theano.tensor.extract_constant(lst[0]) == -1): 799 return True 800 801 return False 802 803 def get_outer_ndim(self, var, scan_args): 804 805 # Given a variable, determine the number of dimension it would have if 806 # it was pushed out of scan 807 if (var in scan_args.inner_in_non_seqs or 808 isinstance(var, theano.Constant)): 809 810 outer_ndim = var.ndim 811 else: 812 outer_ndim = var.ndim + 1 813 814 return outer_ndim 815 816 def push_out_inner_vars(self, fgraph, inner_vars, old_scan_node, 817 old_scan_args): 818 819 outer_vars = [None] * len(inner_vars) 820 new_scan_node = old_scan_node 821 new_scan_args = old_scan_args 822 823 # For the inner_vars that already exist in the outer graph, 824 # simply obtain a reference to them 825 for idx in range(len(inner_vars)): 826 827 var = inner_vars[idx] 828 829 if var in old_scan_args.inner_in_seqs: 830 idx_seq = old_scan_args.inner_in_seqs.index(var) 831 outer_vars[idx] = old_scan_args.outer_in_seqs[idx_seq] 832 833 elif var in old_scan_args.inner_in_non_seqs: 834 idx_non_seq = old_scan_args.inner_in_non_seqs.index(var) 835 outer_vars[idx] = old_scan_args.outer_in_non_seqs[idx_non_seq] 836 837 elif isinstance(var, theano.Constant): 838 outer_vars[idx] = var.clone() 839 840 elif var in old_scan_args.inner_out_nit_sot: 841 idx_nitsot = old_scan_args.inner_out_nit_sot.index(var) 842 outer_vars[idx] = old_scan_args.outer_out_nit_sot[idx_nitsot] 843 844 # For the inner_vars that don't already exist in the outer graph, add 845 # them as new nitsot outputs to the scan node. 846 idx_add_as_nitsots = [i for i in range(len(outer_vars)) 847 if outer_vars[i] is None] 848 add_as_nitsots = [inner_vars[idx] for idx in idx_add_as_nitsots] 849 850 if len(add_as_nitsots) > 0: 851 852 new_scan_node = self.add_nitsot_outputs(fgraph, old_scan_node, 853 old_scan_args, 854 add_as_nitsots) 855 856 new_scan_args = scan_args(new_scan_node.inputs, 857 new_scan_node.outputs, 858 new_scan_node.op.inputs, 859 new_scan_node.op.outputs, 860 new_scan_node.op.info) 861 862 new_outs = new_scan_args.outer_out_nit_sot[-len(add_as_nitsots):] 863 for i in range(len(new_outs)): 864 outer_vars[idx_add_as_nitsots[i]] = new_outs[i] 865 866 return outer_vars, new_scan_node, new_scan_args 867 868 def add_nitsot_outputs(self, fgraph, old_scan_node, 869 old_scan_args, new_outputs_inner): 870 871 nb_new_outs = len(new_outputs_inner) 872 873 # Create the initial values for the new nitsot outputs 874 # (the initial value is the nb of steps to store. For a nistot, 875 # it should be the number of steps performed by scan) 876 new_nitsots_initial_value = [old_scan_node.inputs[0] 877 for i in range(nb_new_outs)] 878 879 # Create the scan_args corresponding to the new scan op to 880 # create 881 new_scan_args = copy.copy(old_scan_args) 882 new_scan_args.inner_out_nit_sot.extend(new_outputs_inner) 883 new_scan_args.outer_in_nit_sot.extend(new_nitsots_initial_value) 884 885 # Create the scan op from the scan_args 886 new_scan_op = scan_op.Scan(new_scan_args.inner_inputs, 887 new_scan_args.inner_outputs, 888 new_scan_args.info) 889 890 # Create the Apply node for the scan op 891 new_scan_node = new_scan_op(*new_scan_args.outer_inputs, 892 **dict(return_list=True))[0].owner 893 894 # Modify the outer graph to make sure the outputs of the new scan are 895 # used instead of the outputs of the old scan 896 new_node_new_outputs_idx = (len(old_scan_args.outer_outputs) - 897 len(old_scan_args.outer_out_shared)) 898 899 new_node_old_outputs = ( 900 new_scan_node.outputs[:new_node_new_outputs_idx] + 901 new_scan_node.outputs[new_node_new_outputs_idx + nb_new_outs:]) 902 903 fgraph.replace_all_validate_remove( 904 list(zip(old_scan_node.outputs, new_node_old_outputs)), 905 remove=[old_scan_node], 906 reason='scanOp_pushout_output') 907 908 return new_scan_node 909 910 911class ScanInplaceOptimizer(Optimizer): 912 """ 913 Graph optimizer for Scan (makes it run inplace). 914 915 """ 916 917 def __init__(self, typeInfer=None, gpua_flag=False): 918 Optimizer.__init__(self) 919 self.typeInfer = typeInfer 920 self.gpua_flag = gpua_flag 921 922 def add_requirements(self, fgraph): 923 fgraph.attach_feature(toolbox.ReplaceValidate()) 924 fgraph.attach_feature(DestroyHandler()) 925 926 def attempt_scan_inplace(self, fgraph, node, output_indices, alloc_ops): 927 """Attempts to replace a Scan node by one which computes the specified 928 outputs inplace. 929 930 Parameters 931 ---------- 932 fgraph : FunctionGraph 933 Function graph in which to attempt the replacement 934 node : Apply node 935 Scan node to replace by an inplace version 936 output_indices : list of integers 937 Indices of the outputs to attempt to compute inplace 938 alloc_ops : list of Op classes 939 Classes that represent operation that allocate new memory and 940 that the optimization should duplicate so it can operate inplace 941 on them. 942 """ 943 944 op = node.op 945 946 info = copy.deepcopy(op.info) 947 if 'destroy_map' not in info: 948 info['destroy_map'] = OrderedDict() 949 950 for out_idx in output_indices: 951 info['destroy_map'][out_idx] = [out_idx + 1 + op.info['n_seqs']] 952 953 # inputs corresponding to sequences and n_steps 954 ls_begin = node.inputs[:1 + op.n_seqs] 955 ls = op.outer_mitmot(node.inputs) 956 ls += op.outer_mitsot(node.inputs) 957 ls += op.outer_sitsot(node.inputs) 958 ls_end = op.outer_shared(node.inputs) 959 ls_end += op.outer_nitsot(node.inputs) 960 ls_end += op.outer_non_seqs(node.inputs) 961 962 # In `ls`, duplicate any input which has more then one client and is 963 # the output of an eligible allocation op 964 for i in range(len(ls)): 965 inp = ls[i] 966 if (len(inp.clients) > 1 and inp.owner and 967 isinstance(inp.owner.op, alloc_ops)): 968 ls[i] = inp.owner.op(*inp.owner.inputs) 969 970 n_outs = len(ls) 971 for idx in xrange(n_outs): 972 if ls[idx] in ls[:idx]: 973 ls[idx] = deep_copy_op(ls[idx]) 974 975 inputs = ls_begin + ls + ls_end 976 if self.typeInfer is None: 977 typeConstructor = None 978 else: 979 typeConstructor = self.typeInfer(node) 980 981 new_op = scan_op.Scan(op.inputs, 982 op.outputs, 983 info, 984 typeConstructor=typeConstructor) 985 986 # Do not call make_node for test_value 987 new_outs = new_op(*inputs, **dict(return_list=True)) 988 try: 989 fgraph.replace_all_validate_remove( 990 list(zip(node.outputs, new_outs)), 991 remove=[node], 992 reason='scanOp_make_inplace') 993 return new_outs[0].owner 994 except InconsistencyError: 995 # Failed moving output to be computed inplace 996 return node 997 998 def apply(self, fgraph): 999 1000 # Depending on the value of gpua_flag, get the list of memory 1001 # allocation ops that the optimization should be able to 1002 # handle 1003 alloc_ops = (Alloc, AllocEmpty) 1004 if self.gpua_flag: 1005 # gpuarray might be imported but not its GpuAlloc and 1006 # GpuAllopEmpty ops. 1007 try: 1008 alloc_ops += (theano.gpuarray.GpuAlloc, 1009 theano.gpuarray.GpuAllocEmpty) 1010 except Exception: 1011 pass 1012 1013 nodes = fgraph.toposort()[::-1] 1014 scan_nodes = [x for x in nodes 1015 if (isinstance(x.op, scan_op.Scan) and 1016 x.op.info['gpua'] == self.gpua_flag)] 1017 for scan_idx in xrange(len(scan_nodes)): 1018 1019 # First attempt to make the Scan compute inplace every recurrent 1020 # output that seems like it could be computed inplace. If that 1021 # fails, go through these outputs individually, trying each of 1022 # them. 1023 original_node = scan_nodes[scan_idx] 1024 op = original_node.op 1025 n_outs = (op.info['n_mit_mot'] + 1026 op.info['n_mit_sot'] + 1027 op.info['n_sit_sot']) 1028 1029 # Generate a list of outputs on which the node could potentially 1030 # operate inplace. 1031 out_indices = [] 1032 for out_idx in range(n_outs): 1033 inp_idx = 1 + op.n_seqs + out_idx 1034 inp = original_node.inputs[inp_idx] 1035 1036 # If the input is from an eligible allocation node, attempt to 1037 # be inplace on it, even if other nodes are modifying it 1038 # inplace. 1039 if inp.owner and isinstance(inp.owner.op, alloc_ops): 1040 out_indices.append(out_idx) 1041 continue 1042 1043 # If the input is not from an eligible allocation node, only 1044 # attempt to be inplace on it if nothing else is currently 1045 # inplace on it. 1046 input_used_inplace = False 1047 for c in original_node.inputs[inp_idx].clients: 1048 client = c[0] 1049 1050 # Get the indices of this client's inputs on which it 1051 # operates inplace 1052 if hasattr(client.op, 'destroy_map'): 1053 # This flattens the content of destroy_map.values() 1054 # which is a list of lists 1055 inplace_inp_indices = sum(client.op.destroy_map.values(), []) 1056 1057 inplace_inps = [client.inputs[i] for i in inplace_inp_indices] 1058 if original_node.inputs[inp_idx] in inplace_inps: 1059 input_used_inplace = True 1060 break 1061 1062 if not input_used_inplace: 1063 out_indices.append(out_idx) 1064 1065 node = self.attempt_scan_inplace(fgraph, scan_nodes[scan_idx], 1066 out_indices, alloc_ops) 1067 1068 if node is original_node: 1069 # Making the scan compute all plausible recurrent outputs 1070 # inplace has failed. Attempt all plausible recurrent output 1071 # individually. 1072 for pos in out_indices: 1073 node = self.attempt_scan_inplace(fgraph, node, [pos], 1074 alloc_ops) 1075 1076 1077class ScanSaveMem(gof.Optimizer): 1078 """ 1079 Graph Optimizer that reduces scan memory consumption. 1080 1081 """ 1082 1083 def __init__(self): 1084 gof.Optimizer.__init__(self) 1085 1086 def add_requirements(self, fgraph): 1087 fgraph.attach_feature(gof.toolbox.ReplaceValidate()) 1088 1089 def process_node(self, fgraph, node): 1090 1091 # helpful functions 1092 def select_min(x, y): 1093 if x is None: 1094 return y 1095 if y is None: 1096 return x 1097 return tensor.minimum(x, y) 1098 1099 def select_max(x, y): 1100 if x is None: 1101 return y 1102 if y is None: 1103 return x 1104 return tensor.maximum(x, y) 1105 1106 def sanitize(x): 1107 if x is None: 1108 return None 1109 else: 1110 return tensor.as_tensor_variable(x) 1111 1112 if hasattr(fgraph, 'shape_feature'): 1113 shape_of = node.fgraph.shape_feature.shape_of 1114 else: 1115 # Each access to shape_of is in a try..except block in order to 1116 # use a default version when the variable is not in the shape_of 1117 # dictionary. 1118 shape_of = OrderedDict() 1119 # 1. Initialization of variables 1120 # Note 1) We do not actually care about outputs representing shared 1121 # variables (those have no intermediate values) so it is safer to 1122 # ignore them and not change them in any way. To simplify the 1123 # optimizations I construct the variable ``c_outs`` ( that counts 1124 # outputs up to those we care) and the list ``init_l`` which for any 1125 # output we care says the length of its initial state. Note that 1126 # defining ``init_l`` for mit_mot sequences is a bit trickier but 1127 # it is safe to set it to 0 1128 op = node.op 1129 c_outs = op.n_mit_mot + op.n_mit_sot + op.n_sit_sot + op.n_nit_sot 1130 1131 init_l = [0 for x in xrange(op.n_mit_mot)] 1132 init_l += [abs(min(v)) for v in op.tap_array[op.n_mit_mot:]] 1133 init_l += [0 for x in xrange(op.n_nit_sot)] 1134 # 2. Check the clients of each output and see for how many steps 1135 # does scan need to run 1136 1137 # This comparison checks if there is any uncounted output, which 1138 # can only be an output corresponding to a shared variable 1139 1140 # 2.1 Initialize 1141 # global_nsteps is a dictionary having two fields ( 'real' deals 1142 # with int values, 'sym' with symbolic ones) or None 1143 # given that a scan op has k outputs o_1, .. o_k and each 1144 # output has n_j clients c_1^1, c_1^2, .. c_1^{n_1}, c_2^1, .., 1145 # global_nsteps is None if any of the clients is different 1146 # from a subtensor or its real and sym field equal to 1147 # max(c_i_j.idx_list[0].stop), meaning store up to which maximal 1148 # index(step) for any output scan actually needs to compute 1149 # In other words n_steps should be equal to this maximal ! 1150 # Note: if we have a shared variable that gets updated at every step 1151 # of the loop, reducing the number of steps will affect the the 1152 # value of the shared variable after the loop so we need not to 1153 # change the number of steps in that case. To do this we set 1154 # global_nsteps to None which is seen as a flag that nothing needs 1155 # to be done 1156 assert len(node.outputs) >= c_outs 1157 if len(node.outputs) == c_outs: 1158 global_nsteps = {'real': -1, 'sym': []} 1159 else: 1160 global_nsteps = None 1161 1162 # Keeps track of the original slices that each client represent 1163 slices = [None for o in node.outputs] 1164 1165 # A list for each output indicating how many intermediate values 1166 # should be stored. If negative it means none of the intermediate 1167 # values (i.e. the output can be removed since it is not used 1168 # afterwards in the computations), if 0 it means that all 1169 # intermediate values are required, otherwise is up to that number 1170 # of intermediate values 1171 # Note that for mit_mot outputs and shared outputs we can not change 1172 # the number of intermediate steps stored without affecting the 1173 # result of the op 1174 store_steps = [0 for o in xrange(op.n_mit_mot)] 1175 store_steps += [-1 for o in node.outputs[op.n_mit_mot:c_outs]] 1176 # Flag that says if an input has changed and we need to do something 1177 # or not 1178 flag_store = False 1179 1180 # 2.2 Loop over the clients 1181 for i, out in enumerate(node.outputs[:c_outs]): 1182 # look at all its clients 1183 slices[i] = [] 1184 for cl, _ in out.clients: 1185 1186 # 2.1 outputs of the function 1187 # => output needs all its intermediate values 1188 if type(cl) == str: 1189 # if the node is actually an output, then 1190 # we need to store the entire thing 1191 global_nsteps = None 1192 slices[i] = None 1193 break 1194 # 2.2 non-subtensor nodes 1195 # => output needs all its intermediate values 1196 elif not isinstance(cl.op, tensor.Subtensor): 1197 global_nsteps = None 1198 slices[i] = None 1199 break 1200 # 2.3 subtensor nodes 1201 # => output might need to store just a subset of its values 1202 else: 1203 # 2.3.1 extract idx list of subtensor 1204 this_slice = tensor.get_idx_list(cl.inputs, 1205 cl.op.idx_list) 1206 if this_slice is None: 1207 # if unable to extract idx_list 1208 # => outputs needs all its intermediate values 1209 global_nsteps = None 1210 slices[i] = None 1211 break 1212 1213 # 2.3.2 extract the begin/end of the first dimension 1214 if i >= op.n_mit_mot: 1215 try: 1216 length = shape_of[out][0] 1217 except KeyError: 1218 length = node.inputs[0] + init_l[i] 1219 else: 1220 try: 1221 length = shape_of[out][0] 1222 except KeyError: 1223 length = out.shape[0] 1224 cf_slice = tensor.get_canonical_form_slice( 1225 this_slice[0], length) 1226 slices[i] += [(cf_slice, this_slice)] 1227 1228 if (isinstance(this_slice[0], slice) and 1229 this_slice[0].stop is None): 1230 global_nsteps = None 1231 if isinstance(cf_slice[0], slice): 1232 stop = tensor.basic.extract_constant(cf_slice[0].stop) 1233 else: 1234 stop = tensor.basic.extract_constant(cf_slice[0]) + 1 1235 if stop == maxsize or stop == length: 1236 stop = None 1237 else: 1238 # there is a **gotcha** here ! Namely, scan returns an 1239 # array that contains the initial state of the output 1240 # as well. Which means that if have a initial state of 1241 # length 3, and you look for 5 steps you get an output 1242 # y of length 8. If you only use y[:5], this does not 1243 # mean that you only need to loop for 5 steps but 1244 # actually only for 2 steps ( the first 3 are the 1245 # initial state) 1246 stop = stop - init_l[i] 1247 1248 # 2.3.3 we might get away with less number of steps 1249 if stop is not None and global_nsteps is not None: 1250 # yes if it is a tensor 1251 if isinstance(stop, tensor.Variable): 1252 global_nsteps['sym'] += [stop] 1253 # not if it is maxsize 1254 elif (type(stop) in integer_types and 1255 stop == maxsize): 1256 global_nsteps = None 1257 # yes if it is a int k, 0 < k < maxsize 1258 elif (type(stop) in integer_types and 1259 global_nsteps['real'] < stop): 1260 global_nsteps['real'] = stop 1261 # yes if it is a int k, 0 < k < maxsize 1262 elif (type(stop) in integer_types and stop > 0): 1263 pass 1264 # not otherwise 1265 else: 1266 global_nsteps = None 1267 1268 # 2.3. Analyze global_nsteps to figure out for how many steps scan 1269 # needs to iterate 1270 if global_nsteps is not None: 1271 nw_steps = node.inputs[0] 1272 1273 # there are some symbolic tensors that limit the number of 1274 # steps 1275 if len(global_nsteps['sym']) == 0: 1276 sym_steps = None 1277 else: 1278 sym_steps = global_nsteps['sym'][0] 1279 for c in global_nsteps['sym'][1:]: 1280 sym_steps = tensor.maximum(sym_steps, c) 1281 1282 if global_nsteps['real'] >= 0: 1283 real_steps = global_nsteps['real'] 1284 else: 1285 real_steps = None 1286 nw_steps = select_min(select_max(sym_steps, real_steps), 1287 node.inputs[0]) 1288 1289 # Make sure the ScanSaveMem optimization never makes the new 1290 # number of steps to be 0 (this could happen, for instance, if 1291 # the optimization detects that the outputs of the Scan go through 1292 # subtensor nodes that end up taking no elements) because Scan with 1293 # 0 iterations are not supported. Make sure the new number of steps 1294 # is at least 1. 1295 nw_steps = select_max(nw_steps, 1) 1296 else: 1297 nw_steps = node.inputs[0] 1298 global_nsteps = None 1299 1300 # 2.4 Loop over the clients again now looking just to see how many 1301 # intermediate steps to store 1302 for i, out in enumerate(node.outputs[:c_outs]): 1303 # look at all its clients 1304 for cl, _ in out.clients: 1305 if type(cl) == str: 1306 store_steps[i] = 0 1307 break 1308 elif not isinstance(cl.op, tensor.Subtensor): 1309 store_steps[i] = 0 1310 break 1311 else: 1312 this_slice = tensor.get_idx_list(cl.inputs, 1313 cl.op.idx_list) 1314 if this_slice is None: 1315 store_steps[i] = 0 1316 break 1317 1318 if (isinstance(this_slice[0], slice) and 1319 this_slice[0].start is None): 1320 store_steps[i] = 0 1321 break 1322 1323 if i > op.n_mit_mot: 1324 length = node.inputs[0] + init_l[i] 1325 else: 1326 try: 1327 length = shape_of[out][0] 1328 except KeyError: 1329 length = out.shape[0] 1330 cf_slice = tensor.get_canonical_form_slice( 1331 this_slice[0], length) 1332 1333 if isinstance(cf_slice[0], slice): 1334 start = tensor.basic.extract_constant( 1335 cf_slice[0].start) 1336 else: 1337 start = tensor.basic.extract_constant(cf_slice[0]) 1338 if start == 0 or store_steps[i] == 0: 1339 store_steps[i] = 0 1340 else: 1341 # The "+ 1" is because of the memory pre-allocation 1342 # mechanism used to in the Scan op to reduce overhead. 1343 # To prevent aliasing between the inputs and outputs 1344 # of recurrent states, it requires that the buffer be 1345 # large enough to that, the new state and the oldest 1346 # tap needed don't occupy the sample place in the 1347 # circular buffer. For now, this only needs to be done 1348 # for mitsots and sitsots (because mitmots are not 1349 # currently supported by the mechanism) and only if 1350 # the pre-allocation mechanism is activated. 1351 prealloc_outs = theano.config.scan.allow_output_prealloc 1352 1353 first_mitsot_idx = node.op.n_mit_mot 1354 last_sitsot_idx = (node.op.n_mit_mot + 1355 node.op.n_mit_sot + 1356 node.op.n_sit_sot - 1) 1357 preallocable_output = (first_mitsot_idx <= i <= last_sitsot_idx) 1358 1359 if (prealloc_outs and preallocable_output): 1360 pval = select_max(nw_steps - start + init_l[i], 1361 init_l[i] + 1) 1362 else: 1363 pval = select_max(nw_steps - start + init_l[i], 1364 init_l[i]) 1365 1366 if store_steps[i] != -1: 1367 pval = select_max(pval, store_steps[i]) 1368 1369 # TODO: Simplify the number of steps needed. 1370 # FB: This need good testing, left to later. 1371 # call get_scalar_constant_value()? it can 1372 # return python/numpy scalar or np.ndarray 1373 # currently. 1374 # pval = pre_greedy_local_optimizer(list_opt_slice, 1375 # pval) 1376 # pval = pre_constant_merge([pval])[0] 1377 # if (isinstance(pval, theano.tensor.TensorConstant) 1378 # and 1379 # pval.dtype.startswith('int')): 1380 # try: 1381 # pval = int(pval.data) 1382 # except Exception: 1383 # pass 1384 1385 store_steps[i] = pval 1386 flag_store = True 1387 1388 orphane_outs = [i for i, x in enumerate(store_steps) 1389 if (type(x) is int) and (x < 0)] 1390 flag_store = flag_store or (len(orphane_outs) > 0) 1391 # 3. is there anything to change ? 1392 if (flag_store or global_nsteps is not None): 1393 # 3.1 initialize inputs for the new scan 1394 old_outputs = [] 1395 nw_inputs = list(node.inputs) 1396 nw_inputs[0] = nw_steps 1397 1398 # 3.2 check orphane outputs to see if we can eliminate any 1399 required, not_required = scan_utils.scan_can_remove_outs( 1400 node.op, orphane_outs) 1401 # 3.3. compose replace pairs for those nodes that need not 1402 # to store everything in memory ( or ar orphane and required 1403 # by the inner function .. ) 1404 replaced_outs = [] 1405 offset = 1 + op.n_seqs + op.n_mit_mot 1406 for idx, _val in enumerate(store_steps[op.n_mit_mot:]): 1407 i = idx + op.n_mit_mot 1408 if not(type(_val) is int and _val <= 0 and i not in required): 1409 1410 if idx + op.n_mit_mot in required: 1411 val = 1 1412 else: 1413 val = _val 1414 # If the memory for this output has been pre-allocated 1415 # before going into the scan op (by an alloc node) 1416 if idx < op.n_mit_sot + op.n_sit_sot: 1417 # In case the input is still an alloc node, we 1418 # actually have two options: 1419 # a) the input is a set_subtensor, in that case we 1420 # can replace the initial tensor by a slice, 1421 # b) it is not, and we simply take a slice of it. 1422 # TODO: commit change below with Razvan 1423 if (nw_inputs[offset + idx].owner and 1424 isinstance(nw_inputs[offset + idx].owner.op, 1425 tensor.IncSubtensor) and 1426 isinstance( 1427 nw_inputs[offset + idx].owner.op.idx_list[0], 1428 slice)): 1429 1430 assert isinstance(nw_inputs[offset + idx].owner.op, 1431 tensor.IncSubtensor) 1432 _nw_input = nw_inputs[offset + idx].owner.inputs[1] 1433 cval = tensor.as_tensor_variable(val) 1434 initl = tensor.as_tensor_variable(init_l[i]) 1435 tmp_idx = tensor.switch(cval < initl, 1436 cval + initl, 1437 cval - initl) 1438 tmp = pre_greedy_local_optimizer(list_opt_slice, 1439 tmp_idx) 1440 tmp = pre_constant_merge([tmp])[0] 1441 1442 nw_input = scan_utils.expand_empty(_nw_input, tmp) 1443 else: 1444 tmp = tensor.as_tensor_variable(val) 1445 initl = tensor.as_tensor_variable(init_l[i]) 1446 tmp = tensor.maximum(tmp, initl) 1447 tmp = pre_greedy_local_optimizer(list_opt_slice, 1448 tmp) 1449 tmp = pre_constant_merge([tmp])[0] 1450 nw_input = nw_inputs[offset + idx][:tmp] 1451 1452 nw_inputs[offset + idx] = nw_input 1453 replaced_outs.append(op.n_mit_mot + idx) 1454 odx = op.n_mit_mot + idx 1455 old_outputs += [(odx, [x[0].outputs[0] for x in 1456 node.outputs[odx].clients])] 1457 # If there is no memory pre-allocated for this output 1458 elif idx < op.n_mit_sot + op.n_sit_sot + op.n_nit_sot: 1459 1460 pos = (op.n_mit_mot + idx + op.n_seqs + 1461 1 + op.n_shared_outs) 1462 if nw_inputs[pos] == node.inputs[0]: 1463 nw_inputs[pos] = val 1464 odx = op.n_mit_mot + idx 1465 replaced_outs.append(odx) 1466 old_outputs += [(odx, [x[0].outputs[0] for x in 1467 node.outputs[odx].clients])] 1468 # 3.4. Recompute inputs for everything else based on the new 1469 # number of steps 1470 if global_nsteps is not None: 1471 for idx, val in enumerate(store_steps[op.n_mit_mot:]): 1472 if val == 0: 1473 # val == 0 means that we want to keep all intermediate 1474 # results for that state, including the initial values. 1475 if idx < op.n_mit_sot + op.n_sit_sot: 1476 in_idx = offset + idx 1477 # Number of steps in the initial state 1478 initl = init_l[op.n_mit_mot + idx] 1479 1480 # If the initial buffer has the form 1481 # inc_subtensor(zeros(...)[...], _nw_input) 1482 # we want to make the zeros tensor as small as 1483 # possible (nw_steps + initl), and call 1484 # inc_subtensor on that instead. 1485 # Otherwise, simply take 0:(nw_steps+initl). 1486 if ((nw_inputs[in_idx].owner and 1487 isinstance(nw_inputs[in_idx].owner.op, 1488 tensor.IncSubtensor) and 1489 isinstance( 1490 nw_inputs[in_idx].owner.op.idx_list[0], 1491 slice))): 1492 _nw_input = nw_inputs[in_idx].owner.inputs[1] 1493 nw_input = scan_utils.expand_empty(_nw_input, 1494 nw_steps) 1495 nw_inputs[in_idx] = nw_input 1496 else: 1497 nw_input = nw_inputs[in_idx][:(initl + nw_steps)] 1498 1499 elif idx < op.n_mit_sot + op.n_sit_sot + op.n_nit_sot: 1500 in_idx = offset + idx + op.n_shared_outs 1501 if nw_inputs[in_idx] == node.inputs[0]: 1502 nw_inputs[in_idx] = nw_steps 1503 1504 # 3.5 Remove unwanted orphane outputs 1505 (inps, outs, info, node_ins, compress_map) = \ 1506 scan_utils.compress_outs(op, not_required, nw_inputs) 1507 inv_compress_map = OrderedDict() 1508 for k, v in iteritems(compress_map): 1509 inv_compress_map[v] = k 1510 1511 node_ins = [pre_greedy_local_optimizer(list_opt_slice, x) for x in 1512 node_ins] 1513 node_ins = pre_constant_merge(node_ins) 1514 # 3.6 Compose the new scan 1515 # TODO: currently we don't support scan with 0 step. So 1516 # don't create one. 1517 # For test, mark that savemem have optimized this node 1518 info['_scan_savemem_visited'] = True 1519 if theano.tensor.extract_constant(node_ins[0]) == 0: 1520 return 1521 1522 # Do not call make_node for test_value 1523 new_outs = scan_op.Scan(inps, outs, info)(*node_ins, 1524 **dict(return_list=True)) 1525 1526 old_new = [] 1527 # 3.7 Get replace pairs for those outputs that do not change 1528 # the number of intermediate steps stored 1529 for idx, sl in enumerate(slices): 1530 if global_nsteps and sl is not None and store_steps[idx] == 0: 1531 for hdx, cl in enumerate(node.outputs[idx].clients): 1532 cnf_slice, old_slices = sl[hdx] 1533 # Sanitize the nw_slice by converting ints back into 1534 # constants :) I only need to do this for the first 1535 # slice since that is the only slice 1536 1537 if isinstance(cnf_slice[0], slice): 1538 fslice = slice( 1539 sanitize(cnf_slice[0].start), 1540 sanitize(cnf_slice[0].stop), 1541 sanitize(cnf_slice[0].step)) 1542 else: 1543 fslice = sanitize(cnf_slice[0]) 1544 1545 nw_slice = (fslice,) + tuple(old_slices[1:]) 1546 nw_pos = inv_compress_map[idx] 1547 1548 subtens = tensor.Subtensor(nw_slice) 1549 # slice inputs 1550 sl_ins = tensor.Subtensor.collapse( 1551 nw_slice, 1552 lambda entry: isinstance(entry, 1553 tensor.Variable)) 1554 new_o = subtens(new_outs[nw_pos], *sl_ins) 1555 if new_o.ndim > 0: 1556 new_o = new_o[::cnf_slice[1]] 1557 replaced_outs.append(idx) 1558 old_new += [(cl[0].outputs[0], new_o)] 1559 # 3.8. Get replace pairs for those outputs that change 1560 # the number of stored intermediate steps 1561 for pos, old_outs in old_outputs: 1562 if len(old_outs) > 0: 1563 nw_pos = compress_map[pos] 1564 for k, old in enumerate(old_outs): 1565 # Get the correct slice 1566 cnf_slice, old_slices = slices[pos][k] 1567 if type(cnf_slice[0]) is slice: 1568 start = (cnf_slice[0].start - nw_steps - 1569 init_l[pos] + store_steps[pos]) 1570 if (cnf_slice[0].stop is not None and 1571 cnf_slice[0].stop != maxsize): 1572 stop = (cnf_slice[0].stop - nw_steps - 1573 init_l[pos] + store_steps[pos]) 1574 else: 1575 stop = None 1576 nw_slice = ((slice(sanitize(start), 1577 sanitize(stop), 1578 sanitize(cnf_slice[0].step)),) + 1579 tuple(old_slices[1:])) 1580 1581 else: 1582 position = (cnf_slice[0] - nw_steps - 1583 init_l[pos] + store_steps[pos]) 1584 1585 nw_slice = (sanitize(position),) + tuple( 1586 old_slices[1:]) 1587 subtens = tensor.Subtensor(nw_slice) 1588 sl_ins = tensor.Subtensor.collapse( 1589 nw_slice, 1590 lambda entry: isinstance(entry, 1591 tensor.Variable)) 1592 new_o = subtens(new_outs[nw_pos], *sl_ins) 1593 if new_o.ndim > 0: 1594 new_o = new_o[::cnf_slice[1]] 1595 old_new += [(old, new_o)] 1596 1597 # 3.9. Get replace pairs for all other nodes 1598 if flag_store or global_nsteps is not None: 1599 for idx, o in enumerate(node.outputs): 1600 if not (idx in replaced_outs) and idx not in not_required: 1601 nw_pos = compress_map[idx] 1602 old_new += [(o, new_outs[nw_pos])] 1603 # Check if the new outputs depend on the old scan node 1604 old_scan_is_used = [gof.graph.is_in_ancestors(new.owner, node) 1605 for old, new in old_new] 1606 if any(old_scan_is_used): 1607 return False 1608 remove = [old.owner for (old, new) in old_new] 1609 # As Fred suggested assert that also the old node is not in 1610 # the Graph as that will make things suboptimal 1611 remove.append(node) 1612 fgraph.replace_all_validate_remove(old_new, 1613 remove, 1614 reason='scanOp_save_mem') 1615 1616 def apply(self, fgraph): 1617 1618 nodelist = [x for x in fgraph.toposort() if isinstance(x.op, 1619 scan_op.Scan)] 1620 for node in nodelist: 1621 self.process_node(fgraph, node) 1622 1623 1624class ScanMerge(gof.Optimizer): 1625 """ 1626 Graph Optimizer that merges different scan ops. 1627 1628 """ 1629 1630 def add_requirements(self, fgraph): 1631 fgraph.attach_feature(gof.toolbox.ReplaceValidate()) 1632 1633 def merge(self, nodes): 1634 1635 if nodes[0].op.as_while: 1636 as_while = True 1637 condition = nodes[0].op.outputs[-1] 1638 else: 1639 as_while = False 1640 1641 info = OrderedDict() 1642 info['tap_array'] = [] 1643 info['n_seqs'] = sum([nd.op.n_seqs for nd in nodes]) 1644 info['n_mit_mot'] = sum([nd.op.n_mit_mot for nd in nodes]) 1645 info['n_mit_mot_outs'] = sum([nd.op.n_mit_mot_outs for nd in nodes]) 1646 info['mit_mot_out_slices'] = [] 1647 info['n_mit_sot'] = sum([nd.op.n_mit_sot for nd in nodes]) 1648 info['n_sit_sot'] = sum([nd.op.n_sit_sot for nd in nodes]) 1649 info['n_shared_outs'] = sum([nd.op.n_shared_outs for nd in nodes]) 1650 info['n_nit_sot'] = sum([nd.op.n_nit_sot for nd in nodes]) 1651 info['truncate_gradient'] = nodes[0].op.truncate_gradient 1652 info['name'] = '&'.join([nd.op.name for nd in nodes]) 1653 info['mode'] = nodes[0].op.mode 1654 info['gpua'] = False 1655 info['as_while'] = as_while 1656 info['profile'] = nodes[0].op.profile 1657 info['allow_gc'] = nodes[0].op.allow_gc 1658 1659 # We keep the inner_ins and inner_outs of each original node separated. 1660 # To be able to recombine them in the right order after the clone, 1661 # we also need to split them by types (seq, mitmot, ...). 1662 # On the other hand, outer_ins, outer_outs and info are held together. 1663 inner_ins = [[] for nd in nodes] 1664 outer_ins = [] 1665 inner_outs = [[] for nd in nodes] 1666 outer_outs = [] 1667 1668 def rename(ls, suffix): 1669 for k in ls: 1670 if k.name: 1671 k.name += str(suffix) 1672 return ls 1673 1674 for idx, nd in enumerate(nodes): 1675 # Seq 1676 inner_ins[idx].append(rename(nd.op.inner_seqs(nd.op.inputs), idx)) 1677 outer_ins += rename(nd.op.outer_seqs(nd.inputs), idx) 1678 1679 for idx, nd in enumerate(nodes): 1680 # MitMot 1681 inner_ins[idx].append( 1682 rename(nd.op.inner_mitmot(nd.op.inputs), idx)) 1683 inner_outs[idx].append(nd.op.inner_mitmot_outs(nd.op.outputs)) 1684 info['tap_array'] += nd.op.mitmot_taps() 1685 info['mit_mot_out_slices'] += nd.op.mitmot_out_taps() 1686 outer_ins += rename(nd.op.outer_mitmot(nd.inputs), idx) 1687 outer_outs += nd.op.outer_mitmot_outs(nd.outputs) 1688 1689 for idx, nd in enumerate(nodes): 1690 # MitSot 1691 inner_ins[idx].append( 1692 rename(nd.op.inner_mitsot(nd.op.inputs), idx)) 1693 inner_outs[idx].append(nd.op.inner_mitsot_outs(nd.op.outputs)) 1694 info['tap_array'] += nd.op.mitsot_taps() 1695 outer_ins += rename(nd.op.outer_mitsot(nd.inputs), idx) 1696 outer_outs += nd.op.outer_mitsot_outs(nd.outputs) 1697 1698 for idx, nd in enumerate(nodes): 1699 # SitSot 1700 inner_ins[idx].append( 1701 rename(nd.op.inner_sitsot(nd.op.inputs), idx)) 1702 info['tap_array'] += [[-1] for x in xrange(nd.op.n_sit_sot)] 1703 inner_outs[idx].append(nd.op.inner_sitsot_outs(nd.op.outputs)) 1704 outer_ins += rename(nd.op.outer_sitsot(nd.inputs), idx) 1705 outer_outs += nd.op.outer_sitsot_outs(nd.outputs) 1706 1707 for idx, nd in enumerate(nodes): 1708 # Shared 1709 inner_ins[idx].append( 1710 rename(nd.op.inner_shared(nd.op.inputs), idx)) 1711 outer_ins += rename(nd.op.outer_shared(nd.inputs), idx) 1712 1713 for idx, nd in enumerate(nodes): 1714 # NitSot 1715 inner_outs[idx].append(nd.op.inner_nitsot_outs(nd.op.outputs)) 1716 outer_ins += rename(nd.op.outer_nitsot(nd.inputs), idx) 1717 outer_outs += nd.op.outer_nitsot_outs(nd.outputs) 1718 1719 for idx, nd in enumerate(nodes): 1720 # Shared 1721 outer_outs += nd.op.outer_shared_outs(nd.outputs) 1722 inner_outs[idx].append(nd.op.inner_shared_outs(nd.op.outputs)) 1723 1724 for idx, nd in enumerate(nodes): 1725 # Non Seqs 1726 inner_ins[idx].append( 1727 rename(nd.op.inner_non_seqs(nd.op.inputs), idx)) 1728 outer_ins += rename(nd.op.outer_non_seqs(nd.inputs), idx) 1729 1730 # Add back the number of steps 1731 outer_ins = [nodes[0].inputs[0]] + outer_ins 1732 1733 if as_while: 1734 # add the condition, which was the one of nodes[0] 1735 inner_outs[0].append([condition]) 1736 1737 # Clone the inner graph of each node independently 1738 for idx, nd in enumerate(nodes): 1739 # concatenate all inner_ins and inner_outs of nd 1740 flat_inner_ins = sum(inner_ins[idx], []) 1741 flat_inner_outs = sum(inner_outs[idx], []) 1742 # clone 1743 flat_inner_ins, flat_inner_outs = scan_utils.reconstruct_graph( 1744 flat_inner_ins, flat_inner_outs) 1745 # split the new inner variables again in seq, mitmot, etc. 1746 new_inner_ins = [] 1747 count = 0 1748 for nl in inner_ins[idx]: 1749 seq_len = len(nl) 1750 new_inner_ins.append(flat_inner_ins[count:(count + seq_len)]) 1751 count += seq_len 1752 1753 new_inner_outs = [] 1754 count = 0 1755 for nl in inner_outs[idx]: 1756 seq_len = len(nl) 1757 new_inner_outs.append(flat_inner_outs[count:(count + seq_len)]) 1758 count += seq_len 1759 1760 inner_ins[idx] = new_inner_ins 1761 inner_outs[idx] = new_inner_outs 1762 1763 # Flatten inner_ins and inner_outs so that all seqs are first, 1764 # then mitmot, etc. 1765 new_inner_ins = [] 1766 new_inner_outs = [] 1767 nb_ins_groups = len(inner_ins[0]) 1768 nb_outs_groups = len(inner_outs[0]) 1769 for idx, nd in enumerate(nodes): 1770 # All inner_ins should have the same length 1771 assert len(inner_ins[idx]) == nb_ins_groups 1772 1773 # All inner_outs should have the same length, except if as_while, 1774 # in which case the first one should have one more element 1775 if as_while and idx > 0: 1776 assert len(inner_outs[idx]) == nb_outs_groups - 1 1777 else: 1778 assert len(inner_outs[idx]) == nb_outs_groups 1779 1780 for gr_idx in range(nb_ins_groups): 1781 for idx, nd in enumerate(nodes): 1782 new_inner_ins += inner_ins[idx][gr_idx] 1783 1784 for gr_idx in range(nb_outs_groups): 1785 for idx, nd in enumerate(nodes): 1786 if as_while and idx > 0 and gr_idx == (nb_outs_groups - 1): 1787 # There is no condition on that node, skip it 1788 pass 1789 else: 1790 new_inner_outs += inner_outs[idx][gr_idx] 1791 1792 new_op = scan_op.Scan(new_inner_ins, new_inner_outs, info) 1793 new_outs = new_op(*outer_ins) 1794 1795 if not isinstance(new_outs, (list, tuple)): 1796 new_outs = [new_outs] 1797 1798 return list(zip(outer_outs, new_outs)) 1799 1800 def belongs_to_set(self, node, set_nodes): 1801 """ 1802 This function checks if node `node` belongs to `set_nodes`, in the 1803 sense that it can be merged together with every other node in 1804 `set_nodes`. In order for two nodes to be mergeable, they have to go 1805 over the same number of steps, have the same condition (if any), 1806 have the same value for truncate_gradient, and have the same mode. 1807 Questionable, we should also consider profile ? 1808 1809 """ 1810 rep = set_nodes[0] 1811 if (rep.op.as_while != node.op.as_while or 1812 node.op.truncate_gradient != rep.op.truncate_gradient or 1813 node.op.mode != rep.op.mode): 1814 return False 1815 1816 nsteps = node.inputs[0] 1817 try: 1818 nsteps = int(get_scalar_constant_value(nsteps)) 1819 except tensor.NotScalarConstantError: 1820 pass 1821 1822 rep_nsteps = rep.inputs[0] 1823 try: 1824 rep_nsteps = int(get_scalar_constant_value(rep_nsteps)) 1825 except tensor.NotScalarConstantError: 1826 pass 1827 1828 if nsteps != rep_nsteps: 1829 return False 1830 1831 # Check to see if it is an input of a different node 1832 for nd in set_nodes: 1833 if gof.graph.is_in_ancestors(node, nd) or gof.graph.is_in_ancestors(nd, node): 1834 return False 1835 1836 if not node.op.as_while: 1837 return True 1838 cond = node.op.outputs[-1] 1839 rep_cond = rep.op.outputs[-1] 1840 return scan_utils.equal_computations([cond], [rep_cond], 1841 node.op.inputs, 1842 rep.op.inputs) 1843 1844 def apply(self, fgraph): 1845 # Collect all scan nodes ordered according to toposort 1846 scan_nodes = [nd for nd in fgraph.toposort() 1847 if isinstance(nd.op, scan_op.Scan)] 1848 1849 # All sets of possibly mergeable nodes 1850 all_sets = [] 1851 1852 for nd in scan_nodes: 1853 belongs_to_set_idx = -1 1854 for pos, subset in enumerate(all_sets): 1855 if self.belongs_to_set(nd, subset): 1856 belongs_to_set_idx = pos 1857 # It is possible that nd belongs to more than one subset. 1858 # For instance, if we have 3 Scan nodes X, Y and Z, if Z 1859 # depends on the output of X, then X and Z are incompatible 1860 # and would create different subsets, but Y could be 1861 # compatible with both X and Z. We choose the first one. 1862 break 1863 1864 if belongs_to_set_idx == -1: 1865 all_sets.append([nd]) 1866 else: 1867 all_sets[belongs_to_set_idx].append(nd) 1868 1869 for subset in all_sets: 1870 if len(subset) > 1: 1871 proposal = self.merge(subset) 1872 fgraph.replace_all_validate_remove(proposal, 1873 remove=subset, 1874 reason='scanOp_merge') 1875 1876 1877def has_duplicates(l): 1878 """ 1879 Returns true if l has any duplicates (according to __eq__). 1880 1881 """ 1882 return len(set(l)) < len(l) 1883 1884 1885def make_equiv(lo, li): 1886 """ 1887 Builds a dictionary of equivalences between inner inputs based on 1888 the equivalence of their corresponding outer inputs. 1889 1890 """ 1891 seeno = OrderedDict() 1892 left = [] 1893 right = [] 1894 for o, i in zip(lo, li): 1895 if o in seeno: 1896 left += [i] 1897 right += [o] 1898 else: 1899 seeno[o] = i 1900 return left, right 1901 1902 1903@gof.local_optimizer([scan_op.Scan]) 1904def scan_merge_inouts(node): 1905 if not isinstance(node.op, scan_op.Scan): 1906 return False 1907 1908 # Do a first pass to merge identical external inputs. 1909 # Equivalent inputs will be stored in inp_equiv, then a new 1910 # scan node created without duplicates. 1911 a = scan_args(node.inputs, node.outputs, 1912 node.op.inputs, node.op.outputs, node.op.info) 1913 1914 inp_equiv = OrderedDict() 1915 1916 if has_duplicates(a.outer_in_seqs): 1917 new_outer_seqs = [] 1918 new_inner_seqs = [] 1919 for out_seq, in_seq in zip(a.outer_in_seqs, a.inner_in_seqs): 1920 if out_seq in new_outer_seqs: 1921 i = new_outer_seqs.index(out_seq) 1922 inp_equiv[in_seq] = new_inner_seqs[i] 1923 else: 1924 new_outer_seqs.append(out_seq) 1925 new_inner_seqs.append(in_seq) 1926 a.outer_in_seqs = new_outer_seqs 1927 a.inner_in_seqs = new_inner_seqs 1928 1929 if has_duplicates(a.outer_in_non_seqs): 1930 new_outer_nseqs = [] 1931 new_inner_nseqs = [] 1932 for out_nseq, in_nseq in zip(a.outer_in_non_seqs, a.inner_in_non_seqs): 1933 if out_nseq in new_outer_nseqs: 1934 i = new_outer_nseqs.index(out_nseq) 1935 inp_equiv[in_nseq] = new_inner_nseqs[i] 1936 else: 1937 new_outer_nseqs.append(out_nseq) 1938 new_inner_nseqs.append(in_nseq) 1939 a.outer_in_non_seqs = new_outer_nseqs 1940 a.inner_in_non_seqs = new_inner_nseqs 1941 1942 if len(inp_equiv) > 0: 1943 # do the replacement now. The rest will be left to ScanSaveMem 1944 inner_inputs = a.inner_inputs 1945 outer_inputs = a.outer_inputs 1946 info = a.info 1947 a_inner_outs = a.inner_outputs 1948 inner_outputs = scan_utils.clone(a_inner_outs, replace=inp_equiv) 1949 1950 op = scan_op.Scan(inner_inputs, inner_outputs, info) 1951 outputs = op(*outer_inputs) 1952 1953 if not isinstance(outputs, (list, tuple)): 1954 outputs = [outputs] 1955 1956 na = scan_args(outer_inputs, outputs, op.inputs, op.outputs, op.info) 1957 remove = [node] 1958 else: 1959 na = a 1960 remove = [] 1961 1962 # Now that the identical external inputs have been merged, we do a new 1963 # loop in order to merge external outputs that compute the same things 1964 # from the same inputs. 1965 left = [] 1966 right = [] 1967 1968 if has_duplicates(na.outer_in_shared): 1969 _left, _right = make_equiv(na.outer_in_shared, na.inner_in_shared) 1970 left += _left 1971 right += _right 1972 if has_duplicates(na.outer_in_sit_sot): 1973 _left, _right = make_equiv(na.outer_in_sit_sot, na.inner_in_sit_sot) 1974 left += _left 1975 right += _right 1976 if has_duplicates(na.outer_in_mit_mot): 1977 seen = OrderedDict() 1978 for omm, imm, _sl in zip(na.outer_in_mit_mot, 1979 na.inner_in_mit_mot, na.mit_mot_in_slices): 1980 sl = tuple(_sl) 1981 if (omm, sl) in seen: 1982 simm = seen[(omm, sl)] 1983 left += imm 1984 right += simm 1985 else: 1986 seen[(omm, sl)] = imm 1987 1988 if has_duplicates(na.outer_in_mit_sot): 1989 seen = OrderedDict() 1990 for oms, ims, _sl in zip(na.outer_in_mit_sot, 1991 na.inner_in_mit_sot, 1992 na.mit_sot_in_slices): 1993 sl = tuple(_sl) 1994 if (oms, sl) in seen: 1995 sims = seen[(oms, sl)] 1996 left += ims 1997 right += sims 1998 else: 1999 seen[(oms, sl)] = ims 2000 2001 def map_out(outer_i, inner_o, outer_o, seen): 2002 # Return the outer input corresponding to an 2003 # (outer input, inner output) pair. If we see that pair for the first 2004 # time, return the provided outer output. If an equivalent pair had 2005 # already been seen, return that one instead. 2006 # Note that we need to check that the outer input match as well, 2007 # because they could have different sizes, and the corresponding 2008 # outer outputs cannot be merged in that case. 2009 for s_outer_i, s_inner_o, s_outer_o in seen: 2010 if (equal_computations([inner_o], [s_inner_o], left, right) and 2011 outer_i == s_outer_i): 2012 return s_outer_o 2013 seen.append((outer_i, inner_o, outer_o)) 2014 return outer_o 2015 2016 seen = [] 2017 2018 assert len(na.outer_in_nit_sot) == len(na.inner_out_nit_sot) 2019 assert len(na.inner_out_nit_sot) == len(na.outer_out_nit_sot) 2020 na.outer_out_nit_sot = [ 2021 map_out(outer_i, inner_o, outer_o, seen) 2022 for outer_i, inner_o, outer_o in zip(na.outer_in_nit_sot, 2023 na.inner_out_nit_sot, 2024 na.outer_out_nit_sot)] 2025 2026 seen = [] 2027 assert len(na.outer_in_sit_sot) == len(na.inner_out_sit_sot) 2028 assert len(na.inner_out_sit_sot) == len(na.outer_out_sit_sot) 2029 na.outer_out_sit_sot = [ 2030 map_out(outer_i, inner_o, outer_o, seen) 2031 for outer_i, inner_o, outer_o in zip(na.outer_in_sit_sot, 2032 na.inner_out_sit_sot, 2033 na.outer_out_sit_sot)] 2034 2035 seen = [] 2036 assert len(na.outer_in_mit_sot) == len(na.inner_out_mit_sot) 2037 assert len(na.inner_out_mit_sot) == len(na.outer_out_mit_sot) 2038 na.outer_out_mit_sot = [ 2039 map_out(outer_i, inner_o, outer_o, seen) 2040 for outer_i, inner_o, outer_o in zip(na.outer_in_mit_sot, 2041 na.inner_out_mit_sot, 2042 na.outer_out_mit_sot)] 2043 2044 seen = [] 2045 new_outer_out_mit_mot = [] 2046 assert len(na.outer_in_mit_mot) == len(na.inner_out_mit_mot) 2047 assert len(na.inner_out_mit_mot) == len(na.outer_out_mit_mot) 2048 assert len(na.outer_out_mit_mot) == len(na.mit_mot_out_slices) 2049 for outer_imm, inner_omm, outer_omm, osl in zip(na.outer_in_mit_mot, 2050 na.inner_out_mit_mot, 2051 na.outer_out_mit_mot, 2052 na.mit_mot_out_slices): 2053 for s_outer_imm, s_inner_omm, s_outer_omm, sosl in seen: 2054 if (osl == sosl and 2055 equal_computations(inner_omm, s_inner_omm, left, right) and 2056 outer_imm == s_outer_imm): 2057 2058 new_outer_out_mit_mot.append(s_outer_omm) 2059 break 2060 else: 2061 seen.append((outer_imm, inner_omm, outer_omm, osl)) 2062 new_outer_out_mit_mot.append(outer_omm) 2063 na.outer_out_mit_mot = new_outer_out_mit_mot 2064 if remove: 2065 return OrderedDict([("remove", remove)] + 2066 list(zip(node.outputs, na.outer_outputs))) 2067 return na.outer_outputs 2068 2069 2070class PushOutDot1(gof.Optimizer): 2071 """ 2072 Graph optimizer for Scan(makes it run inplace). 2073 2074 """ 2075 2076 def __init__(self): 2077 Optimizer.__init__(self) 2078 2079 def add_requirements(self, fgraph): 2080 fgraph.attach_feature(toolbox.ReplaceValidate()) 2081 2082 def apply(self, fgraph): 2083 2084 nodes = fgraph.toposort() 2085 scan_nodes = [x for x in nodes if (isinstance(x.op, scan_op.Scan))] 2086 for node in scan_nodes: 2087 self.apply_opt(fgraph, node) 2088 2089 def apply_opt(self, fgraph, node): 2090 # Replace pattern of the form 2091 # x[t] = x[t-1] + dot(seq[t], value) 2092 # with Sequence.reshape((-1, seq.shape[2])) \dot Value 2093 # When seq[t] is a vector/matrix and `value` is a matrix 2094 # Note that this works when only you need X[-1] in the end 2095 # and assumes dimshuffle are applied to vectors before calling dot 2096 op = node.op 2097 sitsot_ins = op.inner_sitsot(op.inputs) 2098 sitsot_outs = op.inner_sitsot_outs(op.outputs) 2099 outer_sitsot = op.outer_sitsot_outs(node) 2100 seqs = op.inner_seqs(op.inputs) 2101 for inp, out, outer_out in zip(sitsot_ins, sitsot_outs, outer_sitsot): 2102 2103 if (out.owner and 2104 isinstance(out.owner.op, theano.tensor.Elemwise) and 2105 isinstance(out.owner.op.scalar_op, theano.scalar.Add) and 2106 inp in out.owner.inputs and 2107 len(outer_out.clients) == 1 and 2108 not isinstance(outer_out.clients[0][0], str) and 2109 isinstance(outer_out.clients[0][0].op, theano.tensor.Subtensor) and 2110 outer_out.clients[0][0].op.idx_list == (-1,)): 2111 2112 x = out.owner.inputs[0] 2113 if x == inp: 2114 x = out.owner.inputs[1] 2115 # We need to check if x is the result of an outer product 2116 if (x.owner and isinstance(x.owner.op, theano.tensor.Dot) and 2117 x.owner.inputs[0].ndim == 2 and x.owner.inputs[1].ndim == 2): 2118 2119 # We need to check if any of the inputs are a sequence 2120 inp1 = x.owner.inputs[0] 2121 inp2 = x.owner.inputs[1] 2122 2123 if inp1 in seqs or inp2 in seqs: 2124 new_scan_out = inp1 2125 2126 if inp1 in seqs: 2127 new_scan_out = inp2 2128 idx = sitsot_outs.index(out) 2129 # We've found our pattern and need to construct a new 2130 # scan node to replace this one. For this we need to 2131 # replace the sit_sot output with a nit_sot output 2132 2133 # First let us split all arguments according to their 2134 # corresponding categories 2135 2136 inner_seqs = op.inner_seqs(op.inputs) 2137 outer_seqs = op.outer_seqs(node) 2138 inner_mitmot = op.inner_mitmot(op.inputs) 2139 outer_mitmot = op.outer_mitmot(node) 2140 inner_mitmot_outs = op.inner_mitmot_outs(op.outputs) 2141 inner_mitsot = op.inner_mitsot(op.inputs) 2142 outer_mitsot = op.outer_mitsot(node) 2143 inner_mitsot_outs = op.inner_mitsot_outs(op.outputs) 2144 inner_sitsot = op.inner_sitsot(op.inputs) 2145 outer_sitsot = op.outer_sitsot(node) 2146 inner_sitsot_outs = op.inner_sitsot_outs(op.outputs) 2147 outer_nitsot = op.outer_nitsot(node) 2148 inner_nitsot_outs = op.inner_nitsot_outs(op.outputs) 2149 inner_shared = op.inner_shared(op.inputs) 2150 outer_shared = op.outer_shared(node) 2151 inner_shared_outs = op.inner_shared_outs(op.outputs) 2152 inner_non_seqs = op.inner_non_seqs(op.inputs) 2153 outer_non_seqs = op.outer_non_seqs(node) 2154 2155 new_info = op.info.copy() 2156 st = len(op.mitmot_taps()) + len(op.mitsot_taps()) 2157 2158 new_info['tap_array'] = ( 2159 new_info['tap_array'][:st + idx] + 2160 new_info['tap_array'][st + idx + 1:]) 2161 new_info['n_sit_sot'] -= 1 2162 new_info['n_nit_sot'] += 1 2163 inner_sitsot = (inner_sitsot[:idx] + 2164 inner_sitsot[idx + 1:]) 2165 outer_sitsot = (outer_sitsot[:idx] + 2166 outer_sitsot[idx + 1:]) 2167 inner_sitsot_outs = (inner_sitsot_outs[:idx] + 2168 inner_sitsot_outs[idx + 1:]) 2169 # add n_steps as the length 2170 inner_nitsot_outs.append(new_scan_out) 2171 2172 _new_inner_inps = (inner_seqs + 2173 inner_mitmot + 2174 inner_mitsot + 2175 inner_sitsot + 2176 inner_shared + 2177 inner_non_seqs) 2178 _new_inner_outs = (inner_mitmot_outs + 2179 inner_mitsot_outs + 2180 inner_sitsot_outs + 2181 inner_nitsot_outs + 2182 inner_shared_outs) 2183 new_inner_inps, new_inner_outs =\ 2184 scan_utils.reconstruct_graph(_new_inner_inps, 2185 _new_inner_outs) 2186 new_op = scan_op.Scan(new_inner_inps, new_inner_outs, 2187 new_info) 2188 _scan_inputs = ([node.inputs[0]] + 2189 outer_seqs + 2190 outer_mitmot + 2191 outer_mitsot + 2192 outer_sitsot + 2193 outer_shared + 2194 outer_nitsot + 2195 [node.inputs[0]] + 2196 outer_non_seqs) 2197 2198 new_outs = new_op(*_scan_inputs) 2199 if type(new_outs) not in (list, tuple): 2200 new_outs = [new_outs] 2201 2202 # We need now to pair correctly the new outputs 2203 # with the old ones 2204 2205 outer_nitsot_outs = new_op.outer_nitsot_outs(new_outs) 2206 2207 _val = outer_nitsot_outs[-1] 2208 outer_nitsot_outs = outer_nitsot_outs[:-1] 2209 if inp1 in seqs: 2210 _out_seq = op.outer_seqs(node)[seqs.index(inp1)] 2211 # We need to clip the seq to the number of steps 2212 _out_seq = _out_seq[:node.inputs[0]] 2213 sh0 = _out_seq.shape[0] 2214 sh1 = _out_seq.shape[1] 2215 sh2 = _out_seq.shape[2] 2216 out_seq = _out_seq.dimshuffle(1, 0, 2) 2217 out_seq = out_seq.reshape((sh1, sh0 * sh2)) 2218 sh0 = _val.shape[0] 2219 sh1 = _val.shape[1] 2220 sh2 = _val.shape[2] 2221 2222 val = _val.reshape((sh0 * sh1, sh2)) 2223 new_out = tensor.dot(out_seq, val) 2224 else: 2225 _out_seq = op.outer_seqs(node)[seqs.index(inp2)] 2226 out_seq = _out_seq.reshape( 2227 (_out_seq.shape[0] * _out_seq.shape[1], 2228 _out_seq.shape[2])) 2229 2230 val = _val.dimshuffle(1, 0, 2).reshape( 2231 (_val.shape[1], 2232 _val.shape[0] * _val.shape[2])) 2233 new_out = tensor.dot(val, out_seq) 2234 2235 pos = node.outputs.index(outer_out) 2236 old_new = list(zip(node.outputs[:pos], new_outs[:pos])) 2237 old = node.outputs[pos].clients[0][0].outputs[0] 2238 old_new.append((old, new_out)) 2239 old_new += list(zip(node.outputs[pos + 1:], 2240 new_outs[pos:])) 2241 fgraph.replace_all_validate_remove( 2242 old_new, remove=[node], reason='scan_pushout_dot1') 2243 2244 2245# I've added an equilibrium because later scan optimization in the sequence 2246# can make it such that earlier optimizations should apply. However, in 2247# general I do not expect the sequence to run more then once 2248scan_eqopt1 = theano.gof.EquilibriumDB() 2249scan_seqopt1 = theano.gof.SequenceDB() 2250scan_eqopt2 = theano.gof.EquilibriumDB() 2251 2252# scan_eqopt1 before ShapeOpt at 0.1 2253# This is needed to don't have ShapeFeature trac old Scan that we 2254# don't want to reintroduce. 2255optdb.register('scan_eqopt1', scan_eqopt1, .05, 'fast_run', 'scan') 2256# We run before blas opt at 1.7 and specialize 2.0 2257# but after stabilize at 1.5. Should we put it before stabilize? 2258optdb.register('scan_eqopt2', scan_eqopt2, 1.6, 'fast_run', 'scan') 2259# ScanSaveMem should execute only once per node. 2260optdb.register('scanOp_save_mem', ScanSaveMem(), 1.61, 'fast_run', 'scan') 2261optdb.register('scanOp_make_inplace', 2262 ScanInplaceOptimizer(typeInfer=None), 2263 75, 2264 'fast_run', 2265 'inplace', 2266 'scan') 2267 2268scan_eqopt1.register( 2269 'all_pushout_opt', scan_seqopt1, 1, 'fast_run', 'scan') 2270 2271 2272scan_seqopt1.register('scanOp_remove_constants_and_unused_inputs0', 2273 opt.in2out(remove_constants_and_unused_inputs_scan, 2274 ignore_newtrees=True), 2275 1, 2276 'remove_constants_and_unused_inputs_scan', 2277 'fast_run', 2278 'scan') 2279 2280 2281scan_seqopt1.register('scanOp_pushout_nonseqs_ops', 2282 PushOutNonSeqScan(), 2283 2, 2284 'fast_run', 2285 'scan') 2286 2287 2288scan_seqopt1.register('scanOp_pushout_seqs_ops', 2289 PushOutSeqScan(), 2290 3, 2291 'fast_run', 2292 'scan') 2293 2294 2295scan_seqopt1.register('scan_pushout_dot1', 2296 PushOutDot1(), 2297 4, 2298 'fast_run', 2299 'more_mem', 2300 'scan') 2301 2302 2303scan_seqopt1.register('scanOp_pushout_output', 2304 PushOutScanOutput(), 2305 5, 2306 'fast_run', 2307 'more_mem', 2308 'scan') 2309 2310 2311scan_eqopt2.register('constant_folding_for_scan2', 2312 opt.in2out(tensor.opt.constant_folding, 2313 ignore_newtrees=True), 2314 1, 2315 'fast_run', 2316 'scan') 2317 2318 2319scan_eqopt2.register('scanOp_remove_constants_and_unused_inputs1', 2320 opt.in2out(remove_constants_and_unused_inputs_scan, 2321 ignore_newtrees=True), 2322 2, 2323 'remove_constants_and_unused_inputs_scan', 2324 'fast_run', 2325 'scan') 2326 2327 2328# after const merge but before stabilize so that we can have identity 2329# for equivalent nodes but we still have the chance to hoist stuff out 2330# of the scan later. 2331scan_eqopt2.register('scanOp_merge', 2332 ScanMerge(), 2333 4, 2334 'fast_run', 2335 'scan') 2336 2337# After Merge optimization 2338scan_eqopt2.register('scanop_remove_constants_and_unused_inputs2', 2339 opt.in2out(remove_constants_and_unused_inputs_scan, 2340 ignore_newtrees=True), 2341 5, 2342 'remove_constants_and_unused_inputs_scan', 2343 'fast_run', 2344 'scan') 2345 2346scan_eqopt2.register('scanOp_merge_inouts', 2347 opt.in2out(scan_merge_inouts, ignore_newtrees=True), 2348 6, 2349 'scan_merge_inouts', 2350 'fast_run', 2351 'scan') 2352 2353# After everything else 2354scan_eqopt2.register('scanOp_remove_constants_and_unused_inputs3', 2355 opt.in2out(remove_constants_and_unused_inputs_scan, 2356 ignore_newtrees=True), 2357 8, 2358 'remove_constants_and_unused_inputs_scan', 2359 'fast_run', 2360 'scan') 2361