1from __future__ import absolute_import, print_function, division 2""" Tensor optimizations addressing the ops in basic.py. 3""" 4# TODO: intelligent merge for mul/add 5# TODO: 0*x -> 0 6 7from collections import defaultdict 8import logging 9import itertools 10import operator 11import sys 12import time 13import traceback 14import warnings 15 16import numpy as np 17from six import integer_types, iteritems 18from six.moves import reduce, xrange 19 20import theano 21from theano import gof 22from theano.compat import izip 23from theano.gof import opt, InconsistencyError, TopoOptimizer, graph 24from theano.gof import Variable, Constant 25from theano.gof.opt import copy_stack_trace, in2out 26from theano.gof.utils import MethodNotDefined 27from theano.gradient import DisconnectedType 28from theano import config 29from theano.tensor.elemwise import Elemwise, DimShuffle 30from theano.tensor.subtensor import (get_idx_list, get_canonical_form_slice, 31 Subtensor, IncSubtensor, make_constant, 32 AdvancedIncSubtensor1, 33 AdvancedIncSubtensor, 34 AdvancedSubtensor1, 35 advanced_subtensor, 36 advanced_subtensor1, 37 advanced_inc_subtensor1) 38from theano.tensor.sort import TopKOp 39from theano import scalar 40from theano.scalar import basic 41from theano.tensor import basic as T 42from theano import compile # to register the optimizer built by this file 43from theano.compile.ops import Shape, Shape_i 44from theano.tensor.type import (values_eq_approx_remove_inf, 45 values_eq_approx_remove_nan, 46 values_eq_approx_remove_inf_nan) 47 48from theano.gof.opt import (Optimizer, pre_constant_merge, 49 pre_greedy_local_optimizer) 50from theano.gof import toolbox 51from theano.tensor.basic import (Alloc, get_scalar_constant_value, ShapeError, 52 extract_constant, NotScalarConstantError, 53 Reshape) 54from six import StringIO 55 56_logger = logging.getLogger('theano.tensor.opt') 57 58# Utilities 59 60 61def _fill_chain(new_out, orig_inputs): 62 for i in orig_inputs: 63 new_out = T.fill(i, new_out) 64 return [new_out] 65 66 67def encompasses_broadcastable(b1, b2): 68 """ 69 70 Parameters 71 ---------- 72 b1 73 The broadcastable attribute of a tensor type. 74 b2 75 The broadcastable attribute of a tensor type. 76 77 Returns 78 ------- 79 bool 80 True if the broadcastable patterns b1 and b2 are such that b2 is 81 broadcasted to b1's shape and not the opposite. 82 83 """ 84 if len(b1) < len(b2): 85 return False 86 b1 = b1[-len(b2):] 87 return not any(v1 and not v2 for v1, v2 in zip(b1, b2)) 88 89 90def merge_broadcastables(broadcastables): 91 return [all(bcast) for bcast in zip(*broadcastables)] 92 93 94def scalarconsts_rest(inputs, elemwise=True, only_process_constants=False): 95 """Partition a list of variables into two kinds: 96 scalar constants, and the rest.""" 97 consts = [] 98 origconsts = [] 99 nonconsts = [] 100 for i in inputs: 101 try: 102 v = get_scalar_constant_value(i, elemwise=elemwise, 103 only_process_constants=only_process_constants) 104 consts.append(v) 105 origconsts.append(i) 106 except NotScalarConstantError: 107 nonconsts.append(i) 108 return consts, origconsts, nonconsts 109 110 111def broadcast_like(value, template, fgraph, dtype=None): 112 """ 113 Return a Variable with the same shape and dtype as the template, 114 filled by broadcasting value through it. `value` will be cast as 115 necessary. 116 117 """ 118 value = T.as_tensor_variable(value) 119 if value.type == template.type: 120 return value 121 if template not in fgraph.variables: 122 raise NotImplementedError('broadcast_like currently requires the ' 123 'template Variable to be in the fgraph already') 124 if dtype is None: 125 dtype = template.dtype 126 value = T.cast(value, dtype) 127 if value.type == template.type: 128 return value 129 if hasattr(fgraph, 'shape_feature'): 130 new_shape = fgraph.shape_feature.shape_of[template] 131 else: 132 new_shape = template.shape 133 rval = T.alloc(value, *new_shape) 134 # the template may have 1s in its shape without being broadcastable 135 if rval.broadcastable != template.broadcastable: 136 rval = T.unbroadcast(rval, *[i for i in xrange(rval.ndim) 137 if rval.broadcastable[i] and 138 not template.broadcastable[i]]) 139 assert rval.type.dtype == dtype 140 141 if rval.type.broadcastable != template.broadcastable: 142 raise AssertionError("rval.type.broadcastable is " + 143 str(rval.type.broadcastable) + 144 " but template.broadcastable is" + 145 str(template.broadcastable)) 146 147 return rval 148 149 150class InplaceElemwiseOptimizer(Optimizer): 151 """ 152 We parametrise it to make it work for Elemwise and GpuElemwise op. 153 """ 154 def __init__(self, OP): 155 self.op = OP 156 157 def add_requirements(self, fgraph): 158 fgraph.attach_feature(theano.gof.destroyhandler.DestroyHandler()) 159 160 @staticmethod 161 def print_profile(stream, prof, level=0): 162 blanc = (' ' * level) 163 print(blanc, "InplaceElemwiseOptimizer ", prof['opt'].op, file=stream) 164 for k in ['node_before', 165 'nb_call_replace', 166 'nb_call_validate', 167 'nb_inconsistent']: 168 print(blanc, k, prof[k], file=stream) 169 ndim = prof['ndim'] 170 if ndim: 171 print(blanc, "ndim", "nb", file=stream) 172 for n in sorted(ndim.keys()): 173 print(blanc, n, ndim[n], file=stream) 174 175 def apply(self, fgraph): 176 """ 177 Usage: InplaceElemwiseOptimizer(op).optimize(fgraph) 178 179 Attempts to replace all Broadcast ops by versions of them 180 that operate inplace. It operates greedily: for each Broadcast 181 Op that is encountered, for each output, tries each input to 182 see if it can operate inplace on that input. If so, makes the 183 change and go to the next output or Broadcast Op. 184 185 Examples 186 -------- 187 188 `x + y + z -> x += y += z` 189 190 `(x + y) * (x * y) -> (x += y) *= (x * y) or (x + y) *= (x *= y)` 191 192 """ 193 # We should not validate too often as this takes too much time to 194 # execute! 195 # It is the _dfs_toposort() fct in theano/gof/destroyhandler.py 196 # that takes so much time. 197 # Should we try to use another lib that does toposort? 198 # igraph: http://igraph.sourceforge.net/ 199 # networkx: https://networkx.lanl.gov/ 200 # Should we try to use cython? 201 # Compiling only that fct is not enough, should we try to add the 202 # deque class too? 203 # And init the deque and other list to an upper bound number of 204 # elements? 205 # Maybe Theano should do online toposort as in 206 # http://code.google.com/p/acyclic 207 # 208 # The next longest optimizer is the canonizer phase. 209 # Then I think it is the [io_?]toposort (need to validate) so check if 210 # the solution is also applicable there. 211 212 # We execute `validate` after this number of change. 213 prof = {'opt': self, 214 'node_before': len(fgraph.apply_nodes), 215 'nb_call_replace': 0, 216 'nb_call_validate': 0, 217 'nb_inconsistent': 0, 218 'ndim': defaultdict(lambda: 0)} 219 220 check_each_change = config.tensor.insert_inplace_optimizer_validate_nb 221 if check_each_change == -1: 222 if len(fgraph.apply_nodes) > 500: 223 check_each_change = 10 224 else: 225 check_each_change = 1 226 227 nb_change_no_validate = 0 228 chk = fgraph.checkpoint() 229 230 if fgraph.update_mapping: 231 update_outs = [fgraph.outputs[i] for i in fgraph.update_mapping] 232 else: 233 update_outs = [] 234 235 protected_inputs = [ 236 f.protected for f in fgraph._features if 237 isinstance(f, theano.compile.function_module.Supervisor)] 238 protected_inputs = sum(protected_inputs, []) # flatten the list 239 protected_inputs.extend(fgraph.outputs) 240 for node in list(graph.io_toposort(fgraph.inputs, fgraph.outputs)): 241 op = node.op 242 # gpuarray GpuElemwise inherit from Elemwise 243 if not type(op) == self.op: 244 continue 245 # If big graph and the outputs are scalar, do not make it 246 # inplace. 247 if (check_each_change != 1 and 248 # If multiple outputs, they must all have the same size, 249 # so only check the first. 250 getattr(node.outputs[0].type, 'ndim', -1) == 0): 251 continue 252 253 if op.inplace_pattern: 254 # Maybe this isn't needed anymore, but I don't want to 255 # rish regression now. This case only happen if the 256 # original node add already some inplace patter and we 257 # still try to add more pattern. 258 259 baseline = op.inplace_pattern 260 candidate_outputs = [i for i in xrange(len(node.outputs)) 261 if i not in baseline] 262 # node inputs that are Constant, already destroyed, 263 # or fgraph protected inputs and fgraph outputs can't be used as 264 # inplace target. 265 # Remove here as faster. 266 candidate_inputs = [i for i in xrange(len(node.inputs)) 267 if i not in baseline.values() and 268 not isinstance(node.inputs[i], Constant) and 269 # the next line should not be costly most of the time. 270 not fgraph.has_destroyers([node.inputs[i]]) and 271 node.inputs[i] not in protected_inputs] 272 else: 273 baseline = [] 274 candidate_outputs = list(range(len(node.outputs))) 275 # node inputs that are Constant, already destroyed, 276 # fgraph protected inputs and fgraph outputs can't be used as inplace 277 # target. 278 # Remove here as faster. 279 candidate_inputs = [i for i in xrange(len(node.inputs)) 280 if not isinstance(node.inputs[i], Constant) and 281 not fgraph.has_destroyers([node.inputs[i]]) and 282 node.inputs[i] not in protected_inputs] 283 284 verbose = False 285 286 raised_warning = not verbose 287 288 for candidate_output in candidate_outputs: 289 290 # If the output of the node can be established as an update 291 # output of the fgraph, visit the candidate_inputs in an order 292 # that will improve the chances of making the node operate 293 # inplace on the input it's meant to update 294 candidate_out_var = node.outputs[candidate_output] 295 sorted_candidate_inputs = candidate_inputs 296 297 if candidate_out_var in update_outs: 298 299 # The candidate output is an update. Sort the 300 # variables in candidate_inputs in the following order: 301 # - Vars corresponding to the actual updated input 302 # (best case scenario is for the node that procudes 303 # an update to operate inplace on the variable to 304 # update) 305 # - Vars computed inplace on the updates input (second 306 # best scenario if for the node to work inplace on 307 # a variable obtained by a chain of inplace on the 308 # variable to update. In some cases, this will be 309 # equivalent to operating inplace on the variable to 310 # update) 311 # - Remaining variables 312 updated_inputs = [] 313 for i, f_out in enumerate(fgraph.outputs): 314 if (f_out is candidate_out_var and i in fgraph.update_mapping): 315 updated_inp_idx = fgraph.update_mapping[i] 316 updated_inputs.append(fgraph.inputs[updated_inp_idx]) 317 318 updated_vars = [] 319 vars_from_inplace = [] 320 other_vars = [] 321 for inp_idx in candidate_inputs: 322 inp = node.inputs[inp_idx] 323 if inp in updated_inputs: 324 # the candidate input is the actual updated input 325 updated_vars.append(inp_idx) 326 elif (hasattr(fgraph, 'destroy_handler') and 327 inp.owner and 328 any([fgraph.destroy_handler.root_destroyer.get(up_inp, None) is inp.owner 329 for up_inp in updated_inputs])): 330 331 # the candidate input is a variable computed 332 # inplace on the updated input via a sequence of 333 # one or more inplace operations 334 vars_from_inplace.append(inp_idx) 335 else: 336 other_vars.append(inp_idx) 337 338 sorted_candidate_inputs = (updated_vars + 339 vars_from_inplace + other_vars) 340 341 for candidate_input in sorted_candidate_inputs: 342 # remove inputs that don't have the same dtype as the output 343 if node.inputs[candidate_input].type != node.outputs[ 344 candidate_output].type: 345 continue 346 347 inplace_pattern = dict(baseline) 348 inplace_pattern[candidate_output] = candidate_input 349 try: 350 if hasattr(op.scalar_op, "make_new_inplace"): 351 new_scal = op.scalar_op.make_new_inplace( 352 scalar.transfer_type( 353 *[inplace_pattern.get(i, o.dtype) 354 for i, o in enumerate(node.outputs)])) 355 else: 356 new_scal = op.scalar_op.__class__( 357 scalar.transfer_type( 358 *[inplace_pattern.get(i, None) 359 for i in xrange(len(node.outputs))])) 360 new_outputs = self.op(new_scal, inplace_pattern)( 361 *node.inputs, **dict(return_list=True)) 362 new_node = new_outputs[0].owner 363 364 for r, new_r in zip(node.outputs, new_outputs): 365 prof['nb_call_replace'] += 1 366 fgraph.replace(r, new_r, 367 reason="inplace_elemwise_optimizer") 368 nb_change_no_validate += 1 369 prof['ndim'][candidate_out_var.ndim] += 1 370 if nb_change_no_validate >= check_each_change: 371 prof['nb_call_validate'] += 1 372 fgraph.validate() 373 chk = fgraph.checkpoint() 374 nb_change_no_validate = 0 375 except (ValueError, InconsistencyError) as e: 376 prof['nb_inconsistent'] += 1 377 if check_each_change != 1 and not raised_warning: 378 print(("Some inplace optimization was not " 379 "performed due to unexpected error:"), 380 file=sys.stderr) 381 print(e, file=sys.stderr) 382 raised_warning = True 383 fgraph.revert(chk) 384 continue 385 candidate_inputs.remove(candidate_input) 386 node = new_node 387 baseline = inplace_pattern 388 break 389 390 if nb_change_no_validate > 0: 391 try: 392 fgraph.validate() 393 except Exception: 394 if not raised_warning: 395 print(("Some inplace optimization was not " 396 "performed due to unexpected error"), 397 file=sys.stderr) 398 fgraph.revert(chk) 399 return prof 400 401 def print_summary(self, stream=sys.stdout, level=0, depth=-1): 402 print("%s%s (%s)" % ( 403 (' ' * level), self.__class__.__name__, self.op), file=stream) 404 return inplace_elemwise_optimizer 405 406inplace_elemwise_optimizer = InplaceElemwiseOptimizer(T.Elemwise) 407compile.optdb.register('inplace_elemwise_opt', inplace_elemwise_optimizer, 75, 408 'inplace_opt', # for historic reason 409 'inplace_elemwise_optimizer', 410 'fast_run', 'inplace') 411 412 413def register_useless(lopt, *tags, **kwargs): 414 if type(lopt) == str: 415 def register(inner_lopt): 416 return register_useless(inner_lopt, lopt, *tags, **kwargs) 417 return register 418 else: 419 name = kwargs.pop('name', None) or lopt.__name__ 420 421 compile.mode.local_useless.register(name, lopt, 'last', 'fast_run', 422 *tags, **kwargs) 423 return lopt 424 425 426def register_canonicalize(lopt, *tags, **kwargs): 427 if type(lopt) == str: 428 def register(inner_lopt): 429 return register_canonicalize(inner_lopt, lopt, *tags, **kwargs) 430 return register 431 else: 432 name = kwargs.pop('name', None) or lopt.__name__ 433 compile.optdb['canonicalize'].register(name, lopt, 'fast_run', 434 *tags, **kwargs) 435 return lopt 436 437 438def register_stabilize(lopt, *tags, **kwargs): 439 if type(lopt) == str: 440 def register(inner_lopt): 441 return register_stabilize(inner_lopt, lopt, *tags, **kwargs) 442 return register 443 else: 444 name = kwargs.pop('name', None) or lopt.__name__ 445 compile.optdb['stabilize'].register(name, lopt, 'fast_run', 446 *tags, **kwargs) 447 return lopt 448 449 450def register_specialize(lopt, *tags, **kwargs): 451 if type(lopt) == str: 452 def register(inner_lopt): 453 return register_specialize(inner_lopt, lopt, *tags, **kwargs) 454 return register 455 else: 456 name = kwargs.pop('name', None) or lopt.__name__ 457 compile.optdb['specialize'].register(name, lopt, 'fast_run', 458 *tags, **kwargs) 459 return lopt 460 461 462def register_uncanonicalize(lopt, *tags, **kwargs): 463 if type(lopt) == str: 464 def register(inner_lopt): 465 return register_uncanonicalize(inner_lopt, lopt, *tags, **kwargs) 466 return register 467 else: 468 name = (kwargs and kwargs.pop('name', None)) or lopt.__name__ 469 compile.optdb['uncanonicalize'].register(name, lopt, 'fast_run', *tags, 470 **kwargs) 471 return lopt 472 473 474def register_specialize_device(lopt, *tags, **kwargs): 475 if type(lopt) == str: 476 def register(inner_lopt): 477 return register_specialize_device(inner_lopt, lopt, *tags, **kwargs) 478 return register 479 else: 480 name = (kwargs and kwargs.pop('name', None)) or lopt.__name__ 481 compile.optdb['specialize_device'].register(name, lopt, 'fast_run', *tags, 482 **kwargs) 483 return lopt 484 485 486##################### 487# Dot optimizations # 488##################### 489 490@register_canonicalize 491@register_stabilize 492@gof.local_optimizer([T.Dot]) 493def local_0_dot_x(node): 494 if not isinstance(node.op, T.Dot): 495 return False 496 497 x = node.inputs[0] 498 y = node.inputs[1] 499 replace = False 500 try: 501 if get_scalar_constant_value(x, only_process_constants=True) == 0: 502 replace = True 503 except NotScalarConstantError: 504 pass 505 506 try: 507 if get_scalar_constant_value(y, only_process_constants=True) == 0: 508 replace = True 509 except NotScalarConstantError: 510 pass 511 512 if replace: 513 constant_zero = T.constant(0, dtype=node.outputs[0].type.dtype) 514 if x.ndim == 2 and y.ndim == 2: 515 constant_zero = assert_(constant_zero, 516 T.eq(x.shape[1], y.shape[0])) 517 return [T.alloc(constant_zero, x.shape[0], y.shape[1])] 518 elif x.ndim == 1 and y.ndim == 2: 519 constant_zero = assert_(constant_zero, 520 T.eq(x.shape[0], y.shape[0])) 521 return [T.alloc(constant_zero, y.shape[1])] 522 elif x.ndim == 2 and y.ndim == 1: 523 constant_zero = assert_(constant_zero, 524 T.eq(x.shape[1], y.shape[0])) 525 return [T.alloc(constant_zero, x.shape[0])] 526 elif x.ndim == 1 and y.ndim == 1: 527 constant_zero = assert_(constant_zero, 528 T.eq(x.shape[0], y.shape[0])) 529 return [constant_zero] 530 else: 531 _logger.warning("Optimization Warning: " 532 "Optimization theano/opt.py:local_0_dot_x Found " 533 "that it could apply, but was not implemented " 534 "for dot product with these input types:\n" 535 "(%s, %s)", 536 x.type, y.type) 537 538###################### 539# DimShuffle lifters # 540###################### 541 542 543def apply_local_dimshuffle_lift(var): 544 # return var 545 # lift recursively 546 if not var.owner: 547 return var 548 new = local_dimshuffle_lift.transform(var.owner) 549 if new: 550 return new[0] 551 return var 552 553 554# Checks for two types of useless dimshuffles: 555# 1 - dimshuffle all dimensions in order. 556# 2 - dimshuffle a broadcastable dimension. 557def is_dimshuffle_useless(new_order, input): 558 is_useless = True 559 if len(new_order) == input.type.ndim: 560 all_broadcastable_dims = [i for (i, is_broadcastable) 561 in enumerate(input.type.broadcastable) 562 if is_broadcastable] + ['x'] 563 for i in range(input.type.ndim): 564 if (new_order[i] == i or 565 (i in all_broadcastable_dims and 566 new_order[i] in all_broadcastable_dims)): 567 is_useless = True 568 else: 569 is_useless = False 570 break 571 else: 572 is_useless = False 573 return is_useless 574 575 576@gof.local_optimizer([DimShuffle]) 577def local_dimshuffle_lift(node): 578 """ 579 "Lifts" DimShuffle through Elemwise operations and merges 580 consecutive DimShuffles. Basically, applies the following 581 transformations on the whole graph: 582 583 DimShuffle(Elemwise(x, y)) => Elemwise(DimShuffle(x), DimShuffle(y)) 584 DimShuffle(DimShuffle(x)) => DimShuffle(x) 585 DimShuffle{0,1,...}(x) => x (when the dimshuffle do nothing) 586 587 After this transform, clusters of Elemwise operations are 588 void of DimShuffle operations. 589 590 """ 591 op = node.op 592 if not isinstance(op, DimShuffle): 593 return False 594 595 input = node.inputs[0] 596 inode = input.owner 597 new_order = op.new_order 598 if inode and isinstance(inode.op, Elemwise) and (len(input.clients) == 1): 599 # Don't use make_node to have tag.test_value set. 600 new_inputs = [] 601 for inp in inode.inputs: 602 new_inp = op.__class__(inp.type.broadcastable, 603 op.new_order)(inp) 604 new_inputs.append(apply_local_dimshuffle_lift(new_inp)) 605 copy_stack_trace(node.outputs[0], new_inputs) 606 ret = inode.op(*new_inputs, **dict(return_list=True)) 607 return ret 608 if inode and isinstance(inode.op, DimShuffle): 609 new_order = [x == 'x' and 'x' or inode.op.new_order[x] for x in 610 new_order] 611 input = inode.inputs[0] 612 613 if is_dimshuffle_useless(new_order, input): 614 return [input] 615 elif inode and isinstance(inode.op, DimShuffle): 616 ret = op.__class__(input.type.broadcastable, new_order)(input) 617 ret = apply_local_dimshuffle_lift(ret) 618 copy_stack_trace(node.outputs[0], ret) 619 return [ret] 620 621 622@register_canonicalize 623@gof.local_optimizer([Reshape]) 624def local_useless_dimshuffle_in_reshape(node): 625 """ 626 Removes useless DimShuffle operation inside Reshape: 627 628 reshape(vector.dimshuffle('x', 0), shp) => reshape(vector, shp) 629 reshape(matrix.dimshuffle('x', 0, 'x', 1), shp) => reshape(matrix, shp) 630 reshape(row.dimshuffle(1, 'x'), shp) => reshape(row, shp) 631 reshape(col.dimshuffle(0), shp) => reshape(col, shp) 632 633 """ 634 op = node.op 635 if not isinstance(op, Reshape): 636 return False 637 if not (node.inputs[0].owner is not None and 638 isinstance(node.inputs[0].owner.op, DimShuffle)): 639 return False 640 641 new_order = node.inputs[0].owner.op.new_order 642 input = node.inputs[0].owner.inputs[0] 643 broadcastables = node.inputs[0].broadcastable 644 new_order_of_nonbroadcast = [] 645 for i, bd in zip(new_order, broadcastables): 646 if not bd: 647 new_order_of_nonbroadcast.append(i) 648 no_change_in_order = all( 649 new_order_of_nonbroadcast[i] <= new_order_of_nonbroadcast[i + 1] 650 for i in xrange(len(new_order_of_nonbroadcast) - 1)) 651 if no_change_in_order: 652 shape = node.inputs[1] 653 ret = op.__class__(node.outputs[0].ndim)(input, shape) 654 copy_stack_trace(node.outputs[0], ret) 655 return [ret] 656 657 658@register_canonicalize 659@gof.local_optimizer([DimShuffle]) 660def local_lift_transpose_through_dot(node): 661 """ 662 dot(x,y).T -> dot(y.T, x.T) 663 664 These optimizations "lift" (propagate towards the inputs) DimShuffle 665 through dot product. It allows to put the graph in a more standard shape, 666 and to later merge consecutive DimShuffles. 667 668 The transformation should be apply whether or not the transpose is 669 inplace. The newly-introduced transpositions are not inplace, this will 670 be taken care of in a later optimization phase. 671 672 """ 673 if not (isinstance(node.op, T.DimShuffle) and node.op.new_order == (1, 0)): 674 return False 675 if not (node.inputs[0].owner and 676 isinstance(node.inputs[0].owner.op, T.Dot)): 677 return False 678 x, y = node.inputs[0].owner.inputs 679 680 if x.ndim == y.ndim == 2: 681 # Output is dot product of transposed inputs in reverse order 682 ret = [T.dot(y.T, x.T)] 683 684 # Copy over stack trace to output from result of dot-product 685 copy_stack_trace(node.inputs[0], ret) 686 return ret 687 688register_canonicalize(local_dimshuffle_lift) 689register_specialize(local_dimshuffle_lift) 690 691###################### 692# Casting operations # 693###################### 694 695 696@register_canonicalize 697@register_specialize 698@gof.local_optimizer([T.TensorFromScalar]) 699def local_tensor_scalar_tensor(node): 700 '''tensor_from_scalar(scalar_from_tensor(x)) -> x''' 701 if isinstance(node.op, T.TensorFromScalar): 702 s = node.inputs[0] 703 if s.owner and isinstance(s.owner.op, T.ScalarFromTensor): 704 t = s.owner.inputs[0] 705 706 # We don't need to copy over any stack traces here 707 return [t] 708 709 710@register_canonicalize 711@register_specialize 712@gof.local_optimizer([T.ScalarFromTensor]) 713def local_scalar_tensor_scalar(node): 714 '''scalar_from_tensor(tensor_from_scalar(x)) -> x''' 715 if isinstance(node.op, T.ScalarFromTensor): 716 t = node.inputs[0] 717 if t.owner and isinstance(t.owner.op, T.TensorFromScalar): 718 s = t.owner.inputs[0] 719 720 # We don't need to copy over any stack traces here 721 return [s] 722 723##################################### 724# ShapeFeature, Shape optimizations 725##################################### 726 727 728class MakeVector(T.Op): 729 """Concatenate a number of scalars together into a vector. 730 731 This is a simple version of stack() that introduces far less cruft 732 into the graph. Should work with 0 inputs. The constant_folding 733 optimization will remove it. 734 735 """ 736 737 __props__ = ("dtype",) 738 739 def __init__(self, dtype='int64'): 740 self.dtype = dtype 741 742 def make_node(self, *inputs): 743 inputs = list(map(T.as_tensor_variable, inputs)) 744 if (not all(a.type == inputs[0].type for a in inputs) or 745 (len(inputs) > 0 and inputs[0].dtype != self.dtype)): 746 dtype = theano.scalar.upcast(self.dtype, *[i.dtype for i in inputs]) 747 # upcast the input to the determined dtype, 748 # but don't downcast anything 749 assert dtype == self.dtype, ( 750 "The upcast of the inputs to MakeVector should match the " 751 "dtype given in __init__.") 752 if not all(self.dtype == T.cast(i, dtype=dtype).dtype 753 for i in inputs): 754 raise TypeError("MakeVector.make_node expected inputs" 755 " upcastable to %s. got %s" % 756 (self.dtype, str([i.dtype for i in inputs]))) 757 inputs = [T.cast(i, dtype=dtype) for i in inputs] 758 assert all(self.dtype == a.dtype for a in inputs) 759 assert all(a.ndim == 0 for a in inputs) 760 761 if inputs: 762 dtype = inputs[0].type.dtype 763 else: 764 dtype = self.dtype 765 # bcastable = (len(inputs) == 1) 766 bcastable = False 767 otype = T.TensorType(broadcastable=(bcastable,), dtype=dtype) 768 return T.Apply(self, inputs, [otype()]) 769 770 def perform(self, node, inputs, out_): 771 out, = out_ 772 # not calling theano._asarray as optimization 773 if (out[0] is None) or (out[0].size != len(inputs)): 774 out[0] = theano._asarray(inputs, dtype=node.outputs[0].dtype) 775 else: 776 # assume that out has correct dtype. there is no cheap way to check 777 out[0][...] = inputs 778 779 def c_code_cache_version(self): 780 return (2,) 781 782 def c_code(self, node, name, inp, out_, sub): 783 out, = out_ 784 # Shouldn't use PyArray_TYPE(inp[0]) for the dtype 785 # when len(inp) == 0 (we need to support this case. 786 # So there will be (1 * nb_dtype) + ((nb len(inp) - 1 )) 787 # different c code with the following algo 788 out_shape = len(inp) 789 out_num = np.dtype(node.outputs[0].dtype).num 790 # don't use dtype_%(out)s as when check_input=False, it isn't defined. 791 out_dtype = node.outputs[0].type.dtype_specs()[1] 792 if len(inp) > 0: 793 assert self.dtype == node.inputs[0].dtype 794 out_num = 'PyArray_TYPE(%s)' % inp[0] 795 796 ret = """ 797 npy_intp dims[1]; 798 dims[0] = %(out_shape)s; 799 if(!%(out)s || PyArray_DIMS(%(out)s)[0] != %(out_shape)s){ 800 Py_XDECREF(%(out)s); 801 %(out)s = (PyArrayObject*)PyArray_EMPTY(1, dims, %(out_num)s, 0); 802 } 803 """ % locals() 804 for idx, i in enumerate(inp): 805 ret += """ 806 *((%(out_dtype)s *)PyArray_GETPTR1(%(out)s, %(idx)s)) = *((%(out_dtype)s *) PyArray_DATA(%(i)s)); 807 """ % locals() 808 return ret 809 810 def infer_shape(self, node, ishapes): 811 return [(len(ishapes),)] 812 813 def grad(self, inputs, output_gradients): 814 # If the output is of an integer dtype, no gradient shall pass 815 if self.dtype in theano.tensor.discrete_dtypes: 816 return [ipt.zeros_like().astype(theano.config.floatX) 817 for ipt in inputs] 818 819 grads = [] 820 for i, inp in enumerate(inputs): 821 grads.append(output_gradients[0][i]) 822 return grads 823 824 def R_op(self, inputs, eval_points): 825 if None in eval_points: 826 return [None] 827 return self.make_node(*eval_points).outputs 828 829make_vector = MakeVector() 830 831 832class MakeVectorPrinter: 833 def process(self, r, pstate): 834 if r.owner is None: 835 raise TypeError("Can only print make_vector.") 836 elif isinstance(r.owner.op, MakeVector): 837 old_precedence = getattr(pstate, 'precedence', None) 838 try: 839 pstate.precedence = 1000 840 s = [pstate.pprinter.process(input) 841 for input in r.owner.inputs] 842 finally: 843 pstate.precedence = old_precedence 844 return "[%s]" % ", ".join(s) 845 else: 846 raise TypeError("Can only print make_vector.") 847 848T.pprint.assign(MakeVector, MakeVectorPrinter()) 849 850 851class ShapeFeature(object): 852 """Graph optimizer for removing all calls to shape(). 853 854 This optimizer replaces all Shapes and Subtensors of Shapes with 855 Shape_i and MakeVector Ops. 856 857 This optimizer has several goals: 858 859 1. to 'lift' Shapes to as close to the inputs as possible. 860 861 2. to infer the shape of every node in the graph in terms of the 862 input shapes. 863 864 3. remove all fills (T.second, T.fill) from the graph 865 866 Lifting shapes as close to the inputs as possible is important for 867 canonicalization because it is very bad form to have to compute 868 something just to know how big it will be. Firstly, it is a waste 869 of time to compute such outputs. But it is important to get rid 870 of these outputs as early as possible in the compilation process 871 because the extra computations make it appear as if many internal 872 graph nodes have multiple clients. Many optimizations refuse to 873 work on nodes with multiple clients. 874 875 Lifting is done by using an `<Op>.infer_shape` function if one is 876 present, or else using a conservative default. An Op that 877 supports shape-lifting should define a infer_shape(self, node, 878 input_shapes) function. The argument input_shapes is a tuple of 879 tuples... there is an interior tuple for each input to the node. 880 The tuple has as many elements as dimensions. The element in 881 position i of tuple j represents the i'th shape component of the 882 j'th input. The function should return a tuple of tuples. One 883 output tuple for each node.output. Again, the i'th element of the 884 j'th output tuple represents the output[j].shape[i] of the 885 function. If an output is not a TensorType, then None should be 886 returned instead of a tuple for that output. 887 888 For example the infer_shape for a matrix-matrix product would accept 889 input_shapes=((x0,x1), (y0,y1)) and return ((x0, y1),). 890 891 Inferring the shape of internal nodes in the graph is important 892 for doing size-driven optimizations. If we know how big various 893 intermediate results will be, we can estimate the cost of many Ops 894 accurately, and generate c-code that is specific [e.g. unrolled] 895 to particular sizes. 896 897 In cases where you cannot figure out the shape, raise a ShapeError. 898 899 Notes 900 ----- 901 Right now there is only the ConvOp that could really take 902 advantage of this shape inference, but it is worth it even 903 just for the ConvOp. All that's necessary to do shape 904 inference is 1) to mark shared inputs as having a particular 905 shape, either via a .tag or some similar hacking; and 2) to 906 add an optional In() argument to promise that inputs will 907 have a certain shape (or even to have certain shapes in 908 certain dimensions). We can't automatically infer the shape of 909 shared variables as they can change of shape during the 910 execution by default. (NOT IMPLEMENTED YET, BUT IS IN TRAC) 911 912 913 **Using Shape information in Optimizations** 914 915 To use this shape information in OPTIMIZATIONS, use the 916 ``shape_of`` dictionary. 917 918 For example: 919 920 .. code-block:: python 921 922 try: 923 shape_of = node.fgraph.shape_feature.shape_of 924 except AttributeError: 925 # This can happen when the mode doesn't include the ShapeFeature. 926 return 927 928 shape_of_output_zero = shape_of[node.output[0]] 929 930 The ``shape_of_output_zero`` symbol will contain a tuple, whose 931 elements are either integers or symbolic integers. 932 933 TODO: check to see if the symbols are necessarily 934 non-constant... or are integer literals sometimes Theano 935 constants?? That would be confusing. 936 937 """ 938 def get_node_infer_shape(self, node): 939 try: 940 shape_infer = node.op.infer_shape 941 except AttributeError: 942 shape_infer = self.default_infer_shape 943 944 try: 945 o_shapes = shape_infer(node, 946 [self.shape_of[r] for r in node.inputs]) 947 except ShapeError: 948 o_shapes = self.default_infer_shape(node, [self.shape_of[r] for 949 r in node.inputs]) 950 except NotImplementedError as e: 951 raise NotImplementedError( 952 'Code called by infer_shape failed raising a ' 953 'NotImplementedError. Raising NotImplementedError to ' 954 'indicate that a shape cannot be computed is no longer ' 955 'supported, and one should now use tensor.ShapeError ' 956 'instead. The original exception message is: %s' % e) 957 except Exception as e: 958 msg = ('Failed to infer_shape from Op %s.\nInput shapes: ' 959 '%s\nException encountered during infer_shape: ' 960 '%s\nException message: %s\nTraceback: %s') % ( 961 node.op, [self.shape_of[r] for r in node.inputs], 962 type(e), str(e), traceback.format_exc()) 963 if config.on_shape_error == "raise": 964 raise Exception(msg) 965 else: 966 _logger.warning(msg) 967 o_shapes = self.default_infer_shape( 968 node, [self.shape_of[r] for r in node.inputs]) 969 970 return o_shapes 971 972 def get_shape(self, var, idx): 973 """ Optimization can call this to get the current shape_i 974 975 It is better to call this then use directly shape_of[var][idx] 976 as this method should update shape_of if needed. 977 978 TODO: Up to now, we don't update it in all cases. Update in all cases. 979 """ 980 r = self.shape_of[var][idx] 981 if (r.owner and 982 isinstance(r.owner.op, Shape_i) and 983 r.owner.inputs[0] not in var.fgraph.variables): 984 assert var.owner 985 node = var.owner 986 # recur on inputs 987 for i in node.inputs: 988 if getattr(i, 'ndim', None) > 0: 989 self.get_shape(i, 0) 990 o_shapes = self.get_node_infer_shape(node) 991 assert len(o_shapes) == len(node.outputs) 992 993 # Only change the variables and dimensions that would introduce 994 # extra computation 995 for new_shps, out in zip(o_shapes, node.outputs): 996 if not hasattr(out, 'ndim'): 997 continue 998 999 merged_shps = list(self.shape_of[out]) 1000 changed = False 1001 for i in range(out.ndim): 1002 n_r = merged_shps[i] 1003 if (n_r.owner and 1004 isinstance(n_r.owner.op, Shape_i) and 1005 n_r.owner.inputs[0] not in var.fgraph.variables): 1006 changed = True 1007 merged_shps[i] = new_shps[i] 1008 if changed: 1009 self.set_shape(out, merged_shps, override=True) 1010 r = self.shape_of[var][idx] 1011 return r 1012 1013 def shape_ir(self, i, r): 1014 """Return symbolic r.shape[i] for tensor variable r, int i.""" 1015 if hasattr(r.type, "broadcastable") and r.type.broadcastable[i]: 1016 return self.lscalar_one 1017 else: 1018 # Do not call make_node for test_value 1019 s = Shape_i(i)(r) 1020 try: 1021 s = get_scalar_constant_value(s) 1022 except NotScalarConstantError: 1023 pass 1024 return s 1025 1026 def shape_tuple(self, r): 1027 """Return a tuple of symbolic shape vars for tensor variable r.""" 1028 if not hasattr(r, 'ndim'): 1029 # This happen for NoneConst. 1030 return None 1031 return tuple([self.shape_ir(i, r) for i in xrange(r.ndim)]) 1032 1033 def default_infer_shape(self, node, i_shapes): 1034 """Return a list of shape tuple or None for the outputs of node. 1035 1036 This function is used for Ops that don't implement infer_shape. 1037 Ops that do implement infer_shape should use the i_shapes parameter, 1038 but this default implementation ignores it. 1039 1040 """ 1041 rval = [] 1042 for r in node.outputs: 1043 try: 1044 rval.append(self.shape_tuple(r)) 1045 except AttributeError: 1046 rval.append(None) 1047 return rval 1048 1049 def unpack(self, s_i, var): 1050 """Return a symbolic integer scalar for the shape element s_i. 1051 1052 The s_i argument was produced by the infer_shape() of an Op subclass. 1053 1054 var: the variable that correspond to s_i. This is just for 1055 error reporting. 1056 1057 """ 1058 # unpack the s_i that the Op returned 1059 assert s_i is not None 1060 if s_i == 1: 1061 # don't make the optimizer merge a zillion ones together 1062 # by always returning the same object to represent 1 1063 return self.lscalar_one 1064 if type(s_i) is float and int(s_i) == s_i: 1065 s_i = int(s_i) 1066 if (type(s_i) in integer_types or 1067 isinstance(s_i, np.integer) or 1068 (isinstance(s_i, np.ndarray) and s_i.ndim == 0)): 1069 # this shape is a constant 1070 if s_i < 0: 1071 msg = "There is a negative shape in the graph!" 1072 msg += gof.utils.get_variable_trace_string(var) 1073 # The rest of the pipeline don't handle correctly this 1074 # case. So we have 2 choices, stop compilation or 1075 # consider the shape as unknow. As we have more 1076 # chance to give the stack trace here then later, I 1077 # choose that options as it would give better error 1078 # message. 1079 raise AssertionError(msg) 1080 return T.constant(s_i, dtype='int64') 1081 if type(s_i) in (tuple, list): 1082 # this dimension is the same as many of the inputs 1083 # which tells us that if one of the inputs is known, 1084 # the others all become known. 1085 # TODO: should be implemented in Elemwise, and Dot 1086 # 1087 # worst case, we loop over shape_of and replace things 1088 raise NotImplementedError(s_i) 1089 1090 # s_i is x.shape[i] for some x, we change it to shape_of[x][i] 1091 if (s_i.owner and 1092 isinstance(s_i.owner.op, Subtensor) and 1093 s_i.owner.inputs[0].owner and 1094 isinstance(s_i.owner.inputs[0].owner.op, T.Shape)): 1095 assert s_i.ndim == 0 1096 assert len(s_i.owner.op.idx_list) == 1 1097 1098 # The current Subtensor always put constant index in the graph. 1099 # This was not True in the past. So call the Subtensor function 1100 # that will return the right index. 1101 idx = get_idx_list(s_i.owner.inputs, s_i.owner.op.idx_list) 1102 assert len(idx) == 1 1103 idx = idx[0] 1104 try: 1105 i = get_scalar_constant_value(idx) 1106 except NotScalarConstantError: 1107 pass 1108 else: 1109 # Executed only if no exception was raised 1110 x = s_i.owner.inputs[0].owner.inputs[0] 1111 # x should already have been imported, and should be in shape_of. 1112 s_i = self.shape_of[x][i] 1113 1114 if s_i.type.dtype in theano.tensor.integer_dtypes: 1115 if getattr(s_i.type, 'ndim', 0): 1116 raise TypeError('Shape element must be scalar', s_i) 1117 return s_i 1118 else: 1119 raise TypeError('Unsupported shape element', 1120 s_i, type(s_i), getattr(s_i, 'type', None)) 1121 1122 def set_shape(self, r, s, override=False): 1123 """Assign the shape `s` to previously un-shaped variable `r`. 1124 1125 Parameters 1126 ---------- 1127 r : a variable 1128 s : None or a tuple of symbolic integers 1129 override : If False, it mean r is a new object in the fgraph. 1130 If True, it mean r is already in the fgraph and we want to 1131 override its shape. 1132 1133 """ 1134 if not override: 1135 assert r not in self.shape_of, 'r already in shape_of' 1136 if s is None: 1137 self.shape_of[r] = s 1138 else: 1139 if not isinstance(s, (tuple, list)): 1140 raise TypeError('shapes must be tuple/list', (r, s)) 1141 1142 if r.ndim != len(s): 1143 sio = StringIO() 1144 theano.printing.debugprint(r, file=sio, print_type=True) 1145 raise AssertionError( 1146 "Something inferred a shape with %d dimensions " 1147 "for a variable with %d dimensions" 1148 " for the variable:\n%s" % ( 1149 len(s), r.ndim, sio.getvalue())) 1150 1151 shape_vars = [] 1152 for i in xrange(r.ndim): 1153 if (hasattr(r.type, 'broadcastable') and 1154 r.type.broadcastable[i]): 1155 shape_vars.append(self.lscalar_one) 1156 else: 1157 shape_vars.append(self.unpack(s[i], r)) 1158 assert all([not hasattr(r.type, "broadcastable") or 1159 not r.type.broadcastable[i] or 1160 # The two following comparison are a speed optimization 1161 # But we never timed this speed optimization! 1162 self.lscalar_one.equals(shape_vars[i]) or 1163 self.lscalar_one.equals( 1164 T.extract_constant(shape_vars[i])) 1165 for i in xrange(r.ndim)]) 1166 self.shape_of[r] = tuple(shape_vars) 1167 for sv in shape_vars: 1168 self.shape_of_reverse_index.setdefault(sv, set()).add(r) 1169 1170 def update_shape(self, r, other_r): 1171 """Replace shape of r by shape of other_r. 1172 1173 If, on some dimensions, the shape of other_r is not informative, 1174 keep the shape of r on those dimensions. 1175 1176 """ 1177 # other_r should already have a shape 1178 assert other_r in self.shape_of, ('other_r not in shape_of', other_r) 1179 other_shape = self.shape_of[other_r] 1180 1181 # If other_shape has no information, call is pointless. 1182 if other_shape is None: 1183 return 1184 1185 if r in self.shape_of: 1186 r_shape = self.shape_of[r] 1187 else: 1188 # If no info is known on r's shape, use other_shape 1189 self.set_shape(r, other_shape) 1190 return 1191 if (other_r.owner and r.owner and 1192 other_r.owner.inputs == r.owner.inputs and 1193 other_r.owner.op == r.owner.op): 1194 # We are doing a merge. So the 2 shapes graph will be the 1195 # same. This is only a speed optimization to call 1196 # ancestors() less frequently. 1197 return 1198 1199 # Merge other_shape with r_shape, giving the priority to other_shape 1200 merged_shape = [] 1201 for i, ps in enumerate(other_shape): 1202 if r_shape is None and other_shape: 1203 merged_shape.append(other_shape[i]) 1204 elif (ps.owner and 1205 isinstance(getattr(ps.owner, 'op', None), Shape_i) and 1206 ps.owner.op.i == i and 1207 ps.owner.inputs[0] in (r, other_r)): 1208 # If other_shape[i] is uninformative, use r_shape[i]. 1209 # For now, we consider 2 cases of uninformative other_shape[i]: 1210 # - Shape_i(i)(other_r); 1211 # - Shape_i(i)(r). 1212 merged_shape.append(r_shape[i]) 1213 elif isinstance(r_shape[i], (Constant, integer_types)): 1214 # We do this to call less often ancestors and make 1215 # sure we have the simplest shape possible. 1216 merged_shape.append(r_shape[i]) 1217 elif isinstance(other_shape[i], (Constant, integer_types)): 1218 # We do this to call less often ancestors and make 1219 # sure we have the simplest shape possible. 1220 merged_shape.append(other_shape[i]) 1221 elif other_shape[i] == r_shape[i]: 1222 # This mean the shape is equivalent 1223 # We do not want to do the ancestor check in those cases 1224 merged_shape.append(r_shape[i]) 1225 elif r_shape[i] in theano.gof.graph.ancestors([other_shape[i]]): 1226 # Another case where we want to use r_shape[i] is when 1227 # other_shape[i] actually depends on r_shape[i]. In that case, 1228 # we do not want to substitute an expression with another that 1229 # is strictly more complex. Such a substitution could also lead 1230 # to cycles: if (in the future) r_shape[i] gets replaced by an 1231 # expression of other_shape[i], other_shape[i] may end up 1232 # depending on itself. 1233 merged_shape.append(r_shape[i]) 1234 else: 1235 merged_shape.append(other_shape[i]) 1236 assert all([(not hasattr(r.type, "broadcastable") or 1237 not r.type.broadcastable[i] and 1238 not other_r.type.broadcastable[i]) or 1239 # The two following comparison are a speed optimization 1240 # But we never timed this speed optimization! 1241 self.lscalar_one.equals(merged_shape[i]) or 1242 self.lscalar_one.equals( 1243 T.extract_constant(merged_shape[i], only_process_constants=True)) 1244 for i in xrange(r.ndim)]) 1245 self.shape_of[r] = tuple(merged_shape) 1246 for sv in self.shape_of[r]: 1247 self.shape_of_reverse_index.setdefault(sv, set()).add(r) 1248 1249 def set_shape_i(self, r, i, s_i): 1250 '''Replace element i of shape_of[r] by s_i''' 1251 assert r in self.shape_of 1252 prev_shape = self.shape_of[r] 1253 # prev_shape is a tuple, so we cannot change it inplace, 1254 # so we build another one. 1255 new_shape = [] 1256 for j, s_j in enumerate(prev_shape): 1257 if j == i: 1258 new_shape.append(self.unpack(s_i, r)) 1259 else: 1260 new_shape.append(s_j) 1261 assert all([not hasattr(r.type, "broadcastable") or 1262 not r.type.broadcastable[idx] or 1263 # The two following comparison are a speed optimization 1264 # But we never timed this speed optimization! 1265 self.lscalar_one.equals(new_shape[idx]) or 1266 self.lscalar_one.equals(T.extract_constant(new_shape[idx])) 1267 for idx in xrange(r.ndim)]) 1268 self.shape_of[r] = tuple(new_shape) 1269 for sv in self.shape_of[r]: 1270 self.shape_of_reverse_index.setdefault(sv, set()).add(r) 1271 1272 def init_r(self, r): 1273 '''Register r's shape in the shape_of dictionary.''' 1274 if r not in self.shape_of: 1275 try: 1276 self.set_shape(r, self.shape_tuple(r)) 1277 except AttributeError: # XXX: where would this come from? 1278 self.set_shape(r, None) 1279 1280 def make_vector_shape(self, r): 1281 return make_vector(*self.shape_of[r]) 1282 1283 # 1284 # Feature interface 1285 # 1286 # 1287 def on_attach(self, fgraph): 1288 assert not hasattr(fgraph, 'shape_feature') 1289 fgraph.shape_feature = self 1290 # Must be local to the object as otherwise we reuse the same 1291 # variable for multiple fgraph! 1292 self.lscalar_one = T.constant(1, dtype='int64') 1293 assert self.lscalar_one.type == T.lscalar 1294 1295 self.shape_of = {} 1296 # Variable -> tuple(scalars) or None (All tensor vars map to tuple) 1297 1298 self.scheduled = {} 1299 # Variable -> 1300 1301 self.shape_of_reverse_index = {} 1302 # shape var -> graph v 1303 1304 for node in fgraph.toposort(): 1305 self.on_import(fgraph, node, reason='on_attach') 1306 1307 def on_detach(self, fgraph): 1308 self.shape_of = {} 1309 self.scheduled = {} 1310 self.shape_of_reverse_index = {} 1311 del fgraph.shape_feature 1312 1313 def on_import(self, fgraph, node, reason): 1314 if node.outputs[0] in self.shape_of: 1315 # this is a revert, not really an import 1316 for r in node.outputs + node.inputs: 1317 assert r in self.shape_of 1318 return 1319 1320 for i, r in enumerate(node.inputs): 1321 # make sure we have shapes for the inputs 1322 self.init_r(r) 1323 1324 o_shapes = self.get_node_infer_shape(node) 1325 1326 # this is packed information 1327 # an element of o_shapes is either None or a tuple 1328 # elements of the tuple can be either strings, or ints 1329 if len(o_shapes) != len(node.outputs): 1330 raise Exception( 1331 ('The infer_shape method for the Op "%s" returned a list ' + 1332 'with the wrong number of element: len(o_shapes) = %d ' + 1333 ' != len(node.outputs) = %d') % (str(node.op), 1334 len(o_shapes), 1335 len(node.outputs))) 1336 1337 # Ensure shapes are in 'int64'. This is to make sure the assert 1338 # found in the `local_useless_subtensor` optimization does not fail. 1339 for sh_idx, sh in enumerate(o_shapes): 1340 if sh is None: 1341 continue 1342 if not isinstance(sh, (list, tuple)): 1343 raise ValueError("infer_shape of %s didn't return a list of" 1344 " list. It returned '%s'" % (str(node), str(o_shapes))) 1345 new_shape = [] 1346 for i, d in enumerate(sh): 1347 # Note: we ignore any shape element that is not typed (i.e., 1348 # does not have a 'dtype' attribute). This means there may 1349 # still remain int elements that are int32 on 32-bit platforms, 1350 # but this works with `local_useless_subtensor`, so for now we 1351 # keep it this way. See #266 for a better long-term fix. 1352 if getattr(d, 'dtype', 'int64') != 'int64': 1353 assert d.dtype in theano.tensor.discrete_dtypes, (node, d.dtype) 1354 assert str(d.dtype) != 'uint64', node 1355 new_shape += sh[len(new_shape):i + 1] 1356 if isinstance(d, T.Constant): 1357 casted_d = T.constant(d.data, dtype='int64') 1358 else: 1359 casted_d = theano.tensor.cast(d, 'int64') 1360 new_shape[i] = casted_d 1361 if new_shape: 1362 # We replace the shape with wrong dtype by the one with 1363 # 'int64'. 1364 new_shape += sh[len(new_shape):] 1365 o_shapes[sh_idx] = tuple(new_shape) 1366 1367 for r, s in izip(node.outputs, o_shapes): 1368 self.set_shape(r, s) 1369 1370 def on_change_input(self, fgraph, node, i, r, new_r, reason): 1371 if new_r not in self.shape_of: 1372 # It happen that the fgraph didn't called on_import for some 1373 # new_r. This happen when new_r don't have an 1374 # owner(i.e. it is a constant or an input of the graph) 1375 # update_shape suppose that r and new_r are in shape_of. 1376 self.init_r(new_r) 1377 1378 # This tells us that r and new_r must have the same shape if 1379 # we didn't know that the shapes are related, now we do. 1380 self.update_shape(new_r, r) 1381 1382 # change_input happens in two cases: 1383 # 1) we are trying to get rid of r, or 1384 # 2) we are putting things back after a failed transaction. 1385 1386 # In case 1, if r has a shape_i client, we will want to 1387 # replace the shape_i of r with the shape of new_r. Say that 1388 # r is *scheduled*. 1389 # At that point, node is no longer a client of r, but of new_r 1390 for (shpnode, idx) in (r.clients + [(node, i)]): 1391 if isinstance(getattr(shpnode, 'op', None), Shape_i): 1392 idx = shpnode.op.i 1393 repl = self.shape_of[new_r][idx] 1394 if repl.owner is shpnode: 1395 # This mean the replacement shape object is 1396 # exactly the same as the current shape object. So 1397 # no need for replacement. This happen for example 1398 # with the InputToGpuOptimizer optimizer. 1399 continue 1400 if (repl.owner and 1401 repl.owner.inputs[0] is shpnode.inputs[0] and 1402 isinstance(repl.owner.op, Shape_i) and 1403 repl.owner.op.i == shpnode.op.i): 1404 # The replacement is a shape_i of the same 1405 # input. So no need to do this equivalent 1406 # replacement. 1407 continue 1408 1409 if shpnode.outputs[0] in theano.gof.graph.ancestors([repl]): 1410 raise InconsistencyError( 1411 "This substitution would insert a cycle in the graph:" 1412 "node: %s, i: %i, r: %s, new_r: %s" 1413 % (node, i, r, new_r)) 1414 1415 self.scheduled[shpnode] = new_r 1416 # In case 2, if r is a variable that we've scheduled for shape update, 1417 # then we should cancel it. 1418 unscheduled = [k for k, v in self.scheduled.items() if v == r] 1419 for k in unscheduled: 1420 del self.scheduled[k] 1421 1422 # In either case, r could be in shape_of.values(), that is, r itself 1423 # is the shape of something. In that case, we want to update 1424 # the value in shape_of, to keep it up-to-date. 1425 for v in self.shape_of_reverse_index.get(r, []): 1426 # The reverse index is only approximate. It is not updated on 1427 # deletion of variables, or on change_input so it might be the 1428 # case that there are a few extra `v`'s in it that no longer have 1429 # a shape of r or possibly have been deleted from shape_of 1430 # entirely. The important thing is that it permits to recall 1431 # all variables with r in their shape. 1432 for ii, svi in enumerate(self.shape_of.get(v, [])): 1433 if svi == r: 1434 self.set_shape_i(v, ii, new_r) 1435 self.shape_of_reverse_index[r] = set() 1436 1437 def same_shape(self, x, y, dim_x=None, dim_y=None): 1438 """Return True if we are able to assert that x and y have the 1439 same shape. 1440 1441 dim_x and dim_y are optional. If used, they should be an index 1442 to compare only 1 dimension of x and y. 1443 1444 """ 1445 sx = self.shape_of[x] 1446 sy = self.shape_of[y] 1447 if sx is None or sy is None: 1448 return False 1449 if dim_x is not None: 1450 sx = [sx[dim_x]] 1451 if dim_y is not None: 1452 sy = [sy[dim_y]] 1453 assert len(sx) == len(sy) 1454 1455 # We look on each dimensions we want to compare. 1456 # If any of them can't be asserted to be equal, return False. 1457 # Otherwise, we return True at the end. 1458 for dx, dy in zip(sx, sy): 1459 if dx is dy: 1460 continue 1461 # Need to try to find that they are the same shape. We 1462 # need to compare the full graph. It could be slow. So I 1463 # just implement for now the case of Shape_i. 1464 if not dx.owner or not dy.owner: 1465 return False 1466 if (not isinstance(dx.owner.op, Shape_i) or 1467 not isinstance(dy.owner.op, Shape_i)): 1468 return False 1469 opx = dx.owner.op 1470 opy = dy.owner.op 1471 if not (opx.i == opy.i): 1472 return False 1473 # FB I'm not sure if this handle correctly constants. 1474 if dx.owner.inputs[0] == dy.owner.inputs[0]: 1475 continue 1476 # To be sure to cover all case, call equal_computation. 1477 # Can't use theano.gof.graph.is_same_graph(dx, dy) 1478 # As it currently expect that dx and dy aren't in a FunctionGraph 1479 from theano.scan_module.scan_utils import equal_computations 1480 if not equal_computations([dx], [dy]): 1481 return False 1482 return True 1483 1484 1485class ShapeOptimizer(Optimizer): 1486 """Optimizer that serves to add ShapeFeature as an fgraph feature.""" 1487 def add_requirements(self, fgraph): 1488 fgraph.attach_feature(ShapeFeature()) 1489 1490 def apply(self, fgraph): 1491 pass 1492 1493 1494class UnShapeOptimizer(Optimizer): 1495 """Optimizer remove ShapeFeature as an fgraph feature.""" 1496 def apply(self, fgraph): 1497 for feature in fgraph._features: 1498 if isinstance(feature, ShapeFeature): 1499 fgraph.remove_feature(feature) 1500 1501# Register it after merge1 optimization at 0. We don't want to track 1502# the shape of merged node. 1503theano.compile.mode.optdb.register('ShapeOpt', ShapeOptimizer(), 1504 0.1, 'fast_run', 'fast_compile') 1505# Not enabled by default for now. Some crossentropy opt use the 1506# shape_feature. They are at step 2.01. uncanonicalize is at step 1507# 3. After it goes to 48.5 that move to the gpu. So 10 seem resonable. 1508theano.compile.mode.optdb.register('UnShapeOpt', UnShapeOptimizer(), 1509 10) 1510 1511 1512def local_elemwise_alloc_op(ElemwiseOP, AllocOP, DimShuffleOP): 1513 def local_elemwise_alloc(node): 1514 """ 1515 elemwise(alloc(x, shp), ..., y.TensorType(BROADCAST CONDITION)) 1516 -> elemwise(x, y.TensorType(BROADCAST CONDITION)) 1517 1518 elemwise(dimshuffle(alloc(x, shp)),... ,y.TensorType(BROADCAST CONDITION)) 1519 -> elemwise(x.dimshuffle(...), y.TensorType(BROADCAST CONDITION)) 1520 1521 BROADCAST CONDITION: the condition is that the one input that are 1522 not to be optimized to have the same broadcast pattern as the 1523 output. 1524 1525 We can change the alloc by a dimshuffle as the elemwise 1526 already have the shape info. The dimshuffle will be faster 1527 to exec. 1528 1529 """ 1530 if not isinstance(node.op, ElemwiseOP): 1531 return False 1532 1533 if len(node.outputs) > 1: 1534 # Ensure all outputs have the same broadcast pattern 1535 # This is a supposition that I'm not sure is always true. 1536 assert all([o.type.broadcastable == 1537 node.outputs[0].type.broadcastable for o in 1538 node.outputs[1:]]) 1539 1540 # The broadcast pattern of the ouptut must match the broadcast 1541 # pattern of at least one of the inputs. 1542 if not any([i.type.broadcastable == 1543 node.outputs[0].type.broadcastable for i in node.inputs]): 1544 return False 1545 1546 def dimshuffled_alloc(i): 1547 return (isinstance(i.owner.op, DimShuffleOP) and 1548 i.owner.inputs[0].owner and 1549 isinstance(i.owner.inputs[0].owner.op, AllocOP)) 1550 1551 # At least one input must have an owner that is either a AllocOP or a 1552 # DimShuffleOP with an owner that is a AllocOP -- otherwise there is 1553 # nothing to optimize. 1554 if not any([i.owner and (isinstance(i.owner.op, AllocOP) or 1555 dimshuffled_alloc(i)) for i in node.inputs]): 1556 return False 1557 1558 # Search for input that we can use as a baseline for the dimensions. 1559 assert_op_idx = -1 1560 for idx, i in enumerate(node.inputs): 1561 if i.type.broadcastable == node.outputs[0].type.broadcastable: 1562 # Prefer an input that is not a AllocOP nor a DimShuffleOP of a 1563 # AllocOP so that all allocs can be optimized. 1564 if not (i.owner and (isinstance(i.owner.op, AllocOP) or 1565 dimshuffled_alloc(i))): 1566 assert_op_idx = idx 1567 break 1568 1569 # It may be the case that only AllocOP and DimShuffleOP of AllocOP exist. 1570 if assert_op_idx < 0: 1571 # We want to optimize as many allocs as possible. When 1572 # there is more than one then do all but one. number of 1573 # inputs with alloc or dimshuffle alloc 1574 l2 = [i for i in node.inputs 1575 if (i.owner and (isinstance(i.owner.op, AllocOP) or 1576 dimshuffled_alloc(i)))] 1577 # If only 1 alloc or dimshuffle alloc, it is the one we 1578 # will use for the shape. So no alloc would be removed. 1579 if len(l2) > 1: 1580 # l containt inputs with alloc or dimshuffle alloc 1581 # only. Its length will always be at least one, as we 1582 # checked that before 1583 l = [idx for idx, i in enumerate(node.inputs) 1584 if i.broadcastable == node.outputs[0].broadcastable] 1585 assert_op_idx = l[0] # The first one is as good as any to use. 1586 else: 1587 # Nothing would be optimized! 1588 return False 1589 1590 assert_op = node.inputs[assert_op_idx] 1591 cmp_op = assert_op 1592 new_i = [] 1593 same_shape = node.fgraph.shape_feature.same_shape 1594 for i in node.inputs: 1595 # Remove alloc 1596 if (i.owner and isinstance(i.owner.op, AllocOP) and 1597 i.owner.inputs[0].type != i.owner.outputs[0].type): 1598 # when i.owner.inputs[0].type == i.owner.outputs[0].type we 1599 # will remove that alloc later 1600 assert i.type.ndim == cmp_op.ndim 1601 if theano.config.experimental.local_alloc_elemwise_assert: 1602 get_shape = node.fgraph.shape_feature.get_shape 1603 cond = [] 1604 for idx in xrange(i.type.ndim): 1605 if (not i.type.broadcastable[idx] and 1606 not same_shape(i, cmp_op, idx, idx)): 1607 i_shp = get_shape(i, idx) 1608 cmp_shp = get_shape(cmp_op, idx) 1609 cond.append(T.eq(i_shp, cmp_shp)) 1610 if cond: 1611 assert_op = assert_(assert_op, *cond) 1612 new_i.append(i.owner.inputs[0]) 1613 1614 # Remove Alloc in DimShuffle 1615 elif i.owner and dimshuffled_alloc(i): 1616 assert i.type.ndim == cmp_op.type.ndim 1617 if theano.config.experimental.local_alloc_elemwise_assert: 1618 assert_cond = [T.eq(i.shape[idx], cmp_op.shape[idx]) 1619 for idx in xrange(i.type.ndim) 1620 if not i.type.broadcastable[idx] and 1621 not same_shape(i, cmp_op, idx, idx)] 1622 if assert_cond: 1623 assert_op = assert_(assert_op, *assert_cond) 1624 alloc_input = i.owner.inputs[0].owner.inputs[0] 1625 if alloc_input.ndim != i.owner.inputs[0].ndim: 1626 # The alloc can add dimension to the value 1627 # We add a dimshuffle to add them. 1628 # We let later optimization merge the multiple dimshuffle 1629 nb_dim_to_add = i.owner.inputs[0].ndim - alloc_input.ndim 1630 alloc_input = alloc_input.dimshuffle( 1631 ['x'] * nb_dim_to_add + 1632 list(range(alloc_input.ndim))) 1633 1634 # We need to keep the dimshuffle. It could swap axes or 1635 # add dimensions anywhere. 1636 r_i = i.owner.op(alloc_input) 1637 1638 # Copy stack trace from i to new_i 1639 copy_stack_trace(i, r_i) 1640 new_i.append(r_i) 1641 else: 1642 new_i.append(i) 1643 new_i[assert_op_idx] = assert_op 1644 1645 ret = node.op(*new_i, return_list=True) 1646 1647 # Copy over stack trace from previous outputs to new outputs. 1648 copy_stack_trace(node.outputs, ret) 1649 return ret 1650 1651 return local_elemwise_alloc 1652 1653# TODO, global optimizer that lift the assert to the beginning of the graph. 1654# TODO, optimize all inputs when possible -- currently when all inputs have 1655# an alloc all but one is optimized. 1656 1657local_elemwise_alloc = register_specialize( 1658 gof.local_optimizer([T.Elemwise])( 1659 local_elemwise_alloc_op(T.Elemwise, T.Alloc, T.DimShuffle)), 1660 'local_alloc_elemwise') 1661 1662 1663@gof.local_optimizer([T.Elemwise]) 1664def local_fill_sink(node): 1665 """ 1666 f(fill(a, b), fill(c, d), e) -> fill(c, fill(a, f(b, d, e))) 1667 f need to be an elemwise that isn't a fill. 1668 """ 1669 if (not hasattr(node, 'op') or 1670 not isinstance(node.op, T.Elemwise) or 1671 node.op == T.fill): 1672 return False 1673 models = [] 1674 inputs = [] 1675 for input in node.inputs: 1676 if input.owner and input.owner.op == T.fill: 1677 models.append(input.owner.inputs[0]) 1678 inputs.append(input.owner.inputs[1]) 1679 else: 1680 inputs.append(input) 1681 if not models: 1682 return False 1683 c = node.op(*inputs) 1684 for model in models: 1685 if model.type != c.type: 1686 c = T.fill(model, c) 1687 1688 # The newly created node c doesn't has 'clients', 1689 # so this iteration is took place with node.outputs[0] 1690 replacements = {node.outputs[0]: c} 1691 for client, cl_idx in node.outputs[0].clients: 1692 if (hasattr(client, 'op') and 1693 isinstance(client.op, T.Elemwise) and 1694 not client.op == T.fill): 1695 client_inputs = client.inputs[:] 1696 client_inputs[cl_idx] = c 1697 new_client = client.op(*client_inputs) 1698 1699 # Add clients to new_client 1700 new_client.owner.outputs[0].clients = client.outputs[0].clients 1701 r = local_fill_sink.transform(new_client.owner) 1702 if not r: 1703 continue 1704 replacements.update(r) 1705 return replacements 1706 1707register_canonicalize(local_fill_sink) 1708 1709 1710@register_specialize 1711@register_stabilize 1712# @register_canonicalize # We make full pass after the canonizer phase. 1713@gof.local_optimizer([T.fill]) 1714def local_fill_to_alloc(node): 1715 """fill(s,v) -> alloc(v, shape(s)) 1716 1717 This is an important optimization because with the shape_to_shape_i 1718 optimization, the dependency on 's' is often removed. 1719 1720 """ 1721 if node.op == T.fill: 1722 r, v = node.inputs 1723 if v.type == node.outputs[0].type: 1724 # this is a useless fill, erase it. 1725 rval = [v] 1726 elif v.type.broadcastable == node.outputs[0].type.broadcastable: 1727 # this is a cast 1728 rval = [T.cast(v, node.outputs[0].type.dtype)] 1729 elif r.type.broadcastable == node.outputs[0].type.broadcastable: 1730 # we are broadcasting v somehow, but not r 1731 o = broadcast_like(v, r, node.fgraph, dtype=v.dtype) 1732 copy_stack_trace(node.outputs[0], o) 1733 rval = [o] 1734 else: 1735 # we are broadcasting both v and r, 1736 # the output shape must be computed 1737 # 1738 # TODO: implement this case (including a test!) 1739 # 1740 # I think the strategy should be to extend the shorter 1741 # shape vector with 1s (how?) and then take the 1742 # elementwise max of the two. - how to flag an error of 1743 # shape mismatch where broadcasting should be illegal? 1744 return 1745 # TODO: cut out un-necessary dimshuffles of v 1746 1747 assert rval[0].type == node.outputs[0].type, ( 1748 'rval', rval[0].type, 'orig', node.outputs[0].type, 'node', 1749 node,) # theano.printing.debugprint(node.outputs[0], file='str')) 1750 return rval 1751 1752# Register this after stabilize at 1.5 to make sure stabilize don't 1753# get affected by less canonicalized graph due to alloc. 1754compile.optdb.register('local_fill_to_alloc', 1755 in2out(local_fill_to_alloc), 1756 1.51, 'fast_run') 1757# Needed to clean some extra alloc added by local_fill_to_alloc 1758compile.optdb.register('local_elemwise_alloc', 1759 in2out(local_elemwise_alloc), 1760 1.52, 'fast_run') 1761 1762 1763@register_canonicalize("fast_compile") 1764@register_useless 1765@gof.local_optimizer([T.fill]) 1766def local_useless_fill(node): 1767 """fill(s,v) -> v 1768 1769 This optimization is only needed in FAST_COMPILE to make the code 1770 more readable. Normally, it is done by the local_fill_to_alloc 1771 opt. 1772 1773 """ 1774 if node.op == T.fill: 1775 r, v = node.inputs 1776 if v.type == node.outputs[0].type: 1777 # this is a useless fill, erase it. 1778 # also, we don't need to copy over any stack traces here 1779 return [v] 1780 1781 1782@register_specialize 1783@register_stabilize 1784@register_canonicalize 1785@register_useless 1786@gof.local_optimizer([T.alloc]) 1787def local_useless_alloc(node): 1788 """ 1789 If the input type is the same as the output type (dtype and broadcast) 1790 there is no change in the shape of the input. So this is just a simple copy 1791 of the input. This is not needed. 1792 1793 """ 1794 op = node.op 1795 if not isinstance(op, Alloc): 1796 return False 1797 1798 input = node.inputs[0] 1799 output = node.outputs[0] 1800 1801 # Check if dtype and broadcast remain the same. 1802 if input.type == output.type: 1803 # We don't need to copy over any stack traces here 1804 return [input] 1805 1806 1807@register_specialize 1808@register_stabilize 1809@register_canonicalize 1810@gof.local_optimizer([T.alloc]) 1811def local_canonicalize_alloc(node): 1812 """If the input type is the same as the output type (dtype and broadcast) 1813 there is no change in the shape of the input. So this is just a simple copy 1814 of the input. This is not needed. (as local_useless_alloc) 1815 1816 Also, it will canonicalize alloc by creating Dimshuffle after the 1817 alloc to introduce the dimensions of constant size 1. 1818 1819 See https://github.com/Theano/Theano/issues/4072 to know why this 1820 is needed. 1821 1822 """ 1823 op = node.op 1824 if not isinstance(op, Alloc): 1825 return False 1826 1827 input = node.inputs[0] 1828 output = node.outputs[0] 1829 1830 # Check if dtype and broadcast remain the same. 1831 if input.type == output.type: 1832 # We don't need to copy over any stack traces here 1833 return [input] 1834 1835 # Allow local_merge_alloc to do its work first 1836 clients = getattr(output, 'clients', []) 1837 for client, i in clients: 1838 if client != "output" and isinstance(client.op, Alloc): 1839 return 1840 1841 # Check if alloc adds a broadcastable dimension with shape 1. 1842 1843 output_shape = node.inputs[1:] 1844 num_dims_with_size_1_added_to_left = 0 1845 for i in range(len(output_shape) - input.ndim): 1846 if extract_constant(output_shape[i], only_process_constants=True) == 1: 1847 num_dims_with_size_1_added_to_left += 1 1848 else: 1849 break 1850 new_output_shape = output_shape[num_dims_with_size_1_added_to_left:] 1851 if num_dims_with_size_1_added_to_left > 0 and len(new_output_shape) >= input.ndim: 1852 if output.broadcastable[num_dims_with_size_1_added_to_left:] == input.broadcastable: 1853 inner = input 1854 else: 1855 inner = op(*([input] + new_output_shape)) 1856 dimshuffle_new_order = (['x'] * num_dims_with_size_1_added_to_left + 1857 list(xrange(len(new_output_shape)))) 1858 return [DimShuffle(inner.type.broadcastable, dimshuffle_new_order)(inner)] 1859 1860 1861# Don't register by default. 1862@gof.local_optimizer([T.AllocEmpty]) 1863def local_alloc_empty_to_zeros(node): 1864 """This convert AllocEmpty to Alloc of 0. 1865 1866 This help investigate NaN with NanGuardMode. Not registered by 1867 default. To activate it, use the Theano flag 1868 optimizer_including=alloc_empty_to_zeros. This also enable 1869 the GPU version of this optimizations. 1870 1871 """ 1872 if isinstance(node.op, T.AllocEmpty): 1873 return [T.zeros(node.inputs, dtype=node.outputs[0].dtype)] 1874compile.optdb.register('local_alloc_empty_to_zeros', 1875 in2out(local_alloc_empty_to_zeros), 1876 # After move to gpu and merge2, before inplace. 1877 49.3, 1878 'alloc_empty_to_zeros',) 1879 1880 1881@register_specialize 1882@register_canonicalize 1883@gof.local_optimizer([T.Shape]) 1884def local_shape_to_shape_i(node): 1885 if node.op == T.shape: 1886 # This optimization needs ShapeOpt and fgraph.shape_feature 1887 if not hasattr(node.fgraph, 'shape_feature'): 1888 return 1889 shape_feature = node.fgraph.shape_feature 1890 ret = shape_feature.make_vector_shape(node.inputs[0]) 1891 1892 # We need to copy over stack trace from input to output 1893 copy_stack_trace(node.outputs[0], ret) 1894 return [ret] 1895 1896 1897# TODO: Not sure what type of node we are expecting here 1898@register_specialize 1899@register_canonicalize 1900@gof.local_optimizer(None) 1901def local_track_shape_i(node): 1902 try: 1903 shape_feature = node.fgraph.shape_feature 1904 except AttributeError: 1905 return 1906 if node in shape_feature.scheduled: 1907 # Don't unschedule node as it could be reinserted in the 1908 # fgraph as we don't change it in the shapefeature internal 1909 # structure. 1910 assert isinstance(node.op, Shape_i) 1911 replacement = shape_feature.scheduled[node] 1912 return [shape_feature.shape_of[replacement][node.op.i]] 1913 1914 1915@register_specialize 1916@register_canonicalize 1917@gof.local_optimizer([Subtensor]) 1918def local_subtensor_inc_subtensor(node): 1919 """ 1920 Subtensor(SetSubtensor(x, y, idx), idx) -> y 1921 1922 """ 1923 if isinstance(node.op, Subtensor): 1924 x = node.inputs[0] 1925 if not x.owner or not isinstance(x.owner.op, IncSubtensor): 1926 return 1927 if not x.owner.op.set_instead_of_inc: 1928 return 1929 1930 if (x.owner.inputs[2:] == node.inputs[1:] and 1931 tuple(x.owner.op.idx_list) == tuple(node.op.idx_list)): 1932 out = node.outputs[0] 1933 y = x.owner.inputs[1] 1934 # If the dtypes differ, cast y into x.dtype 1935 if x.dtype != y.dtype: 1936 y = y.astype(x.dtype) 1937 if out.type == y.type: 1938 # if x[idx] and y have the same type, directly return y 1939 return [y] 1940 else: 1941 # The difference is related to broadcasting pattern 1942 assert out.broadcastable != y.broadcastable 1943 # We have to alloc y to the shape of x[idx] 1944 x_subtensor = node.op(x.owner.inputs[0], *x.owner.inputs[2:]) 1945 return [T.alloc(y, *x_subtensor.shape)] 1946 else: 1947 return 1948 1949 1950@register_specialize 1951@register_canonicalize 1952@gof.local_optimizer([Subtensor]) 1953def local_subtensor_remove_broadcastable_index(node): 1954 """ 1955 Remove broadcastable dimension with index 0 or -1 1956 a[:,:,:,0] -> a.dimshuffle(0,1,2), when 1957 a.broadcastable = (False, False, False, True) 1958 a[0,:,-1,:] -> a.dimshuffle(1,3), when 1959 a.broadcastable = (True, False, True, False) 1960 1961 """ 1962 if isinstance(node.op, Subtensor): 1963 idx = node.op.idx_list 1964 else: 1965 return 1966 1967 remove_dim = [] 1968 node_inputs_idx = 1 1969 for dim, elem in enumerate(idx): 1970 if isinstance(elem, (scalar.Scalar)): 1971 # The idx is a Scalar, ie a Type. This means the actual index 1972 # is contained in node.inputs[1] 1973 dim_index = node.inputs[node_inputs_idx] 1974 if type(dim_index) == theano.scalar.basic.ScalarConstant: 1975 dim_index = dim_index.value 1976 if dim_index in [0, -1] and node.inputs[0].broadcastable[dim]: 1977 remove_dim.append(dim) 1978 node_inputs_idx += 1 1979 else: 1980 return 1981 elif isinstance(elem, slice): 1982 if elem != slice(None): 1983 return 1984 elif isinstance(elem, (integer_types, np.integer)): 1985 if elem in [0, -1] and node.inputs[0].broadcastable[dim]: 1986 remove_dim.append(dim) 1987 else: 1988 raise TypeError('case not expected') 1989 1990 if len(remove_dim) == 0: 1991 return 1992 else: 1993 all_dim = range(node.inputs[0].ndim) 1994 remain_dim = [x for x in all_dim if x not in remove_dim] 1995 return [node.inputs[0].dimshuffle(tuple(remain_dim))] 1996 1997 1998@register_specialize 1999@register_canonicalize('fast_compile_gpu') 2000@register_useless 2001@gof.local_optimizer([Subtensor, AdvancedSubtensor1]) 2002def local_subtensor_make_vector(node): 2003 """ 2004 Replace all subtensor(make_vector) like: 2005 [a,b,c][0] -> a 2006 [a,b,c][0:2] -> [a,b] 2007 2008 Replace all AdvancedSubtensor1(make_vector) like: 2009 [a,b,c][[0,2]] -> [a,c] 2010 2011 We can do this for constant indexes. 2012 2013 """ 2014 x = node.inputs[0] 2015 if not x.owner or x.owner.op != make_vector: 2016 return 2017 2018 if isinstance(node.op, Subtensor): 2019 # This optimization needs ShapeOpt and fgraph.shape_feature 2020 try: 2021 idx, = node.op.idx_list 2022 except Exception: 2023 # 'how can you have multiple indexes into a shape?' 2024 raise 2025 2026 if isinstance(idx, (scalar.Scalar, T.TensorType)): 2027 # The idx is a Scalar, ie a Type. This means the actual index 2028 # is contained in node.inputs[1] 2029 old_idx, idx = idx, node.inputs[1] 2030 assert idx.type == old_idx 2031 elif isinstance(node.op, AdvancedSubtensor1): 2032 idx = node.inputs[1] 2033 else: 2034 return 2035 2036 if isinstance(idx, (integer_types, np.integer)): 2037 # We don't need to copy over any stack traces here 2038 return [x.owner.inputs[idx]] 2039 elif isinstance(idx, Variable): 2040 if idx.ndim == 0: 2041 # if it is a constant we can do something with it 2042 try: 2043 v = get_scalar_constant_value(idx, only_process_constants=True) 2044 if isinstance(v, np.integer): 2045 # Python 2.4 wants to index only with Python integers 2046 v = int(v) 2047 # We don't need to copy over any stack traces here 2048 try: 2049 ret = [x.owner.inputs[v]] 2050 except IndexError: 2051 raise NotScalarConstantError("Bad user graph!") 2052 return ret 2053 except NotScalarConstantError: 2054 pass 2055 elif idx.ndim == 1 and isinstance(idx, T.Constant): 2056 values = list(map(int, list(idx.value))) 2057 ret = make_vector(*[x.owner.inputs[v] for v in values]) 2058 2059 # Copy over stack trace from previous output to new output 2060 copy_stack_trace(node.outputs[0], ret) 2061 ret = T.patternbroadcast(ret, node.outputs[0].broadcastable) 2062 return [ret] 2063 else: 2064 raise TypeError('case not expected') 2065 elif isinstance(idx, slice): 2066 # it is a slice of ints and/or Variables 2067 # check subtensor to see if it can contain constant variables, and if 2068 # it can, then try to unpack them. 2069 try: 2070 const_slice = node.op.get_constant_idx(node.inputs, 2071 allow_partial=False)[0] 2072 ret = make_vector(*x.owner.inputs[const_slice]) 2073 # Copy over stack trace from previous outputs to new output 2074 copy_stack_trace(node.outputs, ret) 2075 ret = T.patternbroadcast(ret, node.outputs[0].broadcastable) 2076 return [ret] 2077 except NotScalarConstantError: 2078 pass 2079 else: 2080 raise TypeError('case not expected') 2081 2082 2083# TODO: the other optimization for and, or, xor, le and ge see ticket #496. 2084 2085@register_useless 2086@register_canonicalize('fast_compile') 2087@register_specialize 2088@gof.local_optimizer([T.Elemwise]) 2089def local_useless_elemwise(node): 2090 """ 2091 eq(x, x) -> 1 2092 neq(x, x) -> 0 2093 mul(x) -> x 2094 add(x) -> x 2095 identity(x) -> x 2096 and(x, 1) -> x (if x.dtype == 'bool') 2097 and(x, 0) -> zeros_like(x) 2098 or(x, 0) -> x 2099 or(x, 1) -> ones_like(x) (if x.dtype == 'bool') 2100 xor(x, x) -> zeros_like(x) 2101 2102 """ 2103 if isinstance(node.op, T.Elemwise): 2104 # We call zeros_like and one_like with opt=True to generate a 2105 # cleaner graph. 2106 dtype = node.outputs[0].dtype 2107 2108 if node.op.scalar_op == theano.scalar.eq and len(node.inputs) == 2: 2109 if node.inputs[0] == node.inputs[1]: 2110 # it is the same var in the graph. That will always be true 2111 ret = T.ones_like(node.inputs[0], dtype=dtype, opt=True) 2112 2113 # Copy stack trace from input to constant output 2114 copy_stack_trace(node.outputs[0], ret) 2115 return [ret] 2116 elif node.op.scalar_op == theano.scalar.neq and len(node.inputs) == 2: 2117 if node.inputs[0] == node.inputs[1]: 2118 # it is the same var in the graph. That will always be false 2119 ret = T.zeros_like(node.inputs[0], dtype=dtype, opt=True) 2120 2121 # Copy stack trace from input to constant output 2122 copy_stack_trace(node.outputs[0], ret) 2123 return [ret] 2124 2125 elif node.op.scalar_op == theano.scalar.mul and len(node.inputs) == 1: 2126 # No need to copy over any stack trace 2127 return [node.inputs[0]] 2128 2129 elif node.op.scalar_op == theano.scalar.add and len(node.inputs) == 1: 2130 # No need to copy over any stack trace 2131 return [node.inputs[0]] 2132 elif (node.op.scalar_op == theano.scalar.identity and 2133 len(node.inputs) == 1): 2134 return [node.inputs[0]] 2135 2136 elif (isinstance(node.op.scalar_op, scalar.AND) and 2137 len(node.inputs) == 2): 2138 2139 if isinstance(node.inputs[0], T.TensorConstant): 2140 const_val = T.extract_constant(node.inputs[0], only_process_constants=True) 2141 if not isinstance(const_val, Variable): 2142 if const_val == 0: 2143 return [T.zeros_like(node.inputs[1], dtype=dtype, 2144 opt=True)] 2145 elif node.outputs[0].dtype == 'bool': 2146 # If the output is not Boolean, it is the bitwise AND, 2147 # and this optimization would be wrong 2148 return [node.inputs[1].astype(node.outputs[0].dtype)] 2149 2150 if isinstance(node.inputs[1], T.TensorConstant): 2151 const_val = T.extract_constant(node.inputs[1], only_process_constants=True) 2152 if not isinstance(const_val, Variable): 2153 if const_val == 0: 2154 return [T.zeros_like(node.inputs[0], dtype=dtype, 2155 opt=True)] 2156 elif node.outputs[0].dtype == 'bool': 2157 # If the output is not Boolean, it is the bitwise AND, 2158 # and this optimization would be wrong 2159 return [node.inputs[0].astype(node.outputs[0].dtype)] 2160 2161 elif (isinstance(node.op.scalar_op, scalar.OR) and 2162 len(node.inputs) == 2): 2163 2164 if isinstance(node.inputs[0], T.TensorConstant): 2165 const_val = T.extract_constant(node.inputs[0], only_process_constants=True) 2166 if not isinstance(const_val, Variable): 2167 if const_val == 0: 2168 return [node.inputs[1].astype(node.outputs[0].dtype)] 2169 elif node.outputs[0].dtype == 'bool': 2170 # If the output is not Boolean, it is the bitwise OR, 2171 # and this optimization would be wrong 2172 return [T.ones_like(node.inputs[1], dtype=dtype, 2173 opt=True)] 2174 2175 if isinstance(node.inputs[1], T.TensorConstant): 2176 const_val = T.extract_constant(node.inputs[1], only_process_constants=True) 2177 if not isinstance(const_val, Variable): 2178 if const_val == 0: 2179 return [node.inputs[0].astype(node.outputs[0].dtype)] 2180 elif node.outputs[0].dtype == 'bool': 2181 # If the output is not Boolean, it is the bitwise OR, 2182 # and this optimization would be wrong 2183 return [T.ones_like(node.inputs[0], dtype=dtype, 2184 opt=True)] 2185 2186 elif (isinstance(node.op.scalar_op, scalar.XOR) and 2187 len(node.inputs) == 2): 2188 if node.inputs[0] is node.inputs[1]: 2189 return [T.zeros_like(node.inputs[0], dtype=dtype, opt=True)] 2190 2191 2192@register_specialize 2193@gof.local_optimizer([T.Elemwise]) 2194def local_alloc_unary(node): 2195 """unary(alloc(x, shp)) -> alloc(unary(x), shp)""" 2196 if isinstance(node.op, T.Elemwise) and len(node.inputs) == 1: 2197 a = node.inputs[0] 2198 if a.owner and isinstance(a.owner.op, T.Alloc): 2199 x = a.owner.inputs[0] 2200 shp = a.owner.inputs[1:] 2201 v = node.op(x) 2202 # T.alloc does not preserve the stacktrace of v, 2203 # so we need to copy it over from x. 2204 copy_stack_trace(node.outputs[0], v) 2205 ret = T.alloc(T.cast(v, node.outputs[0].dtype), *shp) 2206 2207 # T.cast does not preserve the stacktrace of x, 2208 # so we need to copy it over to the output. 2209 copy_stack_trace([node.outputs[0], a], ret) 2210 return [ret] 2211 2212 2213@register_canonicalize 2214@register_specialize 2215@gof.local_optimizer([T.Elemwise]) 2216def local_cast_cast(node): 2217 """cast(cast(x, dtype1), dtype2) 2218 2219 when those contrain: 2220 dtype1 == dtype2 2221 OR the base dtype is the same (int, uint, float, complex) 2222 and the first cast cause an upcast. 2223 2224 """ 2225 if (not isinstance(node.op, T.Elemwise) or 2226 not isinstance(node.op.scalar_op, scalar.Cast)): 2227 return 2228 x = node.inputs[0] 2229 if (not x.owner or 2230 not isinstance(x.owner.op, T.Elemwise) or 2231 not isinstance(x.owner.op.scalar_op, scalar.Cast)): 2232 return 2233 2234 type1 = x.owner.op.scalar_op.o_type 2235 type2 = node.op.scalar_op.o_type 2236 base = x.owner.inputs[0] 2237 2238 if type1 == type2: 2239 # We don't need to copy over any stack traces here 2240 return [x] 2241 2242 if(is_an_upcast(base.dtype, type1.dtype)): 2243 # Checking for further redundancy. Eg: int8 -> int32 -> int8 2244 if(type2.dtype == base.dtype): 2245 return x.owner.inputs 2246 else: 2247 # Apply the second cast only 2248 v = node.op(base) 2249 # Copy stack trace from the output of the original cast 2250 copy_stack_trace(node.outputs[0], v) 2251 return [v] 2252 2253 2254def is_an_upcast(type1, type2): 2255 """Given two data types (as strings), check if converting to 2256 type2 from type1 constitutes an upcast. 2257 Differs from theano.scalar.upcast 2258 2259 """ 2260 category = { 2261 # The first number in the pair is the dtype (bool, uint, int, float, 2262 # complex). Conversion from higher to lower is never an upcast. 2263 # The second number roughly indicates the precision. Again, conversion 2264 # from higher to lower is never an upcast. 2265 2266 'bool': (0, 0), 2267 'uint8': (1, 1), 'uint16': (1, 2), 'uint32': (1, 3), 'uint64': (1, 4), 2268 'int8': (2, 1), 'int16': (2, 2), 'int32': (2, 3), 'int64': (2, 4), 2269 'float16': (3, 1.5), 'float32': (3, 2.5), 'float64': (3, 3.5), 2270 'complex64': (4, 3), 'complex128': (4, 4) 2271 } 2272 2273 cat1 = category[type1] 2274 cat2 = category[type2] 2275 2276 if(cat2[0] >= cat1[0] and cat2[1] > cat1[1]): 2277 return True 2278 else: 2279 return False 2280 2281 2282@register_canonicalize 2283@register_specialize 2284@gof.local_optimizer([T.Elemwise]) 2285def local_func_inv(node): 2286 """ 2287 Check for two consecutive operations that are functional inverses 2288 and remove them from the function graph. 2289 2290 """ 2291 inv_pairs = ( 2292 (basic.Deg2Rad, basic.Rad2Deg), 2293 (basic.Cosh, basic.ArcCosh), 2294 (basic.Tanh, basic.ArcTanh), 2295 (basic.Sinh, basic.ArcSinh), 2296 (basic.Conj, basic.Conj), 2297 (basic.Neg, basic.Neg), 2298 (basic.Inv, basic.Inv), 2299 ) 2300 x = node.inputs[0] 2301 2302 if not isinstance(node.op, T.Elemwise): 2303 return 2304 if (not x.owner or not isinstance(x.owner.op, T.Elemwise)): 2305 return 2306 2307 prev_op = x.owner.op.scalar_op 2308 node_op = node.op.scalar_op 2309 2310 for inv_pair in inv_pairs: 2311 if is_inverse_pair(node_op, prev_op, inv_pair): 2312 # We don't need to copy stack trace, because the optimization 2313 # is trivial and maintains the earlier stack trace 2314 return x.owner.inputs 2315 2316 return 2317 2318 2319def is_inverse_pair(node_op, prev_op, inv_pair): 2320 """ 2321 Given two consecutive operations, check if they are the 2322 provided pair of inverse functions. 2323 2324 """ 2325 node_is_op0 = isinstance(node_op, inv_pair[0]) 2326 node_is_op1 = isinstance(node_op, inv_pair[1]) 2327 prev_is_op0 = isinstance(prev_op, inv_pair[0]) 2328 prev_is_op1 = isinstance(prev_op, inv_pair[1]) 2329 2330 return (node_is_op0 and prev_is_op1) or (node_is_op1 and prev_is_op0) 2331 2332 2333class Assert(T.Op): 2334 """ 2335 Implements assertion in a computational graph. 2336 2337 Returns the first parameter if the condition is true, otherwise, triggers 2338 AssertionError. 2339 2340 Notes 2341 ----- 2342 This Op is a debugging feature. It can be removed from the graph 2343 because of optimizations, and can hide some possible optimizations to 2344 the optimizer. Specifically, removing happens if it can be determined 2345 that condition will always be true. Also, the output of the Op must be 2346 used in the function computing the graph, but it doesn't have to be 2347 returned. 2348 2349 Examples 2350 -------- 2351 >>> import theano 2352 >>> T = theano.tensor 2353 >>> x = T.vector('x') 2354 >>> assert_op = T.opt.Assert() 2355 >>> func = theano.function([x], assert_op(x, x.size<2)) 2356 2357 """ 2358 _f16_ok = True 2359 __props__ = ('msg',) 2360 view_map = {0: [0]} 2361 2362 check_input = False 2363 2364 def __init__(self, msg="Theano Assert failed!"): 2365 self.msg = msg 2366 2367 def __setstate__(self, attrs): 2368 self.__dict__.update(attrs) 2369 if not hasattr(self, 'msg'): 2370 self.msg = "Theano Assert failed!" 2371 2372 def make_node(self, value, *conds): 2373 if not isinstance(value, Variable): 2374 value = T.as_tensor_variable(value) 2375 cond = [T.as_tensor_variable(c) for c in conds] 2376 assert np.all([c.type.ndim == 0 for c in cond]) 2377 return gof.Apply(self, [value] + cond, [value.type()]) 2378 2379 def perform(self, node, inputs, out_): 2380 out, = out_ 2381 v = inputs[0] 2382 out[0] = v 2383 assert np.all(inputs[1:]), self.msg 2384 2385 def grad(self, input, output_gradients): 2386 return output_gradients + [DisconnectedType()()] * (len(input) - 1) 2387 2388 def connection_pattern(self, node): 2389 return [[1]] + [[0]] * (len(node.inputs) - 1) 2390 2391 def c_code(self, node, name, inames, onames, sub): 2392 value = inames[0] 2393 out = onames[0] 2394 check = [] 2395 fail = sub['fail'] 2396 msg = self.msg.replace('"', '\\"').replace('\n', '\\n') 2397 for idx in xrange(len(inames) - 1): 2398 i = inames[idx + 1] 2399 dtype = node.inputs[idx + 1].dtype 2400 check.append('if(!((npy_%(dtype)s*)PyArray_DATA(%(i)s))[0])' 2401 '{PyErr_SetString(PyExc_AssertionError,"%(msg)s");' 2402 '%(fail)s}' % locals()) 2403 check = "\n".join(check) 2404 return """ 2405 %(check)s 2406 Py_XDECREF(%(out)s); 2407 %(out)s = %(value)s; 2408 Py_INCREF(%(value)s); 2409 """ % locals() 2410 2411 def c_code_cache_version(self): 2412 return (3, 0) 2413 2414 def infer_shape(self, node, input_shapes): 2415 return [input_shapes[0]] 2416 2417assert_ = Assert() 2418# Unittest.assert_ is a deprecated name for assertTrue. 2419# 2to3 convert theano.tensor.opt.assert_ to theano.tensor.opt.assertTrue 2420# So I define a new name as a work around. 2421assert_op = assert_ 2422 2423 2424@register_specialize 2425@gof.local_optimizer([Assert]) 2426def local_remove_useless_assert(node): 2427 if isinstance(node.op, Assert): 2428 cond = [] 2429 for c in node.inputs[1:]: 2430 try: 2431 const = get_scalar_constant_value(c) 2432 2433 if 0 != const.ndim or const == 0: 2434 # Should we raise an error here? How to be sure it 2435 # is not catched? 2436 cond.append(c) 2437 except NotScalarConstantError: 2438 cond.append(c) 2439 2440 if len(cond) == 0: 2441 # We don't need to copy over any stack traces here 2442 return [node.inputs[0]] 2443 if len(cond) != len(node.inputs) - 1: 2444 ret = assert_(node.inputs[0], *cond) 2445 2446 # We copy over stack trace from the output of the original assert 2447 copy_stack_trace(node.outputs[0], ret) 2448 return [ret] 2449 2450 2451@gof.local_optimizer([Assert]) 2452def local_remove_all_assert(node): 2453 """An optimization disabled by default that removes all asserts from 2454 the graph. 2455 2456 Notes 2457 ----- 2458 See the :ref:`unsafe` section to know how to enable it. 2459 2460 """ 2461 if not isinstance(node.op, Assert): 2462 return 2463 2464 # We don't need to copy over any stack traces here 2465 return [node.inputs[0]] 2466# Disabled by default 2467compile.optdb['canonicalize'].register('local_remove_all_assert', 2468 local_remove_all_assert, 2469 'unsafe', 2470 use_db_name_as_tag=False) 2471compile.optdb['stabilize'].register('local_remove_all_assert', 2472 local_remove_all_assert, 2473 'unsafe', 2474 use_db_name_as_tag=False) 2475compile.optdb['specialize'].register('local_remove_all_assert', 2476 local_remove_all_assert, 2477 'unsafe', 2478 use_db_name_as_tag=False) 2479compile.optdb['useless'].register('local_remove_all_assert', 2480 local_remove_all_assert, 2481 'unsafe', 2482 use_db_name_as_tag=False) 2483 2484####################### 2485# Constant Canonicalization 2486############################ 2487 2488 2489@register_canonicalize 2490@gof.local_optimizer([T.Elemwise]) 2491def local_upcast_elemwise_constant_inputs(node): 2492 """This explicitly upcasts constant inputs to elemwise Ops, when 2493 those Ops do implicit upcasting anyway. 2494 2495 Rationale: it helps merge things like (1-x) and (1.0 - x). 2496 2497 """ 2498 if len(node.outputs) > 1: 2499 return 2500 try: 2501 shape_i = node.fgraph.shape_feature.shape_i 2502 except AttributeError: 2503 shape_i = None 2504 if isinstance(node.op, T.Elemwise): 2505 scalar_op = node.op.scalar_op 2506 # print "aa", scalar_op.output_types_preference 2507 if (getattr(scalar_op, 'output_types_preference', None) 2508 in (T.scal.upgrade_to_float, T.scal.upcast_out)): 2509 # this is the kind of op that we can screw with the input 2510 # dtypes by upcasting explicitly 2511 output_dtype = node.outputs[0].type.dtype 2512 new_inputs = [] 2513 for i in node.inputs: 2514 if i.type.dtype == output_dtype: 2515 new_inputs.append(i) 2516 else: 2517 try: 2518 # works only for scalars 2519 cval_i = get_scalar_constant_value(i, 2520 only_process_constants=True) 2521 if all(i.broadcastable): 2522 new_inputs.append(T.shape_padleft( 2523 T.cast(cval_i, output_dtype), 2524 i.ndim)) 2525 else: 2526 if shape_i is None: 2527 return 2528 new_inputs.append( 2529 T.alloc(T.cast(cval_i, output_dtype), 2530 *[shape_i(d)(i) 2531 for d in xrange(i.ndim)])) 2532 # print >> sys.stderr, "AAA", 2533 # *[Shape_i(d)(i) for d in xrange(i.ndim)] 2534 except NotScalarConstantError: 2535 # for the case of a non-scalar 2536 if isinstance(i, T.TensorConstant): 2537 new_inputs.append(T.cast(i, output_dtype)) 2538 else: 2539 new_inputs.append(i) 2540 2541 if new_inputs != node.inputs: 2542 rval = [node.op(*new_inputs)] 2543 if rval[0].type != node.outputs[0].type: 2544 # This can happen for example when floatX=float32 2545 # and we do the true division between and int64 2546 # and a constant that will get typed as int8. 2547 2548 # As this is just to allow merging more case, if 2549 # the upcast don't work, we can just skip it. 2550 return 2551 2552 # Copy over output stacktrace from before upcasting 2553 copy_stack_trace(node.outputs[0], rval) 2554 return rval 2555 2556################## 2557# Subtensor opts # 2558################## 2559 2560 2561@register_useless 2562@register_canonicalize 2563@register_specialize 2564@gof.local_optimizer([IncSubtensor]) 2565def local_useless_inc_subtensor(node): 2566 """ 2567 Remove IncSubtensor, when we overwrite the full inputs with the 2568 new value. 2569 2570 """ 2571 if not isinstance(node.op, IncSubtensor): 2572 return 2573 if node.op.set_instead_of_inc is False: 2574 # This is an IncSubtensor, so the init value must be zeros 2575 try: 2576 c = get_scalar_constant_value(node.inputs[0], 2577 only_process_constants=True) 2578 if c != 0: 2579 return 2580 except NotScalarConstantError: 2581 return 2582 if (node.inputs[0].ndim != node.inputs[1].ndim or 2583 node.inputs[0].broadcastable != node.inputs[1].broadcastable): 2584 # FB: I didn't check if this case can happen, but this opt 2585 # don't support it. 2586 return 2587 # We have a SetSubtensor or an IncSubtensor on zeros 2588 # If is this IncSubtensor useful? 2589 2590 # Check that we keep all the original data. 2591 # Put the constant inputs in the slice. 2592 idx_cst = get_idx_list(node.inputs[1:], node.op.idx_list) 2593 if all(isinstance(e, slice) and e.start is None and 2594 e.stop is None and (e.step is None or T.extract_constant(e.step, 2595 only_process_constants=True) == -1) 2596 for e in idx_cst): 2597 # IncSubtensor broadcast node.inputs[1] on node.inputs[0] 2598 # based on run time shapes, so we must check they are the same. 2599 if not hasattr(node.fgraph, 'shape_feature'): 2600 return 2601 if not node.fgraph.shape_feature.same_shape(node.inputs[0], 2602 node.inputs[1]): 2603 return 2604 # There is no reverse, so we don't need a replacement. 2605 if all(e.step is None 2606 for e in node.op.idx_list): 2607 # They are the same shape, so we can remore this IncSubtensor 2608 return [node.inputs[1]] 2609 ret = Subtensor(node.op.idx_list)(*node.inputs[1:]) 2610 # Copy over previous output stacktrace 2611 copy_stack_trace(node.outputs, ret) 2612 return [ret] 2613 2614 2615@register_canonicalize 2616@gof.local_optimizer([AdvancedIncSubtensor1]) 2617def local_set_to_inc_subtensor(node): 2618 """ 2619 AdvancedIncSubtensor1(x, x[ilist]+other, ilist, set_instead_of_inc=True) -> 2620 AdvancedIncSubtensor1(x, other, ilist, set_instead_of_inc=False) 2621 2622 """ 2623 if (isinstance(node.op, AdvancedIncSubtensor1) and 2624 node.op.set_instead_of_inc and 2625 node.inputs[1].owner and 2626 isinstance(node.inputs[1].owner.op, Elemwise) and 2627 isinstance(node.inputs[1].owner.op.scalar_op, scalar.Add)): 2628 addn = node.inputs[1].owner 2629 subn = None 2630 other = None 2631 2632 if (addn.inputs[0].owner and 2633 isinstance(addn.inputs[0].owner.op, AdvancedSubtensor1)): 2634 subn = addn.inputs[0].owner 2635 other = addn.inputs[1] 2636 elif (addn.inputs[1].owner and 2637 isinstance(addn.inputs[1].owner.op, AdvancedSubtensor1)): 2638 subn = addn.inputs[1].owner 2639 other = addn.inputs[0] 2640 else: 2641 return 2642 if (subn.inputs[1] != node.inputs[2] or 2643 subn.inputs[0] != node.inputs[0]): 2644 return 2645 ret = advanced_inc_subtensor1(node.inputs[0], other, node.inputs[2]) 2646 # Copy over previous output stacktrace 2647 # Julian: I'm not sure about this at all... 2648 copy_stack_trace(node.outputs, ret) 2649 return [ret] 2650 2651 2652@register_useless 2653@register_canonicalize 2654@register_specialize 2655@gof.local_optimizer([Subtensor]) 2656def local_useless_slice(node): 2657 """ 2658 Remove Subtensor of the form X[0, :] -> X[0] 2659 """ 2660 if isinstance(node.op, Subtensor): 2661 slices = get_idx_list(node.inputs, node.op.idx_list) 2662 last_slice = len(slices) 2663 for s in slices[::-1]: 2664 # check if slice and then check slice indices 2665 if (isinstance(s, slice) and s.start is None and s.stop is None and 2666 (s.step is None or T.extract_constant(s.step, 2667 only_process_constants=True) == 1)): 2668 last_slice -= 1 2669 else: 2670 break 2671 # check if we removed something 2672 if last_slice < len(slices): 2673 subtens = Subtensor(slices[:last_slice]) 2674 sl_ins = Subtensor.collapse(slices[:last_slice], 2675 lambda x: isinstance(x, T.Variable)) 2676 out = subtens(node.inputs[0], *sl_ins) 2677 # Copy over previous output stacktrace 2678 copy_stack_trace(node.outputs, out) 2679 return [out] 2680 2681 2682@register_canonicalize 2683@register_specialize 2684@gof.local_optimizer([Subtensor, AdvancedSubtensor1]) 2685def local_useless_subtensor(node): 2686 """ 2687 Remove Subtensor/AdvancedSubtensor1 if it takes the full input. In the 2688 AdvancedSubtensor1 case, the full input is taken when the indices are 2689 equivalent to `arange(0, input.shape[0], 1)` using either an explicit 2690 list/vector or the ARange op. 2691 2692 """ 2693 2694 # If the optimization is tried over a node that is not a part of graph before 2695 if not hasattr(node, 'fgraph'): 2696 return 2697 2698 # This optimization needs ShapeOpt and fgraph.shape_feature 2699 if not hasattr(node.fgraph, 'shape_feature'): 2700 return 2701 2702 shape_of = node.fgraph.shape_feature.shape_of 2703 2704 if isinstance(node.op, Subtensor): 2705 cdata = node.op.get_constant_idx(node.inputs, allow_partial=True, 2706 only_process_constants=True) 2707 for pos, idx in enumerate(cdata): 2708 if not isinstance(idx, slice): 2709 # If idx is not a slice, this means we remove this dimension 2710 # from the output, so the subtensor is not useless 2711 return False 2712 if idx.start is not None and idx.start != 0: 2713 # If the start of the slice is different from 0, or is a 2714 # variable, then we assume the subtensor is not useless 2715 return False 2716 if idx.step is not None and idx.step != 1: 2717 # If we are going backwards, or skipping elements, then this 2718 # is not a useless subtensor 2719 return False 2720 2721 for pos, idx in enumerate(cdata): 2722 2723 length_pos = shape_of[node.inputs[0]][pos] 2724 2725 if isinstance(idx.stop, (integer_types, np.integer)): 2726 length_pos_data = sys.maxsize 2727 try: 2728 length_pos_data = get_scalar_constant_value(length_pos, 2729 only_process_constants=True) 2730 except NotScalarConstantError: 2731 pass 2732 2733 if idx.stop < length_pos_data: 2734 return False 2735 elif isinstance(idx.stop, gof.Variable): 2736 length_pos_shape_i = idx.stop 2737 # length_pos is a tensor variable, but length_pos_shape_i 2738 # is a scalar variable. We try to see if they represent 2739 # the same underlying variable. 2740 if (length_pos_shape_i.owner and 2741 isinstance(length_pos_shape_i.owner.op, 2742 T.ScalarFromTensor)): 2743 length_pos_shape_i = length_pos_shape_i.owner.inputs[0] 2744 elif (length_pos.owner and 2745 isinstance(length_pos.owner.op, T.TensorFromScalar)): 2746 length_pos = length_pos.owner.inputs[0] 2747 else: 2748 # We did not find underlying variables of the same type 2749 return False 2750 2751 # The type can be different: int32 vs int64. length_pos 2752 # should always be int64 as that is what the shape 2753 # tracker keep. Subtensor accept any scalar int{8,16,32,64} 2754 # as index type. 2755 assert str(length_pos.type.dtype) == "int64" 2756 assert str(length_pos_shape_i.type.dtype) in ["int8", "int16", 2757 "int32", "int64"] 2758 2759 # length_pos_shape_i cannot be None 2760 if length_pos_shape_i != length_pos: 2761 return False 2762 elif idx.stop is None: 2763 pass 2764 else: 2765 return False 2766 elif isinstance(node.op, AdvancedSubtensor1): 2767 # get length of the indexed tensor along the first axis 2768 try: 2769 length = get_scalar_constant_value(shape_of[node.inputs[0]][0], 2770 only_process_constants=True) 2771 except NotScalarConstantError: 2772 return False 2773 2774 # get index (which must be a vector by definition) 2775 idx = node.inputs[1] 2776 2777 # `idx` must be equivalent to [0,1,...,shape[0] - 1] to qualify for 2778 # this optimization 2779 if isinstance(idx, T.Constant): 2780 idx = idx.value 2781 if len(idx) != length: 2782 return False 2783 if np.any(idx != np.arange(length)): 2784 return False 2785 elif idx.owner is not None and isinstance(idx.owner.op, T.ARange): 2786 try: 2787 start, stop, step = map(lambda x: get_scalar_constant_value(x, 2788 only_process_constants=True), 2789 idx.owner.inputs) 2790 except NotScalarConstantError: 2791 return False 2792 2793 if start != 0: 2794 return False 2795 if stop != length: 2796 return False 2797 if step != 1: 2798 return False 2799 else: 2800 return False 2801 else: 2802 return False 2803 2804 # We don't need to copy over any stacktrace here, 2805 # because previous stacktrace should suffice. 2806 return [node.inputs[0]] 2807 2808 2809# fast_compile to allow opt subtensor(cast{float32}(make_vector)) 2810@register_canonicalize('fast_compile') 2811@gof.local_optimizer([Subtensor]) 2812def local_subtensor_lift(node): 2813 """ 2814 unary(x)[idx] -> unary(x[idx])#any broadcast pattern. 2815 2816 Handles the following unary ops: 2817 elemwise(x,...)[idx] -> elemwise(x[idx],...) 2818 when x,... are broadcasted scalar or not broadcasted at all 2819 rebroadcast(x)[idx] => rebroadcast(x[idx]) 2820 2821 """ 2822 if isinstance(node.op, Subtensor): 2823 u = node.inputs[0] 2824 if not u.owner or len(u.clients) > 1: 2825 return False 2826 2827 if isinstance(u.owner.op, T.Elemwise) and len(u.owner.inputs) == 1: 2828 idx = node.inputs[1:] 2829 x_idx = node.op(u.owner.inputs[0], *idx) 2830 # Copy over previous output stacktrace 2831 copy_stack_trace(node.outputs, x_idx) 2832 ret = u.owner.op(x_idx) 2833 # Copy over previous output stacktrace 2834 # and stacktrace from previous unary operation 2835 copy_stack_trace([node.outputs[0], node.inputs[0]], ret) 2836 return [ret] 2837 2838 if isinstance(u.owner.op, T.Elemwise): 2839 new_inputs = [] 2840 if all([sum(i.type.broadcastable) == 0 for i in u.owner.inputs]): 2841 # There is no broadcastable in the inputs 2842 idx = node.inputs[1:] 2843 new_inputs = [node.op(i, *idx) for i in u.owner.inputs] 2844 # Copy over previous output stacktrace 2845 copy_stack_trace(node.outputs[0], new_inputs) 2846 2847 ret = u.owner.op(*new_inputs) 2848 # Copy over previous output stacktrace 2849 # and stacktrace from previous unary operation 2850 copy_stack_trace([node.outputs[0], node.inputs[0]], ret) 2851 return [ret] 2852 elif all([sum(i.type.broadcastable) in [i.ndim, 0] 2853 for i in u.owner.inputs]): 2854 # There is no broadcastable in the inputs or it is scalar 2855 idx = node.inputs[1:] 2856 new_inputs = [] 2857 for i in u.owner.inputs: 2858 if sum(i.type.broadcastable) == 0: 2859 new_inputs.append(node.op(i, *idx)) 2860 else: 2861 # If the subtensor remove some dims, we must 2862 # lower the number of dimensions of this scalar. 2863 if node.outputs[0].ndim == i.ndim: 2864 new_inputs.append(i) 2865 else: 2866 new_inputs.append( 2867 i.dimshuffle(['x'] * node.outputs[0].ndim)) 2868 2869 # Copy over previous output stacktrace 2870 copy_stack_trace(node.outputs[0], new_inputs) 2871 2872 ret = u.owner.op(*new_inputs) 2873 # Copy over previous output stacktrace 2874 # and stacktrace from previous unary operation 2875 copy_stack_trace([node.outputs[0], node.inputs[0]], ret) 2876 return [ret] 2877 2878 if isinstance(u.owner.op, T.Rebroadcast): 2879 # make sure that Rebroadcast has only 1 input 2880 assert len(u.owner.inputs) == 1 2881 2882 # Subtensor might reduce dim., adapt broadcast pattern accordingly 2883 new_axis = [] 2884 2885 # loop through indices being subtensor-ed 2886 # i indexes broadcastable pattern before subtensor 2887 # j indexes broadcastable pattern after subtensor 2888 j = 0 2889 for (i, x) in enumerate(node.op.idx_list): 2890 # if its not a slice, it will reduce the dimension, should 2891 # not appear in the broascastable dimensions 2892 if isinstance(x, slice): 2893 new_axis += [(j, u.broadcastable[i])] 2894 j += 1 2895 # now keep the broadcastable pattern of all 2896 # items not appearing in subtensor list 2897 for i in xrange(len(node.op.idx_list), len(u.broadcastable)): 2898 new_axis += [(j, u.broadcastable[i])] 2899 j += 1 2900 2901 subt_x = node.op(u.owner.inputs[0], *node.inputs[1:]) 2902 # Copy over previous output stacktrace 2903 copy_stack_trace(node.outputs[0], subt_x) 2904 2905 rbcast_subt_x = T.Rebroadcast(*new_axis)(subt_x) 2906 # Copy over previous output stacktrace 2907 # and stacktrace from previous unary operation 2908 copy_stack_trace([node.outputs[0], node.inputs[0]], rbcast_subt_x) 2909 2910 return [rbcast_subt_x] 2911 2912 2913def merge_two_slices(slice1, len1, slice2, len2): 2914 """ 2915 This function merges two slices into a single slice. The code works on 2916 the assumption that: 2917 2918 a) slice1 is actually a slice and not an index, while slice2 2919 can be just an index. 2920 2921 b) the two slices **have been applied consecutively** on the same 2922 tensor 2923 2924 The output slice is **not** in canonical form, but actually just a slice 2925 that can be applied to a tensor to produce the same output as applying 2926 the two consecutive slices. 2927 ``len1`` is the length of the tensor **before** applying the first slice, 2928 while ``len2`` is the length **after** applying the first slice. 2929 """ 2930 list_opt = [local_abs_merge, local_mul_switch_sink, 2931 local_upcast_elemwise_constant_inputs, 2932 local_useless_switch, constant_folding] 2933 2934 if type(slice1) is not slice: 2935 raise ValueError(('First provided slice should actually be of type' 2936 'slice and not an index !'), slice1) 2937 sl1, reverse1 = get_canonical_form_slice(slice1, len1) 2938 sl2, reverse2 = get_canonical_form_slice(slice2, len2) 2939 2940 if type(sl2) is not slice: 2941 if reverse1 is None: 2942 # The first slice is not in reverse, which makes things a lot 2943 # more clear. 2944 # In this case we need to take care only of the special cases: 2945 # len2 <=0 -> throw index error regardless of sl2 2946 # sl2 > len2 -> throw index error 2947 # sl2 < -len2 -> throw index error 2948 # To get a index error we simply use len1+1 to indicate we are 2949 # out of bounds, because passing this index through the formula 2950 # of getting the mixed slice is not guaranteed to result in an 2951 # index error. The **issue though** if that the error will 2952 # complain about accessing element len1+1 which is probably not 2953 # too intuitive for the user 2954 val = sl1.start + sl2 * sl1.step 2955 val = T.switch(T.le(len2, 0), len1 + 1, val) 2956 val = T.switch(T.ge(sl2, len2), len1 + 1, val) 2957 val = T.switch(T.lt(sl2, 0), - len1 - 1, val) 2958 if sl1.step: 2959 val = T.switch(T.eq(sl1.step, 0), len1 + 1, val) 2960 val = pre_greedy_local_optimizer(list_opt, val) 2961 return val 2962 else: 2963 # We are in the more complex case when we do not actually know 2964 # if the first slice was in reverse or not. 2965 # in case it was not in reverse: 2966 p_val = sl1.start + sl2 * sl1.step 2967 # case it was in reverse we need to realize that we do not want 2968 # the k-th element from sl.start but the k-th element from 2969 # sl.stop backwards 2970 n_val = sl1.stop - 1 - sl2 * sl1.step 2971 if config.warn.subtensor_merge_bug: 2972 warnings.warn(( 2973 'Your current code is fine, but Theano versions ' 2974 'prior to 0.5rc2 might have given an incorrect result. ' 2975 'To disable this warning, set the Theano flag ' 2976 'warn.subtensor_merge_bug to False.')) 2977 # we need to pick either n_val or p_val and then follow same 2978 # steps as above for covering the index error cases 2979 val = T.switch(T.lt(reverse1, 0), n_val, p_val) 2980 val = T.switch(T.le(len2, 0), len1 + 1, val) 2981 val = T.switch(T.ge(sl2, len2), len1 + 1, val) 2982 val = T.switch(T.lt(sl2, 0), - len1 - 1, val) 2983 if sl1.step: 2984 val = T.switch(T.eq(sl1.step, 0), len1 + 1, val) 2985 val = pre_greedy_local_optimizer(list_opt, val) 2986 return val 2987 else: 2988 # We are deleaing with two slices that need to be put together 2989 # according to the two steps we have 4 different combinations of 2990 # positive/negative. I will denote the case I'm looking at by 2991 # suffixes to the variables (nn,np,pn,pp): 2992 flen = sl2.stop - sl2.start 2993 p_step = sl1.step * sl2.step 2994 n_step = sl1.step * sl2.step * -1 2995 2996 pp_start = T.minimum(sl1.start + sl2.start * sl1.step, sl1.stop) 2997 pp_stop = T.minimum(sl1.start + sl2.stop * sl1.step, sl1.stop) 2998 2999 pn_stop = sl1.start + (sl2.start - 1) * sl1.step 3000 pn_stop = T.switch(T.and_(T.lt(pn_stop, 0), 3001 T.gt(flen, 0)), 3002 -len1 - 1, 3003 T.minimum(pn_stop, sl1.stop)) 3004 pn_start = sl1.start + (sl2.stop - 1) * sl1.step 3005 pn_start = T.minimum(pn_start, sl1.stop) 3006 pn_start = T.maximum(pn_start, 0) 3007 3008 np_stop = sl1.stop - sl2.stop * sl1.step - 1 3009 np_stop = T.switch(T.and_(T.lt(np_stop, 0), 3010 T.gt(flen, 0)), 3011 -len1 - 1, 3012 T.maximum(sl1.start - 1, np_stop)) 3013 np_start = T.maximum(sl1.start, sl1.stop - sl2.start * sl1.step - 1) 3014 3015 nn_start = T.maximum(sl1.start, 3016 (sl1.stop - 1) - (sl2.stop - 1) * sl1.step) 3017 nn_stop = T.maximum(sl1.start, sl1.stop - sl2.start * sl1.step) 3018 3019 start = T.switch(T.lt(reverse2 * reverse1, 0), 3020 T.switch(T.lt(reverse1, 0), np_start, pn_start), 3021 T.switch(T.lt(reverse1, 0), nn_start, 3022 pp_start)) 3023 3024 stop = T.switch(T.lt(reverse2 * reverse1, 0), 3025 T.switch(T.lt(reverse1, 0), np_stop, pn_stop), 3026 T.switch(T.lt(reverse1, 0), nn_stop, pp_stop)) 3027 3028 step = T.switch(T.lt(reverse2 * reverse1, 0), n_step, p_step) 3029 start = T.switch(T.le(flen, 0), 0, start) 3030 stop = T.switch(T.le(flen, 0), 0, stop) 3031 3032 # The canonical form of the slice is pretty complicated 3033 # and is not simplified. We simplify it in advance here 3034 # as otherwise this create too many useless optimization that 3035 # DebugMode must check. 3036 start = pre_greedy_local_optimizer(list_opt, start) 3037 stop = pre_greedy_local_optimizer(list_opt, stop) 3038 step = pre_greedy_local_optimizer(list_opt, step) 3039 start = pre_greedy_local_optimizer(list_opt, start) 3040 stop = pre_greedy_local_optimizer(list_opt, stop) 3041 step = pre_greedy_local_optimizer(list_opt, step) 3042 3043 # Pre merge constant for the same reason. 3044 start, stop, step = pre_constant_merge([start, stop, step]) 3045 3046 return slice(start, stop, step) 3047 3048 3049@register_canonicalize 3050@register_specialize 3051@gof.local_optimizer([Subtensor]) 3052def local_subtensor_merge(node): 3053 """ 3054 Refactored optimization to deal with all cases of tensor merging. 3055 Given a subgraph of the form Subtensor(Subtensor(u)), the optimization 3056 expresses all slices in a canonical form, and then merges them together. 3057 3058 """ 3059 3060 if isinstance(node.op, Subtensor): 3061 u = node.inputs[0] 3062 if u.owner and isinstance(u.owner.op, Subtensor): 3063 # We can merge :) 3064 # x actual tensor on which we are picking slices 3065 x = u.owner.inputs[0] 3066 # slices of the first applied subtensor 3067 slices1 = get_idx_list(u.owner.inputs, u.owner.op.idx_list) 3068 slices2 = get_idx_list(node.inputs, node.op.idx_list) 3069 # Get the shapes of the vectors ! 3070 try: 3071 # try not to introduce new shape into the graph 3072 xshape = node.fgraph.shape_feature.shape_of[x] 3073 ushape = node.fgraph.shape_feature.shape_of[u] 3074 except AttributeError: 3075 # Following the suggested use of shape_feature which should 3076 # consider the case when the compilation mode doesn't 3077 # include the ShapeFeature 3078 xshape = x.shape 3079 ushape = u.shape 3080 3081 merged_slices = [] 3082 pos_2 = 0 3083 pos_1 = 0 3084 while (pos_1 < len(slices1)) and (pos_2 < len(slices2)): 3085 slice1 = slices1[pos_1] 3086 if type(slice1) is slice: 3087 merged_slices.append( 3088 merge_two_slices(slice1, 3089 xshape[pos_1], 3090 slices2[pos_2], 3091 ushape[pos_2])) 3092 pos_2 += 1 3093 else: 3094 merged_slices.append(slice1) 3095 pos_1 += 1 3096 3097 if pos_2 < len(slices2): 3098 merged_slices += slices2[pos_2:] 3099 else: 3100 merged_slices += slices1[pos_1:] 3101 3102 merged_slices = make_constant(merged_slices) 3103 subtens = Subtensor(merged_slices) 3104 3105 sl_ins = Subtensor.collapse( 3106 merged_slices, 3107 lambda x: isinstance(x, T.Variable)) 3108 # Do not call make_node for test_value 3109 out = subtens(x, *sl_ins) 3110 3111 # Copy over previous output stacktrace 3112 # and stacktrace from previous slicing operation. 3113 # Why? Because, the merged slicing operation could have failed 3114 # because of either of the two original slicing operations 3115 orig_out = node.outputs[0] 3116 copy_stack_trace([orig_out, node.inputs[0]], out) 3117 3118 # Restore original broadcastable dimensions that `subtens()` may 3119 # have been unable to infer again 3120 if out.type != orig_out.type: 3121 assert out.dtype == orig_out.dtype 3122 assert out.ndim == orig_out.ndim 3123 out = T.patternbroadcast(out, orig_out.broadcastable) 3124 copy_stack_trace([orig_out, node.inputs[0]], out) 3125 return [out] 3126 3127 3128@register_useless 3129@register_canonicalize 3130@register_specialize 3131@gof.local_optimizer([Subtensor]) 3132def local_subtensor_of_alloc(node): 3133 """ 3134 3135 alloc(val)[x:y] -> alloc(val[...]) 3136 alloc(val)[x:y] -> alloc(val) 3137 This can be seen as a lift, but it also reduce the number of computation/memory. 3138 3139 """ 3140 if not isinstance(node.op, Subtensor): 3141 return False 3142 u = node.inputs[0] 3143 if u.owner is None: 3144 return False 3145 if not isinstance(u.owner.op, T.Alloc): 3146 return False 3147 slices = get_idx_list(node.inputs, node.op.idx_list) 3148 val = u.owner.inputs[0] 3149 dims = u.owner.inputs[1:] 3150 assert len(slices) <= len(dims) 3151 3152 # Number of dimensions added to val 3153 n_added_dims = u.ndim - val.ndim 3154 # Dimensions of the returned alloc 3155 nw_dims = [] 3156 # Slices to take from val 3157 val_slices = [] 3158 3159 for i, (sl, dim) in enumerate(zip(slices, dims)): 3160 # If val was not copied over that dim, 3161 # we need to take the appropriate subtensor on it. 3162 if i >= n_added_dims: 3163 # We check that the corresponding val dimensions was 3164 # not a broadcasted dimensions. 3165 if (val.type.ndim > (i - n_added_dims) and 3166 val.type.broadcastable[i - n_added_dims]): 3167 val_slices.append(slice(None)) 3168 else: 3169 val_slices.append(sl) 3170 3171 csl, _ = get_canonical_form_slice(sl, dim) 3172 if type(csl) is not slice: 3173 # That dimension is removed. 3174 pass 3175 else: 3176 nw_dim = csl.stop - csl.start 3177 3178 if csl.step != 1: 3179 # Do not add the ceil_intdiv() graphs in the graphs 3180 # when this is not needed as it prevent detecting the 3181 # correct broadcast pattern. 3182 nw_dim = T.ceil_intdiv(nw_dim, csl.step) 3183 nw_dims += [nw_dim] 3184 3185 nw_val = val[tuple(val_slices)] 3186 nw_dims += dims[len(slices):] 3187 if nw_val.ndim > len(nw_dims): 3188 return False 3189 rval = T.alloc(nw_val, *nw_dims) 3190 if type(rval) not in (list, tuple): 3191 rval = [rval] 3192 if rval[0].type != node.outputs[0].type: 3193 # It happen that the make_node() isn't able to infer the same pattern. 3194 # We know it is safe, so fix that. 3195 rval[0] = T.patternbroadcast(rval[0], node.outputs[0].broadcastable) 3196 3197 return rval 3198 3199 3200@register_canonicalize 3201@register_stabilize 3202@register_specialize 3203@gof.local_optimizer([Subtensor]) 3204def local_subtensor_of_dot(node): 3205 """ 3206 This optimization translates T.dot(A, B)[idxs] into T.dot(A[idxs_a], B[idxs_b]), 3207 where idxs_a and idxs_b are defined appropriately. 3208 3209 idxs_a is the first A.ndim-1 entries of idxs, 3210 and idxs_b is the remaining entries of idxs (if any), 3211 modified to skip the second-to-last dimension of B 3212 (because dot sums over this dimension). 3213 3214 """ 3215 if not isinstance(node.op, Subtensor): 3216 return 3217 if (not node.inputs[0].owner or 3218 not isinstance(node.inputs[0].owner.op, T.Dot)): 3219 return 3220 # If there is other node that use the outputs of the dot 3221 # We don't want to compute twice the sub part. 3222 if len(node.inputs[0].clients) > 1: 3223 return 3224 3225 a = node.inputs[0].owner.inputs[0] 3226 b = node.inputs[0].owner.inputs[1] 3227 3228 idx_list = get_idx_list(node.inputs, node.op.idx_list) 3229 3230 num_a_indices = min(a.ndim - 1, len(idx_list)) 3231 a_indices = idx_list[:num_a_indices] 3232 b_indices = idx_list[num_a_indices:] 3233 3234 # This is necessary because np.dot sums the last index of a with the second to last of b 3235 # so we want to skip the second-to-last index into b. 3236 # This wasn't necessary for a, because we just omitted the last index. 3237 # We skip this if b.ndim = 1, since then we just want b_sub = b, not b_sub = b[:] 3238 # (dot also handles b.ndim < 2 as a special case) 3239 if b.ndim > 1 and len(b_indices) >= b.ndim - 1: 3240 b_indices = (b_indices[:b.ndim - 2] + 3241 (slice(None, None, None),) + b_indices[b.ndim - 2:]) 3242 3243 a_sub = a.__getitem__(tuple(a_indices)) 3244 b_sub = b.__getitem__(tuple(b_indices)) if b_indices else b 3245 3246 # Copy over previous output stacktrace to a_sub and b_sub, 3247 # because an error in the subtensor operation (e.g. an index error) 3248 # on either a or b must correspond to an error in the 3249 # subtensor operation on their dot product. 3250 copy_stack_trace(node.outputs[0], [a_sub, b_sub]) 3251 3252 # Copy over previous output stacktrace and previous dot product stacktrace, 3253 # because an error here may correspond to an either in either the original 3254 # dot product, or in the dot product after the subtensor operation. 3255 r = T.dot(a_sub, b_sub) 3256 copy_stack_trace([node.outputs[0], node.inputs[0]], r) 3257 3258 return [r] 3259 3260 3261@register_canonicalize 3262@gof.local_optimizer([T.add]) 3263def local_IncSubtensor_serialize(node): 3264 """ 3265 When using Subtensor, gradient graphs can be ugly. 3266 3267 If we ask for grad(f(a[0]), a), we are going to get something like 3268 3269 IncSubtensor(Elemwise{second}(a, 0), g(f(a[0])), [0]) 3270 3271 This might be ugly, but at least it's as fast as you could want. 3272 If we ask for grad(f(a[0], a[1], a[2]), a), it's much worse... 3273 3274 Elemwise{Add} 3275 IncSubtensor(Elemwise{second}(a, 0), g(f(a[0])), [0]) 3276 IncSubtensor(Elemwise{second}(a, 0), g(f(a[1])), [1]) 3277 IncSubtensor(Elemwise{second}(a, 0), g(f(a[2])), [2]) 3278 3279 This is much worse because this time we have to produce 3 matrices 3280 the size of 'a', just so we can add them together. 3281 3282 This Op rearranges IncSubtensor's that all work on the same 3283 initial argument (here, Elemwise{second}(a,0)) into a chain. The 3284 advantage of the chain structure is that each one can be optimized 3285 later in the pipeline to operate inplace. 3286 3287 Ideally, the op will do something like this: 3288 3289 # 3290 # add(x, incsubtensor(b, c), incsubtensor(b, d)) 3291 # -> incsubtensor(incsubtensor(add(x,b,b), c), d) 3292 3293 """ 3294 def movable(i): 3295 # Return True iff this is a incsubtensor that we can move 3296 return (i.owner and 3297 isinstance(i.owner.op, (IncSubtensor, 3298 AdvancedIncSubtensor1, 3299 AdvancedIncSubtensor,)) and 3300 i.type == o_type and 3301 len(i.clients) == 1 and 3302 not i.owner.op.set_instead_of_inc) 3303 3304 if node.op == T.add: 3305 o_type = node.outputs[0].type 3306 3307 movable_inputs = [i for i in node.inputs if movable(i)] 3308 3309 if movable_inputs: 3310 new_inputs = ([i for i in node.inputs if not movable(i)] + 3311 [mi.owner.inputs[0] for mi in movable_inputs]) 3312 if len(new_inputs) == 0: 3313 new_add = new_inputs[0] 3314 else: 3315 new_add = T.add(*new_inputs) 3316 3317 # Copy over stacktrace from original output, as an error 3318 # (e.g. an index error) in this add operation should 3319 # correspond to an error in the original add operation. 3320 copy_stack_trace(node.outputs[0], new_add) 3321 3322 # stack up the new incsubtensors 3323 tip = new_add 3324 for mi in movable_inputs: 3325 assert tip.type == o_type 3326 assert tip.type == mi.owner.inputs[0].type 3327 tip = mi.owner.op(tip, *mi.owner.inputs[1:]) 3328 # Copy over stacktrace from outputs of the original 3329 # "movable" operation to the new operation. 3330 copy_stack_trace(node.outputs + mi.owner.outputs, tip) 3331 3332 return [tip] 3333 3334 # print incsub_inputs, [id(i.owner.inputs[0]) for i in incsub_inputs] 3335 3336# We register it in a TopoOptimizer inside the canonizer EQ optimizer. 3337# Otherwise in some cases it was making the EQ optimizer use 45. In 3338# the TopoOptimizer, the EQ only use 5 passes. 3339compile.optdb.register('pre_local_IncSubtensor_serialize', 3340 in2out(local_IncSubtensor_serialize), 3341 # Just before canonizer 3342 0.99, 'fast_run') 3343 3344 3345# after priority 50 Destructive inplace operations 3346# gemm is the first one now, at priority 70 3347 3348@gof.local_optimizer([IncSubtensor], inplace=True) 3349def local_inplace_setsubtensor(node): 3350 """ 3351 Also work for GpuIncSubtensor. 3352 3353 """ 3354 if isinstance(node.op, IncSubtensor) and not node.op.inplace: 3355 dta = node.op.destroyhandler_tolerate_aliased 3356 new_op = node.op.__class__( 3357 node.op.idx_list, inplace=True, 3358 set_instead_of_inc=node.op.set_instead_of_inc, 3359 destroyhandler_tolerate_aliased=dta) 3360 new_node = new_op(*node.inputs) 3361 val = getattr(node.outputs[0].tag, 'nan_guard_mode_check', True) 3362 new_node.tag.nan_guard_mode_check = val 3363 3364 # Copy stacktrace from original outputs to new outputs. 3365 # This is sensible, because the new operation is the 3366 # same as the old one, but now with different attributes. 3367 copy_stack_trace(node.outputs, new_node) 3368 return [new_node] 3369 return False 3370compile.optdb.register('local_inplace_setsubtensor', 3371 TopoOptimizer( 3372 local_inplace_setsubtensor, 3373 failure_callback=TopoOptimizer.warn_inplace), 3374 60, 'fast_run', 'inplace') # DEBUG 3375 3376 3377@gof.local_optimizer([AdvancedIncSubtensor1], inplace=True) 3378def local_inplace_incsubtensor1(node): 3379 """ 3380 Also work for GpuAdvancedIncSubtensor1. 3381 3382 """ 3383 if isinstance(node.op, AdvancedIncSubtensor1) and not node.op.inplace: 3384 new_op = node.op.clone_inplace() 3385 new_node = new_op(*node.inputs) 3386 3387 # Copy stacktrace from original outputs to new outputs. 3388 # This is sensible, because the new operation is the 3389 # same as the old one, but now with different attributes. 3390 copy_stack_trace(node.outputs, new_node) 3391 return [new_node] 3392 return False 3393compile.optdb.register('local_inplace_incsubtensor1', 3394 TopoOptimizer( 3395 local_inplace_incsubtensor1, 3396 failure_callback=TopoOptimizer.warn_inplace), 3397 60, 'fast_run', 'inplace') # DEBUG 3398 3399 3400# Register old name 3401@register_canonicalize("local_incsubtensor_of_allocs") 3402@register_stabilize("local_incsubtensor_of_allocs") 3403@gof.local_optimizer([IncSubtensor, 3404 AdvancedIncSubtensor, 3405 AdvancedIncSubtensor1]) 3406def local_incsubtensor_of_zeros(node): 3407 """ 3408 IncSubtensor(x, zeros, idx) -> x 3409 3410 """ 3411 if (isinstance(node.op, (IncSubtensor, 3412 AdvancedIncSubtensor, 3413 AdvancedIncSubtensor1)) and 3414 not node.op.set_instead_of_inc): 3415 x = node.inputs[0] 3416 y = node.inputs[1] 3417 try: 3418 # Don't use only_process_constants=True. We need to 3419 # investigate Alloc of 0s but with non constant shape. 3420 if get_scalar_constant_value(y, elemwise=False) == 0: 3421 # No need to copy over the stacktrace, 3422 # because x should already have a stacktrace 3423 return [x] 3424 except NotScalarConstantError: 3425 return 3426 3427 3428@register_canonicalize 3429@register_specialize 3430@gof.local_optimizer([IncSubtensor]) 3431def local_incsubtensor_of_zeros_to_setsubtensor(node): 3432 """ 3433 IncSubtensor(zeros, x, ...) -> SetSubtensor(zeros, x, ...) 3434 """ 3435 if (isinstance(node.op, (IncSubtensor)) and not node.op.set_instead_of_inc): 3436 x = node.inputs[0] 3437 3438 if isinstance(x, T.Constant) and not np.any(x.data): 3439 return [IncSubtensor(node.op.idx_list, 3440 node.op.inplace, 3441 set_instead_of_inc=True, 3442 destroyhandler_tolerate_aliased=node.op.destroyhandler_tolerate_aliased, 3443 )(*node.inputs)] 3444 3445 3446@register_canonicalize('local_setsubtensor_of_allocs') 3447@register_stabilize('local_setsubtensor_of_allocs') 3448@gof.local_optimizer([IncSubtensor]) 3449def local_setsubtensor_of_constants(node): 3450 """ 3451 SetSubtensor(x, x[idx], idx) -> x 3452 3453 when x is constant or alloc. 3454 3455 """ 3456 if isinstance(node.op, IncSubtensor) and node.op.set_instead_of_inc: 3457 x = node.inputs[0] 3458 y = node.inputs[1] 3459 3460 # Don't use only_process_constants=True. We need to 3461 # investigate Alloc of 0s but with non constant shape. 3462 try: 3463 replace_x = get_scalar_constant_value(x, elemwise=False) 3464 except NotScalarConstantError: 3465 return 3466 3467 try: 3468 replace_y = get_scalar_constant_value(y, elemwise=False) 3469 except NotScalarConstantError: 3470 return 3471 3472 if replace_x == replace_y: 3473 3474 # No need to copy over the stacktrace, 3475 # because x should already have a stacktrace 3476 return [x] 3477 else: 3478 return False 3479 3480 3481@register_canonicalize 3482@register_stabilize 3483@gof.local_optimizer([AdvancedSubtensor1]) 3484def local_adv_sub1_adv_inc_sub1(node): 3485 """Optimize the possible AdvSub1(AdvSetSub1(...), ...). 3486 3487 AdvancedSubtensor1(AdvancedSetSubtensor1(x, y, idx), idx) -> y 3488 3489 Notes 3490 ----- 3491 This opt add AssertOp. Otherwise, it would remove shape and 3492 index error. If you want to get rid of them, see the 3493 :ref:`unsafe_optimization` section. 3494 3495 WARNING: 3496 A previous version of this optimization also matched 3497 AdvancedSubtensor1(AdvancedIncSubtensor1(0s, y, idx), idx) -> y 3498 This is incorrect when there are duplicate indices. 3499 The current version warns the user about potential past issues. 3500 3501 """ 3502 if not isinstance(node.op, AdvancedSubtensor1): 3503 return 3504 inp = node.inputs[0] 3505 if (not inp.owner or 3506 not isinstance(inp.owner.op, AdvancedIncSubtensor1)): 3507 return 3508 idx = node.inputs[1] 3509 idx2 = inp.owner.inputs[2] 3510 x = inp.owner.inputs[0] 3511 y = inp.owner.inputs[1] 3512 if idx is not idx2: 3513 return 3514 if (not inp.owner.op.set_instead_of_inc and 3515 # Don't use only_process_constants=True. We need to 3516 # investigate Alloc of 0s but with non constant shape. 3517 T.extract_constant(x, elemwise=False) != 0): 3518 return 3519 3520 if not inp.owner.op.set_instead_of_inc: 3521 if config.warn.inc_subtensor1_opt: 3522 warnings.warn( 3523 'Your current code is fine, but Theano versions ' 3524 'between 0.7rc1 and 0.10 (or development versions ' 3525 'between Nov. 2014 and May 2017) ' 3526 'might have given incorrect results. This graph has ' 3527 'following pattern: inc_subtensor(zeros[idx], x)[idx], ' 3528 'where idx is an array of integers. This used to be ' 3529 'optimized to "x", which is incorrect if there are ' 3530 'duplicated indices in idx. ' 3531 'To disable this warning, set the Theano flag ' 3532 'warn.inc_subtensor1_opt to False.') 3533 return 3534 3535 cond = [T.all(T.and_(T.lt(idx, x.shape[0]), T.ge(idx, -x.shape[0])))] 3536 if not node.fgraph.shape_feature.same_shape(idx, y, 0, 0): 3537 cond.append(T.eq(idx.shape[0], y.shape[0])) 3538 r = Assert("Bad indexing or shapes in a AdvancedIncSubtensor1 " 3539 "that was optimized away")(y, *cond) 3540 copy_stack_trace(y, r) 3541 3542 if r.dtype == node.outputs[0].dtype: 3543 return [r] 3544 # It is possible that y is upcast or downcast to x.dtype. 3545 # In all case, as we set or add with 0, we can just cast y. 3546 r2 = T.cast(r, node.outputs[0].dtype) 3547 3548 # Copy over stacktrace from before casting, since 3549 # we don't expect problems in the casting operation, 3550 # and any problems in the indexing would have been spotted above. 3551 copy_stack_trace(r, r2) 3552 return [r2] 3553 3554 3555@register_specialize 3556@register_stabilize 3557@register_canonicalize 3558@register_useless 3559@gof.local_optimizer([IncSubtensor, 3560 AdvancedIncSubtensor, 3561 AdvancedIncSubtensor1]) 3562def local_useless_inc_subtensor_alloc(node): 3563 """ 3564 Replaces an [Advanced]IncSubtensor[1], whose increment is an `alloc` of 3565 a fully or partially broadcastable variable, by one that skips the 3566 intermediate `alloc` where possible. 3567 3568 """ 3569 if isinstance(node.op, (IncSubtensor, 3570 AdvancedIncSubtensor, 3571 AdvancedIncSubtensor1)): 3572 x = node.inputs[0] 3573 y = node.inputs[1] 3574 i = node.inputs[2:] 3575 3576 if y.owner is not None and isinstance(y.owner.op, T.Alloc): 3577 # `z` is the input of the Alloc op, i.e. T.alloc(z, <shape>) 3578 z = y.owner.inputs[0] 3579 3580 try: 3581 shape_feature = node.fgraph.shape_feature 3582 except AttributeError: 3583 # The shape feature may not be available in some mode, but we 3584 # need it for this optimization, so don't continue. 3585 return False 3586 3587 shape_of = shape_feature.shape_of 3588 same_shape = shape_feature.same_shape 3589 3590 # Get the subtensor of `x` indexed by `i` in order to compare 3591 # shapes later. 3592 if isinstance(node.op, IncSubtensor): 3593 xi = Subtensor(node.op.idx_list)(x, *i) 3594 elif isinstance(node.op, AdvancedIncSubtensor): 3595 xi = advanced_subtensor(x, *i) 3596 elif isinstance(node.op, AdvancedIncSubtensor1): 3597 xi = advanced_subtensor1(x, *i) 3598 else: 3599 raise Exception('Should never happen!') 3600 3601 reason = 'local_useless_incsubtensor_alloc' 3602 3603 # Add `xi` to the shape feature `fgraph`. This is important for 3604 # shape inference later because the variable must be part of the 3605 # function graph in order to call `same_shape` on it. 3606 if xi not in shape_of: 3607 shape_feature.on_import(node.fgraph, xi.owner, 3608 '%s: add `xi`' % reason) 3609 3610 # `xi` may have more dimensions than `y` since the subtensor ops 3611 # do automatic broadcasting of the increment internally. Thus, we 3612 # need to make the leading implicitly broadcasted dimensions 3613 # explicit for shape comparison later. 3614 if xi.ndim > y.ndim: 3615 y = T.shape_padleft(y, xi.ndim - y.ndim) 3616 if y not in shape_of: 3617 shape_feature.on_import(node.fgraph, y.owner, 3618 '%s: add `y`' % reason) 3619 3620 # Build `z_broad` explicitly to include extra implicit dimensions. 3621 z_broad = ((True,) * (xi.ndim - z.ndim) + z.broadcastable) 3622 3623 cond = [ 3624 # The shapes of `y` and `xi` must either agree or `y` may 3625 # also have shape equal to 1 which may be treated as a 3626 # broadcastable dimension by the subtensor op. 3627 T.or_(T.eq(y.shape[k], 1), T.eq(y.shape[k], xi.shape[k])) 3628 # Loop over all dimensions. 3629 for k in xrange(xi.ndim) 3630 # We need to check the above shapes, if 3631 # * the pre-alloc increment `z` is broadcastable in 3632 # dimension `k` (if it isn't, then the shapes of `z` and 3633 # `y` are the same by the definition of the `Alloc` op in 3634 # this dimension and replacing `y` by `z` will not hide a 3635 # shape error), and 3636 # * `xi` and `y` do not have the same shape in dimension 3637 # `k` or we cannot infer the shape statically (if the 3638 # shapes of `xi` and `y` are not the same, then replacing 3639 # `y` by `z` will hide the shape error of `y`), and 3640 # * the shape of `y` is not equal to 1 or we cannot infer 3641 # the shape statically (if the shape of `y` is equal to 3642 # 1, then `y` is broadcasted by the inc_subtensor op 3643 # internally, so the shapes of `xi` and `y` do not need 3644 # to match in dimension `k`; else we need to check at 3645 # runtime that the shape of `y` is either 1 or the same 3646 # as `xi` or otherwise replacing `y` by `z` will hide a 3647 # shape error). 3648 if (z_broad[k] and 3649 not same_shape(xi, y, dim_x=k, dim_y=k) and 3650 shape_of[y][k] != 1)] 3651 3652 if len(cond) > 0: 3653 msg = '`x[i]` and `y` do not have the same shape.' 3654 z = Assert(msg)(z, *cond) 3655 3656 r = node.op(x, z, *i) 3657 # Copy over stacktrace from previous output, since 3658 # we don't expect problems when removing the intermediate 3659 # alloc operation and so we still want to point at the line 3660 # of the inc_subtensor operation. 3661 copy_stack_trace(node.outputs, r) 3662 3663 return [r] 3664 3665 3666#################### 3667# Rebroadcast opts # 3668#################### 3669 3670@register_useless 3671@register_canonicalize 3672@register_specialize 3673@gof.local_optimizer([T.Rebroadcast]) 3674def local_useless_rebroadcast(node): 3675 """ 3676 Remove Rebroadcast if id does not actually change the broadcasting pattern. 3677 3678 """ 3679 if isinstance(node.op, T.Rebroadcast): 3680 x = node.inputs[0] 3681 if np.all(x.broadcastable == node.outputs[0].broadcastable): 3682 # No broadcastable flag was modified 3683 # No need to copy over stack trace, 3684 # because x should already have a stack trace. 3685 return [x] 3686 else: 3687 # Keep the flags that modify something 3688 new_axis = {} 3689 for dim, bc in list(node.op.axis.items()): 3690 if x.broadcastable[dim] != bc: 3691 new_axis[dim] = bc 3692 if new_axis == node.op.axis: 3693 # All flags are useful 3694 return 3695 else: 3696 r = T.Rebroadcast(*list(new_axis.items()))(x) 3697 # Copy over stacktrace from previous output 3698 copy_stack_trace(node.outputs, r) 3699 return [r] 3700 3701 3702@register_canonicalize 3703@register_specialize 3704@gof.local_optimizer([T.Rebroadcast]) 3705def local_rebroadcast_lift(node): 3706 """ 3707 Lifts Rebroadcast through unary Elemwise operations, 3708 and merges consecutive Rebroadcasts. 3709 3710 Rebroadcast(Elemwise(x)) => Elemwise(Rebroadcast(x)) 3711 Rebroadcast(Rebroadcast(x)) => Rebroadcast(x) 3712 3713 """ 3714 op = node.op 3715 if not isinstance(op, T.Rebroadcast): 3716 return False 3717 3718 input = node.inputs[0] 3719 inode = input.owner 3720 if inode and isinstance(inode.op, Elemwise) and len(inode.inputs) == 1: 3721 # It may happen that `input` has no client because this optimization 3722 # is called from `apply_rebroadcast_opt`, which in particular is used 3723 # by the `unbroadcast` function before we are in the actual function 3724 # compilation phase. 3725 if hasattr(input, 'clients') and len(input.clients) == 1: 3726 rebroadcasted = T.Rebroadcast(*list(op.axis.items()))( 3727 inode.inputs[0]) 3728 # Copy over stacktrace from previous output (after rebroadcasting) 3729 # to new output, because an error in the new graph right after 3730 # rebroadcasting must have been caused by the previous rebroadcasting. 3731 copy_stack_trace(node.outputs, rebroadcasted) 3732 3733 rval = inode.op.make_node(rebroadcasted).outputs 3734 3735 # Copy over stacktrace from previous output (after rebroadcasting) 3736 # and input (after elemwise operation) to new output, because an 3737 # error in the new graph could have been caused by either of the 3738 # two ops. 3739 copy_stack_trace(node.outputs + node.inputs, rval) 3740 3741 return rval 3742 if inode and isinstance(inode.op, T.Rebroadcast): 3743 # the "axis" specification in the outer Rebroadcast overrides 3744 # the axis of the inner one 3745 axis = inode.op.axis.copy() 3746 axis.update(op.axis) 3747 iinput = inode.inputs[0] 3748 3749 rval = [T.Rebroadcast(*list(axis.items()))(iinput)] 3750 3751 # Copy over stacktrace from previous output (after second rebroadcast) 3752 # and from previous input (after first rebroadcast op) because an error in 3753 # the new graph could have been caused by either of the two 3754 # rebroadcast ops. 3755 copy_stack_trace(node.outputs + node.inputs, rval) 3756 return rval 3757 3758 3759def apply_rebroadcast_opt(rval): 3760 """ 3761 Apply as many times as required the optimization local_useless_rebroadcast 3762 and local_rebroadcast_lift. 3763 3764 Parameters 3765 ---------- 3766 rval: a Variable 3767 3768 Returns 3769 ------- 3770 A Variable (the same if no optimization can be applied) 3771 3772 """ 3773 3774 changed = True 3775 while changed and rval.owner: 3776 changed = False 3777 rval2 = theano.tensor.opt.local_useless_rebroadcast.transform( 3778 rval.owner) 3779 if rval2: 3780 assert len(rval2) == 1 3781 rval = rval2[0] 3782 changed = True 3783 if rval.owner: 3784 rval2 = theano.tensor.opt.local_rebroadcast_lift.transform( 3785 rval.owner) 3786 if rval2: 3787 assert len(rval2) == 1 3788 rval = rval2[0] 3789 changed = True 3790 return rval 3791 3792 3793############# 3794# Join opts # 3795############# 3796@register_specialize 3797@register_canonicalize 3798@register_useless 3799@gof.local_optimizer([T.Join]) 3800def local_join_1(node): 3801 """Join(i, x) => x 3802 3803 Remove Join() when only one element is joined. 3804 3805 """ 3806 if not isinstance(node.op, T.Join): 3807 return 3808 tensors = node.inputs[1:] 3809 if len(tensors) == 1: 3810 # We don't need to copy over any stacktrace here, because the 3811 # input variable should already have its own stacktrace. 3812 return [tensors[0]] 3813 3814 3815# TODO: merge in local_useless_join 3816@register_useless 3817@register_specialize 3818@register_canonicalize 3819@gof.local_optimizer([T.Join]) 3820def local_join_empty(node): 3821 """Join(i, x, y, empty) => Join(i, x, y) 3822 3823 Remove empty inputs to joins. The empty inputs can be anywhere. 3824 3825 """ 3826 if not isinstance(node.op, T.Join): 3827 return 3828 new_inputs = [] 3829 try: 3830 join_idx = get_scalar_constant_value(node.inputs[0], 3831 only_process_constants=True) 3832 except NotScalarConstantError: 3833 return 3834 for idx in xrange(1, len(node.inputs)): 3835 inp = node.inputs[idx] 3836 # We can not use size == 0,, as this can change shape from 3,0 3837 # to 2,0. This trigger DebugMode error. This happen with 3838 # stack(...,[]) as this add a dimshuffle on [], that add a 3839 # dimensions with shape 1. 3840 if isinstance(inp, theano.Constant) and inp.data.shape[join_idx] == 0: 3841 continue 3842 new_inputs.append(inp) 3843 if len(new_inputs) < len(node.inputs) - 1: 3844 if len(new_inputs) == 0: 3845 # T.join do not work in that case. 3846 # constant folding will take care of this case. 3847 return 3848 ret = T.join(node.inputs[0], *new_inputs) 3849 o = node.outputs[0] 3850 if ret.dtype != o.dtype: 3851 # Join can upcast some inputs 3852 return 3853 3854 # Copy over stacktrace from previous output (after join op) 3855 # to new output, because an error in the new op must be caused 3856 # by an error in the old join op. 3857 copy_stack_trace(node.outputs, ret) 3858 3859 if ret.type != o.type: 3860 assert ret.dtype == o.dtype 3861 assert ret.ndim == o.ndim 3862 ret = T.patternbroadcast(ret, node.outputs[0].broadcastable) 3863 3864 # Copy over stacktrace from previous output 3865 # (after patternbroadcast op) for same reasons as before. 3866 copy_stack_trace(node.outputs, ret) 3867 3868 return [ret] 3869 3870 3871@register_specialize 3872@register_canonicalize 3873@register_useless 3874@gof.local_optimizer([T.Join]) 3875def local_join_make_vector(node): 3876 """Join(0, make_vector1, make_vector2, ...) => Join(0, make_vector12, ...) 3877 3878 Merge MakeVector inputs to Join. This can make the join completly 3879 disapear with the local_join_1 opt. 3880 3881 """ 3882 if not isinstance(node.op, T.Join) or node.outputs[0].ndim != 1: 3883 return 3884 new_inputs = [node.inputs[1]] 3885 for idx in xrange(2, len(node.inputs)): 3886 inp = node.inputs[idx] 3887 if (inp.owner and 3888 isinstance(inp.owner.op, MakeVector) and 3889 new_inputs[-1].owner and 3890 isinstance(new_inputs[-1].owner.op, MakeVector) and 3891 # MakeVector have a dtype parameter 3892 inp.owner.op == new_inputs[-1].owner.op): 3893 inps = new_inputs[-1].owner.inputs + inp.owner.inputs 3894 new_inputs[-1] = inp.owner.op(*inps) 3895 3896 # Copy over stacktrace from previous output (after join op) 3897 # to new intermediate output, because an error in the intermediate 3898 # op must be caused by an error in the old join op. 3899 copy_stack_trace(node.outputs, new_inputs[-1]) 3900 else: 3901 new_inputs.append(inp) 3902 if len(new_inputs) < len(node.inputs) - 1: 3903 ret = T.join(node.inputs[0], *new_inputs) 3904 3905 # Copy over stacktrace from previous output (after join op) 3906 # to new output, because an error in the new op must be caused 3907 # by an error in the old join op. 3908 copy_stack_trace(node.outputs, ret) 3909 return [ret] 3910 3911 3912################# 3913# speed/memory # 3914################# 3915@register_canonicalize 3916@register_specialize 3917@gof.local_optimizer([T.elemwise.Sum]) 3918def local_sumsqr2dot(node): 3919 """ 3920 This optimization detects T.sqr( W.dimshuffle('x',0,1) * G.dimshuffle(0,'x',1) ).sum(axis=(1,2)) 3921 and converts this to T.dot(T.sqr(G), T.sqr(W).sum(axis=0)). 3922 """ 3923 if (isinstance(node.op, T.elemwise.Sum) and 3924 isinstance(node.op.scalar_op, theano.scalar.basic.Add) and node.op.axis == (1, 2)): 3925 in1 = node.inputs[0] 3926 out = node.outputs[0] 3927 3928 if (in1.owner and isinstance(in1.owner.op, T.Elemwise) and isinstance(in1.owner.op.scalar_op, theano.scalar.basic.Sqr)): 3929 in_sqr = in1.owner.inputs[0] 3930 if (in_sqr.owner and isinstance(in_sqr.owner.op, T.Elemwise) and 3931 isinstance(in_sqr.owner.op.scalar_op, theano.scalar.basic.Mul) and len(in_sqr.owner.inputs) == 2): 3932 in_mul1, in_mul2 = in_sqr.owner.inputs 3933 3934 if (isinstance(in_mul1.owner.op, T.elemwise.DimShuffle) and in_mul1.owner.op.new_order == ('x', 0, 1) and 3935 isinstance(in_mul2.owner.op, T.elemwise.DimShuffle) and in_mul2.owner.op.new_order == (0, 'x', 1)): 3936 W = in_mul1.owner.inputs[0] 3937 G = in_mul2.owner.inputs[0] 3938 3939 new_out = T.dot(T.sqr(G), T.sqr(W).sum(axis=0)) 3940 if new_out.dtype != out.dtype: 3941 new_out = T.cast(new_out, dtype=out.dtype) 3942 return [new_out] 3943 3944 3945################# 3946# Exp stability # 3947################# 3948@register_stabilize 3949@register_specialize 3950@register_canonicalize 3951@gof.local_optimizer([T.Elemwise]) 3952def local_expm1(node): 3953 """ 3954 This optimization detects exp(a)-1 and converts this to expm1(a). 3955 """ 3956 if (isinstance(node.op, T.Elemwise) and 3957 isinstance(node.op.scalar_op, theano.scalar.basic.Sub)): 3958 in1, in2 = node.inputs 3959 out = node.outputs[0] 3960 3961 if (in1.owner and isinstance(in1.owner.op, T.Elemwise) and isinstance(in1.owner.op.scalar_op, theano.scalar.basic.Exp) and 3962 T.extract_constant(in2, only_process_constants=False) == 1): 3963 in11 = in1.owner.inputs[0] 3964 new_out = T.expm1(in11) 3965 3966 if new_out.dtype != out.dtype: 3967 new_out = T.cast(new_out, dtype=out.dtype) 3968 if new_out.type != out.type: 3969 return 3970 return [new_out] 3971 3972 3973############### 3974# Switch opts # 3975############### 3976@register_useless('local_remove_switch_const_cond') 3977@register_canonicalize('fast_compile', 'local_remove_switch_const_cond') 3978@register_specialize 3979@gof.local_optimizer([T.Elemwise]) 3980def local_useless_switch(node): 3981 """ 3982 This optimization makes the following changes in the graph: 3983 T.switch(cond,left,right) --> 3984 if cond is constant and cond == 0: right 3985 if cond is constant and cond != 0: left 3986 if left is right -> left 3987 3988 T.switch(le(shape_i{id}(X), 0), 0, shape_i{id}(X)) -> shape_i{id}(X) 3989 """ 3990 if (isinstance(node.op, T.Elemwise) and 3991 isinstance(node.op.scalar_op, scalar.basic.Switch)): 3992 cond = T.extract_constant(node.inputs[0], 3993 only_process_constants=True) 3994 if ((type(cond) is np.ndarray and cond.ndim == 0) or 3995 isinstance(cond, np.number)): 3996 if cond == 0: 3997 correct_out = node.inputs[2] 3998 else: 3999 correct_out = node.inputs[1] 4000 4001 if correct_out.ndim != node.outputs[0].ndim: 4002 # TODO: broadcast? 4003 return False 4004 if correct_out.dtype != node.outputs[0].dtype: 4005 out = T.cast(correct_out, node.outputs[0].dtype) 4006 else: 4007 out = correct_out 4008 4009 if out.type.broadcastable != node.outputs[0].type.broadcastable: 4010 # We need to copy data to the new dimensions during execution 4011 4012 # We should not depend on node.outputs as this would 4013 # make the new node depend on the old one that will 4014 # get optimized again. So this create a cycle. 4015 shps = [] 4016 for idx, (b1, b2), in enumerate(zip(out.type.broadcastable, 4017 node.outputs[0].type.broadcastable)): 4018 if b1 == b2: 4019 shps.append(out.shape[idx]) 4020 elif not node.inputs[1].type.broadcastable[idx]: 4021 shps.append(node.inputs[1].shape[idx]) 4022 else: 4023 shps.append(node.inputs[2].shape[idx]) 4024 out = T.alloc(out, *shps) 4025 else: 4026 out = out 4027 4028 # Copy over stacktrace from selected output to new output 4029 copy_stack_trace(node.outputs + correct_out, out) 4030 return [out] 4031 # if left is right -> left 4032 if node.inputs[1] is node.inputs[2]: 4033 # Note: No need to copy over stacktrace, because the input node 4034 # already has its own stacktrace 4035 if cond.type == node.inputs[1].type: 4036 return [node.inputs[1]] 4037 4038 ret = T.fill(cond, node.inputs[1]) 4039 4040 # Copy over stacktrace from switch output and correct branch 4041 copy_stack_trace(node.outputs + node.inputs[1], ret) 4042 return [ret] 4043 4044 # This case happens with scan. 4045 # Elemwise{switch}(le(shape_i{id}(X), 0), 0, shape_i{id}(X)) -> shape_i{id}(X) 4046 left = node.inputs[1] 4047 right = node.inputs[2] 4048 cond_var = node.inputs[0] 4049 if cond_var.owner and \ 4050 isinstance(cond_var.owner.op, T.Elemwise) and \ 4051 isinstance(cond_var.owner.op.scalar_op, scalar.LE) and \ 4052 cond_var.owner.inputs[0].owner and \ 4053 isinstance(cond_var.owner.inputs[0].owner.op, Shape_i) and \ 4054 T.extract_constant(cond_var.owner.inputs[1], only_process_constants=True) == 0 and \ 4055 T.extract_constant(left, only_process_constants=True) == 0 and \ 4056 right is cond_var.owner.inputs[0]: 4057 assert right.type == node.outputs[0].type 4058 # No need to copy over stacktrace, because the right input node 4059 # already has its own stacktrace 4060 return [right] 4061 return False 4062 return False 4063 4064 4065@register_specialize 4066@register_canonicalize 4067@gof.local_optimizer([T.mul]) 4068def local_mul_switch_sink(node): 4069 """ 4070 This optimization makes the following changes in the graph: 4071 T.mul(A,T.switch(cond,0,iff),B) --> T.switch(cond,0,T.mul(A,B,iff)) 4072 T.mul(A,T.switch(cond,ift,0),B) --> T.switch(cond,T.mul(A,B,ift),0) 4073 A and B being several (or none) symbolic variables. 4074 This is useful because A and B may not be numerically stable and give 4075 NaN or inf values for cases where the switch returns 0. 4076 With this optimization T.grad(T.switch(...)) has the right behavior. 4077 4078 Examples 4079 -------- 4080 x -> f(x) 4081 x -> g(x) 4082 y = T.switch(cond,f(x),g(x)) 4083 **without the optimization 4084 T.grad(y,x) -> grad(f(x),x) * grad(y,f(x)) + grad(g(x),x) * grad(y,g(x)) 4085 **with the optimization 4086 T.grad(y,x) -> switch(cond,grad(f(x),x), 0) + switch(cond,0,grad(g(x),x)) 4087 This will be particularly useful for the lazyif because we skip 4088 an entire part of the graph. 4089 4090 """ 4091 if node.op != T.mul: 4092 return False 4093 for idx, i in enumerate(node.inputs): 4094 if i.owner and i.owner.op == T.switch: 4095 switch = i.owner 4096 try: 4097 if (get_scalar_constant_value( 4098 switch.inputs[1], only_process_constants=True) == 0.): 4099 listmul = node.inputs[:idx] + node.inputs[idx + 1:] 4100 fmul = T.mul(*(listmul + [switch.inputs[2]])) 4101 4102 # Copy over stacktrace for elementwise multiplication op 4103 # from previous elementwise multiplication op. 4104 # An error in the multiplication (e.g. errors due to 4105 # inconsistent shapes), will point to the 4106 # multiplication op. 4107 copy_stack_trace(node.outputs, fmul) 4108 4109 fct = [T.switch(switch.inputs[0], 0, 4110 fmul)] 4111 fct[0].tag.values_eq_approx = values_eq_approx_remove_nan 4112 4113 # Copy over stacktrace for switch op from both previous 4114 # elementwise multiplication op and previous switch op, 4115 # because an error in this part can be caused by either 4116 # of the two previous ops. 4117 copy_stack_trace(node.outputs + switch.outputs, fct) 4118 return fct 4119 except NotScalarConstantError: 4120 pass 4121 try: 4122 if (get_scalar_constant_value( 4123 switch.inputs[2], only_process_constants=True) == 0.): 4124 listmul = node.inputs[:idx] + node.inputs[idx + 1:] 4125 fmul = T.mul(*(listmul + [switch.inputs[1]])) 4126 # Copy over stacktrace for elementwise multiplication op 4127 # from previous elementwise multiplication op. 4128 # An error in the multiplication (e.g. errors due to 4129 # inconsistent shapes), will point to the 4130 # multiplication op. 4131 copy_stack_trace(node.outputs, fmul) 4132 4133 fct = [T.switch(switch.inputs[0], 4134 fmul, 0)] 4135 fct[0].tag.values_eq_approx = values_eq_approx_remove_nan 4136 4137 # Copy over stacktrace for switch op from both previous 4138 # elementwise multiplication op and previous switch op, 4139 # because an error in this part can be caused by either 4140 # of the two previous ops. 4141 copy_stack_trace(node.outputs + switch.outputs, fct) 4142 return fct 4143 except NotScalarConstantError: 4144 pass 4145 return False 4146 4147 4148@register_canonicalize 4149@gof.local_optimizer([T.true_div, T.int_div]) 4150def local_div_switch_sink(node): 4151 """ 4152 This optimization makes the following changes in the graph: 4153 T.div(T.switch(cond,0,iff),A) --> T.switch(cond,0,T.div(iff,A)) 4154 T.div(T.switch(cond,ift,0),A) --> T.switch(cond,T.div(ift,A),0) 4155 4156 A being a symbolic variable. 4157 This is useful because A may not be numerically stable and give 4158 NaN or inf values for cases where the switch returns 0. 4159 See local_mul_switch_sink for more details. 4160 4161 """ 4162 if (node.op != T.true_div and node.op != T.int_div): 4163 return False 4164 op = node.op 4165 if node.inputs[0].owner and node.inputs[0].owner.op == T.switch: 4166 switch = node.inputs[0].owner 4167 try: 4168 if get_scalar_constant_value(switch.inputs[1], 4169 only_process_constants=True) == 0.: 4170 fdiv = op(switch.inputs[2], node.inputs[1]) 4171 # Copy over stacktrace for elementwise division op 4172 # from previous elementwise multiplication op. 4173 # An error in the division (e.g. errors due to 4174 # inconsistent shapes or division by zero), 4175 # will point to the new division op. 4176 copy_stack_trace(node.outputs, fdiv) 4177 4178 fct = [T.switch(switch.inputs[0], 0, 4179 fdiv)] 4180 fct[0].tag.values_eq_approx = values_eq_approx_remove_nan 4181 4182 # Copy over stacktrace for switch op from both previous 4183 # elementwise division op and previous switch op, 4184 # because an error in this part can be caused by either 4185 # of the two previous ops. 4186 copy_stack_trace(node.outputs + switch.outputs, fct) 4187 return fct 4188 except NotScalarConstantError: 4189 pass 4190 try: 4191 if get_scalar_constant_value(switch.inputs[2], 4192 only_process_constants=True) == 0.: 4193 fdiv = op(switch.inputs[1], node.inputs[1]) 4194 # Copy over stacktrace for elementwise division op 4195 # from previous elementwise multiplication op. 4196 # An error in the division (e.g. errors due to 4197 # inconsistent shapes or division by zero), 4198 # will point to the new division op. 4199 copy_stack_trace(node.outputs, fdiv) 4200 4201 fct = [T.switch(switch.inputs[0], 4202 fdiv, 0)] 4203 fct[0].tag.values_eq_approx = values_eq_approx_remove_nan 4204 4205 # Copy over stacktrace for switch op from both previous 4206 # elementwise division op and previous switch op, 4207 # because an error in this part can be caused by either 4208 # of the two previous ops. 4209 copy_stack_trace(node.outputs + switch.outputs, fct) 4210 return fct 4211 except NotScalarConstantError: 4212 pass 4213 return False 4214 4215 4216# Merge add/sub/mul/div/minimum/maximum/... of switches sharing the same 4217# condition, to enable further simplification of their branches 4218# Example: switch(c, a, b) + switch(c, x, y) -> switch(c, a+x, b+y) 4219@register_canonicalize 4220@gof.local_optimizer([T.Elemwise]) 4221def local_merge_switch_same_cond(node): 4222 scal = theano.scalar 4223 # node must be binary elemwise or add or mul 4224 if not isinstance(node.op, T.Elemwise) or not isinstance( 4225 node.op.scalar_op, (scal.BinaryScalarOp, scal.Add, scal.Mul)): 4226 return 4227 # all inputs must be switch 4228 if not all(s.owner and isinstance(s.owner.op, T.Elemwise) and 4229 isinstance(s.owner.op.scalar_op, scal.Switch) 4230 for s in node.inputs): 4231 return 4232 # all switch conditions must be the same 4233 cond = node.inputs[0].owner.inputs[0] 4234 if not all(s.owner.inputs[0] is cond for s in node.inputs[1:]): 4235 return 4236 # pull out switch 4237 return [T.switch(cond, 4238 node.op(*[s.owner.inputs[1] for s in node.inputs]), 4239 node.op(*[s.owner.inputs[2] for s in node.inputs]))] 4240 4241 4242############# 4243# Tile Opts # 4244############# 4245@register_useless 4246@register_canonicalize 4247@register_stabilize 4248@gof.local_optimizer([T.Tile]) 4249def local_useless_tile(node): 4250 """Tile(x, (1,)*N) -> x 4251 4252 This is useless tile. (1,)*N, just mean a vector with all element 4253 being 1. 4254 4255 """ 4256 if isinstance(node.op, T.Tile): 4257 try: 4258 a = T.get_scalar_constant_value(node.inputs[1], 4259 only_process_constants=True) 4260 if a == 1: 4261 try: 4262 l = T.get_vector_length(node.inputs[1]) 4263 if l == node.inputs[0].ndim: 4264 # No need to copy over any stacktrace as previous 4265 # input variable already has a stacktrace 4266 return [node.inputs[0]] 4267 elif l < node.inputs[0].ndim: 4268 # The Op don't support that case, so we can't 4269 # implement the opt and test it. 4270 return 4271 return [node.inputs[0]] 4272 else: 4273 # The Op don't support that case, so we can't 4274 # implement the opt and test it. 4275 return 4276 x_nd = node.inputs[0].ndim 4277 broad = ['x'] * (l - x_nd) + xrange(x_nd) 4278 ret = node.inputs[0].dimshuffle(broad) 4279 # Copy over stacktrace from previous output node, 4280 # and from node before tiling operation. 4281 copy_stack_trace(node.outputs + node.inputs[0], ret) 4282 return [ret] 4283 except ValueError: 4284 return 4285 except NotScalarConstantError: 4286 return 4287 4288 4289############## 4290# Split Opts # 4291############## 4292@register_useless 4293@register_canonicalize 4294@register_specialize 4295@gof.local_optimizer([T.Split]) 4296def local_useless_split(node): 4297 """ Split{n_splits=1}(x, y) -> x 4298 4299 Remove Split with only 1 split. 4300 4301 """ 4302 if isinstance(node.op, T.Split): 4303 if node.op.len_splits == 1: 4304 x, axis, splits = node.inputs 4305 out = assert_op(x, T.eq(splits.shape[0], 1)) 4306 # Copy over stacktrace from previous output node. 4307 copy_stack_trace(node.outputs, out) 4308 out2 = assert_op(out, T.eq(x.shape[axis], splits[0])) 4309 # Copy over stacktrace from previous output node. 4310 copy_stack_trace(out, out2) 4311 4312 return [out2] 4313 4314 4315################ 4316# Flatten Opts # 4317################ 4318@register_canonicalize 4319@register_stabilize 4320@gof.local_optimizer([T.Flatten]) 4321def local_flatten_lift(node): 4322 """ 4323 Flatten(UnaryElemwise(x)) -> UnaryElemwise(Flatten(x)) 4324 4325 This optimization is needed by optimization 4326 nnet/sigm.py:log1msigm_to_softplus to get applied when there is a flatten. 4327 4328 """ 4329 if (isinstance(node.op, T.Flatten) and 4330 node.inputs[0].owner and 4331 isinstance(node.inputs[0].owner.op, T.Elemwise) and 4332 len(node.inputs[0].owner.inputs) == 1): 4333 f = node.op(node.inputs[0].owner.inputs[0]) 4334 4335 # Copy over stacktrace from previous output node (flatten op), 4336 # since this is the op which may cause an error for f. 4337 copy_stack_trace(node.outputs, f) 4338 4339 e = node.inputs[0].owner.op(f) 4340 4341 # Copy over stacktrace from previous output node and from unary 4342 # elementwise output node since if there was an error, it would 4343 # probably have come from that operation. 4344 copy_stack_trace(node.outputs + [node.inputs[0]], e) 4345 4346 return [e] 4347 4348################## 4349# Reshape opts # 4350################## 4351 4352 4353def local_reshape_chain(op): 4354 @gof.local_optimizer([op]) 4355 def f(node): 4356 """ 4357 Reshape(Reshape(shape1),shape2) -> Reshape(shape2) 4358 4359 """ 4360 if not opt.check_chain(node, op, op): 4361 return False 4362 4363 # TODO: this can permit a failing program to run by eliminating 4364 # the lower reshape 4365 rval = node.op(node.inputs[0].owner.inputs[0], node.inputs[1]) 4366 4367 # Copy over stacktrace from previous output node, as any error 4368 # in new computational graph would have been caused by last op 4369 # in the old computational graph. 4370 copy_stack_trace(node.outputs, rval) 4371 4372 # It might happen that the desired output of this node has a 4373 # broadcastable pattern that does not match that of 'rval'. This is 4374 # when originally, we were able to figure out that one of the 4375 # dimensions of the reshape is one, but some other transformation 4376 # replaced the shape by one for which this cannot be guessed. 4377 # We should try to figure out why we lost the information about this 4378 # constant value... but in the meantime, better not apply this 4379 # optimization. 4380 if rval.broadcastable == node.outputs[0].broadcastable: 4381 return [rval] 4382 else: 4383 return False 4384 4385 return f 4386register_canonicalize(local_reshape_chain(T.Reshape), 4387 name='local_reshape_chain') 4388 4389 4390@register_useless 4391@register_canonicalize 4392@register_stabilize 4393@gof.local_optimizer([T.Reshape]) 4394def local_useless_reshape(node): 4395 """ 4396 Remove two kinds of useless reshape. 4397 4398 Remove Reshape when both the input and output have a single dimension. 4399 Remove Reshape when reshaping to the shape of the input. 4400 4401 """ 4402 op = node.op 4403 if not isinstance(op, Reshape): 4404 return False 4405 4406 input = node.inputs[0] 4407 output = node.outputs[0] 4408 output_shape = node.inputs[1] 4409 4410 if input.ndim != output.ndim: 4411 return False 4412 4413 # Simple case: both input and output have a single dimension. 4414 # This could hide errors if the user provides inconsistent shapes. 4415 if (input.ndim == 1 and output.ndim == 1 and 4416 input.broadcastable == output.broadcastable): 4417 return [input] 4418 4419 # Second case: all the shapes match the input shape 4420 # Match Reshape(x, x.shape) 4421 if output_shape.owner and isinstance(output_shape.owner.op, Shape): 4422 shape_input = output_shape.owner.inputs[0] 4423 if shape_input == input: 4424 return [input] 4425 4426 # Match Reshape(x, [x.shape[0], ..., x.shape[-1]]), accounting for 4427 # broadcastable and constant dimensions 4428 if output_shape.owner and isinstance(output_shape.owner.op, MakeVector): 4429 output_shape_is = output_shape.owner.inputs 4430 4431 if not hasattr(node, 'fgraph'): 4432 shape_feature = None 4433 else: 4434 shape_feature = getattr(node.fgraph, 'shape_feature', None) 4435 4436 nb_m1 = 0 4437 shape_match = [False] * input.ndim 4438 for dim in xrange(input.ndim): 4439 outshp_i = output_shape_is[dim] 4440 # Match Shape_i{dim}(input) 4441 if (outshp_i.owner and isinstance(outshp_i.owner.op, Shape_i) and 4442 outshp_i.owner.op.i == dim and 4443 outshp_i.owner.inputs[0] == input): 4444 shape_match[dim] = True 4445 continue 4446 4447 # Match Shape(input)[dim] 4448 if (outshp_i.owner and isinstance(outshp_i.owner.op, Subtensor) and 4449 len(outshp_i.owner.inputs) == 2 and 4450 extract_constant(outshp_i.owner.inputs[1]) == dim): 4451 subtensor_inp = outshp_i.owner.inputs[0] 4452 if (subtensor_inp.owner and 4453 isinstance(subtensor_inp.owner.op, Shape)): 4454 shape_input_i = subtensor_inp.owner.inputs[0] 4455 if shape_input_i == input: 4456 shape_match[dim] = True 4457 continue 4458 4459 # Match 1 if input.broadcastable[dim] is True 4460 cst_outshp_i = extract_constant(outshp_i, only_process_constants=1) 4461 if input.broadcastable[dim] and cst_outshp_i == 1: 4462 shape_match[dim] = True 4463 continue 4464 4465 # Match -1 4466 if cst_outshp_i == -1: 4467 shape_match[dim] = True 4468 nb_m1 += 1 4469 continue 4470 4471 # Match shape_of[input][dim] or its constant equivalent 4472 if shape_feature: 4473 inpshp_i = shape_feature.get_shape(input, dim) 4474 if (inpshp_i == outshp_i or 4475 (extract_constant(inpshp_i, only_process_constants=1) == 4476 extract_constant(outshp_i, only_process_constants=1))): 4477 shape_match[dim] = True 4478 continue 4479 4480 if all(shape_match) and nb_m1 <= 1: 4481 return [input] 4482 4483 # TODO later: if all the shapes except one match, we may want to 4484 # consider it useless as well, like we do in the 1-dim case. 4485 4486 4487@register_canonicalize 4488@gof.local_optimizer([T.Reshape]) 4489def local_reshape_to_dimshuffle(node): 4490 """ 4491 Broadcastable dimensions in Reshape are replaced with dimshuffle. 4492 4493 The goal is to avoid using reshape to add or remove broadcastable 4494 dimensions, but use dimshuffle instead, so dimshuffles can cancel out 4495 or be removed later on. 4496 4497 For example: 4498 - reshape(x, (1, n)) --> dimshuffle{x,0}(reshape(x, (n,)) 4499 - reshape(x, (1, m, 1, n, 1, 1)) 4500 --> dimshuffle{x,0,x,1,x,x}(reshape(x, (m, n))) 4501 """ 4502 op = node.op 4503 if not isinstance(op, Reshape): 4504 return False 4505 4506 input = node.inputs[0] 4507 output = node.outputs[0] 4508 output_shape = node.inputs[1] 4509 4510 dimshuffle_new_order = [] 4511 new_output_shape = [] 4512 index = 0 # index over the output of the new reshape 4513 for i in xrange(output.ndim): 4514 # Since output_shape is a symbolic vector, we trust extract_constant 4515 # to go through however it is formed to see if its i-th element is 1. 4516 # We need only_process_constants=False for that. 4517 dim = extract_constant(output_shape[i], only_process_constants=False, 4518 elemwise=False) 4519 if dim == 1: 4520 dimshuffle_new_order.append('x') 4521 else: 4522 dimshuffle_new_order.append(index) 4523 new_output_shape.append(dim) 4524 index = index + 1 4525 if index != output.ndim: 4526 inner = op.__class__(len(new_output_shape))(input, new_output_shape) 4527 copy_stack_trace(output, inner) 4528 new_node = [DimShuffle(inner.type.broadcastable, dimshuffle_new_order)(inner)] 4529 copy_stack_trace(output, new_node) 4530 return new_node 4531 4532 4533@register_canonicalize 4534@register_stabilize 4535@gof.local_optimizer([T.Reshape]) 4536def local_reshape_lift(node): 4537 """ 4538 Reshape(UnaryElemwise(x)) -> UnaryElemwise(Reshape(x)) 4539 4540 This optimization is needed by optimization 4541 nnet/sigm.py:log1msigm_to_softplus to get applied when there is a reshape. 4542 4543 """ 4544 if (isinstance(node.op, T.Reshape) and 4545 node.inputs[0].owner and 4546 isinstance(node.inputs[0].owner.op, T.Elemwise) and 4547 len(node.inputs[0].owner.inputs) == 1): 4548 r = node.op(node.inputs[0].owner.inputs[0], node.inputs[1]) 4549 # Copy stacktrace from previous Reshape op, as an error in new 4550 # Reshape op could only have been caused by old one. 4551 copy_stack_trace(node.outputs, r) 4552 4553 e = node.inputs[0].owner.op(r) 4554 # Copy stacktrace from both previous Reshape and UnaryElemwise op 4555 # because an error in new cg could have been caused by either ops. 4556 copy_stack_trace(node.outputs + node.inputs, e) 4557 4558 # In rare case the original broadcast was (False, True), but 4559 # the new one is (False, False). So don't crash in that case. 4560 if e.type != node.outputs[0].type: 4561 re = T.patternbroadcast(e, node.outputs[0].broadcastable) 4562 4563 # Copy over stack trace. 4564 # If the graph fails it is usually due to the fact that a dimension 4565 # that should be broadcastable does not actually have length 1, 4566 copy_stack_trace(e, re) 4567 else: 4568 re = e 4569 4570 return [re] 4571 4572 4573################## 4574# Middleman cuts # 4575################## 4576 4577register_canonicalize(gof.OpRemove(T.tensor_copy), name='remove_tensor_copy') 4578 4579################ 4580# Canonization # 4581################ 4582 4583 4584class Canonizer(gof.LocalOptimizer): 4585 r""" 4586 Simplification tool. The variable is a local_optimizer. It is best used 4587 with a TopoOptimizer in in_to_out order. 4588 4589 Usage: Canonizer(main, inverse, reciprocal, calculate) 4590 4591 Parameters 4592 ---------- 4593 main 4594 A suitable Op class that is commutative, associative and 4595 takes one to an arbitrary number of inputs, e.g. add or 4596 mul 4597 inverse 4598 An Op class such that inverse(main(x, y), y) == x 4599 e.g. sub or true_div 4600 reciprocal 4601 A function such that main(x, reciprocal(y)) == inverse(x, y) 4602 e.g. neg or inv 4603 calculate 4604 Function that takes a list of numpy.ndarray instances 4605 for the numerator, another list for the denumerator, 4606 and calculates inverse(main(\*num), main(\*denum)). It 4607 takes a keyword argument, aslist. If True, the value 4608 should be returned as a list of one element, unless 4609 the value is such that value = main(). In that case, 4610 the return value should be an empty list. 4611 4612 Examples 4613 -------- 4614 >>> import theano.tensor as T 4615 >>> from theano.tensor.opt import Canonizer 4616 >>> add_canonizer = Canonizer(T.add, T.sub, T.neg, \ 4617 ... lambda n, d: sum(n) - sum(d)) 4618 >>> mul_canonizer = Canonizer(T.mul, T.true_div, T.inv, \ 4619 ... lambda n, d: prod(n) / prod(d)) 4620 4621 Examples of optimizations mul_canonizer can perform: 4622 4623 | x / x -> 1 4624 | (x * y) / x -> y 4625 | x / y / x -> 1 / y 4626 | x / y / z -> x / (y * z) 4627 | x / (y / z) -> (x * z) / y 4628 | (a / b) * (b / c) * (c / d) -> a / d 4629 | (2.0 * x) / (4.0 * y) -> (0.5 * x) / y 4630 | 2 * x / 2 -> x 4631 | x * y * z -> Elemwise(T.mul){x,y,z} #only one pass over the memory. 4632 | !-> Elemwise(T.mul){x,Elemwise(T.mul){y,z}} 4633 4634 """ 4635 4636 def __init__(self, main, inverse, reciprocal, calculate, 4637 use_reciprocal=True): 4638 self.main = main 4639 self.inverse = inverse 4640 self.reciprocal = reciprocal 4641 self.calculate = calculate 4642 self.use_reciprocal = use_reciprocal 4643 4644 self.external_simplifiers = [] 4645 4646 def add_simplifier(self, simplifier, reason): 4647 self.external_simplifiers.append((reason, simplifier)) 4648 4649 def tracks(self): 4650 return [self.main, self.inverse, self.reciprocal] 4651 4652 def get_num_denum(self, input): 4653 r""" 4654 This extract two lists, num and denum, such that the input is: 4655 self.inverse(self.main(\*num), self.main(\*denum)). It returns 4656 the two lists in a (num, denum) pair. 4657 4658 For example, for main, inverse and reciprocal = \*, / and inv(), 4659 4660 | input -> returned value (num, denum) 4661 4662 | x*y -> ([x, y], []) 4663 | inv(x) -> ([], [x]) 4664 | inv(x) * inv(y) -> ([], [x, y]) 4665 | x*y/z -> ([x, y], [z]) 4666 | log(x) / y * (z + x) / y -> ([log(x), z + x], [y, y]) 4667 | (((a / b) * c) / d) -> ([a, c], [b, d]) 4668 | a / (b / c) -> ([a, c], [b]) 4669 | log(x) -> ([log(x)], []) 4670 | x**y -> ([x**y], []) 4671 | x * y * z -> ([x, y, z], []) 4672 4673 """ 4674 # This function is recursive. The idea is that there is a 4675 # get_num_denum recursion in which the internal ops are all 4676 # one of (main, inverse, reciprocal, DimShuffle) and the 4677 # internal data nodes all have the dtype of the 'input' 4678 # argument. The leaf-Variables of the graph covered by the 4679 # recursion may be of any Variable type. 4680 4681 if input.owner is None or input.owner.op not in [ 4682 self.main, self.inverse, self.reciprocal]: 4683 if input.owner and isinstance(input.owner.op, T.DimShuffle): 4684 # If input is a DimShuffle of some input which does 4685 # something like this: 4686 4687 # * change a vector of length N into a 1xN row matrix 4688 # * change a scalar into a 1x1x1 tensor 4689 # * in general, complete the shape of a tensor 4690 # with broadcastable 1s to the *left* 4691 # Then we will simply discard the DimShuffle and return 4692 # the num/denum of its input 4693 dsn = input.owner # dimshuffle node 4694 dsop = dsn.op # dimshuffle op 4695 4696 # the first input of the dimshuffle i.e. the ndarray to redim 4697 dsi0 = dsn.inputs[0] 4698 4699 # The compatible order is a DimShuffle "new_order" of the form: 4700 # ('x', ..., 'x', 0, 1, 2, ..., dimshuffle_input.type.ndim) 4701 4702 # That kind of DimShuffle only adds broadcastable 4703 # dimensions on the left, without discarding any 4704 # existing broadcastable dimension and is inserted 4705 # automatically by Elemwise when the inputs have 4706 # different numbers of dimensions (hence why we can 4707 # discard its information - we know we can retrieve it 4708 # later on). 4709 compatible_order = (('x',) * 4710 (input.type.ndim - dsi0.type.ndim) + 4711 tuple(range(dsi0.type.ndim))) 4712 if dsop.new_order == compatible_order: 4713 # If the "new_order" is the one we recognize, 4714 # we return the num_denum of the dimshuffled input. 4715 return self.get_num_denum(input.owner.inputs[0]) 4716 else: 4717 # This is when the input isn't produced by main, 4718 # inverse or reciprocal. 4719 return [input], [] 4720 else: 4721 return [input], [] 4722 num = [] 4723 denum = [] 4724 parent = input.owner 4725 4726 # We get the (num, denum) pairs for each input 4727 # pairs = [self.get_num_denum(input2) if input2.type.dtype == 4728 # input.type.dtype else ([input2], []) for input2 in 4729 # parent.inputs] 4730 pairs = [self.get_num_denum(input2) for input2 in parent.inputs] 4731 4732 if parent.op == self.main: 4733 # If we have main(x, y, ...), numx, denumx, numy, denumy, ... 4734 # then num is concat(numx, numy, num...) and denum is 4735 # concat(denumx, denumy, denum...) note that main() can have any 4736 # number of arguments >= 0 concat is list concatenation 4737 num = reduce(list.__iadd__, map(operator.itemgetter(0), pairs)) 4738 denum = reduce(list.__iadd__, map(operator.itemgetter(1), pairs)) 4739 elif parent.op == self.inverse: 4740 # If we have inverse(x, y), numx, denumx, numy and denumy 4741 # then num is concat(numx, denumy) and denum is 4742 # concat(denumx, numy) note that inverse() is binary 4743 num = pairs[0][0] + pairs[1][1] 4744 denum = pairs[0][1] + pairs[1][0] 4745 elif parent.op == self.reciprocal: 4746 # If we have reciprocal(x), numx, denumx 4747 # then num is denumx and denum is numx 4748 # note that reciprocal() is unary 4749 num = pairs[0][1] 4750 denum = pairs[0][0] 4751 return num, denum 4752 4753 def merge_num_denum(self, num, denum): 4754 r""" 4755 Utility function which takes two lists, num and denum, and 4756 returns something which is equivalent to inverse(main(\*num), 4757 main(\*denum)), but depends on the length of num and the length 4758 of denum (in order to minimize the number of operations). 4759 4760 Let n = len(num) and d = len(denum): 4761 4762 | n=0, d=0: neutral element (given by self.calculate([], [])) 4763 | (for example, this would be 0 if main is addition 4764 | and 1 if main is multiplication) 4765 | n=1, d=0: num[0] 4766 | n=0, d=1: reciprocal(denum[0]) 4767 | n=1, d=1: inverse(num[0], denum[0]) 4768 | n=0, d>1: reciprocal(main(\*denum)) 4769 | n>1, d=0: main(\*num) 4770 | n=1, d>1: inverse(num[0], main(\*denum)) 4771 | n>1, d=1: inverse(main(\*num), denum[0]) 4772 | n>1, d>1: inverse(main(\*num), main(\*denum)) 4773 4774 Given the values of n and d to which they are associated, all 4775 of the above are equivalent to: 4776 inverse(main(\*num), main(\*denum)) 4777 4778 """ 4779 4780 ln, ld = len(num), len(denum) 4781 if not ln and not ld: 4782 return T.as_tensor_variable(self.calculate([], [])) 4783 if not ln: 4784 if self.use_reciprocal: 4785 return self.reciprocal(self.merge_num_denum(denum, [])) 4786 else: 4787 ln = [self.calculate([], [], aslist=False)] 4788 if not ld: 4789 if ln == 1: 4790 # num[0] should always be a variable 4791 assert isinstance(num[0], gof.Variable) 4792 return num[0] 4793 else: 4794 return self.main(*num) 4795 return self.inverse(self.merge_num_denum(num, []), 4796 self.merge_num_denum(denum, [])) 4797 4798 @staticmethod 4799 def get_constant(v): 4800 """ 4801 4802 Returns 4803 ------- 4804 object 4805 A numeric constant if v is a Constant or, well, a 4806 numeric constant. If v is a plain Variable, returns None. 4807 4808 """ 4809 if isinstance(v, Constant): 4810 if getattr(v.tag, 'unique_value', None) is not None: 4811 data = v.tag.unique_value 4812 else: 4813 data = v.data 4814 if data.ndim == 0: 4815 return data 4816 else: 4817 return None 4818 elif isinstance(v, Variable): 4819 return None 4820 else: 4821 return v 4822 4823 def simplify(self, num, denum, out_type): 4824 """ 4825 Shorthand for: 4826 4827 .. code-block:: python 4828 4829 self.simplify_constants(*self.simplify_factors(num, denum)) 4830 4831 """ 4832 rval = self.simplify_constants(*self.simplify_factors(num, denum), 4833 out_type=out_type) 4834 for reason, simplifier in self.external_simplifiers: 4835 # TODO: document that 'reason' is associated with this 4836 # simplification to help auditing when things go 4837 # wrong 4838 rval = simplifier(*rval) 4839 return rval 4840 4841 def simplify_factors(self, num, denum): 4842 """ 4843 For any Variable r which is both in num and denum, removes it 4844 from both lists. Modifies the lists inplace. Returns the 4845 modified lists. For example: 4846 4847 | [x], [x] -> [], [] 4848 | [x, y], [x] -> [y], [] 4849 | [a, b], [c, d] -> [a, b], [c, d] 4850 4851 """ 4852 ln = len(num) 4853 ld = len(denum) 4854 if (ld > 2 and ln > 2): 4855 # Faster version for "big" inputs. 4856 while True: 4857 s = set(num) 4858 # Inputs can appear multiple times 4859 redo = len(s) != len(num) 4860 inter = s.intersection(denum) 4861 for v in inter: 4862 num.remove(v) 4863 denum.remove(v) 4864 if not redo or not inter: 4865 break 4866 else: 4867 for v in list(num): 4868 if v in denum: 4869 num.remove(v) 4870 denum.remove(v) 4871 return num, denum 4872 4873 def simplify_constants(self, orig_num, orig_denum, out_type=None): 4874 """ 4875 Find all constants and put them together into a single constant. 4876 4877 Finds all constants in orig_num and orig_denum (using 4878 get_constant) and puts them together into a single 4879 constant. The constant is inserted as the first element of the 4880 numerator. If the constant is the neutral element, it is 4881 removed from the numerator. 4882 4883 Examples 4884 -------- 4885 Let main be multiplication: 4886 4887 | [2, 3, x], [] -> [6, x], [] 4888 | [x, y, 2], [4, z] -> [0.5, x, y], [z] 4889 | [x, 2, y], [z, 2] -> [x, y], [z] 4890 4891 """ 4892 # Lists representing the numerator and denumerator 4893 num, denum = [], [] 4894 4895 # Lists representing the *constant* elements of num and denum 4896 numct, denumct = [], [] 4897 4898 for v in orig_num: 4899 ct = self.get_constant(v) 4900 if ct is not None: 4901 # We found a constant in the numerator! 4902 # We add it to numct 4903 numct.append(ct) 4904 else: 4905 num.append(v) 4906 for v in orig_denum: 4907 ct = self.get_constant(v) 4908 if ct is not None: 4909 denumct.append(ct) 4910 else: 4911 denum.append(v) 4912 4913 if self.use_reciprocal or num: 4914 # This will calculate either: 4915 # [inverse(main(*numct), main(*denumct))] 4916 # [] - if inverse(main(*numct), main(*denumct)) is the 4917 # neutral element 4918 ct = self.calculate(numct, denumct, aslist=True, 4919 out_type=out_type) 4920 else: 4921 # This happens if we don't allow the reciprocal and the 4922 # numerator is empty. That means we will need to represent 4923 # reciprocal(x) like inverse(neutral_element, x) so 4924 # we can't allow ct == [] 4925 # TODO: why is this branch needed when merge_num_denum 4926 # does it for us? 4927 ct = [self.calculate(numct, denumct, aslist=False, 4928 out_type=out_type)] 4929 4930 # Wrapping ct in a Constant with the right dtype 4931 ct = [T.constant(c, dtype=out_type.dtype) for c in ct] 4932 4933 if orig_num and len(numct) == 1 and len(denumct) == 0 and ct: 4934 # In that case we should only have one constant in `ct`. 4935 assert len(ct) == 1 4936 first_num_ct = self.get_constant(orig_num[0]) 4937 if first_num_ct is not None and ct[0].type.values_eq(ct[0].data, 4938 first_num_ct): 4939 # This is an important trick :( if it so happens that: 4940 # * there's exactly one constant on the numerator and none on 4941 # the denominator 4942 # * it's not the neutral element (ct is an empty list in that 4943 # case) 4944 # * the constant is the same as the first argument in the 4945 # numerator (we only check the first argument because the 4946 # canonizer puts the computed constants first) 4947 # -> then we return very exactly the original num/denum. 4948 # If we don't do that the optimizer will just loop 4949 # infinitely because it will not catch on that there are 4950 # no changes to be made and every time it will want to 4951 # replace something by the same thing... 4952 # Note that it is important to use `values_eq` instead of 4953 # the == operator, to handle NaN values correctly. 4954 return orig_num, orig_denum 4955 4956 return ct + num, denum 4957 4958 def transform(self, node): 4959 op = node.op 4960 if op not in [self.main, self.inverse, self.reciprocal]: 4961 return False 4962 4963 assert len(node.outputs) == 1 4964 out = node.outputs[0] 4965 4966 # out won't have a clients field when we didn't commit a 4967 # started change in the graph. We can't do the check if we 4968 # want to skip it, so we force the skip it. It should be 4969 # reapplied later. 4970 if not hasattr(out, 'clients'): 4971 return 4972 4973 # check if any of the clients of this node would be part of 4974 # this canonized graph... if so, we do nothing and wait for 4975 # them to be transformed. 4976 for c, c_idx in out.clients: 4977 if c == 'output': 4978 continue 4979 while (isinstance(getattr(c, 'op', None), DimShuffle) and 4980 len(c.outputs[0].clients) <= 1): 4981 c = c.outputs[0].clients[0][0] 4982 if getattr(c, 'op', '') in [self.main, self.inverse, 4983 self.reciprocal]: 4984 return False 4985 4986 # Here we make the canonical version of the graph around this node 4987 # See the documentation of get_num_denum and simplify 4988 orig_num, orig_denum = self.get_num_denum(node.outputs[0]) 4989 num, denum = self.simplify(list(orig_num), list(orig_denum), out.type) 4990 4991 def same(x, y): 4992 return len(x) == len(y) and all(np.all(xe == ye) for xe, ye in 4993 zip(x, y)) 4994 4995 if same(orig_num, num) and same(orig_denum, denum): 4996 # We return False if there are no changes 4997 return False 4998 4999 new = self.merge_num_denum(num, denum) 5000 if new.type.dtype != out.type.dtype: 5001 new = T.cast(new, out.type.dtype) 5002 5003 assert (new.type == out.type) == (not (new.type != out.type)) 5004 5005 if not (new.type == out.type): 5006 new = _fill_chain(new, node.inputs)[0] 5007 5008 if new.type == out.type: 5009 # This happen with test 5010 # theano/tensor/tests/test_opt.py:T_local_switch_sink 5011 new.tag.values_eq_approx = values_eq_approx_remove_inf_nan 5012 5013 # We need to implement the copy over of the stacktrace. 5014 # See issue #5104. 5015 return [new] 5016 else: 5017 _logger.warning(' '.join(('CANONIZE FAILED: new, out = ', 5018 new, ',', out, 'types', 5019 new.type, ',', out.type))) 5020 return False 5021 5022 def __str__(self): 5023 return getattr(self, 'name', 'Canonizer(%s, %s, %s)' % ( 5024 self.main, self.inverse, self.reciprocal)) 5025 5026 5027def mul_calculate(num, denum, aslist=False, out_type=None): 5028 if not num and not denum: 5029 # Smallest 1 possible. 5030 if aslist: 5031 return [] 5032 else: 5033 return np.int8(1) 5034 5035 # Make sure we do not accidentally upcast data types. 5036 if out_type is None: 5037 out_dtype = scalar.upcast(*[v.dtype for v in (num + denum)]) 5038 else: 5039 out_dtype = out_type.dtype 5040 one = theano._asarray(1, dtype=out_dtype) 5041 5042 v = reduce(np.multiply, num, one) / reduce(np.multiply, denum, one) 5043 if aslist: 5044 if np.all(v == 1): 5045 return [] 5046 else: 5047 return [v] 5048 return v 5049 5050local_mul_canonizer = Canonizer(T.mul, T.true_div, T.inv, mul_calculate, False) 5051register_canonicalize(local_mul_canonizer, name='local_mul_canonizer') 5052 5053 5054@gof.local_optimizer([T.neg]) 5055def local_neg_to_mul(node): 5056 if node.op == T.neg: 5057 return [T.mul(np.array(-1, dtype=node.inputs[0].dtype), 5058 node.inputs[0])] 5059register_canonicalize(local_neg_to_mul) 5060 5061 5062@register_specialize 5063@gof.local_optimizer([T.Sum, T.elemwise.Prod]) 5064def local_sum_prod_mul_by_scalar(node): 5065 """ 5066 sum(scalar * smth) -> scalar * sum(smth) 5067 sum(-smth) -> -sum(smth) 5068 5069 or 5070 5071 prod(scalar * smth) -> scalar ** size(smth) * prod(smth) 5072 prod(-smth) -> -1 ** size(smth) * prod(smth) 5073 5074 """ 5075 # TODO: if the the thing inside the Sum is a division, 5076 # we should get at the numerator.... 5077 if isinstance(node.op, (T.Sum, T.elemwise.Prod)): 5078 node_inps, = node.inputs 5079 if node_inps.owner and node_inps.owner.op == T.mul: 5080 terms = node_inps.owner.inputs 5081 scalars = [t.dimshuffle() for t in terms if 5082 np.all(t.type.broadcastable)] 5083 5084 if len(scalars) == 0: 5085 # Nothing to optimize here 5086 return 5087 5088 non_scalars = [t for t in terms if not np.all(t.broadcastable)] 5089 5090 # Perform the op only on the non-scalar inputs, if applicable 5091 if len(non_scalars) == 0: 5092 new_op_input_nb_elements = 1 5093 new_op_output = 1 5094 elif len(non_scalars) == 1: 5095 new_op_input_nb_elements = non_scalars[0].size 5096 new_op_output = node.op(non_scalars[0]) 5097 else: 5098 new_op_input = T.mul(*non_scalars) 5099 # We assume that errors always come from the prod/mul op in the 5100 # original computational graph, and therefore need to only 5101 # copy over its output stacktrace. 5102 copy_stack_trace(node.outputs, new_op_input) 5103 5104 new_op_input_nb_elements = new_op_input.size 5105 new_op_output = node.op(new_op_input) 5106 5107 if not len(non_scalars) == 0: 5108 # Copy over stacktrace from previous output to new mul op, 5109 # for same reason as above. 5110 copy_stack_trace(node.outputs, new_op_output) 5111 5112 # If node.op is a T.elemwise.Prod, then the scalars need to be 5113 # raised to the power of the number of elements in the input 5114 # to the Prod 5115 if (isinstance(node.op, T.elemwise.Prod) and 5116 new_op_input_nb_elements != 1): 5117 5118 scalars = [s ** new_op_input_nb_elements for s in scalars] 5119 5120 # Scale the output of the op by the scalars and return as 5121 # replacement for the original output 5122 mul_inputs = scalars 5123 if new_op_input_nb_elements != 1: 5124 mul_inputs.append(new_op_output) 5125 5126 if len(mul_inputs) == 1: 5127 # Copy over stacktrace from previous output to new mul op, 5128 # for same reason as above. 5129 copy_stack_trace(node.outputs, mul_inputs) 5130 5131 return mul_inputs 5132 else: 5133 ret = T.mul(*mul_inputs) 5134 # Copy over stacktrace from previous output to new mul op, 5135 # for same reason as above. 5136 copy_stack_trace(node.outputs, [ret] + mul_inputs) 5137 5138 return [ret] 5139 5140 if isinstance(node.op, T.Sum) and node_inps.owner and node_inps.owner.op == T.neg: 5141 s = node.op(node_inps.owner.inputs[0]) 5142 ret = T.neg(s) 5143 # There are never errors in the negative op, thus 5144 # we need only to copy over stacktrace from previous output node to 5145 # the two new ops. 5146 copy_stack_trace(node.outputs, [s, ret]) 5147 5148 return [ret] 5149 5150 5151@register_specialize 5152@gof.local_optimizer([T.Elemwise]) 5153def local_elemwise_sub_zeros(node): 5154 """ 5155 Elemwise{sub}(X,X) -> zeros_like(X) 5156 """ 5157 if (isinstance(node.op, T.Elemwise) and 5158 node.op.scalar_op.nin == 2 and 5159 node.op.scalar_op == scalar.sub and 5160 node.inputs[0] == node.inputs[1]): 5161 res = T.zeros_like(node.inputs[0]) 5162 # Copy over stacktrace from previous output. 5163 # This could help for failures due to out-of-memory. 5164 copy_stack_trace(node.outputs, res) 5165 return [res] 5166 5167 5168@register_useless 5169@register_specialize 5170@register_stabilize 5171@register_canonicalize 5172@gof.local_optimizer([T.Elemwise]) 5173def local_useless_elemwise_comparison(node): 5174 """... 5175 5176 :note: These cases appear in the graph generated by scan. 5177 These optimizations will make the graph easier to read. 5178 # Comparing to itself is constant 5179 Elemwise[{LT,GT}](X, X) -> Elemwise[zeros](X) 5180 Elemwise[{LE,GE}](X, X) -> Elemwise[ones](X) 5181 Elemwise[{minimum,maximum}](X, X) -> X 5182 5183 # Comparing shape to 0 can be constant 5184 Elemwise[LT](X.shape[i], 0) -> Elemwise[zeros](X) 5185 Elemwise[GE](X.shape[i], 0) -> Elemwise[ones](X) 5186 Elemwise[maximum](X.shape[i], 0) -> X.shape[i] 5187 Elemwise[maximum](0, X.shape[i]) -> X.shape[i] 5188 Elemwise[minimum](X.shape[i], 0) -> 0 5189 Elemwise[minimum](0, X.shape[i]) -> 0 5190 5191 # The shape can be replaced with sum of shapes 5192 Elemwise[LT](add([anything that is shapes]), 0) -> Elemwise[zeros](X) 5193 Elemwise[GE](add([anything that is shapes]), 0) -> Elemwise[ones](X) 5194 5195 # Shapes are never negative 5196 # Needed by Reshape.infer_shape 5197 Elemwise[EQ](Subtensor(Shape(x)), -N) -> Elemwise[zeros](X) 5198 5199 """ 5200 if not isinstance(node.op, T.Elemwise): 5201 return 5202 if node.op.scalar_op.nin != 2: 5203 return 5204 5205 # We call zeros_like and one_like with opt=True to generate a 5206 # cleaner graph. 5207 dtype = node.outputs[0].dtype 5208 5209 # Elemwise[{LT,GT}](X, X) -> Elemwise[zeros](X) 5210 if isinstance(node.op.scalar_op, (scalar.LT, scalar.GT)) and \ 5211 node.inputs[0] is node.inputs[1]: 5212 res = T.zeros_like(node.inputs[0], dtype=dtype, opt=True) 5213 # Copy over stacktrace from previous output. 5214 copy_stack_trace(node.outputs, res) 5215 return [res] 5216 # Elemwise[{LE,GE}](X, X) -> Elemwise[ones](X) 5217 if isinstance(node.op.scalar_op, (scalar.LE, scalar.GE)) and \ 5218 node.inputs[0] is node.inputs[1]: 5219 res = T.ones_like(node.inputs[0], dtype=dtype, opt=True) 5220 5221 # Copy over stacktrace from previous output. 5222 copy_stack_trace(node.outputs, res) 5223 return [res] 5224 # Elemwise[{minimum,maximum}](X, X) -> X 5225 if isinstance(node.op.scalar_op, (scalar.Minimum, scalar.Maximum)) and \ 5226 node.inputs[0] is node.inputs[1]: 5227 res = node.inputs[0] 5228 # Copy over stacktrace from previous output. 5229 copy_stack_trace(node.outputs, res) 5230 return [res] 5231 5232 # Elemwise[LT](X.shape[i], 0) -> Elemwise[zeros](X) 5233 if isinstance(node.op.scalar_op, scalar.LT) and \ 5234 node.inputs[0].owner and \ 5235 isinstance(node.inputs[0].owner.op, Shape_i) and \ 5236 T.extract_constant(node.inputs[1], only_process_constants=True) == 0: 5237 res = T.zeros_like(node.inputs[0], dtype=dtype, opt=True) 5238 # Copy over stacktrace from previous output. 5239 copy_stack_trace(node.outputs, res) 5240 return [res] 5241 # Elemwise[GE](X.shape[i], 0) -> Elemwise[ones](X) 5242 if isinstance(node.op.scalar_op, scalar.GE) and \ 5243 node.inputs[0].owner and \ 5244 isinstance(node.inputs[0].owner.op, Shape_i) and \ 5245 T.extract_constant(node.inputs[1], only_process_constants=True) == 0: 5246 res = T.ones_like(node.inputs[0], dtype=dtype, opt=True) 5247 # Copy over stacktrace from previous output. 5248 copy_stack_trace(node.outputs, res) 5249 return [res] 5250 # Elemwise[maximum](X.shape[i], 0) -> X.shape[i] 5251 if isinstance(node.op.scalar_op, scalar.Maximum) and \ 5252 node.inputs[0].owner and \ 5253 isinstance(node.inputs[0].owner.op, Shape_i) and \ 5254 T.extract_constant(node.inputs[1], only_process_constants=True) == 0: 5255 # No need to copy over stacktrace. 5256 return [node.inputs[0]] 5257 # Elemwise[maximum](0, X.shape[i]) -> X.shape[i] 5258 if isinstance(node.op.scalar_op, scalar.Maximum) and \ 5259 T.extract_constant(node.inputs[0], only_process_constants=True) == 0 and \ 5260 node.inputs[1].owner and \ 5261 isinstance(node.inputs[1].owner.op, Shape_i): 5262 # No need to copy over stacktrace. 5263 return [node.inputs[1]] 5264 # Elemwise[minimum](X.shape[i], 0) -> 0 5265 if isinstance(node.op.scalar_op, scalar.Minimum) and \ 5266 node.inputs[0].owner and \ 5267 isinstance(node.inputs[0].owner.op, Shape_i) and \ 5268 T.extract_constant(node.inputs[1], only_process_constants=True) == 0: 5269 res = T.zeros_like(node.inputs[0], dtype=dtype, opt=True) 5270 # Copy over stacktrace from previous output. 5271 copy_stack_trace(node.outputs, res) 5272 return [res] 5273 5274 # Elemwise[minimum](0, X.shape[i]) -> 0 5275 if isinstance(node.op.scalar_op, scalar.Minimum) and \ 5276 T.extract_constant(node.inputs[0], only_process_constants=True) == 0 and \ 5277 node.inputs[1].owner and \ 5278 isinstance(node.inputs[1].owner.op, Shape_i): 5279 res = T.zeros_like(node.inputs[1], dtype=dtype, opt=True) 5280 # Copy over stacktrace from previous output. 5281 copy_stack_trace(node.outputs, res) 5282 return [res] 5283 5284 # Elemwise[LT](add([anything that is shapes]), 0) -> Elemwise[zeros](X) 5285 if isinstance(node.op.scalar_op, scalar.LT) and \ 5286 node.inputs[0].owner and \ 5287 isinstance(node.inputs[0].owner.op, Elemwise) and \ 5288 isinstance(node.inputs[0].owner.op.scalar_op, scalar.Add) and \ 5289 all([isinstance(var.owner and var.owner.op, Shape_i) 5290 for var in node.inputs[0].owner.inputs]) and \ 5291 T.extract_constant(node.inputs[1], only_process_constants=True) == 0: 5292 res = T.zeros_like(node.inputs[0], dtype=dtype, opt=True) 5293 # Copy over stacktrace from previous output. 5294 copy_stack_trace(node.outputs, res) 5295 return [res] 5296 # Elemwise[GE](add([anything that is shapes]), 0) -> Elemwise[ones](X) 5297 if isinstance(node.op.scalar_op, scalar.GE) and \ 5298 node.inputs[0].owner and \ 5299 isinstance(node.inputs[0].owner.op, Elemwise) and \ 5300 isinstance(node.inputs[0].owner.op.scalar_op, scalar.Add) and \ 5301 all([isinstance(var.owner and var.owner.op, Shape_i) 5302 for var in node.inputs[0].owner.inputs]) and \ 5303 T.extract_constant(node.inputs[1], only_process_constants=True) == 0: 5304 res = T.ones_like(node.inputs[0], dtype=dtype, opt=True) 5305 5306 # Copy over stacktrace from previous output. 5307 copy_stack_trace(node.outputs, res) 5308 return [res] 5309 5310 # Elemwise[EQ](Subtensor(Shape(x)), -N) 5311 # Elemwise[EQ](somegraph that only depend of shape, -N) 5312 # TODO: handle the case where the -N is on either side 5313 """ 5314 |Elemwise{eq,no_inplace} [id B] '' 5315 | |Subtensor{int64} [id C] '' 5316 | | |Join [id D] '' 5317 | | | |TensorConstant{0} [id E] 5318 | | | |Subtensor{int64:int64:} [id F] '' 5319 | | | | |Shape [id G] '' 5320 """ 5321 def investigate(node): 5322 " Return True if values will be shapes, so >= 0" 5323 if isinstance(node.op, (T.Shape, Shape_i)): 5324 return True 5325 elif isinstance(node.op, Subtensor) and node.inputs[0].owner: 5326 return investigate(node.inputs[0].owner) 5327 elif isinstance(node.op, T.Join): 5328 return all(v.owner and 5329 investigate(v.owner) for v in node.inputs[1:]) 5330 elif isinstance(node.op, MakeVector): 5331 return all(v.owner and 5332 investigate(v.owner) for v in node.inputs) 5333 5334 if (isinstance(node.op.scalar_op, scalar.EQ) and 5335 node.inputs[0].owner and 5336 investigate(node.inputs[0].owner)): 5337 try: 5338 cst = get_scalar_constant_value(node.inputs[1], 5339 only_process_constants=True) 5340 5341 res = T.zeros_like(node.inputs[0], dtype=dtype, opt=True) 5342 5343 if cst < 0: 5344 # Copy over stacktrace from previous output. 5345 copy_stack_trace(node.outputs, res) 5346 5347 return [res] 5348 5349 except NotScalarConstantError: 5350 pass 5351 return 5352 5353 5354@register_canonicalize 5355@register_specialize 5356@gof.local_optimizer([T.Sum, T.elemwise.Prod]) 5357def local_sum_prod_div_dimshuffle(node): 5358 """ 5359 sum(a / dimshuffle{...}(b), axis=l) -> sum(a, axis={...}) / b, 5360 if dimension l of the DimShuffle is 'x' 5361 5362 or 5363 5364 prod(a / dimshuffle{...}(b), axis=l) -> 5365 prod(a, axis={...}) / b ** a.shape[l], 5366 if dimension l of the DimShuffle is 'x' 5367 """ 5368 5369 # It does not make much sense now to extend it to the case where the 5370 # dimshuffle is in the numerator, since elemwise inversion of the 5371 # denominator would still be needed before the summation or production. 5372 5373 if isinstance(node.op, (T.Sum, T.elemwise.Prod)): 5374 axis = node.op.axis 5375 if axis is None: 5376 axis = list(range(node.inputs[0].ndim)) 5377 node_input = node.inputs[0] 5378 if node_input.owner and node_input.owner.op == T.true_div: 5379 numerator, denominator = node_input.owner.inputs 5380 5381 # Old, bugged logic, reproduced here only to warn users 5382 if (config.warn.sum_div_dimshuffle_bug and 5383 isinstance(node.op, T.Sum) and 5384 numerator.owner and 5385 isinstance(numerator.owner.op, T.DimShuffle)): 5386 # Check compatibility 5387 new_order = numerator.owner.op.new_order 5388 compatible_dims = True 5389 for ax in axis: 5390 if len(new_order) <= ax or new_order[ax] != 'x': 5391 compatible_dims = False 5392 break 5393 5394 if compatible_dims: 5395 _logger.warn('WARNING: Your current code is fine, but' 5396 ' Theano versions between ' 5397 'rev. 3bd9b789f5e8 (2010-06-16) and' 5398 ' cfc6322e5ad4 (2010-08-03) would ' 5399 'have given an incorrect result. ' 5400 'To disable this warning, set the Theano' 5401 ' flag warn.sum_div_dimshuffle_bug to' 5402 ' False.') 5403 5404 if denominator.owner and isinstance(denominator.owner.op, 5405 T.DimShuffle): 5406 dimshuffle_input = denominator.owner.inputs[0] 5407 dimshuffle_order = denominator.owner.op.new_order 5408 5409 compatible_dims = [] 5410 incompatible_dims = [] 5411 for ax in axis: 5412 if (ax < len(dimshuffle_order) and 5413 dimshuffle_order[ax] == 'x'): 5414 compatible_dims.append(ax) 5415 else: 5416 incompatible_dims.append(ax) 5417 reordered_incompatible_dims = [] 5418 for ic_ax in incompatible_dims: 5419 reordered_incompatible_dims.append( 5420 ic_ax - sum( 5421 [1 for c_ax in compatible_dims if c_ax < ic_ax])) 5422 5423 if len(compatible_dims) > 0: 5424 optimized_dimshuffle_order = list( 5425 ax for i, ax in enumerate(dimshuffle_order) 5426 if (i not in axis) or (ax != 'x')) 5427 5428 # Removing leading 'x' (since it will be done automatically) 5429 while (len(optimized_dimshuffle_order) > 0 and 5430 optimized_dimshuffle_order[0] == 'x'): 5431 del optimized_dimshuffle_order[0] 5432 5433 # if optimized_dimshuffle_order is sorted with 5434 # not 'x', then dimshuffle is useless. 5435 if all(i == e for i, e in 5436 enumerate(optimized_dimshuffle_order)): 5437 optimized_dimshuffle = dimshuffle_input 5438 else: 5439 optimized_dimshuffle = T.DimShuffle( 5440 dimshuffle_input.type.broadcastable, 5441 optimized_dimshuffle_order)(dimshuffle_input) 5442 5443 if (config.warn.sum_div_dimshuffle_bug and 5444 isinstance(node.op, T.Sum)): 5445 _logger.warn('WARNING: Your current code is fine,' 5446 ' but Theano versions between ' 5447 'rev. 3bd9b789f5e8 (2010-06-16) and' 5448 ' cfc6322e5ad4 (2010-08-03) would ' 5449 'have given an incorrect result. ' 5450 'To disable this warning, set the' 5451 ' Theano flag ' 5452 'warn.sum_div_dimshuffle_bug' 5453 ' to False.') 5454 5455 if isinstance(node.op, T.Sum): 5456 op_on_compatible_dims = T.sum( 5457 numerator, axis=compatible_dims) 5458 rval = T.true_div( 5459 op_on_compatible_dims, 5460 optimized_dimshuffle) 5461 if len(reordered_incompatible_dims) > 0: 5462 rval = T.sum(rval, 5463 axis=reordered_incompatible_dims) 5464 elif isinstance(node.op, T.elemwise.Prod): 5465 op_on_compatible_dims = T.prod( 5466 numerator, axis=compatible_dims) 5467 dtype = numerator.dtype 5468 rval = T.true_div( 5469 op_on_compatible_dims, 5470 (optimized_dimshuffle ** 5471 T.prod([numerator.shape[ax].astype(dtype) 5472 for ax in compatible_dims]))) 5473 if len(reordered_incompatible_dims) > 0: 5474 rval = T.prod(rval, 5475 axis=reordered_incompatible_dims) 5476 return [rval] 5477 5478 5479@register_canonicalize 5480@gof.local_optimizer([T.Sum, T.elemwise.Prod]) 5481def local_sum_prod_all_to_none(node): 5482 """ 5483 Sum{0,1,...N} -> Sum{} or 5484 Prod{0,1,...N} -> Prod{} 5485 5486 """ 5487 if isinstance(node.op, T.Sum) or isinstance(node.op, T.elemwise.Prod): 5488 opt_type = T.Sum if isinstance(node.op, T.Sum) else T.elemwise.Prod 5489 # if all the axes are named, then use None as a shorthand 5490 # this permits more merging 5491 if node.op.axis is None: 5492 return 5493 if set(node.op.axis) == set(range(node.inputs[0].type.ndim)): 5494 return [opt_type(axis=None, dtype=node.op.dtype)(node.inputs[0])] 5495 5496 5497@register_canonicalize 5498@gof.local_optimizer([T.Sum, T.elemwise.Prod]) 5499def local_op_of_op(node): 5500 """ 5501 Prod(Prod()) -> single Prod() 5502 or 5503 Sum(Sum()) -> single Sum() 5504 5505 """ 5506 if isinstance(node.op, T.elemwise.Prod) or isinstance(node.op, T.Sum): 5507 opt_type = T.Sum if isinstance(node.op, T.Sum) else T.elemwise.Prod 5508 node_inps, = node.inputs 5509 out_dtype = node.op.dtype 5510 # We manipulate the graph so this is done to make sure the opt 5511 # doesn't affect other computations. 5512 if len(node_inps.clients) == 1: 5513 if (node_inps.owner and 5514 (isinstance(node_inps.owner.op, node.op.__class__))): 5515 5516 # check to see either the inner or outer prod is doing a 5517 # product over all axis, in which case we can remove it 5518 if node_inps.owner.op.axis is None or node.op.axis is None: 5519 return [opt_type(None, dtype=out_dtype)( 5520 node_inps.owner.inputs[0])] 5521 5522 # figure out which axes were in the original sum 5523 newaxis = list(tuple(node_inps.owner.op.axis)) 5524 for i in node.op.axis: 5525 new_i = i 5526 for ii in node_inps.owner.op.axis: 5527 if new_i >= ii: 5528 new_i += 1 5529 assert new_i not in newaxis 5530 newaxis.append(new_i) 5531 5532 assert len(newaxis) == len(list(node_inps.owner.op.axis) + 5533 list(node.op.axis)) 5534 5535 # The old bugged logic. We keep it there to generate a warning 5536 # when we generated bad code. 5537 alldims = list(range(node_inps.owner.inputs[0].type.ndim)) 5538 alldims = [d for i, d in enumerate(alldims) if i 5539 in node_inps.owner.op.axis] 5540 alldims = [d for i, d in enumerate(alldims) 5541 if i in node.op.axis] 5542 newaxis_old = [i for i in 5543 xrange(node_inps.owner.inputs[0].type.ndim) 5544 if i not in alldims] 5545 5546 if (theano.config.warn.sum_sum_bug and 5547 newaxis != newaxis_old and 5548 len(newaxis) == len(newaxis_old)): 5549 _logger.warn( 5550 "WARNING (YOUR CURRENT CODE IS FINE): Theano " 5551 "versions between version 9923a40c7b7a and August " 5552 "2nd, 2010 generated bugged code in this case. " 5553 "This happens when there are two consecutive sums " 5554 "in the graph and the intermediate sum is not " 5555 "used elsewhere in the code. Some safeguard " 5556 "removed some bad code, but not in all cases. You " 5557 "are in one such case. To disable this warning " 5558 "(that you can safely ignore since this bug has " 5559 "been fixed) set the theano flag " 5560 "`warn.sum_sum_bug` to False.") 5561 5562 combined = opt_type(newaxis, dtype=out_dtype) 5563 return [combined(node_inps.owner.inputs[0])] 5564 5565 5566ALL_REDUCE = [T.elemwise.CAReduce, T.elemwise.All, T.elemwise.Any, 5567 T.elemwise.Sum, T.elemwise.Prod, 5568 T.elemwise.ProdWithoutZeros] 5569 5570 5571@register_canonicalize 5572@register_uncanonicalize # Needed for MaxAndArgmax -> CAReduce 5573@gof.local_optimizer(ALL_REDUCE) 5574def local_reduce_join(node): 5575 """ 5576 Reduce{scalar.op}(Join(axis=0, a, b), axis=0) -> Elemwise{scalar.op}(a, b) 5577 5578 Notes 5579 ----- 5580 Supported scalar.op are Maximum, Mimimum in some cases and Add and Mul in 5581 all cases. 5582 5583 Currently we must reduce on axis 0. It is probably extensible to the case 5584 where we join and reduce on the same set of axis. 5585 5586 """ 5587 if (isinstance(node.op, T.CAReduce) and 5588 node.inputs[0].owner and 5589 isinstance(node.inputs[0].owner.op, T.Join)): 5590 join = node.inputs[0].owner 5591 if T.extract_constant(join.inputs[0], only_process_constants=True) != 0: 5592 return 5593 5594 if isinstance(node.op.scalar_op, (scalar.Maximum, scalar.Minimum)): 5595 # Support only 2 inputs for now 5596 if len(join.inputs) != 3: 5597 return 5598 elif not isinstance(node.op.scalar_op, (scalar.Add, scalar.Mul)): 5599 return 5600 elif len(join.inputs) <= 2: 5601 # This is a useless join, that will get removed by another opt. 5602 return 5603 5604 new_inp = [] 5605 for inp in join.inputs[1:]: 5606 inp = inp.owner 5607 if not inp: 5608 return 5609 if (not isinstance(inp.op, DimShuffle) or 5610 inp.op.new_order != ('x',) + 5611 tuple(range(inp.inputs[0].ndim))): 5612 return 5613 new_inp.append(inp.inputs[0]) 5614 ret = Elemwise(node.op.scalar_op)(*new_inp) 5615 5616 if ret.dtype != node.outputs[0].dtype: 5617 # The reduction do something about the dtype. 5618 return 5619 5620 reduce_axis = node.op.axis 5621 if reduce_axis is None: 5622 reduce_axis = tuple(xrange(node.inputs[0].ndim)) 5623 5624 # I put this warning late to don't add extra warning. 5625 if len(reduce_axis) != 1 or 0 not in reduce_axis: 5626 if theano.config.warn.reduce_join: 5627 warnings.warn(( 5628 'Your current code is fine, but Theano versions ' 5629 'prior to 0.7 (or this development version Sept 2014) ' 5630 'might have given an incorrect result for this code. ' 5631 'To disable this warning, set the Theano flag ' 5632 'warn.reduce_join to False. The problem was an ' 5633 'optimization, that modified the pattern ' 5634 '"Reduce{scalar.op}(Join(axis=0, a, b), axis=0)", ' 5635 'did not check the reduction axis. So if the ' 5636 'reduction axis was not 0, you got a wrong answer.')) 5637 return 5638 5639 # We add the new check late to don't add extra warning. 5640 try: 5641 join_axis = get_scalar_constant_value(join.inputs[0], 5642 only_process_constants=True) 5643 5644 if join_axis != reduce_axis[0]: 5645 return 5646 except NotScalarConstantError: 5647 return 5648 5649 return [ret] 5650 5651 5652@register_canonicalize('fast_compile', 'local_cut_useless_reduce') 5653@register_useless('local_cut_useless_reduce') 5654@gof.local_optimizer(ALL_REDUCE) 5655def local_useless_reduce(node): 5656 """Sum(a, axis=[]) -> a """ 5657 if isinstance(node.op, T.CAReduce): 5658 summed, = node.inputs 5659 # if reduce were doing anything, the output ndim would be reduced 5660 if summed.type == node.outputs[0].type: 5661 return [summed] 5662 5663 5664@register_canonicalize 5665@register_uncanonicalize 5666@register_specialize 5667@gof.local_optimizer(ALL_REDUCE) 5668def local_reduce_broadcastable(node): 5669 """Remove reduction over broadcastable dimensions.""" 5670 if isinstance(node.op, T.CAReduce): 5671 reduced, = node.inputs 5672 odtype = node.outputs[0].dtype 5673 if node.op.axis is None: 5674 if all(reduced.broadcastable): 5675 return [reduced.dimshuffle().astype(odtype)] 5676 else: 5677 axis = list(node.op.axis) 5678 cuttable = [a for a in axis if reduced.broadcastable[a]] 5679 if cuttable: 5680 # -- we can remove some axes of summation, 5681 # which simplifies the codegen for sum, especially on GPU 5682 new_axis = [] 5683 pattern = [] 5684 ii = 0 5685 for p in xrange(reduced.ndim): 5686 if p not in cuttable: 5687 if p in axis: 5688 new_axis.append(ii) 5689 pattern.append(p) 5690 ii += 1 5691 new_reduced = reduced.dimshuffle(*pattern) 5692 if new_axis: 5693 if type(node.op) == theano.tensor.elemwise.CAReduce: 5694 # This happen for tensor.max(), tensor.min() 5695 new_op = node.op.__class__(node.op.scalar_op, 5696 axis=new_axis) 5697 else: 5698 new_op = node.op.__class__(axis=new_axis) 5699 return [new_op(new_reduced)] 5700 else: 5701 # -- in this case we can remove the reduction completely 5702 return [new_reduced.astype(odtype)] 5703 5704 5705@register_specialize 5706@gof.local_optimizer([T.Sum, T.elemwise.Prod]) 5707def local_opt_alloc(node): 5708 """ 5709 sum(alloc(constant,shapes...)) => constant*prod(shapes) 5710 or 5711 prod(alloc(constant,shapes...)) => constant**prod(shapes) 5712 5713 """ 5714 if isinstance(node.op, T.Sum) or isinstance(node.op, T.elemwise.Prod): 5715 node_inps, = node.inputs 5716 if node_inps.owner and isinstance(node_inps.owner.op, T.Alloc): 5717 input = node_inps.owner.inputs[0] 5718 shapes = node_inps.owner.inputs[1:] 5719 try: 5720 val = get_scalar_constant_value(input, 5721 only_process_constants=True) 5722 assert val.size == 1 5723 val = val.reshape(1)[0] 5724 # check which type of op 5725 size = T.mul(*shapes) 5726 if input.dtype in ["float16", "float32"]: 5727 # shapes are ints and normally int64. 5728 # We don't want to have a float64 upcast 5729 # We don't want to downcast to float16 5730 # as we fear it could loose too much precision 5731 # that will be amplified by the mul/pow below. 5732 size = size.astype('float32') 5733 if (node.op.axis is None or 5734 node.op.axis == tuple(range(input.ndim))): 5735 if isinstance(node.op, T.Sum): 5736 val = val * size 5737 else: 5738 val = val ** size 5739 # Sum can change the input dtype (upcast or bool 5740 # -> float32) by default or by user request. 5741 # We can ignore the acc_dtype, as there is only 1 5742 # elemwise we will do and not a sequence, so there is no 5743 # accumulation of errors. 5744 # So mostly, we just need to cast the output to the old 5745 # dtype. 5746 val = val.astype(node.outputs[0].dtype) 5747 return [val] 5748 to_prod = [shapes[i] for i in xrange(len(shapes)) 5749 if i in node.op.axis] 5750 if to_prod: 5751 size = T.mul(*to_prod) 5752 if isinstance(node.op, T.Sum): 5753 val *= size 5754 else: 5755 val = val ** size 5756 # See comments above. 5757 val = val.astype(node.outputs[0].dtype) 5758 return [T.alloc(val, 5759 *[shapes[i] for i in xrange(len(shapes)) 5760 if i not in node.op.axis])] 5761 except NotScalarConstantError: 5762 pass 5763 5764 5765@register_specialize 5766@gof.local_optimizer([T.neg]) 5767def local_neg_neg(node): 5768 # other specializations shouldn't put this in, 5769 # but sometimes they do 5770 if node.op == T.neg: 5771 if node.inputs[0].owner and node.inputs[0].owner.op == T.neg: 5772 return [node.inputs[0].owner.inputs[0]] 5773 5774 5775@register_specialize 5776@gof.local_optimizer([T.neg]) 5777def local_neg_div_neg(node): 5778 """ 5779 - (-a / b) -> a / b 5780 5781 Also performs - (c / b) -> ((-c) / b) when c is a scalar constant. 5782 5783 """ 5784 if node.op == T.neg: 5785 if node.inputs[0].owner and node.inputs[0].owner.op == T.true_div: 5786 frac = node.inputs[0] 5787 num, denom = frac.owner.inputs 5788 if num.owner and num.owner.op == T.neg: 5789 if len(frac.clients) == 1: 5790 # No other clients of the original division 5791 new_num = num.owner.inputs[0] 5792 return [T.true_div(new_num, denom)] 5793 elif np.all(num.broadcastable) and isinstance(num, Constant): 5794 if len(frac.clients) == 1: 5795 new_num = -num.data 5796 return [T.true_div(new_num, denom)] 5797 5798 5799@gof.local_optimizer([T.mul]) 5800def local_mul_zero(node): 5801 """ 5802 As part of canonicalization, we replace multiplication by zero 5803 with zero. 5804 5805 """ 5806 if node.op == T.mul: 5807 otype = node.outputs[0].type 5808 5809 for i in node.inputs: 5810 try: 5811 value = get_scalar_constant_value(i) 5812 except NotScalarConstantError: 5813 continue 5814 # print 'MUL by value', value, node.inputs 5815 if value == 0: 5816 # print '... returning zeros' 5817 return _fill_chain(theano._asarray(0, dtype=otype.dtype), 5818 node.inputs) 5819register_canonicalize(local_mul_zero) 5820 5821 5822@gof.local_optimizer([T.true_div]) 5823def local_div_to_inv(node): 5824 if node.op == T.true_div and np.all( 5825 local_mul_canonizer.get_constant(node.inputs[0]) == 1.0): 5826 out = node.outputs[0] 5827 new_out = T.inv(local_mul_canonizer.merge_num_denum(node.inputs[1:], 5828 [])) 5829 # The ones could have forced upcasting 5830 if new_out.dtype != out.dtype: 5831 new_out = T.cast(new_out, dtype=out.dtype) 5832 # The ones could have forced a specific length 5833 if new_out.type != out.type: 5834 new_out = broadcast_like(new_out, out, node.fgraph) 5835 return [new_out] 5836 else: 5837 return False 5838register_specialize(local_div_to_inv) 5839 5840 5841@gof.local_optimizer([T.inv]) 5842def local_inv_canon(node): 5843 if node.op == T.inv: 5844 return [T.pow(node.inputs[0], -1.0)] 5845 else: 5846 return False 5847register_canonicalize(local_inv_canon) 5848 5849 5850@gof.local_optimizer([T.pow]) 5851def local_pow_canonicalize(node): 5852 if node.op == T.pow: 5853 cst = local_mul_canonizer.get_constant(node.inputs[1]) 5854 if cst == 0: 5855 return [broadcast_like(1, node.outputs[0], node.fgraph)] 5856 if cst == 1: 5857 return [broadcast_like(node.inputs[0], node.outputs[0], node.fgraph)] 5858 else: 5859 return False 5860register_canonicalize(local_pow_canonicalize) 5861 5862 5863@register_specialize 5864@gof.local_optimizer([T.mul]) 5865def local_mul_to_sqr(node): 5866 """ 5867 x*x -> sqr(x) 5868 5869 This is faster on the GPU when memory fetching is a big part of 5870 the computation time. 5871 5872 """ 5873 if node.op == T.mul: 5874 if len(node.inputs) == 2: 5875 if node.inputs[0] is node.inputs[1]: 5876 return [T.sqr(node.inputs[0])] 5877 5878 5879@register_canonicalize 5880@gof.local_optimizer([T.int_div]) 5881def local_intdiv_by_one(node): 5882 """x // 1 -> x 5883 """ 5884 if node.op in [T.int_div]: 5885 if isinstance(node.inputs[1], T.TensorConstant) and \ 5886 np.all(node.inputs[1].value == 1): 5887 return [node.inputs[0].astype(node.outputs[0].dtype)] 5888 5889 5890@register_canonicalize 5891@register_specialize 5892@gof.local_optimizer([T.int_div, T.true_div]) 5893def local_zero_div(node): 5894 """0 / x -> 0 5895 """ 5896 if isinstance(node.op, T.Elemwise) and isinstance( 5897 node.op.scalar_op, (theano.scalar.IntDiv, theano.scalar.TrueDiv)): 5898 if local_mul_canonizer.get_constant(node.inputs[0]) == 0: 5899 ret = broadcast_like(0, node.outputs[0], node.fgraph) 5900 ret.tag.values_eq_approx = values_eq_approx_remove_nan 5901 return [ret] 5902 5903 5904@gof.local_optimizer([T.pow]) 5905def local_pow_specialize(node): 5906 # here, we are past the point of canonicalization, so we don't want 5907 # to put in un-necessary fills. 5908 if node.op == T.pow: 5909 # the idea here is that we have pow(x, y) 5910 odtype = node.outputs[0].dtype 5911 xsym = node.inputs[0] 5912 ysym = node.inputs[1] 5913 y = local_mul_canonizer.get_constant(ysym) 5914 if (y is not None) \ 5915 and encompasses_broadcastable(xsym.type.broadcastable, 5916 ysym.type.broadcastable): 5917 rval = None 5918 5919 if np.all(y == 2): 5920 rval = [T.sqr(xsym)] 5921 if np.all(y == 1): 5922 rval = [xsym] 5923 if np.all(y == 0): 5924 rval = [T.fill(xsym, np.asarray(1, dtype=odtype))] 5925 if np.all(y == 0.5): 5926 rval = [T.sqrt(xsym)] 5927 if np.all(y == -0.5): 5928 rval = [T.inv(T.sqrt(xsym))] 5929 if np.all(y == -1): 5930 rval = [T.inv(xsym)] 5931 if np.all(y == -2): 5932 rval = [T.inv(T.sqr(xsym))] 5933 if rval: 5934 rval[0] = T.cast(rval[0], odtype) 5935 assert rval[0].type == node.outputs[0].type, ( 5936 rval, node.outputs) 5937 return rval 5938 else: 5939 return False 5940register_specialize(local_pow_specialize) 5941 5942 5943@register_specialize_device 5944@gof.local_optimizer([T.pow]) 5945def local_pow_specialize_device(node): 5946 """ 5947 This optimization is not the same on all device. We do it only on cpu here. 5948 """ 5949 if node.op == T.pow: 5950 # the idea here is that we have pow(x, y) 5951 odtype = node.outputs[0].dtype 5952 xsym = node.inputs[0] 5953 ysym = node.inputs[1] 5954 y = local_mul_canonizer.get_constant(ysym) 5955 5956 # the next line is needed to fix a strange case that I don't 5957 # know how to make a separate test. 5958 # That happen in the test_opt.py:test_log_erfc test. 5959 # y is a ndarray with dtype int8 and value 2,4 or 6. This make 5960 # the abs(y) <= 512 fail! 5961 # taking the value outside ndarray solve the problem. 5962 # it could be that in that case, numpy make the comparaison 5963 # into the wrong type(do in int8 that overflow.) 5964 if isinstance(y, np.ndarray): 5965 assert y.size == 1 5966 try: 5967 y = y[0] 5968 except IndexError: 5969 pass 5970 if (y is not None) \ 5971 and encompasses_broadcastable(xsym.type.broadcastable, 5972 ysym.type.broadcastable): 5973 rval = None 5974 # 512 is too small for the cpu and too big for some gpu! 5975 if abs(y) == int(abs(y)) and abs(y) <= 512: 5976 pow2 = [xsym] 5977 pow2_scal = [theano.scalar.get_scalar_type(xsym.dtype)()] 5978 y_to_do = abs(y) 5979 for i in xrange(int(np.log2(y_to_do))): 5980 pow2.append(T.sqr(pow2[i])) 5981 pow2_scal.append(theano.scalar.sqr(pow2_scal[i])) 5982 rval1 = None 5983 rval1_scal = None 5984 while y_to_do > 0: 5985 log_to_do = int(np.log2(y_to_do)) 5986 if rval1: 5987 rval1 *= pow2[log_to_do] 5988 rval1_scal *= pow2_scal[log_to_do] 5989 else: 5990 rval1 = pow2[log_to_do] 5991 rval1_scal = pow2_scal[log_to_do] 5992 y_to_do -= 2 ** log_to_do 5993 5994 if abs(y) > 2: 5995 # We fuse all the pow together here to make 5996 # compilation faster 5997 rval1 = Elemwise( 5998 theano.scalar.Composite( 5999 [pow2_scal[0]], [rval1_scal])).make_node(xsym) 6000 if y < 0: 6001 rval = [T.inv(rval1)] 6002 else: 6003 rval = [rval1] 6004 if rval: 6005 rval[0] = T.cast(rval[0], odtype) 6006 assert rval[0].type == node.outputs[0].type, ( 6007 rval, node.outputs) 6008 return rval 6009 6010 6011@gof.local_optimizer([T.mul]) 6012def local_mul_specialize(node): 6013 """ 6014 Remove special-case constants from mul arguments and useless neg in inputs. 6015 6016 mul(-1, x) -> neg(x) 6017 mul(1, x, y) -> mul(x, y) 6018 mul(0, ...) -> alloc(0, shapes...) 6019 6020 This is not done if we would add more nodes in the graph, like with: 6021 6022 mul(-1, x, y) -/-> neg(mul(x, y)) 6023 6024 """ 6025 # here, we are past the point of canonicalization, so we don't 6026 # want to put in un-necessary fills. 6027 # 6028 # at this point [post canonicalize], mul() may have many inputs. 6029 if node.op == T.mul: 6030 # the idea here is that we have pow(x, y) 6031 neg = False 6032 new_inputs = [] 6033 nb_neg_node = 0 6034 nb_cst = 0 6035 for input in node.inputs: 6036 # remove any neg arguments 6037 while input.owner and input.owner.op == T.neg: 6038 neg ^= True 6039 input = input.owner.inputs[0] 6040 nb_neg_node += 1 6041 6042 # remove special case arguments of 1, -1 or 0 6043 y = local_mul_canonizer.get_constant(input) 6044 if y == 1.0: 6045 nb_cst += 1 6046 elif y == -1.0: 6047 nb_cst += 1 6048 neg ^= True # toggles 6049 elif y == 0.0: 6050 # if we find any zero, we just return right away 6051 return [broadcast_like(0, node.outputs[0], node.fgraph)] 6052 else: 6053 new_inputs.append(input) 6054 6055 if new_inputs != node.inputs: 6056 if new_inputs: 6057 if len(new_inputs) == 1: 6058 if neg: 6059 if new_inputs[0].dtype in (T.uint_dtypes + ['bool']): 6060 return 6061 else: 6062 rval = -new_inputs[0] 6063 else: 6064 rval = new_inputs[0] 6065 else: 6066 # The next case would cause a replace by an equivalent case. 6067 if (neg and 6068 nb_neg_node == 0 and 6069 nb_cst == 1): 6070 return 6071 elif neg: 6072 # Don't add an extra neg node as we can't 6073 # fully replace this mul by a neg. 6074 m1 = np.asarray(-1, dtype=node.outputs[0].dtype) 6075 new_inputs = [m1] + new_inputs 6076 rval = T.mul(*new_inputs) 6077 6078 return [broadcast_like(rval, node.outputs[0], node.fgraph)] 6079 else: 6080 # there are no variable inputs to mul 6081 # N.B. this could have been constant-folded... 6082 if neg: 6083 return [broadcast_like(-1, node.outputs[0], node.fgraph)] 6084 else: 6085 return [broadcast_like(1, node.outputs[0], node.fgraph)] 6086 6087register_specialize(local_mul_specialize) 6088 6089 6090@gof.local_optimizer([T.add]) 6091def local_add_specialize(node): 6092 def fill_chain(v): 6093 out = _fill_chain(v, node.inputs) 6094 return out 6095 6096 # here, we are past the point of canonicalization, so we don't want 6097 # to put in un-necessary fills. 6098 if node.op == T.add: 6099 new_inputs = [] 6100 for input in node.inputs: 6101 try: 6102 y = get_scalar_constant_value(input) 6103 except NotScalarConstantError: 6104 y = input 6105 if np.all(y == 0.0): 6106 continue 6107 new_inputs.append(input) 6108 6109 if len(new_inputs) < len(node.inputs): 6110 dtype = node.outputs[0].type.dtype 6111 if len(new_inputs) == 0: 6112 # we got rid of the entire expression! 6113 ndim = node.outputs[0].type.ndim 6114 # Reuse call to constant for cache() 6115 cst = T.constant(np.zeros((1,) * ndim, dtype=dtype)) 6116 assert cst.type.broadcastable == (True,) * ndim 6117 return fill_chain(cst) 6118 6119 if len(new_inputs) == 1: 6120 ret = fill_chain(new_inputs[0]) 6121 else: 6122 ret = fill_chain(T.add(*new_inputs)) 6123 # The dtype should not be changed. It can happen if the input 6124 # that was forcing upcasting was equal to 0. 6125 if ret[0].dtype != dtype: 6126 ret = [T.cast(ret[0], dtype)] 6127 return ret 6128 else: 6129 return False 6130register_specialize(local_add_specialize) 6131 6132mul_canonizer = in2out(gof.LocalOptGroup(local_mul_canonizer, 6133 local_fill_sink, apply_all_opts=True), 6134 name='mul_canonizer_groups') 6135 6136 6137def check_for_x_over_absX(numerators, denominators): 6138 """Convert x/abs(x) into sign(x). """ 6139 # TODO: this function should dig/search through dimshuffles 6140 # This won't catch a dimshuffled absolute value 6141 for den in list(denominators): 6142 if (den.owner and den.owner.op == T.abs_ and 6143 den.owner.inputs[0] in numerators): 6144 if den.owner.inputs[0].type.dtype.startswith('complex'): 6145 # TODO: Make an Op that projects a complex number to 6146 # have unit length but projects 0 to 0. That 6147 # would be a weird Op, but consistent with the 6148 # special case below. I heard there's some 6149 # convention in Matlab that is similar to 6150 # this... but not sure. 6151 pass 6152 else: 6153 denominators.remove(den) 6154 numerators.remove(den.owner.inputs[0]) 6155 numerators.append(T.sgn(den.owner.inputs[0])) 6156 return numerators, denominators 6157local_mul_canonizer.add_simplifier(check_for_x_over_absX, 'X_over_absX') 6158 6159 6160@register_canonicalize 6161@gof.local_optimizer([T.abs_]) 6162def local_abs_lift(node): 6163 """ 6164 Move the abs toward the input. 6165 6166 This is needed for check_for_x_over_absX to apply in more case. 6167 6168 """ 6169 if node.op == T.abs_ and node.inputs[0].owner: 6170 assert node.nin == 1 6171 if node.inputs[0].owner.op == T.mul: 6172 return [T.mul(*[T.abs_(i) for i in node.inputs[0].owner.inputs])] 6173 if node.inputs[0].owner.op == T.true_div: 6174 i = node.inputs[0].owner.inputs 6175 return [T.true_div(T.abs_(i[0]), T.abs_(i[1]))] 6176 6177 6178@register_specialize 6179@gof.local_optimizer([T.mul, T.true_div]) 6180def local_abs_merge(node): 6181 """ 6182 Merge abs generated by local_abs_lift when the canonizer don't 6183 need it anymore 6184 6185 """ 6186 if node.op == T.mul and sum([i.owner.op == T.abs_ for i in node.inputs 6187 if i.owner]) > 1: 6188 inputs = [] 6189 for i in node.inputs: 6190 if i.owner and i.owner.op == T.abs_: 6191 inputs.append(i.owner.inputs[0]) 6192 elif isinstance(i, Constant): 6193 try: 6194 const = get_scalar_constant_value(i, 6195 only_process_constants=True) 6196 except NotScalarConstantError: 6197 return False 6198 if not (const >= 0).all(): 6199 return False 6200 inputs.append(i) 6201 else: 6202 return False 6203 return [T.abs_(T.mul(*inputs))] 6204 if node.op == T.true_div and sum([i.owner.op == T.abs_ for i in 6205 node.inputs if i.owner]) == 2: 6206 return [T.abs_(T.true_div(node.inputs[0].owner.inputs[0], 6207 node.inputs[1].owner.inputs[0]))] 6208 6209 6210@register_stabilize 6211@register_specialize 6212@gof.local_optimizer([T.log]) 6213def local_log1p(node): 6214 # log(1+x) -> log1p(x) 6215 # log(1-x) -> log1p(-x) 6216 if node.op == T.log: 6217 log_arg, = node.inputs 6218 if log_arg.owner and log_arg.owner.op == T.add: 6219 scalars, scalar_inputs, nonconsts = scalarconsts_rest( 6220 log_arg.owner.inputs, only_process_constants=True) 6221 # scalar_inputs are potentially dimshuffled and fill'd scalars 6222 if scalars and np.allclose(np.sum(scalars), 1): 6223 if nonconsts: 6224 if len(nonconsts) > 1: 6225 ninp = T.add(*nonconsts) 6226 else: 6227 ninp = nonconsts[0] 6228 if ninp.dtype != log_arg.type.dtype: 6229 ninp = ninp.astype(node.outputs[0].dtype) 6230 return _fill_chain(T.log1p(ninp), scalar_inputs) 6231 6232 elif log_arg.owner and log_arg.owner.op == T.sub: 6233 one = T.extract_constant(log_arg.owner.inputs[0], 6234 only_process_constants=True) 6235 if one != 1: 6236 return 6237 other = log_arg.owner.inputs[1] 6238 if other.dtype != log_arg.dtype: 6239 other = other.astype(log_arg.dtype) 6240 return [T.log1p(T.neg(other))] 6241 6242 6243# TODO: in canonicalize, change log10 and log2 -> log 6244@register_stabilize 6245@register_specialize 6246@gof.local_optimizer([T.log]) 6247def local_log_add(node): 6248 # log(exp(x)+exp(y)) 6249 # 6250 # Suppose x >= y 6251 # log(exp(x) + exp(y)) 6252 # log(exp(x) * (1 + exp(y)/exp(x))) 6253 # x + log(1 + exp(y)/exp(x)) 6254 # x + log1p(exp(y)/exp(x)) 6255 # x + log1p(exp(y-x)) 6256 if node.op == T.log: 6257 z = node.inputs[0] 6258 if z.owner and z.owner.op == T.add: 6259 zi = z.owner.inputs 6260 if len(zi) != 2: 6261 # -- upgrading Maximum to handle multiple inputs wasn't trivial 6262 # TODO 6263 # raise NotImplementedError() 6264 return 6265 pre_exp = [x.owner.inputs[0] for x in zi 6266 if x.owner and x.owner.op == T.exp] 6267 if len(pre_exp) == len(zi): 6268 # all arguments to add are exp(<something>) 6269 max_pre = T.maximum(*pre_exp) 6270 6271 ret = max_pre + T.log1p(T.exp(T.add(*[p - max_pre 6272 for p in pre_exp]))) 6273 ret.tag.values_eq_approx = values_eq_approx_remove_inf 6274 return [ret] 6275 6276 6277@gof.local_optimizer([T.log]) 6278def local_log_sum_exp(node): 6279 # log(sum_i(exp(x_i))) = x_max + log(sum_i(exp(x_i - x_max))) 6280 6281 if node.op != T.log: 6282 return 6283 6284 sum_node = node.inputs[0].owner 6285 # If the sum has keepdims=True, there might be a dimshuffle 6286 if sum_node and isinstance(sum_node.op, T.DimShuffle): 6287 dimshuffle_op = sum_node.op 6288 sum_node = sum_node.inputs[0].owner 6289 else: 6290 dimshuffle_op = None 6291 6292 if not sum_node or not isinstance(sum_node.op, T.Sum): 6293 return 6294 6295 exp_node, axis = sum_node.inputs[0].owner, sum_node.op.axis 6296 if not exp_node or not ( 6297 isinstance(exp_node.op, Elemwise) and 6298 isinstance(exp_node.op.scalar_op, scalar.Exp)): 6299 return 6300 6301 pre_exp = exp_node.inputs[0] 6302 max_pre_exp = T.max(pre_exp, axis=axis) 6303 max_pre_exp_keepdims = T.makeKeepDims(pre_exp, max_pre_exp, axis) 6304 6305 ret = (max_pre_exp + 6306 T.log(T.sum(T.exp(pre_exp - max_pre_exp_keepdims), axis=axis))) 6307 6308 # Restore the dimshuffle op, if any. 6309 if dimshuffle_op: 6310 ret = dimshuffle_op(ret) 6311 6312 return [ret] 6313 6314 6315compile.optdb.register('local_log_sum_exp', 6316 in2out(local_log_sum_exp, ignore_newtrees=True), 6317 1.6, 'fast_run') 6318 6319 6320def add_calculate(num, denum, aslist=False, out_type=None): 6321 # TODO: make sure that this function and mul_calculate are similar 6322 if out_type is None: 6323 zero = 0.0 6324 else: 6325 zero = theano._asarray(0, dtype=out_type.dtype) 6326 # zero = 0.0 if out_type is None else theano._asarray(0, 6327 # dtype=out_type.dtype) 6328 if out_type and out_type.dtype == 'bool': 6329 if len(denum) == 0: 6330 # NumPy 1.14 do not accept to do "bool - bool" 6331 v = reduce(np.add, num, zero) 6332 else: 6333 raise Exception( 6334 "bool subtraction not supported. This should not happen as" 6335 " an earlier error should have been raised") 6336 else: 6337 v = reduce(np.add, num, zero) - reduce(np.add, denum, zero) 6338 if aslist: 6339 if np.all(v == 0): 6340 return [] 6341 else: 6342 return [v] 6343 return v 6344 6345 6346local_add_canonizer = Canonizer(T.add, T.sub, T.neg, add_calculate) 6347add_canonizer = in2out(gof.LocalOptGroup(local_add_canonizer, 6348 local_fill_sink, apply_all_opts=True), 6349 name='add_canonizer_group') 6350 6351 6352register_canonicalize(local_add_canonizer, name='local_add_canonizer') 6353 6354 6355################## 6356# Distributivity # 6357################## 6358 6359 6360def distribute_greedy(pos_pairs, neg_pairs, num, denum, 6361 out_type, minscore=0): 6362 # each pair in pos_pairs and neg_pairs is a num/denum pair. this 6363 # function attempts to add num and denum to the corresponding parts 6364 # of each pair, and counts how many multiplications/divisions can 6365 # be saved in that way. 6366 6367 # each division is counted like div_cost multiplications 6368 # (typically, division costs more so we are willing to multiply more 6369 # in order to divide less) 6370 # 1.5 was obtained through an informal test and may very well be 6371 # platform dependent 6372 div_cost = 1.5 6373 6374 # score is number of operations saved, higher is better 6375 score = len(num) + div_cost * len(denum) 6376 new_pos_pairs = list(itertools.starmap(local_mul_canonizer.simplify, 6377 [(n + num, d + denum, out_type) for (n, d) 6378 in pos_pairs])) 6379 new_neg_pairs = list(itertools.starmap(local_mul_canonizer.simplify, 6380 [(n + num, d + denum, out_type) for (n, d) 6381 in neg_pairs])) 6382 for (n, d), (nn, dd) in zip(pos_pairs + neg_pairs, new_pos_pairs + 6383 new_neg_pairs): 6384 # We calculate how many operations we are saving with the new 6385 # num and denum 6386 score += len(n) + div_cost * len(d) - len(nn) - div_cost * len(dd) 6387 if score <= minscore: 6388 # the change is not applied because it adds too many operations 6389 return False, pos_pairs, neg_pairs 6390 return True, new_pos_pairs, new_neg_pairs 6391 6392 6393def attempt_distribution(factor, num, denum, out_type): 6394 # we try to insert each num and each denum in the factor 6395 # returns: changes?, new_factor, new_num, new_denum 6396 # if there are changes, new_num and new_denum contain all the numerators 6397 # and denumerators that could not be distributed in the factor 6398 pos, neg = local_add_canonizer.get_num_denum(factor) 6399 if len(pos) == 1 and not neg: 6400 return False, factor, num, denum 6401 pos_pairs = list(map(local_mul_canonizer.get_num_denum, pos)) 6402 neg_pairs = list(map(local_mul_canonizer.get_num_denum, neg)) 6403 change = False 6404 for n in list(num): 6405 success, pos_pairs, neg_pairs = distribute_greedy(pos_pairs, 6406 neg_pairs, [n], [], out_type) 6407 if success: 6408 change = True 6409 num.remove(n) 6410 for d in list(denum): 6411 success, pos_pairs, neg_pairs = distribute_greedy(pos_pairs, 6412 neg_pairs, [], [d], out_type) 6413 if success: 6414 change = True 6415 denum.remove(d) 6416 if not change: 6417 return change, factor, num, denum 6418 else: 6419 return change, local_add_canonizer.merge_num_denum( 6420 list(itertools.starmap(local_mul_canonizer.merge_num_denum, 6421 pos_pairs)), 6422 list(itertools.starmap(local_mul_canonizer.merge_num_denum, 6423 neg_pairs))), num, denum 6424 6425 6426@register_canonicalize 6427@register_stabilize 6428@gof.local_optimizer([T.mul, T.true_div, T.inv]) 6429def local_greedy_distributor(node): 6430 """ 6431 Optimize by reducing the number of multiplications and/or divisions. 6432 6433 This optimization tries to apply distributivity of multiplication 6434 to addition in order to reduce the number of multiplications 6435 and/or divisions that must be done. The algorithm weighs division 6436 more than multiplication to account for the former's slightly 6437 greater computational cost. 6438 6439 The following expressions are simplified: 6440 1. ((a/x + b/y) * x * y) --> a*y + b*x 6441 2. ((a/x + b) * x) --> a + b*x 6442 3. There are other forms too where node is a true_div. 6443 6444 The following expressions are not simplified: 6445 4. ((a + b) * x) -/-> a*x + b*x 6446 6447 This optimization aims to reduce computational cost. It may also 6448 increase numerical stability, e.g. when x and/or y tend to 0 in 6449 example 1. 6450 6451 """ 6452 6453 out = node.outputs[0] 6454 num, denum = local_mul_canonizer.get_num_denum(out) 6455 if len(num) == 1 and not denum: 6456 return False 6457 6458 new_num, new_denum = [], [] 6459 6460 change = False 6461 6462 out_type = out.type 6463 for candidate in list(num): 6464 if candidate not in num: 6465 continue 6466 num.remove(candidate) 6467 _change, candidate, num, denum = attempt_distribution( 6468 candidate, num, denum, out_type,) 6469 6470 change |= _change 6471 new_num.append(candidate) 6472 6473 for candidate in list(denum): 6474 if candidate not in denum: 6475 continue 6476 denum.remove(candidate) 6477 _change, candidate, denum, num = attempt_distribution( 6478 candidate, denum, num, out_type) 6479 change |= _change 6480 new_denum.append(candidate) 6481 if not change: 6482 return False 6483 6484 new_num += num 6485 new_denum += denum 6486 6487 rval = local_mul_canonizer.merge_num_denum(new_num, new_denum) 6488 6489 if not (rval.type == out.type): 6490 # WHY DOES THIS HAPPEN? 6491 return False 6492 6493 return [rval] 6494 6495 6496@gof.local_optimizer(None) 6497def constant_folding(node): 6498 for input in node.inputs: 6499 if not isinstance(input, Constant): 6500 return False 6501 # condition: all inputs are constant 6502 if not node.op.do_constant_folding(node): 6503 # The op asks not to be constant folded. 6504 return False 6505 6506 storage_map = dict([(i, [i.data]) for i in node.inputs]) 6507 compute_map = dict([(i, [True]) for i in node.inputs]) 6508 for o in node.outputs: 6509 storage_map[o] = [None] 6510 compute_map[o] = [False] 6511 impl = None 6512 if (hasattr(node.op, 'python_constant_folding') and 6513 node.op.python_constant_folding(node)): 6514 impl = 'py' 6515 thunk = node.op.make_thunk(node, storage_map, compute_map, 6516 no_recycling=[], impl=impl) 6517 6518 required = thunk() 6519 assert not required # a node whose inputs are all provided should always 6520 # return successfully 6521 rval = [] 6522 for output in node.outputs: 6523 assert compute_map[output][0], (output, storage_map[output][0]) 6524 try: 6525 constant = output.type.Constant 6526 except AttributeError: 6527 constant = Constant 6528 6529 v = constant(output.type, storage_map[output][0]) 6530 copy_stack_trace(output, v) 6531 6532 rval.append(v) 6533 return rval 6534 6535 6536topo_constant_folding = in2out(constant_folding, ignore_newtrees=True, 6537 name="topo_constant_folding") 6538register_canonicalize(topo_constant_folding, 'fast_compile', final_opt=True) 6539register_uncanonicalize(topo_constant_folding, 'fast_compile', final_opt=True) 6540register_stabilize(topo_constant_folding, 'fast_compile', final_opt=True) 6541register_specialize(topo_constant_folding, 'fast_compile', final_opt=True) 6542 6543 6544def get_clients(node): 6545 """ 6546 Used by erf/erfc opt to track less frequent op. 6547 6548 """ 6549 return [c for c, i in node.outputs[0].clients 6550 if c != "output"] 6551 6552 6553def get_clients2(node): 6554 """ 6555 Used by erf/erfc opt to track less frequent op. 6556 6557 """ 6558 l = [] 6559 for c, i in node.outputs[0].clients: 6560 if c != "output": 6561 for var in c.outputs: 6562 l.extend([cc for cc, ii in var.clients if cc != "output"]) 6563 return l 6564 6565# 1+erf(x)=>erfc(-x) 6566local_one_plus_erf = gof.PatternSub((T.add, 6567 1, 6568 (T.erf, 'x')), 6569 (T.erfc, (T.neg, 'x')), 6570 allow_multiple_clients=True, 6571 name='local_one_plus_erf', 6572 tracks=[T.erf], 6573 get_nodes=get_clients) 6574register_canonicalize(local_one_plus_erf) 6575register_stabilize(local_one_plus_erf) 6576register_specialize(local_one_plus_erf) 6577 6578# 1-erf(x)=>erfc(x) 6579local_one_minus_erf = gof.PatternSub((T.sub, 6580 1, 6581 (T.erf, 'x')), 6582 (T.erfc, 'x'), 6583 allow_multiple_clients=True, 6584 name='local_one_minus_erf',) 6585register_canonicalize(local_one_minus_erf) 6586register_stabilize(local_one_minus_erf) 6587register_specialize(local_one_minus_erf) 6588 6589local_one_minus_erf2 = gof.PatternSub((T.add, 6590 1, 6591 (T.mul, -1, (T.erf, 'x'))), 6592 (T.erfc, 'x'), 6593 allow_multiple_clients=True, 6594 name='local_one_minus_erf2') 6595register_canonicalize(local_one_minus_erf2) 6596register_stabilize(local_one_minus_erf2) 6597register_specialize(local_one_minus_erf2) 6598 6599# 1+(-erf(x))=>erfc(x) This is a different graph then the previous as 6600# the canonicalize don't work completly 6601local_one_plus_neg_erf = gof.PatternSub((T.add, 6602 1, 6603 (T.neg, (T.erf, 'x'))), 6604 (T.erfc, 'x'), 6605 allow_multiple_clients=True, 6606 name='local_one_plus_neg_erf', 6607 tracks=[T.erf], 6608 get_nodes=get_clients2) 6609register_canonicalize(local_one_plus_neg_erf) 6610register_stabilize(local_one_plus_neg_erf) 6611register_specialize(local_one_plus_neg_erf) 6612 6613# (-1)+erf(x) => -erfc(x) don't need erf(x)+(-1) as the canonicalize 6614# will put the -1 as the first argument. 6615local_erf_minus_one = gof.PatternSub((T.add, 6616 -1, 6617 (T.erf, 'x')), 6618 (T.neg, (T.erfc, 'x')), 6619 allow_multiple_clients=True, 6620 name='local_erf_minus_one', 6621 tracks=[T.erf], 6622 get_nodes=get_clients) 6623register_canonicalize(local_erf_minus_one) 6624register_stabilize(local_erf_minus_one) 6625register_specialize(local_erf_minus_one) 6626 6627# 1-erfc(x) => erf(x) 6628local_one_minus_erfc = gof.PatternSub((T.sub, 6629 1, 6630 (T.erfc, 'x')), 6631 (T.erf, 'x'), 6632 allow_multiple_clients=True, 6633 name='local_one_minus_erfc', 6634 tracks=[T.erfc], 6635 get_nodes=get_clients) 6636register_canonicalize(local_one_minus_erfc) 6637register_stabilize(local_one_minus_erfc) 6638register_specialize(local_one_minus_erfc) 6639 6640local_one_minus_erfc2 = gof.PatternSub((T.add, 6641 1, 6642 (T.neg, (T.erfc, 'x'))), 6643 (T.erf, 'x'), 6644 allow_multiple_clients=True, 6645 name='local_one_minus_erfc2', 6646 tracks=[T.erfc], 6647 get_nodes=get_clients2) 6648register_canonicalize(local_one_minus_erfc2) 6649register_stabilize(local_one_minus_erfc2) 6650register_specialize(local_one_minus_erfc2) 6651 6652local_one_minus_erfc3 = gof.PatternSub((T.add, 6653 1, 6654 (T.mul, -1, (T.erfc, 'x'))), 6655 (T.erf, 'x'), 6656 allow_multiple_clients=True, 6657 name='local_one_minus_erfc3', 6658 tracks=[T.erfc], 6659 get_nodes=get_clients2) 6660register_canonicalize(local_one_minus_erfc3) 6661register_stabilize(local_one_minus_erfc3) 6662register_specialize(local_one_minus_erfc3) 6663 6664# 1+(-erfc(x)) => erf(x) This is a different graph then the previous as 6665# the canonicalize don't work completly 6666local_one_add_neg_erfc = gof.PatternSub((T.add, 6667 1, 6668 (T.neg, (T.erfc, 'x'))), 6669 (T.erf, 'x'), 6670 allow_multiple_clients=True, 6671 name='local_one_add_neg_erfc', 6672 tracks=[T.erfc], 6673 get_nodes=get_clients2) 6674 6675register_canonicalize(local_one_add_neg_erfc) 6676register_stabilize(local_one_add_neg_erfc) 6677register_specialize(local_one_add_neg_erfc) 6678 6679# (-1)+erfc(-x)=>erf(x) 6680local_erf_neg_minus_one = gof.PatternSub((T.add, 6681 -1, 6682 (T.erfc, (T.neg, 'x'))), 6683 (T.erf, 'x'), 6684 allow_multiple_clients=True, 6685 name='local_erf_neg_minus_one', 6686 tracks=[T.erfc], 6687 get_nodes=get_clients) 6688register_canonicalize(local_erf_neg_minus_one) 6689register_stabilize(local_erf_neg_minus_one) 6690register_specialize(local_erf_neg_minus_one) 6691 6692# (-1)+erfc(-1*x)=>erf(x) 6693local_erf_neg_minus_one2 = gof.PatternSub((T.add, 6694 -1, 6695 (T.erfc, (T.mul, -1, 'x'))), 6696 (T.erf, 'x'), 6697 allow_multiple_clients=True, 6698 name='local_erf_neg_minus_one2', 6699 tracks=[T.erfc], 6700 get_nodes=get_clients) 6701register_canonicalize(local_erf_neg_minus_one2) 6702register_stabilize(local_erf_neg_minus_one2) 6703register_specialize(local_erf_neg_minus_one2) 6704 6705 6706# Stability optimization 6707# log(erfc(x)) => when x>threashold, 6708# -x**2-log(x)-.5*log(pi)+log(1-1/(2*x**2)+3/(4*x**4)-15/(8*x**6)) 6709# for float64: threshold=26.641747557 was choosed with: 6710# [(i,numpy.log(scipy.special.erfc(numpy.asarray([i],dtype='float64')))) 6711# for i in numpy.arange(26.641747557,26.6417475571,.00000000001)] 6712# for float32: threshold=10.0541949, [(i,numpy.log(scipy.special.erfc( 6713# numpy.asarray([i],dtype='float32')))) for i in numpy.arange( 6714# 10.0541948,10.0541951,.0000001)] 6715@register_stabilize 6716@register_specialize 6717@gof.local_optimizer([T.log]) 6718def local_log_erfc(node): 6719 if node.op != T.log: 6720 return False 6721 if not node.inputs[0].owner or node.inputs[0].owner.op != T.erfc: 6722 return False 6723 6724 if hasattr(node.tag, 'local_log_erfc_applied'): 6725 # We use that flag to don't apply the optimization recursively 6726 return False 6727 node.tag.local_log_erfc_applied = True 6728 6729 x = node.inputs[0].owner.inputs[0] 6730 stab_value = (-x ** 2 - T.log(x) - .5 * T.log(np.pi) + 6731 T.log(1 - 1 / (2 * x ** 2) + 3 / (4 * x ** 4) - 6732 15 / (8 * x ** 6))) 6733 6734 if (node.outputs[0].dtype == 'float32' or 6735 node.outputs[0].dtype == 'float16'): 6736 threshold = 10.0541949 6737 elif node.outputs[0].dtype == 'float64': 6738 threshold = 26.641747557 6739 6740 ret = T.switch(x < threshold, node.outputs[0], stab_value) 6741 ret.tag.values_eq_approx = values_eq_approx_remove_inf 6742 return [ret] 6743 6744 6745# Stability optimization of the grad of log(erfc(x)) 6746# ([y*]exp(-(x**2)))/erfc(x) # The y* is optional 6747# ([y*]exp(x**2))/erfc(-x) => [y*](when x>threashold, 6748# sqrt(pi)*-x/(1-1/(2*x**2)+3/(4*x**4)-15/(8*x**6))) 6749# for float64: threshold=26.63 see at the end of the fct for the explanation 6750# for float32: threshold=9.3 see at the end of the fct for the explanation 6751# TODO: remove the contraint that there are only 2 inputs to exp(x**2) 6752# is the second. 6753# TODO: at the test point 10 in float32, there is instability in the original 6754# value. The original gives -30.0, the stab -20.1 and in float64 -18.1. 6755# Make it so that the test does not generate an error in that case! 6756@register_stabilize 6757@register_specialize 6758@gof.local_optimizer([T.true_div]) 6759def local_grad_log_erfc_neg(node): 6760 if node.op != T.true_div: 6761 return False 6762 if not node.inputs[1].owner or node.inputs[1].owner.op != T.erfc: 6763 return False 6764 erfc = node.inputs[1] 6765 erfc_x = erfc.owner.inputs[0] 6766 if not node.inputs[0].owner: 6767 return False 6768 6769 # The mul is optional. 6770 if node.inputs[0].owner.op != T.mul: 6771 mul = None 6772 y = [] 6773 if not node.inputs[0].owner or node.inputs[0].owner.op != T.exp: 6774 return False 6775 exp = node.inputs[0] 6776 else: 6777 mul = node.inputs[0] 6778 exp = None 6779 for idx, inp in enumerate(mul.owner.inputs): 6780 if inp.owner and inp.owner.op == T.exp: 6781 exp = inp 6782 break 6783 if len(mul.owner.inputs) == 2: 6784 y = [mul.owner.inputs[1 - idx]] 6785 else: 6786 y = mul.owner.inputs[:] 6787 del y[idx] 6788 del mul 6789 if not exp.owner.inputs[0].owner: 6790 return False 6791 6792 if exp.owner.inputs[0].owner.op == T.neg: 6793 neg = exp.owner.inputs[0] 6794 if (not neg.owner.inputs[0].owner or 6795 neg.owner.inputs[0].owner.op != T.sqr): 6796 return False 6797 sqr = neg.owner.inputs[0] 6798 x = sqr.owner.inputs[0] 6799 elif exp.owner.inputs[0].owner.op == T.mul: 6800 # We should compare that -(erfc_x**2) is equivalent to mul_neg. 6801 # There is currently no easy way to do this in the general case, 6802 # so we implement some common case for now. 6803 6804 # In many cases the neg are replaced by mul in the graph. 6805 # This also allows to stabilize log(erfc(cst*x)). 6806 mul_neg = exp.owner.inputs[0] 6807 6808 # In case that multiple mul are not fused together, we do it here. 6809 def check_input(inputs): 6810 new_inputs = [] 6811 for i in inputs: 6812 if i.owner and i.owner.op == T.mul: 6813 new_inputs.extend(check_input(i.owner.inputs)) 6814 else: 6815 new_inputs.append(i) 6816 return new_inputs 6817 mul_inputs = check_input(mul_neg.owner.inputs) 6818 6819 # Put the constant first. 6820 for i in xrange(len(mul_inputs)): 6821 if isinstance(i, Constant): 6822 if i == 0: 6823 break 6824 else: 6825 tmp = mul_inputs[0] 6826 mul_inputs[0] = mul_inputs[i] 6827 mul_inputs[i] = tmp 6828 break 6829 mul_neg = T.mul(*mul_inputs) 6830 6831 try: 6832 cst2 = get_scalar_constant_value(mul_neg.owner.inputs[0], 6833 only_process_constants=True) 6834 except NotScalarConstantError: 6835 return False 6836 6837 if len(mul_neg.owner.inputs) == 2: 6838 if (not mul_neg.owner.inputs[1].owner or 6839 mul_neg.owner.inputs[1].owner.op != T.sqr): 6840 return False 6841 sqr = mul_neg.owner.inputs[1] 6842 x = sqr.owner.inputs[0] 6843 elif len(mul_neg.owner.inputs) == 3: 6844 if mul_neg.owner.inputs[1] is not mul_neg.owner.inputs[2]: 6845 return False 6846 x = mul_neg.owner.inputs[1] 6847 else: 6848 return False 6849 6850 if cst2 != -1: 6851 if (not erfc_x.owner or erfc_x.owner.op != T.mul or 6852 len(erfc_x.owner.inputs) != 2): 6853 # todo implement that case 6854 return False 6855 if erfc_x.owner.inputs[1] is not mul_neg.owner.inputs[1]: 6856 return False 6857 6858 x = erfc_x 6859 try: 6860 cst = get_scalar_constant_value(erfc_x.owner.inputs[0], 6861 only_process_constants=True) 6862 except NotScalarConstantError: 6863 return False 6864 if cst2 != -cst * 2: 6865 return False 6866 6867 # The constant is valid. Must check that the 6868 elif erfc_x is not x: 6869 return False 6870 6871 else: 6872 return False 6873 6874 if hasattr(node.tag, 'local_grad_log_erfc_neg'): 6875 # We use that flag to don't apply the optimization recursively 6876 return False 6877 6878 # we move the y outside the div. 6879 true_div_no_mul = T.true_div(exp, erfc) 6880 true_div_no_mul.owner.tag.local_grad_log_erfc_neg = True 6881 6882 # aaron value 6883 stab_value = (x * T.pow(1 - 1 / (2 * (x ** 2)) + 6884 3 / (4 * (x ** 4)) - 15 / (8 * (x ** 6)), -1) * 6885 T.cast(T.sqrt(np.pi), dtype=x.dtype)) 6886 6887 if x.dtype == 'float32' or x.dtype == 'float16': 6888 threshold = 9.3 6889 # threshold = 10.1 6890 elif x.dtype == 'float64': 6891 threshold = 26.641747557 6892 ret = T.switch(x < threshold, true_div_no_mul, stab_value) 6893 if y: 6894 ret = T.mul(ret, *y) 6895 ret.tag.values_eq_approx = values_eq_approx_remove_inf_nan 6896 return [ret] 6897 """ 6898The libm used for the test is amdlibm 6899 #([y*]exp(-(x**2)))/erfc(x) # The mul is optional 6900#exp(x**2)/erfc(-x) => when x>threashold, 6901#-x*(1-1/(2*x**2)+3/(4*x**4)-15/(8*x**6))*sqrt(pi) for float64: 6902#threshold=26.63 see below for float32: threshold=9.3 see below TODO 6903#remove the contraint that there are only 2 inputs to mul TODO: should 6904#we cast numpy.pi to x.dtype? 6905 6906#float32 threshold 9.3 as the approximation is more precise at that 6907#point and more stable. 6908import numpy, scipy.special 6909r = numpy.arange(9,10.06,.01) 6910 6911p64=[(numpy.exp(-(x**2)))/scipy.special.erfc(x) for x in r] 6912p32=[(numpy.exp(-(x**2)))/scipy.special.erfc(x) for x in 6913numpy.asarray(r,dtype='float32')] 6914a64=[x*((1-1/(2*x**2)+3/(4*x**4)-15/(8*x**6))**(-1))*numpy.sqrt(numpy.pi) 6915for x in r] 6916a32=[x*((1-1/(2*x**2)+3/(4*x**4)-15/(8*x**6))**(-1)) 6917 * numpy.float32(numpy.sqrt(numpy.pi)) 6918for x in numpy.asarray(r,dtype='float32')] for idx,(a,b,c,d,e) in 6919enumerate(zip(r,p64,p32,a64,a32)):print 6920a,b,c,d,e,c-b,e-b,numpy.absolute(c-b)<numpy.absolute(e-b) 6921 6922#, show that the value don't look stable at some point before inf. 6923for i in xrange(1,len(p32)): print r[i], p32[i]-p32[i-1] 6924 6925#float64 threshold is 26.63 the approx seam more precise at that 6926point. r = numpy.arange(26.2,26.7,.001) 6927#scipy.special.erfc(numpy.float128(x)) don't work 6928#p128=[(numpy.exp(-(x**2)))/scipy.special.erfc(x)for x in 6929numpy.float128(r)] #those value have been computed with g++ 6930theano/misc/erfc_stability_threshold.c && ./a.out 6931p128=numpy.float128(['46.47206725', '46.47383842', '46.47560959', 6932'46.47738076', '46.47915193', '46.48092309', '46.48269426', 6933'46.48446543', '46.48623660', '46.48800777', '46.48977894', 6934'46.49155011', '46.49332128', '46.49509245', '46.49686362', 6935'46.49863479', '46.50040596', '46.50217713', '46.50394830', 6936'46.50571947', '46.50749064', '46.50926181', '46.51103298', 6937'46.51280415', '46.51457532', '46.51634649', '46.51811766', 6938'46.51988883', '46.52166000', '46.52343118', '46.52520235', 6939'46.52697352', '46.52874469', '46.53051586', '46.53228703', 6940'46.53405820', '46.53582938', '46.53760055', '46.53937172', 6941'46.54114289', '46.54291407', '46.54468524', '46.54645641', 6942'46.54822758', '46.54999876', '46.55176993', '46.55354110', 6943'46.55531227', '46.55708345', '46.55885462', '46.56062579', 6944'46.56239697', '46.56416814', '46.56593931', '46.56771049', 6945'46.56948166', '46.57125283', '46.57302401', '46.57479518', 6946'46.57656636', '46.57833753', '46.58010871', '46.58187988', 6947'46.58365105', '46.58542223', '46.58719340', '46.58896458', 6948'46.59073575', '46.59250693', '46.59427810', '46.59604928', 6949'46.59782045', '46.59959163', '46.60136280', '46.60313398', 6950'46.60490516', '46.60667633', '46.60844751', '46.61021868', 6951'46.61198986', '46.61376104', '46.61553221', '46.61730339', 6952'46.61907456', '46.62084574', '46.62261692', '46.62438809', 6953'46.62615927', '46.62793045', '46.62970163', '46.63147280', 6954'46.63324398', '46.63501516', '46.63678633', '46.63855751', 6955'46.64032869', '46.64209987', '46.64387104', '46.64564222', 6956'46.64741340', '46.64918458', '46.65095576', '46.65272693', 6957'46.65449811', '46.65626929', '46.65804047', '46.65981165', 6958'46.66158283', '46.66335401', '46.66512519', '46.66689636', 6959'46.66866754', '46.67043872', '46.67220990', '46.67398108', 6960'46.67575226', '46.67752344', '46.67929462', '46.68106580', 6961'46.68283698', '46.68460816', '46.68637934', '46.68815052', 6962'46.68992170', '46.69169288', '46.69346406', '46.69523524', 6963'46.69700642', '46.69877760', '46.70054878', '46.70231997', 6964'46.70409115', '46.70586233', '46.70763351', '46.70940469', 6965'46.71117587', '46.71294705', '46.71471824', '46.71648942', 6966'46.71826060', '46.72003178', '46.72180296', '46.72357414', 6967'46.72534533', '46.72711651', '46.72888769', '46.73065887', 6968'46.73243006', '46.73420124', '46.73597242', '46.73774361', 6969'46.73951479', '46.74128597', '46.74305715', '46.74482834', 6970'46.74659952', '46.74837070', '46.75014189', '46.75191307', 6971'46.75368426', '46.75545544', '46.75722662', '46.75899781', 6972'46.76076899', '46.76254018', '46.76431136', '46.76608254', 6973'46.76785373', '46.76962491', '46.77139610', '46.77316728', 6974'46.77493847', '46.77670965', '46.77848084', '46.78025202', 6975'46.78202321', '46.78379439', '46.78556558', '46.78733677', 6976'46.78910795', '46.79087914', '46.79265032', '46.79442151', 6977'46.79619269', '46.79796388', '46.79973507', '46.80150625', 6978'46.80327744', '46.80504863', '46.80681981', '46.80859100', 6979'46.81036219', '46.81213337', '46.81390456', '46.81567575', 6980'46.81744693', '46.81921812', '46.82098931', '46.82276050', 6981'46.82453168', '46.82630287', '46.82807406', '46.82984525', 6982'46.83161644', '46.83338762', '46.83515881', '46.83693000', 6983'46.83870119', '46.84047238', '46.84224357', '46.84401475', 6984'46.84578594', '46.84755713', '46.84932832', '46.85109951', 6985'46.85287070', '46.85464189', '46.85641308', '46.85818427', 6986'46.85995546', '46.86172665', '46.86349784', '46.86526903', 6987'46.86704022', '46.86881141', '46.87058260', '46.87235379', 6988'46.87412498', '46.87589617', '46.87766736', '46.87943855', 6989'46.88120974', '46.88298093', '46.88475212', '46.88652331', 6990'46.88829450', '46.89006569', '46.89183688', '46.89360807', 6991'46.89537927', '46.89715046', '46.89892165', '46.90069284', 6992'46.90246403', '46.90423522', '46.90600642', '46.90777761', 6993'46.90954880', '46.91131999', '46.91309119', '46.91486238', 6994'46.91663357', '46.91840476', '46.92017596', '46.92194715', 6995'46.92371834', '46.92548953', '46.92726073', '46.92903192', 6996'46.93080311', '46.93257431', '46.93434550', '46.93611669', 6997'46.93788789', '46.93965908', '46.94143028', '46.94320147', 6998'46.94497266', '46.94674386', '46.94851505', '46.95028625', 6999'46.95205744', '46.95382864', '46.95559983', '46.95737103', 7000'46.95914222', '46.96091341', '46.96268461', '46.96445581', 7001'46.96622700', '46.96799820', '46.96976939', '46.97154059', 7002'46.97331178', '46.97508298', '46.97685417', '46.97862537', 7003'46.98039657', '46.98216776', '46.98393896', '46.98571015', 7004'46.98748135', '46.98925255', '46.99102374', '46.99279494', 7005'46.99456614', '46.99633733', '46.99810853', '46.99987973', 7006'47.00165092', '47.00342212', '47.00519332', '47.00696452', 7007'47.00873571', '47.01050691', '47.01227811', '47.01404931', 7008'47.01582050', '47.01759170', '47.01936290', '47.02113410', 7009'47.02290530', '47.02467649', '47.02644769', '47.02821889', 7010'47.02999009', '47.03176129', '47.03353249', '47.03530369', 7011'47.03707489', '47.03884608', '47.04061728', '47.04238848', 7012'47.04415968', '47.04593088', '47.04770208', '47.04947328', 7013'47.05124448', '47.05301568', '47.05478688', '47.05655808', 7014'47.05832928', '47.06010048', '47.06187168', '47.06364288', 7015'47.06541408', '47.06718528', '47.06895648', '47.07072768', 7016'47.07249888', '47.07427009', '47.07604129', '47.', '47.07958369', 7017'47.08135489', '47.08312609', '47.08489729', '47.08666850', 7018'47.08843970', '47.09021090', '47.09198210', '47.09375330', 7019'47.09552450', '47.09729571', '47.09906691', '47.10083811', 7020'47.10260931', '47.10438052', '47.10615172', '47.10792292', 7021'47.10969412', '47.11146533', '47.11323653', '47.11500773', 7022'47.11677894', '47.11855014', '47.12032134', '47.12209255', 7023'47.12386375', '47.12563495', '47.12740616', '47.12917736', 7024'47.13094857', '47.13271977', '47.13449097', '47.13626218', 7025'47.13803338', '47.13980459', '47.14157579', '47.14334700', 7026'47.14511820', '47.14688941', '47.14866061', '47.15043182', 7027'47.15220302', '47.15397423', '47.15574543', '47.15751664', 7028'47.15928784', '47.16105905', '47.16283025', '47.16460146', 7029'47.16637266', '47.16814387', '47.16991508', '47.17168628', 7030'47.17345749', '47.17522869', '47.17699990', '47.17877111', 7031'47.18054231', '47.18231352', '47.18408473', '47.18585593', 7032'47.18762714', '47.18939835', '47.19116956', '47.19294076', 7033'47.19471197', '47.19648318', '47.19825439', '47.20002559', 7034'47.20179680', '47.20356801', '47.20533922', '47.20711042', 7035'47.20888163', '47.21065284', '47.21242405', '47.21419526', 7036'47.21596647', '47.21773767', '47.21950888', '47.22128009', 7037'47.22305130', '47.22482251', '47.22659372', '47.22836493', 7038'47.23013614', '47.23190735', '47.23367855', '47.23544976', 7039'47.23722097', '47.23899218', '47.24076339', '47.24253460', 7040'47.24430581', '47.24607702', '47.24784823', '47.24961944', 7041'47.25139065', '47.25316186', '47.25493307', '47.25670429', 7042'47.25847550', '47.26024671', '47.26201792', '47.26378913', 7043'47.26556034', '47.26733155', '47.26910276', '47.27087397', 7044'47.27264518', '47.27441640', '47.27618761', '47.27795882', 7045'47.27973003', '47.28150124', '47.28327246', '47.28504367', 7046'47.28681488', '47.28858609', '47.29035730', '47.29212852', 7047'47.29389973', '47.29567094', '47.29744215', '47.29921337', 7048'47.30098458', '47.30275579', '47.30452701', '47.30629822', 7049'47.30806943', '47.30984065', '47.31161186', '47.31338307', 7050'47.31515429', '47.31692550', '47.31869671', '47.32046793', 7051'47.32223914', '47.32401036', '47.32578157', '47.32755278', 7052'47.32932400', '47.33109521', '47.33286643', '47.33463764', 7053'47.33640886', '47.33818007', '47.33995129', '47.34172250', 7054'47.34349372', '47.34526493', '47.34703615', '47.34880736', 7055'47.35057858', '47.35234979', '47.35412101', '47.35589223']) 7056p64=[(numpy.exp(-(x**2)))/scipy.special.erfc(x)for x in r] 7057a128=[x*((1-1/(2*x**2)+3/(4*x**4)-15/(8*x**6))**(-1)) 7058 *numpy.float128(numpy.sqrt(numpy.pi)) 7059 for x in numpy.asarray(r,dtype='float128')] 7060a64=[x*((1-1/(2*x**2)+3/(4*x**4)-15/(8*x**6)+63/(7*x**8))**(-1)) 7061 *numpy.sqrt(numpy.pi) 7062 for x in r] for a,b,c,d in zip(r,p128,p64,a64):print a,b,c,d,c-b,d-b 7063 7064for i in xrange(1,len(p64)): print i, 64[i]-p64[i-1] 7065 """ 7066 7067 7068# ############### 7069# # Loop fusion # 7070# ############### 7071def local_elemwise_fusion_op(OP, max_input_fct=lambda node: 32, 7072 maker=None): 7073 """ 7074 We parametrize it to make it work for Elemwise and GpuElemwise op. 7075 7076 Parameters 7077 ---------- 7078 OP 7079 GpuElemwise or Elemwise class (the one that we want to fuse) 7080 max_input_fct 7081 A function that returns the maximum number of inputs 7082 that this elemwise can take (useful for GpuElemwise). 7083 GPU kernel currently has a limit of 256 bytes for 7084 the size of all parameters passed to it. As currently 7085 we pass many information only by parameter, we must 7086 limit how many ops we fuse together to avoid busting 7087 that 256 limit. 7088 7089 On the CPU we limit to 32 input variables 7090 since that is the maximum numpy support. 7091 7092 """ 7093 if maker is None: 7094 def maker(node, scalar_op): 7095 return OP(scalar_op) 7096 7097 def local_fuse(node): 7098 """ 7099 As part of specialization, we fuse two consecutive elemwise Ops of the 7100 same shape. 7101 7102 For mixed dtype, we let the Composite op do the cast. It lets the C 7103 compiler do the cast. 7104 The number of dimensions is validated at call time by theano itself. 7105 7106 """ 7107 # META TODO: PUT THESE THINGS IN TRAC, NOT TODO NOTES!! 7108 # TODO: use broadcast flag? 7109 7110 # TODO: don't do this optimization as a localOptimizer. 7111 # Analyze the graph in terms of elemwise subgraphs, and then 7112 # replace each subgraph with a Composite version. 7113 7114 # TODO: use malloc and copy to transfer arguments that don't 7115 # fit within the parameter space of 256 bytes 7116 # 7117 # TODO: Merge with multiple output to merge when an inputs 7118 # have multiple clients. This can't be done with a local 7119 # optimiser. 7120 7121 # TODO: Related: Support composites with multiple outputs 7122 7123 # TODO: Use Composite to combine Elemwise and Reduce 7124 # operations. We have to loop over the data anyway... might 7125 # as well sum it up while we're at it (this can be trickier 7126 # than i'm making it seound here. The data-traversal should be 7127 # done contiguously, and the summing-up might not be easy or 7128 # worthwhile if the summation axis doesn't line up with a 7129 # contiguous dimension) 7130 7131 if type(node.op) is not OP: 7132 return False 7133 7134 if len(node.outputs) > 1: 7135 # We don't support the fusion for node with multiple outputs. 7136 return 7137 inputs = [] # inputs of the new Elemwise op. 7138 s_inputs = [] # inputs of the new scalar op used by the Composite. 7139 # Inputs of the new scalar op that represents the current node. 7140 s_g = [] 7141 7142 # There is a hard limit of 256 bytes for the formal argument list to a 7143 # GPU kernel function. 7144 max_nb_input = max_input_fct(node) 7145 # The number of inputs to the new fused op if we do not fuse more 7146 # inputs. 7147 new_nb_input = len(node.inputs) 7148 # Did we fuse something? 7149 # Needed as we can fuse unary op that don't change the number of 7150 # inputs. 7151 # And there is a case where the inputs are the same as the current 7152 # node. That won't change the number of inputs of the new op. 7153 fused = False 7154 7155 for i in node.inputs: 7156 do_fusion = False 7157 catch = False 7158 # Will store inputs of the fused node that are not currently inputs 7159 # of the node we want to create (to avoid duplicating inputs). 7160 tmp_input = [] 7161 # Same as tmp_input, but for scalars. 7162 tmp_scalar = [] 7163 7164 # We should not check the number of inputs here 7165 # As fusing op don't always change the number of input. 7166 # If a variable is used as multiple into to the same node, 7167 # we still want to fusion. So we take the set. 7168 if (i.owner and 7169 isinstance(i.owner.op, OP) and 7170 len(set([n for n, idx in i.clients])) == 1 and 7171 # Do not merge elemwise that don't have the same 7172 # broadcastable pattern to don't redo duplicate 7173 # computation due to broadcast. 7174 i.owner.outputs[0].broadcastable == 7175 node.outputs[0].broadcastable): 7176 do_fusion = True 7177 try: 7178 tmp_s_input = [] 7179 # we should not put duplicate input into s_inputs and inputs 7180 for ii in i.owner.inputs: 7181 if ii in inputs: 7182 tmp_s_input.append(s_inputs[inputs.index(ii)]) 7183 elif ii in tmp_input: 7184 tmp_s_input.append(tmp_scalar[tmp_input.index(ii)]) 7185 else: 7186 tmp = scalar.get_scalar_type(ii.dtype).make_variable() 7187 try: 7188 tv = gof.op.get_test_value(ii) 7189 if tv.size > 0: 7190 tmp.tag.test_value = tv.flatten()[0] 7191 else: 7192 tmp.tag.test_value = tv 7193 except AttributeError: 7194 pass 7195 tmp_s_input.append(tmp) 7196 tmp_input.append(ii) 7197 tmp_scalar.append(tmp_s_input[-1]) 7198 s_op = i.owner.op.scalar_op(*tmp_s_input, 7199 return_list=True) 7200 7201 # if the scalar_op don't have a c implementation, 7202 # we skip its fusion to allow the fusion of the 7203 # other ops. 7204 i.owner.op.scalar_op.c_code(s_op[0].owner, 7205 "test_presence_of_c_code", 7206 ["x" for x in i.owner.inputs], 7207 ["z" for z in i.owner.outputs], 7208 {"fail": "%(fail)s"}) 7209 except MethodNotDefined: 7210 catch = True 7211 except NotImplementedError: 7212 catch = True 7213 if catch: 7214 _logger.info(("%s does not implement the c_code function." 7215 " As well as being potentially slow, this" 7216 " disables loop fusion of this op.") % 7217 str(i.owner.op.scalar_op)) 7218 do_fusion = False 7219 7220 # Compute the number of inputs in case we fuse this input. 7221 # We subtract 1 because we replace the existing input with the new 7222 # inputs from `tmp_input`. 7223 new_nb_input_ = new_nb_input + len(tmp_input) - 1 7224 7225 # If the new input is already an input of the current node, it was 7226 # already counted when `new_nb_input` was initialized to 7227 # len(node.inputs). 7228 # This can happen when a variable is used both by the Elemwise to 7229 # fuse and the current node. 7230 for x in tmp_input: 7231 if x in node.inputs: 7232 new_nb_input_ -= 1 7233 7234 if do_fusion and (new_nb_input_ <= max_nb_input): 7235 fused = True 7236 new_nb_input = new_nb_input_ 7237 inputs.extend(tmp_input) 7238 s_inputs.extend(tmp_scalar) 7239 s_g.extend(s_op) 7240 else: 7241 # We must support the case where the same variable appear many 7242 # time in the inputs 7243 if inputs.count(i) == node.inputs.count(i): 7244 s = s_inputs[inputs.index(i)] 7245 else: 7246 s = scalar.get_scalar_type(i.dtype).make_variable() 7247 try: 7248 if theano.config.compute_test_value != 'off': 7249 v = gof.op.get_test_value(i) 7250 if v.size > 0: 7251 s.tag.test_value = v.flatten()[0] 7252 except AttributeError: 7253 pass 7254 7255 inputs.append(i) 7256 s_inputs.append(s) 7257 s_g.append(s) 7258 7259 if not fused: 7260 return False 7261 7262 if new_nb_input != len(inputs) or len(s_inputs) != len(inputs): 7263 raise Exception("""Something has gone wrong with the elemwise 7264fusion optimization. We skip this optimization. You can ignore this message, 7265your code will run correctly, but may be slower.""") 7266 7267 s_new_out = node.op.scalar_op(*s_g, return_list=True) 7268 try: 7269 s_new_out[0].owner.op.c_code(s_new_out[0].owner, 7270 "test_presence_of_c_code", 7271 ["x" for x in s_g], 7272 ["z" for x in s_new_out], 7273 {"fail": "%(fail)s"}) 7274 except MethodNotDefined: 7275 _logger.info(("%s does not implement the c_code function." 7276 " As well as being potentially slow, this disables " 7277 "loop fusion of this op.") % str( 7278 s_new_out[0].owner.op)) 7279 return False 7280 except NotImplementedError: 7281 _logger.info(("%s does not implement the c_code function. As well" 7282 " as being potentially slow, this disables loop" 7283 " fusion of this op.") % str(s_new_out[0].owner.op)) 7284 return False 7285 7286 # create the composite op. 7287 C = scalar.Composite(s_inputs, s_new_out) 7288 7289 # create the new node. 7290 # Do not call make_node to have test_value 7291 n = maker(node, C)(*inputs).owner 7292 assert len(n.outputs) == 1 7293 assert node.outputs[0].dtype == n.outputs[0].dtype 7294 7295 if len(n.inputs) > max_nb_input: 7296 _logger.info('loop fusion failed because Op would exceed' 7297 ' kernel argument limit.') 7298 return False 7299 7300 # we fuse as many that we can at the same time to make debug mode faster 7301 # debug mode will be faster as it won't test all intermediate step. 7302 while True: 7303 ret = local_fuse(n) 7304 if ret is not False and ret is not None: 7305 # print n,ret 7306 assert len(ret) == len(n.outputs) 7307 assert len(ret) == 1 7308 n = ret[0].owner 7309 else: 7310 break 7311 7312 return n.outputs 7313 return local_fuse 7314 7315 7316def elemwise_max_input_fct(node): 7317 # The Elemwise.perform use numpy ufunc and they are limited to 31 7318 # inputs. 7319 if not theano.config.cxx: 7320 return 31 7321 return 1024 7322 7323 7324local_elemwise_fusion = local_elemwise_fusion_op(T.Elemwise, 7325 elemwise_max_input_fct) 7326 7327 7328class FusionOptimizer(Optimizer): 7329 """Graph optimizer for Fusion of elemwise operations.""" 7330 def __init__(self, local_optimizer): 7331 Optimizer.__init__(self) 7332 self.optimizer = local_optimizer 7333 7334 def add_requirements(self, fgraph): 7335 fgraph.attach_feature(toolbox.ReplaceValidate()) 7336 7337 def apply(self, fgraph): 7338 did_something = True 7339 nb_iter = 0 7340 nb_replacement = 0 7341 nb_inconsistency_replace = 0 7342 time_toposort = 0 7343 if fgraph.profile: 7344 validate_before = fgraph.profile.validate_time 7345 callbacks_before = fgraph.execute_callbacks_times.copy() 7346 callback_before = fgraph.execute_callbacks_time 7347 while did_something: 7348 t0 = time.time() 7349 nodelist = list(fgraph.toposort()) 7350 time_toposort += time.time() - t0 7351 nodelist.reverse() 7352 did_something = False 7353 for node in nodelist: 7354 # Don't try to fuse node that have already been fused. 7355 if node in fgraph.apply_nodes: 7356 new_outputs = self.optimizer(node) 7357 if new_outputs: 7358 assert len(new_outputs) == len(node.outputs) 7359 try: 7360 fgraph.replace_all_validate( 7361 list(zip(node.outputs, new_outputs)), 7362 reason=self.__class__.__name__) 7363 did_something = True 7364 nb_replacement += 1 7365 except InconsistencyError: 7366 nb_inconsistency_replace += 1 7367 pass 7368 nb_iter += 1 7369 7370 if fgraph.profile: 7371 validate_time = fgraph.profile.validate_time - validate_before 7372 callback_time = fgraph.execute_callbacks_time - callback_before 7373 callbacks_time = {} 7374 for k, v in iteritems(fgraph.execute_callbacks_times): 7375 if k in callbacks_before: 7376 callbacks_time[k] = v - callbacks_before[k] 7377 else: 7378 callbacks_time[k] = v 7379 else: 7380 validate_time = None 7381 callback_time = None 7382 callbacks_time = {} 7383 return (self, nb_iter, nb_replacement, 7384 nb_inconsistency_replace, 7385 validate_time, callback_time, callbacks_time, 7386 time_toposort) 7387 7388 @staticmethod 7389 def print_profile(stream, prof, level=0): 7390 blanc = (' ' * level) 7391 print(blanc, "FusionOptimizer", file=stream) 7392 print(blanc, " nb_iter", prof[1], file=stream) 7393 print(blanc, " nb_replacement", prof[2], file=stream) 7394 print(blanc, " nb_inconsistency_replace", prof[3], file=stream) 7395 print(blanc, " validate_time", prof[4], file=stream) 7396 print(blanc, " callback_time", prof[5], file=stream) 7397 if prof[5] > 1: 7398 print(blanc, " callbacks_time", file=stream) 7399 for i in sorted(iteritems(prof[6]), key=lambda a: a[1])[::-1]: 7400 if i[1] > 0: 7401 print(blanc, " ", i) 7402 print(blanc, " time_toposort", prof[7], file=stream) 7403 7404 7405def local_add_mul_fusion(node): 7406 """Fuse consecutive add or mul in one such node with more inputs. 7407 7408 It is better to fuse add/mul that way then in a Composite node as 7409 this make the inner graph of the Composite smaller. This allow to 7410 put more computation in a Composite before hitting the max 7411 recusion limit when pickling Composite. 7412 7413 """ 7414 if (not isinstance(node.op, Elemwise) or 7415 not isinstance(node.op.scalar_op, (scalar.Add, scalar.Mul))): 7416 return False 7417 7418 s_op = node.op.scalar_op.__class__ 7419 new_inp = [] 7420 fused = False 7421 nb_inputs = len(node.inputs) 7422 max_inputs = float('inf') 7423 if hasattr(node.op, 'max_inputs'): 7424 max_inputs = node.op.max_inputs(node) 7425 for inp in node.inputs: 7426 if (inp.owner and 7427 isinstance(inp.owner.op, Elemwise) and 7428 isinstance(inp.owner.op.scalar_op, s_op) and 7429 # Do not duplicate the operation. 7430 len(inp.clients) == 1 and 7431 (nb_inputs + len(inp.owner.inputs) - 1) <= max_inputs): 7432 new_inp.extend(inp.owner.inputs) 7433 fused = True 7434 else: 7435 new_inp.append(inp) 7436 7437 # We can not compare the number of inputs as Mul and Add could have 7438 # 0 or 1 inputs in some corner cases. 7439 if fused: 7440 output = node.op(*new_inp) 7441 copy_stack_trace(node.outputs[0], output) 7442 7443 # Do the recursion here to help lower the number of 7444 # FusionOptimizer iteration. 7445 if output.owner: 7446 output2 = local_add_mul_fusion(output.owner) 7447 if output2: 7448 return output2 7449 return [output] 7450 7451if config.tensor.local_elemwise_fusion: 7452 _logger.debug("enabling optimization fusion elemwise in fast_run") 7453 # Must be after gpu(48.5) and before AddDestroyHandler(49.5) 7454 fuse_seqopt = gof.SequenceDB() 7455 fuse_seqopt.register('local_add_mul_fusion', 7456 FusionOptimizer(local_add_mul_fusion), 7457 0, 'fast_run', 'fusion') 7458 fuse_seqopt.register('composite_elemwise_fusion', 7459 FusionOptimizer(local_elemwise_fusion), 7460 1, 'fast_run', 'fusion') 7461 compile.optdb.register('elemwise_fusion', 7462 fuse_seqopt, 49, 7463 'fast_run', 'fusion', 'local_elemwise_fusion', 7464 'FusionOptimizer') 7465else: 7466 _logger.debug("not enabling optimization fusion elemwise in fast_run") 7467 compile.optdb.register('elemwise_fusion', 7468 FusionOptimizer(local_elemwise_fusion), 49, 7469 'fusion', 'local_elemwise_fusion', 7470 'FusionOptimizer') 7471 7472 7473@register_canonicalize 7474@gof.local_optimizer([Elemwise]) 7475def local_useless_composite(node): 7476 """For elemwise Composite that have multiple outputs, remove the 7477 outputs that are not used. 7478 7479 """ 7480 if (not isinstance(node.op, Elemwise) or 7481 not isinstance(node.op.scalar_op, scalar.Composite)): 7482 return 7483 comp = node.op.scalar_op 7484 idx = [i for i, o_extern in enumerate(node.outputs) 7485 if o_extern.clients] 7486 if len(idx) < len(node.outputs): 7487 new_outputs = [comp.outputs[i] for i in idx] 7488 c = scalar.Composite(inputs=comp.inputs, 7489 outputs=new_outputs) 7490 e = Elemwise(scalar_op=c)(*node.inputs, return_list=True) 7491 return dict(zip([node.outputs[i] for i in idx], e)) 7492 7493# ############################ 7494# # Remove consider_constant # 7495# ############################ 7496 7497 7498# Although the ops ConsiderConstant, ZeroGrad and DisconnectedGrad 7499# just returns the input, it should be removed from the graph to 7500@register_canonicalize('fast_compile') 7501@register_useless('fast_compile') 7502@gof.local_optimizer(None) 7503def local_view_op(node): 7504 if isinstance(node.op, theano.compile.ops.ViewOp): 7505 return node.inputs 7506 7507 7508@register_useless 7509@register_canonicalize 7510@register_stabilize 7511@register_specialize 7512@gof.local_optimizer([T.Alloc]) 7513def local_merge_alloc(node): 7514 # This opt takes care of several cases: 7515 # Alloc(Alloc(m, x, 1, 1, 1), x, y, z, w) -> Alloc(m, x, y, z, w) 7516 # Alloc(Alloc(m, y, 1, 1), x, y, z, w) -> Alloc(m, x, y, z, w) 7517 # Alloc(Alloc(m, y1, 1, 1), x, y2, z, w) -> Alloc(m, x, assert(y1, y1==y2), z, w) 7518 if not isinstance(node.op, T.Alloc): 7519 return False 7520 if not node.inputs[0].owner or not isinstance( 7521 node.inputs[0].owner.op, T.Alloc): 7522 return False 7523 inputs_outer = node.inputs 7524 inputs_inner = node.inputs[0].owner.inputs 7525 dims_outer = inputs_outer[1:] 7526 dims_inner = inputs_inner[1:] 7527 dims_outer_rev = dims_outer[::-1] 7528 dims_inner_rev = dims_inner[::-1] 7529 # check if the pattern of broadcasting is matched, in the reversed ordering. 7530 # The reverse ordering is needed when an Alloc add an implicit new 7531 # broadcasted dimensions to its inputs[0]. Eg: 7532 # Alloc(Alloc(m, y, 1, 1), x, y, z, w) -> Alloc(m, x, y, z, w) 7533 i = 0 7534 for dim_inner, dim_outer in zip(dims_inner_rev, dims_outer_rev): 7535 if dim_inner != dim_outer: 7536 if isinstance(dim_inner, Constant) and dim_inner.data == 1: 7537 pass 7538 else: 7539 dims_outer[-1 - i] = Assert( 7540 "You have a shape error in your graph. To see a better" 7541 " error message and a stack trace of where in your code" 7542 " the error is created, use the Theano flags" 7543 " optimizer=None or optimizer=fast_compile.")( 7544 dim_outer, T.eq(dim_outer, dim_inner)) 7545 i += 1 7546 return [T.alloc(inputs_inner[0], *dims_outer)] 7547 7548 7549@register_useless('fast_compile') 7550@gof.local_optimizer([TopKOp]) 7551def local_useless_topk(node): 7552 """ 7553 TopKOp generates two outputs by default 7554 This opt removes the useless ones 7555 7556 """ 7557 op = node.op 7558 if not isinstance(op, TopKOp): 7559 return 7560 if not (op.return_values and op.return_indices): 7561 return False 7562 7563 x, k = node.inputs 7564 ret_val = bool(node.outputs[0].clients) 7565 ret_idx = bool(node.outputs[1].clients) 7566 7567 if not (ret_val ^ ret_idx): 7568 # both true -> nothing to remove 7569 # both false -> let pruner handle 7570 return False 7571 7572 old_output = node.outputs[ret_idx] 7573 new_output = TopKOp( 7574 axis=op.axis, 7575 sorted=op.sorted, 7576 idx_dtype=op.idx_dtype, 7577 return_values=ret_val, 7578 return_indices=ret_idx)(x, k) 7579 copy_stack_trace(node.outputs[0], new_output) 7580 return {old_output: new_output} 7581