1from libc cimport stdlib
2cimport cpython.buffer
3
4import sys
5
6available_flags = (
7    ('FORMAT', cpython.buffer.PyBUF_FORMAT),
8    ('INDIRECT', cpython.buffer.PyBUF_INDIRECT),
9    ('ND', cpython.buffer.PyBUF_ND),
10    ('STRIDES', cpython.buffer.PyBUF_STRIDES),
11    ('C_CONTIGUOUS', cpython.buffer.PyBUF_C_CONTIGUOUS),
12    ('F_CONTIGUOUS', cpython.buffer.PyBUF_F_CONTIGUOUS),
13    ('WRITABLE', cpython.buffer.PyBUF_WRITABLE)
14)
15
16cdef class MockBuffer:
17    cdef object format, offset
18    cdef void* buffer
19    cdef Py_ssize_t len, itemsize
20    cdef Py_ssize_t* strides
21    cdef Py_ssize_t* shape
22    cdef Py_ssize_t* suboffsets
23    cdef object label, log
24    cdef int ndim
25    cdef bint writable
26
27    cdef readonly object received_flags, release_ok
28    cdef public object fail
29
30    def __init__(self, label, data, shape=None, strides=None, format=None, writable=True, offset=0):
31        # It is important not to store references to data after the constructor
32        # as refcounting is checked on object buffers.
33        cdef Py_ssize_t x, s, cumprod, itemsize
34        self.label = label
35        self.release_ok = True
36        self.log = u""
37        self.offset = offset
38        self.itemsize = itemsize = self.get_itemsize()
39        self.writable = writable
40        if format is None: format = self.get_default_format()
41        if shape is None: shape = (len(data),)
42        if strides is None:
43            strides = []
44            cumprod = 1
45            for s in shape[::-1]:
46                strides.append(cumprod)
47                cumprod *= s
48            strides.reverse()
49        strides = [x * itemsize for x in strides]
50        suboffsets = [-1] * len(shape)
51        datashape = [len(data)]
52        p = data
53        while True:
54            p = p[0]
55            if isinstance(p, list): datashape.append(len(p))
56            else: break
57        if len(datashape) > 1:
58            # indirect access
59            self.ndim = <int>len(datashape)
60            shape = datashape
61            self.buffer = self.create_indirect_buffer(data, shape)
62            suboffsets = [0] * self.ndim
63            suboffsets[-1] = -1
64            strides = [sizeof(void*)] * self.ndim
65            strides[-1] = itemsize
66            self.suboffsets = self.list_to_sizebuf(suboffsets)
67        else:
68            # strided and/or simple access
69            self.buffer = self.create_buffer(data)
70            self.ndim = <int>len(shape)
71            self.suboffsets = NULL
72
73        try:
74            format = format.encode('ASCII')
75        except AttributeError:
76            pass
77        self.format = format
78        self.len = len(data) * itemsize
79
80        self.strides = self.list_to_sizebuf(strides)
81        self.shape = self.list_to_sizebuf(shape)
82
83    def __dealloc__(self):
84        stdlib.free(self.strides)
85        stdlib.free(self.shape)
86        if self.suboffsets != NULL:
87            stdlib.free(self.suboffsets)
88            # must recursively free indirect...
89        else:
90            stdlib.free(self.buffer)
91
92    cdef void* create_buffer(self, data) except NULL:
93        cdef size_t n = <size_t>(len(data) * self.itemsize)
94        cdef char* buf = <char*>stdlib.malloc(n)
95        if buf == NULL:
96            raise MemoryError
97        cdef char* it = buf
98        for value in data:
99            self.write(it, value)
100            it += self.itemsize
101        return buf
102
103    cdef void* create_indirect_buffer(self, data, shape) except NULL:
104        cdef size_t n = 0
105        cdef void** buf
106        assert shape[0] == len(data), (shape[0], len(data))
107        if len(shape) == 1:
108            return self.create_buffer(data)
109        else:
110            shape = shape[1:]
111            n = <size_t>len(data) * sizeof(void*)
112            buf = <void**>stdlib.malloc(n)
113            if buf == NULL:
114                return NULL
115
116            for idx, subdata in enumerate(data):
117                buf[idx] = self.create_indirect_buffer(subdata, shape)
118
119            return buf
120
121    cdef Py_ssize_t* list_to_sizebuf(self, l):
122        cdef Py_ssize_t i, x
123        cdef size_t n = <size_t>len(l) * sizeof(Py_ssize_t)
124        cdef Py_ssize_t* buf = <Py_ssize_t*>stdlib.malloc(n)
125        for i, x in enumerate(l):
126            buf[i] = x
127        return buf
128
129    def __getbuffer__(MockBuffer self, Py_buffer* buffer, int flags):
130        if self.fail:
131            raise ValueError("Failing on purpose")
132
133        self.received_flags = []
134        cdef int value
135        for name, value in available_flags:
136            if (value & flags) == value:
137                self.received_flags.append(name)
138
139        if flags & cpython.buffer.PyBUF_WRITABLE and not self.writable:
140            raise BufferError(f"Writable buffer requested from read-only mock: {' | '.join(self.received_flags)}")
141
142        buffer.buf = <void*>(<char*>self.buffer + (<int>self.offset * self.itemsize))
143        buffer.obj = self
144        buffer.len = self.len
145        buffer.readonly = not self.writable
146        buffer.format = <char*>self.format
147        buffer.ndim = self.ndim
148        buffer.shape = self.shape
149        buffer.strides = self.strides
150        buffer.suboffsets = self.suboffsets
151        buffer.itemsize = self.itemsize
152        buffer.internal = NULL
153        if self.label:
154            msg = f"acquired {self.label}"
155            print(msg)
156            self.log += msg + u"\n"
157
158    def __releasebuffer__(MockBuffer self, Py_buffer* buffer):
159        if buffer.suboffsets != self.suboffsets:
160            self.release_ok = False
161        if self.label:
162            msg = f"released {self.label}"
163            print(msg)
164            self.log += msg + u"\n"
165
166    def printlog(self):
167        print(self.log[:-1])
168
169    def resetlog(self):
170        self.log = u""
171
172    cdef int write(self, char* buf, object value) except -1: raise Exception()
173    cdef get_itemsize(self):
174        print(f"ERROR, not subclassed: {self.__class__}")
175    cdef get_default_format(self):
176        print(f"ERROR, not subclassed {self.__class__}")
177
178cdef class CharMockBuffer(MockBuffer):
179    cdef int write(self, char* buf, object value) except -1:
180        (<char*>buf)[0] = <char>value
181        return 0
182    cdef get_itemsize(self): return sizeof(char)
183    cdef get_default_format(self): return b"@b"
184
185cdef class IntMockBuffer(MockBuffer):
186    cdef int write(self, char* buf, object value) except -1:
187        (<int*>buf)[0] = <int>value
188        return 0
189    cdef get_itemsize(self): return sizeof(int)
190    cdef get_default_format(self): return b"@i"
191
192cdef class UnsignedIntMockBuffer(MockBuffer):
193    cdef int write(self, char* buf, object value) except -1:
194        (<unsigned int*>buf)[0] = <unsigned int>value
195        return 0
196    cdef get_itemsize(self): return sizeof(unsigned int)
197    cdef get_default_format(self): return b"@I"
198
199cdef class ShortMockBuffer(MockBuffer):
200    cdef int write(self, char* buf, object value) except -1:
201        (<short*>buf)[0] = <short>value
202        return 0
203    cdef get_itemsize(self): return sizeof(short)
204    cdef get_default_format(self): return b"h" # Try without endian specifier
205
206cdef class UnsignedShortMockBuffer(MockBuffer):
207    cdef int write(self, char* buf, object value) except -1:
208        (<unsigned short*>buf)[0] = <unsigned short>value
209        return 0
210    cdef get_itemsize(self): return sizeof(unsigned short)
211    cdef get_default_format(self): return b"@1H" # Try with repeat count
212
213cdef class FloatMockBuffer(MockBuffer):
214    cdef int write(self, char* buf, object value) except -1:
215        (<float*>buf)[0] = <float>(<double>value)
216        return 0
217    cdef get_itemsize(self): return sizeof(float)
218    cdef get_default_format(self): return b"f"
219
220cdef class DoubleMockBuffer(MockBuffer):
221    cdef int write(self, char* buf, object value) except -1:
222        (<double*>buf)[0] = <double>value
223        return 0
224    cdef get_itemsize(self): return sizeof(double)
225    cdef get_default_format(self): return b"d"
226
227cdef extern from *:
228    void* addr_of_pyobject "(void*)"(object)
229
230cdef class ObjectMockBuffer(MockBuffer):
231    cdef int write(self, char* buf, object value) except -1:
232        (<void**>buf)[0] = addr_of_pyobject(value)
233        return 0
234
235    cdef get_itemsize(self): return sizeof(void*)
236    cdef get_default_format(self): return b"@O"
237
238cdef class IntStridedMockBuffer(IntMockBuffer):
239    cdef __cythonbufferdefaults__ = {"mode" : "strided"}
240
241cdef class ErrorBuffer:
242    cdef object label
243
244    def __init__(self, label):
245        self.label = label
246
247    def __getbuffer__(ErrorBuffer self, Py_buffer* buffer, int flags):
248        raise Exception(f"acquiring {self.label}")
249
250    def __releasebuffer__(ErrorBuffer self, Py_buffer* buffer):
251        raise Exception(f"releasing {self.label}")
252
253#
254# Structs
255#
256cdef struct MyStruct:
257    signed char a
258    signed char b
259    long long int c
260    int d
261    int e
262
263cdef struct SmallStruct:
264    int a
265    int b
266
267cdef struct NestedStruct:
268    SmallStruct x
269    SmallStruct y
270    int z
271
272cdef packed struct PackedStruct:
273    signed char a
274    int b
275
276cdef struct NestedPackedStruct:
277    signed char a
278    int b
279    PackedStruct sub
280    int c
281
282cdef class MyStructMockBuffer(MockBuffer):
283    cdef int write(self, char* buf, object value) except -1:
284        cdef MyStruct* s
285        s = <MyStruct*>buf
286        s.a, s.b, s.c, s.d, s.e = value
287        return 0
288
289    cdef get_itemsize(self): return sizeof(MyStruct)
290    cdef get_default_format(self): return b"2cq2i"
291
292cdef class NestedStructMockBuffer(MockBuffer):
293    cdef int write(self, char* buf, object value) except -1:
294        cdef NestedStruct* s
295        s = <NestedStruct*>buf
296        s.x.a, s.x.b, s.y.a, s.y.b, s.z = value
297        return 0
298
299    cdef get_itemsize(self): return sizeof(NestedStruct)
300    cdef get_default_format(self): return b"2T{ii}i"
301
302cdef class PackedStructMockBuffer(MockBuffer):
303    cdef int write(self, char* buf, object value) except -1:
304        cdef PackedStruct* s
305        s = <PackedStruct*>buf
306        s.a, s.b = value
307        return 0
308
309    cdef get_itemsize(self): return sizeof(PackedStruct)
310    cdef get_default_format(self): return b"^ci"
311
312cdef class NestedPackedStructMockBuffer(MockBuffer):
313    cdef int write(self, char* buf, object value) except -1:
314        cdef NestedPackedStruct* s
315        s = <NestedPackedStruct*>buf
316        s.a, s.b, s.sub.a, s.sub.b, s.c = value
317        return 0
318
319    cdef get_itemsize(self): return sizeof(NestedPackedStruct)
320    cdef get_default_format(self): return b"ci^ci@i"
321
322cdef struct LongComplex:
323    long double real
324    long double imag
325
326cdef class LongComplexMockBuffer(MockBuffer):
327    cdef int write(self, char* buf, object value) except -1:
328        cdef LongComplex* s
329        s = <LongComplex*>buf
330        s.real, s.imag = value
331        return 0
332
333    cdef get_itemsize(self): return sizeof(LongComplex)
334    cdef get_default_format(self): return b"Zg"
335
336
337def print_offsets(*args, size, newline=True):
338    sys.stdout.write(' '.join([str(item // size) for item in args]) + ('\n' if newline else ''))
339
340def print_int_offsets(*args, newline=True):
341    print_offsets(*args, size=sizeof(int), newline=newline)
342
343
344shape_5_3_4_list = [[list(range(k * 12 + j * 4, k * 12 + j * 4 + 4))
345                        for j in range(3)]
346                            for k in range(5)]
347
348stride1 = 21 * 14
349stride2 = 21
350shape_9_14_21_list = [[list(range(k * stride1 + j * stride2, k * stride1 + j * stride2 + 21))
351                           for j in range(14)]
352                               for k in range(9)]
353