1"""
2Node classes (`Apply`, `Variable`) and expression graph algorithms.
3"""
4from __future__ import absolute_import, print_function, division
5
6from collections import deque
7import contextlib
8from copy import copy
9from itertools import count
10
11import warnings
12
13import theano
14from theano import config
15from theano.gof import utils
16from six import string_types, integer_types, iteritems
17from theano.misc.ordered_set import OrderedSet
18
19__docformat__ = "restructuredtext en"
20
21# Lazy imports to avoid circular dependencies.
22is_same_graph_with_merge = None
23equal_computations = None
24
25NoParams = object()
26
27
28class Node(utils.object2):
29    """
30    A Node in a theano graph.
31
32    Graphs contain two kinds of Nodes -- Variable and Apply.
33    Edges in the graph are not explicitly represented.
34    Instead each Node keeps track of its parents via
35    Variable.owner / Apply.inputs and its children
36    via Variable.clients / Apply.outputs.
37
38    """
39
40    def get_parents(self):
41        """
42        Return a list of the parents of this node.
43        Should return a copy--i.e., modifying the return
44        value should not modify the graph structure.
45
46        """
47        raise NotImplementedError()
48
49
50class Apply(Node):
51    """
52    An :term:`Apply` instance is a node in an expression graph which represents
53    the application of an `Op` to some input `Variable` nodes, producing some
54    output `Variable` nodes.
55
56    This class is typically instantiated by an Op's make_node() function, which
57    is typically called by that Op's __call__() function.
58
59    An Apply instance serves as a simple structure with three important
60    attributes:
61
62    - :literal:`inputs` :  a list of `Variable` nodes that represent the
63      arguments of the expression,
64
65    - :literal:`outputs` : a list of `Variable` nodes that represent the
66      variable of the expression, and
67
68    - :literal:`op` : an `Op` instance that determines the nature of the
69      expression being applied.
70
71    The driver `compile.function` uses Apply's inputs attribute together with
72    Variable's owner attribute to search the expression graph and determine
73    which inputs are necessary to compute the function's outputs.
74
75    A `Linker` uses the Apply instance's `op` field to compute the variables.
76
77    Comparing with the Python language, an `Apply` instance is theano's version
78    of a function call (or expression instance) whereas `Op` is theano's version
79    of a function definition.
80
81    Parameters
82    ----------
83    op : `Op` instance
84    inputs : list of Variable instances
85    outputs : list of Variable instances
86
87    Notes
88    -----
89    The owner field of each output in the outputs list will be set to self.
90
91    If an output element has an owner that is neither None nor self, then a
92    ValueError exception will be raised.
93
94    """
95
96    def __init__(self, op, inputs, outputs):
97        self.op = op
98        self.inputs = []
99        self.tag = utils.scratchpad()
100
101        if not isinstance(inputs, (list, tuple)):
102            raise TypeError("The inputs of an Apply must be a list or tuple")
103
104        if not isinstance(outputs, (list, tuple)):
105            raise TypeError("The output of an Apply must be a list or tuple")
106
107        # filter inputs to make sure each element is a Variable
108        for input in inputs:
109            if isinstance(input, Variable):
110                self.inputs.append(input)
111            else:
112                raise TypeError("The 'inputs' argument to Apply must contain Variable instances, not %s" % input)
113        self.outputs = []
114        # filter outputs to make sure each element is a Variable
115        for i, output in enumerate(outputs):
116            if isinstance(output, Variable):
117                if output.owner is None:
118                    output.owner = self
119                    output.index = i
120                elif output.owner is not self or output.index != i:
121                    raise ValueError("All output variables passed to Apply must belong to it.")
122                self.outputs.append(output)
123            else:
124                raise TypeError("The 'outputs' argument to Apply must contain Variable instances with no owner, not %s" % output)
125
126    def run_params(self):
127        """
128        Returns the params for the node, or NoParams if no params is set.
129
130        """
131        try:
132            return self.op.get_params(self)
133        except theano.gof.utils.MethodNotDefined:
134            return NoParams
135
136    def __getstate__(self):
137        d = self.__dict__
138        # ufunc don't pickle/unpickle well
139        if hasattr(self.tag, 'ufunc'):
140            d = copy(self.__dict__)
141            t = d["tag"]
142            del t.ufunc
143            d["tag"] = t
144        return d
145
146    def default_output(self):
147        """
148        Returns the default output for this node.
149
150        Returns
151        -------
152        Variable instance
153            An element of self.outputs, typically self.outputs[0].
154
155        Notes
156        -----
157        May raise AttributeError self.op.default_output is out of range, or if
158        there are multiple outputs and self.op.default_output does not exist.
159
160        """
161        do = getattr(self.op, 'default_output', None)
162        if do is None:
163            if len(self.outputs) == 1:
164                return self.outputs[0]
165            else:
166                raise AttributeError(
167                    "%s.default_output should be an output index." % self.op)
168        elif not isinstance(do, integer_types):
169            raise AttributeError("%s.default_output should be an int or long" %
170                                 self.op)
171        elif do < 0 or do >= len(self.outputs):
172            raise AttributeError("%s.default_output is out of range." %
173                                 self.op)
174        return self.outputs[do]
175
176    out = property(default_output,
177                   doc="alias for self.default_output()")
178    """
179    Alias for self.default_output().
180
181    """
182
183    def __str__(self):
184        return op_as_string(self.inputs, self)
185
186    def __repr__(self):
187        return str(self)
188
189    def __asapply__(self):
190        return self
191
192    def clone(self):
193        """
194        Duplicate this Apply instance with inputs = self.inputs.
195
196        Returns
197        -------
198        object
199            A new Apply instance (or subclass instance) with new outputs.
200
201        Notes
202        -----
203        Tags are copied from self to the returned instance.
204
205        """
206        cp = self.__class__(self.op, self.inputs,
207                            [output.clone() for output in self.outputs])
208        cp.tag = copy(self.tag)
209        return cp
210
211    def clone_with_new_inputs(self, inputs, strict=True):
212        """
213        Duplicate this Apply instance in a new graph.
214
215        Parameters
216        ----------
217        inputs
218            List of Variable instances to use as inputs.
219        strict : bool
220            If True, the type fields of all the inputs must be equal
221            to the current ones (or compatible, for instance Tensor /
222            GpuArray of the same dtype and broadcastable patterns,
223            in which case they will be converted into current Type), and
224            returned outputs are guaranteed to have the same types as
225            self.outputs.  If False, then there's no guarantee that the
226            clone's outputs will have the same types as self.outputs,
227            and cloning may not even be possible (it depends on the Op).
228
229        Returns
230        -------
231        object
232            An Apply instance with the same op but different outputs.
233
234        """
235        assert isinstance(inputs, (list, tuple))
236        remake_node = False
237        new_inputs = inputs[:]
238        for i, (curr, new) in enumerate(zip(self.inputs, new_inputs)):
239            if not curr.type == new.type:
240                if strict:
241                    # If compatible, casts new into curr.type
242                    new_inputs[i] = curr.type.filter_variable(new)
243                else:
244                    remake_node = True
245        if remake_node:
246            new_node = self.op.make_node(*new_inputs)
247            new_node.tag = copy(self.tag).__update__(new_node.tag)
248        else:
249            new_node = self.clone()
250            new_node.inputs = new_inputs
251        return new_node
252
253    def get_parents(self):
254        return list(self.inputs)
255
256    # convenience properties
257    nin = property(lambda self: len(self.inputs), doc='same as len(self.inputs)')
258    """
259    Property: Number of inputs.
260
261    """
262    nout = property(lambda self: len(self.outputs), doc='same as len(self.outputs)')
263    """
264    Property: Number of outputs.
265
266    """
267    params_type = property(lambda self: self.op.params_type, doc='type to use for the params')
268
269
270class Variable(Node):
271    """
272    A :term:`Variable` is a node in an expression graph that represents a
273    variable.
274
275    The inputs and outputs of every `Apply` (theano.gof.Apply) are `Variable`
276    instances. The input and output arguments to create a `function` are also
277    `Variable` instances. A `Variable` is like a strongly-typed variable in
278    some other languages; each `Variable` contains a reference to a `Type`
279    instance that defines the kind of value the `Variable` can take in a
280    computation.
281
282    A `Variable` is a container for four important attributes:
283
284    - :literal:`type` a `Type` instance defining the kind of value this
285      `Variable` can have,
286
287    - :literal:`owner` either None (for graph roots) or the `Apply` instance
288      of which `self` is an output,
289
290    - :literal:`index` the integer such that :literal:`owner.outputs[index] is
291      this_variable` (ignored if `owner` is None),
292
293    - :literal:`name` a string to use in pretty-printing and debugging.
294
295    There are a few kinds of Variables to be aware of: A Variable which is the
296    output of a symbolic computation has a reference to the Apply instance to
297    which it belongs (property: owner) and the position of itself in the owner's
298    output list (property: index).
299
300    - `Variable` (this base type) is typically the output of a symbolic
301      computation.
302
303    - `Constant` (a subclass) which adds a default and un-replaceable
304      :literal:`value`, and requires that owner is None.
305
306    - `TensorVariable` subclass of Variable that represents a numpy.ndarray
307       object.
308
309    - `TensorSharedVariable` Shared version of TensorVariable.
310
311    - `SparseVariable` subclass of Variable that represents
312      a scipy.sparse.{csc,csr}_matrix object.
313
314    - `GpuArrayVariable` subclass of Variable that represents our object on
315      the GPU that is a subset of numpy.ndarray.
316
317    - `RandomVariable`.
318
319    A Variable which is the output of a symbolic computation will have an owner
320    not equal to None.
321
322    Using the Variables' owner field and the Apply nodes' inputs fields, one can
323    navigate a graph from an output all the way to the inputs. The opposite
324    direction is not possible until a FunctionGraph has annotated the Variables
325    with the clients field, ie, before the compilation process has begun a
326    Variable does not know which Apply nodes take it as input.
327
328    Parameters
329    ----------
330    type : a Type instance
331        The type governs the kind of data that can be associated with this
332        variable.
333    owner : None or Apply instance
334        The Apply instance which computes the value for this variable.
335    index : None or int
336        The position of this Variable in owner.outputs.
337    name : None or str
338        A string for pretty-printing and debugging.
339
340    Examples
341    --------
342
343    .. code-block:: python
344
345        import theano
346        from theano import tensor
347
348        a = tensor.constant(1.5)        # declare a symbolic constant
349        b = tensor.fscalar()            # declare a symbolic floating-point scalar
350
351        c = a + b                       # create a simple expression
352
353        f = theano.function([b], [c])   # this works because a has a value associated with it already
354
355        assert 4.0 == f(2.5)            # bind 2.5 to an internal copy of b and evaluate an internal c
356
357        theano.function([a], [c])       # compilation error because b (required by c) is undefined
358
359        theano.function([a,b], [c])     # compilation error because a is constant, it can't be an input
360
361        d = tensor.value(1.5)           # create a value similar to the constant 'a'
362        e = d + b
363        theano.function([d,b], [e])     # this works.  d's default value of 1.5 is ignored.
364
365    The python variables :literal:`a,b,c` all refer to instances of type
366    `Variable`. The `Variable` referred to by `a` is also an instance of
367    `Constant`.
368
369    `compile.function` uses each `Apply` instance's `inputs` attribute together
370    with each Variable's `owner` field to determine which inputs are necessary
371    to compute the function's outputs.
372
373    """
374
375    # __slots__ = ['type', 'owner', 'index', 'name']
376    __count__ = count(0)
377
378    def __init__(self, type, owner=None, index=None, name=None):
379        super(Variable, self).__init__()
380
381        self.tag = utils.scratchpad()
382        self.type = type
383        if owner is not None and not isinstance(owner, Apply):
384            raise TypeError("owner must be an Apply instance", owner)
385        self.owner = owner
386        if index is not None and not isinstance(index, integer_types):
387            raise TypeError("index must be an int", index)
388        self.index = index
389        if name is not None and not isinstance(name, string_types):
390            raise TypeError("name must be a string", name)
391        self.name = name
392        self.auto_name = 'auto_' + str(next(self.__count__))
393
394        Variable.notify_construction_observers(self)
395
396    def __str__(self):
397        """Return a str representation of the Variable.
398
399        """
400        if self.name is not None:
401            return self.name
402        if self.owner is not None:
403            op = self.owner.op
404            if self.index == op.default_output:
405                return str(self.owner.op) + ".out"
406            else:
407                return str(self.owner.op) + "." + str(self.index)
408        else:
409            return "<%s>" % str(self.type)
410
411    def __repr_test_value__(self):
412        """Return a repr of the test value.
413
414        Return a printable representation of the test value. It can be
415        overridden by classes with non printable test_value to provide a
416        suitable representation of the test_value.
417        """
418        return repr(theano.gof.op.get_test_value(self))
419
420    def __repr__(self, firstPass=True):
421        """Return a repr of the Variable.
422
423        Return a printable name or description of the Variable. If
424        config.print_test_value is True it will also print the test_value if
425        any.
426        """
427        to_print = [str(self)]
428        if config.print_test_value and firstPass:
429            try:
430                to_print.append(self.__repr_test_value__())
431            except AttributeError:
432                pass
433        return '\n'.join(to_print)
434
435    def clone(self):
436        """
437        Return a new Variable like self.
438
439        Returns
440        -------
441        Variable instance
442            A new Variable instance (or subclass instance) with no owner or
443            index.
444
445        Notes
446        -----
447        Tags are copied to the returned instance.
448
449        Name is copied to the returned instance.
450
451        """
452        # return copy(self)
453        cp = self.__class__(self.type, None, None, self.name)
454        cp.tag = copy(self.tag)
455        return cp
456
457    def __lt__(self, other):
458        raise NotImplementedError('Subclasses of Variable must provide __lt__',
459                                  self.__class__.__name__)
460
461    def __le__(self, other):
462        raise NotImplementedError('Subclasses of Variable must provide __le__',
463                                  self.__class__.__name__)
464
465    def __gt__(self, other):
466        raise NotImplementedError('Subclasses of Variable must provide __gt__',
467                                  self.__class__.__name__)
468
469    def __ge__(self, other):
470        raise NotImplementedError('Subclasses of Variable must provide __ge__',
471                                  self.__class__.__name__)
472
473    def get_parents(self):
474        if self.owner is not None:
475            return [self.owner]
476        return []
477
478    def eval(self, inputs_to_values=None):
479        """
480        Evaluates this variable.
481
482        Parameters
483        ----------
484        inputs_to_values
485            A dictionary mapping theano Variables to values.
486
487        Examples
488        --------
489
490        >>> import numpy as np
491        >>> import theano.tensor as T
492        >>> x = T.dscalar('x')
493        >>> y = T.dscalar('y')
494        >>> z = x + y
495        >>> np.allclose(z.eval({x : 16.3, y : 12.1}), 28.4)
496        True
497
498        We passed :func:`eval` a dictionary mapping symbolic theano
499        variables to the values to substitute for them, and it returned
500        the numerical value of the expression.
501
502        Notes
503        -----
504
505        `eval` will be slow the first time you call it on a variable --
506        it needs to call :func:`function` to compile the expression behind
507        the scenes. Subsequent calls to :func:`eval` on that same variable
508        will be fast, because the variable caches the compiled function.
509
510        This way of computing has more overhead than a normal Theano
511        function, so don't use it too much in real scripts.
512        """
513
514        if inputs_to_values is None:
515            inputs_to_values = {}
516
517        if not hasattr(self, '_fn_cache'):
518            self._fn_cache = dict()
519
520        inputs = tuple(sorted(inputs_to_values.keys(), key=id))
521        if inputs not in self._fn_cache:
522            self._fn_cache[inputs] = theano.function(inputs, self)
523        args = [inputs_to_values[param] for param in inputs]
524
525        rval = self._fn_cache[inputs](*args)
526
527        return rval
528
529    def __getstate__(self):
530        d = self.__dict__.copy()
531        d.pop("_fn_cache", None)
532        if (not config.pickle_test_value) \
533                and (hasattr(self.tag, 'test_value')):
534            if not type(config).pickle_test_value.is_default:
535                warnings.warn("pickle_test_value is not defaut value (True).\n"
536                              "Test value of variable %s(%s) will not be dumped." % (d['auto_name'], d['name']))
537            t = copy(d["tag"])
538            del t.test_value
539            d["tag"] = t
540        return d
541
542    #  refer to doc in nodes_constructed.
543    construction_observers = []
544
545    @classmethod
546    def append_construction_observer(cls, observer):
547        cls.construction_observers.append(observer)
548
549    @classmethod
550    def remove_construction_observer(cls, observer):
551        cls.construction_observers.remove(observer)
552
553    @classmethod
554    def notify_construction_observers(cls, instance):
555        for observer in cls.construction_observers:
556            observer(instance)
557
558
559class Constant(Variable):
560    """
561    A :term:`Constant` is a `Variable` with a `value` field that cannot be
562    changed at runtime.
563
564    Constant nodes make eligible numerous optimizations: constant inlining in
565    C code, constant folding, etc.
566
567    Notes
568    -----
569    The data field is filtered by what is provided in the constructor for the
570    Constant's type field.
571
572    WRITEME
573
574    """
575
576    # __slots__ = ['data']
577    def __init__(self, type, data, name=None):
578        Variable.__init__(self, type, None, None, name)
579        self.data = type.filter(data)
580        utils.add_tag_trace(self)
581
582    def equals(self, other):
583        # this does what __eq__ should do, but Variable and Apply should always be hashable by id
584        return isinstance(other, Constant) and self.signature() == other.signature()
585
586    def signature(self):
587        return (self.type, self.data)
588
589    def merge_signature(self):
590        return self.signature()
591
592    def __str__(self):
593        if self.name is not None:
594            return self.name
595        else:
596            name = str(self.data)
597            if len(name) > 20:
598                name = name[:10] + '...' + name[-10:]
599            return 'Constant{%s}' % name
600
601    def clone(self):
602        """
603        We clone this object, but we don't clone the data to lower memory
604        requirement. We suppose that the data will never change.
605
606        """
607        cp = self.__class__(self.type, self.data, self.name)
608        cp.tag = copy(self.tag)
609        return cp
610
611    def __set_owner(self, value):
612        """
613        WRITEME
614
615        Raises
616        ------
617        ValueError
618            If `value` is not `None`.
619
620        """
621        if value is not None:
622            raise ValueError("Constant instances cannot have an owner.")
623
624    owner = property(lambda self: None, __set_owner)
625    value = property(lambda self: self.data, doc='read-only data access method')
626
627    # index is not defined, because the `owner` attribute must necessarily be None
628
629
630def stack_search(start, expand, mode='bfs', build_inv=False):
631    """
632    Search through a graph, either breadth- or depth-first.
633
634    Parameters
635    ----------
636    start : deque
637        Search from these nodes.
638    expand : callable
639        When we get to a node, add expand(node) to the list of nodes to visit.
640        This function should return a list, or None.
641    mode : string
642        'bfs' or 'dfs' for breath first search or depth first search.
643
644    Returns
645    -------
646    list of `Variable` or `Apply` instances (depends on `expend`)
647        The list of nodes in order of traversal.
648
649    Notes
650    -----
651    A node will appear at most once in the return value, even if it
652    appears multiple times in the start parameter.
653
654    :postcondition: every element of start is transferred to the returned list.
655    :postcondition: start is empty.
656
657    """
658
659    if mode not in ('bfs', 'dfs'):
660        raise ValueError('mode should be bfs or dfs', mode)
661    rval_set = set()
662    rval_list = list()
663    if mode == 'bfs':
664        start_pop = start.popleft
665    else:
666        start_pop = start.pop
667    expand_inv = {}  # var: clients
668    while start:
669        l = start_pop()
670        if id(l) not in rval_set:
671            rval_list.append(l)
672            rval_set.add(id(l))
673            expand_l = expand(l)
674            if expand_l:
675                if build_inv:
676                    for r in expand_l:
677                        expand_inv.setdefault(r, []).append(l)
678                start.extend(expand_l)
679    assert len(rval_list) == len(rval_set)
680    if build_inv:
681        return rval_list, expand_inv
682    return rval_list
683
684
685def ancestors(variable_list, blockers=None):
686    """
687    Return the variables that contribute to those in variable_list (inclusive).
688
689    Parameters
690    ----------
691    variable_list : list of `Variable` instances
692        Output `Variable` instances from which to search backward through
693        owners.
694
695    Returns
696    -------
697    list of `Variable` instances
698        All input nodes, in the order found by a left-recursive depth-first
699        search started at the nodes in `variable_list`.
700
701    """
702    def expand(r):
703        if r.owner and (not blockers or r not in blockers):
704            return reversed(r.owner.inputs)
705    dfs_variables = stack_search(deque(variable_list), expand, 'dfs')
706    return dfs_variables
707
708
709def inputs(variable_list, blockers=None):
710    """
711    Return the inputs required to compute the given Variables.
712
713    Parameters
714    ----------
715    variable_list : list of `Variable` instances
716        Output `Variable` instances from which to search backward through
717        owners.
718
719    Returns
720    -------
721    list of `Variable` instances
722        Input nodes with no owner, in the order found by a left-recursive
723        depth-first search started at the nodes in `variable_list`.
724
725    """
726    vlist = ancestors(variable_list, blockers)
727    rval = [r for r in vlist if r.owner is None]
728    return rval
729
730
731def variables_and_orphans(i, o):
732    """
733    Extract list of variables between i and o nodes via
734    dfs traversal and chooses the orphans among them
735
736    Parameters
737    ----------
738    i : list
739         Input variables.
740    o : list
741         Output variables.
742
743    """
744    def expand(r):
745        if r.owner and r not in i:
746            l = list(r.owner.inputs) + list(r.owner.outputs)
747            l.reverse()
748            return l
749    variables = stack_search(deque(o), expand, 'dfs')
750    orphans = [r for r in variables if r.owner is None and r not in i]
751    return variables, orphans
752
753
754def ops(i, o):
755    """
756    Set of Ops contained within the subgraph between i and o
757
758    Parameters
759    ----------
760    i : list
761        Input variables.
762    o : list
763        Output variables.
764
765    Returns
766    -------
767    object
768        The set of ops that are contained within the subgraph that lies
769        between i and o, including the owners of the variables in o and
770        intermediary ops between i and o, but not the owners of the variables
771        in i.
772
773    """
774    ops = set()
775    variables, orphans = variables_and_orphans(i, o)
776    for r in variables:
777        if r not in i and r not in orphans:
778            if r.owner is not None:
779                ops.add(r.owner)
780    return ops
781
782
783def variables(i, o):
784    """
785    Extracts list of variables within input and output nodes via dfs travesal
786
787    Parameters
788    ----------
789    i : list
790        Input variables.
791    o : list
792        Output variables.
793
794    Returns
795    -------
796    object
797        The set of Variables that are involved in the subgraph that lies
798        between i and o. This includes i, o, orphans(i, o) and all values of
799        all intermediary steps from i to o.
800
801    """
802    return variables_and_orphans(i, o)[0]
803
804
805def orphans(i, o):
806    """
807    Extracts list of variables within input and output nodes
808    via dfs travesal and returns the orphans among them
809
810    Parameters
811    ----------
812    i : list
813        Input Variables.
814    o : list
815        Output Variables.
816
817    Returns
818    -------
819    object
820        The set of Variables which one or more Variables in o depend on but are
821        neither in i nor in the subgraph that lies between i and o.
822
823    Examples
824    --------
825    orphans([x], [(x+y).out]) => [y]
826
827    """
828    return variables_and_orphans(i, o)[1]
829
830
831def clone(i, o, copy_inputs=True, copy_orphans=None):
832    """Copies the subgraph contained between i and o.
833
834    Parameters
835    ----------
836    i : list
837        Input Variables.
838    o : list
839        Output Variables.
840    copy_inputs : bool
841        If True, the inputs will be copied (defaults to True).
842    copy_orphans:
843        When None, use the copy_inputs value,
844        When True, new orphans nodes are created.
845        When False, original orphans nodes are reused in the new graph.
846
847    Returns
848    -------
849    object
850        The inputs and outputs of that copy.
851
852    Note
853    ----
854
855    A constant, if in the ``i`` list is not an orpha. So it will be
856    copied depending of the ``copy_inputs`` parameter. Otherwise it
857    will be copied depending of the ``copy_orphans`` parameter.
858
859    """
860    if copy_orphans is None:
861        copy_orphans = copy_inputs
862    equiv = clone_get_equiv(i, o, copy_inputs, copy_orphans)
863    return [equiv[input] for input in i], [equiv[output] for output in o]
864
865
866def clone_get_equiv(inputs, outputs, copy_inputs=True, copy_orphans=True,
867                    memo=None):
868    """
869    Return a dictionary that maps from Variable and Apply nodes in the
870    original graph to a new node (a clone) in a new graph.
871
872    This function works by recursively cloning inputs... rebuilding a directed
873    graph from the inputs up to eventually building new outputs.
874
875    Parameters
876    ----------
877    inputs : a list of Variables
878    outputs : a list of Variables
879    copy_inputs : bool
880        True means to create the cloned graph from new input
881        nodes (the bottom of a feed-upward graph).
882        False means to clone a graph that is rooted at the original input
883        nodes.
884    copy_orphans:
885        When True, new constant nodes are created. When False, original
886        constant nodes are reused in the new graph.
887    memo : None or dict
888        Optionally start with a partly-filled dictionary for the return value.
889        If a dictionary is passed, this function will work in-place on that
890        dictionary and return it.
891
892    """
893    if memo is None:
894        memo = {}
895
896    # clone the inputs if necessary
897    for input in inputs:
898        if copy_inputs:
899            cpy = input.clone()
900            cpy.owner = None
901            cpy.index = None
902            memo.setdefault(input, cpy)
903        else:
904            memo.setdefault(input, input)
905
906    # go through the inputs -> outputs graph cloning as we go
907    for apply in io_toposort(inputs, outputs):
908        for input in apply.inputs:
909            if input not in memo:
910                if copy_orphans:
911                    cpy = input.clone()
912                    memo[input] = cpy
913                else:
914                    memo[input] = input
915
916        new_apply = apply.clone_with_new_inputs([memo[i] for i in apply.inputs])
917        memo.setdefault(apply, new_apply)
918        for output, new_output in zip(apply.outputs, new_apply.outputs):
919            memo.setdefault(output, new_output)
920
921    # finish up by cloning any remaining outputs (it can happen)
922    for output in outputs:
923        if output not in memo:
924            memo[output] = output.clone()
925
926    return memo
927
928
929def general_toposort(outputs, deps, debug_print=False,
930                     compute_deps_cache=None, deps_cache=None,
931                     clients=None):
932    """
933    WRITEME
934
935    Parameters
936    ----------
937    deps
938        A python function that takes a node as input and returns its dependence.
939    compute_deps_cache : optional
940        If provided deps_cache should also be provided. This is a function like
941        deps, but that also cache its results in a dict passed as deps_cache.
942    deps_cache : dict
943        Must be used with compute_deps_cache.
944    clients : dict
945        If a dict is passed it will be filled with a mapping of node
946        -> clients for each node in the subgraph.
947
948    Notes
949    -----
950        deps(i) should behave like a pure function (no funny business with
951        internal state).
952
953        deps(i) will be cached by this function (to be fast).
954
955        The order of the return value list is determined by the order of nodes
956        returned by the deps() function.
957
958        deps should be provided or can be None and the caller provides
959        compute_deps_cache and deps_cache. The second option removes a Python
960        function call, and allows for more specialized code, so it can be
961        faster.
962
963    """
964    if compute_deps_cache is None:
965        deps_cache = {}
966
967        def compute_deps_cache(io):
968            if io not in deps_cache:
969                d = deps(io)
970                if d:
971                    if not isinstance(d, (list, OrderedSet)):
972                        raise TypeError(
973                            "Non-deterministic collections here make"
974                            " toposort non-deterministic.")
975                    deps_cache[io] = list(d)
976                else:
977                    deps_cache[io] = d
978                return d
979            else:
980                return deps_cache[io]
981    assert deps_cache is not None
982
983    assert isinstance(outputs, (tuple, list, deque))
984
985    reachable, _clients = stack_search(deque(outputs), compute_deps_cache,
986                                       'dfs', True)
987    if clients is not None:
988        clients.update(_clients)
989    sources = deque([r for r in reachable if not deps_cache.get(r, None)])
990
991    rset = set()
992    rlist = []
993    while sources:
994        node = sources.popleft()
995        if node not in rset:
996            rlist.append(node)
997            rset.add(node)
998            for client in _clients.get(node, []):
999                d = [a for a in deps_cache[client] if a is not node]
1000                deps_cache[client] = d
1001                if not d:
1002                    sources.append(client)
1003
1004    if len(rlist) != len(reachable):
1005        if debug_print:
1006            print('')
1007            print(reachable)
1008            print(rlist)
1009        raise ValueError('graph contains cycles')
1010
1011    return rlist
1012
1013
1014def io_toposort(inputs, outputs, orderings=None, clients=None):
1015    """
1016    Perform topological sort from input and output nodes
1017
1018    Parameters
1019    ----------
1020    inputs : list or tuple of Variable instances
1021    outputs : list or tuple of Apply instances
1022    orderings : dict
1023        Key: Apply instance. Value: list of Apply instance.
1024        It is important that the value be a container with a deterministic
1025        iteration order. No sets allowed!
1026    clients : dict
1027        If a dict is provided it will be filled with mappings of
1028        node->clients for each node in the subgraph that is sorted
1029
1030    """
1031    if not orderings and clients is None:  # ordering can be None or empty dict
1032        # Specialized function that is faster when more then ~10 nodes
1033        # when no ordering.
1034
1035        # Do a new stack implementation with the vm algo.
1036        # This will change the order returned.
1037        computed = set(inputs)
1038        todo = [o.owner for o in reversed(outputs) if o.owner]
1039        order = []
1040        while todo:
1041            cur = todo.pop()
1042            # We suppose that all outputs are always computed
1043            if cur.outputs[0] in computed:
1044                continue
1045            if all([i in computed or i.owner is None for i in cur.inputs]):
1046                computed.update(cur.outputs)
1047                order.append(cur)
1048            else:
1049                todo.append(cur)
1050                todo.extend(i.owner for i in cur.inputs if i.owner)
1051        return order
1052
1053    compute_deps = None
1054    compute_deps_cache = None
1055    iset = set(inputs)
1056    deps_cache = {}
1057
1058    if not orderings:  # ordering can be None or empty dict
1059        # Specialized function that is faster when no ordering.
1060        # Also include the cache in the function itself for speed up.
1061
1062        def compute_deps_cache(obj):
1063            if obj in deps_cache:
1064                return deps_cache[obj]
1065            rval = []
1066            if obj not in iset:
1067                if isinstance(obj, Variable):
1068                    if obj.owner:
1069                        rval = [obj.owner]
1070                elif isinstance(obj, Apply):
1071                    rval = list(obj.inputs)
1072                if rval:
1073                    if not isinstance(rval, (list, OrderedSet)):
1074                        raise TypeError(
1075                            "Non-deterministic collections here make"
1076                            " toposort non-deterministic.")
1077                    deps_cache[obj] = list(rval)
1078                else:
1079                    deps_cache[obj] = rval
1080            else:
1081                deps_cache[obj] = rval
1082            return rval
1083    else:
1084
1085        # the inputs are used only here in the function that decides what
1086        # 'predecessors' to explore
1087        def compute_deps(obj):
1088            rval = []
1089            if obj not in iset:
1090                if isinstance(obj, Variable):
1091                    if obj.owner:
1092                        rval = [obj.owner]
1093                elif isinstance(obj, Apply):
1094                    rval = list(obj.inputs)
1095                rval.extend(orderings.get(obj, []))
1096            else:
1097                assert not orderings.get(obj, None)
1098            return rval
1099
1100    topo = general_toposort(outputs, deps=compute_deps,
1101                            compute_deps_cache=compute_deps_cache,
1102                            deps_cache=deps_cache, clients=clients)
1103    return [o for o in topo if isinstance(o, Apply)]
1104
1105
1106default_leaf_formatter = str
1107
1108
1109def default_node_formatter(op, argstrings):
1110    return "%s(%s)" % (op.op, ", ".join(argstrings))
1111
1112
1113def io_connection_pattern(inputs, outputs):
1114    """
1115    Returns the connection pattern of a subgraph defined by given
1116    inputs and outputs.
1117
1118    """
1119    inner_nodes = io_toposort(inputs, outputs)
1120
1121    # Initialize 'connect_pattern_by_var' by establishing each input as
1122    # connected only to itself
1123    connect_pattern_by_var = {}
1124    nb_inputs = len(inputs)
1125
1126    for i in range(nb_inputs):
1127        input = inputs[i]
1128        inp_connection_pattern = [i == j for j in range(nb_inputs)]
1129        connect_pattern_by_var[input] = inp_connection_pattern
1130
1131    # Iterate through the nodes used to produce the outputs from the
1132    # inputs and, for every node, infer their connection pattern to
1133    # every input from the connection patterns of their parents.
1134    for n in inner_nodes:
1135
1136        # Get the connection pattern of the inner node's op. If the op
1137        # does not define a connection_pattern method, assume that
1138        # every node output is connected to every node input
1139        try:
1140            op_connection_pattern = n.op.connection_pattern(n)
1141        except AttributeError:
1142            op_connection_pattern = ([[True] * len(n.outputs)] *
1143                                     len(n.inputs))
1144
1145        # For every output of the inner node, figure out which inputs it
1146        # is connected to by combining the connection pattern of the inner
1147        # node and the connection patterns of the inner node's inputs.
1148        for out_idx in range(len(n.outputs)):
1149            out = n.outputs[out_idx]
1150            out_connection_pattern = [False] * nb_inputs
1151
1152            for inp_idx in range(len(n.inputs)):
1153                inp = n.inputs[inp_idx]
1154
1155                if inp in connect_pattern_by_var:
1156                    inp_connection_pattern = connect_pattern_by_var[inp]
1157
1158                    # If the node output is connected to the node input, it
1159                    # means it is connected to every inner input that the
1160                    # node inputs is connected to
1161                    if op_connection_pattern[inp_idx][out_idx]:
1162                        out_connection_pattern = [out_connection_pattern[i] or
1163                                                  inp_connection_pattern[i]
1164                                                  for i in range(nb_inputs)]
1165
1166            # Store the connection pattern of the node output
1167            connect_pattern_by_var[out] = out_connection_pattern
1168
1169    # Obtain the global connection pattern by combining the
1170    # connnection patterns of the individual outputs
1171    global_connection_pattern = [[] for o in range(len(inputs))]
1172    for out in outputs:
1173        out_connection_pattern = connect_pattern_by_var.get(out)
1174        if out_connection_pattern is None:
1175            # the output is completely isolated from inputs
1176            out_connection_pattern = [False] * len(inputs)
1177        for i in range(len(inputs)):
1178            global_connection_pattern[i].append(out_connection_pattern[i])
1179
1180    return global_connection_pattern
1181
1182
1183def is_same_graph(var1, var2, givens=None, debug=False):
1184    """
1185    Return True iff Variables `var1` and `var2` perform the same computation.
1186
1187    By 'performing the same computation', we mean that they must share the same
1188    graph, so that for instance this function will return False when comparing
1189    (x * (y * z)) with ((x * y) * z).
1190
1191    The current implementation is not efficient since, when possible, it
1192    verifies equality by calling two different functions that are expected to
1193    return the same output. The goal is to verify this assumption, to
1194    eventually get rid of one of them in the future.
1195
1196    Parameters
1197    ----------
1198    var1
1199        The first Variable to compare.
1200    var2
1201        The second Variable to compare.
1202    givens
1203        Similar to the `givens` argument of `theano.function`, it can be used
1204        to perform substitutions in the computational graph of `var1` and
1205        `var2`. This argument is associated to neither `var1` nor `var2`:
1206        substitutions may affect both graphs if the substituted variable
1207        is present in both.
1208    debug : bool
1209        If True, then an exception is raised when we are in a situation where
1210        the `equal_computations` implementation cannot be called.
1211        This parameter is intended to be used in tests only, to make sure we
1212        properly test both implementations.
1213
1214    Examples
1215    --------
1216
1217        ======  ======  ======  ======
1218        var1    var2    givens  output
1219        ======  ======  ======  ======
1220        x + 1   x + 1   {}      True
1221        x + 1   y + 1   {}      False
1222        x + 1   y + 1   {x: y}  True
1223        ======  ======  ======  ======
1224
1225    """
1226    # Lazy import.
1227    if givens is None:
1228        givens = {}
1229    global equal_computations, is_same_graph_with_merge
1230    if equal_computations is None:
1231        from theano.gof.opt import is_same_graph_with_merge
1232        from theano.scan_module.scan_utils import equal_computations
1233    # Convert `givens` to dictionary.
1234    if not isinstance(givens, dict):
1235        givens = dict(givens)
1236    # Get result from the merge-based function.
1237    rval1 = is_same_graph_with_merge(var1=var1, var2=var2, givens=givens)
1238    # Get result from the function `equal_computations` from scan_utils.
1239
1240    use_equal_computations = True
1241    if givens:
1242        # We need to build the `in_xs` and `in_ys` lists. To do this, we need
1243        # to be able to tell whether a variable belongs to the computational
1244        # graph of `var1` or `var2`.
1245        # The typical case we want to handle is when `to_replace` belongs to
1246        # one of these graphs, and `replace_by` belongs to the other one. In
1247        # other situations, the current implementation of `equal_computations`
1248        # is probably not appropriate, so we do not call it.
1249        ok = True
1250        in_xs = []
1251        in_ys = []
1252        # Compute the sets of all variables found in each computational graph.
1253        inputs_var = list(map(inputs, ([var1], [var2])))
1254        all_vars = [set(variables(v_i, v_o))
1255                    for v_i, v_o in ((inputs_var[0], [var1]),
1256                                     (inputs_var[1], [var2]))]
1257
1258        def in_var(x, k):
1259            # Return True iff `x` is in computation graph of variable `vark`.
1260            return x in all_vars[k - 1]
1261
1262        for to_replace, replace_by in iteritems(givens):
1263            # Map a substitution variable to the computational graphs it
1264            # belongs to.
1265            inside = dict((v, [in_var(v, k) for k in (1, 2)])
1266                          for v in (to_replace, replace_by))
1267            if (inside[to_replace][0] and not inside[to_replace][1] and
1268                    inside[replace_by][1] and not inside[replace_by][0]):
1269                # Substitute variable in `var1` by one from `var2`.
1270                in_xs.append(to_replace)
1271                in_ys.append(replace_by)
1272            elif (inside[to_replace][1] and not inside[to_replace][0] and
1273                  inside[replace_by][0] and not inside[replace_by][1]):
1274                # Substitute variable in `var2` by one from `var1`.
1275                in_xs.append(replace_by)
1276                in_ys.append(to_replace)
1277            else:
1278                ok = False
1279                break
1280        if not ok:
1281            # We cannot directly use `equal_computations`.
1282            if debug:
1283                raise AssertionError(
1284                    'When `debug` is True we want to make sure we are also '
1285                    'using the `equal_computations` implementation')
1286            use_equal_computations = False
1287    else:
1288        in_xs = None
1289        in_ys = None
1290    if use_equal_computations:
1291        rval2 = equal_computations(xs=[var1], ys=[var2],
1292                                   in_xs=in_xs, in_ys=in_ys)
1293        assert rval2 == rval1
1294    return rval1
1295
1296
1297def op_as_string(i, op,
1298                 leaf_formatter=default_leaf_formatter,
1299                 node_formatter=default_node_formatter):
1300    """
1301    Op to return a string representation of the subgraph
1302    between i and o
1303    """
1304    strs = as_string(i, op.inputs, leaf_formatter, node_formatter)
1305    return node_formatter(op, strs)
1306
1307
1308def as_string(i, o,
1309              leaf_formatter=default_leaf_formatter,
1310              node_formatter=default_node_formatter):
1311    """
1312    Returns a string representation of the subgraph between i and o
1313
1314    Parameters
1315    ----------
1316    i : list
1317        Input `Variable` s.
1318    o : list
1319        Output `Variable` s.
1320    leaf_formatter : callable
1321        Takes a `Variable`  and returns a string to describe it.
1322    node_formatter : callable
1323        Takes an `Op`  and the list of strings corresponding to its arguments
1324        and returns a string to describe it.
1325
1326    Returns
1327    -------
1328    str
1329        Returns a string representation of the subgraph between i and o. If the
1330        same op is used by several other ops, the first occurrence will be
1331        marked as :literal:`*n -> description` and all subsequent occurrences
1332        will be marked as :literal:`*n`, where n is an id number (ids are
1333        attributed in an unspecified order and only exist for viewing
1334        convenience).
1335
1336    """
1337    i = set(i)
1338
1339    orph = orphans(i, o)
1340
1341    multi = set()
1342    seen = set()
1343    for output in o:
1344        op = output.owner
1345        if op in seen:
1346            multi.add(op)
1347        else:
1348            seen.add(op)
1349    for op in ops(i, o):
1350        for input in op.inputs:
1351            op2 = input.owner
1352            if input in i or input in orph or op2 is None:
1353                continue
1354            if op2 in seen:
1355                multi.add(op2)
1356            else:
1357                seen.add(input.owner)
1358    multi = [x for x in multi]
1359    done = set()
1360
1361    def multi_index(x):
1362        return multi.index(x) + 1
1363
1364    def describe(r):
1365        if r.owner is not None and r not in i and r not in orph:
1366            op = r.owner
1367            idx = op.outputs.index(r)
1368            if len(op.outputs) == 1:
1369                idxs = ""
1370            else:
1371                idxs = "::%i" % idx
1372            if op in done:
1373                return "*%i%s" % (multi_index(op), idxs)
1374            else:
1375                done.add(op)
1376                s = node_formatter(op, [describe(input) for input in op.inputs])
1377                if op in multi:
1378                    return "*%i -> %s" % (multi_index(op), s)
1379                else:
1380                    return s
1381        else:
1382            return leaf_formatter(r)
1383
1384    return [describe(output) for output in o]
1385
1386
1387def view_roots(r):
1388    """
1389    Utility function that returns the leaves of a search through
1390    consecutive view_map()s.
1391
1392    WRITEME
1393
1394    """
1395    owner = r.owner
1396    if owner is not None:
1397        try:
1398            view_map = owner.op.view_map
1399            view_map = dict((owner.outputs[o], i)
1400                            for o, i in iteritems(view_map))
1401        except AttributeError:
1402            return [r]
1403        if r in view_map:
1404            answer = []
1405            for i in view_map[r]:
1406                answer += view_roots(owner.inputs[i])
1407            return answer
1408        else:
1409            return [r]
1410    else:
1411        return [r]
1412
1413
1414def list_of_nodes(inputs, outputs):
1415    """
1416    Return the apply nodes of the graph between inputs and outputs.
1417
1418    """
1419    return stack_search(
1420        deque([o.owner for o in outputs]),
1421        lambda o: [inp.owner for inp in o.inputs
1422                   if inp.owner and
1423                   not any(i in inp.owner.outputs for i in inputs)])
1424
1425
1426def is_in_ancestors(l_node, f_node):
1427    r"""
1428    Goes up in the graph and returns True if the apply node f_node is found.
1429
1430    Use a stack implementation as the vm algo.
1431    We suppose all nodes are not lazy
1432    (i.e. for IfElse we suppose all inputs are computed)
1433    """
1434    computed = set()
1435    todo = [l_node]
1436    while todo:
1437        cur = todo.pop()
1438        if cur.outputs[0] in computed:
1439            continue
1440        if all([i in computed or i.owner is None for i in cur.inputs]):
1441            computed.update(cur.outputs)
1442            if cur is f_node:
1443                return True
1444        else:
1445            todo.append(cur)
1446            todo.extend(i.owner for i in cur.inputs if i.owner)
1447    return False
1448
1449
1450@contextlib.contextmanager
1451def nodes_constructed():
1452    """
1453    A contextmanager that is used in inherit_stack_trace and keeps track
1454    of all the newly created varaible nodes inside an optimization. A list
1455    of new_nodes is instantiated but will be filled in a lazy manner (when
1456    Variable.notify_construction_observers is called).
1457
1458
1459    `observer` is the entity that updates the new_nodes list.
1460    construction_observers is a list inside Variable class and contains
1461    a list of observer functions. The observer functions inside
1462    construction_observers are only called when a variable node is
1463    instantiated (where Variable.notify_construction_observers is called).
1464    When the observer function is called, a new variable node is added to
1465    the new_nodes list.
1466
1467
1468    Parameters
1469    ----------
1470    new_nodes
1471        A list of all the variable nodes that are created inside the optimization.
1472
1473    yields
1474        new_nodes list.
1475    """
1476    new_nodes = []
1477
1478    def observer(node):
1479        new_nodes.append(node)
1480    Variable.append_construction_observer(observer)
1481    yield new_nodes
1482    Variable.remove_construction_observer(observer)
1483