1"""
2Driver of graph construction, optimization, and linking.
3
4"""
5from __future__ import absolute_import, print_function, division
6
7import copy
8from six import string_types, iteritems, iterkeys
9from six.moves import xrange
10import six.moves.copyreg as copyreg
11import six.moves.cPickle as pickle
12from itertools import chain
13import time
14import warnings
15import numpy as np
16
17import theano
18from theano import config, gof
19from theano.compat import izip
20from theano.gof import graph
21import theano.compile.profiling
22from theano.compile.io import (
23    In, SymbolicInput, SymbolicOutput)
24from theano.compile.ops import deep_copy_op, view_op
25from theano.gof.graph import is_same_graph
26from theano.gof.op import ops_with_inner_function
27
28import logging
29_logger = logging.getLogger('theano.compile.function_module')
30
31__docformat__ = "restructuredtext en"
32
33
34class UnusedInputError(Exception):
35    """
36    A symbolic input passed to function is not needed.
37
38    """
39
40    pass
41
42
43def alias_root(v):
44    """
45    Return the variable to which v is aliased by view_maps and destroy_maps.
46
47    """
48    if v.owner is None:
49        return v
50    vmap = getattr(v.owner.op, 'view_map', {})
51    dmap = getattr(v.owner.op, 'destroy_map', {})
52    outpos = v.owner.outputs.index(v)
53    v_views = vmap.get(outpos, []) + dmap.get(outpos, [])
54    if len(v_views) > 1:
55        raise NotImplementedError(
56            str(v) + " is a view/destroyed version of more then one inputs. "
57            "Currently, we only support the case where an output is a view or "
58            "a destroyed version of one input.")
59    elif v_views:
60        return alias_root(v.owner.inputs[v_views[0]])
61    else:
62        return v
63
64
65def view_tree_set(v, treeset):
66    """
67    Add to `treeset` all variables that are views of v, given that v is
68    not a view.
69
70    """
71    treeset.add(v)
72    for cl, v_input_pos_to_cl in v.clients:
73        if cl == 'output':
74            continue
75        vmap = getattr(cl.op, 'view_map', {})
76        dmap = getattr(cl.op, 'destroy_map', {})
77        for opos, iposlist in chain(iteritems(vmap), iteritems(dmap)):
78            if v_input_pos_to_cl in iposlist:
79                if cl.outputs[opos] not in treeset:
80                    view_tree_set(cl.outputs[opos], treeset)
81
82
83def infer_reuse_pattern(fgraph, outputs_to_disown):
84    """
85    Given an fgraph and a list of variables, returns the list or set
86    of all variables which may share the same underlying data storage
87    as any of the specified variables. Used internally by function,
88    FunctionMaker.
89
90    This list (or set) is also referred to as no_recycling sometimes,
91    especially by linker code.
92
93    """
94    rval = set()
95    for o in outputs_to_disown:
96        view_tree_set(alias_root(o), rval)
97    # remove from rval all of the inputs, constants, values.
98    rval = set(r for r in rval if r.owner is not None)
99
100    return rval
101
102
103def fgraph_updated_vars(fgraph, expanded_inputs):
104    """
105    Reconstruct the full "updates" dictionary, mapping from FunctionGraph input
106    variables to the fgraph outputs that will replace their values.
107
108    Returns
109    -------
110    dict variable -> variable
111
112    """
113    updated_vars = {}
114    potential_values = list(fgraph.outputs)  # copy the list
115    if len(expanded_inputs) != len(fgraph.inputs):
116        raise ValueError('expanded_inputs must match len(fgraph.inputs)')
117    for e_input, ivar in reversed(list(zip(expanded_inputs, fgraph.inputs))):
118        if e_input.update is not None:
119            updated_vars[ivar] = potential_values.pop()
120    return updated_vars
121
122
123class Supervisor:
124    """
125    Listener for FunctionGraph events which makes sure that no
126    operation overwrites the contents of protected Variables. The
127    outputs of the FunctionGraph are protected by default.
128
129    """
130
131    def __init__(self, protected):
132        self.protected = list(protected)
133
134    def validate(self, fgraph):
135        if config.cycle_detection == 'fast' and hasattr(fgraph, 'has_destroyers'):
136            if fgraph.has_destroyers(self.protected):
137                raise gof.InconsistencyError("Trying to destroy a protected"
138                                             "Variable.")
139            return True
140        if not hasattr(fgraph, 'destroyers'):
141            return True
142        for r in self.protected + list(fgraph.outputs):
143            if fgraph.destroyers(r):
144                raise gof.InconsistencyError("Trying to destroy a protected"
145                                             "Variable.", r)
146
147
148def std_fgraph(input_specs, output_specs, accept_inplace=False):
149    """
150    Makes an FunctionGraph corresponding to the input specs and the output
151    specs.  Any SymbolicInput in the input_specs, if its update field
152    is not None, will add an output to the FunctionGraph corresponding to that
153    update. The return value is the FunctionGraph as well as a list of
154    SymbolicOutput instances corresponding to the updates.
155
156    If accept_inplace is False, the graph will be checked for inplace
157    operations and an exception will be raised if it has any. If
158    accept_inplace is True, a DestroyHandler will be added to the FunctionGraph
159    if there are any inplace operations.
160
161    The returned FunctionGraph is a clone of the graph between the provided
162    inputs and outputs.
163
164    """
165    orig_inputs = [spec.variable for spec in input_specs]
166
167    # Extract the updates and the mapping between update outputs and
168    # the updated inputs.
169    updates = []
170    update_mapping = {}
171    out_idx = len(output_specs)
172    for inp_idx in range(len(input_specs)):
173        if input_specs[inp_idx].update:
174            updates.append(input_specs[inp_idx].update)
175            update_mapping[out_idx] = inp_idx
176            out_idx += 1
177
178    orig_outputs = [spec.variable for spec in output_specs] + updates
179
180    fgraph = gof.fg.FunctionGraph(orig_inputs, orig_outputs,
181                                  update_mapping=update_mapping)
182
183    for node in fgraph.apply_nodes:
184        if getattr(node.op, 'destroy_map', None):
185            if not accept_inplace:
186                raise TypeError("Graph must not contain inplace operations",
187                                node, node.op)
188            else:
189                fgraph.attach_feature(gof.DestroyHandler())
190                break
191
192    # We need to protect all immutable inputs from inplace operations.
193    fgraph.attach_feature(
194        Supervisor(input
195                   for spec, input in zip(input_specs, fgraph.inputs)
196                   if not (spec.mutable or
197                           (hasattr(fgraph, 'destroyers') and
198                            fgraph.has_destroyers([input])))))
199
200    # If named nodes are replaced, keep the name
201    for feature in std_fgraph.features:
202        fgraph.attach_feature(feature())
203    return fgraph, list(map(SymbolicOutput, updates))
204
205
206std_fgraph.features = [gof.toolbox.PreserveVariableAttributes]
207
208
209class AliasedMemoryError(Exception):
210    """
211    Memory is aliased that should not be.
212
213    """
214    pass
215
216
217###
218# Function
219###
220
221# unique id object used as a placeholder for duplicate entries
222DUPLICATE = ['DUPLICATE']
223
224
225class Function(object):
226    """
227    Type of the functions returned by theano.function or
228    theano.FunctionMaker.create.
229
230    `Function` is the callable object that does computation.  It has the storage
231    of inputs and outputs, performs the packing and unpacking of inputs and
232    return values. It implements the square-bracket indexing so that you can
233    look up the value of a symbolic node.
234
235    Functions are copyable via {{{fn.copy()}}} and {{{copy.copy(fn)}}}.
236    When a function is copied, this instance is duplicated. Contrast with
237    self.maker (instance of `FunctionMaker`) that is shared between copies.
238    The meaning of copying a function is that the containers and their current
239    values will all be duplicated. This requires that mutable inputs be
240    copied, whereas immutable inputs may be shared between copies.
241
242    A Function instance is hashable, on the basis of its memory
243    address (its id).
244
245    A Function instance is only equal to itself.
246
247    A Function instance may be serialized using the `pickle` or
248    `cPickle` modules.  This will save all default inputs, the graph,
249    and WRITEME to the pickle file.
250
251    A Function instance have a ``trust_input`` field that default to
252    False. When True, we don't do extra check of the input to give
253    better error message. In some case, python code will still return
254    the good results if you pass a python or numpy scalar instead of a
255    numpy tensor.  C code should raise an error if you pass an object
256    of the wrong type.
257
258    Attributes
259    ----------
260    finder
261    inv_finder
262
263    """
264
265    pickle_aliased_memory_strategy = 'warn'
266    """
267    How to deal with pickling finding aliased storage.
268
269    Meaningful settings are: 'ignore', 'warn', 'raise'.
270
271    If the value is 'warn', then a message will be printed to stderr
272    if aliased storage is dectected during pickle.dump.
273
274    If the value is 'raise', then an AliasedMemoryError will be raised
275    if aliased storage is detected during pickle.dump.
276
277    """
278
279    input_storage = None
280    """
281    List of Container instances.
282
283    """
284
285    output_storage = None
286    """
287    List of Container instances.
288
289    """
290
291    indices = None
292    """
293    List of (SymbolicInput, indices, [SymbolicInput,...]),
294    one tuple for each input.
295
296    The first tuple element is the SymbolicInput object for the corresponding
297    function input.
298
299    The second and third tuple elements are used only by Kits, which
300    are deprecated.
301
302    """
303
304    defaults = None
305    """
306    List of 3-tuples, one 3-tuple for each input.
307
308    Tuple element 0: Bool:  Is this input required at each function call?
309    Tuple element 1: Bool: Should this inputs value be reverted after
310        each call?
311    Tuple element 2: Any:  The value associated with this input.
312
313    """
314
315    unpack_single = None
316    """
317    Bool: for outputs lists of length 1, should the 0'th element be
318    returned directly?
319
320    """
321
322    return_none = None
323    """
324    Bool: whether the function should return None or not.
325
326    """
327
328    maker = None
329    """
330    FunctionMaker instance.
331
332    """
333
334    fn = None
335    """
336    A function that evaluates the graph. Typically a linker's make_thunk method
337    created this function.
338
339    """
340
341    finder = None
342    """
343    Dictionary mapping several kinds of things to containers.
344
345    We set an entry in finder for:
346
347    - the index of the input
348
349    - the variable instance the input is based on
350
351    - the name of the input
352
353    All entries map to the container or to DUPLICATE if an ambiguity
354    is detected.
355
356    """
357
358    inv_finder = None
359    """
360    Dict. Reverse lookup of `finder`.
361
362    It maps container -> SymbolicInput
363
364    """
365
366    def __init__(self, fn, input_storage, output_storage, indices, outputs,
367                 defaults, unpack_single, return_none, output_keys, maker,
368                 name=None):
369        self.fn = fn
370        self.input_storage = input_storage
371        self.output_storage = output_storage
372        self.indices = indices
373        self.outputs = outputs
374        self.defaults = defaults
375        self.unpack_single = unpack_single
376        self.return_none = return_none
377        self.maker = maker
378        self.profile = None  # reassigned in FunctionMaker.create
379        self.trust_input = False  # If True, we don't check the input parameter
380        self.name = name
381        self.nodes_with_inner_function = []
382        self.output_keys = output_keys
383
384        # See if we have any mutable / borrow inputs
385        # TODO: this only need to be set if there is more then 1 input
386        self._check_for_aliased_inputs = False
387        for i in maker.inputs:
388            # If the input is a shared variable, the memory region is
389            # under Theano control and so we don't need to check if it
390            # is aliased as we never do that.
391            if (isinstance(i, In) and not i.shared and
392                (getattr(i, 'borrow', False) or
393                 getattr(i, 'mutable', False))):
394                self._check_for_aliased_inputs = True
395                break
396
397        # We will be popping stuff off this `containers` object.  It is a copy.
398        containers = list(self.input_storage)
399        finder = {}
400        inv_finder = {}
401
402        def distribute(indices, cs, value):
403            input.distribute(value, indices, cs)
404            for c in cs:
405                c.provided += 1
406
407        # Store the list of names of named inputs.
408        named_inputs = []
409        # Count the number of un-named inputs.
410        n_unnamed_inputs = 0
411
412        # Initialize the storage
413        # this loop works by modifying the elements (as variable c) of
414        # self.input_storage inplace.
415        for i, ((input, indices, sinputs), (required, refeed, value)) in \
416                enumerate(zip(self.indices, defaults)):
417            if indices is None:
418                # containers is being used as a stack. Here we pop off
419                # the next one.
420                c = containers[0]
421                c.strict = getattr(input, 'strict', False)
422                c.allow_downcast = getattr(input, 'allow_downcast', None)
423
424                if value is not None:
425                    # Always initialize the storage.
426                    if isinstance(value, gof.Container):
427                        # There is no point in obtaining the current value
428                        # stored in the container, since the container is
429                        # shared.
430                        # For safety, we make sure 'refeed' is False, since
431                        # there is no need to refeed the defaullt value.
432                        assert not refeed
433                    else:
434                        c.value = value
435                c.required = required
436                c.implicit = input.implicit
437                # this is a count of how many times the input has been
438                # provided (reinitialized to 0 on __call__)
439                c.provided = 0
440                finder[i] = c
441                finder[input.variable] = c
442                if input.name not in finder:
443                    finder[input.name] = c
444                else:
445                    finder[input.name] = DUPLICATE
446                if input.name is None:
447                    n_unnamed_inputs += 1
448                else:
449                    named_inputs.append(input.name)
450                inv_finder[c] = input
451                containers[:1] = []
452
453        self.finder = finder
454        self.inv_finder = inv_finder
455
456        # this class is important in overriding the square-bracket notation:
457        #     fn.value[x]
458        # self reference is available via the closure on the class
459        class ValueAttribute(object):
460            def __getitem__(self, item):
461                try:
462                    s = finder[item]
463                except KeyError:
464                    raise TypeError("Unknown input or state: %s" % str(item))
465                if s is DUPLICATE:
466                    raise TypeError("Ambiguous name: %s - please check the "
467                                    "names of the inputs of your function "
468                                    "for duplicates." % str(item))
469                if isinstance(s, gof.Container):
470                    return s.value
471                else:
472                    raise NotImplementedError
473
474            def __setitem__(self, item, value):
475                try:
476                    s = finder[item]
477                except KeyError:
478                    # Print informative error message.
479                    msg = get_info_on_inputs(named_inputs, n_unnamed_inputs)
480                    raise TypeError("Unknown input or state: %s. %s" %
481                                    (str(item), msg))
482                if s is DUPLICATE:
483                    raise TypeError("Ambiguous name: %s - please check the "
484                                    "names of the inputs of your function "
485                                    "for duplicates." % str(item))
486                if isinstance(s, gof.Container):
487                    s.value = value
488                    s.provided += 1
489                else:
490                    s(value)
491
492            def __contains__(self, item):
493                return finder.__contains__(item)
494
495        # this class is important in overriding the square-bracket notation:
496        #     fn.container[x]
497        # self reference is available via the closure on the class
498        class ContainerAttribute(object):
499            def __getitem__(self, item):
500                return finder[item]
501
502            def __contains__(self, item):
503                return finder.__contains__(item)
504            # You cannot set the container
505
506        self._value = ValueAttribute()
507        self._container = ContainerAttribute()
508
509        # Compute self.n_returned_outputs.
510        # This is used only when fn.need_update_inputs is False
511        # because we're using one of the VM objects and it is
512        # putting updates back into the input containers all by itself.
513        assert len(self.maker.expanded_inputs) == len(self.input_storage)
514        self.n_returned_outputs = len(self.output_storage)
515        for input in self.maker.expanded_inputs:
516            if input.update is not None:
517                self.n_returned_outputs -= 1
518
519        for node in self.maker.fgraph.apply_nodes:
520            if node.op in ops_with_inner_function:
521                self.nodes_with_inner_function.append(node.op)
522
523    def __contains__(self, item):
524        return self.value.__contains__(item)
525
526    def __getitem__(self, item):
527        return self.value[item]
528
529    def __setitem__(self, item, value):
530        self.value[item] = value
531
532    def __copy__(self):
533        """
534        Copy a function. Copied function have separate intermediate
535        storages and output storages with original function
536        """
537        return self.copy()
538
539    def copy(self, share_memory=False, swap=None, delete_updates=False,
540             name=None, profile=None):
541        """
542        Copy this function. Copied function will have separated maker and
543        fgraph with original function. User can choose whether to separate
544        storage by changing the share_memory arguments.
545
546        Parameters
547        ----------
548        share_memory : boolean
549            When True, two function share intermediate storages(storages except input and
550            output storages). Otherwise two functions will only share partial
551            storages and same maker. If two functions share memory and
552            allow_gc=False, this will increase executing speed and save memory.
553
554        swap : dict
555            Dictionary that map old SharedVariables to new
556            SharedVariables. Default is None.
557            NOTE: The shared variable swap in only done in the new returned
558            function, not in the user graph.
559
560        delete_updates : boolean
561            If True, Copied function will not have updates.
562        name : string
563            If provided, will be the name of the new
564            Function. Otherwise, it will be old + " copy"
565
566        profile :
567            as theano.function profile parameter
568
569        Returns
570        -------
571        theano.Function
572            Copied theano.Function
573        """
574        # helper function
575        def checkSV(sv_ori, sv_rpl):
576            """
577            Assert two SharedVariable follow some restirctions:
578                1. same type
579                2. same shape or dim?
580            """
581            SharedVariable = theano.tensor.sharedvar.SharedVariable
582            assert isinstance(sv_ori, SharedVariable), (
583                "Key of swap should be SharedVariable, given:", sv_ori,
584                " type", type(sv_ori))
585            assert isinstance(sv_rpl, SharedVariable), (
586                "Value of swap should be SharedVariable, given:", sv_rpl,
587                "type", type(sv_ori))
588            assert sv_ori.type == sv_rpl.type, (
589                "Type of given SharedVariable conflicts with original one",
590                "Type of given SharedVariable:", sv_rpl.type,
591                "Type of original SharedVariable:", sv_ori.type)
592
593        maker = self.maker
594
595        # Copy Ins and their storage.
596        # so that they have different storage as their value
597        ins = [copy.copy(input) for input in maker.inputs]
598
599        # Delete update output in fgraph and updates In instances if needed
600        if delete_updates:
601            # The first len(maker.outputs) variables are original variables.
602            # The rest are the updates.
603            out_vars = maker.fgraph.outputs[:len(maker.outputs)]
604        else:
605            out_vars = maker.fgraph.outputs
606
607        # Init new fgraph using copied variables and get memo
608        # memo: a dict that map old variables to new variables
609        memo = graph.clone_get_equiv(maker.fgraph.inputs, out_vars)
610        fg_cpy = gof.fg.FunctionGraph([memo[i] for i in maker.fgraph.inputs],
611                                      [memo[o] for o in out_vars],
612                                      clone=False)
613
614        # Re initialize Outs and swap update and variable in Ins
615        # By doing this, we can pass FunctionMaker._check_unused_inputs()
616        outs = list(map(SymbolicOutput, fg_cpy.outputs[:len(maker.outputs)]))
617        for out_ori, out_cpy in zip(maker.outputs, outs):
618            out_cpy.borrow = out_ori.borrow
619
620        # swap SharedVariable
621        if swap is not None:
622            exist_svs = [i.variable for i in maker.inputs]
623
624            # Check if given ShareVariables exist
625            for sv in iterkeys(swap):
626                if sv not in exist_svs:
627                    raise ValueError("SharedVariable: %s not found" %
628                                     (sv.name))
629
630            # Swap SharedVariable in fgraph and In instances
631            for index, (i, in_v) in enumerate(zip(ins, fg_cpy.inputs)):
632                # Variables in maker.inputs are defined by user, therefore we
633                # use them to make comparison and do the mapping.
634                # Otherwise we don't touch them.
635                var = maker.inputs[index].variable
636
637                if var in swap:
638                    swap_sv = swap[var]
639                    checkSV(i.variable, swap_sv)
640
641                    # swap variable and value of In instances
642                    i.variable = swap_sv
643                    i.value = swap_sv.container
644
645                    # In the fgraph we use the cloned SharedVariable
646                    swap_sv = swap_sv.clone()
647
648                    # Swap SharedVariable in fgraph
649                    # if inputs was replaced, change self.inputs
650                    fg_cpy.inputs[index] = swap_sv
651                    fg_cpy.replace(in_v, swap_sv, reason="Swap SV")
652
653        # Delete update if needed
654        update_i = len(outs)
655        for i, in_var in zip(ins, fg_cpy.inputs):
656            i.variable = in_var
657            if not delete_updates and i.update is not None:
658                i.update = fg_cpy.outputs[update_i]
659                update_i += 1
660            else:
661                i.update = None
662
663        # Construct new storage_map that map new variable to old storage,
664        # so that the ensuing function shares storage with the original one
665        storage_map = self.fn.storage_map
666        new_storage_map = {}
667        # TODO: We could share the output storage, but we must make sure
668        # 2 different function call won't override each other values. This
669        # is already done elsewhere, so to reuse it the user would need to
670        # use Out(var, borrow=True) and maybe the mutable=True flag too.
671        # But to be safe for now as it isn't documented and we aren't sure
672        # it is well tested, we don't share the part of the storage_map.
673        if share_memory:
674            i_o_vars = maker.fgraph.inputs + maker.fgraph.outputs
675            for key in storage_map.keys():
676                if key not in i_o_vars:
677                    new_storage_map[memo[key]] = storage_map[key]
678
679        if not name and self.name:
680            name = self.name + " copy"
681
682        input_storage = [i.value for i in ins]
683        # reinitialize new maker and create new function
684        if profile is None:
685            profile = config.profile or config.print_global_stats
686            # profile -> True or False
687        if profile is True:
688            if name:
689                message = name
690            else:
691                message = str(profile.message) + " copy"
692            profile = theano.compile.profiling.ProfileStats(message=message)
693            # profile -> object
694        elif type(profile) == str:
695            profile = theano.compile.profiling.ProfileStats(message=profile)
696
697        f_cpy = maker.__class__(inputs=ins, outputs=outs, fgraph=fg_cpy,
698                                mode=maker.mode, profile=profile,
699                                # When removing updates containing variables
700                                # not used in the output function, copy
701                                # generates an unused implicit input.
702                                # We ignore the resulting errors,
703                                # but could change it to 'warn' if this might
704                                # cause problems.
705                                on_unused_input='ignore',
706                                function_builder=maker.function_builder,
707                                # As this is an optimized graph, it
708                                # can contain inplace. DebugMode check
709                                # that.
710                                accept_inplace=True,
711                                ).create(input_storage,
712                                         storage_map=new_storage_map)
713
714        for in_ori, in_cpy, ori, cpy in zip(maker.inputs, f_cpy.maker.inputs,
715                                            self.input_storage,
716                                            f_cpy.input_storage):
717
718            # Share immutable ShareVariable and constant input's storage
719            swapped = swap is not None and in_ori.variable in swap
720
721            # Using the original storage if SharedVariable will not be updated
722            # and is not swapped
723            if not in_ori.mutable and not swapped:
724                cpy.data = ori.data
725                in_cpy.value = in_ori.value
726
727            # Reconstruct Function.finder which map Variable defined by user
728            # to container, to make Function.value and Function.data work well.
729            # Replace variable in new maker.inputs by the original ones.
730            # So that user can swap SharedVariable in a swapped function
731            container = f_cpy.finder.pop(in_cpy.variable)
732            if not swapped:
733                f_cpy.finder[in_ori.variable] = container
734                in_cpy.vairable = in_ori.variable
735            else:
736                f_cpy.finder[swap[in_ori.variable]] = container
737                in_cpy.variable = swap[in_ori.variable]
738
739        f_cpy.name = name
740        f_cpy.maker.fgraph.name = name
741        return f_cpy
742
743    def __call__(self, *args, **kwargs):
744        """
745        Evaluates value of a function on given arguments.
746
747        Parameters
748        ----------
749        args : list
750            List of inputs to the function. All inputs are required, even when
751            some of them are not necessary to calculate requested subset of
752            outputs.
753
754        kwargs : dict
755            The function inputs can be passed as keyword argument. For this, use
756            the name of the input or the input instance as the key.
757
758            Keyword argument ``output_subset`` is a list of either indices of the
759            function's outputs or the keys belonging to the `output_keys` dict
760            and represent outputs that are requested to be calculated. Regardless
761            of the presence of ``output_subset``, the updates are always calculated
762            and processed. To disable the updates, you should use the ``copy``
763            method with ``delete_updates=True``.
764
765        Returns
766        -------
767        list
768            List of outputs on indices/keys from ``output_subset`` or all of them,
769            if ``output_subset`` is not passed.
770        """
771        def restore_defaults():
772            for i, (required, refeed, value) in enumerate(self.defaults):
773                if refeed:
774                    if isinstance(value, gof.Container):
775                        value = value.storage[0]
776                    self[i] = value
777        profile = self.profile
778        t0 = time.time()
779
780        output_subset = kwargs.pop('output_subset', None)
781        if output_subset is not None and self.output_keys is not None:
782            output_subset =\
783                [self.output_keys.index(key) for key in output_subset]
784
785        # Reinitialize each container's 'provided' counter
786        if self.trust_input:
787            i = 0
788            for arg in args:
789                s = self.input_storage[i]
790                s.storage[0] = arg
791                i += 1
792        else:
793            for c in self.input_storage:
794                c.provided = 0
795
796            if len(args) + len(kwargs) > len(self.input_storage):
797                raise TypeError("Too many parameter passed to theano function")
798
799            # Set positional arguments
800            i = 0
801            for arg in args:
802                # TODO: provide a Param option for skipping the filter if we
803                #      really want speed.
804                s = self.input_storage[i]
805                # see this emails for a discuation about None as input
806                # https://groups.google.com/group/theano-dev/browse_thread/thread/920a5e904e8a8525/4f1b311a28fc27e5
807                if arg is None:
808                    s.storage[0] = arg
809                else:
810                    try:
811                        s.storage[0] = s.type.filter(
812                            arg, strict=s.strict,
813                            allow_downcast=s.allow_downcast)
814
815                    except Exception as e:
816                        function_name = "theano function"
817                        argument_name = "argument"
818                        if self.name:
819                            function_name += ' with name "' + self.name + '"'
820                        if hasattr(arg, 'name') and arg.name:
821                            argument_name += ' with name "' + arg.name + '"'
822                        where = theano.gof.utils.get_variable_trace_string(
823                            self.maker.inputs[i].variable)
824                        if len(e.args) == 1:
825                            e.args = ("Bad input " + argument_name + " to " +
826                                      function_name + " at index %d (0-based). %s"
827                                      % (i, where) + e.args[0],)
828                        else:
829                            e.args = ("Bad input " + argument_name + " to " +
830                                      function_name + " at index %d (0-based). %s"
831                                      % (i, where),) + e.args
832                        restore_defaults()
833                        raise
834                s.provided += 1
835                i += 1
836
837        # Set keyword arguments
838        if kwargs:  # for speed, skip the iteritems for empty kwargs
839            for k, arg in iteritems(kwargs):
840                self[k] = arg
841
842        if (not self.trust_input and
843            # The getattr is only needed for old pickle
844                getattr(self, '_check_for_aliased_inputs', True)):
845            # Collect aliased inputs among the storage space
846            args_share_memory = []
847            for i in xrange(len(self.input_storage)):
848                i_var = self.maker.inputs[i].variable
849                i_val = self.input_storage[i].storage[0]
850                if hasattr(i_var.type, 'may_share_memory'):
851                    is_aliased = False
852                    for j in xrange(len(args_share_memory)):
853
854                        group_j = izip(
855                            [self.maker.inputs[k].variable for k
856                             in args_share_memory[j]],
857                            [self.input_storage[k].storage[0] for k
858                             in args_share_memory[j]])
859                        if any([(var.type is i_var.type and
860                                 var.type.may_share_memory(val, i_val))
861                                for (var, val) in group_j]):
862
863                            is_aliased = True
864                            args_share_memory[j].append(i)
865                            break
866
867                    if not is_aliased:
868                        args_share_memory.append([i])
869
870            # Check for groups of more than one argument that share memory
871            for group in args_share_memory:
872                if len(group) > 1:
873                    # copy all but the first
874                    for j in group[1:]:
875                        self.input_storage[j].storage[0] = copy.copy(
876                            self.input_storage[j].storage[0])
877
878        # Check if inputs are missing, or if inputs were set more than once, or
879        # if we tried to provide inputs that are supposed to be implicit.
880        if not self.trust_input:
881            for c in self.input_storage:
882                if c.required and not c.provided:
883                    restore_defaults()
884                    raise TypeError("Missing required input: %s" %
885                                    getattr(self.inv_finder[c], 'variable',
886                                            self.inv_finder[c]))
887                if c.provided > 1:
888                    restore_defaults()
889                    raise TypeError("Multiple values for input: %s" %
890                                    getattr(self.inv_finder[c], 'variable',
891                                            self.inv_finder[c]))
892                if c.implicit and c.provided > 0:
893                    restore_defaults()
894                    raise TypeError(
895                        'Tried to provide value for implicit input: %s'
896                        % getattr(self.inv_finder[c], 'variable',
897                                  self.inv_finder[c]))
898
899        # Do the actual work
900        t0_fn = time.time()
901        try:
902            outputs =\
903                self.fn() if output_subset is None else\
904                self.fn(output_subset=output_subset)
905        except Exception:
906            restore_defaults()
907            if hasattr(self.fn, 'position_of_error'):
908                # this is a new vm-provided function or c linker
909                # they need this because the exception manipulation
910                # done by raise_with_op is not implemented in C.
911                thunk = None
912                if hasattr(self.fn, 'thunks'):
913                    thunk = self.fn.thunks[self.fn.position_of_error]
914                gof.link.raise_with_op(
915                    node=self.fn.nodes[self.fn.position_of_error],
916                    thunk=thunk,
917                    storage_map=getattr(self.fn, 'storage_map', None))
918            else:
919                # old-style linkers raise their own exceptions
920                raise
921
922        dt_fn = time.time() - t0_fn
923        self.maker.mode.fn_time += dt_fn
924        if profile:
925            profile.vm_call_time += dt_fn
926
927        # Retrieve the values that were computed
928        if outputs is None:
929            outputs = [x.data for x in self.output_storage]
930        assert len(outputs) == len(self.output_storage)
931
932        # Remove internal references to required inputs.
933        # These cannot be re-used anyway.
934        for c in self.input_storage:
935            if c.required:
936                c.storage[0] = None
937
938        # if we are allowing garbage collection, remove the
939        # output reference from the internal storage cells
940        if getattr(self.fn, 'allow_gc', False):
941            assert len(self.output_storage) == len(self.maker.fgraph.outputs)
942            for o_container, o_variable in zip(self.output_storage,
943                                               self.maker.fgraph.outputs):
944                if o_variable.owner is not None:
945                    # this node is the variable of computation
946                    # WARNING: This circumvents the 'readonly' attribute in x
947                    o_container.storage[0] = None
948
949        if getattr(self.fn, 'need_update_inputs', True):
950            # Update the inputs that have an update function
951            for input, storage in reversed(list(zip(self.maker.expanded_inputs,
952                                                    self.input_storage))):
953                if input.update is not None:
954                    storage.data = outputs.pop()
955        else:
956            outputs = outputs[:self.n_returned_outputs]
957
958        # Put default values back in the storage
959        restore_defaults()
960        #
961        # NOTE: This logic needs to be replicated in
962        #       scan.
963        #       grep for 'PROFILE_CODE'
964        #
965
966        dt_call = time.time() - t0
967        theano.compile.profiling.total_fct_exec_time += dt_call
968        self.maker.mode.call_time += dt_call
969        if profile:
970            profile.fct_callcount += 1
971            profile.fct_call_time += dt_call
972            if hasattr(self.fn, 'update_profile'):
973                self.fn.update_profile(profile)
974            if profile.ignore_first_call:
975                profile.reset()
976                profile.ignore_first_call = False
977        if self.return_none:
978            return None
979        elif self.unpack_single and len(outputs) == 1 and\
980                output_subset is None:
981            return outputs[0]
982        else:
983
984            if self.output_keys is not None:
985
986                assert len(self.output_keys) == len(outputs)
987
988                if output_subset is None:
989                    return dict(izip(self.output_keys, outputs))
990                else:
991                    return dict((self.output_keys[index], outputs[index])
992                                for index in output_subset)
993
994            if output_subset is None:
995                return outputs
996            else:
997                return [outputs[i] for i in output_subset]
998
999    value = property(
1000        lambda self: self._value,
1001        None,  # this property itself is not settable
1002        doc="dictionary-like access to the values associated with Variables")
1003    container = property(
1004        lambda self: self._container,
1005        None,  # this property itself is not settable
1006        doc=("dictionary-like access to the containers associated with "
1007             "Variables"))
1008
1009    def free(self):
1010        """
1011        When allow_gc = False, clear the Variables in storage_map
1012        """
1013        # 1.no allow_gc return False
1014        # 2.has allow_gc, if allow_gc is False, return True
1015        if not getattr(self.fn, 'allow_gc', True):
1016            for key in self.fn.storage_map:
1017                if not isinstance(key, theano.gof.Constant):
1018                    self.fn.storage_map[key][0] = None
1019
1020            for node in self.nodes_with_inner_function:
1021                ops_with_inner_function[node.op].free()
1022
1023    def get_shared(self):
1024        """
1025        Return the shared variable read or updated by by this function.
1026        """
1027        return [i.variable for i in self.maker.inputs if i.implicit]
1028
1029    def sync_shared(self):
1030        if (hasattr(theano, "gpuarray") and
1031                theano.gpuarray.pygpu_activated):
1032            import pygpu
1033            for i in self.maker.fgraph.update_mapping.values():
1034                inp = self.input_storage[i]
1035                if isinstance(inp.data, pygpu.gpuarray.GpuArray):
1036                    inp.data.sync()
1037
1038
1039# pickling/deepcopy support for Function
1040def _pickle_Function(f):
1041    # copy of the input storage list
1042    ins = list(f.input_storage)
1043    input_storage = []
1044
1045    for (input, indices, inputs), (required, refeed, default) in \
1046            zip(f.indices, f.defaults):
1047        input_storage.append(ins[0])
1048        del ins[0]
1049
1050    inputs_data = [x.data for x in f.input_storage]
1051
1052    # HACK to detect aliased storage.
1053    # This is here because aliased relationships are not [currently]
1054    # preserved across the pickle operation
1055    if not (f.pickle_aliased_memory_strategy == 'ignore'):
1056        all_data = input_storage + inputs_data
1057        for i, d_i in enumerate(all_data):
1058            for j, d_j in enumerate(all_data):
1059                if ((i < j) and isinstance(d_i, np.ndarray) and
1060                        isinstance(d_j, np.ndarray)):
1061                    if np.may_share_memory(d_i, d_j):
1062                        if f.pickle_aliased_memory_strategy == 'warn':
1063                            _logger.warning('aliased relationship between '
1064                                            'Function arguments %s, %s '
1065                                            'will not be preserved by '
1066                                            'un-pickling operation' %
1067                                            (str(d_i), str(d_j)))
1068                        else:
1069                            raise AliasedMemoryError(d_i, d_j)
1070    # The user can override trust_input. Our doc tell that.  We should
1071    # not do that anymore and make sure the Maker have all the
1072    # information needed.
1073    rval = (_constructor_Function,
1074            (f.maker, input_storage, inputs_data, f.trust_input))
1075    return rval
1076
1077
1078def _constructor_Function(maker, input_storage, inputs_data, trust_input=False):
1079    if not theano.config.unpickle_function:
1080        return None
1081
1082    f = maker.create(input_storage, trustme=True)
1083    assert len(f.input_storage) == len(inputs_data)
1084    for container, x in zip(f.input_storage, inputs_data):
1085        assert (container.data is x) or \
1086            (isinstance(x, np.ndarray) and (container.data == x).all()) or \
1087            (container.data == x)
1088    f.trust_input = trust_input
1089    return f
1090
1091copyreg.pickle(Function, _pickle_Function)
1092
1093
1094###
1095# FunctionMaker
1096###
1097def insert_deepcopy(fgraph, wrapped_inputs, wrapped_outputs):
1098    """
1099    Insert deepcopy in the fgraph to break aliasing of outputs
1100    """
1101    # This loop was inserted to remove aliasing between outputs when
1102    # they all evaluate to the same value. Originally it was OK for
1103    # outputs to be aliased, but some of the outputs can be shared
1104    # variables, and is not good for shared variables to be
1105    # aliased. It might be possible to optimize this by making sure
1106    # there is no aliasing only between shared variables.
1107
1108    # If some outputs are constant, we add deep copy to respect the
1109    # memory contract
1110
1111    # We don't insert deep copy when the output.borrow is True for all
1112    # concerned outputs.
1113
1114    assert len(wrapped_inputs) == len(fgraph.inputs)
1115    assert len(wrapped_outputs) == len(fgraph.outputs)
1116    reason = "insert_deepcopy"
1117    updated_fgraph_inputs = set([fgraph_i for i, fgraph_i in
1118                                zip(wrapped_inputs, fgraph.inputs)
1119                                if getattr(i, 'update', False)])
1120
1121    # We can't use fgraph.inputs as this don't include Constant Value.
1122    all_graph_inputs = gof.graph.inputs(fgraph.outputs)
1123    has_destroyers_attr = hasattr(fgraph, 'has_destroyers')
1124
1125    for i in xrange(len(fgraph.outputs)):
1126        views_of_output_i = set()
1127        view_tree_set(alias_root(fgraph.outputs[i]), views_of_output_i)
1128        copied = False
1129        # do not allow outputs to be aliased
1130        for j in xrange(i + 1, len(fgraph.outputs)):
1131            # We could don't put deep copy if both outputs have borrow==True
1132            # and not(wrapped_outputs[i].borrow and wrapped_outputs[j].borrow):
1133            if fgraph.outputs[j] in views_of_output_i:
1134                if wrapped_outputs[i].borrow and wrapped_outputs[j].borrow:
1135                    fgraph.change_input('output', i,
1136                                        view_op(fgraph.outputs[i]),
1137                                        reason=reason)
1138                else:
1139                    fgraph.change_input('output', i,
1140                                        deep_copy_op(fgraph.outputs[i]),
1141                                        reason=reason)
1142                copied = True
1143                break
1144
1145        if not copied:
1146            for input_j in all_graph_inputs:
1147                # do not allow outputs to be aliased to an inputs (j), unless
1148                # a) that j'th input has been 'destroyed' by
1149                #    e.g. in-place computations
1150                # b) that j'th input is a shared variable that is also
1151                #    being updated
1152                if input_j in updated_fgraph_inputs:
1153                    continue
1154                if input_j in views_of_output_i and not (has_destroyers_attr and fgraph.has_destroyers([input_j])):
1155                    # We don't put deep_copy_op if the input and the
1156                    # output have borrow==True
1157                    if input_j in fgraph.inputs:
1158                        j = fgraph.inputs.index(input_j)
1159                        if (wrapped_outputs[i].borrow and
1160                                wrapped_inputs[j].borrow):
1161                            fgraph.change_input('output', i,
1162                                                view_op(fgraph.outputs[i]),
1163                                                reason="insert_deepcopy")
1164                            break
1165                        else:
1166                            fgraph.change_input(
1167                                'output', i,
1168                                deep_copy_op(fgraph.outputs[i]),
1169                                reason="insert_deepcopy")
1170                            break
1171                    elif wrapped_outputs[i].borrow:
1172                        fgraph.change_input('output', i,
1173                                            view_op(fgraph.outputs[i]),
1174                                            reason="insert_deepcopy")
1175                        break
1176                    else:
1177                        fgraph.change_input('output', i,
1178                                            deep_copy_op(fgraph.outputs[i]),
1179                                            reason="insert_deepcopy")
1180                        break
1181
1182NODEFAULT = ['NODEFAULT']
1183
1184
1185class FunctionMaker(object):
1186    """
1187    `FunctionMaker` is the class to `create` `Function` instances.
1188
1189    This class has the fgraph, the optimizer, and the linker. When
1190    copying a `Function`, there is no need to duplicate the
1191    `FunctionMaker` instance. Deepcopy still copies both, which can
1192    variable in re-compilation.
1193
1194    Parameters
1195    ----------
1196    inputs : list of SymbolicInput instances
1197    outputs : list of SymbolicOutput instances
1198        Outputs may also be a single Variable (not a list), in which case the
1199        functions produced by FunctionMaker will return their output value
1200        directly.
1201    mode : Mode instance
1202        Telling FunctionMaker how to optimize and link. None means to use the
1203        `config.mode`.
1204    accept_inplace : bool
1205        True iff it is acceptable to have inplace operations in the graph from
1206        the inputs to the outputs.
1207    on_unused_input : {'raise', 'warn', 'ignore', None}
1208        What to do if a variable in the 'inputs' list is not used in the graph.
1209        Possible values are:
1210        - 'raise': raise an error
1211        - 'warn': log a warning
1212        - 'ignore': do not do anything
1213        - None: Use the value in the Theano flags on_unused_input.
1214    name : str
1215        An optional name for this function. If used, the profile mode will
1216        print the time spent in this function.
1217
1218    """
1219
1220    @staticmethod
1221    def wrap_in(input):
1222        if isinstance(input, (SymbolicInput)):
1223            return input
1224        elif isinstance(input, gof.Variable):
1225            # r -> SymbolicInput(variable=r)
1226            return SymbolicInput(input)
1227        elif isinstance(input, (list, tuple)):
1228            # (r, u) -> SymbolicInput(variable=r, update=u)
1229            if len(input) == 2:
1230                return SymbolicInput(input[0], update=input[1])
1231            else:
1232                raise TypeError("Expected two elements in the list or tuple.",
1233                                input)
1234        else:
1235            raise TypeError("Unknown input type: %s (%s), expected Variable "
1236                            "instance", type(input), input)
1237
1238    @staticmethod
1239    def expand_in(sinput, rinputs):
1240        # For SymbolicInputKits, this extracts a list of SymbolicInput
1241        # instances and corresponding indices such that these
1242        # SymbolicInputs are representative of some of the Variable
1243        # instances in inputs.  For SymbolicInput, this returns None
1244        # as the list of indices and a list with just the
1245        # SymbolicInput.
1246        # if isinstance(sinput, SymbolicInputKit):
1247        #    return sinput.complete(rinputs)
1248        # elif isinstance(sinput, SymbolicInput):
1249        if isinstance(sinput, SymbolicInput):
1250            return [None, [sinput]]
1251
1252    @staticmethod
1253    def wrap_out(output):
1254        if isinstance(output, SymbolicOutput):
1255            return output
1256        elif isinstance(output, gof.Variable):
1257            return SymbolicOutput(output)
1258        else:
1259            raise TypeError("Unknown output type: %s (%s)", type(output),
1260                            output)
1261
1262    def optimize_graph_with_cache(self, optimizer, inputs, outputs):
1263        # This function is not finished
1264        from theano.gof.compilelock import get_lock, release_lock
1265        import os.path
1266
1267        graph_db_file = os.path.join(theano.config.compiledir,
1268                                     'optimized_graphs.pkl')
1269
1270        # the inputs, outputs, and size of the graph to be optimized
1271        inputs_new = [inp.variable for inp in inputs]
1272        outputs_new = [out.variable for out in outputs]
1273        size_new = len(self.fgraph.apply_nodes)
1274        get_lock()
1275        # Beginning of cache optimizations.
1276        # Could be refactored in different functions.
1277
1278        def load_graph_db():
1279            if os.path.isfile(graph_db_file):
1280                print('graph_db already exists')
1281            else:
1282                # create graph_db
1283                with open(graph_db_file, 'wb') as f:
1284                    print('create new graph_db in %s' % graph_db_file)
1285            # load the graph_db dictionary
1286            try:
1287                with open(graph_db_file, 'rb') as f:
1288                    # Temporary hack to allow
1289                    # theano.scan_module.tests.test_scan.T_Scan to
1290                    # finish. Should be changed in definitive version.
1291                    tmp = theano.config.unpickle_function
1292                    theano.config.unpickle_function = False
1293                    graph_db = pickle.load(f)
1294                print('graph_db loaded and it is not empty')
1295            except EOFError as e:
1296                # the file has nothing in it
1297                print(e)
1298                print('graph_db loaded and it is empty')
1299                graph_db = {}
1300            finally:
1301                theano.config.unpickle_function = tmp
1302
1303            return graph_db
1304
1305        def find_same_graph_in_db(graph_db):
1306            # If found_graph_in_db is None, then need to optimize.
1307            # Otherwise, return the graph found.
1308            found_graph_in_db = None
1309            # The sole purpose of this loop is to set 'need_optimize' by
1310            # going through graph_db, looking for graph that has the same
1311            # computation performed.
1312            for graph_old, graph_optimized in iteritems(graph_db):
1313                inputs_old = graph_old.inputs
1314                outputs_old = graph_old.outputs
1315                size_old = len(graph_old.apply_nodes)
1316                # Some heuristics to check is the same graphs have
1317                # already been optimized before.
1318                if len(inputs_new) != len(inputs_old):
1319                    # If the inputs are of different size,
1320                    # two graphs are for sure different
1321                    print('need to optimize, because input size is different')
1322                    continue
1323                elif len(outputs_new) != len(outputs_old):
1324                    # If the inputs are of different size,
1325                    # two graphs are for sure different
1326                    print('need to optimize, because output size is different')
1327                    continue
1328                elif not all(input_new.type == input_old.type
1329                             for input_new, input_old in
1330                             zip(inputs_new, inputs_old)):
1331                    print('need to optimize, because inputs are of different '
1332                          'types')
1333                    continue
1334                elif not all(output_new.type == output_old.type
1335                             for output_new, output_old in
1336                             zip(outputs_new, outputs_old)):
1337                    print('need to optimize, because outputs are of different '
1338                          'types')
1339                    continue
1340                elif not size_old == size_new:
1341                    print('need to optimize, because numbers of nodes in graph'
1342                          ' are different')
1343                    continue
1344                else:
1345                    flags = []
1346                    for i, (output_new, output_old) in enumerate(
1347                            zip(outputs_new, outputs_old)):
1348                        print('loop through outputs node for both graphs')
1349                        graph_old.variables = set(gof.graph.variables(
1350                            graph_old.inputs, graph_old.outputs))
1351
1352                        # using clone allowed to avoid a lot of errors
1353                        # deep copy seemed to had.
1354                        f2 = graph_old.clone(check_integrity=False)
1355                        t1 = output_new
1356                        t2 = f2.outputs[i]
1357
1358                        # Used to remove "already used by another graph error
1359                        def removeAllFgraph(remove):
1360                            if hasattr(remove, 'fgraph'):
1361                                del remove.fgraph
1362                            if hasattr(remove, 'owner'):
1363                                if remove.owner is None:
1364                                    pass
1365                                else:
1366                                    if hasattr(remove.owner, 'fgraph'):
1367                                        del remove.owner.fgraph
1368                                    if hasattr(remove.owner, 'inputs'):
1369                                        remove.owner.inputs = [removeAllFgraph(
1370                                            i) for i in remove.owner.inputs]
1371                                        for o in remove.owner.outputs:
1372                                            if hasattr(o, 'fgraph'):
1373                                                del o.fgraph
1374                            return remove
1375
1376                        t2 = removeAllFgraph(t2)
1377
1378                        givens = dict(izip(gof.graph.inputs([t1]),
1379                                           gof.graph.inputs([t2])))
1380
1381                        temp = dict(izip(gof.graph.inputs([t1]),
1382                                         gof.graph.inputs([t2])))
1383
1384                        # hack to remove inconstent entry in givens
1385                        # seems to work that but source of inconsistency
1386                        # could be worth investigating.
1387                        for key, value in iteritems(temp):
1388                            if key.type != value.type:
1389                                del givens[key]
1390
1391                        flag = is_same_graph(t1, t2, givens=givens)
1392
1393                        flags.append(flag)
1394
1395                    is_same = all(flags)
1396                    if is_same:
1397                        # found the match
1398                        print('found a match, no need to optimize')
1399                        found_graph_in_db = graph_optimized
1400                        break
1401            return found_graph_in_db
1402
1403        graph_db = load_graph_db()
1404        print('loaded graph_db from %s, size=%d' % (graph_db_file,
1405                                                    len(graph_db)))
1406        found_graph = find_same_graph_in_db(graph_db)
1407        if found_graph:
1408            self.fgraph = found_graph
1409            optimizer_profile = None
1410        else:
1411            # this is a brand new graph, optimize it, save it to graph_db
1412            print('graph not found in graph_db, optimizing the graph')
1413            self.fgraph.variables = set(gof.graph.variables(
1414                self.fgraph.inputs, self.fgraph.outputs))
1415            # check_integrity parameters was added to ignore
1416            # "excess cached variables" errors. Works that way
1417            # but once again the error couldbe worth
1418            # investigating.
1419            before_opt = self.fgraph.clone(check_integrity=False)
1420            optimizer_profile = optimizer(self.fgraph)
1421            graph_db.update({before_opt: self.fgraph})
1422            with open(graph_db_file, 'wb') as f:
1423                pickle.dump(graph_db, f, -1)
1424            print('new graph saved into graph_db')
1425        release_lock()
1426        return optimizer_profile
1427
1428    def __init__(self, inputs, outputs,
1429                 mode=None, accept_inplace=False, function_builder=Function,
1430                 profile=None, on_unused_input=None, fgraph=None,
1431                 output_keys=None, name=None):
1432        # Save the provided mode, not the instanciated mode.
1433        # The instanciated mode don't pickle and if we unpickle a Theano
1434        # function and it get re-compiled, we want the current optimizer to be
1435        # used, not the optimizer when it was saved.
1436        self.mode = mode
1437        mode = theano.compile.mode.get_mode(mode)
1438
1439        # Assert old way of working isn't used
1440        if getattr(mode, 'profile', None):
1441            raise TypeError(
1442                "profile passed via 'mode'. This isn't supported anymore")
1443        self.profile = profile
1444        if profile:
1445            # This is very important:
1446            # 1) We preload the cache here to don't have its timming
1447            #    included in optimization that compile function.
1448            # 2) Do not refresh the cache here by default. It cause
1449            #    too much execution time during testing as we compile
1450            #    much more functions then the number of compile c
1451            #    module.
1452            theano.gof.cc.get_module_cache().refresh()
1453        # Handle the case where inputs and/or outputs is a single
1454        # Variable (not in a list)
1455        unpack_single = False
1456        return_none = False
1457        if outputs is None:
1458            return_none = True
1459            outputs = []
1460        if not isinstance(outputs, (list, tuple)):
1461            unpack_single = True
1462            outputs = [outputs]
1463        if not isinstance(inputs, (list, tuple)):
1464            inputs = [inputs]
1465
1466        # Wrap them in In or Out instances if needed.
1467        inputs = [self.wrap_in(i) for i in inputs]
1468        outputs = [self.wrap_out(o) for o in outputs]
1469        _inputs = gof.graph.inputs([o.variable for o in outputs] +
1470                                   [i.update for i in inputs
1471                                    if getattr(i, 'update', False)])
1472
1473        # Check if some input variables are unused
1474        self._check_unused_inputs(inputs, outputs, on_unused_input)
1475
1476        # Make a list of (SymbolicInput|SymblicInputKits, indices,
1477        # [SymbolicInput,...]), one tuple for each input. (See
1478        # Function.indices for more details)
1479        indices = [[input] + self.expand_in(input, _inputs)
1480                   for input in inputs]
1481
1482        if fgraph is None:
1483            need_opt = True
1484            # make the fgraph (copies the graph, creates NEW INPUT AND
1485            # OUTPUT VARIABLES)
1486            fgraph, additional_outputs = std_fgraph(inputs, outputs,
1487                                                    accept_inplace)
1488            fgraph.profile = profile
1489        else:
1490            # fgraph is already an optimized one
1491            need_opt = False
1492            updates = [spec.update for spec in inputs if spec.update]
1493            additional_outputs = list(map(SymbolicOutput, updates))
1494
1495        self.fgraph = fgraph
1496
1497        # Fetch the optimizer and linker
1498        optimizer, linker = mode.optimizer, copy.copy(mode.linker)
1499        if need_opt:
1500            compute_test_value_orig = theano.config.compute_test_value
1501            limit_orig = theano.config.traceback.limit
1502            # Why we add stack on node when it get done in output var?
1503            try:
1504                # optimize the fgraph
1505                theano.config.compute_test_value = \
1506                    theano.config.compute_test_value_opt
1507                theano.config.traceback.limit = theano.config.traceback.compile_limit
1508                start_optimizer = time.time()
1509
1510                # In case there is an error during optimization.
1511                optimizer_profile = None
1512                opt_time = None
1513
1514                # now optimize the graph
1515                if theano.config.cache_optimizations:
1516                    optimizer_profile = self.optimize_graph_with_cache(
1517                        optimizer, inputs, outputs)
1518                else:
1519                    optimizer_profile = optimizer(fgraph)
1520
1521                end_optimizer = time.time()
1522                opt_time = end_optimizer - start_optimizer
1523                _logger.debug('Optimizing took %f seconds', opt_time)
1524
1525                # Add deep copy to respect the memory interface
1526                insert_deepcopy(fgraph, inputs, outputs + additional_outputs)
1527            finally:
1528                theano.config.compute_test_value = compute_test_value_orig
1529                theano.config.traceback.limit = limit_orig
1530
1531                # If the optimizer got interrupted
1532                if opt_time is None:
1533                    end_optimizer = time.time()
1534                    opt_time = end_optimizer - start_optimizer
1535                theano.compile.profiling.total_graph_opt_time += opt_time
1536                if profile:
1537                    if (optimizer_profile is None and
1538                            hasattr(optimizer, 'pre_profile')):
1539                        optimizer_profile = optimizer.pre_profile
1540                    profile.optimizer_time += opt_time
1541                    if theano.config.profile_optimizer:
1542                        profile.optimizer_profile = (optimizer,
1543                                                     optimizer_profile)
1544                # IF False, if mean the profile for that function was explicitly disabled
1545                elif theano.config.profile_optimizer and profile is not False:
1546                    warnings.warn((
1547                        "config.profile_optimizer requires config.profile to "
1548                        " be set to True as well"), stacklevel=3)
1549
1550        # initialize the linker
1551        if not hasattr(linker, 'accept'):
1552            raise ValueError("'linker' parameter of FunctionMaker should be "
1553                             "a Linker with an accept method or one of %s" %
1554                             list(theano.compile.mode
1555                                  .predefined_linkers.keys()))
1556
1557        # the 'no_borrow' outputs are the ones for which that we can't
1558        # return the internal storage pointer.
1559        assert len(fgraph.outputs) == len(outputs + additional_outputs)
1560        no_borrow = [output for output, spec in
1561                     zip(fgraph.outputs, outputs + additional_outputs)
1562                     if not spec.borrow]
1563        if no_borrow:
1564            self.linker = linker.accept(
1565                fgraph, no_recycling=infer_reuse_pattern(fgraph, no_borrow),
1566                profile=profile)
1567        else:
1568            self.linker = linker.accept(fgraph, profile=profile)
1569
1570        if hasattr(linker, 'accept_var_updates'):
1571            # hacky thing so VMLinker knows about updates
1572            self.linker.accept_var_updates(
1573                fgraph_updated_vars(fgraph, inputs))
1574        fgraph.name = name
1575        self.indices = indices
1576        self.inputs = inputs
1577        self.expanded_inputs = inputs
1578        self.outputs = outputs
1579        self.unpack_single = unpack_single
1580        self.return_none = return_none
1581        self.accept_inplace = accept_inplace
1582        self.function_builder = function_builder
1583        self.on_unused_input = on_unused_input  # Used for the pickling/copy
1584        self.output_keys = output_keys
1585        self.name = name
1586
1587        self.required = [(i.value is None) for i in self.inputs]
1588        self.refeed = [
1589            (i.value is not None and
1590             not isinstance(i.value, gof.Container) and
1591             i.update is None)
1592            for i in self.inputs]
1593
1594    def _check_unused_inputs(self, inputs, outputs, on_unused_input):
1595        if on_unused_input is None:
1596            on_unused_input = theano.config.on_unused_input
1597
1598        if on_unused_input == 'ignore':
1599            return
1600
1601        # There should be two categories of variables in inputs:
1602        #  - variables that have to be provided (used_inputs)
1603        #  - shared variables that will be updated
1604        used_inputs = gof.graph.ancestors(
1605            ([o.variable for o in outputs] +
1606             [i.update for i in inputs if getattr(i, 'update', False)]),
1607            blockers=[i.variable for i in inputs])
1608
1609        msg = ("theano.function was asked to create a function computing "
1610               "outputs given certain inputs, but the provided input "
1611               "variable at index %i is not part of the computational graph "
1612               "needed to compute the outputs: %s.\n%s")
1613        warn_msg = ("To make this warning into an error, you can pass the "
1614                    "parameter on_unused_input='raise' to theano.function. "
1615                    "To disable it completely, use on_unused_input='ignore'.")
1616        err_msg = ("To make this error into a warning, you can pass the "
1617                   "parameter on_unused_input='warn' to theano.function. "
1618                   "To disable it completely, use on_unused_input='ignore'.")
1619
1620        for i in inputs:
1621            if ((i.variable not in used_inputs) and (i.update is None)):
1622                if on_unused_input == 'warn':
1623                    warnings.warn(msg % (inputs.index(i), i.variable,
1624                                         warn_msg), stacklevel=6)
1625                elif on_unused_input == 'raise':
1626                    raise UnusedInputError(msg % (inputs.index(i),
1627                                                  i.variable, err_msg))
1628                else:
1629                    raise ValueError("Invalid value for keyword "
1630                                     "on_unused_input of theano.function: "
1631                                     "'%s'.\nValid values are 'raise', "
1632                                     "'warn', and 'ignore'." % on_unused_input)
1633
1634    def create(self, input_storage=None, trustme=False, storage_map=None):
1635        """
1636        Create a function.
1637
1638        Parameters
1639        ----------
1640        input_storage
1641            A list matching the inputs list and providing default values if the
1642            default for an input is None, then that input is a required input.
1643            For an input with an update, the default acts as initialization.
1644        trustme
1645            Disables some exceptions, used internally.
1646
1647        """
1648
1649        if input_storage is None:
1650            input_storage = [None] * len(self.inputs)
1651        # list of independent one-element lists, will be passed to the linker
1652        input_storage_lists = []
1653        defaults = []
1654
1655        # The following loop is to fill in the input_storage_lists and
1656        # defaults lists.
1657        assert len(self.indices) == len(input_storage)
1658        for i, ((input, indices, subinputs), input_storage_i) in \
1659                enumerate(zip(self.indices, input_storage)):
1660
1661            # Replace any default value given as a variable by its
1662            # container.  Note that this makes sense only in the
1663            # context of shared variables, but for now we avoid
1664            # dealing directly with them to avoid dependency on the
1665            # shared variables work-in-progress repository.
1666            if isinstance(input_storage_i, gof.Variable):
1667                input_storage_i = input_storage_i.container
1668
1669            if isinstance(input_storage_i, gof.Container):
1670                # If the default is a gof.Container, this means we want to
1671                # share the same storage. This is done by appending
1672                # input_storage_i.storage to input_storage_lists.
1673                if indices is not None:
1674                    raise TypeError("Cannot take a Container instance as "
1675                                    "default for a SymbolicInputKit.")
1676                input_storage_lists.append(input_storage_i.storage)
1677
1678                storage = input_storage[i].storage[0]
1679
1680            else:
1681                # Normal case: one new, independent storage unit
1682                input_storage_lists.append([input_storage_i])
1683
1684                storage = input_storage_i
1685
1686            required = self.required[i]
1687            refeed = self.refeed[i]
1688            # sanity check-- if an input is required it should not
1689            # need to be refed
1690            assert not (required and refeed)
1691
1692            # shared variables need neither be input by the user nor refed
1693            if input.shared:
1694                assert not required
1695                assert not refeed
1696                storage = None
1697
1698            # if an input is required, it never need be refed
1699            if required:
1700                storage = None
1701
1702            # make sure that we only store a value if we actually need it
1703            if storage is not None:
1704                assert refeed or not required
1705
1706            defaults.append((required, refeed, storage))
1707
1708        # Get a function instance
1709        start_linker = time.time()
1710        start_import_time = theano.gof.cmodule.import_time
1711        limit_orig = theano.config.traceback.limit
1712        try:
1713            theano.config.traceback.limit = theano.config.traceback.compile_limit
1714            _fn, _i, _o = self.linker.make_thunk(
1715                input_storage=input_storage_lists, storage_map=storage_map)
1716        finally:
1717            theano.config.traceback.limit = limit_orig
1718
1719        end_linker = time.time()
1720
1721        linker_time = end_linker - start_linker
1722        theano.compile.profiling.total_time_linker += linker_time
1723        _logger.debug('Linker took %f seconds', linker_time)
1724        if self.profile:
1725            self.profile.linker_time += linker_time
1726            _fn.time_thunks = self.profile.flag_time_thunks
1727            import_time = theano.gof.cmodule.import_time - start_import_time
1728            self.profile.import_time += import_time
1729
1730        fn = self.function_builder(_fn, _i, _o, self.indices, self.outputs,
1731                                   defaults, self.unpack_single,
1732                                   self.return_none, self.output_keys, self,
1733                                   name=self.name)
1734
1735        fn.profile = self.profile
1736        return fn
1737
1738
1739def _constructor_FunctionMaker(kwargs):
1740    # Needed for old pickle
1741    # Old pickle have at least the problem that output_keys where not saved.
1742    if theano.config.unpickle_function:
1743        if theano.config.reoptimize_unpickled_function:
1744            del kwargs['fgraph']
1745        return FunctionMaker(**kwargs)
1746    else:
1747        return None
1748
1749__checkers = []
1750
1751
1752def check_equal(x, y):
1753    for checker in __checkers:
1754        try:
1755            return checker(x, y)
1756        except Exception:
1757            continue
1758    return x == y
1759
1760
1761def register_checker(checker):
1762    __checkers.insert(0, checker)
1763
1764
1765def orig_function(inputs, outputs, mode=None, accept_inplace=False,
1766                  name=None, profile=None, on_unused_input=None,
1767                  output_keys=None):
1768    """
1769    Return a Function that will calculate the outputs from the inputs.
1770
1771    Parameters
1772    ----------
1773    inputs : list of `SymbolicInput` or `In` instances
1774    outputs : a SymbolicOutput or a list of `SymbolicOutput` or `Out` instances
1775        The return value of the returned function will match the format of this
1776        argument (either the value itself or a list of one or more return
1777        values).
1778    mode : descriptive string or Mode instance
1779        Default of None means to use `config.mode` (see below for descriptive
1780        string list).
1781    name : str
1782        An optional name for this function. If used, the profile mode will print the
1783        time spent in this function.
1784    accept_inplace : bool
1785        True iff the graph can contain inplace operations prior to the
1786        optimization phase (default is False).
1787    profile : None or ProfileStats instance
1788    on_unused_input : {'raise', 'warn', 'ignore', None}
1789        What to do if a variable in the 'inputs' list is not used in the graph.
1790    output_keys :
1791        If the outputs were provided to theano.function as a list, then
1792        output_keys is None. Otherwise, if outputs were provided as a dict,
1793        output_keys is the sorted list of keys from the outputs.
1794
1795    Notes
1796    -----
1797    Currently, the library provides the following mode strings:
1798
1799    - FAST_RUN (default) (optimize without too much time)
1800
1801    - FAST_COMPILE (minimal optimization)
1802
1803    - DebugMode: verify many internal conditions that are normally assumed
1804      (slow)
1805
1806    """
1807
1808    # Every element of the input list will be upgraded to an `In` instance if
1809    # necessary, using the rules implemented by the `convert_function_input`
1810    # function.
1811
1812    # Similarly, every element of the output list will be upgraded to an `Out`
1813    # instance if necessary:
1814
1815    t1 = time.time()
1816    mode = theano.compile.mode.get_mode(mode)
1817
1818    inputs = list(map(convert_function_input, inputs))
1819    if outputs is not None:
1820        if isinstance(outputs, (list, tuple)):
1821            outputs = list(map(FunctionMaker.wrap_out, outputs))
1822        else:
1823            outputs = FunctionMaker.wrap_out(outputs)
1824
1825    defaults = [getattr(input, 'value', None) for input in inputs]
1826
1827    if isinstance(mode, (list, tuple)):  # "mode comparison" semantics
1828        raise Exception("We do not support the passing of multiple modes")
1829    fn = None
1830    try:
1831        Maker = getattr(mode, 'function_maker', FunctionMaker)
1832        m = Maker(inputs,
1833                  outputs,
1834                  mode,
1835                  accept_inplace=accept_inplace,
1836                  profile=profile,
1837                  on_unused_input=on_unused_input,
1838                  output_keys=output_keys,
1839                  name=name)
1840        with theano.change_flags(compute_test_value="off"):
1841            fn = m.create(defaults)
1842    finally:
1843        t2 = time.time()
1844        if fn and profile:
1845            profile.compile_time += t2 - t1
1846            # TODO: append
1847            profile.nb_nodes = len(fn.maker.fgraph.apply_nodes)
1848
1849    return fn
1850
1851
1852def convert_function_input(input):
1853    """
1854    Upgrade a input shortcut to an In instance.
1855
1856    The rules for upgrading are as follows:
1857
1858    - a `Variable` instance r will be upgraded like `In`(r)
1859
1860    - a tuple (name, r) will be `In`(r, name=name)
1861
1862    - a tuple (r, val) will be `In`(r, value=value, autoname=True)
1863
1864    - a tuple ((r,up), val) will be
1865      `In`(r, value=value, update=up, autoname=True)
1866
1867    - a tuple (name, r, val) will be `In`(r, name=name, value=value)
1868
1869    - a tuple (name, (r,up), val) will be
1870      `In`(r, name=name, value=val, update=up, autoname=True)
1871
1872    """
1873    if isinstance(input, SymbolicInput):
1874        return input
1875    elif isinstance(input, gof.Constant):
1876        raise TypeError('A Constant instance is not a legal function input',
1877                        input)
1878    elif isinstance(input, gof.Variable):
1879        return In(input)
1880    elif isinstance(input, (list, tuple)):
1881        orig = input
1882        if not input:
1883            raise TypeError("Nonsensical input specification: %s" % input)
1884        if isinstance(input[0], string_types):
1885            name = input[0]
1886            input = input[1:]
1887        else:
1888            name = None
1889        if isinstance(input[0], (list, tuple)):
1890            if len(input[0]) != 2 or len(input) != 2:
1891                raise TypeError("Invalid input syntax: %s (check "
1892                                "documentation or use an In instance)" % orig)
1893            (variable, update), value = input
1894        elif isinstance(input[0], gof.Variable):
1895            if len(input) == 1:
1896                variable, update, value = input[0], None, None
1897            elif len(input) == 2:
1898                (variable, value), update = input, None
1899            else:
1900                raise TypeError("Invalid input syntax: %s (check "
1901                                "documentation or use an In instance)" % orig)
1902        elif isinstance(input[0], SymbolicInput):
1903            if len(input) == 1:
1904                return input[0]
1905            elif len(input) == 2:
1906                input, value = input
1907                if name is not None:
1908                    input.name = name
1909                input.value = value
1910                return input
1911        else:
1912            raise TypeError("The input specification is not valid: %s" % input)
1913
1914        if not isinstance(variable, gof.Variable):
1915            raise TypeError("Unknown input type: %s, expected Variable "
1916                            "instance" % type(variable), variable)
1917        if update is not None and not isinstance(update, gof.Variable):
1918            raise TypeError("Unknown update type: %s, expected Variable "
1919                            "instance" % type(update), update)
1920        if (value is not None and
1921                isinstance(value, (gof.Variable, SymbolicInput))):
1922            raise TypeError("The value for input %s should not be a Variable "
1923                            "or SymbolicInput instance (got: %s)" %
1924                            (variable, value))
1925
1926        return In(variable, name=name, value=value, update=update)
1927    else:
1928        raise TypeError("Unknown input type: %s, expected Variable instance" %
1929                        type(input), input)
1930
1931
1932def get_info_on_inputs(named_inputs, n_unnamed_inputs):
1933    """
1934    Return a human-readable description of named and un-named inputs.
1935
1936    """
1937    n_named_inputs = len(named_inputs)
1938
1939    def get_plural(n):
1940        if n > 1:
1941            return 's'
1942        else:
1943            return ''
1944
1945    if n_named_inputs == 0:
1946        if n_unnamed_inputs == 0:
1947            msg = 'The function is supposed to have no input.'
1948        else:
1949            if n_unnamed_inputs == 1:
1950                msg = ("The function has a single input variable which has no "
1951                       "name, and thus cannot be assigned through a keyword"
1952                       " argument (use 'name=...' in a Variable's "
1953                       "constructor to give it a name).")
1954            else:
1955                # Use plural.
1956                msg = ("The function has %s inputs, but none of them is named,"
1957                       " and thus they cannot be assigned through keyword "
1958                       "arguments (use 'name=...' in a Variable's "
1959                       "constructor to give it a name)." % n_unnamed_inputs)
1960    else:
1961        if n_unnamed_inputs == 0:
1962            msg = ("The function has %s named input%s (%s)." %
1963                   (n_named_inputs, get_plural(n_named_inputs),
1964                    ', '.join(named_inputs)))
1965        else:
1966            msg = ("The function has %s named input%s (%s), and %s unnamed "
1967                   "input%s which thus cannot be accessed through keyword "
1968                   "argument%s (use 'name=...' in a variable's constructor "
1969                   "to give it a name)." %
1970                   (n_named_inputs, get_plural(n_named_inputs),
1971                    ', '.join(named_inputs), n_unnamed_inputs,
1972                    get_plural(n_unnamed_inputs),
1973                    get_plural(n_unnamed_inputs)))
1974    return msg
1975