1"""
2This file contains auxiliary Ops, used during the compilation phase and Ops
3building class (:class:`FromFunctionOp`) and decorator (:func:`as_op`) that
4help make new Ops more rapidly.
5
6"""
7from __future__ import absolute_import, print_function, division
8from collections import OrderedDict
9
10import copy
11import six.moves.cPickle as pickle
12import warnings
13
14import theano
15from theano import gof
16from six import iteritems, integer_types
17from six.moves import xrange
18
19
20import numpy as np
21
22
23def register_view_op_c_code(type, code, version=()):
24    """
25    Tell ViewOp how to generate C code for a Theano Type.
26
27    Parameters
28    ----------
29    type : Theano type
30        It must be the Theano class itself and not an instance of the class.
31    code : C code
32        Returns a view for the Theano type 'type'. Use %(iname)s and %(oname)s
33        for the input and output C variable names respectively.
34    version
35        A number indicating the version of the code, for cache.
36
37    """
38    ViewOp.c_code_and_version[type] = (code, version)
39
40
41class ViewOp(gof.Op):
42    """
43    Returns an inplace view of the input. Used internally by Theano.
44
45    """
46
47    view_map = {0: [0]}
48    # Mapping from Type to C code (and version) to use.
49    # In the C code, the name of the input variable is %(iname)s,
50    # the output variable is %(oname)s.
51    c_code_and_version = {}
52    __props__ = ()
53    _f16_ok = True
54
55    def make_node(self, x):
56        return gof.Apply(self, [x], [x.type()])
57
58    def perform(self, node, inp, out):
59        x, = inp
60        z, = out
61        z[0] = x
62
63    def __str__(self):
64        return '%s' % self.__class__.__name__
65
66    def c_code(self, node, nodename, inp, out, sub):
67        iname, = inp
68        oname, = out
69        fail = sub['fail']
70
71        itype = node.inputs[0].type.__class__
72        if itype in self.c_code_and_version:
73            code, version = self.c_code_and_version[itype]
74            return code % locals()
75
76        # Else, no C code
77        return super(ViewOp, self).c_code(node, nodename, inp, out, sub)
78
79    def c_code_cache_version(self):
80        version = []
81        # If any of the c code is unversionned, we have to return ()
82        # Else, we will return a list of (type name, version) pairs.
83        for t, (c, v) in sorted(iteritems(self.c_code_and_version),
84                                key=lambda pair: str(pair[0])):
85            if not v:
86                warnings.warn("Type %s has C code for ViewOp, but it has no "
87                              "version. You should add a 'version' keyword "
88                              "arg when calling register_view_op_c_code." % t,
89                              stacklevel=2)
90                return ()
91            version.append((str(t), v))
92
93        return tuple(version)
94
95    def infer_shape(self, node, input_shapes):
96        return input_shapes
97
98    def grad(self, args, g_outs):
99        return g_outs
100
101view_op = ViewOp()
102
103
104class OutputGuard(ViewOp):
105    """
106    This op is used only internally by Theano.
107
108    Only the AddDestroyHandler optimizer tries to insert them in the graph.
109
110    This Op is declared as destructive while it is not destroying anything.
111    It returns a view. This is used to prevent destruction of the output
112    variables of a Theano function.
113
114    There is a mechanism in Theano that should prevent this, but the use
115    of OutputGuard adds a safeguard: it may be possible for some optimization
116    run before the add_destroy_handler phase to bypass this mechanism, by
117    making in-place optimizations.
118
119    TODO: find a current full explanation.
120
121    """
122    destroy_map = {0: [0]}
123
124    check_input = False
125
126_output_guard = OutputGuard()
127
128
129def register_deep_copy_op_c_code(typ, code, version=()):
130    """
131    Tell DeepCopyOp how to generate C code for a Theano Type.
132
133    Parameters
134    ----------
135    typ : Theano type
136        It must be the Theano class itself and not an instance of the class.
137    code: C code
138        Deep copies the Theano type 'typ'. Use %(iname)s and %(oname)s for the
139        input and output C variable names respectively.
140    version
141        A number indicating the version of the code, for cache.
142
143    """
144    DeepCopyOp.c_code_and_version[typ] = (code, version)
145
146
147class DeepCopyOp(gof.Op):
148    # Mapping from Type to C code (and version) to use.
149    # In the C code, the name of the input variable is %(iname)s,
150    # the output variable is %(oname)s.
151    c_code_and_version = {}
152
153    check_input = False
154    __props__ = ()
155    _f16_ok = True
156
157    def __init__(self):
158        pass
159
160    def make_node(self, x):
161        return gof.Apply(self, [x], [x.type()])
162
163    def perform(self, node, args, outs):
164        if hasattr(args[0], 'copy'):
165            # when args[0] is a an ndarray of 0 dimensions,
166            # this return a numpy.dtype and not an ndarray
167            # So when the args have a copy attribute we use it
168            # as this don't have this problem
169            outs[0][0] = args[0].copy()
170        else:
171            outs[0][0] = copy.deepcopy(args[0])
172
173    def c_code_cache_version(self):
174        version = []
175        # If any of the c code is unversionned, we have to return ()
176        # Else, we will return a list of (type name, version) pairs.
177        for t, (c, v) in sorted(iteritems(self.c_code_and_version),
178                                key=lambda pair: str(pair[0])):
179            if not v:
180                warnings.warn("Type %s has C code for DeepCopyOp, but it has "
181                              "no version. You should add a 'version' keyword"
182                              " arg when calling "
183                              "register_deep_copy_op_c_code." % t,
184                              stacklevel=2)
185                return ()
186            version.append((str(t), v))
187
188        if version:
189            version.append(1)
190        return tuple(version)
191
192    def c_code(self, node, name, inames, onames, sub):
193        iname, = inames
194        oname, = onames
195        fail = sub['fail']
196
197        itype = node.inputs[0].type.__class__
198        if itype in self.c_code_and_version:
199            code, version = self.c_code_and_version[itype]
200            return code % locals()
201
202        # Else, no C code
203        return super(DeepCopyOp, self).c_code(node, name, inames, onames, sub)
204
205
206deep_copy_op = DeepCopyOp()
207
208
209def register_shape_c_code(type, code, version=()):
210    """
211    Tell Shape Op how to generate C code for a Theano Type.
212
213    Parameters
214    ----------
215    typ : Theano type
216        It must be the Theano class itself and not an instance of the class.
217    code : C code
218        Returns a vector representing the shape for the Theano type 'typ'.
219        Use %(iname)s and %(oname)s for the input and output C variable names
220        respectively.
221    version
222        A number indicating the version of the code, for cache.
223
224    """
225    Shape.c_code_and_version[type] = (code, version)
226
227
228class Shape(gof.Op):
229    """
230    L{Op} to return the shape of a matrix.
231
232    Notes
233    -----
234    Non-differentiable.
235
236    """
237
238    _f16_ok = True
239
240    # Mapping from Type to C code (and version) to use.
241    # In the C code, the name of the input variable is %(iname)s,
242    # the output variable is %(oname)s.
243    c_code_and_version = {}
244
245    check_input = False
246    __props__ = ()
247
248    def make_node(self, x):
249        # Must work for all type that have a shape attribute.
250        # This will fail at execution time.
251        if not isinstance(x, theano.Variable):
252            x = theano.tensor.as_tensor_variable(x)
253        return gof.Apply(self, [x], [theano.tensor.lvector()])
254
255    def perform(self, node, inp, out_):
256        x, = inp
257        out, = out_
258        out[0] = theano._asarray(x.shape, dtype='int64')
259
260    def infer_shape(self, node, in_shapes):
261        return [[len(in_shapes[0])]]
262
263    def connection_pattern(self, node):
264        # the grad returns the gradient with respect to the
265        # elements of a tensor variable
266        # the elements of the tensor variable do not participate
267        # in the computation of the shape, so they are not really
268        # part of the graph
269        return [[False]]
270
271    def grad(self, inp, grads):
272        # the grad returns the gradient with respect to the
273        # elements of a tensor variable
274        # the elements of the tensor variable do not participate
275        # in the computation of the shape, so they are not really
276        # part of the graph
277        return [theano.gradient.DisconnectedType()()]
278
279    def R_op(self, inputs, eval_points):
280        return [None]
281
282    def c_code(self, node, name, inames, onames, sub):
283        iname, = inames
284        oname, = onames
285        fail = sub['fail']
286
287        itype = node.inputs[0].type.__class__
288        if itype in self.c_code_and_version:
289            code, version = self.c_code_and_version[itype]
290            return code % locals()
291
292        # Else, no C code
293        return super(Shape, self).c_code(node, name, inames, onames, sub)
294
295    def c_code_cache_version(self):
296        version = []
297        # If any of the c code is unversionned, we have to return ()
298        # Else, we will return a list of (type name, version) pairs.
299        for t, (c, v) in sorted(iteritems(self.c_code_and_version),
300                                key=lambda pair: str(pair[0])):
301            if not v:
302                warnings.warn("Type %s has C code for Shape, but it has no "
303                              "version. You should add a 'version' keyword "
304                              "arg when calling register_shape_c_code." % t,
305                              stacklevel=2)
306                return ()
307            version.append((str(t), v))
308
309        if version:
310            version.append(1)
311
312        return tuple(version)
313
314
315shape = Shape()
316_shape = shape  # was used in the past, now use shape directly.
317
318
319class Shape_i(gof.Op):
320    """
321    L{Op} to return the shape of a matrix.
322
323    Notes
324    -----
325    Non-differentiable.
326
327    """
328
329    _f16_ok = True
330
331    # Mapping from Type to C code (and version) to use.
332    # In the C code, the name of the input variable is %(iname)s,
333    # the output variable is %(oname)s.
334    c_code_and_version = {}
335
336    check_input = False
337
338    __props__ = ("i",)
339
340    def __init__(self, i):
341        # As i will be used in the hash and that ndarray are not hashable,
342        # we need to convert it to an int as it is hashable.
343        if isinstance(i, np.ndarray):
344            assert i.dtype in theano.tensor.integer_dtypes
345        assert i == int(i)
346        i = int(i)
347        self.i = i
348
349    # NB:
350    # 1) params_type is defined as a property to avoid
351    #    loop in Python import caused by importing theano.scalar below
352    #    when params_type is defined directly in class code.
353    # 2) We wrap scalar into ParamsType (instead of directly using scalar as op param)
354    #    to avoid Theano converting scalar param to constant that would be later
355    #    hardcoded as litteral in C code, making us loose all the advantages of
356    #    using params.
357    @property
358    def params_type(self):
359        return gof.ParamsType(i=theano.scalar.basic.int64)
360
361    def __str__(self):
362        return '%s{%i}' % (self.__class__.__name__, self.i)
363
364    def make_node(self, x):
365        # x could be one of a number of types
366        # the only thing we require is that the variable have a .ndim,
367        # and that the value have a .shape
368        if not isinstance(x, theano.Variable):
369            raise TypeError('x must be Variable with ndim attribute', x)
370        if x.ndim <= self.i:
371            raise TypeError('x has too few dimensions for Shape_i',
372                            (x, self.i))
373        return theano.Apply(self, [x], [theano.tensor.lscalar()])
374
375    def perform(self, node, inp, out_, params):
376        x, = inp
377        out, = out_
378        if out[0] is None:
379            out[0] = theano._asarray(x.shape[self.i], dtype='int64')
380        else:
381            out[0][...] = x.shape[self.i]
382
383    def c_code_cache_version(self):
384        version = []
385        # If any of the c code is unversionned, we have to return ()
386        # Else, we will return a list of (type name, version) pairs.
387        for t, (c, ci, v) in sorted(iteritems(self.c_code_and_version),
388                                    key=lambda pair: str(pair[0])):
389            if not v:
390                warnings.warn("Type %s has C code for Shape_i, but it has "
391                              "no version. You should add a 'version' keyword "
392                              "arg when calling register_shape_i_c_code." % t,
393                              stacklevel=2)
394                return ()
395            version.append((str(t), v))
396
397        if version:
398            version.append(2)
399
400        return tuple(version)
401
402    def c_code(self, node, name, inames, onames, sub):
403        iname, = inames
404        oname, = onames
405        fail = sub['fail']
406        # i is then 'params->i', not just 'params'.
407        i = sub['params'] + '->i'
408
409        itype = node.inputs[0].type.__class__
410        if itype in self.c_code_and_version:
411            code, check_input, version = self.c_code_and_version[itype]
412            return (check_input + code) % locals()
413
414        # Else, no C code
415        return super(Shape_i, self).c_code(node, name, inames, onames, sub)
416
417    def infer_shape(self, node, input_shapes):
418        return [()]
419
420    def connection_pattern(self, node):
421        # the grad returns the gradient with respect to the
422        # elements of a tensor variable
423        # the elements of the tensor variable do not participate
424        # in the computation of the shape, so they are not really
425        # part of the graph
426        return [[False]]
427
428    def grad(self, inp, grads):
429        return [theano.gradient.grad_not_implemented(
430                op=self, x_pos=0, x=inp[0],
431                comment=("No gradient for the shape of a matrix "
432                         "is implemented."))]
433
434
435def shape_i(var, i, fgraph=None):
436    """
437    Equivalent of var.shape[i], but apply if possible the shape feature
438    optimization.
439
440    This is useful in optimization that need to get the shape. This
441    remove the need of the following shape_feature optimization that
442    convert it. So this speed up optimization and remove Equilibrium
443    max iteration problems.
444
445    Parameters
446    ----------
447    var
448        The variable we want to take the shape of.
449    i
450        The shape dimensions we want
451    fgraph : optional
452        If var.fgraph do not exist, the fgraph that have the shape_feature to
453        introduce var in to get the optimized shape.
454
455    """
456    if fgraph is None and hasattr(var, 'fgraph'):
457        fgraph = var.fgraph
458    if fgraph and hasattr(fgraph, 'shape_feature'):
459        shape_feature = fgraph.shape_feature
460        shape_of = shape_feature.shape_of
461
462        def recur(node):
463            if not node.outputs[0] in shape_of:
464                for inp in node.inputs:
465                    if inp.owner:
466                        recur(inp.owner)
467                # If the output var isn't marked as being in the graph,
468                # we need to add it in the ShapeFeature.
469                shape_feature.on_import(fgraph, node,
470                                        'gof.ops.shape_i')
471        if var not in shape_of:
472            recur(var.owner)
473        return shape_of[var][i]
474
475    # If we are not able to use the shape feature, we should not put
476    # Shape_i in the graph. Otherwise, the shape feature optimization
477    # won't get applied.
478    return var.shape[i]
479
480
481def shape_i_op(i):
482    key = i
483    if key not in shape_i_op.cache:
484        shape_i_op.cache[key] = Shape_i(i)
485    return shape_i_op.cache[key]
486shape_i_op.cache = {}
487
488
489def register_shape_i_c_code(typ, code, check_input, version=()):
490    """
491    Tell Shape_i how to generate C code for a Theano Type.
492
493    Parameters
494    ----------
495    typ : Theano type
496        It must be the Theano class itself and not an instance of the class.
497    code : C code
498        Gets the shape of dimensions %(i)s for the Theano type 'typ'.
499        Use %(iname)s and %(oname)s for the input and output C variable names
500        respectively.
501    version
502        A number indicating the version of the code, for cache.
503
504    """
505    Shape_i.c_code_and_version[typ] = (code, check_input, version)
506
507
508# List of Theano Types that one can add an extra dimension and for which
509# Scan can deal with.
510expandable_types = ()
511
512
513def load_back(mod, name):
514    __import__(mod)
515    import sys
516    module = sys.modules[mod]
517    obj = getattr(module, name)
518    return obj
519
520
521class FromFunctionOp(gof.Op):
522    """
523    Build a basic Theano Op around a function.
524
525    Since the resulting Op is very basic and is missing most of the
526    optional functionalities, some optimizations may not apply.  If you
527    want to help, you can supply an infer_shape function that computes
528    the shapes of the output given the shapes of the inputs.
529
530    Also the gradient is undefined in the resulting op and Theano will
531    raise an error if you attempt to get the gradient of a graph
532    containing this op.
533
534    """
535
536    def __init__(self, fn, itypes, otypes, infer_shape):
537        self.__fn = fn
538        self.itypes = itypes
539        self.otypes = otypes
540        self.__infer_shape = infer_shape
541        if self.__infer_shape is not None:
542            self.infer_shape = self._infer_shape
543
544    def __eq__(self, other):
545        return (type(self) == type(other) and
546                self.__fn == other.__fn)
547
548    def __hash__(self):
549        return hash(type(self)) ^ hash(self.__fn)
550
551    def __str__(self):
552        return 'FromFunctionOp{%s}' % self.__fn.__name__
553
554    def perform(self, node, inputs, outputs):
555        outs = self.__fn(*inputs)
556        if not isinstance(outs, (list, tuple)):
557            outs = (outs,)
558        assert len(outs) == len(outputs)
559        for i in range(len(outs)):
560            outputs[i][0] = outs[i]
561
562    def __reduce__(self):
563        mod = self.__fn.__module__
564        name = self.__fn.__name__
565        try:
566            obj = load_back(mod, name)
567        except (ImportError, KeyError, AttributeError):
568            raise pickle.PicklingError(
569                "Can't pickle as_op(), not found as %s.%s" %
570                (mod, name))
571        else:
572            if obj is not self:
573                raise pickle.PicklingError(
574                    "Can't pickle as_op(), not the object "
575                    "at %s.%s" % (mod, name))
576        return load_back, (mod, name)
577
578    def _infer_shape(self, node, input_shapes):
579        return self.__infer_shape(node, input_shapes)
580
581
582def as_op(itypes, otypes, infer_shape=None):
583    """
584    Decorator that converts a function into a basic Theano op that will call
585    the supplied function as its implementation.
586
587    It takes an optional infer_shape parameter that should be a callable with
588    this signature:
589
590        def infer_shape(node, input_shapes):
591            ...
592            return output_shapes
593
594    Here `input_shapes` and `output_shapes` are lists of tuples that represent
595    the shape of the corresponding inputs/outputs.
596
597    This should not be used when performance is a concern since the very basic
598    nature of the resulting Op may interfere with certain graph optimizations.
599
600    Examples
601    --------
602    @as_op(itypes=[theano.tensor.fmatrix, theano.tensor.fmatrix],
603           otypes=[theano.tensor.fmatrix])
604    def numpy_dot(a, b):
605        return numpy.dot(a, b)
606
607    """
608    if not isinstance(itypes, (list, tuple)):
609        itypes = [itypes]
610    if any(not isinstance(t, theano.Type) for t in itypes):
611        raise TypeError("itypes has to be a list of Theano types")
612    if not isinstance(otypes, (list, tuple)):
613        otypes = [otypes]
614    if any(not isinstance(t, theano.Type) for t in otypes):
615        raise TypeError("otypes has to be a list of Theano types")
616
617    # make sure they are lists and not tuples
618    itypes = list(itypes)
619    otypes = list(otypes)
620
621    if infer_shape is not None and not callable(infer_shape):
622        raise TypeError("infer_shape needs to be a callable")
623
624    def make_op(fn):
625        return FromFunctionOp(fn, itypes, otypes, infer_shape)
626    return make_op
627
628
629def register_rebroadcast_c_code(typ, code, version=()):
630    """
631    Tell Rebroadcast how to generate C code for a Theano Type.
632
633    typ : Theano type
634        It must be the Theano class itself and not an instance of the class.
635    code : C code
636        That checks if the dimension %(axis)s is of shape 1 for the Theano type
637        'typ'. Use %(iname)s and %(oname)s for the input and output C variable
638        names respectively, and %(axis)s for the axis that we need to check.
639        This code is put in a loop for all axes.
640    version
641        A number indicating the version of the code, for cache.
642
643    """
644    Rebroadcast.c_code_and_version[typ] = (code, version)
645
646
647class Rebroadcast(gof.Op):
648    """
649    Change the input's broadcastable fields in some predetermined way.
650
651    See Also
652    --------
653    unbroadcast <theano.tensor.unbroadcast>
654    addbroadcast <theano.tensor.addbroadcast>
655    patternbroadcast <theano.tensor.patternbroadcast>
656
657    Notes
658    -----
659    Works inplace and works for CudaNdarrayType.
660
661    Example
662    -------
663    `Rebroadcast((0, True), (1, False))(x)` would make `x` broadcastable in
664    axis 0 and not broadcastable in axis 1.
665
666    """
667
668    view_map = {0: [0]}
669    _f16_ok = True
670    # Mapping from Type to C code (and version) to use.
671    # In the C code, the name of the input variable is %(iname)s,
672    # the output variable is %(oname)s.
673    c_code_and_version = {}
674
675    check_input = False
676    __props__ = ("axis",)
677    _f16_ok = True
678
679    def __init__(self, *axis):
680        # Sort them to make sure we merge all possible case.
681        items = sorted(axis)
682        self.axis = OrderedDict(items)
683        for axis, broad in iteritems(self.axis):
684            if not isinstance(axis, (np.integer, integer_types)):
685                raise TypeError("Rebroadcast needs integer axes. "
686                                "Got {}".format(axis))
687
688            if not isinstance(broad, (np.bool_, bool)):
689                raise TypeError("Rebroadcast needs bool for new broadcast "
690                                "pattern. Got {}".format(broad))
691
692    def __hash__(self):
693        # Need special __hash__ as dict aren't hashable.
694        # no ambiguity because each item key is unique
695        items = sorted(iteritems(self.axis))
696        return hash((type(self), tuple(items)))
697
698    def __str__(self):
699        if len(self.axis) == 0:
700            broadcast_pattern = []
701        else:
702            broadcast_pattern = ['?' for i
703                                 in xrange(1 + max(self.axis.keys()))]
704        for k, v in iteritems(self.axis):
705            broadcast_pattern[k] = str(int(v))
706        return '%s{%s}' % (self.__class__.__name__,
707                           ','.join(broadcast_pattern))
708
709    def make_node(self, x):
710        if self.axis.keys() and (x.ndim <= max(self.axis.keys())):
711            raise ValueError('Trying to rebroadcast non-existent dimension')
712        t = x.type.clone(
713            broadcastable=[self.axis.get(i, b)
714                           for i, b in enumerate(x.type.broadcastable)])
715        return gof.Apply(self, [x], [t()])
716
717    def perform(self, node, inp, out_):
718        x, = inp
719        out, = out_
720        for axis, value in iteritems(self.axis):
721            if value and x.shape[axis] != 1:
722                raise ValueError('Dimension %s in Rebroadcast\'s input was'
723                                 ' supposed to be 1 (got %s instead)' %
724                                 (axis, x.shape[axis]))
725        out[0] = x
726
727    def grad(self, inp, grads):
728        x, = inp
729        gz, = grads
730        # restore the broadcasting pattern of the input
731        return Rebroadcast(*[(axis, x.type.broadcastable[axis])
732                             for axis, value in iteritems(self.axis)])(gz),
733
734    def infer_shape(self, node, ishapes):
735        assert len(ishapes) == 1
736        l = []
737        one = theano.tensor.basic.constant(1)
738        for ax in xrange(len(ishapes[0])):
739            if self.axis.get(ax, False):
740                l.append(one)
741            else:
742                l.append(ishapes[0][ax])
743
744        return [tuple(l)]
745
746    def R_op(self, inputs, eval_points):
747        if eval_points[0] is None:
748            return [None]
749        return self(*eval_points, **dict(return_list=True))
750
751    def c_code(self, node, nodename, inp, out, sub):
752        iname, = inp
753        oname, = out
754        fail = sub['fail']
755
756        itype = node.inputs[0].type.__class__
757        if itype in self.c_code_and_version:
758            code, version = self.c_code_and_version[itype]
759            final_code = ""
760            for axis, value in iteritems(self.axis):
761                if value:
762                    final_code += code % locals()
763            return final_code + """
764            Py_XDECREF(%(oname)s);
765            %(oname)s = %(iname)s;
766            Py_XINCREF(%(oname)s);
767            """ % locals()
768        return super(Rebroadcast, self).c_code(node, nodename, inp, out, sub)
769
770    def c_code_cache_version(self):
771        version = []
772        # If any of the c code is unversionned, we have to return ()
773        # Else, we will return a list of (type name, version) pairs.
774        for t, (c, v) in sorted(iteritems(self.c_code_and_version),
775                                key=lambda pair: str(pair[0])):
776            if not v:
777                warnings.warn("Type %s has C code for Rebroadcast, but it "
778                              "has no version. You should add a 'version' "
779                              "keyword arg when calling "
780                              "register_rebroadcast_c_code." % t,
781                              stacklevel=2)
782                return ()
783            version.append((str(t), v))
784
785        if version:
786            version.append(1)
787        return tuple(version)
788
789
790def register_specify_shape_c_code(typ, code, version=(),
791                                  c_support_code_apply=None):
792    """
793    Tell SpecifyShape how to generate C code for a Theano Type.
794
795    Parameters
796    ----------
797    typ : Theano type
798        It must be the Theano class itself and not an instance of the class.
799    code : C code
800        Checks the shape and returns a view for the Theano type 'typ'.
801        Use %(iname)s and %(oname)s for the input and output C variable names
802        respectively. %(shape)s is the vector of shape of %(iname)s.
803        Check that its length is good.
804    version
805        A number indicating the version of the code, for cache.
806    c_support_code_apply
807        Extra code.
808
809    """
810    SpecifyShape.c_code_and_version[typ] = (code, version,
811                                            c_support_code_apply)
812
813
814class SpecifyShape(gof.Op):
815    """
816    L{Op} that puts into the graph the user-provided shape.
817
818    In the case where this op stays in the final graph, we assert the shape.
819    For this the output of this op must be used in the graph. This is not
820    the case most of the time if we only take the shape of the output.
821    Maybe there are other optimizations that will mess with this.
822
823    Notes
824    -----
825    Maybe in the future we will never do the assert!
826
827    We currently don't support specifying partial shape information.
828
829    TODO : test this op with sparse. Do C code for them too.
830
831    """
832
833    view_map = {0: [0]}
834    # Mapping from Type to C code (and version) to use.
835    # In the C code, the name of the input variable is %(iname)s,
836    # the output variable is %(oname)s.
837    c_code_and_version = {}
838    __props__ = ()
839    _f16_ok = True
840
841    def make_node(self, x, shape):
842        if not isinstance(x, gof.Variable):
843            x = theano.tensor.as_tensor_variable(x)
844        shape = theano.tensor.as_tensor_variable(shape)
845        assert shape.ndim == 1
846        assert shape.dtype in theano.tensor.integer_dtypes
847        if isinstance(shape, theano.tensor.TensorConstant):
848            assert shape.data.size == x.ndim
849        return gof.Apply(self, [x, shape], [x.type()])
850
851    def perform(self, node, inp, out_):
852        x, shape = inp
853        out, = out_
854        assert x.ndim == shape.size
855        assert np.all(x.shape == shape), ("got shape", x.shape,
856                                          "expected", shape)
857        out[0] = x
858
859    def infer_shape(self, node, shapes):
860        xshape, sshape = shapes
861        new_shape = []
862        for dim in xrange(node.inputs[0].ndim):
863            try:
864                s = theano.tensor.get_scalar_constant_value(
865                    node.inputs[1][dim])
866                s = theano.tensor.as_tensor_variable(s)
867                new_shape.append(s)
868            except theano.tensor.NotScalarConstantError:
869                new_shape.append(node.inputs[1][dim])
870
871        assert len(new_shape) == len(xshape)
872        return [new_shape]
873
874    def connection_pattern(self, node):
875        return [[True], [False]]
876
877    def grad(self, inp, grads):
878        x, s = inp
879        gz, = grads
880        # Should I set an SpecifyShape on gz? I think so
881        # But I don't do it now as we need to make an optimization
882        # to remove that op from the graph to don't block other optimization
883        # Should I do an optimizer that will remove the SpecifyShape?
884        # I think Yes
885        return [gz, theano.gradient.DisconnectedType()()]
886        return [specify_shape(gz, s), theano.gradient.DisconnectedType()()]
887
888    def R_op(self, inputs, eval_points):
889        if eval_points[0] is None:
890            # It means that the this op sits on top of a non-differentiable
891            # path
892            return [None]
893        return self.make_node(eval_points[0], *inputs[1:]).outputs
894
895    def c_support_code_apply(self, node, name):
896        itype = node.inputs[0].type.__class__
897        if itype in self.c_code_and_version:
898            _, _, support_code = self.c_code_and_version[itype]
899            if support_code:
900                return support_code
901        return super(SpecifyShape, self).c_support_code_apply(node, name)
902
903    def c_code(self, node, name, inames, onames, sub):
904        iname, shape = inames
905        oname, = onames
906        fail = sub['fail']
907
908        itype = node.inputs[0].type.__class__
909        if itype in self.c_code_and_version:
910            code, version, _ = self.c_code_and_version[itype]
911            return code % locals()
912
913        return super(SpecifyShape, self).c_code(node, node, inames,
914                                                onames, sub)
915
916    def c_code_cache_version(self):
917        version = []
918        # If any of the c code is unversionned, we have to return ()
919        # Else, we will return a list of (type name, version) pairs.
920        for t, (c, v, _) in sorted(iteritems(self.c_code_and_version),
921                                   key=lambda pair: str(pair[0])):
922            if not v:
923                warnings.warn("Type %s has C code for SpecifyShape, but it "
924                              "has no version. You should add a 'version' "
925                              "keyword arg when calling "
926                              "register_specify_shape_c_code." % t,
927                              stacklevel=2)
928                return ()
929            version.append((str(t), v))
930
931        return tuple(version)
932
933
934specify_shape = SpecifyShape()
935