1"""
2Provides `DebugMode`, an evaluation mode for debugging theano internals.
3
4TODO: add support for IfElse Op, LazyLinker, PureOp, etc.
5
6"""
7from __future__ import absolute_import, print_function, division
8
9import copy
10import sys
11import gc
12import logging
13from itertools import chain, product as itertools_product
14from theano.compat import izip
15
16import numpy as np
17
18import theano
19from theano import gof, config
20from theano.compat import get_unbound_function
21from six import iteritems, itervalues
22from six.moves import StringIO, xrange
23from theano.gof import (graph, utils, link, ops_with_inner_function)
24from theano.gof.link import raise_with_op
25from theano.compile.function_module import (
26    FunctionMaker, Function, infer_reuse_pattern,
27    std_fgraph)
28from theano.compile.mode import Mode, register_mode
29from theano.compile.ops import OutputGuard, _output_guard
30from theano import change_flags
31
32
33__docformat__ = "restructuredtext en"
34_logger = logging.getLogger("theano.compile.debugmode")
35
36
37# Filter to avoid duplicating optimization warnings
38class NoDuplicateOptWarningFilter(logging.Filter):
39    prev_msgs = set([])
40
41    def filter(self, record):
42        msg = record.getMessage()
43        if msg.startswith('Optimization Warning: '):
44            if msg in self.prev_msgs:
45                return False
46            else:
47                self.prev_msgs.add(msg)
48                return True
49        return True
50
51_logger.addFilter(NoDuplicateOptWarningFilter())
52
53
54########################
55#
56# Exceptions
57#
58########################
59class DebugModeError(Exception):
60    """
61    Generic Exception raised to indicate an internal theano problem.
62
63    """
64
65    pass
66
67
68class BadThunkOutput(DebugModeError):
69    """
70    Exception: Calling the same Op twice gives inconsistent outputs.
71
72    It can be raised, for instance, if an Op's c_code and perform method
73    do not agree, or if one of these methods do not give the same result
74    when called twice with the same inputs (but different memory layouts
75    for the output).
76
77    """
78
79    r = None
80    """
81    The `Variable` instance for which conflicting values were computed.
82
83    """
84
85    thunk1 = ''
86    val1 = None
87    """
88    The value computed by `thunk1`.
89
90    """
91
92    thunk2 = ''
93    val2 = None
94    """
95    The value computed by `thunk2`.
96
97    """
98
99    def __init__(self, r, thunk1, val1, thunk2, val2, inputs_val=()):
100        super(BadThunkOutput, self).__init__()
101        self.r = r
102        self.thunk1 = thunk1
103        self.val1 = val1
104        self.thunk2 = thunk2
105        self.val2 = val2
106        self.inputs_val = inputs_val
107
108    def offending_op(self):
109        """
110        Return the Op class whose c_code and perform implementations
111        didn't match.
112
113        """
114        return type(self.r.owner.op)
115
116    def __str__(self):
117        return self.str_diagnostic()
118
119    def str_diagnostic(self):
120        """
121        Return a pretty multiline string representing the cause of
122        the exception.
123
124        """
125        sio = StringIO()
126        print("BadThunkOutput", file=sio)
127        print("  Apply   :", self.r.owner, file=sio)
128        print("  op      :", self.offending_op(), file=sio)
129        print("  Outputs Type:", self.r.type, file=sio)
130        print("  Outputs Shape:", getattr(self.val1, 'shape', None), file=sio)
131        print("  Outputs Strides:", getattr(self.val1, 'strides', None),
132              file=sio)
133        print("  Inputs Type :", [i.type for i in self.r.owner.inputs],
134              file=sio)
135        print("  Inputs Shape:", [getattr(val, 'shape', None)
136                                  for val in self.inputs_val], file=sio)
137        print("  Inputs Strides:", [getattr(val, 'strides', None)
138                                    for val in self.inputs_val], file=sio)
139        scalar_values = []
140        for ipt in self.inputs_val:
141            if getattr(ipt, "size", -1) <= 10:
142                scalar_values.append(ipt)
143            else:
144                scalar_values.append("not shown")
145        print("  Inputs values: %s" % scalar_values, file=sio)
146        print("  Bad Variable:", self.r, file=sio)
147        print("  thunk1  :", self.thunk1, file=sio)
148        print("  thunk2  :", self.thunk2, file=sio)
149
150        # Don't import it at the top of the file to prevent circular import.
151        import theano.tests.unittest_tools as utt
152        print(utt.str_diagnostic(self.val1, self.val2, None, None), file=sio)
153        ret = sio.getvalue()
154        return ret
155
156
157class BadOptimization(DebugModeError, theano.gof.toolbox.BadOptimization):
158    pass
159
160
161class BadDestroyMap(DebugModeError):
162    """
163    Exception: Some perform() or c_code() modified an input that
164    wasn't in the destroy_map.
165
166    """
167    def __init__(self, node, idx, old_val, new_val, perform):
168        super(BadDestroyMap, self).__init__()
169        self.node = node
170        self.idx = idx
171        self.old_val = old_val
172        self.new_val = new_val
173        self.perform = perform
174
175    def __str__(self):
176        sio = StringIO()
177        print("  node:", self.node, file=sio)
178        print("  perform:", self.perform, file=sio)
179        print("  node.inputs:", [(str(i), id(i))
180                                 for i in self.node.inputs], file=sio)
181        print("  destroy_map:", getattr(self.node.op,
182                                        'destroy_map', {}), file=sio)
183        print("  changed input idx:", self.idx, file=sio)
184        print("  changed input type:", self.node.inputs[self.idx].type,
185              file=sio)
186        print("  repr (old val):", repr(self.old_val), file=sio)
187        print("  repr (new val):", repr(self.new_val), file=sio)
188        try:
189            npy_old_val = np.asarray(self.old_val)
190            npy_new_val = np.asarray(self.new_val)
191            print("  value dtype (new <space> old):", npy_new_val.dtype,
192                  npy_old_val.dtype, file=sio)
193            print("  value shape (new <space> old):", npy_new_val.shape,
194                  npy_old_val.shape, file=sio)
195            print("  value min (new <space> old):", npy_new_val.min(),
196                  npy_old_val.min(), file=sio)
197            print("  value max (new <space> old):", npy_new_val.max(),
198                  npy_old_val.max(), file=sio)
199            delta = npy_new_val - npy_old_val
200            print("  value min (new-old):", delta.min(), file=sio)
201            print("  value max (new-old):", delta.max(), file=sio)
202            print("  value argmin (new-old):",
203                  np.unravel_index(delta.argmin(), npy_new_val.shape),
204                  file=sio)
205            print("  value argmax (new-old):",
206                  np.unravel_index(delta.argmax(), npy_new_val.shape),
207                  file=sio)
208            print("  location of first 10 mismatches:",
209                  np.transpose(np.nonzero(delta))[:10], file=sio)
210            print("", file=sio)
211        except Exception as e:
212            print("(Numpy-hints failed with: %s)" % str(e), file=sio)
213        print("  Hint: this can also be caused by a deficient "
214              "values_eq_approx() or __eq__() implementation "
215              "[which compared input values]", file=sio)
216        return sio.getvalue()
217
218
219class BadViewMap(DebugModeError):
220    """
221    Exception: Some perform() or c_code() created a memory alias
222    that wasn't in the view_map.
223
224    """
225
226    def __init__(self, node, output_idx, out_storage,
227                 in_alias_idx=None, out_alias_idx=None):
228        super(BadViewMap, self).__init__()
229        self.node = node
230        self.output_idx = output_idx
231        self.out_storage = out_storage
232        self.in_alias_idx = in_alias_idx
233        self.out_alias_idx = out_alias_idx
234
235    def __str__(self):
236        sio = StringIO()
237        print("  node:", self.node, file=sio)
238        print("  node.inputs:", [(str(i), id(i))
239                                 for i in self.node.inputs], file=sio)
240        print("  node.outputs:", [(str(i), id(i))
241                                  for i in self.node.outputs], file=sio)
242        print("  view_map:", getattr(self.node.op, 'view_map', {}), file=sio)
243        print("  destroy_map:", getattr(self.node.op,
244                                        'destroy_map', {}), file=sio)
245        print("  aliased output:", self.output_idx, file=sio)
246        print("  aliased output storage:", self.out_storage, file=sio)
247        if self.in_alias_idx:
248            print("  aliased to inputs:", self.in_alias_idx, file=sio)
249        if self.out_alias_idx:
250            print("  aliased to outputs:", self.out_alias_idx, file=sio)
251        return sio.getvalue()
252
253
254class StochasticOrder(DebugModeError):
255    """
256    Exception: Repeated Optimizations of the same graph do not give
257    identical results.
258
259    The most common cause is that an Optimization iterates over some
260    objects in a memory-address-dependent order (such as id() or
261    object.hash()).  If you see this error and you think it is related
262    to optimizations within Theano, email theano-dev with the message
263    attached to this exception.
264
265    """
266    pass
267
268
269class InvalidValueError(DebugModeError):
270    """
271    Exception: some Op an output value that is inconsistent with
272    the Type of that output.
273
274    Note: If there is only one parameter and it is a string, then we
275    will use it as the error message. This is needed when we catch,
276    extend, and reraise an error.
277    """
278
279    def __init__(self, r, v=None, client_node=None, hint='none',
280                 specific_hint='none'):
281        super(InvalidValueError, self).__init__()
282        self.r = r
283        self.v = v
284        self.client_node = client_node
285        self.hint = hint
286        self.specific_hint = specific_hint
287
288        # To allow extending th error message of an existing error.
289        self.full_err = None
290        if isinstance(r, str):
291            assert (v is None and
292                    client_node is None and
293                    hint == 'none' and
294                    specific_hint == 'none')
295            self.full_err = r
296
297    def __str__(self):
298        # We have a pre-made message
299        if getattr(self, 'full_err', None) is not None:
300            return self.full_err
301
302        r, v = self.r, self.v
303        type_r = r.type
304        type_v = type(v)
305        v_val = str(v)[0:100]
306        v_dtype = 'N/A'
307        v_shape = 'N/A'
308        v_min = 'N/A'
309        v_max = 'N/A'
310        v_isfinite = 'N/A'
311        try:
312            v_shape = v.shape
313            v_dtype = v.dtype
314            v_min = v.min()
315            v_max = v.max()
316            v_isfinite = np.all(np.isfinite(v))
317        except Exception:
318            pass
319        client_node = self.client_node
320        hint = self.hint
321        specific_hint = self.specific_hint
322        context = debugprint(r, prefix='  ', depth=12,
323                             file=StringIO()).getvalue()
324        return """InvalidValueError
325        type(variable) = %(type_r)s
326        variable       = %(r)s
327        type(value)    = %(type_v)s
328        dtype(value)   = %(v_dtype)s
329        shape(value)   = %(v_shape)s
330        value          = %(v_val)s
331        min(value)     = %(v_min)s
332        max(value)     = %(v_max)s
333        isfinite       = %(v_isfinite)s
334        client_node    = %(client_node)s
335        hint           = %(hint)s
336        specific_hint  = %(specific_hint)s
337        context        = ...\n%(context)s
338        """ % locals()
339
340########################
341#
342# Private Functions
343#
344########################
345
346
347def char_from_number(number):
348    """
349    Converts number to string by rendering it in base 26 using
350    capital letters as digits.
351
352    """
353
354    base = 26
355
356    rval = ""
357
358    if number == 0:
359        rval = 'A'
360
361    while number != 0:
362        remainder = number % base
363        new_char = chr(ord('A') + remainder)
364        rval = new_char + rval
365        number //= base
366
367    return rval
368
369
370def debugprint(r, prefix='', depth=-1, done=None, print_type=False,
371               file=sys.stdout, print_destroy_map=False,
372               print_view_map=False, order=None, ids='CHAR',
373               stop_on_name=False, prefix_child=None,
374               scan_ops=None, profile=None,
375               scan_inner_to_outer_inputs=None, smap=None,
376               used_ids=None, print_clients=False):
377    """
378    Print the graph leading to `r` to given depth.
379
380    Parameters
381    ----------
382    r
383        Variable instance.
384    prefix
385        Prefix to each line (typically some number of spaces).
386    depth
387        Maximum recursion depth (Default -1 for unlimited).
388    done
389        Internal. Used to pass information when recursing.
390        Dict of Apply instances that have already been printed and their
391        associated printed ids.
392    print_type
393        Whether to print the Variable type after the other infos.
394    file
395        File-like object to which to print.
396    print_destroy_map
397        Whether to print the op destroy_map after other info.
398    print_view_map
399        Whether to print the op view_map after other info.
400    order
401        If not empty will print the index in the toposort.
402    ids
403        How do we print the identifier of the variable :
404        id - print the python id value,
405        int - print integer character,
406        CHAR - print capital character,
407        "" - don't print an identifier.
408    stop_on_name
409        When True, if a node in the graph has a name, we don't print anything
410        below it.
411    scan_ops
412        Scan ops in the graph will be added inside this list for later printing
413        purposes.
414    scan_inner_to_outer_inputs
415        A dictionary mapping a scan ops inner function inputs to the scan op
416        inputs (outer inputs) for printing purposes.
417    smap
418        None or the storage_map when printing an Theano function.
419    used_ids
420        Internal. Used to pass information when recursing.
421        It is a dict from obj to the id used for it.
422        It wasn't always printed, but at least a reference to it was printed.
423    print_clients
424        If True, we will print the clients of nodes when they have more then one clients.
425    """
426    if depth == 0:
427        return
428
429    if order is None:
430        order = []
431
432    if done is None:
433        done = dict()
434
435    if scan_ops is None:
436        scan_ops = []
437
438    if print_type:
439        type_str = ' <%s>' % r.type
440    else:
441        type_str = ''
442
443    if prefix_child is None:
444        prefix_child = prefix
445
446    if used_ids is None:
447        used_ids = dict()
448
449    def get_id_str(obj, get_printed=True):
450        if obj in used_ids:
451            id_str = used_ids[obj]
452        elif obj == 'output':
453            id_str = 'output'
454        elif ids == "id":
455            id_str = "[id %s]" % str(id(r))
456        elif ids == "int":
457            id_str = "[id %s]" % str(len(used_ids))
458        elif ids == "CHAR":
459            id_str = "[id %s]" % char_from_number(len(used_ids))
460        elif ids == "":
461            id_str = ""
462        if get_printed:
463            done[obj] = id_str
464        used_ids[obj] = id_str
465        return id_str
466
467    if hasattr(r.owner, 'op'):
468        # this variable is the output of computation,
469        # so just print out the apply
470        a = r.owner
471
472        r_name = getattr(r, 'name', '')
473        # normally if the name isn't set, it'll be None, so
474        # r_name is None here
475        if r_name is None:
476            r_name = ''
477
478        if print_destroy_map:
479            destroy_map_str = str(getattr(r.owner.op, 'destroy_map', ''))
480        else:
481            destroy_map_str = ''
482
483        if print_view_map:
484            view_map_str = str(getattr(r.owner.op, 'view_map', ''))
485        else:
486            view_map_str = ''
487        if destroy_map_str and destroy_map_str != '{}':
488            destroy_map_str = 'd=' + destroy_map_str
489        if view_map_str and view_map_str != '{}':
490            view_map_str = 'v=' + view_map_str
491
492        o = ''
493        if order:
494            o = str(order.index(r.owner))
495
496        already_printed = a in done  # get_id_str put it in the dict
497        id_str = get_id_str(a)
498
499        if len(a.outputs) == 1:
500            idx = ""
501        else:
502            idx = ".%i" % a.outputs.index(r)
503        data = ""
504        if smap:
505            data = " " + str(smap.get(a.outputs[0], ''))
506        clients = ''
507        if print_clients and len(getattr(r, 'clients', [])) > 1:
508            def get_index(c):
509                try:
510                    return order.index(c)
511                except ValueError:
512                    return ""
513            clients = " clients:" + str([(get_id_str(c, False), get_index(c))
514                                         for c, i in r.clients])
515        if profile is None or a not in profile.apply_time:
516            print('%s%s%s %s%s \'%s\' %s %s %s%s%s' % (prefix, a.op,
517                                                       idx,
518                                                       id_str, type_str,
519                                                       r_name,
520                                                       destroy_map_str,
521                                                       view_map_str,
522                                                       o, data, clients), file=file)
523        else:
524            op_time = profile.apply_time[a]
525            op_time_percent = (op_time / profile.fct_call_time) * 100
526            tot_time_dict = profile.compute_total_times()
527            tot_time = tot_time_dict[a]
528            tot_time_percent = (tot_time_dict[a] / profile.fct_call_time) * 100
529
530            if len(a.outputs) == 1:
531                idx = ""
532            else:
533                idx = ".%i" % a.outputs.index(r)
534            print("%s%s%s %s%s '%s' %s %s %s%s%s --> "
535                  "%8.2es %4.1f%% %8.2es %4.1f%%"
536                  % (prefix, a.op,
537                     idx,
538                     id_str, type_str,
539                     r_name,
540                     destroy_map_str,
541                     view_map_str,
542                     o, data, clients,
543                     op_time,
544                     op_time_percent,
545                     tot_time,
546                     tot_time_percent), file=file)
547
548        if not already_printed:
549            if (not stop_on_name or
550                    not (hasattr(r, 'name') and r.name is not None)):
551                new_prefix = prefix_child + ' |'
552                new_prefix_child = prefix_child + ' |'
553
554                for idx, i in enumerate(a.inputs):
555                    if idx == len(a.inputs) - 1:
556                        new_prefix_child = prefix_child + '  '
557
558                    if hasattr(i, 'owner') and hasattr(i.owner, 'op'):
559                        if isinstance(i.owner.op,
560                                      theano.scan_module.scan_op.Scan):
561                            scan_ops.append(i)
562
563                    debugprint(
564                        i, new_prefix, depth=depth - 1, done=done,
565                        print_type=print_type, file=file, order=order,
566                        ids=ids, stop_on_name=stop_on_name,
567                        prefix_child=new_prefix_child, scan_ops=scan_ops,
568                        profile=profile,
569                        scan_inner_to_outer_inputs=scan_inner_to_outer_inputs,
570                        smap=smap, used_ids=used_ids, print_clients=print_clients)
571    else:
572        if scan_inner_to_outer_inputs is not None and\
573           r in scan_inner_to_outer_inputs:
574
575            id_str = get_id_str(r)
576            outer_r = scan_inner_to_outer_inputs[r]
577
578            if hasattr(outer_r.owner, 'op'):
579                outer_id_str = get_id_str(outer_r.owner)
580            else:
581                outer_id_str = get_id_str(outer_r)
582            print('%s%s %s%s -> %s' % (prefix, r, id_str, type_str,
583                                       outer_id_str), file=file)
584        else:
585            # this is an input variable
586            data = ""
587            if smap:
588                data = " " + str(smap.get(r, ''))
589            id_str = get_id_str(r)
590            print('%s%s %s%s%s' % (prefix, r, id_str,
591                                   type_str, data),
592                  file=file)
593
594    return file
595
596
597def _optcheck_fgraph(input_specs, output_specs, accept_inplace=False):
598    """
599    Create a FunctionGraph for debugging.
600
601    Parameters
602    ----------
603    input_specs: WRITEME
604        fgraph inputs.
605    output_specs: WRITEME
606        fgraph outputs.
607    accept_inplace : bool
608        Are inplace ops permitted in the original graph?
609
610    Returns
611    -------
612    FunctionGraph
613        A new FunctionGraph with a cloned graph, with debugging `Feature`
614        instances already installed.
615
616    """
617    equivalence_tracker = _VariableEquivalenceTracker()
618    fgraph, updates = std_fgraph(input_specs, output_specs, accept_inplace)
619    fgraph.attach_feature(equivalence_tracker)
620    return fgraph, updates, equivalence_tracker
621
622
623class DataDestroyed():
624    # this is a singleton class We put it in the storage_map when the
625    # variable value was destroyed to prevent reusing bad value for
626    # it.
627    pass
628
629data_destroyed = DataDestroyed()
630
631
632def check_eq(var, val1, val2):
633    if hasattr(var.tag, 'values_eq_approx'):
634        return var.tag.values_eq_approx(val1, val2)
635    else:
636        return var.type.values_eq_approx(val1, val2)
637
638
639def _check_inputs(node, storage_map, r_vals, dr_vals, active_nodes,
640                  clobber_dr_vals=True,
641                  perform=None, warn_input_not_reused=True):
642    """
643    Raise BadDestroyMap if necessary, update dr_vals.
644
645    Returns a list of output variables that actually worked inplace
646    (their value is aliased to the value of at least one input).
647
648    It modify the storage_map to remove node.inputs variable that have
649    been destroyed.
650
651    """
652    destroyed_idx_list = []
653    destroy_map = getattr(node.op, 'destroy_map', {})
654    for o_pos, i_pos_list in iteritems(destroy_map):
655        destroyed_idx_list.extend(i_pos_list)
656    destroyed_res_list = [node.inputs[i] for i in destroyed_idx_list]
657
658    actually_inplace_outputs = []
659    dmap = getattr(node.op, 'destroy_map', {})
660    for oo, ii in iteritems(dmap):
661        var = node.outputs[oo]
662        out_var = storage_map[var][0]
663        in_var = storage_map[node.inputs[ii[0]]][0]
664        if (hasattr(var.type, 'may_share_memory') and
665                var.type.may_share_memory(out_var, in_var)):
666            actually_inplace_outputs.append(node.outputs[oo])
667
668        if warn_input_not_reused and destroyed_res_list:
669            if isinstance(node.op, OutputGuard):
670                # The point of OutputGuard is to be declared as destructive
671                # while not destroying anything
672                continue
673            if out_var is not in_var:
674                _logger.warning("Optimization Warning: input idx %d marked "
675                                "as destroyed was not changed for node '%s'",
676                                ii[0], str(node))
677
678    vmap = getattr(node.op, 'view_map', {})
679    for oo, ii in iteritems(vmap):
680        var = node.outputs[oo]
681        out_var = storage_map[var][0]
682        in_var = storage_map[node.inputs[ii[0]]][0]
683        may_share = (hasattr(var.type, 'may_share_memory') and
684                     var.type.may_share_memory(out_var, in_var))
685        if may_share:
686            actually_inplace_outputs.append(node.outputs[oo])
687
688        if warn_input_not_reused:
689            # We don't try to optimize simple scalar and empty ndarray,
690            # as this is not worth our time. This happen at least in
691            # Subtensor when the output is a scalar But this depend on
692            # the version of numpy!
693            if getattr(out_var, 'size', 2) <= 1:
694                continue
695            if isinstance(node.op, OutputGuard):
696                # This class is not in the final graph.
697                continue
698            if not may_share:
699                _logger.warning("Optimization Warning: input idx %d marked "
700                                "as viewed but new memory allocated by node "
701                                "'%s'", ii[0], str(node))
702
703    for r_idx, r in enumerate(node.inputs):
704        if not r.type.values_eq(r_vals[r], storage_map[r][0]):
705            # some input node 'r' got changed by running the node
706            # this may or may not be ok...
707            if r in destroyed_res_list:
708                # ok, we expected r to be destroyed
709                if node in active_nodes:
710                    if dr_vals.get(r, (0, node))[1] is not node:
711                        # bad: there should only be one active node
712                        # that destroys any variable
713                        raise Exception('failure in topological ordering')
714                    if clobber_dr_vals:
715                        # no copy, this is the last use of this variable
716                        dr_vals[r] = (storage_map[r][0], node)
717                    # make sure that dr_vals[r] doens't get used again
718                    storage_map[r][0] = data_destroyed
719            else:
720                raise BadDestroyMap(node, r_idx, r_vals[r],
721                                    storage_map[r][0], perform)
722
723    return actually_inplace_outputs
724
725
726def _check_viewmap(node, storage_map):
727    """
728    This functions raises a BadViewMap exception when it detects the
729    following:
730    - Output node storages aliased to input storage, with no declaration
731      in view_map.
732    - If not aliased to an input, check if two outputs are aliased together
733      and used subsequently in the graph.
734
735    """
736
737    for oi, onode in enumerate(node.outputs):
738
739        good_alias, bad_alias = {}, {}
740        outstorage = storage_map[onode][0]
741
742        # first find out which input it aliases
743        view_map = getattr(node.op, 'view_map', {})
744        destroy_map = getattr(node.op, 'destroy_map', {})
745
746        # In theory, theano's view_map only allows for 1 output to
747        # alias 1 input. Checking for multiple aliases just in
748        # case...
749
750        for ii, inode in enumerate(node.inputs):
751            in_storage = storage_map[inode][0]
752            if in_storage is data_destroyed:
753                # If the input have been destroyed, it can't be a
754                # view. So no need to check. Also, we don't have the
755                # original value, we we wouldn't be able to do this
756                # useless check.
757                continue
758            if (hasattr(inode.type, 'may_share_memory') and
759                    inode.type.may_share_memory(outstorage, in_storage)):
760
761                nodeid = id(inode)
762                bad_alias[nodeid] = ii
763
764                # check that the aliasing was declared in [view|destroy]_map
765                if ([ii] == view_map.get(oi, None) or
766                        [ii] == destroy_map.get(oi, None)):
767
768                    good_alias[nodeid] = bad_alias.pop(nodeid)
769
770        # TODO: make sure this is correct
771        # According to OB, duplicate inputs are rejected on build graph time
772        # if they cause problems. So if they are here it should be ok.
773        for key, val in iteritems(good_alias):
774            bad_alias.pop(key, None)
775        if bad_alias:
776            raise BadViewMap(node, oi, outstorage, list(bad_alias.values()))
777
778        # if its not aliased to input, check output->output aliasing
779        if not good_alias and _is_used_in_graph(onode):
780            for other_oi, other_onode in enumerate(node.outputs):
781                if other_oi == oi:
782                    continue
783
784                other_storage = storage_map[other_onode][0]
785                # check to see if we share memory with this other output
786                # this is not a problem if the node is not actually used
787                if (_is_used_in_graph(other_onode) and
788                    hasattr(other_onode.type, 'may_share_memory') and
789                    other_onode.type.may_share_memory(outstorage,
790                                                      other_storage)):
791                    raise BadViewMap(node, oi, outstorage,
792                                     out_alias_idx=other_oi)
793
794
795def _is_used_in_graph(var):
796    """
797
798    Returns
799    -------
800    bool
801        True if `var` is used by another node in the graph.
802
803    """
804    return not(var.clients == [('output', 1)] or var.clients == [])
805
806
807def _check_strides_match(a, b, warn_err, op):
808    """
809
810    Parameters
811    ----------
812    warn_err
813        If 0, no warning, if 1 warning, if 2 error.
814
815    """
816    if warn_err == 0:
817        return
818
819    try:
820        strides_eq = a.strides == b.strides
821    except Exception:
822        return  # no strides
823
824    if not strides_eq:
825        e = TypeError('Stride mismatch', (a.shape, b.shape, a.strides,
826                                          b.strides, str(op)))
827        if warn_err == 2:
828            raise e
829        else:
830            print('WARNING:', e, file=sys.stderr)
831
832
833def _lessbroken_deepcopy(a):
834    """
835
836    Parameters
837    ----------
838    a
839        Any object
840
841    Returns
842    -------
843    object
844        A copy of `a` that shares no internal storage with the original
845        (a deep copy). This function handles numpy arrays specially, because
846        copy.deepcopy() called on a 0-d array will return a numpy scalar,
847        not an array.
848
849    """
850    # this exists because copy.deepcopy on numpy arrays is broken
851    # This logic is also in link.py
852    from theano.gof.type import _cdata_type
853    if type(a) in (np.ndarray, np.memmap):
854        rval = a.copy(order='K')
855    elif type(a) is _cdata_type:
856        # This is not copyable (and should be used for constant data).
857        rval = a
858    else:
859        rval = copy.deepcopy(a)
860
861    assert type(rval) == type(a), (type(rval), type(a))
862    if isinstance(rval, np.ndarray):
863        assert rval.dtype == a.dtype
864    return rval
865
866
867def _find_bad_optimizations0(order, reasons, r_vals):
868    """
869    Use a simple algorithm to find broken optimizations.
870
871    This algorithm is simple to understand, but sometimes when there's
872    a problem it identifies the wrong optimization as the culprit.
873    The problem stems from the fact that results are not evaluated in
874    chronological order (looking at when they were introduced to the
875    graph).
876
877    """
878    # iterate over variables looking for values that don't match the
879    # values of the variables they replaced.  This is the sign of a
880    # broken optimization.
881    for i, node in enumerate(order):
882        for new_r in node.outputs:
883            for reason, r, old_graph_str, new_graph_str in reasons[new_r]:
884                # check if the value for new_r doesn't match the value for r
885                new_r_val = r_vals[new_r]
886                r_val = r_vals[r]
887                assert r.type == new_r.type
888
889                if hasattr(new_r.tag, 'values_eq_approx'):
890                    check = new_r.tag.values_eq_approx(r_val, new_r_val)
891                elif hasattr(new_r, 'values_eq_approx'):
892                    # This way will be deprecated later, but not right now
893                    check = new_r.values_eq_approx(r_val, new_r_val)
894                else:
895                    check = r.type.values_eq_approx(r_val, new_r_val)
896                if not check:
897                    raise BadOptimization(old_r=r,
898                                          new_r=new_r,
899                                          old_r_val=r_val,
900                                          new_r_val=new_r_val,
901                                          reason=reason,
902                                          old_graph=old_graph_str,
903                                          new_graph=new_graph_str)
904
905
906def _find_bad_optimizations1(order, reasons, r_vals):
907    # iterate over variables looking for values that don't match the
908    # values of the variables they replaced.  This is the sign of a
909    # broken optimization.
910
911    # identify sets of variables that are supposed to be equivalent
912    equivalence_sets = {}
913    program_position = {}  # node -> order idx
914
915    for i, node in enumerate(order):
916        program_position[node] = i
917        for new_r in node.outputs:
918            equivalence_sets.setdefault(new_r, set([new_r]))
919            for reason, r, old_graph_str, new_graph_str in reasons[new_r]:
920                equivalence_sets[new_r].update(equivalence_sets.setdefault(
921                    r, set([r])))
922                for er in equivalence_sets[r]:
923                    equivalence_sets[er] = equivalence_sets[new_r]
924
925    # identify equivalence sets that are broken
926    equivalence_sets_broken = {}  # id(set) -> Bool
927    there_is_a_problem = False
928    for r, r_equiv in iteritems(equivalence_sets):
929        if id(r_equiv) not in equivalence_sets_broken:
930            equivalence_sets_broken[id(r_equiv)] = False
931            # loop over the variables in the set comparing them to be
932            # equal enough
933            re0 = None
934            for re in r_equiv:
935                if re0:
936                    new_r_val = r_vals[re]
937                    r_val = r_vals[re0]
938                    assert re.type == re0.type
939                    if not re.type.values_eq_approx(r_val, new_r_val):
940                        equivalence_sets_broken[id(r_equiv)] = True
941                        there_is_a_problem = True
942                re0 = re
943
944    if there_is_a_problem:
945        # which broken equivalence set has the earliest-occurring element?
946        first_broken_set = None
947        for i, node in enumerate(order):
948            for r in node.outputs:
949                r_equiv = equivalence_sets[r]
950                if equivalence_sets_broken[id(r_equiv)]:
951                    first_broken_set = r_equiv
952        # TODO finish this to produce good diagnostic information
953        print(first_broken_set)
954        raise Exception('broken')
955
956
957def _find_bad_optimizations2(order, reasons, r_vals):
958    """
959    Use a simple algorithm to find broken optimizations.
960
961    This algorithm is simple to understand, but sometimes when there's
962    a problem it identifies the wrong optimization as the culprit.
963    The problem stems from the fact that results are not evaluated in
964    chronological order (looking at when they were introduced to the
965    graph).
966
967    """
968
969    checked_variables = set()
970
971    def check_variable_norec(new_r):
972        """
973        Verify that `r` has the same value as the results it replaces.
974
975        """
976        for reason, r, old_graph_str, new_graph_str in reasons[new_r]:
977            new_r_val = r_vals[new_r]
978            r_val = r_vals[r]
979
980            if (r.type != new_r.type) or (not r.type.values_eq_approx(
981                    r_val, new_r_val)):
982                raise BadOptimization(old_r=r,
983                                      new_r=new_r,
984                                      old_r_val=r_val,
985                                      new_r_val=new_r_val,
986                                      reason=reason,
987                                      old_graph=old_graph_str,
988                                      new_graph=new_graph_str)
989
990    def check_variable(r):
991        if r in checked_variables:
992            return
993        checked_variables.add(r)
994
995        # (recursively) first check all the variables that could make
996        # r look bad:
997        list_of_vars = [old_r for (reason, old_r, olds, news) in reasons[r]]
998        if (None is not r.owner):
999            list_of_vars += r.owner.inputs
1000
1001        for var_that_could_make_r_look_bad in list_of_vars:
1002            check_variable(var_that_could_make_r_look_bad)
1003
1004        check_variable_norec(r)
1005
1006    # iterate over variables looking for values that don't match the
1007    # values of the variables they replaced.  This is the sign of a
1008    # broken optimization.
1009    for i, node in enumerate(order):
1010        for new_r in node.outputs:
1011            check_variable(new_r)
1012
1013_find_bad_optimizations = _find_bad_optimizations0
1014
1015
1016def _get_preallocated_maps(node, thunk, prealloc_modes, def_val,
1017                           storage_map, r_vals, dr_vals, perform,
1018                           active_order_set, inplace_outs, init_outputs):
1019    """
1020    Preallocate outputs in different memory layouts.
1021
1022    """
1023
1024    # To avoid circular imports
1025    from theano.tensor import TensorType
1026    from theano.gpuarray import GpuArrayType
1027    try:
1028        import pygpu
1029    except ImportError:
1030        pass
1031
1032    # TODO: Sparse? Scalar does not really make sense.
1033
1034    # Do not preallocate memory for outputs that actually work inplace
1035    considered_outputs = []
1036    for r in node.outputs:
1037        if r not in inplace_outs:
1038            considered_outputs.append(r)
1039
1040    # Output storage that was initially present in the storage_map
1041    if 'initial' in prealloc_modes or 'ALL' in prealloc_modes:
1042        initial_outputs = {}
1043        for r in considered_outputs:
1044            if r in init_outputs:
1045                initial_outputs[r] = init_outputs[r]
1046
1047        if initial_outputs:
1048            yield ('initial', initial_outputs)
1049
1050    # reuse_output: use a copy of the same storage returned the first time
1051    # TODO: optimization warning if the storage in reuse_outputs
1052    # is not reused
1053    if 'previous' in prealloc_modes or 'ALL' in prealloc_modes:
1054        reuse_outputs = {}
1055        for r in considered_outputs:
1056            # We want to reuse the exact same memory buffer,
1057            # so we keep the copy in r_vals
1058            new_r = _lessbroken_deepcopy(r_vals[r])
1059            reuse_outputs[r] = r_vals[r]
1060            r_vals[r] = new_r
1061            # Sometimes, outputs can be aliased together.
1062            # I'm not sure why it is legitimate, but there are tests about it.
1063            # So, we cannot fill r_vals[r] with def_val yet, we have to wait
1064            # until all output values are deepcopied.
1065
1066        for r in considered_outputs:
1067            # There is no risk to overwrite inputs, since r does not work
1068            # inplace.
1069            if isinstance(r.type, (TensorType, GpuArrayType)):
1070                reuse_outputs[r][...] = np.asarray(
1071                    def_val).astype(r.type.dtype)
1072
1073        if reuse_outputs:
1074            yield ('previous', reuse_outputs)
1075        # clear memory that is not needed any more
1076        del reuse_outputs
1077
1078    # c_cont_output: use a c-continuous array
1079    # (for TensorType, else None)
1080    if 'c_contiguous' in prealloc_modes or 'ALL' in prealloc_modes:
1081        c_cont_outputs = {}
1082        for r in considered_outputs:
1083            if isinstance(r.type, (TensorType, GpuArrayType)):
1084                # Build a C-contiguous buffer
1085                new_buf = r.type.value_zeros(r_vals[r].shape)
1086                assert new_buf.flags["C_CONTIGUOUS"]
1087                new_buf[...] = np.asarray(def_val).astype(r.type.dtype)
1088
1089                c_cont_outputs[r] = new_buf
1090
1091        if len(c_cont_outputs):
1092            yield ('c_contiguous', c_cont_outputs)
1093            del c_cont_outputs
1094
1095    # f_cont_output: use a fortran-continuous ndarray
1096    # (for TensorType, only)
1097    if 'f_contiguous' in prealloc_modes or 'ALL' in prealloc_modes:
1098        f_cont_outputs = {}
1099        for r in considered_outputs:
1100            if isinstance(r.type, (TensorType, GpuArrayType)):
1101                new_buf = np.zeros(
1102                    shape=r_vals[r].shape,
1103                    dtype=r_vals[r].dtype,
1104                    order='F')
1105                new_buf[...] = def_val
1106                if isinstance(r.type, GpuArrayType):
1107                    new_buf = pygpu.array(new_buf)
1108
1109                f_cont_outputs[r] = new_buf
1110
1111        if len(f_cont_outputs):
1112            yield ('f_contiguous', f_cont_outputs)
1113            del f_cont_outputs
1114
1115    # We assume that the different outputs of a same Op will behave
1116    # independently, and there is no need to test over all combinations
1117    # of outputs (the time taken is prohibitive).
1118    # When all outputs on a certain dimension are broadcastable, the Op
1119    # can assume that the shape is 1 on that dimension, and stride testing
1120    # is less relevant.
1121    # Dimensions should be align by the innermost index, so we iterate
1122    # from the end of shapes.
1123    if ('strided' in prealloc_modes or
1124            'wrong_size' in prealloc_modes or
1125            'ALL' in prealloc_modes):
1126        max_ndim = 0
1127        rev_out_broadcastable = []
1128        for r in considered_outputs:
1129            if isinstance(r.type, (TensorType, GpuArrayType)):
1130                if max_ndim < r.ndim:
1131                    rev_out_broadcastable += [True] * (r.ndim - max_ndim)
1132                    max_ndim = r.ndim
1133                assert len(rev_out_broadcastable) == max_ndim
1134
1135                for i, b in enumerate(r.broadcastable[::-1]):
1136                    rev_out_broadcastable[i] = rev_out_broadcastable[i] and b
1137        out_broadcastable = rev_out_broadcastable[::-1]
1138
1139    if 'strided' in prealloc_modes or 'ALL' in prealloc_modes:
1140        check_ndim = config.DebugMode.check_preallocated_output_ndim
1141        # Initial allocation
1142        init_strided = {}
1143        for r in considered_outputs:
1144            if isinstance(r.type, (TensorType, GpuArrayType)):
1145                # Create a buffer twice as large in every dimension,
1146                # except if broadcastable, or for dimensions above
1147                # config.DebugMode.check_preallocated_output_ndim
1148                buf_shape = []
1149                for s, b in zip(r_vals[r].shape, r.broadcastable):
1150                    if b or ((r.ndim - len(buf_shape)) > check_ndim):
1151                        buf_shape.append(s)
1152                    else:
1153                        buf_shape.append(s * 2)
1154                new_buf = r.type.value_zeros(buf_shape)
1155                new_buf[...] = np.asarray(def_val).astype(r.type.dtype)
1156                init_strided[r] = new_buf
1157
1158        # The number of combinations is exponential in the number of
1159        # dimensions, and some ops can have tens of outputs. To prevent
1160        # tests from lasting days, we use the same strides for all
1161        # dimensions but the last check_ndim ones.
1162        # Moreover, to avoid memory problems, we do not test with strides
1163        # 2 and -2 on those dimensions.
1164        step_signs_list = []
1165        for b in out_broadcastable[-check_ndim:]:
1166            if b:
1167                step_signs_list.append((1,))
1168            else:
1169                step_signs_list.append((-1, 1))
1170
1171        # Use the same step on all dimensions before the last check_ndim.
1172        if all(out_broadcastable[:-check_ndim]):
1173            step_signs_list = [(1,)] + step_signs_list
1174        else:
1175            step_signs_list = [(-1, 1)] + step_signs_list
1176
1177        for step_signs in itertools_product(*step_signs_list):
1178            for step_size in (1, 2):
1179                strided = {}
1180
1181                # First, the dimensions above check_ndim, then the other ones
1182                # Do not test with 2 or -2 for dimensions above check_ndim
1183                steps = [step_signs[0]] * len(out_broadcastable[:-check_ndim])
1184                steps += [s * step_size for s in step_signs[1:]]
1185
1186                name = 'strided%s' % str(tuple(steps))
1187                for r in considered_outputs:
1188                    if r in init_strided:
1189                        strides = []
1190                        shapes = []
1191                        for i, size in enumerate(r_vals[r].shape):
1192                            shapes.append(slice(None, size, None))
1193                            strides.append(slice(None, None, steps[i]))
1194
1195                        r_buf = init_strided[r]
1196
1197                        if r_buf.ndim > 0:
1198                            r_buf = r_buf[tuple(strides)][tuple(shapes)]
1199                        assert r_buf.shape == r_vals[r].shape
1200
1201                        r_buf[...] = np.asarray(def_val).astype(r_buf.dtype)
1202                        strided[r] = r_buf
1203
1204                if strided:
1205                    yield (name, strided)
1206                del strided
1207
1208    if 'wrong_size' in prealloc_modes or 'ALL' in prealloc_modes:
1209        # For each dimension, try size-1, size, size+1
1210        for dim, b in enumerate(out_broadcastable):
1211            if b:
1212                # The shape has to be 1
1213                continue
1214
1215            shape_diff = [0] * max_ndim
1216            for diff in (-1, 1):
1217                shape_diff[dim] = diff
1218
1219                wrong_size = {}
1220                name = 'wrong_size%s' % str(tuple(shape_diff))
1221
1222                for r in considered_outputs:
1223                    if isinstance(r.type, (TensorType, GpuArrayType)):
1224                        r_shape_diff = shape_diff[:r.ndim]
1225                        out_shape = [max((s + sd), 0)
1226                                     for s, sd in zip(r_vals[r].shape,
1227                                                      r_shape_diff)]
1228                        new_buf = r.type.value_zeros(out_shape)
1229                        new_buf[...] = np.asarray(
1230                            def_val).astype(r.type.dtype)
1231                        wrong_size[r] = new_buf
1232
1233                if wrong_size:
1234                    yield (name, wrong_size)
1235                del wrong_size
1236
1237
1238def _check_preallocated_output(node, thunk, prealloc_modes, def_val,
1239                               storage_map, r_vals, dr_vals, perform,
1240                               active_order_set, inplace_outs, init_outputs):
1241    """
1242    Try to apply thunk() on different output storages.
1243
1244    """
1245
1246    # If node has an inner compiled Theano function with mode DebugMode,
1247    # disable memory checks in that mode, since they were already run.
1248    try:
1249        changed_inner_mode = False
1250        if type(getattr(node, 'op', None)) in ops_with_inner_function:
1251            fn_attr_name = ops_with_inner_function[type(node.op)]
1252            fn = getattr(node.op, fn_attr_name, None)
1253            if (not fn or
1254                    not hasattr(fn, 'maker') or
1255                    not hasattr(fn.maker, 'mode')):
1256                _logger.warn('Expected theano function not found in %s.%s',
1257                             node.op, fn_attr_name)
1258            else:
1259                if isinstance(fn.maker.mode, DebugMode):
1260                    backup_mode = fn.maker.mode
1261                    new_mode = copy.copy(backup_mode)
1262                    # Disactivate as many checks as possible
1263                    new_mode.check_py_code = False
1264                    new_mode.check_isfinite = False
1265                    new_mode.require_matching_strides = 0
1266                    new_mode.check_preallocated_output = []
1267                    new_mode.stability_patience = 1
1268                    fn.maker.mode = new_mode
1269                    changed_inner_mode = True
1270                    _logger.info('changing inner mode')
1271
1272        # Set of inputs that are marked as destroyed or viewed
1273        aliased_inputs = set()
1274        dmap = getattr(node.op, 'destroy_map', {})
1275        vmap = getattr(node.op, 'view_map', {})
1276        for i, r in enumerate(node.inputs):
1277            if any(i in v for v in chain(itervalues(dmap), itervalues(vmap))):
1278                aliased_inputs.add(r)
1279
1280        _logger.debug('starting preallocated output checking')
1281        for (name, out_map) in _get_preallocated_maps(
1282                node, thunk, prealloc_modes, def_val, storage_map, r_vals,
1283                dr_vals, perform, active_order_set, inplace_outs,
1284                init_outputs):
1285            _logger.debug('  name = %s', name)
1286
1287            thunk_name = '%s with %s output' % (perform, name)
1288
1289            if not out_map:
1290                # Map is empty, there is no need to execute thunk() again
1291                _logger.warn('%s: out_map is empty', name)
1292                continue
1293
1294            # Copy the inputs over, if they were marked as destroyed or viewed
1295            # (we will destroy the output at some point so it can destroy
1296            # the input)
1297            for r in aliased_inputs:
1298                storage_map[r][0] = _lessbroken_deepcopy(r_vals[r])
1299
1300            # Get the appropriate output storages
1301            # (no copy)
1302            for r in node.outputs:
1303                storage_map[r][0] = out_map.get(r, None)
1304
1305            thunk()
1306
1307            # Check outputs
1308            for r in node.outputs:
1309                if not r.type.is_valid_value(storage_map[r][0]):
1310                    raise InvalidValueError(
1311                        r, storage_map[r][0],
1312                        hint=thunk_name,
1313                        specific_hint=r.type.value_validity_msg(
1314                            storage_map[r][0]))
1315
1316            _check_inputs(node, storage_map, r_vals, dr_vals, active_order_set,
1317                          clobber_dr_vals=False,
1318                          perform=thunk_name,
1319                          warn_input_not_reused=False)
1320
1321            _check_viewmap(node, storage_map)
1322
1323            for r in node.outputs:
1324                if not check_eq(r, r_vals[r], storage_map[r][0]):
1325                    # TODO: indicate it is not a C/Py problem
1326                    inputs_val = [storage_map[inp][0] for inp in
1327                                  r.owner.inputs]
1328                    raise BadThunkOutput(r,
1329                                         thunk1='Reference value',
1330                                         val1=r_vals[r],
1331                                         thunk2=thunk_name,
1332                                         val2=storage_map[r][0],
1333                                         inputs_val=inputs_val)
1334
1335            # Clear storage_map
1336            for r in node.outputs:
1337                storage_map[r][0] = None
1338
1339        _logger.debug('finished preallocated output checking')
1340    finally:
1341        if changed_inner_mode:
1342            _logger.info('changing mode back')
1343            fn.maker.mode = backup_mode
1344
1345
1346class _FunctionGraphEvent(object):
1347    """
1348    A record of an event in the life of an FunctionGraph.
1349
1350    The __eq__ function is important here, as it is the basis for
1351    comparing optimization runs.
1352
1353    """
1354
1355    kind = ""
1356    """
1357    One of 'import', 'change', 'prune'.
1358
1359    """
1360
1361    node = None
1362    """
1363    Either 'output' or an Apply instance.
1364
1365    """
1366
1367    op = None
1368    """Either 'output' or an Op instance"""
1369
1370    idx = None
1371    """
1372    Change events involve an position index of the input variable.
1373
1374    """
1375
1376    reason = None
1377    """
1378    Change events sometimes have a reason.
1379
1380    """
1381
1382    def __init__(self, kind, node, idx=None, reason=None):
1383        self.kind = kind
1384        if node == 'output':
1385            self.node = 'output'
1386            self.op = 'output'
1387        else:
1388            self.node = node
1389            self.op = node.op
1390        self.idx = idx
1391        self.reason = str(reason)
1392
1393    def __str__(self):
1394        if self.kind == 'change':
1395            if (self.op != 'output'):
1396                msg = str(len(self.node.inputs))
1397            else:
1398                msg = ''
1399
1400            return ' '.join(['change',
1401                             self.reason,
1402                             str(self.op),
1403                             str(self.idx),
1404                             msg])
1405        else:
1406            return str(self.__dict__)
1407
1408    def __eq__(self, other):
1409        rval = type(self) == type(other)
1410        if rval:
1411            # nodes are not compared because this comparison is
1412            # supposed to be true for corresponding events that happen
1413            # in different FunctionGraph instances (different graphs)
1414            for attr in ['kind', 'op', 'idx', 'reason']:
1415                rval = rval and getattr(self, attr) == getattr(other, attr)
1416        return rval
1417
1418    def __ne__(self, other):
1419        return not (self == other)
1420
1421
1422class _VariableEquivalenceTracker(object):
1423    """
1424    A FunctionGraph Feature that keeps tabs on an FunctionGraph and
1425    tries to detect problems.
1426
1427    """
1428
1429    fgraph = None
1430    """WRITEME"""
1431
1432    equiv = None
1433    """WRITEME"""
1434
1435    active_nodes = None
1436    """WRITEME"""
1437
1438    inactive_nodes = None
1439    """WRITEME"""
1440
1441    all_variables_ever = None
1442    """WRITEME"""
1443
1444    reasons = None
1445    """WRITEME"""
1446
1447    replaced_by = None
1448    """WRITEME"""
1449
1450    event_list = None
1451    """WRITEME"""
1452
1453    def __init__(self):
1454        self.fgraph = None
1455
1456    def on_attach(self, fgraph):
1457        assert self.fgraph is None
1458        self.equiv = {}
1459        self.active_nodes = set()
1460        self.inactive_nodes = set()
1461        self.fgraph = fgraph
1462        self.all_variables_ever = []
1463        self.reasons = {}
1464        self.replaced_by = {}
1465        self.event_list = []
1466        for node in fgraph.toposort():
1467            self.on_import(fgraph, node, "on_attach")
1468
1469    def on_detach(self, fgraph):
1470        assert fgraph is self.fgraph
1471        self.fgraph = None
1472
1473    def on_prune(self, fgraph, node, reason):
1474        self.event_list.append(_FunctionGraphEvent('prune', node,
1475                                                   reason=str(reason)))
1476        assert node in self.active_nodes
1477        assert node not in self.inactive_nodes
1478        self.active_nodes.remove(node)
1479        self.inactive_nodes.add(node)
1480
1481    def on_import(self, fgraph, node, reason):
1482        self.event_list.append(_FunctionGraphEvent('import', node,
1483                                                   reason=str(reason)))
1484
1485        assert node not in self.active_nodes
1486        self.active_nodes.add(node)
1487
1488        if node in self.inactive_nodes:
1489            self.inactive_nodes.remove(node)
1490            for r in node.outputs:
1491                assert r in self.equiv
1492        else:
1493            for r in node.outputs:
1494                assert r not in self.equiv
1495                self.equiv[r] = set([r])
1496                self.all_variables_ever.append(r)
1497                self.reasons.setdefault(r, [])
1498                self.replaced_by.setdefault(r, [])
1499            for r in node.inputs:
1500                self.reasons.setdefault(r, [])
1501                self.replaced_by.setdefault(r, [])
1502
1503    def on_change_input(self, fgraph, node, i, r, new_r, reason=None):
1504        reason = str(reason)
1505        self.event_list.append(_FunctionGraphEvent('change', node,
1506                                                   reason=reason, idx=i))
1507
1508        self.reasons.setdefault(new_r, [])
1509        self.replaced_by.setdefault(new_r, [])
1510
1511        append_reason = True
1512        for tup in self.reasons[new_r]:
1513            if tup[0] == reason and tup[1] is r:
1514                append_reason = False
1515
1516        if append_reason:
1517            # N.B. compute the debugprint now, because future
1518            # optimizations will change the graph
1519            done = dict()
1520            used_ids = dict()
1521            self.reasons[new_r].append(
1522                (reason,
1523                 r,
1524                 debugprint(r, prefix='  ', depth=6,
1525                            file=StringIO(), done=done,
1526                            print_type=True,
1527                            used_ids=used_ids).getvalue(),
1528                 debugprint(new_r, prefix='  ', depth=6,
1529                            file=StringIO(), done=done,
1530                            print_type=True,
1531                            used_ids=used_ids).getvalue()))
1532            self.replaced_by[r].append((reason, new_r))
1533
1534        if r in self.equiv:
1535            r_set = self.equiv[r]
1536        else:
1537            r_set = self.equiv.setdefault(r, set([r]))
1538            self.all_variables_ever.append(r)
1539
1540        if new_r in self.equiv:
1541            new_r_set = self.equiv[new_r]
1542        else:
1543            new_r_set = self.equiv.setdefault(new_r, set([new_r]))
1544            self.all_variables_ever.append(new_r)
1545
1546        assert new_r in new_r_set
1547        assert r in r_set
1548
1549        # update one equivalence set to contain the other
1550        # transfer all the elements of the old one to the new one
1551        r_set.update(new_r_set)
1552        for like_new_r in new_r_set:
1553            self.equiv[like_new_r] = r_set
1554            assert like_new_r in r_set
1555
1556        assert self.equiv[r] is r_set
1557        assert self.equiv[new_r] is r_set
1558
1559    def printstuff(self):
1560        for key in self.equiv:
1561            print(key)
1562            for e in self.equiv[key]:
1563                print('  ', e)
1564
1565
1566# List of default version of make thunk.
1567# This is needed to know if the user overrided it.
1568default_make_thunk = [get_unbound_function(theano.gof.Op.make_thunk)]
1569
1570
1571# Debug mode cheats and initializes the linker in a different way in
1572# its maker so we can cheat some more by having a linker to satisfy
1573# the external requirements of the .linker attribute of a mode
1574# 1) it's a class instance
1575# 2) it a has a .clone() method
1576class _DummyLinker(object):
1577    # This is not a real linker anyway
1578    def clone(self, allow_gc=None):
1579        return self
1580
1581
1582class _Linker(gof.link.LocalLinker):
1583    """
1584    Special debugging linker.
1585
1586    """
1587
1588    def __init__(self, maker, schedule=None):
1589        super(gof.LocalLinker, self).__init__()
1590        self.fgraph = None
1591        self.maker = maker
1592        if schedule:
1593            self.schedule = schedule
1594
1595    def accept(self, fgraph, no_recycling=None, profile=None):
1596        if no_recycling is None:
1597            no_recycling = []
1598        if self.fgraph is not None and self.fgraph is not fgraph:
1599            assert type(self) is _Linker
1600            return type(self)(maker=self.maker).accept(
1601                fgraph, no_recycling, profile)
1602        self.fgraph = fgraph
1603        self.no_recycling = no_recycling
1604        return self
1605
1606    def make_all(self, profiler=None, input_storage=None,
1607                 output_storage=None, storage_map=None):
1608        # can't import at toplevel because of circular import TODO:
1609        # don't do this ugly hacky way of setting the
1610        # filter_checks_isfinite
1611        from theano.tensor import TensorType  # to set filter_check_isfinite
1612
1613        fgraph = self.fgraph
1614        input_storage_ = input_storage
1615        output_storage_ = output_storage
1616
1617        # Compute a topological ordering that IGNORES the destroy_map
1618        # of destructive Ops.  This will be OK, because every thunk is
1619        # evaluated on a copy of its input.
1620        fgraph_equiv = fgraph.equivalence_tracker
1621        order_outputs = copy.copy(fgraph_equiv.all_variables_ever)
1622        del fgraph_equiv
1623        order_outputs.reverse()
1624        order = graph.io_toposort(fgraph.inputs, order_outputs)
1625
1626        # an ordering of just the active nodes
1627        active_order = self.schedule(fgraph)
1628        active_order_set = set(active_order)
1629
1630        # Disable no_recycling, in order to be able to use
1631        # check_preallocated_output even on the output of the function.
1632        # no_recycling in individual thunks does not really matter, since
1633        # the function's outputs will always be freshly allocated.
1634        no_recycling = []
1635
1636        input_storage, output_storage, storage_map = link.map_storage(
1637            fgraph, order, input_storage_, output_storage_, storage_map)
1638
1639        thunks_py = []  # python thunks
1640        thunks_c = []  # c thunks
1641
1642        for node in order:
1643            compute_map = {}
1644            for k in node.inputs:
1645                compute_map[k] = [True]
1646            for k in node.outputs:
1647                compute_map[k] = [False]
1648
1649            # Some Ops define a make_thunk with the expectation that
1650            # it will be called before the C code is compiled, because
1651            # the compilation of some dependency is triggered there.
1652            thunk_other = None
1653
1654            if (get_unbound_function(node.op.make_thunk) not in
1655                    default_make_thunk):
1656                thunk = node.op.make_thunk(node,
1657                                           storage_map,
1658                                           compute_map,
1659                                           no_recycling)
1660                thunk.inputs = [storage_map[v] for v in node.inputs]
1661                thunk.outputs = [storage_map[v] for v in node.outputs]
1662                thunk_other = thunk
1663
1664            debug = hasattr(node.op, 'debug_perform')
1665
1666            try:
1667                if not self.maker.mode.check_c_code or debug:
1668                    raise utils.MethodNotDefined()
1669                # Ops that do not inherit from gof.op.Op don't have certain
1670                # methods defined that the CLinker expects (Scan is an
1671                # example, ifelse is another of such classes that inherit
1672                # directly from PureOp)
1673                if not isinstance(node.op, gof.op.Op):
1674                    raise utils.MethodNotDefined()
1675
1676                node.op.prepare_node(node, storage_map, compute_map, 'c')
1677                thunk = node.op.make_c_thunk(node, storage_map, compute_map,
1678                                             no_recycling)
1679                thunks_c.append(thunk)
1680            except (NotImplementedError, utils.MethodNotDefined):
1681                thunks_c.append(None)
1682
1683            # Pure ops don't really have a perform ( or their perform just
1684            # raises an not implemented exception), so in those cases we
1685            # consider that we don't have a python implementation
1686            if (((self.maker.mode.check_py_code or thunks_c[-1] is None) and
1687                 node.op.perform.__code__ != gof.op.PureOp.perform.__code__) or
1688                    debug):
1689                node.op.prepare_node(node, storage_map, compute_map, 'py')
1690                thunk = node.op.make_py_thunk(node, storage_map, compute_map,
1691                                              no_recycling, debug=debug)
1692                thunks_py.append(thunk)
1693            else:
1694                thunks_py.append(None)
1695
1696            if not self.maker.mode.check_c_code and thunks_py[-1] is None:
1697                _logger.warn("Op %s doesn't have a perform, "
1698                             "forcing check of the C code" % node.op)
1699                node.op.prepare_node(node, storage_map, compute_map, 'c')
1700                thunk = node.op.make_c_thunk(node, storage_map, compute_map,
1701                                             no_recycling)
1702                thunks_c[-1] = thunk
1703
1704            # If the op defined its own make_thunk, use the generated thunk
1705            if thunk_other is not None:
1706                if thunks_py[-1] is None:
1707                    thunks_py[-1] = thunk_other
1708                elif thunks_c[-1] is None:
1709                    thunks_c[-1] = thunk_other
1710                else:
1711                    _logger.warn("We won't check the perform function "
1712                                 "of node '%s' but we will check its "
1713                                 "make_thunk function" % node)
1714                    thunks_py[-1] = thunk_other
1715
1716        # Use self.no_recycling (that was passed in accept()) to always
1717        # use new memory storage when it is needed, in particular for the
1718        # function's outputs. no_recycling_map will be used in f() below.
1719        if self.no_recycling is True:
1720            no_recycling_map = list(storage_map.values())
1721            no_recycling_map = utils.difference(no_recycling_map,
1722                                                input_storage)
1723        else:
1724            no_recycling_map = [storage_map[r] for r in self.no_recycling
1725                                if r not in fgraph.inputs]
1726
1727        # Precompute some things for storage pre-allocation
1728        def_val = int(config.unittests.rseed)
1729
1730        #####
1731        # This is the function that runs when you evaluate the graph
1732        #####
1733        def f():
1734            ####
1735            # Note: `f` ignores the compute_map and evaluates the nodes in
1736            # topological order. In some sense, this is ok, and can be used
1737            # for now.
1738            #####
1739            _logger.debug("starting a DebugMode call")
1740            _logger.debug("self.maker.mode.check_preallocated_output: %s",
1741                          self.maker.mode.check_preallocated_output)
1742            for x in no_recycling_map:
1743                x[0] = None
1744
1745            # nest all this in try-finally to put storage *back* into
1746            # storage_map when an exception is raised
1747            original_storage_map_keys = [r for r in storage_map
1748                                         if r.owner is None]
1749
1750            try:
1751                # r_vals are the true values associated with each
1752                # variable in the graph they should not change during
1753                # the evaluation of this function, even when the graph
1754                # has destructive ops in it
1755                #
1756                # This dictionary is used to populate the storage_map
1757                # as necessary
1758                r_vals = {}
1759
1760                # dr_vals are the values taken by variables after
1761                # being destroyed
1762                dr_vals = {}
1763                assert len(thunks_py) == len(order)
1764
1765                # transfer the initial values from the storage_map to
1766                # the r_vals
1767                _logger.debug("DEBUGMODE: transfer initial values")
1768                # r_vals_initialized keeps track of the values that have
1769                # actually been transferred from storage_map to r_vals
1770                r_vals_initialized = []
1771                for r in storage_map:
1772                    if (r.owner is None):
1773                        if not r.type.is_valid_value(storage_map[r][0]):
1774                            # None may be a valid input value (for instance,
1775                            # for a Generic object). We only want to raise
1776                            # an error if it is not valid.
1777                            if (storage_map[r][0] is None):
1778                                raise InvalidValueError(
1779                                    r, storage_map[r][0],
1780                                    hint=("Graph Input '%s' is missing" %
1781                                          str(r)))
1782                            raise InvalidValueError(
1783                                r, storage_map[r][0],
1784                                hint=("Graph Input '%s' has invalid value "
1785                                      "%s" % (r, storage_map[r][0])))
1786                        r_vals[r] = storage_map[r][0]
1787                        storage_map[r][0] = None
1788                        r_vals_initialized.append(r)
1789
1790                # store preallocated outputs in another map, and test
1791                # the thunks on them as output storages.
1792                init_outputs = {}
1793                for r in storage_map:
1794                    if r in fgraph.outputs:
1795                        if storage_map[r][0] is not None:
1796                            init_outputs[r] = storage_map[r][0]
1797                            storage_map[r][0] = None
1798
1799                #####
1800                #  Precondition: the storage map is empty, transferred
1801                #  completely to r_vals
1802                #####
1803                for r, s in iteritems(storage_map):
1804                    if s[0] is not None:
1805                        print(r, s)
1806                    assert s[0] is None
1807
1808                # try:
1809                # compute the value of all variables
1810                for i, (thunk_py, thunk_c, node) in enumerate(zip(thunks_py,
1811                                                                  thunks_c,
1812                                                                  order)):
1813                    _logger.debug("%i - starting node %i %s", i, i, node)
1814
1815                    # put a copy of each input into the storage_map
1816                    # also, check that inputs have valid values
1817                    for r in node.inputs:
1818                        assert isinstance(r, gof.Variable)
1819                        assert r in r_vals
1820                        storage_map[r][0] = _lessbroken_deepcopy(r_vals[r])
1821                        if not r.type.is_valid_value(storage_map[r][0]):
1822                            raise InvalidValueError(r, storage_map[r][0],
1823                                                    client_node=node)
1824
1825                    # On the first call to thunk_py(), its output
1826                    # storage will be None
1827                    if thunk_py:
1828                        _logger.debug("%i - running thunk_py with None as "
1829                                      "output storage", i)
1830                        try:
1831                            thunk_py()
1832                        except (utils.MethodNotDefined, NotImplementedError):
1833                            # shouldn't have put it into the list in
1834                            # the first place
1835                            thunk_py = None
1836                            thunks_py[i] = None
1837                        except Exception as e:
1838                            # I think that only 1 optimization can
1839                            # insert a given apply node. If that is not True,
1840                            # we would need to loop over all node outputs,
1841                            # But this make the output uglier.
1842                            reason = fgraph.equivalence_tracker.reasons[
1843                                node.outputs[0]]
1844                            if not reason:
1845                                raise
1846                            opt = str(reason[0][0])
1847                            msg = (
1848                                "An optimization (probably %s) inserted an "
1849                                "apply node that raise an error." % opt +
1850                                "\nThe information we have about this "
1851                                "optimizations is:" + str(reason[0][1]) +
1852                                "\n" + reason[0][2] +
1853                                "\n\nThe original exception: \n" + str(e))
1854                            new_e = e.__class__(msg)
1855                            exc_type, exc_value, exc_trace = sys.exc_info()
1856                            exc_value = new_e
1857                            raise_with_op(node, thunk_c,
1858                                          (exc_type, exc_value, exc_trace))
1859
1860                    if thunk_py:
1861                        # check output values for type-correctness
1862                        for r in node.outputs:
1863                            if not r.type.is_valid_value(storage_map[r][0]):
1864                                hint2 = r.type.value_validity_msg(
1865                                    storage_map[r][0])
1866                                raise InvalidValueError(r, storage_map[r][0],
1867                                                        hint='perform output',
1868                                                        specific_hint=hint2)
1869                        warn_inp = config.DebugMode.warn_input_not_reused
1870                        py_inplace_outs = _check_inputs(
1871                            node, storage_map, r_vals, dr_vals,
1872                            active_order_set,
1873                            clobber_dr_vals=True, perform='py',
1874                            warn_input_not_reused=warn_inp)
1875                        _check_viewmap(node, storage_map)
1876
1877                        # Retrieve each output from the storage_map.
1878                        # The return values of this first run will be
1879                        # the reference ones
1880                        for r in node.outputs:
1881                            assert r not in r_vals
1882                            r_vals[r] = storage_map[r][0]
1883                            # clear the storage_map of outputs for the thunk_c
1884                            storage_map[r][0] = None
1885
1886                        if self.maker.mode.check_preallocated_output:
1887                            prealloc_modes = \
1888                                self.maker.mode.check_preallocated_output
1889                            _logger.debug(
1890                                '%i - calling _check_preallocated_output '
1891                                'with thunk_py', i)
1892                            _check_preallocated_output(
1893                                node=node,
1894                                thunk=thunk_py,
1895                                prealloc_modes=prealloc_modes,
1896                                def_val=def_val,
1897                                storage_map=storage_map,
1898                                r_vals=r_vals,
1899                                dr_vals=dr_vals,
1900                                perform='py',
1901                                active_order_set=active_order_set,
1902                                inplace_outs=py_inplace_outs,
1903                                init_outputs=init_outputs)
1904
1905                        sys.stdout.flush()
1906
1907                    if thunk_c:
1908
1909                        clobber = True
1910                        if thunk_py:
1911                            dmap = getattr(node.op, 'destroy_map', {})
1912                            vmap = getattr(node.op, 'view_map', {})
1913                            # FIXME: This overwrites the outer loop variable `i`.
1914                            for i, r in enumerate(node.inputs):
1915                                # if thunk_py ran, and we still got
1916                                # this far, it means that the
1917                                # destroy_map of the Op (and view_map)
1918                                # are accurate so we can assume that
1919                                # inputs not marked as destroyed have
1920                                # in fact not been destroyed.
1921                                # Therefore... we only need to
1922                                # overwrite inputs that *have* been
1923                                # marked as destroyed.  Inputs marked
1924                                # as viewd are unsafe too, because the
1925                                # corresponding output can be
1926                                # destroyed.
1927                                if any(i in v for v in chain(dmap.values(),
1928                                                             vmap.values())):
1929                                    storage_map[r][0] = _lessbroken_deepcopy(
1930                                        r_vals[r])
1931
1932                            clobber = False
1933
1934                        _logger.debug("%i - running thunk_c", i)
1935                        # First time, with None in output_storage
1936                        try:
1937                            thunk_c()
1938                        except Exception as e:
1939                            # I think that only 1 optimization can
1940                            # insert a given apply node. If that is not True,
1941                            # we would need to loop over all node outputs,
1942                            # But this make the output uglier.
1943                            reason = fgraph.equivalence_tracker.reasons[
1944                                node.outputs[0]]
1945                            if not reason:
1946                                raise
1947                            opt = str(reason[0][0])
1948                            msg = (
1949                                "An optimization (probably %s) inserted "
1950                                "an apply node that raise an error." % opt +
1951                                "\nThe information we have about this "
1952                                "optimizations is:" + str(reason[0][1]) +
1953                                "\n" + reason[0][2] +
1954                                "\n\nThe original exception: \n" + str(e))
1955                            new_e = e.__class__(msg)
1956                            exc_type, exc_value, exc_trace = sys.exc_info()
1957                            exc_value = new_e
1958                            raise_with_op(node, thunk_c,
1959                                          (exc_type, exc_value, exc_trace))
1960
1961                        for r in node.outputs:
1962                            # check output values for type-correctness
1963                            if not r.type.is_valid_value(storage_map[r][0]):
1964                                raise InvalidValueError(r, storage_map[r][0],
1965                                                        hint='c output')
1966
1967                            if thunk_py:
1968                                # because we put it in during the
1969                                # thunk_py branch
1970                                assert r in r_vals
1971                                # check for stride correctness (may
1972                                # raise exception)
1973                                _check_strides_match(
1974                                    r_vals[r], storage_map[r][0],
1975                                    self.maker.mode.require_matching_strides,
1976                                    node.op)
1977
1978                        warn_inp = config.DebugMode.warn_input_not_reused
1979                        c_inplace_outs = _check_inputs(
1980                            node, storage_map, r_vals,
1981                            dr_vals, active_order_set,
1982                            clobber_dr_vals=clobber, perform='c',
1983                            warn_input_not_reused=warn_inp)
1984
1985                        _check_viewmap(node, storage_map)
1986
1987                        # Check with Python result
1988                        for r in node.outputs:
1989                            if r in r_vals:
1990                                # compares the version from thunk_py
1991                                # (in r_vals) to the version produced
1992                                # by thunk_c (in storage_map)
1993                                if not check_eq(r, r_vals[r],
1994                                                storage_map[r][0]):
1995                                    inputs_val = [storage_map[inp][0]
1996                                                  for inp in r.owner.inputs]
1997                                    raise BadThunkOutput(
1998                                        r, thunk1='perform', val1=r_vals[r],
1999                                        thunk2='c_code',
2000                                        val2=storage_map[r][0],
2001                                        inputs_val=inputs_val)
2002                            else:
2003                                # retrieve each output from the storage_map
2004                                r_vals[r] = storage_map[r][0]
2005                            # clear the storage_map for the thunk_c
2006                            storage_map[r][0] = None
2007
2008                        if self.maker.mode.check_preallocated_output:
2009                            prealloc_modes = \
2010                                self.maker.mode.check_preallocated_output
2011
2012                            def thunk():
2013                                try:
2014                                    thunk_c()
2015                                except Exception:
2016                                    raise_with_op(node, thunk_c)
2017                            _logger.debug(
2018                                '%i - calling _check_preallocated_output '
2019                                'with thunk_c', i)
2020                            _check_preallocated_output(
2021                                node=node,
2022                                thunk=thunk,
2023                                prealloc_modes=prealloc_modes,
2024                                def_val=def_val,
2025                                storage_map=storage_map,
2026                                r_vals=r_vals,
2027                                dr_vals=dr_vals,
2028                                perform='c code',
2029                                active_order_set=active_order_set,
2030                                inplace_outs=c_inplace_outs,
2031                                init_outputs=init_outputs)
2032
2033                        sys.stdout.flush()
2034
2035                    # we're done with this thunk
2036                    # clear everything out of the storage_map
2037                    for r in node.inputs:
2038                        storage_map[r][0] = None
2039                    _logger.debug("%i - done with node", i)
2040                    for r in node.outputs:
2041                        if r not in r_vals:
2042                            idx = order.index(node)
2043                            assert thunks_py[idx] is None, node
2044                            assert thunks_c[idx] is None, node
2045                            raise Exception("No code run for %s" % node)
2046
2047                if False:
2048                    # This could be useful to help finding refcount problem.
2049                    # But it is very slow and it is not sure it will help.
2050                    gc.collect()
2051
2052                _find_bad_optimizations(order,
2053                                        fgraph.equivalence_tracker.reasons,
2054                                        r_vals)
2055
2056                #####
2057                #  Postcondition: the input and output variables are
2058                #  in the storage map, nothing more
2059                #####
2060
2061                # Nothing should be in storage map after evaluating
2062                # each the thunk (specifically the last one)
2063                for r, s in iteritems(storage_map):
2064                    assert type(s) is list
2065                    assert s[0] is None
2066
2067                # store our output variables to their respective storage lists
2068                for output, storage in zip(fgraph.outputs, output_storage):
2069                    storage[0] = r_vals[output]
2070
2071                # transfer all inputs back to their respective storage lists
2072                for r in r_vals:
2073                    if r.owner is None:
2074                        if r in fgraph.inputs:
2075                            assert (storage_map[r] is
2076                                    input_storage[fgraph.inputs.index(r)])
2077                        storage_map[r][0] = r_vals[r]
2078
2079                # if an input was destroyed, the destroyed value
2080                # should be returned
2081                for r in dr_vals:
2082                    assert dr_vals[r][0] is not None
2083                    if r.owner is None:
2084                        assert r in fgraph.inputs
2085                        # HACK TO LOOK LIKE A REAL DESTRUCTIVE ACTION
2086                        # TOOK PLACE
2087                        if ((type(dr_vals[r][0]) in
2088                             (np.ndarray, np.memmap)) and
2089                            (dr_vals[r][0].dtype ==
2090                             storage_map[r][0].dtype) and
2091                            (dr_vals[r][0].shape ==
2092                             storage_map[r][0].shape)):
2093                            if len(dr_vals[r][0].shape):
2094                                storage_map[r][0][:] = dr_vals[r][0]
2095                            else:
2096                                storage_map[r][0].itemset(dr_vals[r][0])
2097                        else:
2098                            storage_map[r][0] = dr_vals[r][0]
2099            except Exception:
2100                # Restore the initial state of storage_map
2101                for r in storage_map:
2102                    if r in original_storage_map_keys:
2103                        # If r was transferred to r_vals, put it back
2104                        if r in r_vals_initialized:
2105                            storage_map[r][0] = r_vals[r]
2106                    else:
2107                        # clear out any partially-computed stuff
2108                        storage_map[r][0] = None
2109                raise
2110
2111            for r in storage_map:
2112                if (r.owner is None):
2113                    if not r.type.is_valid_value(None):
2114                        assert storage_map[r][0] is not None
2115
2116            ###############
2117            # Done debugmode function call 'f'
2118            ##############
2119
2120        def run_with_tensortype_filter_check(f):
2121            def deco():
2122                # WARNING: this is a global mechanism...
2123                # so it will screw up if we are trying to use
2124                # multiple modes at once.
2125                old_filter_checks_isfinite = TensorType.filter_checks_isfinite
2126                TensorType.filter_checks_isfinite = \
2127                    self.maker.mode.check_isfinite
2128                try:
2129                    return f()
2130                finally:
2131                    # put back the filter_checks_isfinite
2132                    TensorType.filter_checks_isfinite = \
2133                        old_filter_checks_isfinite
2134            return deco
2135
2136        f = run_with_tensortype_filter_check(f)
2137        f.storage_map = storage_map
2138        f.allow_gc = True
2139        assert len(fgraph.inputs) == len(input_storage)
2140        assert len(fgraph.outputs) == len(output_storage)
2141        return (f,
2142                [link.Container(input, storage, readonly=False)
2143                 for input, storage in zip(fgraph.inputs, input_storage)],
2144                [link.Container(output, storage, readonly=True)
2145                 for output, storage in zip(fgraph.outputs, output_storage)],
2146                thunks_py, order)
2147
2148
2149_NODEFAULT = ['NODEFAULT']
2150
2151
2152class _Maker(FunctionMaker):  # inheritance buys a few helper functions
2153    """
2154    Special debugging FunctionMaker.
2155
2156    Parameters
2157    ----------
2158    inputs : list of SymbolicInput instances
2159    outputs : list of SymbolicOutput instances
2160        Outputs may also be a single Variable (not a list), in which case
2161        the functions produced by FunctionMaker will return their output
2162        value directly.
2163    accept_inplace
2164        True iff it is acceptable to have inplace operations in the graph from
2165        the inputs to the outputs.
2166    on_unused_input
2167        What to do if a variable in the 'inputs' list is not used in the
2168        graph. Possible values are 'raise', 'warn' and 'ignore'.
2169    output_keys
2170        If the outputs argument for theano.function was a list, then
2171        output_keys is None. If the outputs argument was a dict, then
2172        output_keys is a sorted list of the keys from that dict.
2173
2174    Notes
2175    -----
2176    The constructor sets TensorType.filter_checks_isfinite when
2177    `mode.check_isfinite` is True.
2178
2179    """
2180
2181    verbose = 0
2182    """
2183    Verbosity level of compile-time and run-time checks. (Default 0: silent).
2184
2185    """
2186
2187    def __init__(self, inputs, outputs, mode,
2188                 accept_inplace=False,
2189                 function_builder=Function,
2190                 profile=None,
2191                 on_unused_input=None,
2192                 fgraph=None,  # If present the optimized graph. we ignore it.
2193                 output_keys=None,
2194                 name=None):
2195        self.mode = mode
2196        self.profile = profile
2197        if profile:
2198            raise Exception("DebugMode do not support profiling.")
2199        optimizer = mode.optimizer
2200        # Handle the case where inputs and/or outputs is a single
2201        # Variable (not in a list)
2202        unpack_single = False
2203        return_none = False
2204        if outputs is None:
2205            return_none = True
2206            outputs = []
2207        if not isinstance(outputs, (list, tuple)):
2208            unpack_single = True
2209            outputs = [outputs]
2210        if not isinstance(inputs, (list, tuple)):
2211            inputs = [inputs]
2212
2213        # Wrap them in In or Out instances if needed.
2214        inputs = [self.wrap_in(i) for i in inputs]
2215        outputs = [self.wrap_out(o) for o in outputs]
2216
2217        _inputs = gof.graph.inputs([o.variable for o in outputs] +
2218                                   [i.update for i in inputs
2219                                    if getattr(i, 'update', False)])
2220
2221        # Check if some input variables are unused
2222        self._check_unused_inputs(inputs, outputs, on_unused_input)
2223
2224        # Make a list of (SymbolicInput|SymblicInputKits, indices,
2225        # [SymbolicInput,...]), one tuple for each input. (See
2226        # Function.indices for more details)
2227        indices = [[input] + self.expand_in(input, _inputs)
2228                   for input in inputs]
2229
2230        # make the fgraph
2231        for i in xrange(mode.stability_patience):
2232            fgraph, additional_outputs, equivalence_tracker = _optcheck_fgraph(
2233                inputs, outputs, accept_inplace)
2234            fgraph.equivalence_tracker = equivalence_tracker
2235
2236            with change_flags(compute_test_value=config.compute_test_value_opt):
2237                optimizer(fgraph)
2238
2239                theano.compile.function_module.insert_deepcopy(
2240                    fgraph, inputs, list(chain(outputs, additional_outputs)))
2241
2242            if i == 0:
2243                fgraph0 = fgraph
2244            else:
2245                li = fgraph.equivalence_tracker.event_list
2246                l0 = fgraph0.equivalence_tracker.event_list
2247                if li != l0:
2248                    infolog = StringIO()
2249                    print("WARNING: Optimization process is unstable...",
2250                          file=infolog)
2251                    print("  (HINT: Ops that the nodes point to must compare "
2252                          "equal)", file=infolog)
2253                    print("(event index)  (one event trace)  (other event "
2254                          "trace)", file=infolog)
2255                    print("-------------------------------------------------"
2256                          "----", file=infolog)
2257                    for j in xrange(max(len(li), len(l0))):
2258                        if j >= len(li):
2259                            print('trailing event in optimization 0 :', j,
2260                                  file=infolog)
2261                            print('   ', str(l0[j]), file=infolog)
2262                        elif j >= len(l0):
2263                            print('trailing event in optimization', i, ':',
2264                                  j, file=infolog)
2265                            print('   ', str(li[j]), file=infolog)
2266                        elif li[j] != l0[j]:
2267                            print('non-equal optimization events', i, ':',
2268                                  j, file=infolog)
2269                            print('   ', str(l0[j]), file=infolog)
2270                            print('   ', str(li[j]), file=infolog)
2271                        else:
2272                            pass
2273                    raise StochasticOrder(infolog.getvalue())
2274                else:
2275                    if self.verbose:
2276                        print("OPTCHECK: optimization", i,
2277                              "of", len(li), "events was stable.",
2278                              file=sys.stderr)
2279        self.fgraph = fgraph
2280        if theano.config.cycle_detection == 'regular':
2281            destroy_handler_added = False
2282            for feature in fgraph._features:
2283                if isinstance(feature, gof.DestroyHandler):
2284                    destroy_handler_added = True
2285                    break
2286            if not destroy_handler_added:
2287                fgraph.attach_feature(gof.DestroyHandler())
2288            for o in fgraph.outputs:
2289                try:
2290                    with change_flags(compute_test_value=config.compute_test_value_opt):
2291                        fgraph.replace_validate(o, _output_guard(o), reason='output_guard')
2292                    raise Exception("Output variable %s required output_guard, "
2293                                    "how was this output left unprotected against "
2294                                    "destructive operations?" % o)
2295
2296                except gof.InconsistencyError:
2297                    # This output is already impossible to destroy.
2298                    # No guard necessary
2299                    pass
2300
2301        linker = _Linker(self)
2302
2303        # the 'no_borrow' outputs are the ones for which that we can't return
2304        # the internal storage pointer.
2305
2306        no_borrow = [output for output, spec in
2307                     izip(fgraph.outputs, outputs + additional_outputs)
2308                     if not spec.borrow]
2309        if no_borrow:
2310            self.linker = linker.accept(
2311                fgraph, no_recycling=infer_reuse_pattern(fgraph, no_borrow))
2312        else:
2313            self.linker = linker.accept(fgraph)
2314        fgraph.name = name
2315        self.indices = indices
2316        self.inputs = inputs
2317        self.expanded_inputs = inputs
2318        self.outputs = outputs
2319        self.unpack_single = unpack_single
2320        self.return_none = return_none
2321        self.accept_inplace = accept_inplace
2322        self.function_builder = function_builder
2323        self.on_unused_input = on_unused_input  # Used for the pickling/copy
2324        self.output_keys = output_keys
2325        self.name = name
2326
2327        self.required = [(i.value is None) for i in self.inputs]
2328        self.refeed = [
2329            (i.value is not None and
2330             not isinstance(i.value, gof.Container) and
2331             i.update is None)
2332            for i in self.inputs]
2333
2334
2335########################
2336#
2337# API symbol: DebugMode
2338#
2339########################
2340
2341
2342class DebugMode(Mode):
2343    """
2344    Evaluation Mode that detects internal theano errors.
2345
2346    This mode catches several kinds of internal error:
2347
2348    - Inconsistent outputs when calling the same Op twice with the same
2349      inputs, for instance if c_code and perform implementations, are
2350      inconsistent, or in case of incorrect handling of output memory
2351      (see `BadThunkOutput`).
2352
2353    - A variable replacing another when their runtime values don't
2354      match.  This is a symptom of an incorrect optimization step, or
2355      faulty Op implementation (raises `BadOptimization`).
2356
2357    - Stochastic optimization ordering (raises `StochasticOrder`).
2358
2359    - Incomplete `destroy_map` specification (raises `BadDestroyMap`).
2360
2361    - An op that returns an illegal value not matching the output
2362      Variable Type (raises InvalidValueError).
2363
2364    Each of these exceptions inherits from the more generic `DebugModeError`.
2365
2366    If there are no internal errors, this mode behaves like FAST_RUN
2367    or FAST_COMPILE, but takes a little longer and uses more memory.
2368
2369    Raises
2370    ------
2371    DebugModeError
2372        If there are internal errors.
2373
2374    Notes
2375    -----
2376    The work of debugging is implemented by the `_Maker`, `_Linker`,
2377    and `_VariableEquivalenceTracker` classes.
2378
2379    """
2380
2381    stability_patience = config.DebugMode.patience
2382    """
2383    When checking for the stability of optimization, recompile the
2384    graph this many times.
2385
2386    """
2387
2388    check_c_code = config.DebugMode.check_c
2389    """
2390    Should we evaluate (and check) the `c_code` implementations?
2391
2392    """
2393
2394    check_py_code = config.DebugMode.check_py
2395    """
2396    Should we evaluate (and check) the `perform` implementations?
2397    Always checked if no `c_code`.
2398
2399    """
2400
2401    check_isfinite = config.DebugMode.check_finite
2402    """
2403    Should we check for (and complain about) NaN/Inf ndarray elements?
2404
2405    """
2406
2407    require_matching_strides = config.DebugMode.check_strides
2408    """
2409    Should we check for (and complain about) Ops whose python and C
2410    outputs are ndarrays with different strides? (This can catch bugs,
2411    but is generally overly strict.) 0 no check, 1 warn, 2 err.
2412
2413    """
2414
2415    check_preallocated_output = config.DebugMode.check_preallocated_output
2416    check_preallocated_output = check_preallocated_output.split(':')
2417    """
2418    List of strings representing ways to pre-allocate output memory in
2419    tests.  Valid values are: "previous" (previously-returned memory),
2420    "c_contiguous", "f_contiguous", "strided" (positive and negative
2421    strides), "wrong_size" (larger and smaller dimensions), and "ALL"
2422    (all of the above).
2423
2424    """
2425
2426    # This function will be used to create a FunctionMaker in
2427    # function_module.function
2428    def function_maker(self, i, o, m, *args, **kwargs):
2429        """
2430        Return an instance of `_Maker` which handles much of the debugging work.
2431
2432        """
2433        assert m is self
2434        return _Maker(i, o, self, *args, **kwargs)
2435
2436    def __init__(self,
2437                 optimizer='fast_run',
2438                 stability_patience=None,
2439                 check_c_code=None,
2440                 check_py_code=None,
2441                 check_isfinite=None,
2442                 check_preallocated_output=None,
2443                 require_matching_strides=None,
2444                 linker=_DummyLinker()):
2445        """
2446        If any of these arguments (except optimizer) is not None, it overrides
2447        the class default. The linker argument is not used. It is set there to
2448        allow Mode.requiring() and some other fct to work with DebugMode too.
2449
2450        """
2451
2452        if not isinstance(linker, _DummyLinker):
2453            raise Exception("DebugMode can only use its own linker! You "
2454                            "should not provide one.", linker)
2455
2456        super(DebugMode, self).__init__(optimizer=optimizer,
2457                                        linker=linker)
2458
2459        if stability_patience is not None:
2460            self.stability_patience = stability_patience
2461
2462        if check_c_code is not None:
2463            self.check_c_code = check_c_code
2464
2465        if check_py_code is not None:
2466            self.check_py_code = check_py_code
2467
2468        if check_isfinite is not None:
2469            self.check_isfinite = check_isfinite
2470
2471        if check_preallocated_output is not None:
2472            # Copy to avoid sharing the same list across different instances
2473            self.check_preallocated_output = check_preallocated_output[:]
2474
2475        if require_matching_strides is not None:
2476            self.require_matching_strides = require_matching_strides
2477
2478        if not (self.check_c_code or self.check_py_code):
2479            raise ValueError('DebugMode has to check at least one of c and py '
2480                             'code')
2481
2482    def __str__(self):
2483        return "DebugMode(linker=%s, optimizer=%s)" % (
2484            self.provided_linker, self.provided_optimizer)
2485
2486
2487register_mode('DEBUG_MODE', DebugMode(optimizer='fast_run'))
2488