1""" 2Driver of graph construction, optimization, and linking. 3 4""" 5from __future__ import absolute_import, print_function, division 6 7import copy 8from six import string_types, iteritems, iterkeys 9from six.moves import xrange 10import six.moves.copyreg as copyreg 11import six.moves.cPickle as pickle 12from itertools import chain 13import time 14import warnings 15import numpy as np 16 17import theano 18from theano import config, gof 19from theano.compat import izip 20from theano.gof import graph 21import theano.compile.profiling 22from theano.compile.io import ( 23 In, SymbolicInput, SymbolicOutput) 24from theano.compile.ops import deep_copy_op, view_op 25from theano.gof.graph import is_same_graph 26from theano.gof.op import ops_with_inner_function 27 28import logging 29_logger = logging.getLogger('theano.compile.function_module') 30 31__docformat__ = "restructuredtext en" 32 33 34class UnusedInputError(Exception): 35 """ 36 A symbolic input passed to function is not needed. 37 38 """ 39 40 pass 41 42 43def alias_root(v): 44 """ 45 Return the variable to which v is aliased by view_maps and destroy_maps. 46 47 """ 48 if v.owner is None: 49 return v 50 vmap = getattr(v.owner.op, 'view_map', {}) 51 dmap = getattr(v.owner.op, 'destroy_map', {}) 52 outpos = v.owner.outputs.index(v) 53 v_views = vmap.get(outpos, []) + dmap.get(outpos, []) 54 if len(v_views) > 1: 55 raise NotImplementedError( 56 str(v) + " is a view/destroyed version of more then one inputs. " 57 "Currently, we only support the case where an output is a view or " 58 "a destroyed version of one input.") 59 elif v_views: 60 return alias_root(v.owner.inputs[v_views[0]]) 61 else: 62 return v 63 64 65def view_tree_set(v, treeset): 66 """ 67 Add to `treeset` all variables that are views of v, given that v is 68 not a view. 69 70 """ 71 treeset.add(v) 72 for cl, v_input_pos_to_cl in v.clients: 73 if cl == 'output': 74 continue 75 vmap = getattr(cl.op, 'view_map', {}) 76 dmap = getattr(cl.op, 'destroy_map', {}) 77 for opos, iposlist in chain(iteritems(vmap), iteritems(dmap)): 78 if v_input_pos_to_cl in iposlist: 79 if cl.outputs[opos] not in treeset: 80 view_tree_set(cl.outputs[opos], treeset) 81 82 83def infer_reuse_pattern(fgraph, outputs_to_disown): 84 """ 85 Given an fgraph and a list of variables, returns the list or set 86 of all variables which may share the same underlying data storage 87 as any of the specified variables. Used internally by function, 88 FunctionMaker. 89 90 This list (or set) is also referred to as no_recycling sometimes, 91 especially by linker code. 92 93 """ 94 rval = set() 95 for o in outputs_to_disown: 96 view_tree_set(alias_root(o), rval) 97 # remove from rval all of the inputs, constants, values. 98 rval = set(r for r in rval if r.owner is not None) 99 100 return rval 101 102 103def fgraph_updated_vars(fgraph, expanded_inputs): 104 """ 105 Reconstruct the full "updates" dictionary, mapping from FunctionGraph input 106 variables to the fgraph outputs that will replace their values. 107 108 Returns 109 ------- 110 dict variable -> variable 111 112 """ 113 updated_vars = {} 114 potential_values = list(fgraph.outputs) # copy the list 115 if len(expanded_inputs) != len(fgraph.inputs): 116 raise ValueError('expanded_inputs must match len(fgraph.inputs)') 117 for e_input, ivar in reversed(list(zip(expanded_inputs, fgraph.inputs))): 118 if e_input.update is not None: 119 updated_vars[ivar] = potential_values.pop() 120 return updated_vars 121 122 123class Supervisor: 124 """ 125 Listener for FunctionGraph events which makes sure that no 126 operation overwrites the contents of protected Variables. The 127 outputs of the FunctionGraph are protected by default. 128 129 """ 130 131 def __init__(self, protected): 132 self.protected = list(protected) 133 134 def validate(self, fgraph): 135 if config.cycle_detection == 'fast' and hasattr(fgraph, 'has_destroyers'): 136 if fgraph.has_destroyers(self.protected): 137 raise gof.InconsistencyError("Trying to destroy a protected" 138 "Variable.") 139 return True 140 if not hasattr(fgraph, 'destroyers'): 141 return True 142 for r in self.protected + list(fgraph.outputs): 143 if fgraph.destroyers(r): 144 raise gof.InconsistencyError("Trying to destroy a protected" 145 "Variable.", r) 146 147 148def std_fgraph(input_specs, output_specs, accept_inplace=False): 149 """ 150 Makes an FunctionGraph corresponding to the input specs and the output 151 specs. Any SymbolicInput in the input_specs, if its update field 152 is not None, will add an output to the FunctionGraph corresponding to that 153 update. The return value is the FunctionGraph as well as a list of 154 SymbolicOutput instances corresponding to the updates. 155 156 If accept_inplace is False, the graph will be checked for inplace 157 operations and an exception will be raised if it has any. If 158 accept_inplace is True, a DestroyHandler will be added to the FunctionGraph 159 if there are any inplace operations. 160 161 The returned FunctionGraph is a clone of the graph between the provided 162 inputs and outputs. 163 164 """ 165 orig_inputs = [spec.variable for spec in input_specs] 166 167 # Extract the updates and the mapping between update outputs and 168 # the updated inputs. 169 updates = [] 170 update_mapping = {} 171 out_idx = len(output_specs) 172 for inp_idx in range(len(input_specs)): 173 if input_specs[inp_idx].update: 174 updates.append(input_specs[inp_idx].update) 175 update_mapping[out_idx] = inp_idx 176 out_idx += 1 177 178 orig_outputs = [spec.variable for spec in output_specs] + updates 179 180 fgraph = gof.fg.FunctionGraph(orig_inputs, orig_outputs, 181 update_mapping=update_mapping) 182 183 for node in fgraph.apply_nodes: 184 if getattr(node.op, 'destroy_map', None): 185 if not accept_inplace: 186 raise TypeError("Graph must not contain inplace operations", 187 node, node.op) 188 else: 189 fgraph.attach_feature(gof.DestroyHandler()) 190 break 191 192 # We need to protect all immutable inputs from inplace operations. 193 fgraph.attach_feature( 194 Supervisor(input 195 for spec, input in zip(input_specs, fgraph.inputs) 196 if not (spec.mutable or 197 (hasattr(fgraph, 'destroyers') and 198 fgraph.has_destroyers([input]))))) 199 200 # If named nodes are replaced, keep the name 201 for feature in std_fgraph.features: 202 fgraph.attach_feature(feature()) 203 return fgraph, list(map(SymbolicOutput, updates)) 204 205 206std_fgraph.features = [gof.toolbox.PreserveVariableAttributes] 207 208 209class AliasedMemoryError(Exception): 210 """ 211 Memory is aliased that should not be. 212 213 """ 214 pass 215 216 217### 218# Function 219### 220 221# unique id object used as a placeholder for duplicate entries 222DUPLICATE = ['DUPLICATE'] 223 224 225class Function(object): 226 """ 227 Type of the functions returned by theano.function or 228 theano.FunctionMaker.create. 229 230 `Function` is the callable object that does computation. It has the storage 231 of inputs and outputs, performs the packing and unpacking of inputs and 232 return values. It implements the square-bracket indexing so that you can 233 look up the value of a symbolic node. 234 235 Functions are copyable via {{{fn.copy()}}} and {{{copy.copy(fn)}}}. 236 When a function is copied, this instance is duplicated. Contrast with 237 self.maker (instance of `FunctionMaker`) that is shared between copies. 238 The meaning of copying a function is that the containers and their current 239 values will all be duplicated. This requires that mutable inputs be 240 copied, whereas immutable inputs may be shared between copies. 241 242 A Function instance is hashable, on the basis of its memory 243 address (its id). 244 245 A Function instance is only equal to itself. 246 247 A Function instance may be serialized using the `pickle` or 248 `cPickle` modules. This will save all default inputs, the graph, 249 and WRITEME to the pickle file. 250 251 A Function instance have a ``trust_input`` field that default to 252 False. When True, we don't do extra check of the input to give 253 better error message. In some case, python code will still return 254 the good results if you pass a python or numpy scalar instead of a 255 numpy tensor. C code should raise an error if you pass an object 256 of the wrong type. 257 258 Attributes 259 ---------- 260 finder 261 inv_finder 262 263 """ 264 265 pickle_aliased_memory_strategy = 'warn' 266 """ 267 How to deal with pickling finding aliased storage. 268 269 Meaningful settings are: 'ignore', 'warn', 'raise'. 270 271 If the value is 'warn', then a message will be printed to stderr 272 if aliased storage is dectected during pickle.dump. 273 274 If the value is 'raise', then an AliasedMemoryError will be raised 275 if aliased storage is detected during pickle.dump. 276 277 """ 278 279 input_storage = None 280 """ 281 List of Container instances. 282 283 """ 284 285 output_storage = None 286 """ 287 List of Container instances. 288 289 """ 290 291 indices = None 292 """ 293 List of (SymbolicInput, indices, [SymbolicInput,...]), 294 one tuple for each input. 295 296 The first tuple element is the SymbolicInput object for the corresponding 297 function input. 298 299 The second and third tuple elements are used only by Kits, which 300 are deprecated. 301 302 """ 303 304 defaults = None 305 """ 306 List of 3-tuples, one 3-tuple for each input. 307 308 Tuple element 0: Bool: Is this input required at each function call? 309 Tuple element 1: Bool: Should this inputs value be reverted after 310 each call? 311 Tuple element 2: Any: The value associated with this input. 312 313 """ 314 315 unpack_single = None 316 """ 317 Bool: for outputs lists of length 1, should the 0'th element be 318 returned directly? 319 320 """ 321 322 return_none = None 323 """ 324 Bool: whether the function should return None or not. 325 326 """ 327 328 maker = None 329 """ 330 FunctionMaker instance. 331 332 """ 333 334 fn = None 335 """ 336 A function that evaluates the graph. Typically a linker's make_thunk method 337 created this function. 338 339 """ 340 341 finder = None 342 """ 343 Dictionary mapping several kinds of things to containers. 344 345 We set an entry in finder for: 346 347 - the index of the input 348 349 - the variable instance the input is based on 350 351 - the name of the input 352 353 All entries map to the container or to DUPLICATE if an ambiguity 354 is detected. 355 356 """ 357 358 inv_finder = None 359 """ 360 Dict. Reverse lookup of `finder`. 361 362 It maps container -> SymbolicInput 363 364 """ 365 366 def __init__(self, fn, input_storage, output_storage, indices, outputs, 367 defaults, unpack_single, return_none, output_keys, maker, 368 name=None): 369 self.fn = fn 370 self.input_storage = input_storage 371 self.output_storage = output_storage 372 self.indices = indices 373 self.outputs = outputs 374 self.defaults = defaults 375 self.unpack_single = unpack_single 376 self.return_none = return_none 377 self.maker = maker 378 self.profile = None # reassigned in FunctionMaker.create 379 self.trust_input = False # If True, we don't check the input parameter 380 self.name = name 381 self.nodes_with_inner_function = [] 382 self.output_keys = output_keys 383 384 # See if we have any mutable / borrow inputs 385 # TODO: this only need to be set if there is more then 1 input 386 self._check_for_aliased_inputs = False 387 for i in maker.inputs: 388 # If the input is a shared variable, the memory region is 389 # under Theano control and so we don't need to check if it 390 # is aliased as we never do that. 391 if (isinstance(i, In) and not i.shared and 392 (getattr(i, 'borrow', False) or 393 getattr(i, 'mutable', False))): 394 self._check_for_aliased_inputs = True 395 break 396 397 # We will be popping stuff off this `containers` object. It is a copy. 398 containers = list(self.input_storage) 399 finder = {} 400 inv_finder = {} 401 402 def distribute(indices, cs, value): 403 input.distribute(value, indices, cs) 404 for c in cs: 405 c.provided += 1 406 407 # Store the list of names of named inputs. 408 named_inputs = [] 409 # Count the number of un-named inputs. 410 n_unnamed_inputs = 0 411 412 # Initialize the storage 413 # this loop works by modifying the elements (as variable c) of 414 # self.input_storage inplace. 415 for i, ((input, indices, sinputs), (required, refeed, value)) in \ 416 enumerate(zip(self.indices, defaults)): 417 if indices is None: 418 # containers is being used as a stack. Here we pop off 419 # the next one. 420 c = containers[0] 421 c.strict = getattr(input, 'strict', False) 422 c.allow_downcast = getattr(input, 'allow_downcast', None) 423 424 if value is not None: 425 # Always initialize the storage. 426 if isinstance(value, gof.Container): 427 # There is no point in obtaining the current value 428 # stored in the container, since the container is 429 # shared. 430 # For safety, we make sure 'refeed' is False, since 431 # there is no need to refeed the defaullt value. 432 assert not refeed 433 else: 434 c.value = value 435 c.required = required 436 c.implicit = input.implicit 437 # this is a count of how many times the input has been 438 # provided (reinitialized to 0 on __call__) 439 c.provided = 0 440 finder[i] = c 441 finder[input.variable] = c 442 if input.name not in finder: 443 finder[input.name] = c 444 else: 445 finder[input.name] = DUPLICATE 446 if input.name is None: 447 n_unnamed_inputs += 1 448 else: 449 named_inputs.append(input.name) 450 inv_finder[c] = input 451 containers[:1] = [] 452 453 self.finder = finder 454 self.inv_finder = inv_finder 455 456 # this class is important in overriding the square-bracket notation: 457 # fn.value[x] 458 # self reference is available via the closure on the class 459 class ValueAttribute(object): 460 def __getitem__(self, item): 461 try: 462 s = finder[item] 463 except KeyError: 464 raise TypeError("Unknown input or state: %s" % str(item)) 465 if s is DUPLICATE: 466 raise TypeError("Ambiguous name: %s - please check the " 467 "names of the inputs of your function " 468 "for duplicates." % str(item)) 469 if isinstance(s, gof.Container): 470 return s.value 471 else: 472 raise NotImplementedError 473 474 def __setitem__(self, item, value): 475 try: 476 s = finder[item] 477 except KeyError: 478 # Print informative error message. 479 msg = get_info_on_inputs(named_inputs, n_unnamed_inputs) 480 raise TypeError("Unknown input or state: %s. %s" % 481 (str(item), msg)) 482 if s is DUPLICATE: 483 raise TypeError("Ambiguous name: %s - please check the " 484 "names of the inputs of your function " 485 "for duplicates." % str(item)) 486 if isinstance(s, gof.Container): 487 s.value = value 488 s.provided += 1 489 else: 490 s(value) 491 492 def __contains__(self, item): 493 return finder.__contains__(item) 494 495 # this class is important in overriding the square-bracket notation: 496 # fn.container[x] 497 # self reference is available via the closure on the class 498 class ContainerAttribute(object): 499 def __getitem__(self, item): 500 return finder[item] 501 502 def __contains__(self, item): 503 return finder.__contains__(item) 504 # You cannot set the container 505 506 self._value = ValueAttribute() 507 self._container = ContainerAttribute() 508 509 # Compute self.n_returned_outputs. 510 # This is used only when fn.need_update_inputs is False 511 # because we're using one of the VM objects and it is 512 # putting updates back into the input containers all by itself. 513 assert len(self.maker.expanded_inputs) == len(self.input_storage) 514 self.n_returned_outputs = len(self.output_storage) 515 for input in self.maker.expanded_inputs: 516 if input.update is not None: 517 self.n_returned_outputs -= 1 518 519 for node in self.maker.fgraph.apply_nodes: 520 if node.op in ops_with_inner_function: 521 self.nodes_with_inner_function.append(node.op) 522 523 def __contains__(self, item): 524 return self.value.__contains__(item) 525 526 def __getitem__(self, item): 527 return self.value[item] 528 529 def __setitem__(self, item, value): 530 self.value[item] = value 531 532 def __copy__(self): 533 """ 534 Copy a function. Copied function have separate intermediate 535 storages and output storages with original function 536 """ 537 return self.copy() 538 539 def copy(self, share_memory=False, swap=None, delete_updates=False, 540 name=None, profile=None): 541 """ 542 Copy this function. Copied function will have separated maker and 543 fgraph with original function. User can choose whether to separate 544 storage by changing the share_memory arguments. 545 546 Parameters 547 ---------- 548 share_memory : boolean 549 When True, two function share intermediate storages(storages except input and 550 output storages). Otherwise two functions will only share partial 551 storages and same maker. If two functions share memory and 552 allow_gc=False, this will increase executing speed and save memory. 553 554 swap : dict 555 Dictionary that map old SharedVariables to new 556 SharedVariables. Default is None. 557 NOTE: The shared variable swap in only done in the new returned 558 function, not in the user graph. 559 560 delete_updates : boolean 561 If True, Copied function will not have updates. 562 name : string 563 If provided, will be the name of the new 564 Function. Otherwise, it will be old + " copy" 565 566 profile : 567 as theano.function profile parameter 568 569 Returns 570 ------- 571 theano.Function 572 Copied theano.Function 573 """ 574 # helper function 575 def checkSV(sv_ori, sv_rpl): 576 """ 577 Assert two SharedVariable follow some restirctions: 578 1. same type 579 2. same shape or dim? 580 """ 581 SharedVariable = theano.tensor.sharedvar.SharedVariable 582 assert isinstance(sv_ori, SharedVariable), ( 583 "Key of swap should be SharedVariable, given:", sv_ori, 584 " type", type(sv_ori)) 585 assert isinstance(sv_rpl, SharedVariable), ( 586 "Value of swap should be SharedVariable, given:", sv_rpl, 587 "type", type(sv_ori)) 588 assert sv_ori.type == sv_rpl.type, ( 589 "Type of given SharedVariable conflicts with original one", 590 "Type of given SharedVariable:", sv_rpl.type, 591 "Type of original SharedVariable:", sv_ori.type) 592 593 maker = self.maker 594 595 # Copy Ins and their storage. 596 # so that they have different storage as their value 597 ins = [copy.copy(input) for input in maker.inputs] 598 599 # Delete update output in fgraph and updates In instances if needed 600 if delete_updates: 601 # The first len(maker.outputs) variables are original variables. 602 # The rest are the updates. 603 out_vars = maker.fgraph.outputs[:len(maker.outputs)] 604 else: 605 out_vars = maker.fgraph.outputs 606 607 # Init new fgraph using copied variables and get memo 608 # memo: a dict that map old variables to new variables 609 memo = graph.clone_get_equiv(maker.fgraph.inputs, out_vars) 610 fg_cpy = gof.fg.FunctionGraph([memo[i] for i in maker.fgraph.inputs], 611 [memo[o] for o in out_vars], 612 clone=False) 613 614 # Re initialize Outs and swap update and variable in Ins 615 # By doing this, we can pass FunctionMaker._check_unused_inputs() 616 outs = list(map(SymbolicOutput, fg_cpy.outputs[:len(maker.outputs)])) 617 for out_ori, out_cpy in zip(maker.outputs, outs): 618 out_cpy.borrow = out_ori.borrow 619 620 # swap SharedVariable 621 if swap is not None: 622 exist_svs = [i.variable for i in maker.inputs] 623 624 # Check if given ShareVariables exist 625 for sv in iterkeys(swap): 626 if sv not in exist_svs: 627 raise ValueError("SharedVariable: %s not found" % 628 (sv.name)) 629 630 # Swap SharedVariable in fgraph and In instances 631 for index, (i, in_v) in enumerate(zip(ins, fg_cpy.inputs)): 632 # Variables in maker.inputs are defined by user, therefore we 633 # use them to make comparison and do the mapping. 634 # Otherwise we don't touch them. 635 var = maker.inputs[index].variable 636 637 if var in swap: 638 swap_sv = swap[var] 639 checkSV(i.variable, swap_sv) 640 641 # swap variable and value of In instances 642 i.variable = swap_sv 643 i.value = swap_sv.container 644 645 # In the fgraph we use the cloned SharedVariable 646 swap_sv = swap_sv.clone() 647 648 # Swap SharedVariable in fgraph 649 # if inputs was replaced, change self.inputs 650 fg_cpy.inputs[index] = swap_sv 651 fg_cpy.replace(in_v, swap_sv, reason="Swap SV") 652 653 # Delete update if needed 654 update_i = len(outs) 655 for i, in_var in zip(ins, fg_cpy.inputs): 656 i.variable = in_var 657 if not delete_updates and i.update is not None: 658 i.update = fg_cpy.outputs[update_i] 659 update_i += 1 660 else: 661 i.update = None 662 663 # Construct new storage_map that map new variable to old storage, 664 # so that the ensuing function shares storage with the original one 665 storage_map = self.fn.storage_map 666 new_storage_map = {} 667 # TODO: We could share the output storage, but we must make sure 668 # 2 different function call won't override each other values. This 669 # is already done elsewhere, so to reuse it the user would need to 670 # use Out(var, borrow=True) and maybe the mutable=True flag too. 671 # But to be safe for now as it isn't documented and we aren't sure 672 # it is well tested, we don't share the part of the storage_map. 673 if share_memory: 674 i_o_vars = maker.fgraph.inputs + maker.fgraph.outputs 675 for key in storage_map.keys(): 676 if key not in i_o_vars: 677 new_storage_map[memo[key]] = storage_map[key] 678 679 if not name and self.name: 680 name = self.name + " copy" 681 682 input_storage = [i.value for i in ins] 683 # reinitialize new maker and create new function 684 if profile is None: 685 profile = config.profile or config.print_global_stats 686 # profile -> True or False 687 if profile is True: 688 if name: 689 message = name 690 else: 691 message = str(profile.message) + " copy" 692 profile = theano.compile.profiling.ProfileStats(message=message) 693 # profile -> object 694 elif type(profile) == str: 695 profile = theano.compile.profiling.ProfileStats(message=profile) 696 697 f_cpy = maker.__class__(inputs=ins, outputs=outs, fgraph=fg_cpy, 698 mode=maker.mode, profile=profile, 699 # When removing updates containing variables 700 # not used in the output function, copy 701 # generates an unused implicit input. 702 # We ignore the resulting errors, 703 # but could change it to 'warn' if this might 704 # cause problems. 705 on_unused_input='ignore', 706 function_builder=maker.function_builder, 707 # As this is an optimized graph, it 708 # can contain inplace. DebugMode check 709 # that. 710 accept_inplace=True, 711 ).create(input_storage, 712 storage_map=new_storage_map) 713 714 for in_ori, in_cpy, ori, cpy in zip(maker.inputs, f_cpy.maker.inputs, 715 self.input_storage, 716 f_cpy.input_storage): 717 718 # Share immutable ShareVariable and constant input's storage 719 swapped = swap is not None and in_ori.variable in swap 720 721 # Using the original storage if SharedVariable will not be updated 722 # and is not swapped 723 if not in_ori.mutable and not swapped: 724 cpy.data = ori.data 725 in_cpy.value = in_ori.value 726 727 # Reconstruct Function.finder which map Variable defined by user 728 # to container, to make Function.value and Function.data work well. 729 # Replace variable in new maker.inputs by the original ones. 730 # So that user can swap SharedVariable in a swapped function 731 container = f_cpy.finder.pop(in_cpy.variable) 732 if not swapped: 733 f_cpy.finder[in_ori.variable] = container 734 in_cpy.vairable = in_ori.variable 735 else: 736 f_cpy.finder[swap[in_ori.variable]] = container 737 in_cpy.variable = swap[in_ori.variable] 738 739 f_cpy.name = name 740 f_cpy.maker.fgraph.name = name 741 return f_cpy 742 743 def __call__(self, *args, **kwargs): 744 """ 745 Evaluates value of a function on given arguments. 746 747 Parameters 748 ---------- 749 args : list 750 List of inputs to the function. All inputs are required, even when 751 some of them are not necessary to calculate requested subset of 752 outputs. 753 754 kwargs : dict 755 The function inputs can be passed as keyword argument. For this, use 756 the name of the input or the input instance as the key. 757 758 Keyword argument ``output_subset`` is a list of either indices of the 759 function's outputs or the keys belonging to the `output_keys` dict 760 and represent outputs that are requested to be calculated. Regardless 761 of the presence of ``output_subset``, the updates are always calculated 762 and processed. To disable the updates, you should use the ``copy`` 763 method with ``delete_updates=True``. 764 765 Returns 766 ------- 767 list 768 List of outputs on indices/keys from ``output_subset`` or all of them, 769 if ``output_subset`` is not passed. 770 """ 771 def restore_defaults(): 772 for i, (required, refeed, value) in enumerate(self.defaults): 773 if refeed: 774 if isinstance(value, gof.Container): 775 value = value.storage[0] 776 self[i] = value 777 profile = self.profile 778 t0 = time.time() 779 780 output_subset = kwargs.pop('output_subset', None) 781 if output_subset is not None and self.output_keys is not None: 782 output_subset =\ 783 [self.output_keys.index(key) for key in output_subset] 784 785 # Reinitialize each container's 'provided' counter 786 if self.trust_input: 787 i = 0 788 for arg in args: 789 s = self.input_storage[i] 790 s.storage[0] = arg 791 i += 1 792 else: 793 for c in self.input_storage: 794 c.provided = 0 795 796 if len(args) + len(kwargs) > len(self.input_storage): 797 raise TypeError("Too many parameter passed to theano function") 798 799 # Set positional arguments 800 i = 0 801 for arg in args: 802 # TODO: provide a Param option for skipping the filter if we 803 # really want speed. 804 s = self.input_storage[i] 805 # see this emails for a discuation about None as input 806 # https://groups.google.com/group/theano-dev/browse_thread/thread/920a5e904e8a8525/4f1b311a28fc27e5 807 if arg is None: 808 s.storage[0] = arg 809 else: 810 try: 811 s.storage[0] = s.type.filter( 812 arg, strict=s.strict, 813 allow_downcast=s.allow_downcast) 814 815 except Exception as e: 816 function_name = "theano function" 817 argument_name = "argument" 818 if self.name: 819 function_name += ' with name "' + self.name + '"' 820 if hasattr(arg, 'name') and arg.name: 821 argument_name += ' with name "' + arg.name + '"' 822 where = theano.gof.utils.get_variable_trace_string( 823 self.maker.inputs[i].variable) 824 if len(e.args) == 1: 825 e.args = ("Bad input " + argument_name + " to " + 826 function_name + " at index %d (0-based). %s" 827 % (i, where) + e.args[0],) 828 else: 829 e.args = ("Bad input " + argument_name + " to " + 830 function_name + " at index %d (0-based). %s" 831 % (i, where),) + e.args 832 restore_defaults() 833 raise 834 s.provided += 1 835 i += 1 836 837 # Set keyword arguments 838 if kwargs: # for speed, skip the iteritems for empty kwargs 839 for k, arg in iteritems(kwargs): 840 self[k] = arg 841 842 if (not self.trust_input and 843 # The getattr is only needed for old pickle 844 getattr(self, '_check_for_aliased_inputs', True)): 845 # Collect aliased inputs among the storage space 846 args_share_memory = [] 847 for i in xrange(len(self.input_storage)): 848 i_var = self.maker.inputs[i].variable 849 i_val = self.input_storage[i].storage[0] 850 if hasattr(i_var.type, 'may_share_memory'): 851 is_aliased = False 852 for j in xrange(len(args_share_memory)): 853 854 group_j = izip( 855 [self.maker.inputs[k].variable for k 856 in args_share_memory[j]], 857 [self.input_storage[k].storage[0] for k 858 in args_share_memory[j]]) 859 if any([(var.type is i_var.type and 860 var.type.may_share_memory(val, i_val)) 861 for (var, val) in group_j]): 862 863 is_aliased = True 864 args_share_memory[j].append(i) 865 break 866 867 if not is_aliased: 868 args_share_memory.append([i]) 869 870 # Check for groups of more than one argument that share memory 871 for group in args_share_memory: 872 if len(group) > 1: 873 # copy all but the first 874 for j in group[1:]: 875 self.input_storage[j].storage[0] = copy.copy( 876 self.input_storage[j].storage[0]) 877 878 # Check if inputs are missing, or if inputs were set more than once, or 879 # if we tried to provide inputs that are supposed to be implicit. 880 if not self.trust_input: 881 for c in self.input_storage: 882 if c.required and not c.provided: 883 restore_defaults() 884 raise TypeError("Missing required input: %s" % 885 getattr(self.inv_finder[c], 'variable', 886 self.inv_finder[c])) 887 if c.provided > 1: 888 restore_defaults() 889 raise TypeError("Multiple values for input: %s" % 890 getattr(self.inv_finder[c], 'variable', 891 self.inv_finder[c])) 892 if c.implicit and c.provided > 0: 893 restore_defaults() 894 raise TypeError( 895 'Tried to provide value for implicit input: %s' 896 % getattr(self.inv_finder[c], 'variable', 897 self.inv_finder[c])) 898 899 # Do the actual work 900 t0_fn = time.time() 901 try: 902 outputs =\ 903 self.fn() if output_subset is None else\ 904 self.fn(output_subset=output_subset) 905 except Exception: 906 restore_defaults() 907 if hasattr(self.fn, 'position_of_error'): 908 # this is a new vm-provided function or c linker 909 # they need this because the exception manipulation 910 # done by raise_with_op is not implemented in C. 911 thunk = None 912 if hasattr(self.fn, 'thunks'): 913 thunk = self.fn.thunks[self.fn.position_of_error] 914 gof.link.raise_with_op( 915 node=self.fn.nodes[self.fn.position_of_error], 916 thunk=thunk, 917 storage_map=getattr(self.fn, 'storage_map', None)) 918 else: 919 # old-style linkers raise their own exceptions 920 raise 921 922 dt_fn = time.time() - t0_fn 923 self.maker.mode.fn_time += dt_fn 924 if profile: 925 profile.vm_call_time += dt_fn 926 927 # Retrieve the values that were computed 928 if outputs is None: 929 outputs = [x.data for x in self.output_storage] 930 assert len(outputs) == len(self.output_storage) 931 932 # Remove internal references to required inputs. 933 # These cannot be re-used anyway. 934 for c in self.input_storage: 935 if c.required: 936 c.storage[0] = None 937 938 # if we are allowing garbage collection, remove the 939 # output reference from the internal storage cells 940 if getattr(self.fn, 'allow_gc', False): 941 assert len(self.output_storage) == len(self.maker.fgraph.outputs) 942 for o_container, o_variable in zip(self.output_storage, 943 self.maker.fgraph.outputs): 944 if o_variable.owner is not None: 945 # this node is the variable of computation 946 # WARNING: This circumvents the 'readonly' attribute in x 947 o_container.storage[0] = None 948 949 if getattr(self.fn, 'need_update_inputs', True): 950 # Update the inputs that have an update function 951 for input, storage in reversed(list(zip(self.maker.expanded_inputs, 952 self.input_storage))): 953 if input.update is not None: 954 storage.data = outputs.pop() 955 else: 956 outputs = outputs[:self.n_returned_outputs] 957 958 # Put default values back in the storage 959 restore_defaults() 960 # 961 # NOTE: This logic needs to be replicated in 962 # scan. 963 # grep for 'PROFILE_CODE' 964 # 965 966 dt_call = time.time() - t0 967 theano.compile.profiling.total_fct_exec_time += dt_call 968 self.maker.mode.call_time += dt_call 969 if profile: 970 profile.fct_callcount += 1 971 profile.fct_call_time += dt_call 972 if hasattr(self.fn, 'update_profile'): 973 self.fn.update_profile(profile) 974 if profile.ignore_first_call: 975 profile.reset() 976 profile.ignore_first_call = False 977 if self.return_none: 978 return None 979 elif self.unpack_single and len(outputs) == 1 and\ 980 output_subset is None: 981 return outputs[0] 982 else: 983 984 if self.output_keys is not None: 985 986 assert len(self.output_keys) == len(outputs) 987 988 if output_subset is None: 989 return dict(izip(self.output_keys, outputs)) 990 else: 991 return dict((self.output_keys[index], outputs[index]) 992 for index in output_subset) 993 994 if output_subset is None: 995 return outputs 996 else: 997 return [outputs[i] for i in output_subset] 998 999 value = property( 1000 lambda self: self._value, 1001 None, # this property itself is not settable 1002 doc="dictionary-like access to the values associated with Variables") 1003 container = property( 1004 lambda self: self._container, 1005 None, # this property itself is not settable 1006 doc=("dictionary-like access to the containers associated with " 1007 "Variables")) 1008 1009 def free(self): 1010 """ 1011 When allow_gc = False, clear the Variables in storage_map 1012 """ 1013 # 1.no allow_gc return False 1014 # 2.has allow_gc, if allow_gc is False, return True 1015 if not getattr(self.fn, 'allow_gc', True): 1016 for key in self.fn.storage_map: 1017 if not isinstance(key, theano.gof.Constant): 1018 self.fn.storage_map[key][0] = None 1019 1020 for node in self.nodes_with_inner_function: 1021 ops_with_inner_function[node.op].free() 1022 1023 def get_shared(self): 1024 """ 1025 Return the shared variable read or updated by by this function. 1026 """ 1027 return [i.variable for i in self.maker.inputs if i.implicit] 1028 1029 def sync_shared(self): 1030 if (hasattr(theano, "gpuarray") and 1031 theano.gpuarray.pygpu_activated): 1032 import pygpu 1033 for i in self.maker.fgraph.update_mapping.values(): 1034 inp = self.input_storage[i] 1035 if isinstance(inp.data, pygpu.gpuarray.GpuArray): 1036 inp.data.sync() 1037 1038 1039# pickling/deepcopy support for Function 1040def _pickle_Function(f): 1041 # copy of the input storage list 1042 ins = list(f.input_storage) 1043 input_storage = [] 1044 1045 for (input, indices, inputs), (required, refeed, default) in \ 1046 zip(f.indices, f.defaults): 1047 input_storage.append(ins[0]) 1048 del ins[0] 1049 1050 inputs_data = [x.data for x in f.input_storage] 1051 1052 # HACK to detect aliased storage. 1053 # This is here because aliased relationships are not [currently] 1054 # preserved across the pickle operation 1055 if not (f.pickle_aliased_memory_strategy == 'ignore'): 1056 all_data = input_storage + inputs_data 1057 for i, d_i in enumerate(all_data): 1058 for j, d_j in enumerate(all_data): 1059 if ((i < j) and isinstance(d_i, np.ndarray) and 1060 isinstance(d_j, np.ndarray)): 1061 if np.may_share_memory(d_i, d_j): 1062 if f.pickle_aliased_memory_strategy == 'warn': 1063 _logger.warning('aliased relationship between ' 1064 'Function arguments %s, %s ' 1065 'will not be preserved by ' 1066 'un-pickling operation' % 1067 (str(d_i), str(d_j))) 1068 else: 1069 raise AliasedMemoryError(d_i, d_j) 1070 # The user can override trust_input. Our doc tell that. We should 1071 # not do that anymore and make sure the Maker have all the 1072 # information needed. 1073 rval = (_constructor_Function, 1074 (f.maker, input_storage, inputs_data, f.trust_input)) 1075 return rval 1076 1077 1078def _constructor_Function(maker, input_storage, inputs_data, trust_input=False): 1079 if not theano.config.unpickle_function: 1080 return None 1081 1082 f = maker.create(input_storage, trustme=True) 1083 assert len(f.input_storage) == len(inputs_data) 1084 for container, x in zip(f.input_storage, inputs_data): 1085 assert (container.data is x) or \ 1086 (isinstance(x, np.ndarray) and (container.data == x).all()) or \ 1087 (container.data == x) 1088 f.trust_input = trust_input 1089 return f 1090 1091copyreg.pickle(Function, _pickle_Function) 1092 1093 1094### 1095# FunctionMaker 1096### 1097def insert_deepcopy(fgraph, wrapped_inputs, wrapped_outputs): 1098 """ 1099 Insert deepcopy in the fgraph to break aliasing of outputs 1100 """ 1101 # This loop was inserted to remove aliasing between outputs when 1102 # they all evaluate to the same value. Originally it was OK for 1103 # outputs to be aliased, but some of the outputs can be shared 1104 # variables, and is not good for shared variables to be 1105 # aliased. It might be possible to optimize this by making sure 1106 # there is no aliasing only between shared variables. 1107 1108 # If some outputs are constant, we add deep copy to respect the 1109 # memory contract 1110 1111 # We don't insert deep copy when the output.borrow is True for all 1112 # concerned outputs. 1113 1114 assert len(wrapped_inputs) == len(fgraph.inputs) 1115 assert len(wrapped_outputs) == len(fgraph.outputs) 1116 reason = "insert_deepcopy" 1117 updated_fgraph_inputs = set([fgraph_i for i, fgraph_i in 1118 zip(wrapped_inputs, fgraph.inputs) 1119 if getattr(i, 'update', False)]) 1120 1121 # We can't use fgraph.inputs as this don't include Constant Value. 1122 all_graph_inputs = gof.graph.inputs(fgraph.outputs) 1123 has_destroyers_attr = hasattr(fgraph, 'has_destroyers') 1124 1125 for i in xrange(len(fgraph.outputs)): 1126 views_of_output_i = set() 1127 view_tree_set(alias_root(fgraph.outputs[i]), views_of_output_i) 1128 copied = False 1129 # do not allow outputs to be aliased 1130 for j in xrange(i + 1, len(fgraph.outputs)): 1131 # We could don't put deep copy if both outputs have borrow==True 1132 # and not(wrapped_outputs[i].borrow and wrapped_outputs[j].borrow): 1133 if fgraph.outputs[j] in views_of_output_i: 1134 if wrapped_outputs[i].borrow and wrapped_outputs[j].borrow: 1135 fgraph.change_input('output', i, 1136 view_op(fgraph.outputs[i]), 1137 reason=reason) 1138 else: 1139 fgraph.change_input('output', i, 1140 deep_copy_op(fgraph.outputs[i]), 1141 reason=reason) 1142 copied = True 1143 break 1144 1145 if not copied: 1146 for input_j in all_graph_inputs: 1147 # do not allow outputs to be aliased to an inputs (j), unless 1148 # a) that j'th input has been 'destroyed' by 1149 # e.g. in-place computations 1150 # b) that j'th input is a shared variable that is also 1151 # being updated 1152 if input_j in updated_fgraph_inputs: 1153 continue 1154 if input_j in views_of_output_i and not (has_destroyers_attr and fgraph.has_destroyers([input_j])): 1155 # We don't put deep_copy_op if the input and the 1156 # output have borrow==True 1157 if input_j in fgraph.inputs: 1158 j = fgraph.inputs.index(input_j) 1159 if (wrapped_outputs[i].borrow and 1160 wrapped_inputs[j].borrow): 1161 fgraph.change_input('output', i, 1162 view_op(fgraph.outputs[i]), 1163 reason="insert_deepcopy") 1164 break 1165 else: 1166 fgraph.change_input( 1167 'output', i, 1168 deep_copy_op(fgraph.outputs[i]), 1169 reason="insert_deepcopy") 1170 break 1171 elif wrapped_outputs[i].borrow: 1172 fgraph.change_input('output', i, 1173 view_op(fgraph.outputs[i]), 1174 reason="insert_deepcopy") 1175 break 1176 else: 1177 fgraph.change_input('output', i, 1178 deep_copy_op(fgraph.outputs[i]), 1179 reason="insert_deepcopy") 1180 break 1181 1182NODEFAULT = ['NODEFAULT'] 1183 1184 1185class FunctionMaker(object): 1186 """ 1187 `FunctionMaker` is the class to `create` `Function` instances. 1188 1189 This class has the fgraph, the optimizer, and the linker. When 1190 copying a `Function`, there is no need to duplicate the 1191 `FunctionMaker` instance. Deepcopy still copies both, which can 1192 variable in re-compilation. 1193 1194 Parameters 1195 ---------- 1196 inputs : list of SymbolicInput instances 1197 outputs : list of SymbolicOutput instances 1198 Outputs may also be a single Variable (not a list), in which case the 1199 functions produced by FunctionMaker will return their output value 1200 directly. 1201 mode : Mode instance 1202 Telling FunctionMaker how to optimize and link. None means to use the 1203 `config.mode`. 1204 accept_inplace : bool 1205 True iff it is acceptable to have inplace operations in the graph from 1206 the inputs to the outputs. 1207 on_unused_input : {'raise', 'warn', 'ignore', None} 1208 What to do if a variable in the 'inputs' list is not used in the graph. 1209 Possible values are: 1210 - 'raise': raise an error 1211 - 'warn': log a warning 1212 - 'ignore': do not do anything 1213 - None: Use the value in the Theano flags on_unused_input. 1214 name : str 1215 An optional name for this function. If used, the profile mode will 1216 print the time spent in this function. 1217 1218 """ 1219 1220 @staticmethod 1221 def wrap_in(input): 1222 if isinstance(input, (SymbolicInput)): 1223 return input 1224 elif isinstance(input, gof.Variable): 1225 # r -> SymbolicInput(variable=r) 1226 return SymbolicInput(input) 1227 elif isinstance(input, (list, tuple)): 1228 # (r, u) -> SymbolicInput(variable=r, update=u) 1229 if len(input) == 2: 1230 return SymbolicInput(input[0], update=input[1]) 1231 else: 1232 raise TypeError("Expected two elements in the list or tuple.", 1233 input) 1234 else: 1235 raise TypeError("Unknown input type: %s (%s), expected Variable " 1236 "instance", type(input), input) 1237 1238 @staticmethod 1239 def expand_in(sinput, rinputs): 1240 # For SymbolicInputKits, this extracts a list of SymbolicInput 1241 # instances and corresponding indices such that these 1242 # SymbolicInputs are representative of some of the Variable 1243 # instances in inputs. For SymbolicInput, this returns None 1244 # as the list of indices and a list with just the 1245 # SymbolicInput. 1246 # if isinstance(sinput, SymbolicInputKit): 1247 # return sinput.complete(rinputs) 1248 # elif isinstance(sinput, SymbolicInput): 1249 if isinstance(sinput, SymbolicInput): 1250 return [None, [sinput]] 1251 1252 @staticmethod 1253 def wrap_out(output): 1254 if isinstance(output, SymbolicOutput): 1255 return output 1256 elif isinstance(output, gof.Variable): 1257 return SymbolicOutput(output) 1258 else: 1259 raise TypeError("Unknown output type: %s (%s)", type(output), 1260 output) 1261 1262 def optimize_graph_with_cache(self, optimizer, inputs, outputs): 1263 # This function is not finished 1264 from theano.gof.compilelock import get_lock, release_lock 1265 import os.path 1266 1267 graph_db_file = os.path.join(theano.config.compiledir, 1268 'optimized_graphs.pkl') 1269 1270 # the inputs, outputs, and size of the graph to be optimized 1271 inputs_new = [inp.variable for inp in inputs] 1272 outputs_new = [out.variable for out in outputs] 1273 size_new = len(self.fgraph.apply_nodes) 1274 get_lock() 1275 # Beginning of cache optimizations. 1276 # Could be refactored in different functions. 1277 1278 def load_graph_db(): 1279 if os.path.isfile(graph_db_file): 1280 print('graph_db already exists') 1281 else: 1282 # create graph_db 1283 with open(graph_db_file, 'wb') as f: 1284 print('create new graph_db in %s' % graph_db_file) 1285 # load the graph_db dictionary 1286 try: 1287 with open(graph_db_file, 'rb') as f: 1288 # Temporary hack to allow 1289 # theano.scan_module.tests.test_scan.T_Scan to 1290 # finish. Should be changed in definitive version. 1291 tmp = theano.config.unpickle_function 1292 theano.config.unpickle_function = False 1293 graph_db = pickle.load(f) 1294 print('graph_db loaded and it is not empty') 1295 except EOFError as e: 1296 # the file has nothing in it 1297 print(e) 1298 print('graph_db loaded and it is empty') 1299 graph_db = {} 1300 finally: 1301 theano.config.unpickle_function = tmp 1302 1303 return graph_db 1304 1305 def find_same_graph_in_db(graph_db): 1306 # If found_graph_in_db is None, then need to optimize. 1307 # Otherwise, return the graph found. 1308 found_graph_in_db = None 1309 # The sole purpose of this loop is to set 'need_optimize' by 1310 # going through graph_db, looking for graph that has the same 1311 # computation performed. 1312 for graph_old, graph_optimized in iteritems(graph_db): 1313 inputs_old = graph_old.inputs 1314 outputs_old = graph_old.outputs 1315 size_old = len(graph_old.apply_nodes) 1316 # Some heuristics to check is the same graphs have 1317 # already been optimized before. 1318 if len(inputs_new) != len(inputs_old): 1319 # If the inputs are of different size, 1320 # two graphs are for sure different 1321 print('need to optimize, because input size is different') 1322 continue 1323 elif len(outputs_new) != len(outputs_old): 1324 # If the inputs are of different size, 1325 # two graphs are for sure different 1326 print('need to optimize, because output size is different') 1327 continue 1328 elif not all(input_new.type == input_old.type 1329 for input_new, input_old in 1330 zip(inputs_new, inputs_old)): 1331 print('need to optimize, because inputs are of different ' 1332 'types') 1333 continue 1334 elif not all(output_new.type == output_old.type 1335 for output_new, output_old in 1336 zip(outputs_new, outputs_old)): 1337 print('need to optimize, because outputs are of different ' 1338 'types') 1339 continue 1340 elif not size_old == size_new: 1341 print('need to optimize, because numbers of nodes in graph' 1342 ' are different') 1343 continue 1344 else: 1345 flags = [] 1346 for i, (output_new, output_old) in enumerate( 1347 zip(outputs_new, outputs_old)): 1348 print('loop through outputs node for both graphs') 1349 graph_old.variables = set(gof.graph.variables( 1350 graph_old.inputs, graph_old.outputs)) 1351 1352 # using clone allowed to avoid a lot of errors 1353 # deep copy seemed to had. 1354 f2 = graph_old.clone(check_integrity=False) 1355 t1 = output_new 1356 t2 = f2.outputs[i] 1357 1358 # Used to remove "already used by another graph error 1359 def removeAllFgraph(remove): 1360 if hasattr(remove, 'fgraph'): 1361 del remove.fgraph 1362 if hasattr(remove, 'owner'): 1363 if remove.owner is None: 1364 pass 1365 else: 1366 if hasattr(remove.owner, 'fgraph'): 1367 del remove.owner.fgraph 1368 if hasattr(remove.owner, 'inputs'): 1369 remove.owner.inputs = [removeAllFgraph( 1370 i) for i in remove.owner.inputs] 1371 for o in remove.owner.outputs: 1372 if hasattr(o, 'fgraph'): 1373 del o.fgraph 1374 return remove 1375 1376 t2 = removeAllFgraph(t2) 1377 1378 givens = dict(izip(gof.graph.inputs([t1]), 1379 gof.graph.inputs([t2]))) 1380 1381 temp = dict(izip(gof.graph.inputs([t1]), 1382 gof.graph.inputs([t2]))) 1383 1384 # hack to remove inconstent entry in givens 1385 # seems to work that but source of inconsistency 1386 # could be worth investigating. 1387 for key, value in iteritems(temp): 1388 if key.type != value.type: 1389 del givens[key] 1390 1391 flag = is_same_graph(t1, t2, givens=givens) 1392 1393 flags.append(flag) 1394 1395 is_same = all(flags) 1396 if is_same: 1397 # found the match 1398 print('found a match, no need to optimize') 1399 found_graph_in_db = graph_optimized 1400 break 1401 return found_graph_in_db 1402 1403 graph_db = load_graph_db() 1404 print('loaded graph_db from %s, size=%d' % (graph_db_file, 1405 len(graph_db))) 1406 found_graph = find_same_graph_in_db(graph_db) 1407 if found_graph: 1408 self.fgraph = found_graph 1409 optimizer_profile = None 1410 else: 1411 # this is a brand new graph, optimize it, save it to graph_db 1412 print('graph not found in graph_db, optimizing the graph') 1413 self.fgraph.variables = set(gof.graph.variables( 1414 self.fgraph.inputs, self.fgraph.outputs)) 1415 # check_integrity parameters was added to ignore 1416 # "excess cached variables" errors. Works that way 1417 # but once again the error couldbe worth 1418 # investigating. 1419 before_opt = self.fgraph.clone(check_integrity=False) 1420 optimizer_profile = optimizer(self.fgraph) 1421 graph_db.update({before_opt: self.fgraph}) 1422 with open(graph_db_file, 'wb') as f: 1423 pickle.dump(graph_db, f, -1) 1424 print('new graph saved into graph_db') 1425 release_lock() 1426 return optimizer_profile 1427 1428 def __init__(self, inputs, outputs, 1429 mode=None, accept_inplace=False, function_builder=Function, 1430 profile=None, on_unused_input=None, fgraph=None, 1431 output_keys=None, name=None): 1432 # Save the provided mode, not the instanciated mode. 1433 # The instanciated mode don't pickle and if we unpickle a Theano 1434 # function and it get re-compiled, we want the current optimizer to be 1435 # used, not the optimizer when it was saved. 1436 self.mode = mode 1437 mode = theano.compile.mode.get_mode(mode) 1438 1439 # Assert old way of working isn't used 1440 if getattr(mode, 'profile', None): 1441 raise TypeError( 1442 "profile passed via 'mode'. This isn't supported anymore") 1443 self.profile = profile 1444 if profile: 1445 # This is very important: 1446 # 1) We preload the cache here to don't have its timming 1447 # included in optimization that compile function. 1448 # 2) Do not refresh the cache here by default. It cause 1449 # too much execution time during testing as we compile 1450 # much more functions then the number of compile c 1451 # module. 1452 theano.gof.cc.get_module_cache().refresh() 1453 # Handle the case where inputs and/or outputs is a single 1454 # Variable (not in a list) 1455 unpack_single = False 1456 return_none = False 1457 if outputs is None: 1458 return_none = True 1459 outputs = [] 1460 if not isinstance(outputs, (list, tuple)): 1461 unpack_single = True 1462 outputs = [outputs] 1463 if not isinstance(inputs, (list, tuple)): 1464 inputs = [inputs] 1465 1466 # Wrap them in In or Out instances if needed. 1467 inputs = [self.wrap_in(i) for i in inputs] 1468 outputs = [self.wrap_out(o) for o in outputs] 1469 _inputs = gof.graph.inputs([o.variable for o in outputs] + 1470 [i.update for i in inputs 1471 if getattr(i, 'update', False)]) 1472 1473 # Check if some input variables are unused 1474 self._check_unused_inputs(inputs, outputs, on_unused_input) 1475 1476 # Make a list of (SymbolicInput|SymblicInputKits, indices, 1477 # [SymbolicInput,...]), one tuple for each input. (See 1478 # Function.indices for more details) 1479 indices = [[input] + self.expand_in(input, _inputs) 1480 for input in inputs] 1481 1482 if fgraph is None: 1483 need_opt = True 1484 # make the fgraph (copies the graph, creates NEW INPUT AND 1485 # OUTPUT VARIABLES) 1486 fgraph, additional_outputs = std_fgraph(inputs, outputs, 1487 accept_inplace) 1488 fgraph.profile = profile 1489 else: 1490 # fgraph is already an optimized one 1491 need_opt = False 1492 updates = [spec.update for spec in inputs if spec.update] 1493 additional_outputs = list(map(SymbolicOutput, updates)) 1494 1495 self.fgraph = fgraph 1496 1497 # Fetch the optimizer and linker 1498 optimizer, linker = mode.optimizer, copy.copy(mode.linker) 1499 if need_opt: 1500 compute_test_value_orig = theano.config.compute_test_value 1501 limit_orig = theano.config.traceback.limit 1502 # Why we add stack on node when it get done in output var? 1503 try: 1504 # optimize the fgraph 1505 theano.config.compute_test_value = \ 1506 theano.config.compute_test_value_opt 1507 theano.config.traceback.limit = theano.config.traceback.compile_limit 1508 start_optimizer = time.time() 1509 1510 # In case there is an error during optimization. 1511 optimizer_profile = None 1512 opt_time = None 1513 1514 # now optimize the graph 1515 if theano.config.cache_optimizations: 1516 optimizer_profile = self.optimize_graph_with_cache( 1517 optimizer, inputs, outputs) 1518 else: 1519 optimizer_profile = optimizer(fgraph) 1520 1521 end_optimizer = time.time() 1522 opt_time = end_optimizer - start_optimizer 1523 _logger.debug('Optimizing took %f seconds', opt_time) 1524 1525 # Add deep copy to respect the memory interface 1526 insert_deepcopy(fgraph, inputs, outputs + additional_outputs) 1527 finally: 1528 theano.config.compute_test_value = compute_test_value_orig 1529 theano.config.traceback.limit = limit_orig 1530 1531 # If the optimizer got interrupted 1532 if opt_time is None: 1533 end_optimizer = time.time() 1534 opt_time = end_optimizer - start_optimizer 1535 theano.compile.profiling.total_graph_opt_time += opt_time 1536 if profile: 1537 if (optimizer_profile is None and 1538 hasattr(optimizer, 'pre_profile')): 1539 optimizer_profile = optimizer.pre_profile 1540 profile.optimizer_time += opt_time 1541 if theano.config.profile_optimizer: 1542 profile.optimizer_profile = (optimizer, 1543 optimizer_profile) 1544 # IF False, if mean the profile for that function was explicitly disabled 1545 elif theano.config.profile_optimizer and profile is not False: 1546 warnings.warn(( 1547 "config.profile_optimizer requires config.profile to " 1548 " be set to True as well"), stacklevel=3) 1549 1550 # initialize the linker 1551 if not hasattr(linker, 'accept'): 1552 raise ValueError("'linker' parameter of FunctionMaker should be " 1553 "a Linker with an accept method or one of %s" % 1554 list(theano.compile.mode 1555 .predefined_linkers.keys())) 1556 1557 # the 'no_borrow' outputs are the ones for which that we can't 1558 # return the internal storage pointer. 1559 assert len(fgraph.outputs) == len(outputs + additional_outputs) 1560 no_borrow = [output for output, spec in 1561 zip(fgraph.outputs, outputs + additional_outputs) 1562 if not spec.borrow] 1563 if no_borrow: 1564 self.linker = linker.accept( 1565 fgraph, no_recycling=infer_reuse_pattern(fgraph, no_borrow), 1566 profile=profile) 1567 else: 1568 self.linker = linker.accept(fgraph, profile=profile) 1569 1570 if hasattr(linker, 'accept_var_updates'): 1571 # hacky thing so VMLinker knows about updates 1572 self.linker.accept_var_updates( 1573 fgraph_updated_vars(fgraph, inputs)) 1574 fgraph.name = name 1575 self.indices = indices 1576 self.inputs = inputs 1577 self.expanded_inputs = inputs 1578 self.outputs = outputs 1579 self.unpack_single = unpack_single 1580 self.return_none = return_none 1581 self.accept_inplace = accept_inplace 1582 self.function_builder = function_builder 1583 self.on_unused_input = on_unused_input # Used for the pickling/copy 1584 self.output_keys = output_keys 1585 self.name = name 1586 1587 self.required = [(i.value is None) for i in self.inputs] 1588 self.refeed = [ 1589 (i.value is not None and 1590 not isinstance(i.value, gof.Container) and 1591 i.update is None) 1592 for i in self.inputs] 1593 1594 def _check_unused_inputs(self, inputs, outputs, on_unused_input): 1595 if on_unused_input is None: 1596 on_unused_input = theano.config.on_unused_input 1597 1598 if on_unused_input == 'ignore': 1599 return 1600 1601 # There should be two categories of variables in inputs: 1602 # - variables that have to be provided (used_inputs) 1603 # - shared variables that will be updated 1604 used_inputs = gof.graph.ancestors( 1605 ([o.variable for o in outputs] + 1606 [i.update for i in inputs if getattr(i, 'update', False)]), 1607 blockers=[i.variable for i in inputs]) 1608 1609 msg = ("theano.function was asked to create a function computing " 1610 "outputs given certain inputs, but the provided input " 1611 "variable at index %i is not part of the computational graph " 1612 "needed to compute the outputs: %s.\n%s") 1613 warn_msg = ("To make this warning into an error, you can pass the " 1614 "parameter on_unused_input='raise' to theano.function. " 1615 "To disable it completely, use on_unused_input='ignore'.") 1616 err_msg = ("To make this error into a warning, you can pass the " 1617 "parameter on_unused_input='warn' to theano.function. " 1618 "To disable it completely, use on_unused_input='ignore'.") 1619 1620 for i in inputs: 1621 if ((i.variable not in used_inputs) and (i.update is None)): 1622 if on_unused_input == 'warn': 1623 warnings.warn(msg % (inputs.index(i), i.variable, 1624 warn_msg), stacklevel=6) 1625 elif on_unused_input == 'raise': 1626 raise UnusedInputError(msg % (inputs.index(i), 1627 i.variable, err_msg)) 1628 else: 1629 raise ValueError("Invalid value for keyword " 1630 "on_unused_input of theano.function: " 1631 "'%s'.\nValid values are 'raise', " 1632 "'warn', and 'ignore'." % on_unused_input) 1633 1634 def create(self, input_storage=None, trustme=False, storage_map=None): 1635 """ 1636 Create a function. 1637 1638 Parameters 1639 ---------- 1640 input_storage 1641 A list matching the inputs list and providing default values if the 1642 default for an input is None, then that input is a required input. 1643 For an input with an update, the default acts as initialization. 1644 trustme 1645 Disables some exceptions, used internally. 1646 1647 """ 1648 1649 if input_storage is None: 1650 input_storage = [None] * len(self.inputs) 1651 # list of independent one-element lists, will be passed to the linker 1652 input_storage_lists = [] 1653 defaults = [] 1654 1655 # The following loop is to fill in the input_storage_lists and 1656 # defaults lists. 1657 assert len(self.indices) == len(input_storage) 1658 for i, ((input, indices, subinputs), input_storage_i) in \ 1659 enumerate(zip(self.indices, input_storage)): 1660 1661 # Replace any default value given as a variable by its 1662 # container. Note that this makes sense only in the 1663 # context of shared variables, but for now we avoid 1664 # dealing directly with them to avoid dependency on the 1665 # shared variables work-in-progress repository. 1666 if isinstance(input_storage_i, gof.Variable): 1667 input_storage_i = input_storage_i.container 1668 1669 if isinstance(input_storage_i, gof.Container): 1670 # If the default is a gof.Container, this means we want to 1671 # share the same storage. This is done by appending 1672 # input_storage_i.storage to input_storage_lists. 1673 if indices is not None: 1674 raise TypeError("Cannot take a Container instance as " 1675 "default for a SymbolicInputKit.") 1676 input_storage_lists.append(input_storage_i.storage) 1677 1678 storage = input_storage[i].storage[0] 1679 1680 else: 1681 # Normal case: one new, independent storage unit 1682 input_storage_lists.append([input_storage_i]) 1683 1684 storage = input_storage_i 1685 1686 required = self.required[i] 1687 refeed = self.refeed[i] 1688 # sanity check-- if an input is required it should not 1689 # need to be refed 1690 assert not (required and refeed) 1691 1692 # shared variables need neither be input by the user nor refed 1693 if input.shared: 1694 assert not required 1695 assert not refeed 1696 storage = None 1697 1698 # if an input is required, it never need be refed 1699 if required: 1700 storage = None 1701 1702 # make sure that we only store a value if we actually need it 1703 if storage is not None: 1704 assert refeed or not required 1705 1706 defaults.append((required, refeed, storage)) 1707 1708 # Get a function instance 1709 start_linker = time.time() 1710 start_import_time = theano.gof.cmodule.import_time 1711 limit_orig = theano.config.traceback.limit 1712 try: 1713 theano.config.traceback.limit = theano.config.traceback.compile_limit 1714 _fn, _i, _o = self.linker.make_thunk( 1715 input_storage=input_storage_lists, storage_map=storage_map) 1716 finally: 1717 theano.config.traceback.limit = limit_orig 1718 1719 end_linker = time.time() 1720 1721 linker_time = end_linker - start_linker 1722 theano.compile.profiling.total_time_linker += linker_time 1723 _logger.debug('Linker took %f seconds', linker_time) 1724 if self.profile: 1725 self.profile.linker_time += linker_time 1726 _fn.time_thunks = self.profile.flag_time_thunks 1727 import_time = theano.gof.cmodule.import_time - start_import_time 1728 self.profile.import_time += import_time 1729 1730 fn = self.function_builder(_fn, _i, _o, self.indices, self.outputs, 1731 defaults, self.unpack_single, 1732 self.return_none, self.output_keys, self, 1733 name=self.name) 1734 1735 fn.profile = self.profile 1736 return fn 1737 1738 1739def _constructor_FunctionMaker(kwargs): 1740 # Needed for old pickle 1741 # Old pickle have at least the problem that output_keys where not saved. 1742 if theano.config.unpickle_function: 1743 if theano.config.reoptimize_unpickled_function: 1744 del kwargs['fgraph'] 1745 return FunctionMaker(**kwargs) 1746 else: 1747 return None 1748 1749__checkers = [] 1750 1751 1752def check_equal(x, y): 1753 for checker in __checkers: 1754 try: 1755 return checker(x, y) 1756 except Exception: 1757 continue 1758 return x == y 1759 1760 1761def register_checker(checker): 1762 __checkers.insert(0, checker) 1763 1764 1765def orig_function(inputs, outputs, mode=None, accept_inplace=False, 1766 name=None, profile=None, on_unused_input=None, 1767 output_keys=None): 1768 """ 1769 Return a Function that will calculate the outputs from the inputs. 1770 1771 Parameters 1772 ---------- 1773 inputs : list of `SymbolicInput` or `In` instances 1774 outputs : a SymbolicOutput or a list of `SymbolicOutput` or `Out` instances 1775 The return value of the returned function will match the format of this 1776 argument (either the value itself or a list of one or more return 1777 values). 1778 mode : descriptive string or Mode instance 1779 Default of None means to use `config.mode` (see below for descriptive 1780 string list). 1781 name : str 1782 An optional name for this function. If used, the profile mode will print the 1783 time spent in this function. 1784 accept_inplace : bool 1785 True iff the graph can contain inplace operations prior to the 1786 optimization phase (default is False). 1787 profile : None or ProfileStats instance 1788 on_unused_input : {'raise', 'warn', 'ignore', None} 1789 What to do if a variable in the 'inputs' list is not used in the graph. 1790 output_keys : 1791 If the outputs were provided to theano.function as a list, then 1792 output_keys is None. Otherwise, if outputs were provided as a dict, 1793 output_keys is the sorted list of keys from the outputs. 1794 1795 Notes 1796 ----- 1797 Currently, the library provides the following mode strings: 1798 1799 - FAST_RUN (default) (optimize without too much time) 1800 1801 - FAST_COMPILE (minimal optimization) 1802 1803 - DebugMode: verify many internal conditions that are normally assumed 1804 (slow) 1805 1806 """ 1807 1808 # Every element of the input list will be upgraded to an `In` instance if 1809 # necessary, using the rules implemented by the `convert_function_input` 1810 # function. 1811 1812 # Similarly, every element of the output list will be upgraded to an `Out` 1813 # instance if necessary: 1814 1815 t1 = time.time() 1816 mode = theano.compile.mode.get_mode(mode) 1817 1818 inputs = list(map(convert_function_input, inputs)) 1819 if outputs is not None: 1820 if isinstance(outputs, (list, tuple)): 1821 outputs = list(map(FunctionMaker.wrap_out, outputs)) 1822 else: 1823 outputs = FunctionMaker.wrap_out(outputs) 1824 1825 defaults = [getattr(input, 'value', None) for input in inputs] 1826 1827 if isinstance(mode, (list, tuple)): # "mode comparison" semantics 1828 raise Exception("We do not support the passing of multiple modes") 1829 fn = None 1830 try: 1831 Maker = getattr(mode, 'function_maker', FunctionMaker) 1832 m = Maker(inputs, 1833 outputs, 1834 mode, 1835 accept_inplace=accept_inplace, 1836 profile=profile, 1837 on_unused_input=on_unused_input, 1838 output_keys=output_keys, 1839 name=name) 1840 with theano.change_flags(compute_test_value="off"): 1841 fn = m.create(defaults) 1842 finally: 1843 t2 = time.time() 1844 if fn and profile: 1845 profile.compile_time += t2 - t1 1846 # TODO: append 1847 profile.nb_nodes = len(fn.maker.fgraph.apply_nodes) 1848 1849 return fn 1850 1851 1852def convert_function_input(input): 1853 """ 1854 Upgrade a input shortcut to an In instance. 1855 1856 The rules for upgrading are as follows: 1857 1858 - a `Variable` instance r will be upgraded like `In`(r) 1859 1860 - a tuple (name, r) will be `In`(r, name=name) 1861 1862 - a tuple (r, val) will be `In`(r, value=value, autoname=True) 1863 1864 - a tuple ((r,up), val) will be 1865 `In`(r, value=value, update=up, autoname=True) 1866 1867 - a tuple (name, r, val) will be `In`(r, name=name, value=value) 1868 1869 - a tuple (name, (r,up), val) will be 1870 `In`(r, name=name, value=val, update=up, autoname=True) 1871 1872 """ 1873 if isinstance(input, SymbolicInput): 1874 return input 1875 elif isinstance(input, gof.Constant): 1876 raise TypeError('A Constant instance is not a legal function input', 1877 input) 1878 elif isinstance(input, gof.Variable): 1879 return In(input) 1880 elif isinstance(input, (list, tuple)): 1881 orig = input 1882 if not input: 1883 raise TypeError("Nonsensical input specification: %s" % input) 1884 if isinstance(input[0], string_types): 1885 name = input[0] 1886 input = input[1:] 1887 else: 1888 name = None 1889 if isinstance(input[0], (list, tuple)): 1890 if len(input[0]) != 2 or len(input) != 2: 1891 raise TypeError("Invalid input syntax: %s (check " 1892 "documentation or use an In instance)" % orig) 1893 (variable, update), value = input 1894 elif isinstance(input[0], gof.Variable): 1895 if len(input) == 1: 1896 variable, update, value = input[0], None, None 1897 elif len(input) == 2: 1898 (variable, value), update = input, None 1899 else: 1900 raise TypeError("Invalid input syntax: %s (check " 1901 "documentation or use an In instance)" % orig) 1902 elif isinstance(input[0], SymbolicInput): 1903 if len(input) == 1: 1904 return input[0] 1905 elif len(input) == 2: 1906 input, value = input 1907 if name is not None: 1908 input.name = name 1909 input.value = value 1910 return input 1911 else: 1912 raise TypeError("The input specification is not valid: %s" % input) 1913 1914 if not isinstance(variable, gof.Variable): 1915 raise TypeError("Unknown input type: %s, expected Variable " 1916 "instance" % type(variable), variable) 1917 if update is not None and not isinstance(update, gof.Variable): 1918 raise TypeError("Unknown update type: %s, expected Variable " 1919 "instance" % type(update), update) 1920 if (value is not None and 1921 isinstance(value, (gof.Variable, SymbolicInput))): 1922 raise TypeError("The value for input %s should not be a Variable " 1923 "or SymbolicInput instance (got: %s)" % 1924 (variable, value)) 1925 1926 return In(variable, name=name, value=value, update=update) 1927 else: 1928 raise TypeError("Unknown input type: %s, expected Variable instance" % 1929 type(input), input) 1930 1931 1932def get_info_on_inputs(named_inputs, n_unnamed_inputs): 1933 """ 1934 Return a human-readable description of named and un-named inputs. 1935 1936 """ 1937 n_named_inputs = len(named_inputs) 1938 1939 def get_plural(n): 1940 if n > 1: 1941 return 's' 1942 else: 1943 return '' 1944 1945 if n_named_inputs == 0: 1946 if n_unnamed_inputs == 0: 1947 msg = 'The function is supposed to have no input.' 1948 else: 1949 if n_unnamed_inputs == 1: 1950 msg = ("The function has a single input variable which has no " 1951 "name, and thus cannot be assigned through a keyword" 1952 " argument (use 'name=...' in a Variable's " 1953 "constructor to give it a name).") 1954 else: 1955 # Use plural. 1956 msg = ("The function has %s inputs, but none of them is named," 1957 " and thus they cannot be assigned through keyword " 1958 "arguments (use 'name=...' in a Variable's " 1959 "constructor to give it a name)." % n_unnamed_inputs) 1960 else: 1961 if n_unnamed_inputs == 0: 1962 msg = ("The function has %s named input%s (%s)." % 1963 (n_named_inputs, get_plural(n_named_inputs), 1964 ', '.join(named_inputs))) 1965 else: 1966 msg = ("The function has %s named input%s (%s), and %s unnamed " 1967 "input%s which thus cannot be accessed through keyword " 1968 "argument%s (use 'name=...' in a variable's constructor " 1969 "to give it a name)." % 1970 (n_named_inputs, get_plural(n_named_inputs), 1971 ', '.join(named_inputs), n_unnamed_inputs, 1972 get_plural(n_unnamed_inputs), 1973 get_plural(n_unnamed_inputs))) 1974 return msg 1975