1 /*
2 *Copyright (c) 2018 Intel Corporation.
3 *
4 *Permission is hereby granted, free of charge, to any person obtaining a copy
5 *of this software and associated documentation files (the "Software"), to deal
6 *in the Software without restriction, including without limitation the rights
7 *to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
8 *copies of the Software, and to permit persons to whom the Software is
9 *furnished to do so, subject to the following conditions:
10 *
11 *The above copyright notice and this permission notice shall be included in
12 *all copies or substantial portions of the Software.
13 *
14 *THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
15 *IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
16 *FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
17 *AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
18 *LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
19 *OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
20 *THE SOFTWARE.
21 *
22 */
23
24
25 #include <Python.h>
26 #define NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION
27 #include "mdarray.h"
28 #include <immintrin.h>
29 #include <mkl_vml_functions.h>
30 #include "ideep_pin_singletons.hpp"
31 // #include "dlcp_py.h"
32
33 namespace implementation {
34
35 // Pin virtual table
36
37 mdarray:: ~mdarray() = default;
38
39 static PyObject *PyType_reorder_buffer = nullptr;
40
41 static swig_type_info *SwigTy_mdarray = nullptr;
42 //static swig_type_info *SwigTy_engine = nullptr;
43 static PyObject *PyType_mdarray = nullptr;
44
45 // get mdarray from PyObject
get_mdarray_from_PyObject(PyObject * self)46 static inline mdarray *get_mdarray_from_PyObject(PyObject *self) {
47 void *oprd_self;
48 int res = SWIG_ConvertPtr(self, &oprd_self, nullptr, 0);
49 if (!SWIG_IsOK(res)) {
50 // PyErr_SetString(PyExc_ValueError, "Error self PyObject");
51 return NULL;
52 }
53 return (reinterpret_cast<py_handle *>(oprd_self))->get();
54 }
55
56 //check whether mdarray support this operation
is_mdarray_supported(PyObject * self,PyObject * o)57 static inline bool is_mdarray_supported(PyObject *self, PyObject *o) {
58 // get self mdarray
59 mdarray *self_mdarray = get_mdarray_from_PyObject(self);
60 if (!self_mdarray)
61 return false;
62
63 // o is ndarray
64 // if size not equal, mean array broadcast
65 if (reinterpret_cast<PyTypeObject *>(o->ob_type) == &PyArray_Type) {
66 if ((size_t)PyArray_SIZE(reinterpret_cast<PyArrayObject *>(o))
67 != (size_t)self_mdarray->get_nelems() ||
68 !PyArray_ISFLOAT(reinterpret_cast<PyArrayObject *>(o))) {
69 return false;
70 }
71 return true;
72 }
73
74 // o is mdarray
75 if (reinterpret_cast<PyTypeObject *>(o->ob_type)
76 == reinterpret_cast<PyTypeObject *>(PyType_mdarray)) {
77 // if o is mdarray, try to get mdarray
78 mdarray *o_mdarray = get_mdarray_from_PyObject(o);
79 if (!o_mdarray)
80 return false;
81
82 // not support different size's mdarray's operations
83 if (o_mdarray->get_nelems() != self_mdarray->get_nelems())
84 return false;
85
86 return true;
87 }
88
89 return false;
90 }
91
queryPyTypeObject(const char * name)92 PyObject *queryPyTypeObject(const char *name) {
93 swig_type_info *info = SWIG_TypeQuery(name);
94 if (info != nullptr) {
95 SwigPyClientData *cd
96 = (SwigPyClientData *)info->clientdata;
97 return reinterpret_cast<PyObject *>(cd->pytype);
98 }
99
100 throw mkldnn::error(mkldnn_invalid_arguments
101 , "Failed to find reorderer object");
102 }
103
104 // We brought this to global scope to mitigate it consumption
105 #if PY_VERSION_HEX >= 0x03000000
g_init()106 int g_init() {
107 #else
108 void g_init() {
109 #endif
110 PyType_reorder_buffer = queryPyTypeObject("_p_reorder_buffer");
111 SwigTy_mdarray = SWIG_TypeQuery("_p_mdarray");
112 PyType_mdarray = queryPyTypeObject("_p_mdarray");
113 //SwigTy_engine = SWIG_TypeQuery("_p_mkldnn__engine");
114
115 #if PY_VERSION_HEX < 0x03000000
116 if ((reinterpret_cast<PyTypeObject *>(PyType_mdarray)->tp_flags
117 & Py_TPFLAGS_HAVE_NEWBUFFER) != Py_TPFLAGS_HAVE_NEWBUFFER)
118 throw mkldnn::error(mkldnn_invalid_arguments
119 , "Python2 should have new buffer flag on!");
120 #endif
121
122 // XXX: I don't quite understand it, and its repercussions :)
123 SwigPyObject_stype = SWIG_MangledTypeQuery("_p_SwigPyObject");
124
125 if (SwigPyObject_stype == nullptr)
126 throw mkldnn::error(mkldnn_invalid_arguments
127 , "Failed to find SwigPyObject object");
128
129 // Initiate static variables imported from numpy include
130 import_array();
131
132 // dlCompression::init();
133
134 #if PY_VERSION_HEX >= 0x03000000
135 return 0;
136 #else
137 return;
138 #endif
139 }
140
141 //FIXME: macro SWIG_as_voidptr is copied from mdarray_wrap.cpp
142 #define SWIG_as_voidptr(a) const_cast< void * >(static_cast< const void * >(a))
143
144 #if 0
145 // Pickle
146 PyObject *mdarray::__getstate__() const {
147 auto md = desc();
148 void *raw_data = get_data_handle();
149 int ndims = md.data.ndims;
150 mkldnn::memory::dims dims;
151 mkldnn::memory::data_type dtype = static_cast<mkldnn::memory::data_type>(md.data.data_type);
152 mkldnn::memory::format format = static_cast<mkldnn::memory::format>(md.data.format);
153 static mkldnn::engine engine = get_engine();
154
155 PyObject *py_dims = PyTuple_New(ndims);
156 for (int i = 0; i < ndims; i++) {
157 PyObject *py_dim = PyLong_FromLong(md.data.dims[i]);
158 PyTuple_SetItem(py_dims, i, py_dim);
159 }
160
161 PyObject *py_dtype = PyLong_FromLong((long)dtype);
162 PyObject *py_format = PyLong_FromLong((long)format);
163 PyObject *py_engine = PyLong_FromVoidPtr((void *)&engine);
164 PyObject *py_rdata = PyLong_FromVoidPtr((void *)raw_data);
165
166 PyObject *state = PyTuple_New(5);
167 PyTuple_SetItem(state, 0, py_dims);
168 PyTuple_SetItem(state, 1, py_dtype);
169 PyTuple_SetItem(state, 2, py_format);
170 PyTuple_SetItem(state, 3, py_engine);
171 PyTuple_SetItem(state, 4, py_rdata);
172
173 return state;
174 }
175
176 // Unpickle.
177 void mdarray::__setstate__(PyObject *state) {
178 return;
179 }
180 #endif
181
182 PyObject *mdarray::py_mdarray_from(PyObject *o) const {
183 PyObject *argList = Py_BuildValue("(O)", o);
184
185 if (argList == nullptr) {
186 PyErr_SetString(PyExc_SystemError, "Can not create argument list");
187 return nullptr;
188 }
189
190 o = PyObject_CallObject(PyType_mdarray, argList);
191
192 Py_DECREF(argList);
193
194 if (o == nullptr) {
195 PyErr_SetString(PyExc_BufferError, "Cannot create mdarray from input");
196 return nullptr;
197 }
198
199 return o;
200 }
201
202 using sum = ideep::sum;
203 using tensor = ideep::tensor;
204 using reorder = ideep::reorder;
205 using sum_array = ideep::sum_array;
206 using err_num_t = sum_array::err_num_t;
207 using scratch_allocator = ideep::utils::scratch_allocator;
208 using descriptor = ideep::tensor::descriptor;
209
210 void mdarray::axpby(tensor &dst, float a, const tensor &x, float b, const tensor &y) {
211 sum::compute<scratch_allocator, _IDEEP4PY_WEB_OPT_>(
212 {(float)a, (float)b}, {x, y}, dst);
213 return;
214 }
215
216 PyObject *mdarray::axpby(float a, float b, PyObject *o) {
217 /// Resource manager, for GCC do not accept lambda
218 struct py_decref {
219 void operator () (PyObject *p) const {
220 Py_DECREF(p);
221 }
222 };
223
224 std::unique_ptr<PyObject, py_decref> op(nullptr);
225
226 /// Create mdarray from buffer provider
227 if (reinterpret_cast<PyTypeObject *>(o->ob_type) == &PyArray_Type) {
228 o = py_mdarray_from(o);
229 op.reset(o);
230 }
231
232 void *oprd2;
233 int res = SWIG_ConvertPtr(o, &oprd2, nullptr, 0);
234
235 if (!SWIG_IsOK(res)) {
236 PyErr_SetString(PyExc_ValueError, "Wrong operand object in add wrapper");
237 return nullptr;
238 }
239
240 auto x = (reinterpret_cast<py_handle *>(oprd2))->get();
241 // cache dst tensor for performance
242 tensor dst;
243 dst.init<scratch_allocator, tensor>(x->get_descriptor());
244 py_handle *output = new py_handle(new mdarray(dst));
245
246 /// Switch position for format consistency
247 axpby(*output->get(), b, *x, a, *this);
248
249 PyObject *resultobj = SWIG_Python_NewPointerObj(nullptr
250 , SWIG_as_voidptr(output), SwigTy_mdarray, SWIG_POINTER_OWN | 0 );
251
252 return resultobj;
253 }
254
255 PyObject *mdarray::inplace_axpby(float a, PyObject *self, float b, PyObject *o) {
256 // Resource manager, for GCC do not accept lambda
257 struct py_decref {
258 void operator () (PyObject *p) const {
259 Py_DECREF(p);
260 }
261 };
262
263 std::unique_ptr<PyObject, py_decref> op(nullptr);
264
265 // Create mdarray from buffer provider
266 if (reinterpret_cast<PyTypeObject *>(o->ob_type) == &PyArray_Type) {
267 o = py_mdarray_from(o);
268 op.reset(o);
269 }
270
271 void *oprd2;
272 int res = SWIG_ConvertPtr(o, &oprd2, nullptr, 0);
273
274 if (!SWIG_IsOK(res)) {
275 PyErr_SetString(PyExc_ValueError, "Wrong operand object in add wrapper");
276 return nullptr;
277 }
278
279 auto y = (reinterpret_cast<py_handle *>(oprd2))->get();
280 axpby(*this, a, *this, b, *y);
281 Py_INCREF(self);
282
283 return self;
284 }
285
286 void mdarray::set(PyObject *o) {
287 // Resource manager, for GCC do not accept lambda
288 struct py_decref {
289 void operator () (PyObject *p) const {
290 Py_DECREF(p);
291 }
292 };
293
294 std::unique_ptr<PyObject, py_decref> op(nullptr);
295
296 // Create mdarray from buffer provider
297 if (reinterpret_cast<PyTypeObject *>(o->ob_type) == &PyArray_Type) {
298 o = py_mdarray_from(o);
299 op.reset(o);
300 }
301
302 void *oprd2;
303 int res = SWIG_ConvertPtr(o, &oprd2, nullptr, 0);
304
305 if (!SWIG_IsOK(res)) {
306 PyErr_SetString(PyExc_ValueError, "Wrong operand object in add wrapper");
307 return;
308 }
309
310 auto in = *(reinterpret_cast<py_handle *>(oprd2))->get();
311 auto dims = get_dims();
312 auto in_dims = in.get_dims();
313 if (dims.size() != in_dims.size())
314 throw error(mkldnn_invalid_arguments, "mdarray set: Inconsistent ndims");
315 for (size_t d = 0; d < dims.size(); d++) {
316 if (dims[d] != in_dims[d])
317 throw error(mkldnn_invalid_arguments, "mdarray set: Inconsistent dims");
318 }
319
320 tensor in_ = in;
321 if (in.get_descriptor() != get_descriptor()) {
322 in_.init(get_descriptor());
323 reorder::compute(in, in_);
324 }
325
326 memcpy(get_data_handle(), in_.get_data_handle(), get_size());
327 return;
328 }
329
330 PyObject *mdarray::m_Add(PyObject *self, PyObject *o) {
331 // Array Broadcast
332 if (!is_mdarray_supported(self, o)) {
333 return m_Add_map_impl(self, o);
334 } else if (PyArray_Check(o) &&
335 !PyArray_IS_C_CONTIGUOUS(reinterpret_cast<PyArrayObject *>(o))) {
336 // Make compatibility with Non-C-Contiguous array.
337 PyObject *_o = o;
338 #if PY_VERSION_HEX < 0x03000000
339 _o = reinterpret_cast<PyObject *>(PyArray_ContiguousFromAny(
340 o, PyArray_ISFLOAT(reinterpret_cast<PyArrayObject *>(o)) ? NPY_FLOAT : NPY_INT, 0, 0));
341 #endif
342 PyObject *ret = m_Add_map_impl(self, _o);
343 #if PY_VERSION_HEX < 0x03000000
344 Py_DECREF(_o);
345 #endif
346 return ret;
347 } else {
348 return axpby(1.0f, 1.0f, o);
349 }
350 }
351
352 PyObject *mdarray::m_Subtract(PyObject *self, PyObject *o) {
353 // Array Broadcast
354 if (!is_mdarray_supported(self, o)) {
355 return m_Subtract_map_impl(self, o);
356 } else if (PyArray_Check(o) &&
357 !PyArray_IS_C_CONTIGUOUS(reinterpret_cast<PyArrayObject *>(o))) {
358 PyObject *_o = o;
359 #if PY_VERSION_HEX < 0x03000000
360 _o = reinterpret_cast<PyObject *>(PyArray_ContiguousFromAny(
361 o, PyArray_ISFLOAT(reinterpret_cast<PyArrayObject *>(o)) ? NPY_FLOAT : NPY_INT, 0, 0));
362 #endif
363 PyObject *ret = m_Subtract_map_impl(self, _o);
364 #if PY_VERSION_HEX < 0x03000000
365 Py_DECREF(_o);
366 #endif
367 return ret;
368 } else {
369 return axpby(1.0f, -1.0f, o);
370 }
371 }
372
373 PyObject *mdarray::m_InPlaceAdd(PyObject *self, PyObject *o) {
374 // Array Broadcast
375 if (!is_mdarray_supported(self, o)) {
376 return m_InPlaceAdd_map_impl(self, o);
377 } else if (PyArray_Check(o) &&
378 !PyArray_IS_C_CONTIGUOUS(reinterpret_cast<PyArrayObject *>(o))) {
379 PyObject *_o = o;
380 #if PY_VERSION_HEX < 0x03000000
381 _o = reinterpret_cast<PyObject *>(PyArray_ContiguousFromAny(
382 o, PyArray_ISFLOAT(reinterpret_cast<PyArrayObject *>(o)) ? NPY_FLOAT : NPY_INT, 0, 0));
383 #endif
384 PyObject *ret = m_InPlaceAdd_map_impl(self, _o);
385 #if PY_VERSION_HEX < 0x03000000
386 Py_DECREF(_o);
387 #endif
388 return ret;
389 } else {
390 return inplace_axpby(1.0f, self, 1.0f, o);
391 }
392 }
393
394 PyObject *mdarray::m_InPlaceSubtract(PyObject *self, PyObject *o) {
395 // Array Broadcast
396 if (!is_mdarray_supported(self, o)) {
397 return m_InPlaceSubtract_map_impl(self, o);
398 } else if (PyArray_Check(o) &&
399 !PyArray_IS_C_CONTIGUOUS(reinterpret_cast<PyArrayObject *>(o))) {
400 PyObject *_o = o;
401 #if PY_VERSION_HEX < 0x03000000
402 _o = reinterpret_cast<PyObject *>(PyArray_ContiguousFromAny(
403 o, PyArray_ISFLOAT(reinterpret_cast<PyArrayObject *>(o)) ? NPY_FLOAT : NPY_INT, 0, 0));
404 #endif
405 PyObject *ret = m_InPlaceSubtract_map_impl(self, _o);
406 #if PY_VERSION_HEX < 0x03000000
407 Py_DECREF(_o);
408 #endif
409 return ret;
410 } else {
411 return inplace_axpby(1.0f, self, -1.0f, o);
412 }
413 }
414
415 template <typename T>
416 void plain_mult(const T *a, const T *b, T *o, int size) {
417 for (int idx = 0; idx < size; idx++)
418 o[idx] = a[idx] * b[idx];
419 }
420
421 template <typename T>
422 void plain_div(const T *a, const T *b, T *o, int size) {
423 for (int idx = 0; idx < size; idx++)
424 o[idx] = a[idx] / b[idx];
425 }
426
427 enum {mmult, mdiv};
428 PyObject *mdarray::m_mult_div(PyObject *self, PyObject *o, int mult_or_div, bool inplace) {
429 struct py_decref {
430 void operator () (PyObject *p) const {
431 Py_DECREF(p);
432 }
433 };
434
435 std::unique_ptr<PyObject, py_decref> op(nullptr);
436
437 enum mult_type_t { MULT_UNKNOWN, MULT_ELTWISE, MULT_SCALAR };
438
439 PyTypeObject *oprd2_type = reinterpret_cast<PyTypeObject *>(o->ob_type);
440 int mult_type = static_cast<int>(MULT_UNKNOWN);
441 if (oprd2_type == &PyArray_Type) {
442 mult_type = MULT_ELTWISE;
443 o = py_mdarray_from(o);
444 op.reset(o);
445 } else if (PyObject_HasAttrString(o, "is_mdarray")) {
446 mult_type = MULT_ELTWISE;
447 } else if (PyFloat_Check(o) || PyInt_Check(o) || PyNumber_Check(o)) {
448 mult_type = MULT_SCALAR;
449 }
450
451 PyObject *resultobj = nullptr;
452
453 switch (static_cast<enum mult_type_t>(mult_type)) {
454 case MULT_ELTWISE: {
455 void *oprd2;
456 int res = SWIG_ConvertPtr(o, &oprd2, nullptr, 0);
457 if (!SWIG_IsOK(res)) {
458 PyErr_SetString(PyExc_ValueError, "Error oprd2 %matrix element multiply");
459 break;
460 }
461
462 auto oprd1_mdarr = this;
463 auto oprd2_mdarr = (reinterpret_cast<py_handle *>(oprd2))->get();
464
465 if (oprd1_mdarr->get_nelems() != oprd2_mdarr->get_nelems()) {
466 PyErr_SetString(PyExc_SystemError, "Abnormal matrix size %matrix element multiply");
467 break;
468 }
469
470 auto oprd2_internal_m = *oprd2_mdarr;
471 if (oprd2_mdarr->get_descriptor() != oprd1_mdarr->get_descriptor()) {
472 oprd2_internal_m.init(oprd1_mdarr->get_descriptor());
473 reorder::compute(*oprd2_mdarr, oprd2_internal_m);
474 }
475
476 assert(oprd1_mdarr->ndims() == 2 || oprd1_mdarr->ndims() == 4);
477
478 mdarray *res_mdarr;
479 if (!inplace) {
480 res_mdarr = new mdarray(
481 oprd1_mdarr->get_dims(), oprd1_mdarr->get_data_type());
482 } else {
483 res_mdarr = oprd1_mdarr;
484 }
485
486 data_type_t res_dtype = oprd1_mdarr->get_data_type();
487 assert(data_type_t::f32 == res_dtype ||
488 data_type_t::s32 == res_dtype ||
489 data_type_t::s16 == res_dtype ||
490 data_type_t::s8 == res_dtype ||
491 data_type_t::u8 == res_dtype );
492 assert(mmult == mult_or_div ||
493 mdiv == mult_or_div);
494 if (data_type_t::f32 == res_dtype) {
495 switch (mult_or_div) {
496 case mmult:
497 vsMul(oprd1_mdarr->get_nelems(),
498 reinterpret_cast<const float *>(oprd1_mdarr->get_data_handle()),
499 reinterpret_cast<const float *>(oprd2_internal_m.get_data_handle()),
500 reinterpret_cast<float *>(res_mdarr->get_data_handle()));
501 break;
502
503 case mdiv:
504 plain_div(reinterpret_cast<const float *>(oprd1_mdarr->get_data_handle()),
505 reinterpret_cast<const float *>(oprd2_internal_m.get_data_handle()),
506 reinterpret_cast<float *>(res_mdarr->get_data_handle()),
507 static_cast<int>(oprd1_mdarr->get_nelems()));
508 break;
509 }
510 } else if (data_type_t::s32 == res_dtype) {
511 switch (mult_or_div) {
512 case mmult:
513 plain_mult(reinterpret_cast<const int *>(oprd1_mdarr->get_data_handle()),
514 reinterpret_cast<const int *>(oprd2_internal_m.get_data_handle()),
515 reinterpret_cast<int *>(res_mdarr->get_data_handle()),
516 static_cast<int>(oprd1_mdarr->get_nelems()));
517 break;
518
519 case mdiv:
520 plain_div(reinterpret_cast<const int *>(oprd1_mdarr->get_data_handle()),
521 reinterpret_cast<const int *>(oprd2_internal_m.get_data_handle()),
522 reinterpret_cast<int *>(res_mdarr->get_data_handle()),
523 static_cast<int>(oprd1_mdarr->get_nelems()));
524 break;
525 }
526 } else if (data_type_t::s16 == res_dtype) {
527 switch (mult_or_div) {
528 case mmult:
529 plain_mult(reinterpret_cast<const int16_t *>(oprd1_mdarr->get_data_handle()),
530 reinterpret_cast<const int16_t *>(oprd2_internal_m.get_data_handle()),
531 reinterpret_cast<int16_t *>(res_mdarr->get_data_handle()),
532 static_cast<int>(oprd1_mdarr->get_nelems()));
533 break;
534
535 case mdiv:
536 plain_div(reinterpret_cast<const int16_t *>(oprd1_mdarr->get_data_handle()),
537 reinterpret_cast<const int16_t *>(oprd2_internal_m.get_data_handle()),
538 reinterpret_cast<int16_t *>(res_mdarr->get_data_handle()),
539 static_cast<int>(oprd1_mdarr->get_nelems()));
540 break;
541 }
542 } else if (data_type_t::s8 == res_dtype) {
543 switch (mult_or_div) {
544 case mmult:
545 plain_mult(reinterpret_cast<const int8_t *>(oprd1_mdarr->get_data_handle()),
546 reinterpret_cast<const int8_t *>(oprd2_internal_m.get_data_handle()),
547 reinterpret_cast<int8_t *>(res_mdarr->get_data_handle()),
548 static_cast<int>(oprd1_mdarr->get_nelems()));
549 break;
550
551 case mdiv:
552 plain_div(reinterpret_cast<const int8_t *>(oprd1_mdarr->get_data_handle()),
553 reinterpret_cast<const int8_t *>(oprd2_internal_m.get_data_handle()),
554 reinterpret_cast<int8_t *>(res_mdarr->get_data_handle()),
555 static_cast<int>(oprd1_mdarr->get_nelems()));
556 break;
557 }
558 } else if (data_type_t::u8 == res_dtype) {
559 switch (mult_or_div) {
560 case mmult:
561 plain_mult(reinterpret_cast<const uint8_t *>(oprd1_mdarr->get_data_handle()),
562 reinterpret_cast<const uint8_t *>(oprd2_internal_m.get_data_handle()),
563 reinterpret_cast<uint8_t *>(res_mdarr->get_data_handle()),
564 static_cast<int>(oprd1_mdarr->get_nelems()));
565 break;
566
567 case mdiv:
568 plain_div(reinterpret_cast<const uint8_t *>(oprd1_mdarr->get_data_handle()),
569 reinterpret_cast<const uint8_t *>(oprd2_internal_m.get_data_handle()),
570 reinterpret_cast<uint8_t *>(res_mdarr->get_data_handle()),
571 static_cast<int>(oprd1_mdarr->get_nelems()));
572 break;
573 }
574 }
575
576 if (!inplace) {
577 auto res_py_handle = new py_handle(res_mdarr);
578 resultobj = SWIG_Python_NewPointerObj(nullptr,
579 SWIG_as_voidptr(res_py_handle),
580 SwigTy_mdarray,
581 SWIG_POINTER_OWN | 0);
582 } else {
583 resultobj = self;
584 Py_INCREF(self);
585 }
586
587 break;
588 }
589
590 case MULT_SCALAR: {
591 float a = PyInt_Check(o) ?
592 static_cast<float>(PyInt_AsLong(o)) :
593 PyFloat_AsDouble(o),
594 b = 0.0;
595
596 a = (mmult == mult_or_div) ? a : (1 / a);
597
598 if (!inplace) {
599 resultobj = axpby(a, b, self);
600 } else {
601 resultobj = inplace_axpby(a, self, b, self);;
602 }
603 break;
604 }
605
606 case MULT_UNKNOWN:
607 default:
608 PyErr_SetString(PyExc_SystemError, "Abnormal type % matrix * scalar");
609 break;
610 }
611
612 return resultobj;
613 }
614
615 PyObject *mdarray::m_Multiply(PyObject *self, PyObject *o) {
616 if (!is_mdarray_supported(self, o)) {
617 return m_Multiply_map_impl(self, o);
618 } else if (PyArray_Check(o) &&
619 !PyArray_IS_C_CONTIGUOUS(reinterpret_cast<PyArrayObject *>(o))) {
620 PyObject *_o = o;
621 #if PY_VERSION_HEX < 0x03000000
622 _o = reinterpret_cast<PyObject *>(PyArray_ContiguousFromAny(
623 o, PyArray_ISFLOAT(reinterpret_cast<PyArrayObject *>(o)) ? NPY_FLOAT : NPY_INT, 0, 0));
624 #endif
625 PyObject *ret = m_Multiply_map_impl(self, _o);
626 #if PY_VERSION_HEX < 0x03000000
627 Py_DECREF(_o);
628 #endif
629 return ret;
630 } else {
631 return m_mult_div(self, o, mmult, false);
632 }
633 }
634
635 PyObject *mdarray::m_InPlaceMultiply(PyObject *self, PyObject *o) {
636 if (!is_mdarray_supported(self, o)) {
637 return m_InPlaceMultiply_map_impl(self, o);
638 } else if (PyArray_Check(o) &&
639 !PyArray_IS_C_CONTIGUOUS(reinterpret_cast<PyArrayObject *>(o))) {
640 PyObject *_o = o;
641 #if PY_VERSION_HEX < 0x03000000
642 _o = reinterpret_cast<PyObject *>(PyArray_ContiguousFromAny(
643 o, PyArray_ISFLOAT(reinterpret_cast<PyArrayObject *>(o)) ? NPY_FLOAT : NPY_INT, 0, 0));
644 #endif
645 PyObject *ret = m_InPlaceMultiply_map_impl(self, _o);
646 #if PY_VERSION_HEX < 0x03000000
647 Py_DECREF(_o);
648 #endif
649 return ret;
650 } else {
651 return m_mult_div(self, o, mmult, true);
652 }
653 }
654
655 PyObject *mdarray::m_Divide(PyObject *self, PyObject *o) {
656 if (!is_mdarray_supported(self, o)) {
657 return m_Divide_map_impl(self, o);
658 } else if (PyArray_Check(o) &&
659 !PyArray_IS_C_CONTIGUOUS(reinterpret_cast<PyArrayObject *>(o))) {
660 PyObject *_o = o;
661 #if PY_VERSION_HEX < 0x03000000
662 _o = reinterpret_cast<PyObject *>(PyArray_ContiguousFromAny(
663 o, PyArray_ISFLOAT(reinterpret_cast<PyArrayObject *>(o)) ? NPY_FLOAT : NPY_INT, 0, 0));
664 #endif
665 PyObject *ret = m_Divide_map_impl(self, _o);
666 #if PY_VERSION_HEX < 0x03000000
667 Py_DECREF(_o);
668 #endif
669 return ret;
670 } else {
671 return m_mult_div(self, o, mdiv, false);
672 }
673 }
674
675 PyObject *mdarray::m_InPlaceDivide(PyObject *self, PyObject *o) {
676 if (!is_mdarray_supported(self, o)) {
677 return m_InPlaceDivide_map_impl(self, o);
678 } else if (PyArray_Check(o) &&
679 !PyArray_IS_C_CONTIGUOUS(reinterpret_cast<PyArrayObject *>(o))) {
680 PyObject *_o = o;
681 #if PY_VERSION_HEX < 0x03000000
682 _o = reinterpret_cast<PyObject *>(PyArray_ContiguousFromAny(
683 o, PyArray_ISFLOAT(reinterpret_cast<PyArrayObject *>(o)) ? NPY_FLOAT : NPY_INT, 0, 0));
684 #endif
685 PyObject *ret = m_InPlaceDivide_map_impl(self, _o);
686 #if PY_VERSION_HEX < 0x03000000
687 Py_DECREF(_o);
688 #endif
689 return ret;
690 } else {
691 return m_mult_div(self, o, mdiv, true);
692 }
693 }
694
695 int mdarray::build_view(Py_buffer *view, int flags, const reorderer &reorder) {
696 view->buf = reorder.data();
697 view->itemsize = get_view_itemsize();
698 view->readonly = 0;
699 view->internal = nullptr;
700 view->len = get_size();
701
702 if ((flags & PyBUF_FORMAT) == PyBUF_FORMAT) {
703 view->format = const_cast<char *>(get_view_format());
704 } else {
705 view->format = nullptr;
706 }
707
708 if ((flags & PyBUF_ND) == PyBUF_ND) {
709 view->ndim = ndims();
710 view->shape = const_cast<Py_ssize_t *>(get_view_shape());
711 } else {
712 view->ndim = 0;
713 view->shape = nullptr;
714 }
715
716 if ((flags & PyBUF_STRIDES) == PyBUF_STRIDES) {
717 view->strides = const_cast<Py_ssize_t *>(get_view_strides(view->itemsize));
718 } else {
719 view->strides = nullptr;
720 }
721
722 view->suboffsets = nullptr;
723
724 return 0;
725 }
726
727 int mdarray::getbuffer(PyObject *self, Py_buffer *view, int flags) {
728 if ((flags & PyBUF_F_CONTIGUOUS) == PyBUF_F_CONTIGUOUS) {
729 PyErr_SetString(PyExc_ValueError, "carray is not Fortran contiguous");
730 return -1;
731 }
732
733 if (view == nullptr) {
734 PyErr_SetString(PyExc_ValueError, "NULL view in getbuffer");
735 return -1;
736 }
737
738 if (PyType_reorder_buffer == nullptr) {
739 PyErr_SetString(PyExc_NameError, "name 'reorderer' is not defined");
740 return -1;
741 }
742
743 PyObject *rbobj = nullptr;
744 reorderer *rb = nullptr;
745 // Check entity or view array
746 if (view_.get()) {
747 // Share view(rb) from view array
748 rbobj = view_->obj;
749
750 if (!rbobj) {
751 PyErr_SetString(PyExc_RuntimeError, "No buffer management entity in buffer view");
752 return -1;
753 }
754
755 // Check current view from ndarray or mdarray
756 // We don't know what is the exporting object from source array
757 // If the exporting object is dropped out from mdarray, obj is rb
758 if (PyObject_IsInstance(rbobj, PyType_reorder_buffer)) {
759 int res = SWIG_ConvertPtr(rbobj, reinterpret_cast<void **>(&rb), nullptr, 0);
760 if (!SWIG_IsOK(res)) {
761 PyErr_SetString(PyExc_RuntimeError, "Can't get C++ object from python object");
762 return -1;
763 }
764 }
765
766 // Increase reference for new view
767 // cur_view->rbobj ref_cnt = n
768 // new_view->rbobj ref_cnt = n + 1
769 Py_INCREF(rbobj);
770 }
771
772 if (!rb) {
773 // Create view(rb) from entity array
774 // reorderer type object
775 PyObject *argList = Py_BuildValue("(O)", self);
776 if (argList == nullptr) {
777 return -1;
778 }
779
780 rbobj = PyObject_CallObject(PyType_reorder_buffer, argList);
781 Py_DECREF(argList);
782
783 if (rbobj == nullptr) {
784 return -1;
785 }
786
787 int res = SWIG_ConvertPtr(rbobj, reinterpret_cast<void **>(&rb), nullptr, 0);
788 if (!SWIG_IsOK(res)) {
789 PyErr_SetString(PyExc_RuntimeError, "Can't get C++ object from python object");
790 return -1;
791 }
792
793 if (rb->non_trivial()) {
794 rb->fire(*this);
795
796 buff_ = rb->data_;
797 init({get_dims(), get_data_type(),
798 descriptor::public_compatible_format(get_descriptor())},
799 reinterpret_cast<void *>(rb->data()));
800 // mdarray in internal format has no view
801 // view_.reset();
802 }
803 }
804
805 // FIXED: cannot copy directly
806 // In some case, operations on mdarray just make impacts on tensor,
807 // not the mdarray(view), e.g. `reshape`. So it is necessary to
808 // create a rb to build a new view for consumer.
809 // memcpy((void *)view, (void *)view_.get(), sizeof(Py_buffer));
810 if (build_view(view, flags, *rb)) {
811 PyErr_SetString(PyExc_RuntimeError, "Can't build Py_buffer!");
812 Py_DECREF(rbobj);
813 return -1;
814 }
815
816 // Stolen reference
817 // PyBuffer_Release helps to decrease its reference
818 view->obj = rbobj;
819
820 // avoiding AVX-SSE Transition Penalties
821 _mm256_zeroupper();
822 return 0;
823 }
824
825 PyObject *mdarray::getattro(PyObject *self, PyObject *name) {
826 // XXX: Recursive alarm !!! XXX
827 PyObject *surrogate = PyArray_FromAny(self, nullptr, 0, 0
828 , NPY_ARRAY_ELEMENTSTRIDES, nullptr);
829
830 if (surrogate == nullptr)
831 return nullptr;
832
833 // Watch the reference count of surrogate if more compicated
834 // looking up method involved
835 PyObject * attr = PyObject_GetAttr(surrogate, name);
836
837 // The surrogate will be destroyed after attribute is done
838 Py_DECREF(surrogate);
839
840 if (attr == nullptr && PyErr_ExceptionMatches(PyExc_AttributeError)) {
841 PyErr_Clear();
842
843 // Switch to our exception message if things gone wrong
844 PyTypeObject *tp = Py_TYPE(self);
845 PyErr_Format(PyExc_AttributeError
846 , "mdarray '%.50s' object has no attribute '%p'", tp->tp_name, name);
847 }
848
849 return attr;
850 }
851
852 Py_ssize_t mdarray::mp_length(PyObject *self) {
853 PyObject *surrogate = PyArray_FromAny(self, nullptr, 0, 0
854 , NPY_ARRAY_ELEMENTSTRIDES, nullptr);
855
856 if (surrogate == nullptr)
857 return -1;
858
859 Py_ssize_t len = PyMapping_Length(surrogate);
860 Py_DECREF(surrogate);
861
862 // TODO: Exception localize
863 return len;
864 }
865
866 PyObject *mdarray::mp_subscript(PyObject *self, PyObject *op) {
867 PyObject *surrogate = PyArray_FromAny(self, nullptr, 0, 0
868 , NPY_ARRAY_ELEMENTSTRIDES, nullptr);
869
870 if (surrogate == nullptr)
871 return nullptr;
872
873 PyObject *ret = PyObject_GetItem(surrogate, op);
874 Py_DECREF(surrogate);
875
876 // TODO: Exception localize
877 return ret;
878 }
879
880 int mdarray::mp_ass_subscript(PyObject *self, PyObject *ind, PyObject *op) {
881 PyObject *surrogate = PyArray_FromAny(self, nullptr, 0, 0
882 , NPY_ARRAY_ELEMENTSTRIDES, nullptr);
883
884 int ret;
885
886 if (surrogate == nullptr)
887 return -1;
888
889 if (op == nullptr)
890 ret = PyObject_DelItem(surrogate, ind);
891 else
892 ret = PyObject_SetItem(surrogate, ind, op);
893
894 Py_DECREF(surrogate);
895
896 // TODO: Exception localize
897 return ret;
898 }
899
900 PyObject *mdarray::flat() {
901 long int dims[1] = {static_cast<long int>(this->get_nelems())};
902
903 int typenum = NPY_NOTYPE;
904 switch(get_data_type()) {
905 case data_type_t::f32:
906 typenum = NPY_FLOAT32;
907 break;
908 case data_type_t::s32:
909 typenum = NPY_INT;
910 break;
911 case data_type_t::s16:
912 typenum = NPY_INT16;
913 break;
914 case data_type_t::s8:
915 typenum = NPY_INT8;
916 break;
917 case data_type_t::u8:
918 typenum = NPY_UINT8;
919 break;
920 default:
921 PyErr_SetString(PyExc_ValueError, "Bad mdarray data_type");
922 break;
923 }
924
925 PyObject *plain_arr = PyArray_SimpleNewFromData(1, dims, typenum, this->get_data_handle());
926 if (!plain_arr)
927 PyErr_SetString(PyExc_ValueError, "Can't create plain array with format from mdarray");
928
929 return plain_arr;
930 }
931
932 PyObject *mdarray::reshape(py_handle *self, std::vector<int> dims)
933 {
934 if (dims.size() != 4 && dims.size() != 2) {
935 PyErr_SetString(PyExc_ValueError,"Only support reshape to 2 dimension");
936 return nullptr;
937 }
938 int idx_unknown = -1;
939 size_t size = 1;
940 for (unsigned int i = 0; i < dims.size(); i++) {
941 if (dims[i] < 0) {
942 if (idx_unknown == -1) {
943 idx_unknown = i;
944 } else {
945 PyErr_SetString(PyExc_ValueError,"Only support 1 unkown dimension");
946 return nullptr;
947 }
948 } else {
949 size *= dims[i];
950 }
951 }
952 if (idx_unknown == -1) {
953 if (size != (size_t)this->get_nelems()) {
954 PyErr_SetString(PyExc_ValueError,"Wrong dimension to reshape");
955 return nullptr;
956 }
957 } else if (this->get_nelems() % size) {
958 PyErr_SetString(PyExc_ValueError,"Wrong dimension to reshape");
959 return nullptr;
960 } else {
961 dims[idx_unknown] = this->get_nelems() / size;
962 }
963
964 // Same bahavior as numpy ndarray
965 // Share memory between src and dst array
966 auto o = new mdarray(*this);
967 o->_reshape(dims);
968
969 if (!is_public_format()) {
970 buff_ = o->get_tensor_buffer();
971 o->set_shared_buff(buff_);
972
973 // update src mdarray
974 set_shared_buff(buff_);
975 set_descriptor({get_dims(), get_data_type()});
976 set_data_handle(o->get_data_handle());
977 set_tensor_buffer(buff_);
978
979 // mdarray becomes entity, free view
980 if (view_.get()) {
981 view_.reset();
982 o->view_.reset();
983 }
984 }
985 PyObject *resultobj = SWIG_Python_NewPointerObj(nullptr,
986 SWIG_as_voidptr(new py_handle(o)), SwigTy_mdarray, SWIG_POINTER_OWN | 0);
987 return resultobj;
988 }
989
990 PyObject *mdarray::sum(std::vector<int> axis, bool keepdims)
991 {
992 err_num_t e;
993
994 auto tensor = sum_array::compute(*this, axis, e);
995 if (e != err_num_t::NOERR)
996 return nullptr;
997
998 if (keepdims) {
999 std::vector<int> expected_shape;
1000 for (int v = 0; v < this->ndims(); v++)
1001 expected_shape.push_back(this->get_dims()[v]);
1002
1003 for (unsigned a = 0; a < axis.size(); a++)
1004 expected_shape[axis[a]] = 1;
1005
1006 tensor.reshape(expected_shape);
1007 }
1008
1009 auto output = new py_handle(new mdarray(tensor));
1010 auto resultobj = SWIG_Python_NewPointerObj(nullptr,
1011 SWIG_as_voidptr(output), SwigTy_mdarray,
1012 SWIG_POINTER_OWN | 0);
1013 return resultobj;
1014 }
1015
1016 bool mdarray::is_mdarray(PyObject *o)
1017 {
1018 return (reinterpret_cast<PyTypeObject *>(o->ob_type)
1019 == reinterpret_cast<PyTypeObject *>(PyType_mdarray));
1020 }
1021
1022 }
1023