1""" 2Node classes (`Apply`, `Variable`) and expression graph algorithms. 3""" 4from __future__ import absolute_import, print_function, division 5 6from collections import deque 7import contextlib 8from copy import copy 9from itertools import count 10 11import warnings 12 13import theano 14from theano import config 15from theano.gof import utils 16from six import string_types, integer_types, iteritems 17from theano.misc.ordered_set import OrderedSet 18 19__docformat__ = "restructuredtext en" 20 21# Lazy imports to avoid circular dependencies. 22is_same_graph_with_merge = None 23equal_computations = None 24 25NoParams = object() 26 27 28class Node(utils.object2): 29 """ 30 A Node in a theano graph. 31 32 Graphs contain two kinds of Nodes -- Variable and Apply. 33 Edges in the graph are not explicitly represented. 34 Instead each Node keeps track of its parents via 35 Variable.owner / Apply.inputs and its children 36 via Variable.clients / Apply.outputs. 37 38 """ 39 40 def get_parents(self): 41 """ 42 Return a list of the parents of this node. 43 Should return a copy--i.e., modifying the return 44 value should not modify the graph structure. 45 46 """ 47 raise NotImplementedError() 48 49 50class Apply(Node): 51 """ 52 An :term:`Apply` instance is a node in an expression graph which represents 53 the application of an `Op` to some input `Variable` nodes, producing some 54 output `Variable` nodes. 55 56 This class is typically instantiated by an Op's make_node() function, which 57 is typically called by that Op's __call__() function. 58 59 An Apply instance serves as a simple structure with three important 60 attributes: 61 62 - :literal:`inputs` : a list of `Variable` nodes that represent the 63 arguments of the expression, 64 65 - :literal:`outputs` : a list of `Variable` nodes that represent the 66 variable of the expression, and 67 68 - :literal:`op` : an `Op` instance that determines the nature of the 69 expression being applied. 70 71 The driver `compile.function` uses Apply's inputs attribute together with 72 Variable's owner attribute to search the expression graph and determine 73 which inputs are necessary to compute the function's outputs. 74 75 A `Linker` uses the Apply instance's `op` field to compute the variables. 76 77 Comparing with the Python language, an `Apply` instance is theano's version 78 of a function call (or expression instance) whereas `Op` is theano's version 79 of a function definition. 80 81 Parameters 82 ---------- 83 op : `Op` instance 84 inputs : list of Variable instances 85 outputs : list of Variable instances 86 87 Notes 88 ----- 89 The owner field of each output in the outputs list will be set to self. 90 91 If an output element has an owner that is neither None nor self, then a 92 ValueError exception will be raised. 93 94 """ 95 96 def __init__(self, op, inputs, outputs): 97 self.op = op 98 self.inputs = [] 99 self.tag = utils.scratchpad() 100 101 if not isinstance(inputs, (list, tuple)): 102 raise TypeError("The inputs of an Apply must be a list or tuple") 103 104 if not isinstance(outputs, (list, tuple)): 105 raise TypeError("The output of an Apply must be a list or tuple") 106 107 # filter inputs to make sure each element is a Variable 108 for input in inputs: 109 if isinstance(input, Variable): 110 self.inputs.append(input) 111 else: 112 raise TypeError("The 'inputs' argument to Apply must contain Variable instances, not %s" % input) 113 self.outputs = [] 114 # filter outputs to make sure each element is a Variable 115 for i, output in enumerate(outputs): 116 if isinstance(output, Variable): 117 if output.owner is None: 118 output.owner = self 119 output.index = i 120 elif output.owner is not self or output.index != i: 121 raise ValueError("All output variables passed to Apply must belong to it.") 122 self.outputs.append(output) 123 else: 124 raise TypeError("The 'outputs' argument to Apply must contain Variable instances with no owner, not %s" % output) 125 126 def run_params(self): 127 """ 128 Returns the params for the node, or NoParams if no params is set. 129 130 """ 131 try: 132 return self.op.get_params(self) 133 except theano.gof.utils.MethodNotDefined: 134 return NoParams 135 136 def __getstate__(self): 137 d = self.__dict__ 138 # ufunc don't pickle/unpickle well 139 if hasattr(self.tag, 'ufunc'): 140 d = copy(self.__dict__) 141 t = d["tag"] 142 del t.ufunc 143 d["tag"] = t 144 return d 145 146 def default_output(self): 147 """ 148 Returns the default output for this node. 149 150 Returns 151 ------- 152 Variable instance 153 An element of self.outputs, typically self.outputs[0]. 154 155 Notes 156 ----- 157 May raise AttributeError self.op.default_output is out of range, or if 158 there are multiple outputs and self.op.default_output does not exist. 159 160 """ 161 do = getattr(self.op, 'default_output', None) 162 if do is None: 163 if len(self.outputs) == 1: 164 return self.outputs[0] 165 else: 166 raise AttributeError( 167 "%s.default_output should be an output index." % self.op) 168 elif not isinstance(do, integer_types): 169 raise AttributeError("%s.default_output should be an int or long" % 170 self.op) 171 elif do < 0 or do >= len(self.outputs): 172 raise AttributeError("%s.default_output is out of range." % 173 self.op) 174 return self.outputs[do] 175 176 out = property(default_output, 177 doc="alias for self.default_output()") 178 """ 179 Alias for self.default_output(). 180 181 """ 182 183 def __str__(self): 184 return op_as_string(self.inputs, self) 185 186 def __repr__(self): 187 return str(self) 188 189 def __asapply__(self): 190 return self 191 192 def clone(self): 193 """ 194 Duplicate this Apply instance with inputs = self.inputs. 195 196 Returns 197 ------- 198 object 199 A new Apply instance (or subclass instance) with new outputs. 200 201 Notes 202 ----- 203 Tags are copied from self to the returned instance. 204 205 """ 206 cp = self.__class__(self.op, self.inputs, 207 [output.clone() for output in self.outputs]) 208 cp.tag = copy(self.tag) 209 return cp 210 211 def clone_with_new_inputs(self, inputs, strict=True): 212 """ 213 Duplicate this Apply instance in a new graph. 214 215 Parameters 216 ---------- 217 inputs 218 List of Variable instances to use as inputs. 219 strict : bool 220 If True, the type fields of all the inputs must be equal 221 to the current ones (or compatible, for instance Tensor / 222 GpuArray of the same dtype and broadcastable patterns, 223 in which case they will be converted into current Type), and 224 returned outputs are guaranteed to have the same types as 225 self.outputs. If False, then there's no guarantee that the 226 clone's outputs will have the same types as self.outputs, 227 and cloning may not even be possible (it depends on the Op). 228 229 Returns 230 ------- 231 object 232 An Apply instance with the same op but different outputs. 233 234 """ 235 assert isinstance(inputs, (list, tuple)) 236 remake_node = False 237 new_inputs = inputs[:] 238 for i, (curr, new) in enumerate(zip(self.inputs, new_inputs)): 239 if not curr.type == new.type: 240 if strict: 241 # If compatible, casts new into curr.type 242 new_inputs[i] = curr.type.filter_variable(new) 243 else: 244 remake_node = True 245 if remake_node: 246 new_node = self.op.make_node(*new_inputs) 247 new_node.tag = copy(self.tag).__update__(new_node.tag) 248 else: 249 new_node = self.clone() 250 new_node.inputs = new_inputs 251 return new_node 252 253 def get_parents(self): 254 return list(self.inputs) 255 256 # convenience properties 257 nin = property(lambda self: len(self.inputs), doc='same as len(self.inputs)') 258 """ 259 Property: Number of inputs. 260 261 """ 262 nout = property(lambda self: len(self.outputs), doc='same as len(self.outputs)') 263 """ 264 Property: Number of outputs. 265 266 """ 267 params_type = property(lambda self: self.op.params_type, doc='type to use for the params') 268 269 270class Variable(Node): 271 """ 272 A :term:`Variable` is a node in an expression graph that represents a 273 variable. 274 275 The inputs and outputs of every `Apply` (theano.gof.Apply) are `Variable` 276 instances. The input and output arguments to create a `function` are also 277 `Variable` instances. A `Variable` is like a strongly-typed variable in 278 some other languages; each `Variable` contains a reference to a `Type` 279 instance that defines the kind of value the `Variable` can take in a 280 computation. 281 282 A `Variable` is a container for four important attributes: 283 284 - :literal:`type` a `Type` instance defining the kind of value this 285 `Variable` can have, 286 287 - :literal:`owner` either None (for graph roots) or the `Apply` instance 288 of which `self` is an output, 289 290 - :literal:`index` the integer such that :literal:`owner.outputs[index] is 291 this_variable` (ignored if `owner` is None), 292 293 - :literal:`name` a string to use in pretty-printing and debugging. 294 295 There are a few kinds of Variables to be aware of: A Variable which is the 296 output of a symbolic computation has a reference to the Apply instance to 297 which it belongs (property: owner) and the position of itself in the owner's 298 output list (property: index). 299 300 - `Variable` (this base type) is typically the output of a symbolic 301 computation. 302 303 - `Constant` (a subclass) which adds a default and un-replaceable 304 :literal:`value`, and requires that owner is None. 305 306 - `TensorVariable` subclass of Variable that represents a numpy.ndarray 307 object. 308 309 - `TensorSharedVariable` Shared version of TensorVariable. 310 311 - `SparseVariable` subclass of Variable that represents 312 a scipy.sparse.{csc,csr}_matrix object. 313 314 - `GpuArrayVariable` subclass of Variable that represents our object on 315 the GPU that is a subset of numpy.ndarray. 316 317 - `RandomVariable`. 318 319 A Variable which is the output of a symbolic computation will have an owner 320 not equal to None. 321 322 Using the Variables' owner field and the Apply nodes' inputs fields, one can 323 navigate a graph from an output all the way to the inputs. The opposite 324 direction is not possible until a FunctionGraph has annotated the Variables 325 with the clients field, ie, before the compilation process has begun a 326 Variable does not know which Apply nodes take it as input. 327 328 Parameters 329 ---------- 330 type : a Type instance 331 The type governs the kind of data that can be associated with this 332 variable. 333 owner : None or Apply instance 334 The Apply instance which computes the value for this variable. 335 index : None or int 336 The position of this Variable in owner.outputs. 337 name : None or str 338 A string for pretty-printing and debugging. 339 340 Examples 341 -------- 342 343 .. code-block:: python 344 345 import theano 346 from theano import tensor 347 348 a = tensor.constant(1.5) # declare a symbolic constant 349 b = tensor.fscalar() # declare a symbolic floating-point scalar 350 351 c = a + b # create a simple expression 352 353 f = theano.function([b], [c]) # this works because a has a value associated with it already 354 355 assert 4.0 == f(2.5) # bind 2.5 to an internal copy of b and evaluate an internal c 356 357 theano.function([a], [c]) # compilation error because b (required by c) is undefined 358 359 theano.function([a,b], [c]) # compilation error because a is constant, it can't be an input 360 361 d = tensor.value(1.5) # create a value similar to the constant 'a' 362 e = d + b 363 theano.function([d,b], [e]) # this works. d's default value of 1.5 is ignored. 364 365 The python variables :literal:`a,b,c` all refer to instances of type 366 `Variable`. The `Variable` referred to by `a` is also an instance of 367 `Constant`. 368 369 `compile.function` uses each `Apply` instance's `inputs` attribute together 370 with each Variable's `owner` field to determine which inputs are necessary 371 to compute the function's outputs. 372 373 """ 374 375 # __slots__ = ['type', 'owner', 'index', 'name'] 376 __count__ = count(0) 377 378 def __init__(self, type, owner=None, index=None, name=None): 379 super(Variable, self).__init__() 380 381 self.tag = utils.scratchpad() 382 self.type = type 383 if owner is not None and not isinstance(owner, Apply): 384 raise TypeError("owner must be an Apply instance", owner) 385 self.owner = owner 386 if index is not None and not isinstance(index, integer_types): 387 raise TypeError("index must be an int", index) 388 self.index = index 389 if name is not None and not isinstance(name, string_types): 390 raise TypeError("name must be a string", name) 391 self.name = name 392 self.auto_name = 'auto_' + str(next(self.__count__)) 393 394 Variable.notify_construction_observers(self) 395 396 def __str__(self): 397 """Return a str representation of the Variable. 398 399 """ 400 if self.name is not None: 401 return self.name 402 if self.owner is not None: 403 op = self.owner.op 404 if self.index == op.default_output: 405 return str(self.owner.op) + ".out" 406 else: 407 return str(self.owner.op) + "." + str(self.index) 408 else: 409 return "<%s>" % str(self.type) 410 411 def __repr_test_value__(self): 412 """Return a repr of the test value. 413 414 Return a printable representation of the test value. It can be 415 overridden by classes with non printable test_value to provide a 416 suitable representation of the test_value. 417 """ 418 return repr(theano.gof.op.get_test_value(self)) 419 420 def __repr__(self, firstPass=True): 421 """Return a repr of the Variable. 422 423 Return a printable name or description of the Variable. If 424 config.print_test_value is True it will also print the test_value if 425 any. 426 """ 427 to_print = [str(self)] 428 if config.print_test_value and firstPass: 429 try: 430 to_print.append(self.__repr_test_value__()) 431 except AttributeError: 432 pass 433 return '\n'.join(to_print) 434 435 def clone(self): 436 """ 437 Return a new Variable like self. 438 439 Returns 440 ------- 441 Variable instance 442 A new Variable instance (or subclass instance) with no owner or 443 index. 444 445 Notes 446 ----- 447 Tags are copied to the returned instance. 448 449 Name is copied to the returned instance. 450 451 """ 452 # return copy(self) 453 cp = self.__class__(self.type, None, None, self.name) 454 cp.tag = copy(self.tag) 455 return cp 456 457 def __lt__(self, other): 458 raise NotImplementedError('Subclasses of Variable must provide __lt__', 459 self.__class__.__name__) 460 461 def __le__(self, other): 462 raise NotImplementedError('Subclasses of Variable must provide __le__', 463 self.__class__.__name__) 464 465 def __gt__(self, other): 466 raise NotImplementedError('Subclasses of Variable must provide __gt__', 467 self.__class__.__name__) 468 469 def __ge__(self, other): 470 raise NotImplementedError('Subclasses of Variable must provide __ge__', 471 self.__class__.__name__) 472 473 def get_parents(self): 474 if self.owner is not None: 475 return [self.owner] 476 return [] 477 478 def eval(self, inputs_to_values=None): 479 """ 480 Evaluates this variable. 481 482 Parameters 483 ---------- 484 inputs_to_values 485 A dictionary mapping theano Variables to values. 486 487 Examples 488 -------- 489 490 >>> import numpy as np 491 >>> import theano.tensor as T 492 >>> x = T.dscalar('x') 493 >>> y = T.dscalar('y') 494 >>> z = x + y 495 >>> np.allclose(z.eval({x : 16.3, y : 12.1}), 28.4) 496 True 497 498 We passed :func:`eval` a dictionary mapping symbolic theano 499 variables to the values to substitute for them, and it returned 500 the numerical value of the expression. 501 502 Notes 503 ----- 504 505 `eval` will be slow the first time you call it on a variable -- 506 it needs to call :func:`function` to compile the expression behind 507 the scenes. Subsequent calls to :func:`eval` on that same variable 508 will be fast, because the variable caches the compiled function. 509 510 This way of computing has more overhead than a normal Theano 511 function, so don't use it too much in real scripts. 512 """ 513 514 if inputs_to_values is None: 515 inputs_to_values = {} 516 517 if not hasattr(self, '_fn_cache'): 518 self._fn_cache = dict() 519 520 inputs = tuple(sorted(inputs_to_values.keys(), key=id)) 521 if inputs not in self._fn_cache: 522 self._fn_cache[inputs] = theano.function(inputs, self) 523 args = [inputs_to_values[param] for param in inputs] 524 525 rval = self._fn_cache[inputs](*args) 526 527 return rval 528 529 def __getstate__(self): 530 d = self.__dict__.copy() 531 d.pop("_fn_cache", None) 532 if (not config.pickle_test_value) \ 533 and (hasattr(self.tag, 'test_value')): 534 if not type(config).pickle_test_value.is_default: 535 warnings.warn("pickle_test_value is not defaut value (True).\n" 536 "Test value of variable %s(%s) will not be dumped." % (d['auto_name'], d['name'])) 537 t = copy(d["tag"]) 538 del t.test_value 539 d["tag"] = t 540 return d 541 542 # refer to doc in nodes_constructed. 543 construction_observers = [] 544 545 @classmethod 546 def append_construction_observer(cls, observer): 547 cls.construction_observers.append(observer) 548 549 @classmethod 550 def remove_construction_observer(cls, observer): 551 cls.construction_observers.remove(observer) 552 553 @classmethod 554 def notify_construction_observers(cls, instance): 555 for observer in cls.construction_observers: 556 observer(instance) 557 558 559class Constant(Variable): 560 """ 561 A :term:`Constant` is a `Variable` with a `value` field that cannot be 562 changed at runtime. 563 564 Constant nodes make eligible numerous optimizations: constant inlining in 565 C code, constant folding, etc. 566 567 Notes 568 ----- 569 The data field is filtered by what is provided in the constructor for the 570 Constant's type field. 571 572 WRITEME 573 574 """ 575 576 # __slots__ = ['data'] 577 def __init__(self, type, data, name=None): 578 Variable.__init__(self, type, None, None, name) 579 self.data = type.filter(data) 580 utils.add_tag_trace(self) 581 582 def equals(self, other): 583 # this does what __eq__ should do, but Variable and Apply should always be hashable by id 584 return isinstance(other, Constant) and self.signature() == other.signature() 585 586 def signature(self): 587 return (self.type, self.data) 588 589 def merge_signature(self): 590 return self.signature() 591 592 def __str__(self): 593 if self.name is not None: 594 return self.name 595 else: 596 name = str(self.data) 597 if len(name) > 20: 598 name = name[:10] + '...' + name[-10:] 599 return 'Constant{%s}' % name 600 601 def clone(self): 602 """ 603 We clone this object, but we don't clone the data to lower memory 604 requirement. We suppose that the data will never change. 605 606 """ 607 cp = self.__class__(self.type, self.data, self.name) 608 cp.tag = copy(self.tag) 609 return cp 610 611 def __set_owner(self, value): 612 """ 613 WRITEME 614 615 Raises 616 ------ 617 ValueError 618 If `value` is not `None`. 619 620 """ 621 if value is not None: 622 raise ValueError("Constant instances cannot have an owner.") 623 624 owner = property(lambda self: None, __set_owner) 625 value = property(lambda self: self.data, doc='read-only data access method') 626 627 # index is not defined, because the `owner` attribute must necessarily be None 628 629 630def stack_search(start, expand, mode='bfs', build_inv=False): 631 """ 632 Search through a graph, either breadth- or depth-first. 633 634 Parameters 635 ---------- 636 start : deque 637 Search from these nodes. 638 expand : callable 639 When we get to a node, add expand(node) to the list of nodes to visit. 640 This function should return a list, or None. 641 mode : string 642 'bfs' or 'dfs' for breath first search or depth first search. 643 644 Returns 645 ------- 646 list of `Variable` or `Apply` instances (depends on `expend`) 647 The list of nodes in order of traversal. 648 649 Notes 650 ----- 651 A node will appear at most once in the return value, even if it 652 appears multiple times in the start parameter. 653 654 :postcondition: every element of start is transferred to the returned list. 655 :postcondition: start is empty. 656 657 """ 658 659 if mode not in ('bfs', 'dfs'): 660 raise ValueError('mode should be bfs or dfs', mode) 661 rval_set = set() 662 rval_list = list() 663 if mode == 'bfs': 664 start_pop = start.popleft 665 else: 666 start_pop = start.pop 667 expand_inv = {} # var: clients 668 while start: 669 l = start_pop() 670 if id(l) not in rval_set: 671 rval_list.append(l) 672 rval_set.add(id(l)) 673 expand_l = expand(l) 674 if expand_l: 675 if build_inv: 676 for r in expand_l: 677 expand_inv.setdefault(r, []).append(l) 678 start.extend(expand_l) 679 assert len(rval_list) == len(rval_set) 680 if build_inv: 681 return rval_list, expand_inv 682 return rval_list 683 684 685def ancestors(variable_list, blockers=None): 686 """ 687 Return the variables that contribute to those in variable_list (inclusive). 688 689 Parameters 690 ---------- 691 variable_list : list of `Variable` instances 692 Output `Variable` instances from which to search backward through 693 owners. 694 695 Returns 696 ------- 697 list of `Variable` instances 698 All input nodes, in the order found by a left-recursive depth-first 699 search started at the nodes in `variable_list`. 700 701 """ 702 def expand(r): 703 if r.owner and (not blockers or r not in blockers): 704 return reversed(r.owner.inputs) 705 dfs_variables = stack_search(deque(variable_list), expand, 'dfs') 706 return dfs_variables 707 708 709def inputs(variable_list, blockers=None): 710 """ 711 Return the inputs required to compute the given Variables. 712 713 Parameters 714 ---------- 715 variable_list : list of `Variable` instances 716 Output `Variable` instances from which to search backward through 717 owners. 718 719 Returns 720 ------- 721 list of `Variable` instances 722 Input nodes with no owner, in the order found by a left-recursive 723 depth-first search started at the nodes in `variable_list`. 724 725 """ 726 vlist = ancestors(variable_list, blockers) 727 rval = [r for r in vlist if r.owner is None] 728 return rval 729 730 731def variables_and_orphans(i, o): 732 """ 733 Extract list of variables between i and o nodes via 734 dfs traversal and chooses the orphans among them 735 736 Parameters 737 ---------- 738 i : list 739 Input variables. 740 o : list 741 Output variables. 742 743 """ 744 def expand(r): 745 if r.owner and r not in i: 746 l = list(r.owner.inputs) + list(r.owner.outputs) 747 l.reverse() 748 return l 749 variables = stack_search(deque(o), expand, 'dfs') 750 orphans = [r for r in variables if r.owner is None and r not in i] 751 return variables, orphans 752 753 754def ops(i, o): 755 """ 756 Set of Ops contained within the subgraph between i and o 757 758 Parameters 759 ---------- 760 i : list 761 Input variables. 762 o : list 763 Output variables. 764 765 Returns 766 ------- 767 object 768 The set of ops that are contained within the subgraph that lies 769 between i and o, including the owners of the variables in o and 770 intermediary ops between i and o, but not the owners of the variables 771 in i. 772 773 """ 774 ops = set() 775 variables, orphans = variables_and_orphans(i, o) 776 for r in variables: 777 if r not in i and r not in orphans: 778 if r.owner is not None: 779 ops.add(r.owner) 780 return ops 781 782 783def variables(i, o): 784 """ 785 Extracts list of variables within input and output nodes via dfs travesal 786 787 Parameters 788 ---------- 789 i : list 790 Input variables. 791 o : list 792 Output variables. 793 794 Returns 795 ------- 796 object 797 The set of Variables that are involved in the subgraph that lies 798 between i and o. This includes i, o, orphans(i, o) and all values of 799 all intermediary steps from i to o. 800 801 """ 802 return variables_and_orphans(i, o)[0] 803 804 805def orphans(i, o): 806 """ 807 Extracts list of variables within input and output nodes 808 via dfs travesal and returns the orphans among them 809 810 Parameters 811 ---------- 812 i : list 813 Input Variables. 814 o : list 815 Output Variables. 816 817 Returns 818 ------- 819 object 820 The set of Variables which one or more Variables in o depend on but are 821 neither in i nor in the subgraph that lies between i and o. 822 823 Examples 824 -------- 825 orphans([x], [(x+y).out]) => [y] 826 827 """ 828 return variables_and_orphans(i, o)[1] 829 830 831def clone(i, o, copy_inputs=True, copy_orphans=None): 832 """Copies the subgraph contained between i and o. 833 834 Parameters 835 ---------- 836 i : list 837 Input Variables. 838 o : list 839 Output Variables. 840 copy_inputs : bool 841 If True, the inputs will be copied (defaults to True). 842 copy_orphans: 843 When None, use the copy_inputs value, 844 When True, new orphans nodes are created. 845 When False, original orphans nodes are reused in the new graph. 846 847 Returns 848 ------- 849 object 850 The inputs and outputs of that copy. 851 852 Note 853 ---- 854 855 A constant, if in the ``i`` list is not an orpha. So it will be 856 copied depending of the ``copy_inputs`` parameter. Otherwise it 857 will be copied depending of the ``copy_orphans`` parameter. 858 859 """ 860 if copy_orphans is None: 861 copy_orphans = copy_inputs 862 equiv = clone_get_equiv(i, o, copy_inputs, copy_orphans) 863 return [equiv[input] for input in i], [equiv[output] for output in o] 864 865 866def clone_get_equiv(inputs, outputs, copy_inputs=True, copy_orphans=True, 867 memo=None): 868 """ 869 Return a dictionary that maps from Variable and Apply nodes in the 870 original graph to a new node (a clone) in a new graph. 871 872 This function works by recursively cloning inputs... rebuilding a directed 873 graph from the inputs up to eventually building new outputs. 874 875 Parameters 876 ---------- 877 inputs : a list of Variables 878 outputs : a list of Variables 879 copy_inputs : bool 880 True means to create the cloned graph from new input 881 nodes (the bottom of a feed-upward graph). 882 False means to clone a graph that is rooted at the original input 883 nodes. 884 copy_orphans: 885 When True, new constant nodes are created. When False, original 886 constant nodes are reused in the new graph. 887 memo : None or dict 888 Optionally start with a partly-filled dictionary for the return value. 889 If a dictionary is passed, this function will work in-place on that 890 dictionary and return it. 891 892 """ 893 if memo is None: 894 memo = {} 895 896 # clone the inputs if necessary 897 for input in inputs: 898 if copy_inputs: 899 cpy = input.clone() 900 cpy.owner = None 901 cpy.index = None 902 memo.setdefault(input, cpy) 903 else: 904 memo.setdefault(input, input) 905 906 # go through the inputs -> outputs graph cloning as we go 907 for apply in io_toposort(inputs, outputs): 908 for input in apply.inputs: 909 if input not in memo: 910 if copy_orphans: 911 cpy = input.clone() 912 memo[input] = cpy 913 else: 914 memo[input] = input 915 916 new_apply = apply.clone_with_new_inputs([memo[i] for i in apply.inputs]) 917 memo.setdefault(apply, new_apply) 918 for output, new_output in zip(apply.outputs, new_apply.outputs): 919 memo.setdefault(output, new_output) 920 921 # finish up by cloning any remaining outputs (it can happen) 922 for output in outputs: 923 if output not in memo: 924 memo[output] = output.clone() 925 926 return memo 927 928 929def general_toposort(outputs, deps, debug_print=False, 930 compute_deps_cache=None, deps_cache=None, 931 clients=None): 932 """ 933 WRITEME 934 935 Parameters 936 ---------- 937 deps 938 A python function that takes a node as input and returns its dependence. 939 compute_deps_cache : optional 940 If provided deps_cache should also be provided. This is a function like 941 deps, but that also cache its results in a dict passed as deps_cache. 942 deps_cache : dict 943 Must be used with compute_deps_cache. 944 clients : dict 945 If a dict is passed it will be filled with a mapping of node 946 -> clients for each node in the subgraph. 947 948 Notes 949 ----- 950 deps(i) should behave like a pure function (no funny business with 951 internal state). 952 953 deps(i) will be cached by this function (to be fast). 954 955 The order of the return value list is determined by the order of nodes 956 returned by the deps() function. 957 958 deps should be provided or can be None and the caller provides 959 compute_deps_cache and deps_cache. The second option removes a Python 960 function call, and allows for more specialized code, so it can be 961 faster. 962 963 """ 964 if compute_deps_cache is None: 965 deps_cache = {} 966 967 def compute_deps_cache(io): 968 if io not in deps_cache: 969 d = deps(io) 970 if d: 971 if not isinstance(d, (list, OrderedSet)): 972 raise TypeError( 973 "Non-deterministic collections here make" 974 " toposort non-deterministic.") 975 deps_cache[io] = list(d) 976 else: 977 deps_cache[io] = d 978 return d 979 else: 980 return deps_cache[io] 981 assert deps_cache is not None 982 983 assert isinstance(outputs, (tuple, list, deque)) 984 985 reachable, _clients = stack_search(deque(outputs), compute_deps_cache, 986 'dfs', True) 987 if clients is not None: 988 clients.update(_clients) 989 sources = deque([r for r in reachable if not deps_cache.get(r, None)]) 990 991 rset = set() 992 rlist = [] 993 while sources: 994 node = sources.popleft() 995 if node not in rset: 996 rlist.append(node) 997 rset.add(node) 998 for client in _clients.get(node, []): 999 d = [a for a in deps_cache[client] if a is not node] 1000 deps_cache[client] = d 1001 if not d: 1002 sources.append(client) 1003 1004 if len(rlist) != len(reachable): 1005 if debug_print: 1006 print('') 1007 print(reachable) 1008 print(rlist) 1009 raise ValueError('graph contains cycles') 1010 1011 return rlist 1012 1013 1014def io_toposort(inputs, outputs, orderings=None, clients=None): 1015 """ 1016 Perform topological sort from input and output nodes 1017 1018 Parameters 1019 ---------- 1020 inputs : list or tuple of Variable instances 1021 outputs : list or tuple of Apply instances 1022 orderings : dict 1023 Key: Apply instance. Value: list of Apply instance. 1024 It is important that the value be a container with a deterministic 1025 iteration order. No sets allowed! 1026 clients : dict 1027 If a dict is provided it will be filled with mappings of 1028 node->clients for each node in the subgraph that is sorted 1029 1030 """ 1031 if not orderings and clients is None: # ordering can be None or empty dict 1032 # Specialized function that is faster when more then ~10 nodes 1033 # when no ordering. 1034 1035 # Do a new stack implementation with the vm algo. 1036 # This will change the order returned. 1037 computed = set(inputs) 1038 todo = [o.owner for o in reversed(outputs) if o.owner] 1039 order = [] 1040 while todo: 1041 cur = todo.pop() 1042 # We suppose that all outputs are always computed 1043 if cur.outputs[0] in computed: 1044 continue 1045 if all([i in computed or i.owner is None for i in cur.inputs]): 1046 computed.update(cur.outputs) 1047 order.append(cur) 1048 else: 1049 todo.append(cur) 1050 todo.extend(i.owner for i in cur.inputs if i.owner) 1051 return order 1052 1053 compute_deps = None 1054 compute_deps_cache = None 1055 iset = set(inputs) 1056 deps_cache = {} 1057 1058 if not orderings: # ordering can be None or empty dict 1059 # Specialized function that is faster when no ordering. 1060 # Also include the cache in the function itself for speed up. 1061 1062 def compute_deps_cache(obj): 1063 if obj in deps_cache: 1064 return deps_cache[obj] 1065 rval = [] 1066 if obj not in iset: 1067 if isinstance(obj, Variable): 1068 if obj.owner: 1069 rval = [obj.owner] 1070 elif isinstance(obj, Apply): 1071 rval = list(obj.inputs) 1072 if rval: 1073 if not isinstance(rval, (list, OrderedSet)): 1074 raise TypeError( 1075 "Non-deterministic collections here make" 1076 " toposort non-deterministic.") 1077 deps_cache[obj] = list(rval) 1078 else: 1079 deps_cache[obj] = rval 1080 else: 1081 deps_cache[obj] = rval 1082 return rval 1083 else: 1084 1085 # the inputs are used only here in the function that decides what 1086 # 'predecessors' to explore 1087 def compute_deps(obj): 1088 rval = [] 1089 if obj not in iset: 1090 if isinstance(obj, Variable): 1091 if obj.owner: 1092 rval = [obj.owner] 1093 elif isinstance(obj, Apply): 1094 rval = list(obj.inputs) 1095 rval.extend(orderings.get(obj, [])) 1096 else: 1097 assert not orderings.get(obj, None) 1098 return rval 1099 1100 topo = general_toposort(outputs, deps=compute_deps, 1101 compute_deps_cache=compute_deps_cache, 1102 deps_cache=deps_cache, clients=clients) 1103 return [o for o in topo if isinstance(o, Apply)] 1104 1105 1106default_leaf_formatter = str 1107 1108 1109def default_node_formatter(op, argstrings): 1110 return "%s(%s)" % (op.op, ", ".join(argstrings)) 1111 1112 1113def io_connection_pattern(inputs, outputs): 1114 """ 1115 Returns the connection pattern of a subgraph defined by given 1116 inputs and outputs. 1117 1118 """ 1119 inner_nodes = io_toposort(inputs, outputs) 1120 1121 # Initialize 'connect_pattern_by_var' by establishing each input as 1122 # connected only to itself 1123 connect_pattern_by_var = {} 1124 nb_inputs = len(inputs) 1125 1126 for i in range(nb_inputs): 1127 input = inputs[i] 1128 inp_connection_pattern = [i == j for j in range(nb_inputs)] 1129 connect_pattern_by_var[input] = inp_connection_pattern 1130 1131 # Iterate through the nodes used to produce the outputs from the 1132 # inputs and, for every node, infer their connection pattern to 1133 # every input from the connection patterns of their parents. 1134 for n in inner_nodes: 1135 1136 # Get the connection pattern of the inner node's op. If the op 1137 # does not define a connection_pattern method, assume that 1138 # every node output is connected to every node input 1139 try: 1140 op_connection_pattern = n.op.connection_pattern(n) 1141 except AttributeError: 1142 op_connection_pattern = ([[True] * len(n.outputs)] * 1143 len(n.inputs)) 1144 1145 # For every output of the inner node, figure out which inputs it 1146 # is connected to by combining the connection pattern of the inner 1147 # node and the connection patterns of the inner node's inputs. 1148 for out_idx in range(len(n.outputs)): 1149 out = n.outputs[out_idx] 1150 out_connection_pattern = [False] * nb_inputs 1151 1152 for inp_idx in range(len(n.inputs)): 1153 inp = n.inputs[inp_idx] 1154 1155 if inp in connect_pattern_by_var: 1156 inp_connection_pattern = connect_pattern_by_var[inp] 1157 1158 # If the node output is connected to the node input, it 1159 # means it is connected to every inner input that the 1160 # node inputs is connected to 1161 if op_connection_pattern[inp_idx][out_idx]: 1162 out_connection_pattern = [out_connection_pattern[i] or 1163 inp_connection_pattern[i] 1164 for i in range(nb_inputs)] 1165 1166 # Store the connection pattern of the node output 1167 connect_pattern_by_var[out] = out_connection_pattern 1168 1169 # Obtain the global connection pattern by combining the 1170 # connnection patterns of the individual outputs 1171 global_connection_pattern = [[] for o in range(len(inputs))] 1172 for out in outputs: 1173 out_connection_pattern = connect_pattern_by_var.get(out) 1174 if out_connection_pattern is None: 1175 # the output is completely isolated from inputs 1176 out_connection_pattern = [False] * len(inputs) 1177 for i in range(len(inputs)): 1178 global_connection_pattern[i].append(out_connection_pattern[i]) 1179 1180 return global_connection_pattern 1181 1182 1183def is_same_graph(var1, var2, givens=None, debug=False): 1184 """ 1185 Return True iff Variables `var1` and `var2` perform the same computation. 1186 1187 By 'performing the same computation', we mean that they must share the same 1188 graph, so that for instance this function will return False when comparing 1189 (x * (y * z)) with ((x * y) * z). 1190 1191 The current implementation is not efficient since, when possible, it 1192 verifies equality by calling two different functions that are expected to 1193 return the same output. The goal is to verify this assumption, to 1194 eventually get rid of one of them in the future. 1195 1196 Parameters 1197 ---------- 1198 var1 1199 The first Variable to compare. 1200 var2 1201 The second Variable to compare. 1202 givens 1203 Similar to the `givens` argument of `theano.function`, it can be used 1204 to perform substitutions in the computational graph of `var1` and 1205 `var2`. This argument is associated to neither `var1` nor `var2`: 1206 substitutions may affect both graphs if the substituted variable 1207 is present in both. 1208 debug : bool 1209 If True, then an exception is raised when we are in a situation where 1210 the `equal_computations` implementation cannot be called. 1211 This parameter is intended to be used in tests only, to make sure we 1212 properly test both implementations. 1213 1214 Examples 1215 -------- 1216 1217 ====== ====== ====== ====== 1218 var1 var2 givens output 1219 ====== ====== ====== ====== 1220 x + 1 x + 1 {} True 1221 x + 1 y + 1 {} False 1222 x + 1 y + 1 {x: y} True 1223 ====== ====== ====== ====== 1224 1225 """ 1226 # Lazy import. 1227 if givens is None: 1228 givens = {} 1229 global equal_computations, is_same_graph_with_merge 1230 if equal_computations is None: 1231 from theano.gof.opt import is_same_graph_with_merge 1232 from theano.scan_module.scan_utils import equal_computations 1233 # Convert `givens` to dictionary. 1234 if not isinstance(givens, dict): 1235 givens = dict(givens) 1236 # Get result from the merge-based function. 1237 rval1 = is_same_graph_with_merge(var1=var1, var2=var2, givens=givens) 1238 # Get result from the function `equal_computations` from scan_utils. 1239 1240 use_equal_computations = True 1241 if givens: 1242 # We need to build the `in_xs` and `in_ys` lists. To do this, we need 1243 # to be able to tell whether a variable belongs to the computational 1244 # graph of `var1` or `var2`. 1245 # The typical case we want to handle is when `to_replace` belongs to 1246 # one of these graphs, and `replace_by` belongs to the other one. In 1247 # other situations, the current implementation of `equal_computations` 1248 # is probably not appropriate, so we do not call it. 1249 ok = True 1250 in_xs = [] 1251 in_ys = [] 1252 # Compute the sets of all variables found in each computational graph. 1253 inputs_var = list(map(inputs, ([var1], [var2]))) 1254 all_vars = [set(variables(v_i, v_o)) 1255 for v_i, v_o in ((inputs_var[0], [var1]), 1256 (inputs_var[1], [var2]))] 1257 1258 def in_var(x, k): 1259 # Return True iff `x` is in computation graph of variable `vark`. 1260 return x in all_vars[k - 1] 1261 1262 for to_replace, replace_by in iteritems(givens): 1263 # Map a substitution variable to the computational graphs it 1264 # belongs to. 1265 inside = dict((v, [in_var(v, k) for k in (1, 2)]) 1266 for v in (to_replace, replace_by)) 1267 if (inside[to_replace][0] and not inside[to_replace][1] and 1268 inside[replace_by][1] and not inside[replace_by][0]): 1269 # Substitute variable in `var1` by one from `var2`. 1270 in_xs.append(to_replace) 1271 in_ys.append(replace_by) 1272 elif (inside[to_replace][1] and not inside[to_replace][0] and 1273 inside[replace_by][0] and not inside[replace_by][1]): 1274 # Substitute variable in `var2` by one from `var1`. 1275 in_xs.append(replace_by) 1276 in_ys.append(to_replace) 1277 else: 1278 ok = False 1279 break 1280 if not ok: 1281 # We cannot directly use `equal_computations`. 1282 if debug: 1283 raise AssertionError( 1284 'When `debug` is True we want to make sure we are also ' 1285 'using the `equal_computations` implementation') 1286 use_equal_computations = False 1287 else: 1288 in_xs = None 1289 in_ys = None 1290 if use_equal_computations: 1291 rval2 = equal_computations(xs=[var1], ys=[var2], 1292 in_xs=in_xs, in_ys=in_ys) 1293 assert rval2 == rval1 1294 return rval1 1295 1296 1297def op_as_string(i, op, 1298 leaf_formatter=default_leaf_formatter, 1299 node_formatter=default_node_formatter): 1300 """ 1301 Op to return a string representation of the subgraph 1302 between i and o 1303 """ 1304 strs = as_string(i, op.inputs, leaf_formatter, node_formatter) 1305 return node_formatter(op, strs) 1306 1307 1308def as_string(i, o, 1309 leaf_formatter=default_leaf_formatter, 1310 node_formatter=default_node_formatter): 1311 """ 1312 Returns a string representation of the subgraph between i and o 1313 1314 Parameters 1315 ---------- 1316 i : list 1317 Input `Variable` s. 1318 o : list 1319 Output `Variable` s. 1320 leaf_formatter : callable 1321 Takes a `Variable` and returns a string to describe it. 1322 node_formatter : callable 1323 Takes an `Op` and the list of strings corresponding to its arguments 1324 and returns a string to describe it. 1325 1326 Returns 1327 ------- 1328 str 1329 Returns a string representation of the subgraph between i and o. If the 1330 same op is used by several other ops, the first occurrence will be 1331 marked as :literal:`*n -> description` and all subsequent occurrences 1332 will be marked as :literal:`*n`, where n is an id number (ids are 1333 attributed in an unspecified order and only exist for viewing 1334 convenience). 1335 1336 """ 1337 i = set(i) 1338 1339 orph = orphans(i, o) 1340 1341 multi = set() 1342 seen = set() 1343 for output in o: 1344 op = output.owner 1345 if op in seen: 1346 multi.add(op) 1347 else: 1348 seen.add(op) 1349 for op in ops(i, o): 1350 for input in op.inputs: 1351 op2 = input.owner 1352 if input in i or input in orph or op2 is None: 1353 continue 1354 if op2 in seen: 1355 multi.add(op2) 1356 else: 1357 seen.add(input.owner) 1358 multi = [x for x in multi] 1359 done = set() 1360 1361 def multi_index(x): 1362 return multi.index(x) + 1 1363 1364 def describe(r): 1365 if r.owner is not None and r not in i and r not in orph: 1366 op = r.owner 1367 idx = op.outputs.index(r) 1368 if len(op.outputs) == 1: 1369 idxs = "" 1370 else: 1371 idxs = "::%i" % idx 1372 if op in done: 1373 return "*%i%s" % (multi_index(op), idxs) 1374 else: 1375 done.add(op) 1376 s = node_formatter(op, [describe(input) for input in op.inputs]) 1377 if op in multi: 1378 return "*%i -> %s" % (multi_index(op), s) 1379 else: 1380 return s 1381 else: 1382 return leaf_formatter(r) 1383 1384 return [describe(output) for output in o] 1385 1386 1387def view_roots(r): 1388 """ 1389 Utility function that returns the leaves of a search through 1390 consecutive view_map()s. 1391 1392 WRITEME 1393 1394 """ 1395 owner = r.owner 1396 if owner is not None: 1397 try: 1398 view_map = owner.op.view_map 1399 view_map = dict((owner.outputs[o], i) 1400 for o, i in iteritems(view_map)) 1401 except AttributeError: 1402 return [r] 1403 if r in view_map: 1404 answer = [] 1405 for i in view_map[r]: 1406 answer += view_roots(owner.inputs[i]) 1407 return answer 1408 else: 1409 return [r] 1410 else: 1411 return [r] 1412 1413 1414def list_of_nodes(inputs, outputs): 1415 """ 1416 Return the apply nodes of the graph between inputs and outputs. 1417 1418 """ 1419 return stack_search( 1420 deque([o.owner for o in outputs]), 1421 lambda o: [inp.owner for inp in o.inputs 1422 if inp.owner and 1423 not any(i in inp.owner.outputs for i in inputs)]) 1424 1425 1426def is_in_ancestors(l_node, f_node): 1427 r""" 1428 Goes up in the graph and returns True if the apply node f_node is found. 1429 1430 Use a stack implementation as the vm algo. 1431 We suppose all nodes are not lazy 1432 (i.e. for IfElse we suppose all inputs are computed) 1433 """ 1434 computed = set() 1435 todo = [l_node] 1436 while todo: 1437 cur = todo.pop() 1438 if cur.outputs[0] in computed: 1439 continue 1440 if all([i in computed or i.owner is None for i in cur.inputs]): 1441 computed.update(cur.outputs) 1442 if cur is f_node: 1443 return True 1444 else: 1445 todo.append(cur) 1446 todo.extend(i.owner for i in cur.inputs if i.owner) 1447 return False 1448 1449 1450@contextlib.contextmanager 1451def nodes_constructed(): 1452 """ 1453 A contextmanager that is used in inherit_stack_trace and keeps track 1454 of all the newly created varaible nodes inside an optimization. A list 1455 of new_nodes is instantiated but will be filled in a lazy manner (when 1456 Variable.notify_construction_observers is called). 1457 1458 1459 `observer` is the entity that updates the new_nodes list. 1460 construction_observers is a list inside Variable class and contains 1461 a list of observer functions. The observer functions inside 1462 construction_observers are only called when a variable node is 1463 instantiated (where Variable.notify_construction_observers is called). 1464 When the observer function is called, a new variable node is added to 1465 the new_nodes list. 1466 1467 1468 Parameters 1469 ---------- 1470 new_nodes 1471 A list of all the variable nodes that are created inside the optimization. 1472 1473 yields 1474 new_nodes list. 1475 """ 1476 new_nodes = [] 1477 1478 def observer(node): 1479 new_nodes.append(node) 1480 Variable.append_construction_observer(observer) 1481 yield new_nodes 1482 Variable.remove_construction_observer(observer) 1483