1from __future__ import absolute_import, print_function, division
2#
3# Slice type and Op. None Type and NoneConst.
4#
5
6import numpy as np
7
8import theano
9from theano.gof import Apply, Constant, Generic, Op, Type, hashtype
10from theano.gradient import DisconnectedType
11
12
13def as_int_none_variable(x):
14    if x is None:
15        return NoneConst
16    elif NoneConst.equals(x):
17        return x
18    x = theano.tensor.as_tensor_variable(x, ndim=0)
19    if x.type.dtype not in theano.tensor.integer_dtypes:
20        raise TypeError('index must be integers')
21    return x
22
23
24class MakeSlice(Op):
25
26    __props__ = ()
27
28    def make_node(self, slc, stop=None, step=None):
29        # We need to accept and handle in make_node inputs the node
30        # inputs to allow redoing a new op elsewhere in the graph by
31        # optimization.
32        if isinstance(slc, slice):
33            assert stop is None
34            assert step is None
35            inp = [slc.start, slc.stop, slc.step]
36        else:
37            inp = [slc, stop, step]
38        return Apply(self,
39                     list(map(as_int_none_variable, inp)),
40                     [slicetype()])
41
42    def perform(self, node, inp, out_):
43        out, = out_
44        out[0] = slice(*inp)
45
46    def grad(self, inputs, grads):
47        return [DisconnectedType()() for i in inputs]
48
49make_slice = MakeSlice()
50
51
52class SliceType(Type):
53
54    def filter(self, x, strict=False, allow_downcast=None):
55        if isinstance(x, slice):
56            return x
57        else:
58            raise TypeError('Expected a slice!')
59
60    def __str__(self):
61        return "slice"
62
63    def __eq__(self, other):
64        return type(self) == type(other)
65
66    def __hash__(self):
67        return hashtype(self)
68
69    @staticmethod
70    def may_share_memory(a, b):
71        # Slices never shared memory between object
72        return isinstance(a, slice) and a is b
73
74slicetype = SliceType()
75
76
77class SliceConstant(Constant):
78    def __init__(self, type, data, name=None):
79        assert isinstance(data, slice)
80        # Numpy ndarray aren't hashable, so get rid of them.
81        if isinstance(data.start, np.ndarray):
82            assert data.start.ndim == 0
83            assert str(data.start.dtype) in theano.tensor.integer_dtypes
84            data = slice(int(data.start), data.stop, data.step)
85        elif isinstance(data.stop, np.ndarray):
86            assert data.stop.ndim == 0
87            assert str(data.stop.dtype) in theano.tensor.integer_dtypes
88            data = slice(data.start, int(data.stop), data.step)
89        elif isinstance(data.step, np.ndarray):
90            assert data.step.ndim == 0
91            assert str(data.step.dtype) in theano.tensor.integer_dtypes
92            data = slice(data.start, int(data.stop), data.step)
93        Constant.__init__(self, type, data, name)
94
95    def signature(self):
96        return (SliceConstant, self.data.start, self.data.stop, self.data.step)
97
98    def __str__(self):
99        return "%s{%s, %s, %s}" % (self.__class__.__name__,
100                                   self.data.start,
101                                   self.data.stop,
102                                   self.data.step)
103SliceType.Constant = SliceConstant
104
105
106class NoneTypeT(Generic):
107    """
108    Inherit from Generic to have c code working.
109
110    """
111
112    def filter(self, x, strict=False, allow_downcast=None):
113        if x is None:
114            return x
115        else:
116            raise TypeError('Expected None!')
117
118    @staticmethod
119    def may_share_memory(a, b):
120        # None never share memory between object, in the sence of DebugMode.
121        # Python None are singleton
122        return False
123
124none_type_t = NoneTypeT()
125
126# This is a variable instance. It can be used only once per fgraph.
127# So use NoneConst.clone() before using it in a Theano graph.
128# Use NoneConst.equals(x) to check if two variable are NoneConst.
129NoneConst = Constant(none_type_t, None, name='NoneConst')
130