1from __future__ import absolute_import, print_function, division 2# 3# Slice type and Op. None Type and NoneConst. 4# 5 6import numpy as np 7 8import theano 9from theano.gof import Apply, Constant, Generic, Op, Type, hashtype 10from theano.gradient import DisconnectedType 11 12 13def as_int_none_variable(x): 14 if x is None: 15 return NoneConst 16 elif NoneConst.equals(x): 17 return x 18 x = theano.tensor.as_tensor_variable(x, ndim=0) 19 if x.type.dtype not in theano.tensor.integer_dtypes: 20 raise TypeError('index must be integers') 21 return x 22 23 24class MakeSlice(Op): 25 26 __props__ = () 27 28 def make_node(self, slc, stop=None, step=None): 29 # We need to accept and handle in make_node inputs the node 30 # inputs to allow redoing a new op elsewhere in the graph by 31 # optimization. 32 if isinstance(slc, slice): 33 assert stop is None 34 assert step is None 35 inp = [slc.start, slc.stop, slc.step] 36 else: 37 inp = [slc, stop, step] 38 return Apply(self, 39 list(map(as_int_none_variable, inp)), 40 [slicetype()]) 41 42 def perform(self, node, inp, out_): 43 out, = out_ 44 out[0] = slice(*inp) 45 46 def grad(self, inputs, grads): 47 return [DisconnectedType()() for i in inputs] 48 49make_slice = MakeSlice() 50 51 52class SliceType(Type): 53 54 def filter(self, x, strict=False, allow_downcast=None): 55 if isinstance(x, slice): 56 return x 57 else: 58 raise TypeError('Expected a slice!') 59 60 def __str__(self): 61 return "slice" 62 63 def __eq__(self, other): 64 return type(self) == type(other) 65 66 def __hash__(self): 67 return hashtype(self) 68 69 @staticmethod 70 def may_share_memory(a, b): 71 # Slices never shared memory between object 72 return isinstance(a, slice) and a is b 73 74slicetype = SliceType() 75 76 77class SliceConstant(Constant): 78 def __init__(self, type, data, name=None): 79 assert isinstance(data, slice) 80 # Numpy ndarray aren't hashable, so get rid of them. 81 if isinstance(data.start, np.ndarray): 82 assert data.start.ndim == 0 83 assert str(data.start.dtype) in theano.tensor.integer_dtypes 84 data = slice(int(data.start), data.stop, data.step) 85 elif isinstance(data.stop, np.ndarray): 86 assert data.stop.ndim == 0 87 assert str(data.stop.dtype) in theano.tensor.integer_dtypes 88 data = slice(data.start, int(data.stop), data.step) 89 elif isinstance(data.step, np.ndarray): 90 assert data.step.ndim == 0 91 assert str(data.step.dtype) in theano.tensor.integer_dtypes 92 data = slice(data.start, int(data.stop), data.step) 93 Constant.__init__(self, type, data, name) 94 95 def signature(self): 96 return (SliceConstant, self.data.start, self.data.stop, self.data.step) 97 98 def __str__(self): 99 return "%s{%s, %s, %s}" % (self.__class__.__name__, 100 self.data.start, 101 self.data.stop, 102 self.data.step) 103SliceType.Constant = SliceConstant 104 105 106class NoneTypeT(Generic): 107 """ 108 Inherit from Generic to have c code working. 109 110 """ 111 112 def filter(self, x, strict=False, allow_downcast=None): 113 if x is None: 114 return x 115 else: 116 raise TypeError('Expected None!') 117 118 @staticmethod 119 def may_share_memory(a, b): 120 # None never share memory between object, in the sence of DebugMode. 121 # Python None are singleton 122 return False 123 124none_type_t = NoneTypeT() 125 126# This is a variable instance. It can be used only once per fgraph. 127# So use NoneConst.clone() before using it in a Theano graph. 128# Use NoneConst.equals(x) to check if two variable are NoneConst. 129NoneConst = Constant(none_type_t, None, name='NoneConst') 130