1""" 2Defines the base class for optimizations as well as a certain 3amount of useful generic optimization tools. 4 5""" 6from __future__ import absolute_import, print_function, division 7 8from collections import deque, defaultdict, OrderedDict 9import contextlib 10import copy 11import inspect 12import logging 13import pdb 14import sys 15import time 16import warnings 17import traceback 18 19import numpy as np 20 21import theano 22from theano import config 23from theano.compat import izip 24from six import string_types, iteritems, itervalues, integer_types 25from six.moves import reduce 26from theano.gof import graph, op, utils, unify, toolbox 27from theano.gof.fg import InconsistencyError 28from theano.misc.ordered_set import OrderedSet 29 30from . import destroyhandler as dh 31 32_logger = logging.getLogger('theano.gof.opt') 33_optimizer_idx = [0] 34 35 36def _list_of_nodes(fgraph): 37 return list(graph.io_toposort(fgraph.inputs, fgraph.outputs)) 38 39 40class LocalMetaOptimizerSkipAssertionError(AssertionError): 41 """This is an AssertionError, but instead of having the 42 LocalMetaOptimizer print the error, it just skip that 43 compilation. 44 45 """ 46 pass 47 48 49class Optimizer(object): 50 """ 51 52 An L{Optimizer} can be applied to an L{FunctionGraph} to transform it. 53 It can represent an optimization or in general any kind 54 of transformation you could apply to an L{FunctionGraph}. 55 56 """ 57 58 def __hash__(self): 59 if not hasattr(self, '_optimizer_idx'): 60 self._optimizer_idx = _optimizer_idx[0] 61 _optimizer_idx[0] += 1 62 return self._optimizer_idx 63 64 def __eq__(self, other): 65 # added to override the __eq__ implementation that may be inherited 66 # in subclasses from other bases. 67 return id(self) == id(other) 68 69 def __ne__(self, other): 70 # added to override the __ne__ implementation that may be inherited 71 # in subclasses from other bases. 72 return id(self) != id(other) 73 74 def apply(self, fgraph): 75 """ 76 77 Applies the optimization to the provided L{FunctionGraph}. It may 78 use all the methods defined by the L{FunctionGraph}. If the 79 L{Optimizer} needs to use a certain tool, such as an 80 L{InstanceFinder}, it can do so in its L{add_requirements} method. 81 82 """ 83 pass 84 85 def optimize(self, fgraph, *args, **kwargs): 86 """ 87 88 This is meant as a shortcut to: 89 opt.add_requirements(fgraph) 90 opt.apply(fgraph) 91 92 """ 93 self.add_requirements(fgraph) 94 try: 95 orig = theano.tensor.basic.constant.enable 96 theano.tensor.basic.constant.enable = False 97 ret = self.apply(fgraph, *args, **kwargs) 98 finally: 99 theano.tensor.basic.constant.enable = orig 100 return ret 101 102 def __call__(self, fgraph): 103 """ 104 105 Same as self.optimize(fgraph). 106 107 """ 108 return self.optimize(fgraph) 109 110 def add_requirements(self, fgraph): 111 """ 112 113 Add features to the fgraph that are required to apply the optimization. 114 For example: 115 fgraph.attach_feature(History()) 116 fgraph.attach_feature(MyFeature()) 117 etc. 118 119 """ 120 pass 121 122 def print_summary(self, stream=sys.stdout, level=0, depth=-1): 123 name = getattr(self, 'name', None) 124 print("%s%s %s id=%i" % ( 125 (' ' * level), self.__class__.__name__, name, id(self)), file=stream) 126 127 @staticmethod 128 def print_profile(stream, prof, level=0): 129 if prof is not None: 130 raise NotImplementedError( 131 "The function print_profile must be overrided if the" 132 " optimizer return profiling information.") 133 134 135class FromFunctionOptimizer(Optimizer): 136 """ 137 WRITEME 138 139 """ 140 def __init__(self, fn, requirements=()): 141 self.apply = fn 142 self.requirements = requirements 143 144 def add_requirements(self, fgraph): 145 for req in self.requirements: 146 req(fgraph) 147 148 def print_summary(self, stream=sys.stdout, level=0, depth=-1): 149 print("%s%s id=%i" % ( 150 ' ' * level, 151 str(self.apply), 152 id(self)), file=stream) 153 154 def __call__(self, *args, **kwargs): 155 return self.fn(*args, **kwargs) 156 157 def __str__(self): 158 return self.__name__ 159 160 161def optimizer(f): 162 """ 163 Decorator for FromFunctionOptimizer. 164 165 """ 166 rval = FromFunctionOptimizer(f) 167 rval.__name__ = f.__name__ 168 return rval 169 170 171def inplace_optimizer(f): 172 """ 173 Decorator for FromFunctionOptimizer. 174 175 """ 176 dh_handler = dh.DestroyHandler 177 requirements = (lambda fgraph: 178 fgraph.attach_feature(dh_handler()),) 179 rval = FromFunctionOptimizer(f, requirements) 180 rval.__name__ = f.__name__ 181 return rval 182 183 184class SeqOptimizer(Optimizer, list): 185 # inherit from Optimizer first to get Optimizer.__hash__ 186 """ 187 188 Takes a list of L{Optimizer} instances and applies them 189 sequentially. 190 191 """ 192 @staticmethod 193 def warn(exc, self, optimizer): 194 """ 195 Default failure_callback for SeqOptimizer. 196 197 """ 198 _logger.error("SeqOptimizer apply %s" % str(optimizer)) 199 _logger.error("Traceback:") 200 _logger.error(traceback.format_exc()) 201 if config.on_opt_error == 'raise': 202 raise exc 203 elif config.on_opt_error == 'pdb': 204 pdb.post_mortem(sys.exc_info()[2]) 205 206 def __init__(self, *opts, **kw): 207 """ 208 Parameters 209 ---------- 210 *opts : 211 The List of optimizers to be applied to a node 212 failure_callback : callable or None 213 Keyword only argument. A callback used when a failure 214 happen during optimization. 215 216 """ 217 if len(opts) == 1 and isinstance(opts[0], (list, tuple)): 218 opts = opts[0] 219 self[:] = opts 220 self.failure_callback = kw.pop('failure_callback', None) 221 assert len(kw) == 0 222 223 def apply(self, fgraph): 224 """ 225 226 Applies each L{Optimizer} in self in turn. 227 228 """ 229 l = [] 230 if fgraph.profile: 231 validate_before = fgraph.profile.validate_time 232 sub_validate_time = [validate_before] 233 callbacks_before = fgraph.execute_callbacks_times.copy() 234 else: 235 sub_validate_time = [] 236 callbacks_before = [] 237 callback_before = fgraph.execute_callbacks_time 238 nb_node_before = len(fgraph.apply_nodes) 239 sub_profs = [] 240 nb_nodes = [] 241 242 self.pre_profile = ( 243 self, l, -1, -1, nb_node_before, 244 -1, sub_profs, sub_validate_time, 245 nb_nodes, {}) 246 try: 247 for optimizer in self: 248 try: 249 nb_nodes_before = len(fgraph.apply_nodes) 250 t0 = time.time() 251 sub_prof = optimizer.optimize(fgraph) 252 l.append(float(time.time() - t0)) 253 sub_profs.append(sub_prof) 254 nb_nodes.append((nb_nodes_before, 255 len(fgraph.apply_nodes))) 256 if fgraph.profile: 257 sub_validate_time.append(fgraph.profile.validate_time) 258 except AssertionError: 259 # do not catch Assertion failures 260 raise 261 except Exception as e: 262 if self.failure_callback: 263 self.failure_callback(e, self, optimizer) 264 continue 265 else: 266 raise 267 finally: 268 269 if fgraph.profile: 270 validate_time = fgraph.profile.validate_time - validate_before 271 callbacks_time = {} 272 for k, v in iteritems(fgraph.execute_callbacks_times): 273 if k in callbacks_before: 274 t = v - callbacks_before[k] 275 if t > 0: 276 callbacks_time[k] = t 277 else: 278 callbacks_time[k] = v 279 else: 280 validate_time = None 281 callbacks_time = {} 282 callback_time = fgraph.execute_callbacks_time - callback_before 283 self.pre_profile = ( 284 self, l, validate_time, callback_time, nb_node_before, 285 len(fgraph.apply_nodes), sub_profs, sub_validate_time, 286 nb_nodes, callbacks_time) 287 return self.pre_profile 288 289 def __str__(self): 290 return "SeqOpt(%s)" % list.__str__(self) 291 292 def __repr__(self): 293 return list.__repr__(self) 294 295 def print_summary(self, stream=sys.stdout, level=0, depth=-1): 296 name = getattr(self, 'name', None) 297 print("%s%s %s id=%i" % ( 298 (' ' * level), self.__class__.__name__, name, id(self)), file=stream) 299 # This way, -1 will do all depth 300 if depth != 0: 301 depth -= 1 302 for opt in self: 303 opt.print_summary(stream, level=(level + 2), depth=depth) 304 305 @staticmethod 306 def print_profile(stream, prof, level=0): 307 (opts, prof, validate_time, callback_time, 308 nb_node_before, nb_node_after, sub_profs, sub_validate_time, 309 nb_nodes, callbacks_time) = prof 310 blanc = (' ' * level) 311 312 print(blanc, "SeqOptimizer", end=' ', file=stream) 313 if hasattr(opts, "name"): 314 print(blanc, opts.name, end=' ', file=stream) 315 elif hasattr(opts, "__name__"): 316 print(blanc, opts.__name__, end=' ', file=stream) 317 print((" time %.3fs for %d/%d nodes" 318 " before/after optimization" % ( 319 sum(prof), nb_node_before, nb_node_after)), file=stream) 320 print(blanc, " %.3fs for callback" % (callback_time), file=stream) 321 print(blanc, " %.3fs for fgraph.validate()" % (validate_time), 322 file=stream) 323 if callback_time > 1: 324 print(blanc, " callbacks_time", file=stream) 325 for i in sorted(iteritems(callbacks_time), key=lambda a: -a[1]): 326 if i[1] > 0: 327 # We want to have the __str__ called, so we can't 328 # just print i. 329 print(blanc, " ", i[0], ',', i[1], file=stream) 330 331 if level == 0: 332 print(blanc, 333 " time - (name, class, index, nodes before, nodes after) - validate time", 334 file=stream) 335 ll = [] 336 for (opt, nb_n) in zip(opts, nb_nodes): 337 if hasattr(opt, "__name__"): 338 name = opt.__name__ 339 else: 340 name = opt.name 341 idx = opts.index(opt) 342 ll.append((name, opt.__class__.__name__, 343 idx) + nb_n) 344 lll = sorted(zip(prof, ll), key=lambda a: a[0]) 345 346 for (t, opt) in lll[::-1]: 347 i = opt[2] 348 if sub_validate_time: 349 val_time = sub_validate_time[i + 1] - sub_validate_time[i] 350 print(blanc, ' %.6fs - %s - %.3fs' % ( 351 t, opt, val_time), file=stream) 352 else: 353 print(blanc, ' %.6fs - %s' % (t, opt), file=stream) 354 355 if sub_profs[i]: 356 opts[i].print_profile(stream, sub_profs[i], 357 level=level + 1) 358 print(file=stream) 359 360 @staticmethod 361 def merge_profile(prof1, prof2): 362 """ 363 Merge 2 profiles returned by this cass apply() fct. 364 365 """ 366 new_t = [] # the time for the optimization 367 new_l = [] # the optimization 368 new_sub_profile = [] 369 # merge common(same object) opt 370 for l in set(prof1[0]).intersection(set(prof2[0])): 371 idx1 = prof1[0].index(l) 372 idx2 = prof2[0].index(l) 373 new_t.append(prof1[1][idx1] + 374 prof2[1][idx2]) 375 new_l.append(l) 376 if hasattr(l, 'merge_profile'): 377 assert len(prof1[6][idx1]) == len(prof2[6][idx2]) 378 new_sub_profile.append(l.merge_profile(prof1[6][idx1], 379 prof2[6][idx2])) 380 else: 381 new_sub_profile.append(None) 382 383 # merge not common opt 384 from six import StringIO 385 for l in set(prof1[0]).symmetric_difference(set(prof2[0])): 386 # The set trick above only work for the same object optimization 387 # It don't work for equivalent optimization. 388 # So we try to merge equivalent optimization here. 389 new_l_names = [o.name for o in new_l] 390 if l.name in new_l_names: 391 idx = new_l_names.index(l.name) 392 io1 = StringIO() 393 io2 = StringIO() 394 l.print_summary(io1) 395 new_l[idx].print_summary(io2) 396 if io1.read() == io2.read(): 397 if l in prof1[0]: 398 p = prof1 399 else: 400 p = prof2 401 new_t[idx] += p[1][p[0].index(l)] 402 if hasattr(l, 'merge_profile'): 403 assert len(p[6][p[0].index(l)]) == \ 404 len(new_sub_profile[idx]) 405 new_sub_profile[idx] = l.merge_profile( 406 new_sub_profile[idx], p[6][p[0].index(l)]) 407 else: 408 new_sub_profile[idx] = None 409 continue 410 if l in prof1[0]: 411 p = prof1 412 else: 413 p = prof2 414 new_t.append(p[1][p[0].index(l)]) 415 idx = p[0].index(l) 416 new_l.append(l) 417 new_sub_profile.append(p[6][idx]) 418 419 new_opt = SeqOptimizer(*new_l) 420 new_nb_nodes = [] 421 for p1, p2 in zip(prof1[8], prof2[8]): 422 new_nb_nodes.append((p1[0] + p2[0], p1[1] + p2[1])) 423 new_nb_nodes.extend(prof1[8][len(new_nb_nodes):]) 424 new_nb_nodes.extend(prof2[8][len(new_nb_nodes):]) 425 426 new_callbacks_times = merge_dict(prof1[9], prof2[9]) 427 # We need to assert based on the name as we merge also based on 428 # the name. 429 assert set([l.name for l in prof1[0]]).issubset( 430 set([l.name for l in new_l])) 431 assert set([l.name for l in prof2[0]]).issubset( 432 set([l.name for l in new_l])) 433 assert len(new_t) == len(new_opt) == len(new_sub_profile) 434 return (new_opt, new_t, prof1[2] + prof2[2], 435 prof1[3] + prof2[3], 436 -1, -1, new_sub_profile, [], 437 new_nb_nodes, 438 new_callbacks_times) 439 440 441class _metadict: 442 """ 443 WRITEME 444 445 """ 446 447 # dict that accepts unhashable keys 448 # uses an associative list 449 # for internal use only 450 def __init__(self): 451 self.d = {} 452 self.l = [] 453 454 def __getitem__(self, item): 455 return self.get(item, None) 456 457 def __setitem__(self, item, value): 458 try: 459 self.d[item] = value 460 except Exception: 461 for i, (key, val) in enumerate(self.l): 462 if key == item: 463 self.l[i] = (item, value) 464 return 465 self.l.append((item, value)) 466 467 def __delitem__(self, item): 468 try: 469 if item in self.d: 470 del self.d[item] 471 return 472 except TypeError as e: 473 assert "unhashable type" in str(e) 474 for i, (key, val) in enumerate(self.l): 475 if key == item: 476 del self.l[i] 477 return 478 raise KeyError(item) 479 480 def discard(self, item): 481 try: 482 if item in self.d: 483 del self.d[item] 484 return 485 except TypeError as e: 486 assert "unhashable type" in str(e) 487 for i, (key, val) in enumerate(self.l): 488 if key == item: 489 del self.l[i] 490 return 491 492 def get(self, item, default): 493 try: 494 return self.d[item] 495 except Exception: 496 for item2, value in self.l: 497 try: 498 if item == item2: 499 return value 500 if item.equals(item2): 501 return value 502 except Exception: 503 if item is item2: 504 return value 505 return default 506 507 def clear(self): 508 self.d = {} 509 self.l = [] 510 511 def __str__(self): 512 return "(%s, %s)" % (self.d, self.l) 513 514 515class MergeFeature(object): 516 """ 517 Keeps track of variables in fgraph that cannot be merged together. 518 519 That way, the MergeOptimizer can remember the result of the last merge 520 pass on the fgraph. 521 522 """ 523 def on_attach(self, fgraph): 524 assert not hasattr(fgraph, 'merge_feature') 525 fgraph.merge_feature = self 526 527 # For constants 528 self.seen_constants = set() 529 # variable -> signature (for constants) 530 self.const_sig = _metadict() 531 # signature -> variable (for constants) 532 self.const_sig_inv = _metadict() 533 534 # For all Apply nodes 535 # Set of distinct (not mergeable) nodes 536 self.nodes_seen = set() 537 # Ordered set of distinct (not mergeable) nodes without any input 538 self.noinput_nodes = OrderedSet() 539 540 # Each element of scheduled is a list of list of (out, new_out) pairs. 541 # Each list of pairs represent the substitution needed to replace all 542 # the outputs of a node with the outputs of a replacement candidate. 543 # Each node can have several candidates. For instance, if "node" has 544 # 2 outputs, and there are 3 replacement candidates, we will have: 545 # shelf.scheduled = [ 546 # [[(node.out1, cand1.out1), (node.out2, cand1.out2)], 547 # [(node.out1, cand2.out1), (node.out2, cand2.out2)], 548 # [(node.out1, cand3.out1), (node.out2, cand3.out2)]]] 549 self.scheduled = [] 550 551 # List of (node, candidate) pairs, where we tried to replace node by 552 # candidate, but it failed. This is used to avoid infinite loops 553 # during the replacement phase. 554 self.blacklist = [] 555 556 for node in fgraph.toposort(): 557 self.on_import(fgraph, node, "on_attach") 558 559 def on_change_input(self, fgraph, node, i, r, new_r, reason): 560 # If inputs to node change, it is not guaranteed that it is distinct 561 # from the other nodes in nodes_seen 562 if node in self.nodes_seen: 563 self.nodes_seen.discard(node) 564 self.process_node(fgraph, node) 565 566 # Since we are in on_change_input, node should have inputs. 567 if not isinstance(node, string_types): 568 assert node.inputs 569 570 if isinstance(new_r, graph.Constant): 571 self.process_constant(fgraph, new_r) 572 573 def on_import(self, fgraph, node, reason): 574 for c in node.inputs: 575 if isinstance(c, graph.Constant): 576 self.process_constant(fgraph, c) 577 578 self.process_node(fgraph, node) 579 580 def on_prune(self, fgraph, node, reason): 581 self.nodes_seen.discard(node) 582 if not node.inputs: 583 self.noinput_nodes.discard(node) 584 for c in node.inputs: 585 if isinstance(c, graph.Constant) and (len(c.clients) <= 1): 586 # This was the last node using this constant 587 sig = self.const_sig[c] 588 self.const_sig.discard(c) 589 self.const_sig_inv.discard(sig) 590 self.seen_constants.discard(id(c)) 591 592 def process_constant(self, fgraph, c): 593 """ 594 Check if a constant can be merged, and queue that replacement. 595 596 """ 597 if id(c) in self.seen_constants: 598 return 599 sig = c.merge_signature() 600 other_c = self.const_sig_inv.get(sig, None) 601 if other_c is not None: 602 # multiple names will clobber each other.. 603 # we adopt convention to keep the last name 604 if c.name: 605 other_c.name = c.name 606 self.scheduled.append([[(c, other_c, 'merge')]]) 607 else: 608 # this is a new constant 609 self.const_sig[c] = sig 610 self.const_sig_inv[sig] = c 611 self.seen_constants.add(id(c)) 612 613 def process_node(self, fgraph, node): 614 """ 615 Check if a node can be merged, and queue that replacement. 616 617 """ 618 if node in self.nodes_seen: 619 return 620 621 node_has_assert = False 622 623 # These asserts ensure that the fgraph has set the clients field 624 # properly. 625 # The clients should at least contain `node` itself! 626 if node.inputs: 627 # Take the smallest clients list. Some ops like elemwise 628 # have optimization that put constant as the first inputs. 629 # As constant have in general more clients than other type of nodes 630 # using always inputs[0] make us look at more nodes. 631 # Always pick the smallest clints list between inputs 0 632 # and -1 speed up optimization. 633 634 if len(node.inputs[0].clients) < len(node.inputs[-1].clients): 635 clients = node.inputs[0].clients 636 else: 637 clients = node.inputs[-1].clients 638 assert len(clients) > 0 639 640 merge_candidates = [c for c, i in clients if c in self.nodes_seen] 641 642 # Put all clients of Assert inputs (if exist) into merge_candidates 643 # TODO: Deactivated for now as this cause cycle in the graph. 644 # (There is a second deactivation part below.) 645 for i in []: # node.inputs: 646 if i.owner and isinstance(i.owner.op, 647 theano.tensor.opt.Assert): 648 node_has_assert = True 649 assert_clients = [c for (c, _) in i.owner.inputs[0].clients 650 if c in self.nodes_seen] 651 652 for idx in range(len(assert_clients)): 653 client = assert_clients[idx] 654 if isinstance(i.owner.op, theano.tensor.opt.Assert): 655 for c in client.outputs[0].clients: 656 if c[0] in self.nodes_seen: 657 assert_clients.append(c[0]) 658 659 merge_candidates.extend(assert_clients) 660 else: 661 # If two nodes have no input, but perform the same operation, 662 # they are not always constant-folded, so we want to merge them. 663 # In that case, the candidates are all the nodes without inputs. 664 merge_candidates = self.noinput_nodes 665 666 replacement_candidates = [] 667 for candidate in merge_candidates: 668 if candidate is node: 669 continue 670 if len(node.inputs) != len(candidate.inputs): 671 continue 672 673 cand_has_assert = False 674 675 # Get input list of the candidate with assert removed 676 cand_inputs_assert_removed = [] 677 # TODO: Deactivated while Assert merging is disabled. (See above and below.) 678 for i in []: # candidate.inputs: 679 if i.owner and isinstance(i.owner.op, 680 theano.tensor.opt.Assert): 681 cand_has_assert = True 682 cand_inputs_assert_removed.append(i.owner.inputs[0]) 683 else: 684 cand_inputs_assert_removed.append(i) 685 686 # TODO: Remove this when Assert merging is re-enabled. (See above.) 687 # Without Assert merging we can still look for identical Asserts, 688 # so we should not treat Asserts separately for now. 689 cand_inputs_assert_removed = candidate.inputs 690 691 # Get input list of the node with assert removed 692 if node_has_assert: 693 node_inputs_assert_removed = [] 694 for i in node.inputs: 695 if i.owner and isinstance(i.owner.op, 696 theano.tensor.opt.Assert): 697 node_inputs_assert_removed.append(i.owner.inputs[0]) 698 else: 699 node_inputs_assert_removed.append(i) 700 else: 701 node_inputs_assert_removed = node.inputs 702 703 inputs_match = all(node_in is cand_in 704 for node_in, cand_in 705 in zip(node_inputs_assert_removed, 706 cand_inputs_assert_removed)) 707 708 if inputs_match and node.op == candidate.op: 709 if (node, candidate) in self.blacklist: 710 # They were already tried, and there was an error 711 continue 712 713 # replace node with candidate 714 if not (node_has_assert or cand_has_assert): 715 # Schedule transfer of clients from node to candidate 716 pairs = list(zip(node.outputs, 717 candidate.outputs, 718 ['merge'] * len(node.outputs))) 719 720 # if the current node has assert input, it should not be 721 # replaced with a candidate node which has no assert input 722 elif node_has_assert and not cand_has_assert: 723 pairs = list(zip(candidate.outputs, 724 node.outputs, 725 ['merge'] * len(node.outputs))) 726 else: 727 new_inputs = self.get_merged_assert_input(node, candidate) 728 new_node = node.op(*new_inputs) 729 pairs = list(zip(node.outputs, 730 new_node.owner.outputs, 731 ['new_node'] * len(node.outputs))) +\ 732 list(zip(candidate.outputs, 733 new_node.owner.outputs, 734 ['new_node'] * len(node.outputs))) 735 736 # transfer names 737 for pair in pairs: 738 node_output, cand_output = pair[:2] 739 # clobber old name with new one 740 # it's arbitrary... one of the names has to go 741 if node_output.name: 742 cand_output.name = node_output.name 743 744 replacement_candidates.append(pairs) 745 746 if replacement_candidates: 747 self.scheduled.append(replacement_candidates) 748 else: 749 self.nodes_seen.add(node) 750 if not node.inputs: 751 self.noinput_nodes.add(node) 752 753 def get_merged_assert_input(self, node, candidate): 754 new_inputs = [] 755 for node_i, cand_i in zip(node.inputs, candidate.inputs): 756 # if node_i is assert 757 if (node_i.owner and 758 isinstance(node_i.owner.op, 759 theano.tensor.opt.Assert)): 760 # node_i is assert, cand_i is assert 761 if (cand_i.owner and 762 isinstance(cand_i.owner.op, 763 theano.tensor.opt.Assert)): 764 # Here two assert nodes are merged. 765 # Step 1. Merge conditions of both assert nodes. 766 # Step 2. Make the new assert node 767 node_cond = node_i.owner.inputs[1:] 768 cand_cond = cand_i.owner.inputs[1:] 769 new_cond = list(set(node_cond + cand_cond)) 770 new_inputs.append( 771 theano.tensor.opt.assert_op( 772 node_i.owner.inputs[0], 773 *new_cond)) 774 775 # node_i is assert, cand_i is not assert 776 else: 777 new_inputs.append(node_i) 778 else: 779 # if node_i is not an assert node, append cand_i 780 new_inputs.append(cand_i) 781 782 return new_inputs 783 784 785class MergeOptimizer(Optimizer): 786 """ 787 Merges parts of the graph that are identical and redundant. 788 789 The basic principle is that if two Applies have ops that compare equal, and 790 identical inputs, then they do not both need to be computed. The clients of 791 one are transferred to the other and one of them is removed from the graph. 792 This procedure is carried out in input->output order through the graph. 793 794 The first step of merging is constant-merging, so that all clients of an 795 int(1) for example, are transferred to a particular instance of int(1). 796 797 """ 798 799 def add_requirements(self, fgraph): 800 # Added by default 801 # fgraph.attach_feature(toolbox.ReplaceValidate()) 802 if not hasattr(fgraph, 'merge_feature'): 803 fgraph.attach_feature(MergeFeature()) 804 805 def apply(self, fgraph): 806 # Constant and non-constant are now applied in the same phase. 807 # I am not sure why, but it seems to be faster this way. 808 sched = fgraph.merge_feature.scheduled 809 nb_fail = 0 810 t0 = time.time() 811 if fgraph.profile: 812 validate_before = fgraph.profile.validate_time 813 callback_before = fgraph.execute_callbacks_time 814 callbacks_before = fgraph.execute_callbacks_times.copy() 815 816 nb_merged = 0 817 nb_constant = 0 818 while sched: 819 pairs_list = sched.pop() 820 success = True 821 for pairs_ in pairs_list: 822 # We must check again the equivalence, as the graph 823 # can have changed. If so, doing the replacement can 824 # introduce node that depend on itself. Doing the 825 # full check of such cycle everytimes is very time 826 # consumming. I think this double check is faster then 827 # doing the full cycle check. The full cycle check is 828 # skipped by validate() if the graph don't contain 829 # destroyers. 830 var, candidate, merge_mode = pairs_[0] 831 if merge_mode == "new_node" and hasattr(var, 'fgraph'): 832 pass 833 elif (not hasattr(var, 'fgraph') or 834 not hasattr(candidate, 'fgraph')): 835 continue 836 837 # Keep len(item) == 2 for item in pairs 838 pairs = [pair[:2] for pair in pairs_] 839 840 if var.owner and candidate.owner: 841 node = var.owner 842 candidate = candidate.owner 843 844 # Get input list of the candidate node with assert 845 # nodes removed 846 cand_inputs_assert_removed = [] 847 for i in candidate.inputs: 848 if i.owner and isinstance(i.owner.op, 849 theano.tensor.opt.Assert): 850 cand_inputs_assert_removed.append( 851 i.owner.inputs[0]) 852 else: 853 cand_inputs_assert_removed.append(i) 854 855 # Get input list of the node with assert nodes removed 856 node_inputs_assert_removed = [] 857 for i in node.inputs: 858 if i.owner and isinstance(i.owner.op, 859 theano.tensor.opt.Assert): 860 node_inputs_assert_removed.append( 861 i.owner.inputs[0]) 862 else: 863 node_inputs_assert_removed.append(i) 864 865 if merge_mode == "new_node": 866 inputs_match = True 867 else: 868 inputs_match = all(node_in is cand_in 869 for node_in, cand_in in 870 zip(node_inputs_assert_removed, 871 cand_inputs_assert_removed)) 872 873 # No need to compare the op again, as it don't change. 874 if not inputs_match: 875 continue 876 877 if hasattr(pairs[0][0].fgraph, 'destroy_handler'): 878 # If both nodes have clients that destroy 879 # them, we can't merge them. 880 clients = pairs[0][0].clients + pairs[0][1].clients 881 if sum([i in utils.flatten(c.op.destroy_map.values()) 882 for c, i in clients 883 if c != 'output' and 884 hasattr(c.op, 'destroy_map')]) > 1: 885 continue 886 887 if len(pairs) == 1 and pairs[0][0].type != pairs[0][1].type: 888 res = pairs[0][0].type.convert_variable(pairs[0][1]) 889 890 # Since the fgraph.replace only checks the convert_variable 891 # in one way, we change the order in the case that 892 # convert_variable will not be successful. 893 if not res: 894 pairs = [(pairs[0][1], pairs[0][0])] 895 896 try: 897 # If all Constants, no need to call validate. 898 # Only need to check one of the var of each pairs. 899 # If it is a Constant, the other must also be a Constant as we merge them. 900 if all([isinstance(old, graph.Constant) for old, new in pairs]): 901 fgraph.replace_all(pairs, 'MergeOptimizer') 902 else: 903 fgraph.replace_all_validate(pairs, 'MergeOptimizer') 904 except InconsistencyError: 905 success = False 906 nb_fail += 1 907 fgraph.merge_feature.blacklist.append( 908 (pairs[0][0].owner, pairs[0][1].owner)) 909 910 if success: 911 nb_merged += len(pairs) 912 if isinstance(pairs[0][0], graph.Constant): 913 nb_constant += 1 914 # print pairs, pairs[0][0].type 915 break 916 917 if fgraph.profile: 918 validate_time = fgraph.profile.validate_time - validate_before 919 callback_time = fgraph.execute_callbacks_time - callback_before 920 callbacks_time = {} 921 for k, v in iteritems(fgraph.execute_callbacks_times): 922 if k in callbacks_before: 923 t = v - callbacks_before[k] 924 if t > 0: 925 callbacks_time[k] = t 926 else: 927 callbacks_time[k] = v 928 else: 929 validate_time = None 930 callback_time = None 931 callbacks_time = {} 932 # clear blacklist 933 fgraph.merge_feature.blacklist = [] 934 return (nb_fail, time.time() - t0, validate_time, 935 callback_time, callbacks_time, nb_merged, nb_constant) 936 937 def __str__(self): 938 return self.__class__.__name__ 939 940 @staticmethod 941 def print_profile(stream, prof, level=0): 942 943 (nb_fail, replace_time, validate_time, 944 callback_time, callbacks_time, nb_merged, nb_constant) = prof 945 946 blanc = (' ' * level) 947 print(blanc, "MergeOptimizer", file=stream) 948 print(blanc, " nb fail=%5d merged=%5d constant=%5d" % ( 949 nb_fail, nb_merged, nb_constant), file=stream) 950 print(blanc, " time replace=%2.2f validate=%2.2f callback=%2.2f" % ( 951 replace_time, validate_time, callback_time), file=stream) 952 if callback_time > 1: 953 print(blanc, " callbacks_time", file=stream) 954 for i in sorted(iteritems(callbacks_time), key=lambda a: a[1]): 955 if i[1] > 0: 956 # We want to have the __str__ called, so we can't 957 # just print i. 958 print(blanc, " ", i[0], ',', i[1], file=stream) 959 960 @staticmethod 961 def merge_profile(prof1, prof2): 962 def merge_none_number(v1, v2): 963 if v1 is None: 964 return v2 965 if v2 is None: 966 return v1 967 return v1 + v2 968 nb_fail = prof1[0] + prof2[0] 969 replace_time = prof1[1] + prof2[1] 970 validate_time = merge_none_number(prof1[2], prof2[2]) 971 callback_time = merge_none_number(prof1[3], prof2[3]) 972 callbacks_time = merge_dict(prof1[4], prof2[4]) 973 nb_merged = prof1[5] + prof2[5] 974 nb_constant = prof1[6] + prof2[6] 975 return (nb_fail, replace_time, validate_time, 976 callback_time, callbacks_time, nb_merged, nb_constant) 977 978 979def is_same_graph_with_merge(var1, var2, givens=None): 980 """ 981 Merge-based implementation of `theano.gof.graph.is_same_graph`. 982 983 See help on `theano.gof.graph.is_same_graph` for additional documentation. 984 985 """ 986 if givens is None: 987 givens = {} 988 # Copy variables since the MergeOptimizer will modify them. 989 copied = copy.deepcopy([var1, var2, givens]) 990 vars = copied[0:2] 991 givens = copied[2] 992 # Create FunctionGraph. 993 inputs = theano.gof.graph.inputs(vars) 994 # The clone isn't needed as we did a deepcopy and we cloning will 995 # break the mapping in givens. 996 fgraph = theano.gof.fg.FunctionGraph(inputs, vars, clone=False) 997 # Perform Variable substitution. 998 for to_replace, replace_by in iteritems(givens): 999 fgraph.replace(to_replace, replace_by) 1000 # Perform merge optimization. 1001 MergeOptimizer().optimize(fgraph) 1002 # When two variables perform the same computations, they will have the same 1003 # owner in the optimized graph. 1004 # We need to be careful with the special case where the owner is None, 1005 # which happens when the graph is made of a single Variable. 1006 # We also need to make sure we replace a Variable if it is present in 1007 # `givens`. 1008 vars_replaced = [givens.get(v, v) for v in vars] 1009 o1, o2 = [v.owner for v in vars_replaced] 1010 if o1 is None and o2 is None: 1011 # Comparing two single-Variable graphs: they are equal if they are 1012 # the same Variable. 1013 return vars_replaced[0] == vars_replaced[1] 1014 else: 1015 return o1 is o2 1016 1017 1018def pre_constant_merge(vars): 1019 """ 1020 Merge constants in the subgraph used to compute nodes in `vars`. 1021 1022 `vars` is a list of nodes, and we want to merge together nodes 1023 that are constant inputs used to compute nodes in that list. 1024 1025 Notes 1026 ----- 1027 This function will ignore nodes that are in an fgraph. 1028 It is used to pre-merge nodes generated inside an optimization, 1029 before it is inserted in the fgraph. 1030 It is useful if there are many such replacements to make, 1031 so that DebugMode will not check each of them. 1032 1033 """ 1034 seen_var = set() 1035 # signature -> variable (for constants) 1036 const_sig_inv = {} 1037 if isinstance(vars, graph.Variable): 1038 vars = [vars] 1039 1040 def recursive_merge(var): 1041 if var in seen_var: 1042 return var 1043 if not hasattr(var, 'owner'): 1044 return var 1045 if var.owner and hasattr(var.owner, "fgraph"): 1046 return var 1047 seen_var.add(var) 1048 if isinstance(var, graph.Constant): 1049 sig = var.signature() 1050 try: 1051 if sig in const_sig_inv: 1052 return const_sig_inv[sig] 1053 const_sig_inv[sig] = var 1054 except TypeError: # unhashable type 1055 warnings.warn( 1056 "We work around a problem, the following variable" 1057 " signature isn't hashable. Please, report this to" 1058 " theano-dev so that the better fix is done. %s" % var) 1059 # Some python object like slice aren't hashable. So 1060 # don't merge them here. 1061 pass 1062 return var 1063 if var.owner: 1064 for idx, inp in enumerate(var.owner.inputs): 1065 var.owner.inputs[idx] = recursive_merge(inp) 1066 return var 1067 1068 return list(map(recursive_merge, vars)) 1069 1070 1071######################## 1072# Local Optimizers # 1073######################## 1074 1075class LocalOptimizer(object): 1076 """ 1077 A class for node-based optimizations. 1078 1079 Instances should implement the transform function, 1080 and be passed to configure a fgraph-based Optimizer instance. 1081 1082 """ 1083 1084 def __hash__(self): 1085 if not hasattr(self, '_optimizer_idx'): 1086 self._optimizer_idx = _optimizer_idx[0] 1087 _optimizer_idx[0] += 1 1088 return self._optimizer_idx 1089 1090 def tracks(self): 1091 """ 1092 Return the list of op classes that this opt applies to. 1093 1094 Return None to apply to all nodes. 1095 1096 """ 1097 return None 1098 1099 def transform(self, node): 1100 """ 1101 Transform a subgraph whose output is `node`. 1102 1103 Subclasses should implement this function so that it returns one of two 1104 kinds of things: 1105 1106 - False to indicate that no optimization can be applied to this `node`; 1107 or 1108 - <list of variables> to use in place of `node`'s outputs in the 1109 greater graph. 1110 - dict(old variables -> new variables). A dictionary that map 1111 from old variables to new variables to replace. 1112 1113 Parameters 1114 ---------- 1115 node : an Apply instance 1116 1117 """ 1118 1119 raise utils.MethodNotDefined("transform", 1120 type(self), self.__class__.__name__) 1121 1122 def add_requirements(self, fgraph): 1123 """ 1124 If this local optimization wants to add some requirements to the 1125 fgraph, this is the place to do it. 1126 1127 """ 1128 # Added by default 1129 # fgraph.attach_feature(toolbox.ReplaceValidate()) 1130 pass 1131 1132 def print_summary(self, stream=sys.stdout, level=0, depth=-1): 1133 print("%s%s id=%i" % ( 1134 (' ' * level), self.__class__.__name__, id(self)), file=stream) 1135 1136 1137class LocalMetaOptimizer(LocalOptimizer): 1138 """ 1139 Base class for meta-optimizers that try a set of LocalOptimizers 1140 to replace a node and choose the one that executes the fastest. 1141 1142 If the error LocalMetaOptimizerSkipAssertionError is raised during 1143 compilation, we will skip that function compilation and not print 1144 the error. 1145 1146 """ 1147 1148 def __init__(self): 1149 self.verbose = config.metaopt.verbose 1150 self.track_dict = defaultdict(lambda: []) 1151 self.tag_dict = defaultdict(lambda: []) 1152 self._tracks = [] 1153 self.optimizers = [] 1154 1155 def register(self, optimizer, tag_list): 1156 self.optimizers.append(optimizer) 1157 for c in optimizer.tracks(): 1158 self.track_dict[c].append(optimizer) 1159 self._tracks.append(c) 1160 for tag in tag_list: 1161 self.tag_dict[tag].append(optimizer) 1162 1163 def tracks(self): 1164 return self._tracks 1165 1166 def transform(self, node): 1167 # safety check: depending on registration, tracks may have been ignored 1168 if self._tracks is not None: 1169 if not isinstance(node.op, tuple(self._tracks)): 1170 return 1171 # first, we need to provide dummy values for all inputs 1172 # to the node that are not shared variables anyway 1173 givens = {} 1174 missing = set() 1175 for input in node.inputs: 1176 if isinstance(input, theano.compile.SharedVariable): 1177 pass 1178 elif hasattr(input.tag, 'test_value'): 1179 givens[input] = theano.shared( 1180 input.type.filter(input.tag.test_value), 1181 input.name, 1182 broadcastable=input.broadcastable, 1183 borrow=True) 1184 else: 1185 missing.add(input) 1186 if missing: 1187 givens.update(self.provide_inputs(node, missing)) 1188 missing.difference_update(givens.keys()) 1189 # ensure we have data for all input variables that need it 1190 if missing: 1191 if self.verbose > 0: 1192 print(("%s cannot meta-optimize %s, " 1193 "%d of %d input shapes unknown" % 1194 (self.__class__.__name__, node, len(missing), node.nin))) 1195 return 1196 # now we can apply the different optimizations in turn, 1197 # compile the resulting subgraphs and time their execution 1198 if self.verbose > 1: 1199 print(("%s meta-optimizing %s (%d choices):" % 1200 (self.__class__.__name__, node, len(self.get_opts(node))))) 1201 timings = [] 1202 for opt in self.get_opts(node): 1203 outputs = opt.transform(node) 1204 if outputs: 1205 try: 1206 fn = theano.function([], outputs, givens=givens, 1207 on_unused_input='ignore') 1208 fn.trust_input = True 1209 timing = min(self.time_call(fn) for _ in range(2)) 1210 except LocalMetaOptimizerSkipAssertionError: 1211 continue 1212 except Exception as e: 1213 if self.verbose > 0: 1214 print("* %s: exception" % opt, e) 1215 continue 1216 else: 1217 if self.verbose > 1: 1218 print("* %s: %.5g sec" % (opt, timing)) 1219 timings.append((timing, outputs, opt)) 1220 else: 1221 if self.verbose > 0: 1222 print("* %s: not applicable" % opt) 1223 # finally, we choose the fastest one 1224 if timings: 1225 timings.sort() 1226 if self.verbose > 1: 1227 print("= %s" % timings[0][2]) 1228 return timings[0][1] 1229 return 1230 1231 def provide_inputs(self, node, inputs): 1232 """ 1233 If implemented, returns a dictionary mapping all symbolic variables 1234 in ``inputs`` to SharedVariable instances of suitable dummy values. 1235 The ``node`` can be inspected to infer required input shapes. 1236 1237 """ 1238 raise NotImplementedError() 1239 1240 def get_opts(self, node): 1241 """ 1242 Can be overrided to change the way opts are selected 1243 """ 1244 return self.track_dict[type(node.op)] 1245 1246 def time_call(self, fn): 1247 start = time.time() 1248 fn() 1249 return time.time() - start 1250 1251 1252class FromFunctionLocalOptimizer(LocalOptimizer): 1253 """ 1254 WRITEME 1255 1256 """ 1257 def __init__(self, fn, tracks=None, requirements=()): 1258 self.transform = fn 1259 self._tracks = tracks 1260 self.requirements = requirements 1261 1262 def add_requirements(self, fgraph): 1263 for req in self.requirements: 1264 req(fgraph) 1265 1266 def tracks(self): 1267 return self._tracks 1268 1269 def __str__(self): 1270 return getattr(self, '__name__', 1271 '<FromFunctionLocalOptimizer instance>') 1272 1273 def print_summary(self, stream=sys.stdout, level=0, depth=-1): 1274 print("%s%s id=%i" % ( 1275 ' ' * level, 1276 str(self.transform), 1277 id(self)), file=stream) 1278 1279 1280def local_optimizer(tracks, inplace=False, requirements=()): 1281 def decorator(f): 1282 """ 1283 WRITEME 1284 1285 """ 1286 if tracks is not None: 1287 if len(tracks) == 0: 1288 raise ValueError("Use None instead of an empty list to apply to all nodes.", f.__module__, f.__name__) 1289 for t in tracks: 1290 if not (isinstance(t, op.Op) or issubclass(t, op.PureOp)): 1291 raise ValueError("Tracks are op classes or instances", f.__module__, f.__name__) 1292 req = requirements 1293 if inplace: 1294 dh_handler = dh.DestroyHandler 1295 req = tuple(requirements) + ( 1296 lambda fgraph: 1297 fgraph.attach_feature(dh_handler()),) 1298 rval = FromFunctionLocalOptimizer(f, tracks, req) 1299 rval.__name__ = f.__name__ 1300 return rval 1301 return decorator 1302 1303 1304class LocalOptGroup(LocalOptimizer): 1305 """Takes a list of LocalOptimizer and applies them to the node. 1306 1307 Parameters 1308 ---------- 1309 optimizers : 1310 The List of optimizers to be applied to a node 1311 reentrant : bool (Default True) 1312 Keyword only argument. Reentrant information. Some global 1313 optimizer like NavigatorOptimizer can use this value to 1314 determine if it ignore new nodes during a pass on the 1315 nodes. Sometimes, ignore_newtrees is not reentrant. 1316 apply_all_opts : bool (Default False) 1317 If False, it will return after the new node after the first optimizer 1318 applied. Otherwise, it will start again with the new node until no new 1319 optimization apply. 1320 1321 """ 1322 1323 def __init__(self, *optimizers, **kwargs): 1324 if len(optimizers) == 1 and isinstance(optimizers[0], list): 1325 # This happen when created by LocalGroupDB. 1326 optimizers = tuple(optimizers[0]) 1327 self.opts = optimizers 1328 assert isinstance(self.opts, tuple) 1329 1330 self.reentrant = any(getattr(opt, 'reentrant', True) 1331 for opt in optimizers) 1332 self.retains_inputs = all(getattr(opt, 'retains_inputs', False) 1333 for opt in optimizers) 1334 1335 self.apply_all_opts = kwargs.pop('apply_all_opts', False) 1336 self.profile = kwargs.pop('profile', False) 1337 self.track_map = defaultdict(lambda: []) 1338 assert len(kwargs) == 0 1339 if self.profile: 1340 self.time_opts = {} 1341 self.process_count = {} 1342 self.applied_true = {} 1343 self.node_created = {} 1344 1345 for o in self.opts: 1346 if self.profile: 1347 self.time_opts.setdefault(o, 0) 1348 self.process_count.setdefault(o, 0) 1349 self.applied_true.setdefault(o, 0) 1350 self.node_created.setdefault(o, 0) 1351 tracks = o.tracks() 1352 if tracks is None: 1353 self.track_map[None].append(o) 1354 else: 1355 for c in tracks: 1356 self.track_map[c].append(o) 1357 1358 def __str__(self): 1359 return getattr(self, '__name__', 1360 ('LocalOptGroup(%s)' % 1361 ','.join([str(o) for o in self.opts]))) 1362 1363 def tracks(self): 1364 t = [] 1365 for l in self.opts: 1366 tt = l.tracks() 1367 if tt: 1368 t.extend(tt) 1369 return t 1370 1371 def transform(self, node): 1372 if len(self.opts) == 0: 1373 return 1374 fgraph = node.fgraph 1375 repl = None 1376 while True: 1377 opts = self.track_map[type(node.op)] + self.track_map[node.op] + self.track_map[None] 1378 new_repl = None 1379 for opt in opts: 1380 opt_start = time.time() 1381 new_repl = opt.transform(node) 1382 opt_finish = time.time() 1383 if self.profile: 1384 self.time_opts[opt] += opt_start - opt_finish 1385 self.process_count[opt] += 1 1386 if not new_repl: 1387 continue 1388 if isinstance(new_repl, (tuple, list)): 1389 new_vars = new_repl 1390 else: # It must be a dict 1391 new_vars = list(new_repl.values()) 1392 if self.profile: 1393 self.node_created[opt] += len(graph.ops(fgraph.variables, new_vars)) 1394 self.applied_true[opt] += 1 1395 break # break from the for loop over optimization. 1396 if not new_repl: # No optimization applied in the last iteration 1397 return repl 1398 # only 1 iteration 1399 if not self.apply_all_opts: 1400 return new_repl 1401 if not new_vars[0].owner: 1402 # We are at the start of the graph. 1403 return new_repl 1404 if len(new_repl) > 1: 1405 s = set([v.owner for v in new_repl]) 1406 assert len(s) == 1 1407 repl = new_repl 1408 node = new_vars[0].owner 1409 1410 @staticmethod 1411 def print_profile(stream, prof, level=0): 1412 (time_opts, process_count, applied_true, node_created, profile) = prof 1413 1414 if not profile: 1415 return 1416 1417 blanc = (' ' * int(level)) 1418 print(blanc, "LocalOptGroup", file=stream) 1419 print(blanc, "---------------------", file=stream) 1420 count_opt = [] 1421 not_used = [] 1422 not_used_time = 0 1423 for o, count in iteritems(process_count): 1424 if count > 0: 1425 count_opt.append((time_opts[o], applied_true[o], count, o, node_created[o])) 1426 else: 1427 not_used.append((time_opts[o], o)) 1428 not_used_time += time_opts[o] 1429 if count_opt: 1430 print(blanc, 1431 ' time taken - times applied - times tried - name - node_created:', 1432 file=stream) 1433 count_opt.sort() 1434 for (t, a_t, count, o, n_c) in count_opt[::-1]: 1435 print(blanc, ' %.3fs - %d - %d - %s - %d' % ( 1436 t, a_t, count, o, n_c), file=stream) 1437 print(blanc, ' %.3fs - in %d optimization that were not used (display those with runtime greater than 0)' % ( 1438 not_used_time, len(not_used)), file=stream) 1439 not_used.sort(key=lambda nu: (nu[0], str(nu[1]))) 1440 for (t, o) in not_used[::-1]: 1441 if t > 0: 1442 # Skip opt that have 0 times, they probably wasn't even tried. 1443 print(blanc + " ", ' %.3fs - %s' % (t, o), file=stream) 1444 else: 1445 print(blanc, " The Optimizer wasn't successful ", file=stream) 1446 1447 print(file=stream) 1448 1449 def merge_profile(prof1, prof2): 1450 raise NotImplementedError 1451 1452 def print_summary(self, stream=sys.stdout, level=0, depth=-1): 1453 print("%s%s id=%i" % ( 1454 (' ' * level), self.__class__.__name__, id(self)), file=stream) 1455 if depth != 0: 1456 depth -= 1 1457 for lopt in self.opts: 1458 lopt.print_summary(stream, level=(level + 2), depth=depth) 1459 1460 def add_requirements(self, fgraph): 1461 for opt in self.opts: 1462 opt.add_requirements(fgraph) 1463 1464 1465class GraphToGPULocalOptGroup(LocalOptGroup): 1466 """This is the equivalent of LocalOptGroup for GraphToGPU. 1467 1468 The main different is the function signature of the local 1469 optimizer that use the GraphToGPU signature and not the normal 1470 LocalOptimizer signature. 1471 1472 apply_all_opts=True is not supported 1473 1474 """ 1475 def __init__(self, *optimizers, **kwargs): 1476 super(GraphToGPULocalOptGroup, self).__init__(*optimizers, **kwargs) 1477 assert self.apply_all_opts is False 1478 1479 def transform(self, op, context_name, inputs, outputs): 1480 if len(self.opts) == 0: 1481 return 1482 fgraph = outputs[0].fgraph 1483 opts = self.track_map[type(op)] + self.track_map[op] + self.track_map[None] 1484 for opt in opts: 1485 opt_start = time.time() 1486 new_repl = opt.transform(op, context_name, inputs, outputs) 1487 opt_finish = time.time() 1488 if self.profile: 1489 self.time_opts[opt] += opt_start - opt_finish 1490 self.process_count[opt] += 1 1491 if not new_repl: 1492 continue 1493 if self.profile: 1494 self.node_created[opt] += len(graph.ops(fgraph.variables, new_repl)) 1495 self.applied_true[opt] += 1 1496 1497 return new_repl 1498 1499 1500class OpSub(LocalOptimizer): 1501 """ 1502 1503 Replaces the application of a certain op by the application of 1504 another op that takes the same inputs as what they are replacing. 1505 1506 Parameters 1507 ---------- 1508 op1, op2 1509 op1.make_node and op2.make_node must take the same number of 1510 inputs and have the same number of outputs. 1511 1512 Examples 1513 -------- 1514 OpSub(add, sub) ==> 1515 add(div(x, y), add(y, x)) -> sub(div(x, y), sub(y, x)) 1516 1517 """ 1518 1519 # an OpSub does not apply to the nodes it produces 1520 reentrant = False 1521 # all the inputs of the original node are transferred to the outputs 1522 retains_inputs = True 1523 1524 def __init__(self, op1, op2, transfer_tags=True): 1525 self.op1 = op1 1526 self.op2 = op2 1527 self.transfer_tags = transfer_tags 1528 1529 def op_key(self): 1530 return self.op1 1531 1532 def tracks(self): 1533 return [self.op1] 1534 1535 def transform(self, node): 1536 if node.op != self.op1: 1537 return False 1538 repl = self.op2.make_node(*node.inputs) 1539 if self.transfer_tags: 1540 repl.tag = copy.copy(node.tag) 1541 for output, new_output in zip(node.outputs, repl.outputs): 1542 new_output.tag = copy.copy(output.tag) 1543 return repl.outputs 1544 1545 def __str__(self): 1546 return "%s -> %s" % (self.op1, self.op2) 1547 1548 1549class OpRemove(LocalOptimizer): 1550 """ 1551 1552 Removes all applications of an op by transferring each of its 1553 outputs to the corresponding input. 1554 1555 """ 1556 1557 reentrant = False # no nodes are added at all 1558 1559 def __init__(self, op): 1560 self.op = op 1561 1562 def op_key(self): 1563 return self.op 1564 1565 def tracks(self): 1566 return [self.op] 1567 1568 def transform(self, node): 1569 if node.op != self.op: 1570 return False 1571 return node.inputs 1572 1573 def __str__(self): 1574 return "%s(x) -> x" % (self.op) 1575 1576 def print_summary(self, stream=sys.stdout, level=0, depth=-1): 1577 print("%s%s(%s) id=%i" % ( 1578 ' ' * level, 1579 self.__class__.__name__, 1580 str(self.op), 1581 id(self)), file=stream) 1582 1583 1584class PatternSub(LocalOptimizer): 1585 """ 1586 1587 @todo update 1588 1589 Replaces all occurrences of the input pattern by the output pattern: 1590 1591 input_pattern ::= (op, <sub_pattern1>, <sub_pattern2>, ...) 1592 input_pattern ::= dict(pattern = <input_pattern>, 1593 constraint = <constraint>) 1594 sub_pattern ::= input_pattern 1595 sub_pattern ::= string 1596 sub_pattern ::= a Constant instance 1597 sub_pattern ::= int 1598 sub_pattern ::= float 1599 constraint ::= lambda fgraph, expr: additional matching condition 1600 1601 output_pattern ::= (op, <output_pattern1>, <output_pattern2>, ...) 1602 output_pattern ::= string 1603 output_pattern ::= int 1604 output_pattern ::= float 1605 1606 Each string in the input pattern is a variable that will be set to 1607 whatever expression is found in its place. If the same string is 1608 used more than once, the same expression must be found in those 1609 places. If a string used in the input pattern is used in the 1610 output pattern, the matching expression will be inserted in its 1611 place. The input pattern cannot just be a string but the output 1612 pattern can. 1613 1614 If you put a constant variable in the input pattern, there will be a 1615 match iff a constant variable with the same value and the same type 1616 is found in its place. 1617 1618 You can add a constraint to the match by using the dict(...) form 1619 described above with a 'constraint' key. The constraint must be a 1620 function that takes the fgraph and the current Variable that we are 1621 trying to match and returns True or False according to an 1622 arbitrary criterion. 1623 1624 The constructor creates a PatternSub that replaces occurrences of 1625 in_pattern by occurrences of out_pattern. 1626 1627 Parameters 1628 ---------- 1629 in_pattern 1630 The input pattern that we want to replace. 1631 out_pattern 1632 The replacement pattern. 1633 allow_multiple_clients : bool 1634 If False, the pattern matching will fail if one of the subpatterns has 1635 more than one client. 1636 skip_identities_fn : TODO 1637 name 1638 Allows to override this optimizer name. 1639 pdb : bool 1640 If True, we invoke pdb when the first node in the pattern matches. 1641 tracks : optional 1642 The values that self.tracks() will return. Useful to speed up 1643 optimization sometimes. 1644 get_nodes : optional 1645 If you provide `tracks`, you must provide this parameter. It must be a 1646 function that takes the tracked node and returns a list of nodes on 1647 which we will try this optimizer. 1648 1649 Notes 1650 ----- 1651 `tracks` and `get_nodes` can be used to make this optimizer track a less 1652 frequent Op, so this will make this optimizer tried less frequently. 1653 1654 Examples 1655 -------- 1656 PatternSub((add, 'x', 'y'), (add, 'y', 'x')) 1657 PatternSub((multiply, 'x', 'x'), (square, 'x')) 1658 PatternSub((subtract, (add, 'x', 'y'), 'y'), 'x') 1659 PatternSub((power, 'x', Constant(double, 2.0)), (square, 'x')) 1660 PatternSub((boggle, {'pattern': 'x', 1661 'constraint': lambda expr: expr.type == scrabble}), 1662 (scrabble, 'x')) 1663 """ 1664 1665 def __init__(self, in_pattern, out_pattern, 1666 allow_multiple_clients=False, 1667 skip_identities_fn=None, name=None, pdb=False, 1668 tracks=(), get_nodes=None, 1669 values_eq_approx=None): 1670 self.in_pattern = in_pattern 1671 self.out_pattern = out_pattern 1672 self.values_eq_approx = values_eq_approx 1673 if isinstance(in_pattern, (list, tuple)): 1674 self.op = self.in_pattern[0] 1675 elif isinstance(in_pattern, dict): 1676 self.op = self.in_pattern['pattern'][0] 1677 else: 1678 raise TypeError("The pattern to search for must start with " 1679 "a specific Op instance.") 1680 self.__doc__ = (self.__class__.__doc__ + 1681 "\n\nThis instance does: " + 1682 str(self) + "\n") 1683 self.allow_multiple_clients = allow_multiple_clients 1684 self.skip_identities_fn = skip_identities_fn 1685 if name: 1686 self.__name__ = name 1687 self.pdb = pdb 1688 self._tracks = tracks 1689 self.get_nodes = get_nodes 1690 if tracks != (): 1691 assert get_nodes 1692 1693 def op_key(self): 1694 return self.op 1695 1696 def tracks(self): 1697 if self._tracks != (): 1698 return self._tracks 1699 return [self.op] 1700 1701 def transform(self, node, get_nodes=True): 1702 """ 1703 Checks if the graph from node corresponds to in_pattern. If it does, 1704 constructs out_pattern and performs the replacement. 1705 1706 """ 1707 if get_nodes and self.get_nodes is not None: 1708 for real_node in self.get_nodes(node): 1709 if real_node == "output": 1710 continue 1711 ret = self.transform(real_node, get_nodes=False) 1712 if ret is not False and ret is not None: 1713 assert len(real_node.outputs) == len(ret) 1714 if self.values_eq_approx: 1715 ret.tag.values_eq_approx = self.values_eq_approx 1716 return dict(izip(real_node.outputs, ret)) 1717 1718 if node.op != self.op: 1719 return False 1720 # TODO: if we remove pdb, do this speed things up? 1721 1722 def match(pattern, expr, u, allow_multiple_clients=False, pdb=False): 1723 # TODO move outside match 1724 def retry_with_equiv(): 1725 if not self.skip_identities_fn: 1726 return False 1727 expr_equiv = self.skip_identities_fn(expr) 1728 if expr_equiv is None: 1729 return False 1730 # TODO: Not sure how to handle multiple_clients flag 1731 # print 'retrying match', pattern, expr_equiv 1732 return match(pattern, expr_equiv, u, 1733 allow_multiple_clients=allow_multiple_clients) 1734 1735 if isinstance(pattern, (list, tuple)): 1736 if expr.owner is None: 1737 return False 1738 if (not (expr.owner.op == pattern[0]) or 1739 (not allow_multiple_clients and len(expr.clients) > 1)): 1740 return retry_with_equiv() 1741 if len(pattern) - 1 != len(expr.owner.inputs): 1742 return retry_with_equiv() 1743 for p, v in zip(pattern[1:], expr.owner.inputs): 1744 u = match(p, v, u, self.allow_multiple_clients) 1745 if not u: 1746 return False 1747 elif isinstance(pattern, dict): 1748 try: 1749 real_pattern = pattern['pattern'] 1750 except KeyError: 1751 raise KeyError( 1752 "Malformed pattern: %s (expected key 'pattern')" 1753 % pattern) 1754 constraint = pattern.get('constraint', lambda expr: True) 1755 if constraint(expr): 1756 return match(real_pattern, expr, u, 1757 pattern.get('allow_multiple_clients', 1758 allow_multiple_clients)) 1759 else: 1760 return retry_with_equiv() 1761 elif isinstance(pattern, string_types): 1762 v = unify.Var(pattern) 1763 if u[v] is not v and u[v] is not expr: 1764 return retry_with_equiv() 1765 else: 1766 u = u.merge(expr, v) 1767 elif (isinstance(pattern, (integer_types, float)) and 1768 isinstance(expr, graph.Constant)): 1769 if np.all(theano.tensor.constant(pattern).value == expr.value): 1770 return u 1771 else: 1772 return retry_with_equiv() 1773 elif (isinstance(pattern, graph.Constant) and 1774 isinstance(expr, graph.Constant) and 1775 pattern.equals(expr)): 1776 return u 1777 else: 1778 return retry_with_equiv() 1779 if pdb: 1780 import pdb 1781 pdb.set_trace() 1782 return u 1783 1784 u = match(self.in_pattern, node.out, unify.Unification(), True, 1785 self.pdb) 1786 if u: 1787 def build(pattern, u): 1788 if isinstance(pattern, (list, tuple)): 1789 args = [build(p, u) for p in pattern[1:]] 1790 return pattern[0](*args) 1791 elif isinstance(pattern, string_types): 1792 return u[unify.Var(pattern)] 1793 elif isinstance(pattern, (integer_types, float)): 1794 return pattern 1795 else: 1796 return pattern.clone() 1797 p = self.out_pattern 1798 ret = build(p, u) 1799 if self.values_eq_approx: 1800 ret.tag.values_eq_approx = self.values_eq_approx 1801 return [ret] 1802 else: 1803 return False 1804 1805 def __str__(self): 1806 if getattr(self, '__name__', None): 1807 return self.__name__ 1808 1809 def pattern_to_str(pattern): 1810 if isinstance(pattern, (list, tuple)): 1811 return "%s(%s)" % ( 1812 str(pattern[0]), 1813 ", ".join([pattern_to_str(p) for p in pattern[1:]])) 1814 elif isinstance(pattern, dict): 1815 return "%s subject to %s" % ( 1816 pattern_to_str(pattern['pattern']), 1817 str(pattern.get('constraint', 'no conditions'))) 1818 else: 1819 return str(pattern) 1820 return "%s -> %s" % ( 1821 pattern_to_str(self.in_pattern), 1822 pattern_to_str(self.out_pattern)) 1823 1824 def __repr__(self): 1825 return str(self) 1826 1827 def print_summary(self, stream=sys.stdout, level=0, depth=-1): 1828 name = getattr(self, '__name__', getattr(self, 'name', None)) 1829 print("%s%s %s(%s, %s) id=%i" % ( 1830 ' ' * level, 1831 self.__class__.__name__, 1832 name, 1833 str(self.in_pattern), 1834 str(self.out_pattern), 1835 id(self)), file=stream) 1836 1837 1838################## 1839# Navigators # 1840################## 1841 1842# Use the following classes to apply LocalOptimizers 1843 1844class Updater: 1845 def __init__(self, importer, pruner, chin, name=None): 1846 self.importer = importer 1847 self.pruner = pruner 1848 self.chin = chin 1849 self.name = name 1850 1851 def __str__(self): 1852 return "Updater{%s}" % str(self.name) 1853 1854 def on_import(self, fgraph, node, reason): 1855 if self.importer: 1856 self.importer(node) 1857 1858 def on_prune(self, fgraph, node, reason): 1859 if self.pruner: 1860 self.pruner(node) 1861 1862 def on_change_input(self, fgraph, node, i, r, new_r, reason): 1863 if self.chin: 1864 self.chin(node, i, r, new_r, reason) 1865 1866 def on_detach(self, fgraph): 1867 # To allow pickling this object 1868 self.importer = None 1869 self.pruner = None 1870 self.chin = None 1871 1872 1873class NavigatorOptimizer(Optimizer): 1874 """ 1875 Abstract class. 1876 1877 Parameters 1878 ---------- 1879 local_opt 1880 A LocalOptimizer to apply over a FunctionGraph (or None is Ok too). 1881 ignore_newtrees 1882 - True: new subgraphs returned by an optimization is not a 1883 candidate for optimization. 1884 - False: new subgraphs returned by an optimization is a candidate 1885 for optimization. 1886 - 'auto': let the local_opt set this parameter via its 'reentrant' 1887 attribute. 1888 failure_callback 1889 A function that takes (exception, navigator, [(old, new), 1890 (old,new),...]) and we call it if there's an exception. 1891 1892 If the trouble is from local_opt.transform(), the new variables 1893 will be 'None'. 1894 1895 If the trouble is from validation (the new types don't match for 1896 example) then the new variables will be the ones created by 1897 transform(). 1898 1899 If this parameter is None, then exceptions are not caught here 1900 (raised normally). 1901 1902 """ 1903 @staticmethod 1904 def warn(exc, nav, repl_pairs, local_opt, node): 1905 """ 1906 Failure_callback for NavigatorOptimizer: print traceback. 1907 1908 """ 1909 if config.on_opt_error != 'ignore': 1910 _logger.error("Optimization failure due to: %s" % str(local_opt)) 1911 _logger.error("node: %s" % str(node)) 1912 _logger.error("TRACEBACK:") 1913 _logger.error(traceback.format_exc()) 1914 if config.on_opt_error == 'pdb': 1915 pdb.post_mortem(sys.exc_info()[2]) 1916 elif isinstance(exc, AssertionError) or config.on_opt_error == 'raise': 1917 # We always crash on AssertionError because something may be 1918 # seriously wrong if such an exception is raised. 1919 raise exc 1920 1921 @staticmethod 1922 def warn_inplace(exc, nav, repl_pairs, local_opt, node): 1923 """ 1924 Failure_callback for NavigatorOptimizer. 1925 1926 Ignore InconsistencyErrors, print traceback. 1927 1928 If error during replacement repl_pairs is set. Otherwise None. 1929 1930 """ 1931 if isinstance(exc, InconsistencyError): 1932 return 1933 return NavigatorOptimizer.warn(exc, nav, repl_pairs, local_opt, node) 1934 1935 @staticmethod 1936 def warn_ignore(exc, nav, repl_pairs, local_opt, node): 1937 """ 1938 Failure_callback for NavigatorOptimizer: ignore all errors. 1939 1940 """ 1941 pass 1942 1943 def __init__(self, local_opt, ignore_newtrees='auto', 1944 failure_callback=None): 1945 self.local_opt = local_opt 1946 if ignore_newtrees == 'auto': 1947 self.ignore_newtrees = not getattr(local_opt, 'reentrant', True) 1948 else: 1949 self.ignore_newtrees = ignore_newtrees 1950 self.failure_callback = failure_callback 1951 1952 def attach_updater(self, fgraph, importer, pruner, chin=None, name=None): 1953 """ 1954 Install some FunctionGraph listeners to help the navigator deal with 1955 the ignore_trees-related functionality. 1956 1957 Parameters 1958 ---------- 1959 importer 1960 Function that will be called whenever optimizations add stuff 1961 to the graph. 1962 pruner 1963 Function to be called when optimizations remove stuff 1964 from the graph. 1965 chin 1966 "on change input" called whenever a node's inputs change. 1967 name 1968 name of the Updater to attach. 1969 1970 Returns 1971 ------- 1972 object 1973 The FunctionGraph plugin that handles the three tasks. 1974 Keep this around so that you can detach later! 1975 1976 """ 1977 if self.ignore_newtrees: 1978 importer = None 1979 1980 if importer is None and pruner is None: 1981 return None 1982 1983 u = Updater(importer, pruner, chin, name=name) 1984 fgraph.attach_feature(u) 1985 return u 1986 1987 def detach_updater(self, fgraph, u): 1988 """ 1989 Undo the work of attach_updater. 1990 1991 Parameters 1992 ---------- 1993 u 1994 A return-value of attach_updater. 1995 1996 Returns 1997 ------- 1998 None 1999 2000 """ 2001 if u is not None: 2002 fgraph.remove_feature(u) 2003 2004 def process_node(self, fgraph, node, lopt=None): 2005 """ 2006 This function will use `lopt` to `transform` the `node`. The 2007 `transform` method will return either False or a list of Variables 2008 that are intended to replace `node.outputs`. 2009 2010 If the fgraph accepts the replacement, then the optimization is 2011 successful, and this function returns True. 2012 2013 If there are no replacement candidates or the fgraph rejects the 2014 replacements, this function returns False. 2015 2016 Parameters 2017 ---------- 2018 fgraph 2019 A FunctionGraph. 2020 node 2021 An Apply instance in `fgraph` 2022 lopt 2023 A LocalOptimizer instance that may have a better idea for 2024 how to compute node's outputs. 2025 2026 Returns 2027 ------- 2028 bool 2029 True iff the `node`'s outputs were replaced in the `fgraph`. 2030 2031 """ 2032 lopt = lopt or self.local_opt 2033 try: 2034 replacements = lopt.transform(node) 2035 except Exception as e: 2036 if self.failure_callback is not None: 2037 self.failure_callback(e, self, 2038 [(x, None) for x in node.outputs], 2039 lopt, node) 2040 return False 2041 else: 2042 raise 2043 if replacements is False or replacements is None: 2044 return False 2045 old_vars = node.outputs 2046 remove = [] 2047 if isinstance(replacements, dict): 2048 if "remove" in replacements: 2049 remove = replacements.pop("remove") 2050 old_vars = list(replacements.keys()) 2051 replacements = list(replacements.values()) 2052 elif not isinstance(replacements, (tuple, list)): 2053 raise TypeError('Optimizer %s gave wrong type of replacement. ' 2054 'Expected list or tuple. Got %s' % ( 2055 lopt, replacements)) 2056 if len(old_vars) != len(replacements): 2057 raise ValueError('Optimizer %s gave wrong number of replacements' 2058 % lopt) 2059 # None in the replacement mean that this variable isn't used 2060 # and we want to remove it 2061 for r, rnew in zip(old_vars, replacements): 2062 if rnew is None and len(r.clients) > 0: 2063 raise ValueError("A local optimizer tried to remove a Variable that is used") 2064 # If an output would be replaced by itself, no need to perform 2065 # the replacement 2066 repl_pairs = [(r, rnew) for r, rnew in zip(old_vars, replacements) 2067 if rnew is not r and rnew is not None] 2068 2069 if len(repl_pairs) == 0: 2070 return False 2071 try: 2072 fgraph.replace_all_validate_remove(repl_pairs, 2073 reason=lopt, 2074 remove=remove) 2075 return True 2076 except Exception as e: 2077 # This means the replacements were rejected by the fgraph. 2078 # 2079 # This is not supposed to happen. The default failure_callback 2080 # will print a traceback as a warning. 2081 if self.failure_callback is not None: 2082 self.failure_callback(e, self, repl_pairs, lopt, node) 2083 return False 2084 else: 2085 raise 2086 2087 def add_requirements(self, fgraph): 2088 super(NavigatorOptimizer, self).add_requirements(fgraph) 2089 # Added by default 2090 # fgraph.attach_feature(toolbox.ReplaceValidate()) 2091 if self.local_opt: 2092 self.local_opt.add_requirements(fgraph) 2093 2094 def print_summary(self, stream=sys.stdout, level=0, depth=-1): 2095 print("%s%s (%i)" % ( 2096 (' ' * level), self.__class__.__name__, id(self)), file=stream) 2097 if depth != 0: 2098 self.local_opt.print_summary(stream, level=(level + 2), 2099 depth=(depth - 1)) 2100 2101 2102class TopoOptimizer(NavigatorOptimizer): 2103 """ 2104 TopoOptimizer has one local optimizer. It tries to apply to each node, in topological order (or reverse). 2105 Each time the local optimizer applies, the node gets replaced, and the topooptimizer moves on to the next one. 2106 2107 """ 2108 2109 def __init__(self, local_opt, order='in_to_out', ignore_newtrees=False, 2110 failure_callback=None): 2111 if order not in ['out_to_in', 'in_to_out']: 2112 raise ValueError("order must be 'out_to_in' or 'in_to_out'") 2113 self.order = order 2114 NavigatorOptimizer.__init__(self, local_opt, ignore_newtrees, 2115 failure_callback) 2116 2117 def apply(self, fgraph, start_from=None): 2118 if start_from is None: 2119 start_from = fgraph.outputs 2120 callback_before = fgraph.execute_callbacks_time 2121 nb_nodes_start = len(fgraph.apply_nodes) 2122 t0 = time.time() 2123 q = deque(graph.io_toposort(fgraph.inputs, start_from)) 2124 io_t = time.time() - t0 2125 2126 def importer(node): 2127 if node is not current_node: 2128 q.append(node) 2129 2130 u = self.attach_updater(fgraph, importer, None, 2131 name=getattr(self, 'name', None)) 2132 nb = 0 2133 try: 2134 t0 = time.time() 2135 while q: 2136 if self.order == 'out_to_in': 2137 node = q.pop() 2138 else: 2139 node = q.popleft() 2140 if node not in fgraph.apply_nodes: 2141 continue 2142 current_node = node 2143 nb += self.process_node(fgraph, node) 2144 loop_t = time.time() - t0 2145 finally: 2146 self.detach_updater(fgraph, u) 2147 2148 callback_time = fgraph.execute_callbacks_time - callback_before 2149 nb_nodes_end = len(fgraph.apply_nodes) 2150 return (self, nb, nb_nodes_start, nb_nodes_end, 2151 io_t, loop_t, callback_time, self.local_opt) 2152 2153 @staticmethod 2154 def print_profile(stream, prof, level=0): 2155 blanc = (' ' * level) 2156 if prof is None: # Happen as merge_profile() isn't implemented 2157 print(blanc, "TopoOptimizer merge_profile not implemented", 2158 file=stream) 2159 return 2160 2161 (opt, nb, nb_nodes_start, nb_nodes_end, 2162 io_t, loop_t, callback_time, lopt) = prof 2163 2164 print(blanc, "TopoOptimizer ", 2165 getattr(opt, "name", getattr(opt, "__name__", "")), file=stream) 2166 2167 print(blanc, " nb_node (start, end, changed)", ( 2168 nb_nodes_start, nb_nodes_end, nb), file=stream) 2169 print(blanc, " init io_toposort", io_t, file=stream) 2170 print(blanc, " loop time", loop_t, file=stream) 2171 print(blanc, " callback_time", callback_time, file=stream) 2172 if isinstance(lopt, LocalOptGroup): 2173 if lopt.profile: 2174 lopt.print_profile(stream, (lopt.time_opts, 2175 lopt.process_count, 2176 lopt.applied_true, 2177 lopt.node_created, 2178 lopt.profile), 2179 level=level + 1) 2180 2181 def __str__(self): 2182 return getattr(self, '__name__', 2183 '<TopoOptimizer instance>') 2184 2185 2186def out2in(*local_opts, **kwargs): 2187 """ 2188 Uses the TopoOptimizer from the output nodes to input nodes of the graph. 2189 """ 2190 name = (kwargs and kwargs.pop('name', None)) 2191 if len(local_opts) > 1: 2192 # Don't wrap it uselessly if their is only 1 optimization. 2193 local_opts = LocalOptGroup(*local_opts) 2194 else: 2195 local_opts, = local_opts 2196 if not name: 2197 name = local_opts.__name__ 2198 ret = TopoOptimizer(local_opts, 2199 order='out_to_in', 2200 failure_callback=TopoOptimizer.warn_inplace, 2201 **kwargs) 2202 if name: 2203 ret.__name__ = name 2204 return ret 2205 2206 2207def in2out(*local_opts, **kwargs): 2208 """ 2209 Uses the TopoOptimizer from the input nodes to output nodes of the graph. 2210 """ 2211 name = (kwargs and kwargs.pop('name', None)) 2212 if len(local_opts) > 1: 2213 # Don't wrap it uselessly if their is only 1 optimization. 2214 local_opts = LocalOptGroup(*local_opts) 2215 else: 2216 local_opts, = local_opts 2217 if not name: 2218 name = local_opts.__name__ 2219 ret = TopoOptimizer(local_opts, 2220 order='in_to_out', 2221 failure_callback=TopoOptimizer.warn_inplace, 2222 **kwargs) 2223 if name: 2224 ret.__name__ = name 2225 return ret 2226 2227 2228class OpKeyOptimizer(NavigatorOptimizer): 2229 """ 2230 WRITEME 2231 2232 """ 2233 2234 def __init__(self, local_opt, ignore_newtrees=False, 2235 failure_callback=None): 2236 if not hasattr(local_opt, 'op_key'): 2237 raise TypeError("LocalOptimizer for OpKeyOptimizer must have " 2238 "an 'op_key' method.") 2239 NavigatorOptimizer.__init__(self, local_opt, ignore_newtrees, 2240 failure_callback) 2241 2242 def apply(self, fgraph): 2243 op = self.local_opt.op_key() 2244 if isinstance(op, (list, tuple)): 2245 q = reduce(list.__iadd__, map(fgraph.get_nodes, op)) 2246 else: 2247 q = list(fgraph.get_nodes(op)) 2248 2249 def importer(node): 2250 if node is not current_node: 2251 if node.op == op: 2252 q.append(node) 2253 2254 u = self.attach_updater(fgraph, importer, None, 2255 name=getattr(self, 'name', None)) 2256 try: 2257 while q: 2258 node = q.pop() 2259 if node not in fgraph.apply_nodes: 2260 continue 2261 current_node = node 2262 self.process_node(fgraph, node) 2263 finally: 2264 self.detach_updater(fgraph, u) 2265 2266 def add_requirements(self, fgraph): 2267 """ 2268 Requires the following features: 2269 - NodeFinder 2270 - ReplaceValidate(Added by default) 2271 2272 """ 2273 super(OpKeyOptimizer, self).add_requirements(fgraph) 2274 fgraph.attach_feature(toolbox.NodeFinder()) 2275 2276 2277class ChangeTracker: 2278 def __init__(self): 2279 self.changed = False 2280 self.nb_imported = 0 2281 2282 def on_import(self, fgraph, node, reason): 2283 self.nb_imported += 1 2284 self.changed = True 2285 2286 def on_change_input(self, fgraph, node, i, r, new_r, reason): 2287 self.changed = True 2288 2289 def reset(self): 2290 self.changed = False 2291 2292 def on_attach(self, fgraph): 2293 fgraph.change_tracker = self 2294 2295 def on_detach(self, fgraph): 2296 del fgraph.change_tracker 2297 2298 2299def merge_dict(d1, d2): 2300 """ 2301 merge 2 dicts by adding the values. 2302 """ 2303 d = d1.copy() 2304 for k, v in iteritems(d2): 2305 if k in d: 2306 d[k] += v 2307 else: 2308 d[k] = v 2309 return d 2310 2311 2312class EquilibriumOptimizer(NavigatorOptimizer): 2313 """ 2314 Apply optimizations until equilibrium point. 2315 2316 Parameters 2317 ---------- 2318 optimizers : list or set 2319 Local or global optimizations to apply until equilibrium. 2320 The global optimizer will be run at the start of each iteration before 2321 the local optimizer. 2322 max_use_ratio : int or float 2323 Each optimizer can be applied at most (size of graph * this number) 2324 times. 2325 ignore_newtrees 2326 See EquilibriumDB ignore_newtrees parameter definition. 2327 final_optimizers 2328 Global optimizers that will be run after each iteration. 2329 cleanup_optimizers 2330 Global optimizers that apply a list of pre determined optimization. 2331 They must not traverse the graph as they are called very frequently. 2332 The MergeOptimizer is one example of optimization that respect this. 2333 They are applied after all global optimizer, then when one local optimizer is applied, then after all final optimizer. 2334 2335 """ 2336 2337 def __init__(self, 2338 optimizers, 2339 failure_callback=None, 2340 ignore_newtrees=True, 2341 tracks_on_change_inputs=False, 2342 max_use_ratio=None, 2343 final_optimizers=None, 2344 cleanup_optimizers=None): 2345 super(EquilibriumOptimizer, self).__init__( 2346 None, 2347 ignore_newtrees=ignore_newtrees, 2348 failure_callback=failure_callback) 2349 self.local_optimizers_map = OrderedDict() 2350 self.local_optimizers_all = [] 2351 self.global_optimizers = [] 2352 self.final_optimizers = [] 2353 self.cleanup_optimizers = [] 2354 self.tracks_on_change_inputs = tracks_on_change_inputs 2355 for opt in optimizers: 2356 if isinstance(opt, LocalOptimizer): 2357 if opt.tracks() is None: 2358 self.local_optimizers_all.append(opt) 2359 else: 2360 for c in opt.tracks(): 2361 self.local_optimizers_map.setdefault(c, []).append(opt) 2362 else: 2363 self.global_optimizers.append(opt) 2364 if final_optimizers: 2365 self.final_optimizers = final_optimizers 2366 if cleanup_optimizers: 2367 self.cleanup_optimizers = cleanup_optimizers 2368 self.max_use_ratio = max_use_ratio 2369 assert self.max_use_ratio is not None, ( 2370 'max_use_ratio has to be a number') 2371 2372 def get_local_optimizers(self): 2373 for opt in self.local_optimizers_all: 2374 yield opt 2375 # if repeat is not a problem we can drop the set 2376 s = set() 2377 for lopt in itervalues(self.local_optimizers_map): 2378 for opt in lopt: 2379 if opt not in s: 2380 yield opt 2381 s.add(opt) 2382 2383 def add_requirements(self, fgraph): 2384 super(EquilibriumOptimizer, self).add_requirements(fgraph) 2385 for opt in self.get_local_optimizers(): 2386 opt.add_requirements(fgraph) 2387 for opt in self.global_optimizers: 2388 opt.add_requirements(fgraph) 2389 for opt in self.final_optimizers: 2390 opt.add_requirements(fgraph) 2391 for opt in self.cleanup_optimizers: 2392 opt.add_requirements(fgraph) 2393 2394 def apply(self, fgraph, start_from=None): 2395 change_tracker = ChangeTracker() 2396 fgraph.attach_feature(change_tracker) 2397 if start_from is None: 2398 start_from = fgraph.outputs 2399 else: 2400 for node in start_from: 2401 assert node in fgraph.outputs 2402 2403 changed = True 2404 max_use_abort = False 2405 opt_name = None 2406 global_process_count = {} 2407 start_nb_nodes = len(fgraph.apply_nodes) 2408 max_nb_nodes = len(fgraph.apply_nodes) 2409 max_use = max_nb_nodes * self.max_use_ratio 2410 2411 loop_timing = [] 2412 loop_process_count = [] 2413 global_opt_timing = [] 2414 time_opts = {} 2415 io_toposort_timing = [] 2416 nb_nodes = [] 2417 node_created = {} 2418 global_sub_profs = [] 2419 final_sub_profs = [] 2420 cleanup_sub_profs = [] 2421 for opt in (self.global_optimizers + 2422 list(self.get_local_optimizers()) + 2423 self.final_optimizers + 2424 self.cleanup_optimizers): 2425 global_process_count.setdefault(opt, 0) 2426 time_opts.setdefault(opt, 0) 2427 node_created.setdefault(opt, 0) 2428 2429 def apply_cleanup(profs_dict): 2430 changed = False 2431 for copt in self.cleanup_optimizers: 2432 change_tracker.reset() 2433 nb = change_tracker.nb_imported 2434 t_opt = time.time() 2435 sub_prof = copt.apply(fgraph) 2436 time_opts[copt] += time.time() - t_opt 2437 profs_dict[copt].append(sub_prof) 2438 if change_tracker.changed: 2439 process_count.setdefault(copt, 0) 2440 process_count[copt] += 1 2441 global_process_count[copt] += 1 2442 changed = True 2443 node_created[copt] += change_tracker.nb_imported - nb 2444 return changed 2445 2446 while changed and not max_use_abort: 2447 process_count = {} 2448 t0 = time.time() 2449 changed = False 2450 iter_cleanup_sub_profs = {} 2451 for copt in self.cleanup_optimizers: 2452 iter_cleanup_sub_profs[copt] = [] 2453 2454 # apply global optimizers 2455 sub_profs = [] 2456 for gopt in self.global_optimizers: 2457 change_tracker.reset() 2458 nb = change_tracker.nb_imported 2459 t_opt = time.time() 2460 sub_prof = gopt.apply(fgraph) 2461 time_opts[gopt] += time.time() - t_opt 2462 sub_profs.append(sub_prof) 2463 if change_tracker.changed: 2464 process_count.setdefault(gopt, 0) 2465 process_count[gopt] += 1 2466 global_process_count[gopt] += 1 2467 changed = True 2468 node_created[gopt] += change_tracker.nb_imported - nb 2469 if global_process_count[gopt] > max_use: 2470 max_use_abort = True 2471 opt_name = (getattr(gopt, "name", None) or 2472 getattr(gopt, "__name__", "")) 2473 global_sub_profs.append(sub_profs) 2474 2475 global_opt_timing.append(float(time.time() - t0)) 2476 2477 # apply clean up as global opt can have done changes that 2478 # request that 2479 changed |= apply_cleanup(iter_cleanup_sub_profs) 2480 2481 # apply local optimizer 2482 topo_t0 = time.time() 2483 q = deque(graph.io_toposort(fgraph.inputs, start_from)) 2484 io_toposort_timing.append(time.time() - topo_t0) 2485 2486 nb_nodes.append(len(q)) 2487 max_nb_nodes = max(max_nb_nodes, len(q)) 2488 max_use = max_nb_nodes * self.max_use_ratio 2489 2490 def importer(node): 2491 if node is not current_node: 2492 q.append(node) 2493 2494 chin = None 2495 if self.tracks_on_change_inputs: 2496 def chin(node, i, r, new_r, reason): 2497 if node is not current_node and not isinstance(node, str): 2498 q.append(node) 2499 u = self.attach_updater(fgraph, importer, None, 2500 chin=chin, 2501 name=getattr(self, 'name', None)) 2502 try: 2503 while q: 2504 node = q.pop() 2505 if node not in fgraph.apply_nodes: 2506 continue 2507 current_node = node 2508 for lopt in (self.local_optimizers_all + 2509 self.local_optimizers_map.get(type(node.op), []) + 2510 self.local_optimizers_map.get(node.op, [])): 2511 nb = change_tracker.nb_imported 2512 t_opt = time.time() 2513 lopt_change = self.process_node(fgraph, node, lopt) 2514 time_opts[lopt] += time.time() - t_opt 2515 if not lopt_change: 2516 continue 2517 process_count.setdefault(lopt, 0) 2518 process_count[lopt] += 1 2519 global_process_count[lopt] += 1 2520 changed = True 2521 node_created[lopt] += change_tracker.nb_imported - nb 2522 changed |= apply_cleanup(iter_cleanup_sub_profs) 2523 if global_process_count[lopt] > max_use: 2524 max_use_abort = True 2525 opt_name = (getattr(lopt, "name", None) or 2526 getattr(lopt, "__name__", "")) 2527 if node not in fgraph.apply_nodes: 2528 # go to next node 2529 break 2530 finally: 2531 self.detach_updater(fgraph, u) 2532 2533 # Apply final optimizers 2534 sub_profs = [] 2535 t_before_final_opt = time.time() 2536 for gopt in self.final_optimizers: 2537 change_tracker.reset() 2538 nb = change_tracker.nb_imported 2539 t_opt = time.time() 2540 sub_prof = gopt.apply(fgraph) 2541 time_opts[gopt] += time.time() - t_opt 2542 sub_profs.append(sub_prof) 2543 if change_tracker.changed: 2544 process_count.setdefault(gopt, 0) 2545 process_count[gopt] += 1 2546 global_process_count[gopt] += 1 2547 changed = True 2548 node_created[gopt] += change_tracker.nb_imported - nb 2549 if global_process_count[gopt] > max_use: 2550 max_use_abort = True 2551 opt_name = (getattr(gopt, "name", None) or 2552 getattr(gopt, "__name__", "")) 2553 final_sub_profs.append(sub_profs) 2554 2555 global_opt_timing[-1] += time.time() - t_before_final_opt 2556 # apply clean up as final opt can have done changes that 2557 # request that 2558 changed |= apply_cleanup(iter_cleanup_sub_profs) 2559 # merge clean up profiles during that iteration. 2560 c_sub_profs = [] 2561 for copt, sub_profs in iteritems(iter_cleanup_sub_profs): 2562 sub_prof = sub_profs[0] 2563 for s_p in sub_profs[1:]: 2564 sub_prof = copt.merge_profile(sub_prof, s_p) 2565 c_sub_profs.append(sub_prof) 2566 cleanup_sub_profs.append(c_sub_profs) 2567 2568 loop_process_count.append(process_count) 2569 loop_timing.append(float(time.time() - t0)) 2570 2571 end_nb_nodes = len(fgraph.apply_nodes) 2572 2573 if max_use_abort: 2574 msg = ("EquilibriumOptimizer max'ed out by '%s'" % opt_name + 2575 ". You can safely raise the current threshold of " + 2576 "%f with the theano flag 'optdb.max_use_ratio'." % 2577 config.optdb.max_use_ratio) 2578 if theano.config.on_opt_error == 'raise': 2579 raise AssertionError(msg) 2580 else: 2581 _logger.error(msg) 2582 fgraph.remove_feature(change_tracker) 2583 assert len(loop_process_count) == len(loop_timing) 2584 assert len(loop_process_count) == len(global_opt_timing) 2585 assert len(loop_process_count) == len(nb_nodes) 2586 assert len(loop_process_count) == len(io_toposort_timing) 2587 assert len(loop_process_count) == len(global_sub_profs) 2588 assert len(loop_process_count) == len(final_sub_profs) 2589 assert len(loop_process_count) == len(cleanup_sub_profs) 2590 return (self, loop_timing, loop_process_count, 2591 (start_nb_nodes, end_nb_nodes, max_nb_nodes), 2592 global_opt_timing, nb_nodes, time_opts, io_toposort_timing, 2593 node_created, global_sub_profs, final_sub_profs, 2594 cleanup_sub_profs) 2595 2596 def print_summary(self, stream=sys.stdout, level=0, depth=-1): 2597 name = getattr(self, 'name', None) 2598 print("%s%s %s id=%i" % ( 2599 (' ' * level), self.__class__.__name__, name, id(self)), file=stream) 2600 if depth != 0: 2601 for lopt in self.get_local_optimizers(): 2602 lopt.print_summary(stream, level=(level + 2), 2603 depth=(depth - 1)) 2604 2605 @staticmethod 2606 def print_profile(stream, prof, level=0): 2607 (opt, loop_timing, loop_process_count, 2608 (start_nb_nodes, end_nb_nodes, max_nb_nodes), 2609 global_opt_timing, nb_nodes, time_opts, io_toposort_timing, 2610 node_created, global_sub_profs, final_sub_profs, 2611 cleanup_sub_profs) = prof 2612 2613 blanc = (' ' * level) 2614 print(blanc, "EquilibriumOptimizer", end=' ', file=stream) 2615 print(blanc, getattr(opt, "name", 2616 getattr(opt, "__name__", "")), file=stream) 2617 print(blanc, " time %.3fs for %d passes" % ( 2618 sum(loop_timing), len(loop_timing)), file=stream) 2619 print(blanc, " nb nodes (start, end, max) %d %d %d" % ( 2620 start_nb_nodes, end_nb_nodes, max_nb_nodes), file=stream) 2621 print(blanc, " time io_toposort %.3fs" % sum( 2622 io_toposort_timing), file=stream) 2623 s = sum([time_opts[o] for o in opt.get_local_optimizers()]) 2624 print(blanc, " time in local optimizers %.3fs" % s, file=stream) 2625 s = sum([time_opts[o] for o in opt.global_optimizers]) 2626 print(blanc, " time in global optimizers %.3fs" % s, file=stream) 2627 s = sum([time_opts[o] for o in opt.final_optimizers]) 2628 print(blanc, " time in final optimizers %.3fs" % s, file=stream) 2629 s = sum([time_opts[o] for o in opt.cleanup_optimizers]) 2630 print(blanc, " time in cleanup optimizers %.3fs" % s, file=stream) 2631 for i in range(len(loop_timing)): 2632 lopt = "" 2633 if loop_process_count[i]: 2634 d = list(reversed(sorted(iteritems(loop_process_count[i]), 2635 key=lambda a: a[1]))) 2636 lopt = " ".join([str((str(k), v)) for k, v 2637 in d[:5]]) 2638 if len(d) > 5: 2639 lopt += " ..." 2640 print(blanc, (' %2d - %.3fs %d (%.3fs in global opts, ' 2641 '%.3fs io_toposort) - %d nodes - %s' % ( 2642 i, loop_timing[i], 2643 sum(loop_process_count[i].values()), 2644 global_opt_timing[i], 2645 io_toposort_timing[i], nb_nodes[i], 2646 lopt)), file=stream) 2647 2648 count_opt = [] 2649 not_used = [] 2650 not_used_time = 0 2651 process_count = {} 2652 for o in (opt.global_optimizers + 2653 list(opt.get_local_optimizers()) + 2654 list(opt.final_optimizers) + 2655 list(opt.cleanup_optimizers)): 2656 process_count.setdefault(o, 0) 2657 for count in loop_process_count: 2658 for o, v in iteritems(count): 2659 process_count[o] += v 2660 for o, count in iteritems(process_count): 2661 if count > 0: 2662 count_opt.append((time_opts[o], count, 2663 node_created[o], o)) 2664 else: 2665 not_used.append((time_opts[o], o)) 2666 not_used_time += time_opts[o] 2667 2668 if count_opt: 2669 print(blanc, 2670 ' times - times applied - nb node created - name:', 2671 file=stream) 2672 count_opt.sort() 2673 for (t, count, n_created, o) in count_opt[::-1]: 2674 print(blanc, ' %.3fs - %d - %d - %s' % ( 2675 t, count, n_created, o), file=stream) 2676 print(blanc, ' %.3fs - in %d optimization that were not used (display only those with a runtime > 0)' % ( 2677 not_used_time, len(not_used)), file=stream) 2678 not_used.sort(key=lambda nu: (nu[0], str(nu[1]))) 2679 for (t, o) in not_used[::-1]: 2680 if t > 0: 2681 # Skip opt that have 0 times, they probably wasn't even tried. 2682 print(blanc + " ", ' %.3fs - %s' % (t, o), file=stream) 2683 print(file=stream) 2684 gf_opts = [o for o in (opt.global_optimizers + 2685 list(opt.final_optimizers) + 2686 list(opt.cleanup_optimizers)) 2687 if o.print_profile.__code__ is not 2688 Optimizer.print_profile.__code__] 2689 if not gf_opts: 2690 return 2691 print(blanc, "Global, final and clean up optimizers", file=stream) 2692 for i in range(len(loop_timing)): 2693 print(blanc, "Iter %d" % i, file=stream) 2694 for o, prof in zip(opt.global_optimizers, global_sub_profs[i]): 2695 try: 2696 o.print_profile(stream, prof, level + 2) 2697 except NotImplementedError: 2698 print(blanc, "merge not implemented for ", o) 2699 for o, prof in zip(opt.final_optimizers, final_sub_profs[i]): 2700 try: 2701 o.print_profile(stream, prof, level + 2) 2702 except NotImplementedError: 2703 print(blanc, "merge not implemented for ", o) 2704 for o, prof in zip(opt.cleanup_optimizers, cleanup_sub_profs[i]): 2705 try: 2706 o.print_profile(stream, prof, level + 2) 2707 except NotImplementedError: 2708 print(blanc, "merge not implemented for ", o) 2709 2710 @staticmethod 2711 def merge_profile(prof1, prof2): 2712 # (opt, loop_timing, loop_process_count, max_nb_nodes, 2713 # global_opt_timing, nb_nodes, time_opts, io_toposort_timing) = prof1 2714 local_optimizers = OrderedSet(prof1[0].get_local_optimizers()).union( 2715 prof2[0].get_local_optimizers()) 2716 global_optimizers = OrderedSet(prof1[0].global_optimizers).union( 2717 prof2[0].global_optimizers) 2718 final_optimizers = list(OrderedSet(prof1[0].final_optimizers).union( 2719 prof2[0].final_optimizers)) 2720 cleanup_optimizers = list(OrderedSet(prof1[0].cleanup_optimizers).union( 2721 prof2[0].cleanup_optimizers)) 2722 new_opt = EquilibriumOptimizer( 2723 local_optimizers.union(global_optimizers), 2724 max_use_ratio=1, 2725 final_optimizers=final_optimizers, 2726 cleanup_optimizers=cleanup_optimizers) 2727 2728 def add_append_list(l1, l2): 2729 l = copy.copy(l1) 2730 for idx, nb in enumerate(l2): 2731 if idx < len(l): 2732 l[idx] += nb 2733 else: 2734 l.append(nb) 2735 return l 2736 2737 loop_timing = add_append_list(prof1[1], prof2[1]) 2738 2739 loop_process_count = list(prof1[2]) 2740 global_sub_profs = [] 2741 final_sub_profs = [] 2742 cleanup_sub_profs = [] 2743 2744 for i in range(min(len(loop_process_count), len(prof2[2]))): 2745 process_count = loop_process_count[i] 2746 for process, count in iteritems(prof2[2][i]): 2747 if process in process_count: 2748 process_count[process] += count 2749 else: 2750 process_count[process] = count 2751 2752 def merge(opts, attr, idx): 2753 tmp = [] 2754 for opt in opts: 2755 o1 = getattr(prof1[0], attr) 2756 o2 = getattr(prof2[0], attr) 2757 if opt in o1 and opt in o2: 2758 p1 = prof1[idx][i][o1.index(opt)] 2759 p2 = prof2[idx][i][o2.index(opt)] 2760 m = None 2761 if hasattr(opt, 'merge_profile'): 2762 m = opt.merge_profile(p1, p2) 2763 elif opt in o1: 2764 m = prof1[idx][i][o1.index(opt)] 2765 else: 2766 m = prof2[idx][i][o2.index(opt)] 2767 tmp.append(m) 2768 return tmp 2769 global_sub_profs.append(merge(global_optimizers, 'global_optimizers', 9)) 2770 final_sub_profs.append(merge(final_optimizers, 'final_optimizers', 10)) 2771 cleanup_sub_profs.append(merge(cleanup_optimizers, 'cleanup_optimizers', 11)) 2772 2773 # Add the iteration done by only one of the profile. 2774 loop_process_count.extend(prof1[2][len(loop_process_count):]) 2775 global_sub_profs.extend(prof1[9][len(global_sub_profs):]) 2776 final_sub_profs.extend(prof1[10][len(final_sub_profs):]) 2777 cleanup_sub_profs.extend(prof1[11][len(cleanup_sub_profs):]) 2778 2779 global_sub_profs.extend(prof2[9][len(loop_process_count):]) 2780 final_sub_profs.extend(prof2[10][len(loop_process_count):]) 2781 cleanup_sub_profs.extend(prof2[11][len(loop_process_count):]) 2782 2783 max_nb_nodes = max(prof1[3], prof2[3]) 2784 2785 global_opt_timing = add_append_list(prof1[4], prof2[4]) 2786 2787 nb_nodes = add_append_list(prof1[5], prof2[5]) 2788 2789 time_opts = merge_dict(prof1[6], prof2[6]) 2790 io_toposort_timing = add_append_list(prof1[7], prof2[7]) 2791 assert (len(loop_timing) == len(global_opt_timing) == 2792 len(global_sub_profs) == 2793 len(io_toposort_timing) == len(nb_nodes)) 2794 assert len(loop_timing) == max(len(prof1[1]), len(prof2[1])) 2795 2796 node_created = merge_dict(prof1[8], prof2[8]) 2797 return (new_opt, 2798 loop_timing, 2799 loop_process_count, 2800 max_nb_nodes, 2801 global_opt_timing, 2802 nb_nodes, 2803 time_opts, 2804 io_toposort_timing, 2805 node_created, 2806 global_sub_profs, 2807 final_sub_profs, 2808 cleanup_sub_profs) 2809 2810################# 2811# Utilities # 2812################# 2813 2814 2815def _check_chain(r, chain): 2816 """ 2817 WRITEME 2818 2819 """ 2820 chain = list(reversed(chain)) 2821 while chain: 2822 elem = chain.pop() 2823 if elem is None: 2824 if r.owner is not None: 2825 return False 2826 elif r.owner is None: 2827 return False 2828 elif isinstance(elem, op.Op): 2829 if not r.owner.op == elem: 2830 return False 2831 else: 2832 try: 2833 if (issubclass(elem, op.Op) and 2834 not isinstance(r.owner.op, elem)): 2835 return False 2836 except TypeError: 2837 return False 2838 if chain: 2839 r = r.owner.inputs[chain.pop()] 2840 # print 'check_chain', _check_chain.n_calls 2841 # _check_chain.n_calls += 1 2842 2843 # The return value will be used as a Boolean, but some Variables cannot 2844 # be used as Booleans (the results of comparisons, for instance) 2845 return (r is not None) 2846# _check_chain.n_calls = 0 2847 2848 2849def check_chain(r, *chain): 2850 """ 2851 WRITEME 2852 2853 """ 2854 if isinstance(r, graph.Apply): 2855 r = r.outputs[0] 2856 return _check_chain(r, reduce(list.__iadd__, ([x, 0] for x in chain))) 2857 2858 2859def pre_greedy_local_optimizer(list_optimizations, out): 2860 """ 2861 This function traverses the computation graph described by all 2862 ``node`` in the graph before the variable out but that are not in the 2863 fgraph. It applies each of the local_optimizations on the traversed graph. 2864 2865 Its main use is to apply locally constant folding when generating 2866 the graph of the indices of a subtensor. 2867 2868 We should not apply optimizations on node that are in fgraph. 2869 So we don't optimize node that have an attribute fgraph. 2870 2871 Notes 2872 ----- 2873 This doesn't do an equilibrium... So if there is optimization 2874 like local_upcast_elemwise_constant_inputs in the list, that 2875 adds additional node to the inputs of the node, it can 2876 be needed to call this function multiple times. 2877 2878 """ 2879 def local_recursive_function(list_opt, out, optimized_vars, depth): 2880 if not getattr(out, 'owner', None): 2881 return [out], optimized_vars 2882 node = out.owner 2883 2884 if hasattr(node, 'fgraph'): 2885 return node.outputs, optimized_vars 2886 for idx, inp in enumerate(node.inputs): 2887 if inp in optimized_vars: 2888 nw_in = optimized_vars[inp] 2889 else: 2890 if inp.owner: 2891 outs, optimized_vars = local_recursive_function( 2892 list_opt, 2893 inp, 2894 optimized_vars, 2895 depth + 1) 2896 for k, v in zip(inp.owner.outputs, outs): 2897 optimized_vars[k] = v 2898 nw_in = outs[inp.owner.outputs.index(inp)] 2899 2900 else: 2901 nw_in = inp 2902 optimized_vars[inp] = inp 2903 node.inputs[idx] = nw_in 2904 2905 results = node.outputs 2906 for opt in list_opt: 2907 ret = opt.transform(node) 2908 if ret is not False and ret is not None: 2909 assert len(ret) == len(node.outputs), opt 2910 for k, v in zip(node.outputs, ret): 2911 optimized_vars[k] = v 2912 results = ret 2913 if ret[0].owner: 2914 node = out.owner 2915 else: 2916 break 2917 return results, optimized_vars 2918 if out.owner: 2919 out_index = out.owner.outputs.index(out) 2920 else: 2921 out_index = 0 2922 final_outs, optimized_nodes = local_recursive_function( 2923 list_optimizations, out, {}, 0) 2924 return final_outs[out_index] 2925 2926 2927def copy_stack_trace(from_var, to_var): 2928 """ 2929 Copies the stack trace from one or more tensor variables to 2930 one or more tensor variables and returns the destination variables. 2931 2932 Parameters 2933 ---------- 2934 from_var 2935 Tensor variable or list of tensor variables to copy stack traces from. 2936 to_var 2937 Tensor variable or list of tensor variables to copy stack traces to. 2938 2939 Notes 2940 ----- 2941 The stacktrace is assumed to be of the form of a list of lists 2942 of tuples. Each tuple contains the filename, line number, function name 2943 and so on. Each list of tuples contains the truples belonging to a 2944 particular variable. 2945 2946 """ 2947 2948 # Store stack traces from from_var 2949 tr = [] 2950 if type(from_var) is list: 2951 # If from_var is a list, store concatenated stack traces 2952 for v in from_var: 2953 tr += getattr(v.tag, 'trace', []) 2954 2955 else: 2956 # If from_var is not a list, it must be a single tensor variable, 2957 # so just store that particular stack trace 2958 tr = getattr(from_var.tag, 'trace', []) 2959 2960 if tr and isinstance(tr[0], tuple): 2961 # There was one single stack trace, we encapsulate it in a list 2962 tr = [tr] 2963 2964 # Copy over stack traces to to_var 2965 if type(to_var) is list: 2966 # Copy over stack traces from from_var to each variable in 2967 # to_var, including the stack_trace of the to_var before 2968 for v in to_var: 2969 v.tag.trace = getattr(v.tag, 'trace', []) + tr 2970 else: 2971 # Copy over stack traces from from_var to each variable to 2972 # to_var, including the stack_trace of the to_var before 2973 to_var.tag.trace = getattr(to_var.tag, 'trace', []) + tr 2974 return to_var 2975 2976 2977@contextlib.contextmanager 2978def inherit_stack_trace(from_var): 2979 """ 2980 Contextmanager that copies the stack trace from one or more variable nodes to all 2981 variable nodes constructed in the body. new_nodes is the list of all the newly created 2982 variable nodes inside an optimization that is managed by graph.nodes_constructed(). 2983 2984 Parameters 2985 ---------- 2986 from_var 2987 Variable node or a list of variable nodes to copy stack traces from. 2988 2989 """ 2990 with graph.nodes_constructed() as new_nodes: 2991 yield 2992 copy_stack_trace(from_var, new_nodes) 2993 2994 2995def check_stack_trace(f_or_fgraph, ops_to_check='last', bug_print='raise'): 2996 """ 2997 This function checks if the outputs of specific ops of a compiled graph 2998 have a stack. 2999 3000 Parameters 3001 ---------- 3002 f_or_fgraph: theano.compile.function_module.Function or 3003 theano.gof.fg.FunctionGraph 3004 The compiled function or the function graph to be analysed. 3005 ops_to_check: it can be of four different types: 3006 - classes or instances inheriting from theano.gof.Op 3007 - tuple/list of classes or instances inheriting from theano.gof.Op 3008 - string 3009 - function returning a boolean and taking as input an instance of 3010 theano.gof.Op. 3011 - if ops_to_check is a string, it should be either 'last' or 'all'. 3012 'last' will check only the last op of the graph while 'all' will 3013 check all the ops of the graph. 3014 - if ops_to_check is an op or a tuple/list of ops, the function will 3015 check that all the outputs of their occurrences in the graph have a 3016 stack trace. 3017 - if ops_to_check is a function, it should take as input a 3018 theano.gof.Op and return a boolean indicating if the input op should 3019 be checked or not. 3020 bug_print: string belonging to {'raise', 'warn', 'ignore'} 3021 You can specify the behaviour of the function when the specified 3022 ops_to_check are not in the graph of f_or_fgraph: it can either raise 3023 an exception, write a warning or simply ignore it. 3024 3025 Returns 3026 ------- 3027 boolean 3028 True if the outputs of the specified ops have a stack, False otherwise. 3029 3030 """ 3031 if isinstance(f_or_fgraph, theano.compile.function_module.Function): 3032 fgraph = f_or_fgraph.maker.fgraph 3033 elif isinstance(f_or_fgraph, theano.gof.fg.FunctionGraph): 3034 fgraph = f_or_fgraph 3035 else: 3036 raise ValueError('The type of f_or_fgraph is not supported') 3037 3038 if (isinstance(ops_to_check, theano.gof.Op) or 3039 (inspect.isclass(ops_to_check) and 3040 issubclass(ops_to_check, theano.gof.Op))): 3041 ops_to_check = (ops_to_check,) 3042 3043 # if ops_to_check is a string 3044 if isinstance(ops_to_check, string_types): 3045 if ops_to_check == 'last': 3046 apply_nodes_to_check = [fgraph.outputs[i].owner for i in range( 3047 len(fgraph.outputs))] 3048 elif ops_to_check == 'all': 3049 apply_nodes_to_check = fgraph.apply_nodes 3050 else: 3051 raise ValueError('The string ops_to_check is not recognised') 3052 3053 # if ops_to_check is a list/tuple of ops 3054 elif isinstance(ops_to_check, (tuple, list)): 3055 # Separate classes from instances in ops_to_check 3056 op_instances = [] 3057 op_classes = [] 3058 for obj in ops_to_check: 3059 if isinstance(obj, theano.gof.Op): 3060 op_instances.append(obj) 3061 else: 3062 op_classes.append(obj) 3063 op_classes = tuple(op_classes) 3064 3065 apply_nodes_to_check = ( 3066 [node for node in fgraph.apply_nodes if node.op in ops_to_check] + 3067 [node for node in fgraph.apply_nodes 3068 if isinstance(node.op, op_classes) or 3069 (hasattr(node.op, 'scalar_op') and 3070 isinstance(node.op.scalar_op, op_classes))]) 3071 3072 # if ops_to_check is a function 3073 elif hasattr(ops_to_check, '__call__'): 3074 apply_nodes_to_check = [node for node in fgraph.apply_nodes 3075 if ops_to_check(node)] 3076 3077 else: 3078 raise ValueError('ops_to_check does not have the right type') 3079 3080 if not apply_nodes_to_check: 3081 msg = 'Provided op instances/classes are not in the graph or the ' \ 3082 'graph is empty' 3083 if bug_print == 'warn': 3084 warnings.warn(msg) 3085 elif bug_print == 'raise': 3086 raise Exception(msg) 3087 elif bug_print == 'ignore': 3088 pass 3089 else: 3090 raise ValueError('The string bug_print is not recognised') 3091 3092 for node in apply_nodes_to_check: 3093 for output in node.outputs: 3094 if (not hasattr(output.tag, 'trace') or not output.tag.trace): 3095 return False 3096 3097 return True 3098 3099 3100class CheckStrackTraceFeature(object): 3101 def on_import(self, fgraph, node, reason): 3102 # In optdb we only register the CheckStackTraceOptimization when 3103 # theano.config.check_stack_trace is not off but we also double check here. 3104 if theano.config.check_stack_trace != 'off' and not check_stack_trace(fgraph, 'all'): 3105 if theano.config.check_stack_trace == 'raise': 3106 raise AssertionError( 3107 'Empty stack trace! The optimization that inserted this variable is ' + str(reason)) 3108 elif theano.config.check_stack_trace in ['log', 'warn']: 3109 apply_nodes_to_check = fgraph.apply_nodes 3110 for node in apply_nodes_to_check: 3111 for output in node.outputs: 3112 if not hasattr(output.tag, 'trace') or not output.tag.trace: 3113 output.tag.trace = [[('', 0, 'Empty stack trace! The optimization that' + 3114 'inserted this variable is ' + str(reason), '')]] 3115 if theano.config.check_stack_trace == 'warn': 3116 warnings.warn( 3117 'Empty stack trace! The optimization that inserted this variable is' + str(reason)) 3118 3119 3120class CheckStackTraceOptimization(Optimizer): 3121 """Optimizer that serves to add CheckStackTraceOptimization as an fgraph feature.""" 3122 3123 def add_requirements(self, fgraph): 3124 if not hasattr(fgraph, 'CheckStrackTraceFeature'): 3125 fgraph.attach_feature(CheckStrackTraceFeature()) 3126 3127 def apply(self, fgraph): 3128 pass 3129