1 /*
2 MNN python module
3 PYMNN_EXPR_API: MNN.expr, MNN.nn
4 PYMNN_TRAIN_API: MNN.nn.compress, MNN.nn.loss, MNN.data, MNN.optim
5 */
6 #include "MNNPyBridge.h"
7 #include "common.h"
8 #include "util.h"
9
10 static int tls_key = 0;
11 static int tls_key_2 = 0;
12
13 #ifdef PYMNN_EXPR_API
14 #ifdef PYMNN_USE_ALINNPYTHON
15 #include "pybind_private/pybind11.h"
16 #include "pybind_private/stl.h"
17 #include "pybind_private/operators.h"
18 #else
19 #include "pybind11/pybind11.h"
20 #include "pybind11/stl.h"
21 #include "pybind11/operators.h"
22 #endif // PYMNN_USE_ALINNPYTHON
23 #endif // PYMNN_EXPR_API
24
25 #include <MNN/Interpreter.hpp>
26 #include <MNN/ImageProcess.hpp>
27 #ifdef PYMNN_EXPR_API
28 namespace py = pybind11;
29 #include <MNN/expr/Expr.hpp>
30 #include <MNN/expr/ExprCreator.hpp>
31 #include <MNN/expr/Executor.hpp>
32 //#include <MNN/expr/ExecutorScope.hpp>
33 #include <MNN/expr/Module.hpp>
34 using namespace MNN::Express;
35 #endif // PYMNN_EXPR_API
36
37 #ifdef BUILD_OPTYPE
38 #include "MNN_generated.h"
39 #endif // BUILD_OPTYPE
40
41 #ifdef PYMNN_TRAIN_API
42 #include "NN.hpp"
43 #include "OpGrad.hpp"
44 #include "ParameterOptimizer.hpp"
45 #include "SGD.hpp"
46 #include "ADAM.hpp"
47 #include "Dataset.hpp"
48 #include "DataLoader.hpp"
49 #include "Loss.hpp"
50 #include "Transformer.hpp"
51 #include "PipelineModule.hpp"
52 #include "cpp/ConvertToFullQuant.hpp"
53 using namespace MNN::Train;
54 #endif // PYMNN_TRAIN_API
55
56 #include <mutex>
57 #include <unordered_map>
58
59 using namespace MNN;
60
61 using namespace std;
62
63 struct MNN_TLSData {
64 PyObject *PyMNNHalideTypeInt = NULL;
65 PyObject *PyMNNHalideTypeInt64 = NULL;
66 PyObject *PyMNNHalideTypeFloat = NULL;
67 PyObject *PyMNNHalideTypeDouble = NULL;
68 PyObject *PyMNNHalideTypeUint8 = NULL;
69 PyObject *PyMNNHalideTypeString = NULL;
70 std::unordered_map<std::string, Interpreter *> *interpreterMap = NULL;
71 std::unordered_map<std::string, Session *> *sessionCacheMap = NULL;
72 };
73 static MNN_TLSData* old_python_data = NULL;
getTLSData()74 static MNN_TLSData * getTLSData() {
75 if(global_new_python_flag > 0) {
76 return static_cast<MNN_TLSData*>(PyThread_get_key_value(tls_key));
77 }else{
78 return old_python_data;
79 }
80 }
setTLSData(MNN_TLSData * tlsData)81 static void setTLSData(MNN_TLSData* tlsData) {
82 if(global_new_python_flag > 0) {
83 PyThread_set_key_value(tls_key, tlsData);
84 } else {
85 old_python_data = tlsData;
86 }
87 }
88
89 #if defined(PYMNN_EXPR_API) && defined(PYMNN_USE_ALINNPYTHON)
set_rh_tls_data(py::detail::rh_tls * rh_tls)90 static void set_rh_tls_data(py::detail::rh_tls* rh_tls) {
91 if(global_new_python_flag > 0) {
92 PyThread_set_key_value(tls_key_2, rh_tls);
93 } else {
94 py::detail::old_rh_tls_data = rh_tls;
95 }
96 }
97 #endif
98
99 #ifdef PYMNN_EXPR_API
100 namespace py = pybind11;
101 #endif
importName(const char * name,const char * symbol)102 static PyObject *importName(const char *name, const char *symbol)
103 {
104 PyObject *u_name, *module;
105 u_name = PyUnicode_FromString(name);
106 module = PyImport_Import(u_name);
107 if (!module) {
108 return NULL;
109 }
110 Py_DECREF(u_name);
111 return PyObject_GetAttrString(module, symbol);
112 }
113
114 typedef struct {
115 PyObject_HEAD
116 std::string *modelPath;
117 Interpreter *interpreter;
118 } PyMNNInterpreter;
119
120 typedef struct {
121 PyObject_HEAD
122 std::string *modelPath;
123 Session *session;
124 } PyMNNSession;
125
126 typedef struct {
127 PyObject_HEAD
128 Tensor *tensor;
129 int owner;
130 } PyMNNTensor;
131
132 typedef struct {
133 PyObject_HEAD
134 CV::ImageProcess *imageProcess;
135 } PyMNNCVImageProcess;
136
137 typedef struct {
138 PyObject_HEAD
139 CV::Matrix *matrix;
140 } PyMNNCVMatrix;
141
142 typedef struct {
143 PyObject_HEAD
144 const OperatorInfo *opInfo;
145 }PyMNNOpInfo;
httInt()146 halide_type_t* httInt() {
147 static halide_type_t httInt = halide_type_of<int>();
148 return &httInt;
149 }
150
httInt64()151 halide_type_t* httInt64() {
152 static halide_type_t httInt64 = halide_type_of<int64_t>();
153 return &httInt64;
154 }
155
httFloat()156 halide_type_t* httFloat() {
157 static halide_type_t httFloat = halide_type_of<float>();
158 return &httFloat;
159 }
160
httDouble()161 halide_type_t* httDouble() {
162 static halide_type_t httDouble = halide_type_of<double>();
163 return &httDouble;
164 }
165
httUint8()166 halide_type_t* httUint8() {
167 static halide_type_t httUint8 = halide_type_of<uint8_t>();
168 return &httUint8;
169 }
170
httString()171 halide_type_t* httString() {
172 static halide_type_t httString = halide_type_t(halide_type_handle, sizeof(void*)*8);
173 return &httString;
174 }
175
176 /// MNN NetInstance Type
177 static PyObject* PyMNNInterpreter_createSession(PyMNNInterpreter *self, PyObject *args);
178 static PyObject* PyMNNInterpreter_resizeSession(PyMNNInterpreter *self, PyObject *args);
179 static PyObject* PyMNNInterpreter_resizeTensor(PyMNNInterpreter *self, PyObject *args);
180 static PyObject* PyMNNInterpreter_runSession(PyMNNInterpreter *self, PyObject *args);
181 static PyObject* PyMNNInterpreter_runSessionWithCallBack(PyMNNInterpreter *self, PyObject *args);
182 static PyObject* PyMNNInterpreter_runSessionWithCallBackInfo(PyMNNInterpreter *self, PyObject *args);
183 static PyObject* PyMNNInterpreter_getSessionInput(PyMNNInterpreter *self, PyObject *args);
184 static PyObject* PyMNNInterpreter_getSessionOutput(PyMNNInterpreter *self, PyObject *args);
185 static PyObject* PyMNNInterpreter_getSessionInputAll(PyMNNInterpreter *self, PyObject *args);
186 static PyObject* PyMNNInterpreter_getSessionOutputAll(PyMNNInterpreter *self, PyObject *args);
187 #ifndef PYMNN_USE_ALINNPYTHON
188 static PyObject* PyMNNInterpreter_setCacheFile(PyMNNInterpreter *self, PyObject *args);
189 #endif
190 static PyObject* PyMNNInterpreter_cache(PyMNNInterpreter *self, PyObject *args);
191 static PyObject* PyMNNInterpreter_removeCache(PyMNNInterpreter *self, PyObject *args);
192 static PyObject* PyMNNInterpreter_updateSessionToModel(PyMNNInterpreter *self, PyObject *args);
193 static PyObject* PyMNNInterpreter_new(struct _typeobject *type, PyObject *args, PyObject *kwds);
194 static int PyMNNInterpreter_init(PyMNNInterpreter *self, PyObject *args, PyObject *kwds);
195 static void PyMNNInterpreter_dealloc(PyMNNInterpreter *);
196
197 static PyMethodDef PyMNNInterpreter_methods[] = {
198 {"createSession", (PyCFunction)PyMNNInterpreter_createSession, METH_VARARGS, "create session"},
199 #ifndef PYMNN_USE_ALINNPYTHON
200 {"setCacheFile", (PyCFunction)PyMNNInterpreter_setCacheFile, METH_VARARGS, "set cache file for create session"},
201 #endif
202 {"resizeSession", (PyCFunction)PyMNNInterpreter_resizeSession, METH_VARARGS, "resize session"},
203 {"runSession", (PyCFunction)PyMNNInterpreter_runSession, METH_VARARGS, "run session"},
204 {"runSessionWithCallBack", (PyCFunction)PyMNNInterpreter_runSessionWithCallBack, METH_VARARGS, "run session with callback"},
205 {"runSessionWithCallBackInfo", (PyCFunction)PyMNNInterpreter_runSessionWithCallBackInfo, METH_VARARGS, "run session with callback info"},
206 {"getSessionOutput", (PyCFunction)PyMNNInterpreter_getSessionOutput, METH_VARARGS, "get session output"},
207 {"getSessionInput", (PyCFunction)PyMNNInterpreter_getSessionInput, METH_VARARGS, "get session input"},
208 {"getSessionOutputAll", (PyCFunction)PyMNNInterpreter_getSessionOutputAll, METH_VARARGS, "get session output all"},
209 {"getSessionInputAll", (PyCFunction)PyMNNInterpreter_getSessionInputAll, METH_VARARGS, "get session input all"},
210 {"resizeTensor", (PyCFunction)PyMNNInterpreter_resizeTensor, METH_VARARGS, "resize tensor"},
211 {"cache", (PyCFunction)PyMNNInterpreter_cache, METH_VARARGS, "cache current net instance"},
212 {"removeCache", (PyCFunction)PyMNNInterpreter_removeCache, METH_VARARGS, "remove cache with given path"},
213 {"updateSessionToModel", (PyCFunction)PyMNNInterpreter_updateSessionToModel, METH_VARARGS, "updateSessionToModel"},
214 {NULL} /* Sentinel */
215 };
216
217 static PyTypeObject PyMNNInterpreterType = {
218 PyVarObject_HEAD_INIT(NULL, 0)
219 "MNN.Interpreter", /*tp_name*/
220 sizeof(PyMNNInterpreter), /*tp_basicsize*/
221 0, /*tp_itemsize*/
222 (destructor)PyMNNInterpreter_dealloc, /*tp_dealloc*/
223 0, /*tp_print*/
224 0, /*tp_getattr*/
225 0, /*tp_setattr*/
226 0, /*tp_compare*/
227 0, /*tp_repr*/
228 0, /*tp_as_number*/
229 0, /*tp_as_sequence*/
230 0, /*tp_as_mapping*/
231 0, /*tp_hash */
232 0, /*tp_call*/
233 0, /*tp_str*/
234 0, /*tp_getattro*/
235 0, /*tp_setattro*/
236 0, /*tp_as_buffer*/
237 Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE, /*tp_flags*/
238 "MNN Interpreter objects", /* tp_doc */
239 0, /* tp_traverse */
240 0, /* tp_clear */
241 0, /* tp_richcompare */
242 0, /* tp_weaklistoffset */
243 0, /* tp_iter */
244 0, /* tp_iternext */
245 PyMNNInterpreter_methods, /* tp_methods */
246 0, /* tp_members */
247 0, /* tp_getset */
248 0, /* tp_base */
249 0, /* tp_dict */
250 0, /* tp_descr_get */
251 0, /* tp_descr_set */
252 0, /* tp_dictoffset */
253 (initproc)PyMNNInterpreter_init, /* tp_init */
254 0, /* tp_alloc */
255 PyMNNInterpreter_new, /* tp_new */
256 };
257
258 /// MNN Session Type
259 static PyObject* PyMNNSession_new(struct _typeobject *type, PyObject *args, PyObject *kwds);
260 static void PyMNNSession_dealloc(PyMNNSession *);
261 static PyObject* PyMNNSession_cache(PyMNNSession *self, PyObject *args);
262 static PyObject* PyMNNSession_removeCache(PyMNNSession *self, PyObject *args);
263
264 static PyMethodDef PyMNNSession_methods[] = {
265 {"cache", (PyCFunction)PyMNNSession_cache, METH_VARARGS, "cache current session instance"},
266 {"removeCache", (PyCFunction)PyMNNSession_removeCache, METH_VARARGS, "remove session cache with given path"},
267 {NULL} /* Sentinel */
268 };
269
270 static PyTypeObject PyMNNSessionType = {
271 PyVarObject_HEAD_INIT(NULL, 0)
272 "MNN.Session", /*tp_name*/
273 sizeof(PyMNNSession), /*tp_basicsize*/
274 0, /*tp_itemsize*/
275 (destructor)PyMNNSession_dealloc, /*tp_dealloc*/
276 0, /*tp_print*/
277 0, /*tp_getattr*/
278 0, /*tp_setattr*/
279 0, /*tp_compare*/
280 0, /*tp_repr*/
281 0, /*tp_as_number*/
282 0, /*tp_as_sequence*/
283 0, /*tp_as_mapping*/
284 0, /*tp_hash */
285 0, /*tp_call*/
286 0, /*tp_str*/
287 0, /*tp_getattro*/
288 0, /*tp_setattro*/
289 0, /*tp_as_buffer*/
290 Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE, /*tp_flags*/
291 "MNN Session objects", /* tp_doc */
292 0, /* tp_traverse */
293 0, /* tp_clear */
294 0, /* tp_richcompare */
295 0, /* tp_weaklistoffset */
296 0, /* tp_iter */
297 0, /* tp_iternext */
298 PyMNNSession_methods, /* tp_methods */
299 0, /* tp_members */
300 0, /* tp_getset */
301 0, /* tp_base */
302 0, /* tp_dict */
303 0, /* tp_descr_get */
304 0, /* tp_descr_set */
305 0, /* tp_dictoffset */
306 0, /* tp_init */
307 0, /* tp_alloc */
308 PyMNNSession_new, /* tp_new */
309 };
310
311 /// MNN Tensor Type
312 static PyObject* PyMNNTensor_new(struct _typeobject *type, PyObject *args, PyObject *kwds);
313 static void PyMNNTensor_dealloc(PyMNNTensor *);
314 static int PyMNNTensor_init(PyMNNTensor *self, PyObject *args, PyObject *kwds);
315 #ifdef PYMNN_NUMPY_USABLE
316 static PyObject* PyMNNTensor_fromNumpy(PyMNNTensor *self, PyObject *args);
317 static PyObject* PyMNNTensor_getNumpyData(PyMNNTensor *self, PyObject *args);
318 #endif
319 static PyObject* PyMNNTensor_printTensorData(PyMNNTensor *self, PyObject *args);
320 static PyObject* PyMNNTensor_getShape(PyMNNTensor *self, PyObject *args);
321 static PyObject* PyMNNTensor_getDataType(PyMNNTensor *self, PyObject *args);
322 static PyObject* PyMNNTensor_getDimensionType(PyMNNTensor *self, PyObject *args);
323 static PyObject* PyMNNTensor_getData(PyMNNTensor *self, PyObject *args);
324 static PyObject* PyMNNTensor_getHost(PyMNNTensor *self, PyObject *args);
325 static PyObject* PyMNNTensor_copyFrom(PyMNNTensor *self, PyObject *args);
326 static PyObject* PyMNNTensor_copyToHostTensor(PyMNNTensor *self, PyObject *args);
327
328 static PyMethodDef PyMNNTensor_methods[] = {
329 #ifdef PYMNN_NUMPY_USABLE
330 {"fromNumpy", (PyCFunction)PyMNNTensor_fromNumpy, METH_VARARGS, "copy data from numpy"},
331 {"getNumpyData", (PyCFunction)PyMNNTensor_getNumpyData, METH_NOARGS, "get tensor data (numpy)"},
332 #endif
333 {"printTensorData", (PyCFunction)PyMNNTensor_printTensorData, METH_NOARGS, "print tensor data"},
334 {"getShape", (PyCFunction)PyMNNTensor_getShape, METH_NOARGS, "get tensor shape"},
335 {"getDataType", (PyCFunction)PyMNNTensor_getDataType, METH_NOARGS, "get tensor data type"},
336 {"getData", (PyCFunction)PyMNNTensor_getData, METH_NOARGS, "get tensor data (tuple)"},
337 {"getHost", (PyCFunction)PyMNNTensor_getHost, METH_NOARGS, "get tensor host"},
338 {"getDimensionType", (PyCFunction)PyMNNTensor_getDimensionType, METH_NOARGS, "get dimension data"},
339 {"copyFrom", (PyCFunction)PyMNNTensor_copyFrom, METH_VARARGS, "copy data from host tensor"},
340 {"copyFromHostTensor", (PyCFunction)PyMNNTensor_copyFrom, METH_VARARGS, "copy data from host tensor"},
341 {"copyToHostTensor", (PyCFunction)PyMNNTensor_copyToHostTensor, METH_VARARGS, "copy data to host tensor"},
342 {NULL} /* Sentinel */
343 };
344
345 static PyTypeObject PyMNNTensorType = {
346 PyVarObject_HEAD_INIT(NULL, 0)
347 "MNN.Tensor", /*tp_name*/
348 sizeof(PyMNNTensor), /*tp_basicsize*/
349 0, /*tp_itemsize*/
350 (destructor)PyMNNTensor_dealloc, /*tp_dealloc*/
351 0, /*tp_print*/
352 0, /*tp_getattr*/
353 0, /*tp_setattr*/
354 0, /*tp_compare*/
355 0, /*tp_repr*/
356 0, /*tp_as_number*/
357 0, /*tp_as_sequence*/
358 0, /*tp_as_mapping*/
359 0, /*tp_hash */
360 0, /*tp_call*/
361 0, /*tp_str*/
362 0, /*tp_getattro*/
363 0, /*tp_setattro*/
364 0, /*tp_as_buffer*/
365 Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE, /*tp_flags*/
366 "MNN Tensor objects", /* tp_doc */
367 0, /* tp_traverse */
368 0, /* tp_clear */
369 0, /* tp_richcompare */
370 0, /* tp_weaklistoffset */
371 0, /* tp_iter */
372 0, /* tp_iternext */
373 PyMNNTensor_methods, /* tp_methods */
374 0, /* tp_members */
375 0, /* tp_getset */
376 0, /* tp_base */
377 0, /* tp_dict */
378 0, /* tp_descr_get */
379 0, /* tp_descr_set */
380 0, /* tp_dictoffset */
381 (initproc)PyMNNTensor_init, /* tp_init */
382 0, /* tp_alloc */
383 PyMNNTensor_new, /* tp_new */
384 };
385
386 /// MNN ImageProcess Type
387 static PyObject* PyMNNCVImageProcess_new(struct _typeobject *type, PyObject *args, PyObject *kwds);
388 static void PyMNNCVImageProcess_dealloc(PyMNNCVImageProcess *);
389 static int PyMNNCVImageProcess_init(PyMNNCVImageProcess *self, PyObject *args, PyObject *kwds);
390 static PyObject* PyMNNCVImageProcess_setMatrix(PyMNNCVImageProcess *self, PyObject *args);
391 static PyObject* PyMNNCVImageProcess_convert(PyMNNCVImageProcess *self, PyObject *args);
392 static PyObject* PyMNNCVImageProcess_createImageTensor(PyMNNCVImageProcess *self, PyObject *args);
393
394 static PyMethodDef PyMNNCVImageProcess_methods[] = {
395 {"setMatrix", (PyCFunction)PyMNNCVImageProcess_setMatrix, METH_VARARGS, "ImageProcess setMatrix"},
396 {"convert", (PyCFunction)PyMNNCVImageProcess_convert, METH_VARARGS, "ImageProcess convert"},
397 {"createImageTensor", (PyCFunction)PyMNNCVImageProcess_createImageTensor, METH_VARARGS, "ImageProcess create Image Tensor"},
398 {NULL} /* Sentinel */
399 };
400
401 static PyTypeObject PyMNNCVImageProcessType = {
402 PyVarObject_HEAD_INIT(NULL, 0)
403 "MNN.CVImageProcess", /*tp_name*/
404 sizeof(PyMNNCVImageProcess), /*tp_basicsize*/
405 0, /*tp_itemsize*/
406 (destructor)PyMNNCVImageProcess_dealloc, /*tp_dealloc*/
407 0, /*tp_print*/
408 0, /*tp_getattr*/
409 0, /*tp_setattr*/
410 0, /*tp_compare*/
411 0, /*tp_repr*/
412 0, /*tp_as_number*/
413 0, /*tp_as_sequence*/
414 0, /*tp_as_mapping*/
415 0, /*tp_hash */
416 0, /*tp_call*/
417 0, /*tp_str*/
418 0, /*tp_getattro*/
419 0, /*tp_setattro*/
420 0, /*tp_as_buffer*/
421 Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE, /*tp_flags*/
422 "MNN CVImageProcess objects", /* tp_doc */
423 0, /* tp_traverse */
424 0, /* tp_clear */
425 0, /* tp_richcompare */
426 0, /* tp_weaklistoffset */
427 0, /* tp_iter */
428 0, /* tp_iternext */
429 PyMNNCVImageProcess_methods, /* tp_methods */
430 0, /* tp_members */
431 0, /* tp_getset */
432 0, /* tp_base */
433 0, /* tp_dict */
434 0, /* tp_descr_get */
435 0, /* tp_descr_set */
436 0, /* tp_dictoffset */
437 (initproc)PyMNNCVImageProcess_init, /* tp_init */
438 0, /* tp_alloc */
439 PyMNNCVImageProcess_new, /* tp_new */
440 };
441
442 /// MNN CVMatrix Type
443 static PyObject* PyMNNCVMatrix_new(struct _typeobject *type, PyObject *args, PyObject *kwds);
444 static void PyMNNCVMatrix_dealloc(PyMNNCVMatrix *);
445 /// scale
446 static PyObject* PyMNNCVMatrix_setScale(PyMNNCVMatrix *, PyObject *args);
447 static PyObject* PyMNNCVMatrix_preScale(PyMNNCVMatrix *, PyObject *args);
448 static PyObject* PyMNNCVMatrix_postScale(PyMNNCVMatrix *, PyObject *args);
449 /// rotate
450 static PyObject* PyMNNCVMatrix_setRotate(PyMNNCVMatrix *, PyObject *args);
451 static PyObject* PyMNNCVMatrix_preRotate(PyMNNCVMatrix *, PyObject *args);
452 static PyObject* PyMNNCVMatrix_postRotate(PyMNNCVMatrix *, PyObject *args);
453 /// translate
454 static PyObject* PyMNNCVMatrix_setTranslate(PyMNNCVMatrix *, PyObject *args);
455 static PyObject* PyMNNCVMatrix_preTranslate(PyMNNCVMatrix *, PyObject *args);
456 static PyObject* PyMNNCVMatrix_postTranslate(PyMNNCVMatrix *, PyObject *args);
457
458 static PyObject* PyMNNCVMatrix_invert(PyMNNCVMatrix *);
459
460 static PyMethodDef PyMNNCVMatrix_methods[] = {
461 {"setScale", (PyCFunction)PyMNNCVMatrix_setScale, METH_VARARGS, "MNNCVMatrix setScale"},
462 {"preScale", (PyCFunction)PyMNNCVMatrix_preScale, METH_VARARGS, "MNNCVMatrix preScale"},
463 {"postScale", (PyCFunction)PyMNNCVMatrix_postScale, METH_VARARGS, "MNNCVMatrix postScale"},
464
465 {"setRotate", (PyCFunction)PyMNNCVMatrix_setRotate, METH_VARARGS, "MNNCVMatrix setRotate"},
466 {"preRotate", (PyCFunction)PyMNNCVMatrix_preRotate, METH_VARARGS, "MNNCVMatrix preRotate"},
467 {"postRotate", (PyCFunction)PyMNNCVMatrix_postRotate, METH_VARARGS, "MNNCVMatrix postRotate"},
468
469 {"setTranslate", (PyCFunction)PyMNNCVMatrix_setTranslate, METH_VARARGS, "MNNCVMatrix setTranslate"},
470 {"preTranslate", (PyCFunction)PyMNNCVMatrix_preTranslate, METH_VARARGS, "MNNCVMatrix preTranslate"},
471 {"postTranslate", (PyCFunction)PyMNNCVMatrix_postTranslate, METH_VARARGS, "MNNCVMatrix postTranslate"},
472
473 {"invert", (PyCFunction)PyMNNCVMatrix_invert, METH_VARARGS, "MNNCVMatrix invert"},
474 {NULL} /* Sentinel */
475 };
476
477 static PyTypeObject PyMNNCVMatrixType = {
478 PyVarObject_HEAD_INIT(NULL, 0)
479 "MNN.CVImageProcess", /*tp_name*/
480 sizeof(PyMNNCVMatrix), /*tp_basicsize*/
481 0, /*tp_itemsize*/
482 (destructor)PyMNNCVMatrix_dealloc, /*tp_dealloc*/
483 0, /*tp_print*/
484 0, /*tp_getattr*/
485 0, /*tp_setattr*/
486 0, /*tp_compare*/
487 0, /*tp_repr*/
488 0, /*tp_as_number*/
489 0, /*tp_as_sequence*/
490 0, /*tp_as_mapping*/
491 0, /*tp_hash */
492 0, /*tp_call*/
493 0, /*tp_str*/
494 0, /*tp_getattro*/
495 0, /*tp_setattro*/
496 0, /*tp_as_buffer*/
497 Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE, /*tp_flags*/
498 "MNN CVMatrix objects", /* tp_doc */
499 0, /* tp_traverse */
500 0, /* tp_clear */
501 0, /* tp_richcompare */
502 0, /* tp_weaklistoffset */
503 0, /* tp_iter */
504 0, /* tp_iternext */
505 PyMNNCVMatrix_methods, /* tp_methods */
506 0, /* tp_members */
507 0, /* tp_getset */
508 0, /* tp_base */
509 0, /* tp_dict */
510 0, /* tp_descr_get */
511 0, /* tp_descr_set */
512 0, /* tp_dictoffset */
513 0, /* tp_init */
514 0, /* tp_alloc */
515 PyMNNCVMatrix_new, /* tp_new */
516 };
517
518 /// MNN NetInstance implementation
519 // 用来缓存net的实例
520
interpreterMap()521 std::unordered_map<std::string, Interpreter *> *interpreterMap() {
522 // static std::unordered_map<std::string, Interpreter *> *interpreterMap = nullptr; // <path, instance>
523 // static std::once_flag flag;
524 // std::call_once(flag, [](){interpreterMap = new std::unordered_map<std::string, Interpreter *>();});
525 struct MNN_TLSData *tlsData = getTLSData();
526 if (tlsData == nullptr) {
527 return nullptr;
528 }
529 return tlsData->interpreterMap;
530 }
531
sessionCacheMap()532 std::unordered_map<std::string, Session *> *sessionCacheMap() {
533 struct MNN_TLSData *tlsData = getTLSData();
534 if (tlsData == nullptr) {
535 return nullptr;
536 }
537 return tlsData->sessionCacheMap;
538 }
539
540 namespace ec {
getVectorByKey(PyObject * dict,const char * key,std::vector<std::string> & result)541 int getVectorByKey(PyObject* dict, const char *key, std::vector<std::string>& result){
542 PyObject *saveTensors = PyDict_GetItemString(dict, key);
543 int count = 0;
544 if (saveTensors) {
545 if (!PyTuple_Check(saveTensors)) {
546 PyErr_SetString(PyExc_Exception,
547 "PyMNNInterpreter_createSession: saveTensors must be a tuple");
548 return -1;
549 }
550
551 size_t saveTensorsCount = PyTuple_Size(saveTensors);
552 for (size_t i = 0; i < saveTensorsCount; i++) {
553 PyObject *tensorNameItem = PyTuple_GetItem(saveTensors, i);
554 if (!checkString(tensorNameItem)) {
555 PyErr_SetString(PyExc_Exception,
556 "PyMNNInterpreter_createSession: saveTensors's member must be string");
557 return -1;
558 }
559
560
561 result.push_back(object2String(tensorNameItem));
562 count++;
563 }
564 }
565 return count;
566 }
567 }
568
PyMNNInterpreter_createSession(PyMNNInterpreter * self,PyObject * args)569 static PyObject* PyMNNInterpreter_createSession(PyMNNInterpreter *self, PyObject *args) {
570 PyMNNInterpreter* instance = (PyMNNInterpreter *)self;
571 PyObject* dict = NULL;
572 if (!PyArg_ParseTuple(args, "|O", &dict)) {
573 return NULL;
574 }
575
576 PyObject *f = importName("MNN", "Session");
577 if (!f || !PyCallable_Check(f)) {
578 PyErr_SetString(PyExc_Exception,
579 "PyMNNInterpreter_createSession: MNN.Session not found");
580 return NULL;
581 }
582
583 // create a new session
584 PyMNNSession *session = (PyMNNSession *)PyObject_Call(f, PyTuple_New(0), NULL);
585 if (!session) {
586 PyErr_SetString(PyExc_Exception,
587 "PyMNNInterpreter_createSession: MNN.Session instance create failed");
588 return NULL;
589 }
590
591 if (self->modelPath && (*sessionCacheMap())[*self->modelPath]) {
592 session->modelPath = self->modelPath;
593 session->session = (*sessionCacheMap())[*self->modelPath];
594 return (PyObject *)session;
595 }
596
597 ScheduleConfig config;
598 BackendConfig backendConfig;
599 config.backendConfig = &backendConfig;
600 if (dict) {
601 PyObject *backend = PyDict_GetItemString(dict, "backend");
602 config.type = MNN_FORWARD_CPU;
603 if (backend) {
604 auto backend_name = object2String(backend);
605 // Avoid misusing backend not supported by the bridge and corresponding MNN library on python level,
606 // then user will ask for right version bridge library to us, same like MNN.expr.Backend.* python enum
607 std::unordered_map<std::string, MNNForwardType> backend_map = {
608 // Don't care whether MNN library support corresponding backend, all backend type are usable by user,
609 // which make MNN.whl setup.py easy
610 {"CPU", MNN_FORWARD_CPU},
611 {"OPENCL", MNN_FORWARD_OPENCL},
612 {"OPENGL", MNN_FORWARD_OPENGL},
613 {"VULKAN", MNN_FORWARD_VULKAN},
614 {"METAL", MNN_FORWARD_METAL},
615 {"TRT", MNN_FORWARD_USER_1},
616 {"CUDA", MNN_FORWARD_CUDA},
617 {"HIAI", MNN_FORWARD_USER_0}
618 };
619 auto iter = backend_map.find(backend_name);
620 if (iter == backend_map.end()) {
621 // backend not support, issue on python level when development
622 PyErr_SetString(PyExc_Exception,
623 "PyMNNInterpreter_createSession: backend not support");
624 return NULL;
625 }
626 config.type = iter->second;
627 }
628 if(config.type == MNN_FORWARD_CPU) {
629 PyObject *numThread = PyDict_GetItemString(dict, "numThread");
630 if (numThread) {
631 if (!PyLong_Check(numThread)) {
632 PyErr_SetString(PyExc_Exception,
633 "PyMNNInterpreter_createSession: numThread must be a integer");
634 return NULL;
635 }
636 config.numThread = (int)PyLong_AsLong(numThread);
637 }
638 }
639
640 {
641 //precision
642 PyObject *obj = PyDict_GetItemString(dict, "precision");
643 if (obj) {
644 auto obj_name = object2String(obj);
645 if (!obj_name.compare("low")) {
646 MNN_PRINT("MNN use low precision\n");
647 backendConfig.precision = MNN::BackendConfig::Precision_Low;
648 }
649 }
650 }
651
652 if (-1 == ec::getVectorByKey(dict, "saveTensors", config.saveTensors)
653 || -1 == ec::getVectorByKey(dict, "inputPaths", config.path.inputs)
654 || -1 == ec::getVectorByKey(dict, "outputPaths", config.path.outputs)){
655 return NULL;
656 }
657
658 }
659
660 Session *s = instance->interpreter->createSession(config);
661 if (!s) {
662 PyErr_SetString(PyExc_Exception,
663 "PyMNNInterpreter_createSession: NetInstance createSession failed");
664 return NULL;
665 }
666
667 session->session = s;
668 session->modelPath = instance->modelPath;
669
670 return (PyObject *)session;
671 }
672
PyMNNInterpreter_resizeSession(PyMNNInterpreter * self,PyObject * args)673 static PyObject* PyMNNInterpreter_resizeSession(PyMNNInterpreter *self, PyObject *args) {
674 PyMNNSession* session = NULL;
675 if (!PyArg_ParseTuple(args, "O", &session)) {
676 return NULL;
677 }
678
679 if (!PyObject_TypeCheck(session, PyType_FindTLSType(&PyMNNSessionType))) {
680 PyErr_SetString(PyExc_Exception,
681 "PyMNNInterpreter_resizeSession: First argument is not a MNN.Session instance");
682 return NULL;
683 }
684
685 self->interpreter->resizeSession(session->session);
686 Py_RETURN_TRUE;
687 }
688
PyMNNInterpreter_resizeTensor(PyMNNInterpreter * self,PyObject * args)689 static PyObject* PyMNNInterpreter_resizeTensor(PyMNNInterpreter *self, PyObject *args) {
690 PyMNNTensor* tensor = NULL;
691 PyObject* shape = NULL;
692 if (!PyArg_ParseTuple(args, "OO", &tensor, &shape)) {
693 return NULL;
694 }
695
696 if (!PyObject_TypeCheck(tensor, PyType_FindTLSType(&PyMNNTensorType))) {
697 PyErr_SetString(PyExc_Exception,
698 "PyMNNInterpreter_resizeTensor: First argument is not a MNN.Tensor instance");
699 return NULL;
700 }
701
702 if (!PyTuple_Check(shape)) {
703 PyErr_SetString(PyExc_Exception,
704 "PyMNNInterpreter_resizeTensor: Second argument is not a tuple");
705 return NULL;
706 }
707
708 size_t shapeSize = PyTuple_Size(shape);
709
710 std::vector<int> vShape;
711 for (size_t i = 0; i < shapeSize; i++) {
712 int shapeItem = (int)PyLong_AsLong(PyTuple_GetItem(shape, i));
713 vShape.push_back(shapeItem);
714 }
715
716 self->interpreter->resizeTensor(tensor->tensor, vShape);
717 Py_RETURN_NONE;
718 }
719 #ifndef PYMNN_USE_ALINNPYTHON
PyMNNInterpreter_setCacheFile(PyMNNInterpreter * self,PyObject * args)720 static PyObject* PyMNNInterpreter_setCacheFile(PyMNNInterpreter *self, PyObject *args) {
721 char *path = NULL;
722 if (!PyArg_ParseTuple(args, "s", &path)) {
723 PyErr_SetString(PyExc_Exception,
724 "PyMNNInterpreter_setCacheFile: Not string input");
725 return NULL;
726 }
727 Py_BEGIN_ALLOW_THREADS
728 self->interpreter->setCacheFile(path);
729 Py_END_ALLOW_THREADS
730 Py_RETURN_NONE;
731 }
732 #endif
733
PyMNNInterpreter_runSession(PyMNNInterpreter * self,PyObject * args)734 static PyObject* PyMNNInterpreter_runSession(PyMNNInterpreter *self, PyObject *args) {
735 PyMNNSession* session = NULL;
736 if (!args) {
737 PyErr_SetString(PyExc_Exception,
738 "PyMNNInterpreter_runSession: No argument passed, expect 1");
739 return NULL;
740 }
741
742 if (!PyArg_ParseTuple(args, "O", &session)) {
743 return NULL;
744 }
745
746 if (!PyObject_TypeCheck(session, PyType_FindTLSType(&PyMNNSessionType))) {
747 PyErr_SetString(PyExc_Exception,
748 "PyMNNInterpreter_runSession: First argument is not a MNN.Session instance");
749 return NULL;
750 }
751 ErrorCode r = NO_ERROR;
752 Py_BEGIN_ALLOW_THREADS
753 r = self->interpreter->runSession(session->session);
754 Py_END_ALLOW_THREADS
755 return PyLong_FromLong(r);
756 }
PyMNNInterpreter_runSessionWithCallBack(PyMNNInterpreter * self,PyObject * args)757 static PyObject* PyMNNInterpreter_runSessionWithCallBack(PyMNNInterpreter *self, PyObject *args) {
758 PyMNNSession* session = NULL;
759 PyObject *beginCallback = NULL;
760 PyObject *endCallback = NULL;
761 if (!args) {
762 PyErr_SetString(PyExc_Exception,
763 "PyMNNInterpreter_runSessionWithCallBack: No argument passed, expect 1 or 3");
764 return NULL;
765 }
766
767 if (!PyArg_ParseTuple(args, "O|OO", &session, &beginCallback, &endCallback)) {
768 return NULL;
769 }
770
771 if (!PyObject_TypeCheck(session, PyType_FindTLSType(&PyMNNSessionType))) {
772 PyErr_SetString(PyExc_Exception,
773 "PyMNNInterpreter_runSessionWithCallBack: First argument is not a AliNN.Session instance");
774 return NULL;
775 }
776
777 TensorCallBack begin = [beginCallback](const std::vector<Tensor*>& tensors, const std::string& name){
778
779 if (!beginCallback || !PyCallable_Check(beginCallback)) {
780
781 return true;
782 }
783
784 PyObject *f = importName("MNN", "Tensor");
785 if (!f || !PyCallable_Check(f)) {
786 PyErr_SetString(PyExc_Exception,
787 "PyMNNInterpreter_runSessionWithCallBack: MNN.Tensor not found");
788 return true;
789 }
790
791 PyObject *args = PyTuple_New(2);
792 size_t size_tensors = tensors.size();
793 PyObject *weTensorData = PyTuple_New(size_tensors);
794 for (size_t i = 0; i < size_tensors; i++) {
795 // create a new tensor
796 PyMNNTensor *tensor = (PyMNNTensor *)PyObject_Call(f, PyTuple_New(0), NULL);
797 if (!tensor) {
798 PyErr_SetString(PyExc_Exception,
799 "PyMNNInterpreter_runSessionWithCallBack: create Tensor failed");
800 return true;
801 }
802 tensor->tensor = tensors[i];
803 PyTuple_SetItem(weTensorData, i, (PyObject *)tensor);
804 }
805 //printf("begincallback name=%s\n",name.c_str());
806 PyObject *weStringData = char2Object(name.c_str());
807 PyTuple_SetItem(args, 0, weTensorData);
808 PyTuple_SetItem(args, 1, weStringData);
809 bool ret = static_cast<bool>(PyLong_AsLong(PyObject_Call(beginCallback, args, NULL)));
810 Py_XDECREF(args);//del all the C++ created python api parameters
811 return ret;
812 };
813 TensorCallBack end = [endCallback](const std::vector<Tensor*>& tensors, const std::string& name){
814 if (!endCallback || !PyCallable_Check(endCallback)) {
815 return true;
816 }
817 PyObject *f = importName("MNN", "Tensor");
818 if (!f || !PyCallable_Check(f)) {
819 PyErr_SetString(PyExc_Exception,
820 "PyMNNInterpreter_runSessionWithCallBack: MNN.Tensor not found");
821 return true;
822 }
823 PyObject *args = PyTuple_New(2);
824 size_t size_tensors = tensors.size();
825 PyObject *weTensorData = PyTuple_New(size_tensors);
826 for (size_t i = 0; i < size_tensors; i++) {
827 // create a new tensor
828 PyMNNTensor *tensor = (PyMNNTensor *)PyObject_Call(f, PyTuple_New(0), NULL);
829 if (!tensor) {
830 PyErr_SetString(PyExc_Exception,
831 "PyMNNInterpreter_runSessionWithCallBack: create Tensor failed");
832 return true;
833 }
834 tensor->tensor = tensors[i];
835 PyTuple_SetItem(weTensorData, i, (PyObject *)tensor);
836 }
837 PyObject *weStringData = char2Object(name.c_str());
838 PyTuple_SetItem(args, 0, weTensorData);
839 PyTuple_SetItem(args, 1, weStringData);
840 bool ret = static_cast<bool>(PyLong_AsLong(PyObject_Call(endCallback, args, NULL)));
841 Py_XDECREF(args);//del all the C++ created python api parameters
842 return ret;
843 };
844
845 ErrorCode r = NO_ERROR;
846 //Py_BEGIN_ALLOW_THREADS
847 r = self->interpreter->runSessionWithCallBack(session->session, begin, end);
848 //Py_END_ALLOW_THREADS
849 return PyLong_FromLong(r);
850 }
851
PyMNNInterpreter_runSessionWithCallBackInfo(PyMNNInterpreter * self,PyObject * args)852 static PyObject* PyMNNInterpreter_runSessionWithCallBackInfo(PyMNNInterpreter *self, PyObject *args) {
853 PyMNNSession* session = NULL;
854 PyObject *beginCallback = NULL;
855 PyObject *endCallback = NULL;
856 if (!args) {
857 PyErr_SetString(PyExc_Exception,
858 "PyMNNInterpreter_runSessionWithCallBackInfo: No argument passed, expect 1 or 3");
859 return NULL;
860 }
861
862 if (!PyArg_ParseTuple(args, "O|OO", &session, &beginCallback, &endCallback)) {
863 return NULL;
864 }
865
866 if (!PyObject_TypeCheck(session, PyType_FindTLSType(&PyMNNSessionType))) {
867 PyErr_SetString(PyExc_Exception,
868 "PyMNNInterpreter_runSessionWithCallBackInfo: First argument is not a AliNN.Session instance");
869 return NULL;
870 }
871
872 TensorCallBackWithInfo begin = [beginCallback](const std::vector<Tensor*>& tensors, const OperatorInfo* info){
873
874 if (!beginCallback || !PyCallable_Check(beginCallback)) {
875
876 return true;
877 }
878
879 PyObject *ftensor = importName("MNN", "Tensor");
880 PyObject *finfo = importName("MNN", "OpInfo");
881 if (!ftensor || !PyCallable_Check(ftensor)) {
882 PyErr_SetString(PyExc_Exception,
883 "PyMNNInterpreter_runSessionWithCallBackINfo: MNN.Tensor not found");
884 return true;
885 }
886 if (!finfo || !PyCallable_Check(finfo)) {
887 PyErr_SetString(PyExc_Exception,
888 "PyMNNInterpreter_runSessionWithCallBackInfo: MNN.OpInfo not found");
889 return true;
890 }
891
892 PyObject *args = PyTuple_New(2);
893 size_t size_tensors = tensors.size();
894 PyObject *weTensorData = PyTuple_New(size_tensors);
895 for (size_t i = 0; i < size_tensors; i++) {
896 // create a new tensor
897 PyMNNTensor *tensor = (PyMNNTensor *)PyObject_Call(ftensor, PyTuple_New(0), NULL);
898 if (!tensor) {
899 PyErr_SetString(PyExc_Exception,
900 "PyMNNInterpreter_runSessionWithCallBackInfo: create Tensor failed");
901 return true;
902 }
903 tensor->tensor = tensors[i];
904 PyTuple_SetItem(weTensorData, i, (PyObject *)tensor);
905 }
906 //printf("begincallback name=%s\n",name.c_str());
907 PyMNNOpInfo *pyinfo = (PyMNNOpInfo *)PyObject_Call(finfo,PyTuple_New(0), NULL);
908 if(!pyinfo){
909 PyErr_SetString(PyExc_Exception,
910 "PyMNNInterpreter_runSessionWithCallBackInfo: create OpInfo failed");
911 return true;
912 }
913 pyinfo->opInfo = info;
914 PyTuple_SetItem(args, 0, weTensorData);
915 PyTuple_SetItem(args, 1, (PyObject *)pyinfo);
916 bool ret = static_cast<bool>(PyLong_AsLong(PyObject_Call(beginCallback, args, NULL)));
917 Py_XDECREF(args);//del all the C++ created python api parameters
918 return ret;
919 };
920 TensorCallBackWithInfo end = [endCallback](const std::vector<Tensor*>& tensors, const OperatorInfo* info){
921 if (!endCallback || !PyCallable_Check(endCallback)) {
922 return true;
923 }
924 PyObject *ftensor = importName("MNN", "Tensor");
925 PyObject *finfo = importName("MNN", "OpInfo");
926 if (!ftensor || !PyCallable_Check(ftensor)) {
927 PyErr_SetString(PyExc_Exception,
928 "PyMNNInterpreter_runSessionWithCallBackInfo: MNN.Tensor not found");
929 return true;
930 }
931 if (!finfo || !PyCallable_Check(finfo)) {
932 PyErr_SetString(PyExc_Exception,
933 "PyMNNInterpreter_runSessionWithCallBackInfo: MNN.OpInfo not found");
934 return true;
935 }
936 PyObject *args = PyTuple_New(2);
937 size_t size_tensors = tensors.size();
938 PyObject *weTensorData = PyTuple_New(size_tensors);
939 for (size_t i = 0; i < size_tensors; i++) {
940 // create a new tensor
941 PyMNNTensor *tensor = (PyMNNTensor *)PyObject_Call(ftensor, PyTuple_New(0), NULL);
942 if (!tensor) {
943 PyErr_SetString(PyExc_Exception,
944 "PyMNNInterpreter_runSessionWithCallBackInfo: create Tensor failed");
945 return true;
946 }
947 tensor->tensor = tensors[i];
948 PyTuple_SetItem(weTensorData, i, (PyObject *)tensor);
949 }
950 PyMNNOpInfo *pyinfo = (PyMNNOpInfo *)PyObject_Call(finfo,PyTuple_New(0), NULL);
951 if(!pyinfo){
952 PyErr_SetString(PyExc_Exception,
953 "PyMNNInterpreter_runSessionWithCallBackInfo: create OpInfo failed");
954 return true;
955 }
956 pyinfo->opInfo = info;
957 PyTuple_SetItem(args, 0, weTensorData);
958 PyTuple_SetItem(args, 1, (PyObject *)pyinfo);
959 bool ret = static_cast<bool>(PyLong_AsLong(PyObject_Call(endCallback, args, NULL)));
960 Py_XDECREF(args);//del all the C++ created python api parameters
961 return ret;
962 };
963
964 ErrorCode r = NO_ERROR;
965 //Py_BEGIN_ALLOW_THREADS
966 r = self->interpreter->runSessionWithCallBackInfo(session->session, begin, end);
967 //Py_END_ALLOW_THREADS
968 return PyLong_FromLong(r);
969 }
970
971
PyMNNInterpreter_getSessionOutput(PyMNNInterpreter * self,PyObject * args)972 static PyObject* PyMNNInterpreter_getSessionOutput(PyMNNInterpreter *self, PyObject *args) {
973 PyMNNSession* session = NULL;
974 char* name = NULL;
975 if (!PyArg_ParseTuple(args, "O|s", &session, &name)) {
976 return NULL;
977 }
978
979 if (!PyObject_TypeCheck(session, PyType_FindTLSType(&PyMNNSessionType))) {
980 PyErr_SetString(PyExc_Exception,
981 "PyMNNInterpreter_getSessionOutput: First argument is not a MNN.Session instance");
982 return NULL;
983 }
984
985 Tensor *t = self->interpreter->getSessionOutput(session->session, name);
986 if (!t) {
987 PyErr_SetString(PyExc_Exception,
988 "PyMNNInterpreter_getSessionOutput: Get output failed");
989 return NULL;
990 }
991
992 PyObject *f = importName("MNN", "Tensor");
993 if (!f || !PyCallable_Check(f)) {
994 PyErr_SetString(PyExc_Exception,
995 "PyMNNInterpreter_getSessionOutput: MNN.Tensor not found");
996 return NULL;
997 }
998
999 // create a new tensor
1000 PyMNNTensor *tensor = (PyMNNTensor *)PyObject_Call(f, PyTuple_New(0), NULL);
1001 if (!tensor) {
1002 PyErr_SetString(PyExc_Exception,
1003 "PyMNNInterpreter_createSession: MNN.Session instance create failed");
1004 return NULL;
1005 }
1006
1007 tensor->tensor = t;
1008 return (PyObject *)tensor;
1009 }
1010
PyMNNInterpreter_getSessionInput(PyMNNInterpreter * self,PyObject * args)1011 static PyObject* PyMNNInterpreter_getSessionInput(PyMNNInterpreter *self, PyObject *args) {
1012 PyMNNSession* session = NULL;
1013 char* name = NULL;
1014 if (!PyArg_ParseTuple(args, "O|s", &session, &name)) {
1015 return NULL;
1016 }
1017
1018 if (!PyObject_TypeCheck(session, PyType_FindTLSType(&PyMNNSessionType))) {
1019 PyErr_SetString(PyExc_Exception,
1020 "PyMNNInterpreter_getSessionInput: First argument is not a MNN.Session instance");
1021 return NULL;
1022 }
1023
1024 Tensor *t = self->interpreter->getSessionInput(session->session, name);
1025 if (!t) {
1026 PyErr_SetString(PyExc_Exception,
1027 "PyMNNInterpreter_getSessionInput: Get input failed");
1028 return NULL;
1029 }
1030
1031 PyObject *f = importName("MNN", "Tensor");
1032 if (!f || !PyCallable_Check(f)) {
1033 PyErr_SetString(PyExc_Exception,
1034 "PyMNNInterpreter_getSessionInput: MNN.Tensor not found");
1035 return NULL;
1036 }
1037
1038 // create a new tensor
1039 PyMNNTensor *tensor = (PyMNNTensor *)PyObject_Call(f, PyTuple_New(0), NULL);
1040 if (!tensor) {
1041 PyErr_SetString(PyExc_Exception,
1042 "PyMNNInterpreter_createSession: MNN.Session instance create failed");
1043 return NULL;
1044 }
1045
1046 tensor->tensor = t;
1047 return (PyObject *)tensor;
1048 }
1049
PyMNNInterpreter_getSessionOutputAll(PyMNNInterpreter * self,PyObject * args)1050 static PyObject* PyMNNInterpreter_getSessionOutputAll(PyMNNInterpreter *self, PyObject *args) {
1051 PyMNNSession* session = NULL;
1052 if (!PyArg_ParseTuple(args, "O", &session)) {
1053 return NULL;
1054 }
1055 if (!PyObject_TypeCheck(session, PyType_FindTLSType(&PyMNNSessionType))) {
1056 PyErr_SetString(PyExc_Exception,"PyMNNInterpreter_getSessionOutputAll: First argument is not a MNN.Session instance");
1057 return NULL;
1058 }
1059 PyObject *f = importName("MNN", "Tensor");
1060 if (!f || !PyCallable_Check(f)) {
1061 PyErr_SetString(PyExc_Exception,"PyMNNInterpreter_getSessionOutputAll: MNN.Tensor not found");
1062 return NULL;
1063 }
1064 auto map = self->interpreter->getSessionOutputAll(session->session);
1065 PyObject* output = PyDict_New();
1066 for (auto it=map.begin(); it!=map.end(); ++it) {
1067 PyObject *tensor = PyObject_Call(f, PyTuple_New(0), NULL);
1068 if (!tensor) {
1069 PyErr_SetString(PyExc_Exception,"PyMNNInterpreter_getSessionOutputAll: MNN.Tensor instance create failed");
1070 return NULL;
1071 }
1072 ((PyMNNTensor*)tensor)->tensor = it->second;
1073 PyDict_SetItem(output, char2Object(it->first.c_str()), tensor);
1074 }
1075 return output;
1076 }
1077
PyMNNInterpreter_getSessionInputAll(PyMNNInterpreter * self,PyObject * args)1078 static PyObject* PyMNNInterpreter_getSessionInputAll(PyMNNInterpreter *self, PyObject *args) {
1079 PyMNNSession* session = NULL;
1080 if (!PyArg_ParseTuple(args, "O", &session)) {
1081 return NULL;
1082 }
1083 if (!PyObject_TypeCheck(session, PyType_FindTLSType(&PyMNNSessionType))) {
1084 PyErr_SetString(PyExc_Exception,"PyMNNInterpreter_getSessionInputAll: First argument is not a MNN.Session instance");
1085 return NULL;
1086 }
1087 PyObject *f = importName("MNN", "Tensor");
1088 if (!f || !PyCallable_Check(f)) {
1089 PyErr_SetString(PyExc_Exception,"PyMNNInterpreter_getSessionInputAll: MNN.Tensor not found");
1090 return NULL;
1091 }
1092 auto map = self->interpreter->getSessionInputAll(session->session);
1093 PyObject* output = PyDict_New();
1094 for (auto it=map.begin(); it!=map.end(); ++it) {
1095 PyObject *tensor = PyObject_Call(f, PyTuple_New(0), NULL);
1096 if (!tensor) {
1097 PyErr_SetString(PyExc_Exception,"PyMNNInterpreter_getSessionInputAll: MNN.Tensor instance create failed");
1098 return NULL;
1099 }
1100 ((PyMNNTensor*)tensor)->tensor = it->second;
1101 PyDict_SetItem(output, char2Object(it->first.c_str()), tensor);
1102 }
1103 return output;
1104 }
1105
PyMNNInterpreter_new(struct _typeobject * type,PyObject * args,PyObject * kwds)1106 PyObject* PyMNNInterpreter_new(struct _typeobject *type, PyObject *args, PyObject *kwds) {
1107 PyMNNInterpreter* self = (PyMNNInterpreter *)type->tp_alloc(type, 0);
1108 return (PyObject*)self;
1109 }
1110
PyMNNInterpreter_init(PyMNNInterpreter * self,PyObject * args,PyObject * kwds)1111 static int PyMNNInterpreter_init(PyMNNInterpreter *self, PyObject *args, PyObject *kwds) {
1112 char *path = NULL;
1113 if (!PyArg_ParseTuple(args, "s", &path)) {
1114 PyErr_SetString(PyExc_Exception,
1115 "PyMNNInterpreter_new: PyArg_ParseTuple failed");
1116 return -1;
1117 }
1118 auto converted_path = convertBytesEncodeIfNeed(path);
1119 self->modelPath = new std::string(converted_path.data());
1120 if (!self->modelPath) {
1121 PyErr_SetString(PyExc_Exception,
1122 "PyMNNInterpreter_new: create modelPath string failed");
1123 return -1;
1124 }
1125
1126 if ((*interpreterMap())[*self->modelPath]) {
1127 self->interpreter = (*interpreterMap())[*self->modelPath];
1128 } else {
1129 self->interpreter = Interpreter::createFromFile(path);
1130 }
1131 if (!self->interpreter) {
1132 PyErr_SetString(PyExc_Exception,
1133 "PyMNNInterpreter_new: NetInstance::createFromFile failed");
1134 return -1;
1135 }
1136
1137 return 0;
1138 }
1139
PyMNNInterpreter_cache(PyMNNInterpreter * self,PyObject * args)1140 static PyObject* PyMNNInterpreter_cache(PyMNNInterpreter *self, PyObject *args) {
1141 if (self->modelPath && !(*interpreterMap())[*self->modelPath]) {
1142 (*interpreterMap())[*self->modelPath] = self->interpreter;
1143 }
1144 Py_RETURN_NONE;
1145 }
1146
PyMNNInterpreter_removeCache(PyMNNInterpreter * self,PyObject * args)1147 static PyObject* PyMNNInterpreter_removeCache(PyMNNInterpreter *self, PyObject *args) {
1148 if (!self->modelPath) {
1149 Py_RETURN_NONE;
1150 }
1151 Interpreter* net = (*interpreterMap())[*self->modelPath];
1152 if (net) {
1153 interpreterMap()->erase(*self->modelPath);
1154 //delete net;
1155 }
1156 Py_RETURN_NONE;
1157 }
1158
PyMNNInterpreter_updateSessionToModel(PyMNNInterpreter * self,PyObject * args)1159 static PyObject* PyMNNInterpreter_updateSessionToModel(PyMNNInterpreter *self, PyObject *args) {
1160 PyMNNSession* session = NULL;
1161 char* name = NULL;
1162 if (!PyArg_ParseTuple(args, "O|s", &session, &name)) {
1163 return NULL;
1164 }
1165
1166 if (!PyObject_TypeCheck(session, PyType_FindTLSType(&PyMNNSessionType))) {
1167 PyErr_SetString(PyExc_Exception,
1168 "PyMNNInterpreter_updateSessionToModel: First argument is not a MNN.Session instance");
1169 return NULL;
1170 }
1171
1172 self->interpreter->updateSessionToModel(session->session);
1173 if(name){
1174 auto modelBuffer = self->interpreter->getModelBuffer();
1175 ofstream output(name);
1176 output.write((const char*)modelBuffer.first, modelBuffer.second);
1177 }
1178 Py_RETURN_NONE;
1179 }
1180
PyMNNInterpreter_dealloc(PyMNNInterpreter * self)1181 static void PyMNNInterpreter_dealloc(PyMNNInterpreter *self) {
1182 if (!self->modelPath) {
1183 return;
1184 }
1185 Interpreter* net = (*interpreterMap())[*self->modelPath];
1186 // 如果对象不存在缓存中, 则释放实例
1187 if (!net && self->interpreter) {
1188 delete self->interpreter;
1189 self->interpreter = NULL;
1190 }
1191 delete self->modelPath;
1192 Py_TYPE(self)->tp_free((PyObject*)self);
1193 }
1194
1195 /// MNN Session implementation
PyMNNSession_new(struct _typeobject * type,PyObject * args,PyObject * kwds)1196 static PyObject* PyMNNSession_new(struct _typeobject *type, PyObject *args, PyObject *kwds) {
1197 PyMNNSession* self = (PyMNNSession *)type->tp_alloc(type, 0);
1198 return (PyObject*)self;
1199 }
1200
PyMNNSession_dealloc(PyMNNSession * self)1201 static void PyMNNSession_dealloc(PyMNNSession *self) {
1202 self->session = NULL;
1203 Py_TYPE(self)->tp_free((PyObject*)self);
1204 }
1205
1206 // cache session
PyMNNSession_cache(PyMNNSession * self,PyObject * args)1207 static PyObject* PyMNNSession_cache(PyMNNSession *self, PyObject *args) {
1208 if (!self->modelPath) {
1209 Py_RETURN_NONE;
1210 }
1211 if (!(*sessionCacheMap())[*self->modelPath]) {
1212 (*sessionCacheMap())[*self->modelPath] = self->session;
1213 }
1214 Py_RETURN_NONE;
1215 }
1216
PyMNNSession_removeCache(PyMNNSession * self,PyObject * args)1217 static PyObject* PyMNNSession_removeCache(PyMNNSession *self, PyObject *args) {
1218 if (!self->modelPath) {
1219 Py_RETURN_NONE;
1220 }
1221 Session* s = (*sessionCacheMap())[*self->modelPath];
1222 if (s) {
1223 sessionCacheMap()->erase(*self->modelPath);
1224 }
1225 Py_RETURN_NONE;
1226 }
1227
1228 /// MNN Tensor implementation
PyMNNTensor_new(struct _typeobject * type,PyObject * args,PyObject * kwds)1229 static PyObject* PyMNNTensor_new(struct _typeobject *type, PyObject *args, PyObject *kwds) {
1230 PyMNNTensor* self = (PyMNNTensor *)type->tp_alloc(type, 0);
1231 return (PyObject*)self;
1232 }
1233
PyMNNTensor_dealloc(PyMNNTensor * self)1234 static void PyMNNTensor_dealloc(PyMNNTensor *self) {
1235 if (self->owner) {
1236 if (self->tensor->host<void *>()) {
1237 free(self->tensor->host<void *>());
1238 }
1239 delete self->tensor;
1240 }
1241 Py_TYPE(self)->tp_free((PyObject*)self);
1242 }
1243
PyMNNTensor_init(PyMNNTensor * self,PyObject * args,PyObject * kwds)1244 static int PyMNNTensor_init(PyMNNTensor *self, PyObject *args, PyObject *kwds) {
1245 if (!PyTuple_Size(args)) {
1246 return 0;
1247 }
1248
1249 PyObject *shape, *dataType, *data;
1250 long dimensionType;
1251 if (!PyArg_ParseTuple(args, "OOOl", &shape, &dataType, &data, &dimensionType)) {
1252 return -1;
1253 }
1254
1255 size_t shapeSize = PyTuple_Size(shape);
1256
1257 std::vector<int> vShape;
1258 size_t dataSize = 1;
1259 for (size_t i = 0; i<shapeSize; i++) {
1260 int shapeItem = (int)PyLong_AsLong(PyTuple_GetItem(shape, i));
1261 vShape.push_back(shapeItem);
1262 dataSize *= shapeItem;
1263 }
1264 bool isNumpy = false;
1265 void *pData = NULL;
1266 if(PyTuple_Check(data)) {
1267 if(dataSize != PyTuple_Size(data)) {
1268 PyErr_SetString(PyExc_Exception,
1269 "PyMNNTensor_init: Tensor Dim not match");
1270 return -1;
1271 }
1272 }
1273 #ifdef PYMNN_NUMPY_USABLE
1274 else {
1275 if(PyArray_Check(data)) {
1276 isNumpy = true;
1277 if(dataSize != PyArray_Size(data)) {
1278 PyErr_SetString(PyExc_Exception, "PyMNNTensor_init: numpy array size does not match shape requirement");
1279 return -1;
1280 }
1281 }
1282 else {
1283 PyErr_SetString(PyExc_Exception, "PyMNNTensor_init: data is not tuple/numpy");
1284 return -1;
1285 }
1286 }
1287 #endif
1288 halide_type_t htt;
1289 struct MNN_TLSData *tlsData = getTLSData();
1290 if (dataType == tlsData->PyMNNHalideTypeInt) {
1291 htt = halide_type_of<int32_t>();
1292 }
1293 else if(dataType == tlsData->PyMNNHalideTypeFloat) {
1294 htt = halide_type_of<float>();
1295 }
1296 else if(dataType == tlsData->PyMNNHalideTypeDouble) {
1297 htt = halide_type_of<float>();
1298 }
1299 else if(dataType == tlsData->PyMNNHalideTypeUint8) {
1300 htt = halide_type_of<uint8_t>();
1301 }
1302 else if(dataType == tlsData->PyMNNHalideTypeInt64) {
1303 htt = halide_type_of<int64_t>();
1304 }
1305 else if(dataType == tlsData->PyMNNHalideTypeString) {
1306 htt = *httString();
1307 }
1308 else {
1309 PyErr_SetString(PyExc_Exception,"PyMNNTensor_create: unsupported data type");
1310 return -1;
1311 }
1312 DType dtype = htype2dtype(htt);
1313 if(!isNumpy) {
1314 int itemsize = getitemsize(dtype);
1315 pData = malloc(dataSize * itemsize);
1316 if(NULL == pData) {
1317 PyErr_SetString(PyExc_Exception,"PyMNNTensor_init: malloc failed");
1318 return -1;
1319 }
1320 if (dataType == tlsData->PyMNNHalideTypeInt) {
1321 for (size_t i = 0; i < dataSize; i++) {
1322 ((int *)pData)[i] = (int)PyLong_AsLong(PyTuple_GetItem(data, i));
1323 }
1324 } else if (dataType == tlsData->PyMNNHalideTypeFloat) {
1325 for (size_t i = 0; i < dataSize; i++) {
1326 ((float *)pData)[i] = (float)PyFloat_AsDouble(PyTuple_GetItem(data, i));
1327 }
1328 } else if (dataType == tlsData->PyMNNHalideTypeDouble) {
1329 for (size_t i = 0; i < dataSize; i++) {
1330 ((double *)pData)[i] = PyFloat_AsDouble(PyTuple_GetItem(data, i));
1331 }
1332 } else if (dataType == tlsData->PyMNNHalideTypeUint8) {
1333 for (size_t i = 0; i < dataSize; i++) {
1334 ((uint8_t *)pData)[i] = (uint8_t)PyLong_AsLong(PyTuple_GetItem(data, i));
1335 }
1336 } else if (dataType == tlsData->PyMNNHalideTypeInt64) {
1337 for (size_t i = 0; i < dataSize; i++) {
1338 ((int64_t *)pData)[i] = (int64_t)PyLong_AsLong(PyTuple_GetItem(data, i));
1339 }
1340 } else if (dataType == tlsData->PyMNNHalideTypeString) {
1341 for (size_t i = 0; i < dataSize; i++) {
1342 char *item = (char *)object2String(PyTuple_GetItem(data, i)).c_str();
1343 ((char **)pData)[i] = item;
1344 }
1345 }
1346 }
1347 #ifdef PYMNN_NUMPY_USABLE
1348 else {
1349 int npy_type = PyArray_TYPE(data);
1350 int itemsize = getitemsize(dtype, npy_type);
1351 pData = malloc(dataSize * itemsize);
1352 if(NULL == pData) {
1353 PyErr_SetString(PyExc_Exception,"PyMNNTensor_init: malloc failed");
1354 return -1;
1355 }
1356 PyArrayObject *data_cont= PyArray_GETCONTIGUOUS((PyArrayObject*)data);
1357 auto tmpBuffer = PyArray_DATA(data_cont);
1358 if(NULL == tmpBuffer) {
1359 PyErr_SetString(PyExc_Exception,"PyMNNTensor_init: ndarry failed to get buffer data");
1360 return -1;
1361 }
1362 memcpy(pData, tmpBuffer, dataSize * itemsize);
1363 Py_XDECREF(data_cont);
1364 }
1365 #endif
1366 Tensor *tensor = Tensor::create(vShape
1367 , htt
1368 , pData
1369 , (Tensor::DimensionType)dimensionType
1370 );
1371 if (!tensor) {
1372 PyErr_SetString(PyExc_Exception,
1373 "PyMNNTensor_create: Tensor create failed");
1374 return -1;
1375 }
1376 self->tensor = tensor;
1377 self->owner = 1;
1378 return 0;
1379 }
1380 #ifdef PYMNN_NUMPY_USABLE
PyMNNTensor_fromNumpy(PyMNNTensor * self,PyObject * args)1381 static PyObject* PyMNNTensor_fromNumpy(PyMNNTensor *self, PyObject *args) {
1382 PyObject *data;
1383 if (!PyArg_ParseTuple(args, "O", &data)) {
1384 return NULL;
1385 }
1386 if (!PyArray_Check(data)) {
1387 PyErr_SetString(PyExc_Exception,"PyMNNTensor_fromNumpy: input is not a numpy");
1388 }
1389 if (self->owner){
1390 if(self->tensor->size() != PyArray_Size(data)) {
1391 PyErr_SetString(PyExc_Exception,"PyMNNTensor_fromNumpy: tensor/numpy size does not match each other");
1392 return NULL;
1393 }
1394 DType dtype = htype2dtype(self->tensor->getType());
1395 int npy_type = PyArray_TYPE(data);
1396 int itemsize = getitemsize(dtype, npy_type);
1397 PyArrayObject *data_cont= PyArray_GETCONTIGUOUS((PyArrayObject*)data);
1398 auto tmpBuffer = PyArray_DATA(data_cont);
1399 if(NULL == tmpBuffer) {
1400 PyErr_SetString(PyExc_Exception,"PyMNNTensor_fromNumpy: ndarry failed to get buffer data");
1401 return NULL;
1402 }
1403 memcpy(self->tensor->host<void *>(), tmpBuffer, self->tensor->size() * itemsize);
1404 Py_XDECREF(data_cont);
1405 }
1406 Py_RETURN_NONE;
1407 }
1408 #endif
PyMNNTensor_printTensorData(PyMNNTensor * self,PyObject * args)1409 static PyObject* PyMNNTensor_printTensorData(PyMNNTensor *self, PyObject *args) {
1410 if (self->tensor) {
1411 // Do nothing
1412 }
1413 Py_RETURN_NONE;
1414 }
1415
PyMNNTensor_getHost(PyMNNTensor * self,PyObject * args)1416 static PyObject* PyMNNTensor_getHost(PyMNNTensor *self, PyObject *args) {
1417 if (self->tensor) {
1418 return PyCapsule_New(self->tensor->host<void *>(), NULL, NULL);
1419 }
1420 Py_RETURN_NONE;
1421 }
1422
PyMNNTensor_getDataType(PyMNNTensor * self,PyObject * args)1423 static PyObject* PyMNNTensor_getDataType(PyMNNTensor *self, PyObject *args) {
1424 if (self->tensor) {
1425 halide_type_t t = self->tensor->getType();
1426 PyObject *type;
1427 struct MNN_TLSData *tlsData =getTLSData();
1428 if (t == *httInt()) {
1429 type = tlsData->PyMNNHalideTypeInt;
1430 } else if (t == *httUint8()) {
1431 type = tlsData->PyMNNHalideTypeUint8;
1432 } else if (t == *httInt64()) {
1433 type = tlsData->PyMNNHalideTypeInt64;
1434 } else if (t == *httFloat()) {
1435 type = tlsData->PyMNNHalideTypeFloat;
1436 } else if (t == *httDouble()) {
1437 type = tlsData->PyMNNHalideTypeDouble;
1438 } else if (t == *httString()) {
1439 type = tlsData->PyMNNHalideTypeString;
1440 } else {
1441 Py_RETURN_NONE;
1442 }
1443 Py_XINCREF(type);
1444 return type;
1445 }
1446 Py_RETURN_NONE;
1447 }
1448
PyMNNTensor_getData(PyMNNTensor * self,PyObject * args)1449 static PyObject* PyMNNTensor_getData(PyMNNTensor *self, PyObject *args) {
1450 if (self->tensor) {
1451 halide_type_t t = self->tensor->getType();
1452 size_t size = self->tensor->elementSize();
1453 PyObject *outputData = PyTuple_New(size);
1454 if (t == *httInt()) {
1455 auto data = self->tensor->host<int32_t>();
1456 for (size_t i = 0; i < size; i++) {
1457 PyTuple_SetItem(outputData, i, PyLong_FromLong(data[i]));
1458 }
1459 } else if (t == *httUint8()) {
1460 auto data = self->tensor->host<uint8_t>();
1461 for (size_t i = 0; i < size; i++) {
1462 PyTuple_SetItem(outputData, i, PyLong_FromLong(data[i]));
1463 }
1464 } else if (t == *httInt64()) {
1465 auto data = self->tensor->host<int64_t>();
1466 for (size_t i = 0; i < size; i++) {
1467 PyTuple_SetItem(outputData, i, PyLong_FromLong(data[i]));
1468 }
1469 } else if (t == *httFloat()) {
1470 auto data = self->tensor->host<float>();
1471 for (size_t i = 0; i < size; i++) {
1472 PyTuple_SetItem(outputData, i, PyFloat_FromDouble(data[i]));
1473 }
1474 } else if (t == *httDouble()) {
1475 auto data = self->tensor->host<double>();
1476 for (size_t i = 0; i < size; i++) {
1477 PyTuple_SetItem(outputData, i, PyFloat_FromDouble(data[i]));
1478 }
1479 } else if (t == *httString()) {
1480 auto data = self->tensor->host<char *>();
1481 for (size_t i = 0; i < size; i++) {
1482 char *dataItem = data[i];
1483 PyTuple_SetItem(outputData, i, char2Object(dataItem?dataItem:""));
1484 }
1485 } else {
1486 Py_RETURN_NONE;
1487 }
1488 return outputData;
1489 }
1490 Py_RETURN_NONE;
1491 }
1492
1493 #ifdef PYMNN_NUMPY_USABLE
PyMNNTensor_getNumpyData(PyMNNTensor * self,PyObject * args)1494 static PyObject* PyMNNTensor_getNumpyData(PyMNNTensor *self, PyObject *args) {
1495 if (self->tensor) {
1496 halide_type_t t = self->tensor->getType();
1497 std::vector<npy_intp> npy_dims;
1498 for(const auto dim : self->tensor->shape()) {
1499 npy_dims.push_back(dim);
1500 }
1501 PyObject* obj;
1502 if (t == *httInt()) {
1503 auto data = self->tensor->host<int32_t>();
1504 obj = PyArray_SimpleNewFromData(npy_dims.size(), npy_dims.data(), NPY_INT32, data);
1505 } else if (t == *httUint8()) {
1506 auto data = self->tensor->host<uint8_t>();
1507 obj = PyArray_SimpleNewFromData(npy_dims.size(), npy_dims.data(), NPY_UINT8, data);
1508 } else if (t == *httInt64()) {
1509 auto data = self->tensor->host<int64_t>();
1510 obj = PyArray_SimpleNewFromData(npy_dims.size(), npy_dims.data(), NPY_INT64, data);
1511 } else if (t == *httFloat()) {
1512 auto data = self->tensor->host<float>();
1513 obj = PyArray_SimpleNewFromData(npy_dims.size(), npy_dims.data(), NPY_FLOAT, data);
1514 } else if (t == *httDouble()) {
1515 auto data = self->tensor->host<double>();
1516 obj = PyArray_SimpleNewFromData(npy_dims.size(), npy_dims.data(), NPY_DOUBLE, data);
1517 } else {
1518 PyErr_SetString(PyExc_Exception, "tensor can not be read as numpy");
1519 Py_RETURN_NONE;
1520 }
1521 return obj;
1522 }
1523 Py_RETURN_NONE;
1524 }
1525 #endif
1526
PyMNNTensor_getDimensionType(PyMNNTensor * self,PyObject * args)1527 static PyObject* PyMNNTensor_getDimensionType(PyMNNTensor *self, PyObject *args) {
1528 if (self->tensor) {
1529 return PyLong_FromLong(self->tensor->getDimensionType());
1530 }
1531 Py_RETURN_NONE;
1532 }
1533
PyMNNTensor_copyFrom(PyMNNTensor * self,PyObject * args)1534 static PyObject* PyMNNTensor_copyFrom(PyMNNTensor *self, PyObject *args) {
1535 PyMNNTensor *fromTensor = NULL;
1536 if (!PyArg_ParseTuple(args, "O", &fromTensor)) {
1537 return NULL;
1538 }
1539
1540 if (!fromTensor->tensor || !self->tensor) {
1541 PyErr_SetString(PyExc_Exception,
1542 "PyMNNTensor_copyFrom: source or destination tensor is null");
1543 }
1544
1545 bool r = self->tensor->copyFromHostTensor(fromTensor->tensor);
1546 if (!r) {
1547 Py_RETURN_FALSE;
1548 }
1549 Py_RETURN_TRUE;
1550 }
1551
PyMNNTensor_copyToHostTensor(PyMNNTensor * self,PyObject * args)1552 static PyObject* PyMNNTensor_copyToHostTensor(PyMNNTensor *self, PyObject *args) {
1553 PyMNNTensor *toTensor = NULL;
1554 if (!PyArg_ParseTuple(args, "O", &toTensor)) {
1555 return NULL;
1556 }
1557
1558 if (!toTensor->tensor || !self->tensor) {
1559 PyErr_SetString(PyExc_Exception,
1560 "PyMNNTensor_copyTo: source or destination tensor is null");
1561 }
1562
1563 bool r = self->tensor->copyToHostTensor(toTensor->tensor);
1564 if (!r) {
1565 Py_RETURN_FALSE;
1566 }
1567 Py_RETURN_TRUE;
1568 }
1569
PyMNNTensor_getShape(PyMNNTensor * self,PyObject * args)1570 static PyObject* PyMNNTensor_getShape(PyMNNTensor *self, PyObject *args) {
1571 if (self->tensor) {
1572 PyObject *shape = PyTuple_New(self->tensor->shape().size());
1573 for (size_t i = 0; i < self->tensor->shape().size(); i++) {
1574 PyTuple_SetItem(shape, i, PyLong_FromLong(self->tensor->shape()[i]));
1575 }
1576 return shape;
1577 }
1578 Py_RETURN_NONE;
1579 }
1580
1581 /// MNN ImageProcess implementation
PyMNNCVImageProcess_new(struct _typeobject * type,PyObject * args,PyObject * kwds)1582 static PyObject* PyMNNCVImageProcess_new(struct _typeobject *type, PyObject *args, PyObject *kwds) {
1583 PyMNNCVImageProcess* self = (PyMNNCVImageProcess *)type->tp_alloc(type, 0);
1584 return (PyObject*)self;
1585 }
1586
PyMNNCVImageProcess_dealloc(PyMNNCVImageProcess * self)1587 static void PyMNNCVImageProcess_dealloc(PyMNNCVImageProcess *self) {
1588 delete self->imageProcess;
1589 Py_TYPE(self)->tp_free((PyObject*)self);
1590 }
1591
1592
PyMNNCVImageProcess_init(PyMNNCVImageProcess * self,PyObject * args,PyObject * kwds)1593 static int PyMNNCVImageProcess_init(PyMNNCVImageProcess *self, PyObject *args, PyObject *kwds) {
1594 PyObject *config = NULL, *destinationTensor = NULL;
1595 if (!PyArg_ParseTuple(args, "O|O", &config, &destinationTensor)) {
1596 return -1;
1597 }
1598
1599 Tensor *t = NULL;
1600 if (destinationTensor
1601 && PyObject_TypeCheck(destinationTensor, PyType_FindTLSType(&PyMNNTensorType))) {
1602 t = ((PyMNNTensor *)destinationTensor)->tensor;
1603 }
1604
1605 CV::ImageProcess::Config c;
1606 if (PyDict_Check(config)) {
1607 PyObject *filterType = PyDict_GetItemString(config, "filterType");
1608 if (filterType && PyLong_Check(filterType)) {
1609 c.filterType = (CV::Filter)PyLong_AsLong(filterType);
1610 }
1611
1612 PyObject *sourceFormat = PyDict_GetItemString(config, "sourceFormat");
1613 if (sourceFormat && PyLong_Check(sourceFormat)) {
1614 c.sourceFormat = (CV::ImageFormat)PyLong_AsLong(sourceFormat);
1615 }
1616
1617 PyObject *destFormat = PyDict_GetItemString(config, "destFormat");
1618 if (destFormat && PyLong_Check(destFormat)) {
1619 c.destFormat = (CV::ImageFormat)PyLong_AsLong(destFormat);
1620 }
1621
1622 PyObject *wrap = PyDict_GetItemString(config, "wrap");
1623 if (wrap && PyLong_Check(wrap)) {
1624 c.wrap = (CV::Wrap)PyLong_AsLong(wrap);
1625 }
1626
1627 PyObject *mean = PyDict_GetItemString(config, "mean");
1628 if (mean) {
1629 if (!PyTuple_Check(mean) || PyTuple_Size(mean) != 4) {
1630 PyErr_SetString(PyExc_Exception,
1631 "PyMNNCVImageProcess_init: mean must be a tuple with 4 elements");
1632 return -1;
1633 }
1634 for (int i = 0; i < 4; i++) {
1635 c.mean[i] = (float)PyFloat_AsDouble(PyTuple_GetItem(mean, i));
1636 }
1637 }
1638
1639 PyObject *normal = PyDict_GetItemString(config, "normal");
1640 if (normal) {
1641 if (!PyTuple_Check(normal) || PyTuple_Size(normal) != 4) {
1642 PyErr_SetString(PyExc_Exception,
1643 "PyMNNCVImageProcess_init: normal must be a tuple with 4 elements");
1644 return -1;
1645 }
1646 for (int i = 0; i < 4; i++) {
1647 c.normal[i] = (float)PyFloat_AsDouble(PyTuple_GetItem(normal, i));
1648 }
1649 }
1650 }
1651
1652 CV::ImageProcess *imageProcess = CV::ImageProcess::create(c, t);
1653 if (!imageProcess) {
1654 PyErr_SetString(PyExc_Exception,
1655 "PyMNNCVImageProcess_init: ImageProcess create failed");
1656 return -1;
1657 }
1658
1659 self->imageProcess = imageProcess;
1660 return 0;
1661 }
1662
PyMNNCVImageProcess_setMatrix(PyMNNCVImageProcess * self,PyObject * args)1663 static PyObject* PyMNNCVImageProcess_setMatrix(PyMNNCVImageProcess *self, PyObject *args) {
1664 PyObject *matrix;
1665 if (!PyArg_ParseTuple(args, "O", &matrix)) {
1666 return NULL;
1667 }
1668
1669 if (!PyObject_TypeCheck(matrix, PyType_FindTLSType(&PyMNNCVMatrixType))) {
1670 PyErr_SetString(PyExc_Exception,
1671 "PyMNNCVImageProcess_setMatrix: argument is not a matrix");
1672 return NULL;
1673 }
1674
1675 self->imageProcess->setMatrix(*((PyMNNCVMatrix *)matrix)->matrix);
1676 Py_RETURN_NONE;
1677 }
1678
PyMNNCVImageProcess_convert(PyMNNCVImageProcess * self,PyObject * args)1679 static PyObject* PyMNNCVImageProcess_convert(PyMNNCVImageProcess *self, PyObject *args) {
1680 PyObject *source, *dest;
1681 int iw, ih, stride;
1682 if (!PyArg_ParseTuple(args, "OiiiO", &source, &iw, &ih, &stride, &dest)) {
1683 return NULL;
1684 }
1685
1686 if (!PyObject_TypeCheck(dest, PyType_FindTLSType(&PyMNNTensorType))) {
1687 PyErr_SetString(PyExc_Exception,
1688 "PyMNNCVImageProcess_convert: argument 4 is not a MNNTensor");
1689 return NULL;
1690 }
1691
1692 if (PyCapsule_CheckExact(source)) {
1693 // Capsule Pointer
1694 ErrorCode ret = self->imageProcess->convert((const uint8_t *)PyCapsule_GetPointer(source, NULL),
1695 iw, ih, stride,
1696 ((PyMNNTensor *)dest)->tensor);
1697 return PyLong_FromLong(ret);
1698 } else if (PyTuple_Check(source)) {
1699 // Tuple Data
1700 size_t size = PyTuple_Size(source);
1701
1702 void *pData = malloc(size * sizeof(uint8_t));
1703 for (size_t i = 0; i < size; i++) {
1704 ((uint8_t *)pData)[i] = (uint8_t)PyLong_AsLong(PyTuple_GetItem(source, i));
1705 }
1706
1707 ErrorCode ret = self->imageProcess->convert((const uint8_t *)pData,
1708 iw, ih, stride,
1709 ((PyMNNTensor *)dest)->tensor);
1710
1711 free(pData);
1712
1713 return PyLong_FromLong(ret);
1714 }
1715 #ifdef PYMNN_NUMPY_USABLE
1716 else if(PyArray_Check(source)) {
1717 // Array Data
1718 int npy_type = PyArray_TYPE(source);
1719 if(npy_type != NPY_UINT8) {
1720 PyErr_SetString(PyExc_Exception,
1721 "PyMNNCVImageProcess_convert: only numpy.uint8 is supported for numpy");
1722 return NULL;
1723 }
1724 int64_t total_length = 1;
1725 for (size_t i = 0; i < ((PyMNNTensor *)dest)->tensor->shape().size(); i++) {
1726 total_length *= ((PyMNNTensor *)dest)->tensor->shape()[i];
1727 }
1728 if(PyArray_Size(source) < total_length) //as input may contain stride, so we can only do basic check
1729 {
1730 PyErr_SetString(PyExc_Exception,
1731 "PyMNNCVImageProcess_convert: data length does not match tensor size");
1732 return NULL;
1733 }
1734 PyArrayObject *data_cont= PyArray_GETCONTIGUOUS((PyArrayObject*)source);
1735 auto tmpBuffer = PyArray_DATA(data_cont);
1736 if(NULL == tmpBuffer) {
1737 PyErr_SetString(PyExc_Exception,"PyMNNTensor_init: ndarry failed to get buffer data");
1738 return NULL;
1739 }
1740 ErrorCode ret = self->imageProcess->convert((const uint8_t *)tmpBuffer,
1741 iw, ih, stride,
1742 ((PyMNNTensor *)dest)->tensor);
1743 Py_XDECREF(data_cont);
1744 return PyLong_FromLong(ret);
1745 }
1746 #endif
1747
1748 PyErr_SetString(PyExc_Exception, "PyMNNCVImageProcess_convert: argument 0 is not a capsule or tuple or numpy");
1749
1750 return NULL;
1751 }
1752
1753
PyMNNCVImageProcess_createImageTensor(PyMNNCVImageProcess * self,PyObject * args)1754 static PyObject* PyMNNCVImageProcess_createImageTensor(PyMNNCVImageProcess *self, PyObject *args) {
1755
1756 PyObject *dataType;
1757 int width, height, bpp;
1758 PyObject *data;
1759
1760 if (!PyArg_ParseTuple(args, "OiiiO", &dataType, &width, &height, &bpp, &data)) {
1761 return NULL;
1762 }
1763
1764
1765 // if (nullptr != data && !PyCapsule_CheckExact(data)) {
1766 // PyErr_SetString(PyExc_Exception,
1767 // "PyMNNCVImageProcess_createImageTensor: argument 4 is not a capsule");
1768 // return NULL;
1769 // }
1770
1771 std::vector<int> vShape = {1, height, width, bpp};
1772
1773 halide_type_t htt;
1774 struct MNN_TLSData *tlsData = getTLSData();
1775 if (dataType == tlsData->PyMNNHalideTypeInt) {
1776 htt = halide_type_of<int32_t>();
1777 } else if (dataType == tlsData->PyMNNHalideTypeFloat) {
1778 htt = halide_type_of<float>();
1779 } else if (dataType == tlsData->PyMNNHalideTypeDouble) {
1780 htt = halide_type_of<double>();
1781 } else if (dataType == tlsData->PyMNNHalideTypeUint8) {
1782 htt = halide_type_of<uint8_t>();
1783 } else if (dataType == tlsData->PyMNNHalideTypeInt64) {
1784 htt = halide_type_of<int64_t>();
1785 } else if (dataType == tlsData->PyMNNHalideTypeString) {
1786 htt = *httString();
1787 }
1788
1789 Tensor *tensor = Tensor::create(vShape, htt);
1790 // Tensor *tensor = Tensor::create(vShape, htt, PyCapsule_GetPointer(data, NULL));TODO
1791 if (!tensor) {
1792 PyErr_SetString(PyExc_Exception,
1793 "PyMNNCVImageProcess_createImageTensor: Tensor create failed");
1794 return NULL;
1795 }
1796
1797 PyObject *f = importName("MNN", "Tensor");
1798 if (!f || !PyCallable_Check(f)) {
1799 PyErr_SetString(PyExc_Exception,
1800 "PyMNNCVImageProcess_createImageTensor: MNN.Tensor not found");
1801 return NULL;
1802 }
1803
1804 PyMNNTensor *t = (PyMNNTensor *)PyObject_Call(f, PyTuple_New(0), NULL);
1805 if (!t) {
1806 PyErr_SetString(PyExc_Exception,
1807 "PyMNNCVImageProcess_createImageTensor: create image tensor failed");
1808 return NULL;
1809 }
1810
1811 t->tensor = tensor;
1812 t->owner = 1;
1813 return (PyObject *)t;
1814 }
1815
1816 /// MNN CVMatrix implementation
PyMNNCVMatrix_new(struct _typeobject * type,PyObject * args,PyObject * kwds)1817 static PyObject* PyMNNCVMatrix_new(struct _typeobject *type, PyObject *args, PyObject *kwds) {
1818 PyMNNCVMatrix* self;
1819 self = (PyMNNCVMatrix *)type->tp_alloc(type, 0);
1820 self->matrix = new CV::Matrix();
1821 return (PyObject*)self;
1822 }
1823
PyMNNCVMatrix_dealloc(PyMNNCVMatrix * self)1824 static void PyMNNCVMatrix_dealloc(PyMNNCVMatrix *self) {
1825 delete self->matrix;
1826 Py_TYPE(self)->tp_free((PyObject*)self);
1827 }
1828
1829 // type: 0 set; 1 pre; 2 post
_PyMNNCVMatrix_Rotate(PyMNNCVMatrix * self,PyObject * args,int type)1830 static PyObject* _PyMNNCVMatrix_Rotate(PyMNNCVMatrix *self, PyObject *args, int type) {
1831 float degrees, px = 0.0, py = 0.0;
1832 size_t argsCount = PyTuple_Size(args);
1833 if (argsCount == 1) {
1834 if (!PyArg_ParseTuple(args, "f", °rees)) {
1835 PyErr_SetString(PyExc_Exception,
1836 "PyMNNCVMatrix_Rotate: PyArg_ParseTuple failed");
1837 return NULL;
1838 }
1839 } else if (argsCount == 3) {
1840 if (!PyArg_ParseTuple(args, "fff", °rees, &px, &py)) {
1841 PyErr_SetString(PyExc_Exception,
1842 "PyMNNCVMatrix_Rotate: PyArg_ParseTuple failed");
1843 return NULL;
1844 }
1845 } else {
1846 PyErr_SetString(PyExc_Exception,
1847 "PyMNNCVMatrix_Rotate: argument count error (should be 1 or 3)");
1848 return NULL;
1849 }
1850
1851 if (argsCount == 1) {
1852 switch (type) {
1853 case 0:
1854 self->matrix->setRotate(degrees);
1855 break;
1856 case 1:
1857 self->matrix->preRotate(degrees);
1858 break;
1859 case 2:
1860 self->matrix->postRotate(degrees);
1861 break;
1862 default:
1863 break;
1864 }
1865
1866 } else if (argsCount == 3) {
1867 switch (type) {
1868 case 0:
1869 self->matrix->setRotate(degrees, px, py);
1870 break;
1871 case 1:
1872 self->matrix->preRotate(degrees, px, py);
1873 break;
1874 case 2:
1875 self->matrix->postRotate(degrees, px, py);
1876 break;
1877 default:
1878 break;
1879 }
1880 }
1881 Py_RETURN_NONE;
1882 }
1883 // set
PyMNNCVMatrix_setRotate(PyMNNCVMatrix * self,PyObject * args)1884 static PyObject* PyMNNCVMatrix_setRotate(PyMNNCVMatrix *self, PyObject *args) {
1885 return _PyMNNCVMatrix_Rotate(self, args, 0);
1886 }
1887 // pre
PyMNNCVMatrix_preRotate(PyMNNCVMatrix * self,PyObject * args)1888 static PyObject* PyMNNCVMatrix_preRotate(PyMNNCVMatrix *self, PyObject *args) {
1889 return _PyMNNCVMatrix_Rotate(self, args, 1);
1890 }
1891 // post
PyMNNCVMatrix_postRotate(PyMNNCVMatrix * self,PyObject * args)1892 static PyObject* PyMNNCVMatrix_postRotate(PyMNNCVMatrix *self, PyObject *args) {
1893 return _PyMNNCVMatrix_Rotate(self, args, 2);
1894 }
1895
_PyMNNCVMatrix_Scale(PyMNNCVMatrix * self,PyObject * args,int type)1896 static PyObject* _PyMNNCVMatrix_Scale(PyMNNCVMatrix *self, PyObject *args, int type) {
1897 float sx, sy, px = 0.0, py = 0.0;
1898 size_t argsCount = PyTuple_Size(args);
1899 if (argsCount == 2) {
1900 if (!PyArg_ParseTuple(args, "ff", &sx, &sy)) {
1901 PyErr_SetString(PyExc_Exception,
1902 "PyMNNCVMatrix_Scale: PyArg_ParseTuple failed");
1903 return NULL;
1904 }
1905 } else if (argsCount == 4) {
1906 if (!PyArg_ParseTuple(args, "ffff", &sx, &sy, &px, &py)) {
1907 PyErr_SetString(PyExc_Exception,
1908 "PyMNNCVMatrix_Scale: PyArg_ParseTuple failed");
1909 return NULL;
1910 }
1911 } else {
1912 PyErr_SetString(PyExc_Exception,
1913 "PyMNNCVMatrix_Scale: argument count error (should be 2 or 4)");
1914 return NULL;
1915 }
1916
1917 if (argsCount == 2) {
1918 switch (type) {
1919 case 0:
1920 self->matrix->setScale(sx, sy);
1921 break;
1922 case 1:
1923 self->matrix->preScale(sx, sy);
1924 break;
1925 case 2:
1926 self->matrix->postScale(sx, sy);
1927 break;
1928 default:
1929 break;
1930 }
1931 } else if (argsCount == 4) {
1932 switch (type) {
1933 case 0:
1934 self->matrix->setScale(sx, sy, px, py);
1935 break;
1936 case 1:
1937 self->matrix->preScale(sx, sy, px, py);
1938 break;
1939 case 2:
1940 self->matrix->postScale(sx, sy, px, py);
1941 break;
1942 default:
1943 break;
1944 }
1945 }
1946 Py_RETURN_NONE;
1947 }
PyMNNCVMatrix_setScale(PyMNNCVMatrix * self,PyObject * args)1948 static PyObject* PyMNNCVMatrix_setScale(PyMNNCVMatrix *self, PyObject *args) {
1949 return _PyMNNCVMatrix_Scale(self, args, 0);
1950 }
PyMNNCVMatrix_preScale(PyMNNCVMatrix * self,PyObject * args)1951 static PyObject* PyMNNCVMatrix_preScale(PyMNNCVMatrix *self, PyObject *args) {
1952 return _PyMNNCVMatrix_Scale(self, args, 1);
1953 }
PyMNNCVMatrix_postScale(PyMNNCVMatrix * self,PyObject * args)1954 static PyObject* PyMNNCVMatrix_postScale(PyMNNCVMatrix *self, PyObject *args) {
1955 return _PyMNNCVMatrix_Scale(self, args, 2);
1956 }
1957
_PyMNNCVMatrix_Translate(PyMNNCVMatrix * self,PyObject * args,int type)1958 static PyObject* _PyMNNCVMatrix_Translate(PyMNNCVMatrix *self, PyObject *args, int type) {
1959 float dx = 0.0, dy = 0.0;
1960 size_t argsCount = PyTuple_Size(args);
1961 if (argsCount == 2) {
1962 if (!PyArg_ParseTuple(args, "ff", &dx, &dy)) {
1963 PyErr_SetString(PyExc_Exception,
1964 "PyMNNCVMatrix_postScale: PyArg_ParseTuple failed");
1965 return NULL;
1966 }
1967 } else {
1968 PyErr_SetString(PyExc_Exception,
1969 "PyMNNCVMatrix_postScale: argument count error (should be 2 or 4)");
1970 return NULL;
1971 }
1972
1973 switch (type) {
1974 case 0:
1975 self->matrix->setTranslate(dy, dy);
1976 break;
1977 case 1:
1978 self->matrix->preTranslate(dy, dy);
1979 break;
1980 case 2:
1981 self->matrix->postTranslate(dy, dy);
1982 break;
1983 default:
1984 break;
1985 }
1986 Py_RETURN_NONE;
1987 }
PyMNNCVMatrix_setTranslate(PyMNNCVMatrix * self,PyObject * args)1988 static PyObject* PyMNNCVMatrix_setTranslate(PyMNNCVMatrix *self, PyObject *args) {
1989 return _PyMNNCVMatrix_Translate(self, args, 0);
1990 }
PyMNNCVMatrix_preTranslate(PyMNNCVMatrix * self,PyObject * args)1991 static PyObject* PyMNNCVMatrix_preTranslate(PyMNNCVMatrix *self, PyObject *args) {
1992 return _PyMNNCVMatrix_Translate(self, args, 1);
1993 }
PyMNNCVMatrix_postTranslate(PyMNNCVMatrix * self,PyObject * args)1994 static PyObject* PyMNNCVMatrix_postTranslate(PyMNNCVMatrix *self, PyObject *args) {
1995 return _PyMNNCVMatrix_Translate(self, args, 2);
1996 }
1997
PyMNNCVMatrix_invert(PyMNNCVMatrix * self)1998 static PyObject* PyMNNCVMatrix_invert(PyMNNCVMatrix *self) {
1999
2000 self->matrix->invert(self->matrix);
2001 Py_RETURN_NONE;
2002 }
2003 static PyObject* PyMNNOpInfo_getName(PyMNNOpInfo *self, PyObject *args);
2004 static PyObject* PyMNNOpInfo_getType(PyMNNOpInfo *self, PyObject *args);
2005
2006 static void PyMNNOpInfo_dealloc(PyMNNOpInfo *self);
2007 static PyObject* PyMNNOpInfo_new(struct _typeobject *type, PyObject *args, PyObject *kwds);
2008 static int PyMNNOpInfo_init(PyMNNOpInfo *info, PyObject *args, PyObject *kwds);
2009
2010 static PyMethodDef PyMNNOpInfo_methods[] = {
2011 {"getName", (PyCFunction)PyMNNOpInfo_getName, METH_VARARGS, "get op name"},
2012 {"getType", (PyCFunction)PyMNNOpInfo_getType, METH_VARARGS, "get op type"},
2013 {NULL} /* Sentinel */
2014 };
PyMNNOpInfo_getName(PyMNNOpInfo * self,PyObject * args)2015 static PyObject* PyMNNOpInfo_getName(PyMNNOpInfo *self, PyObject *args) {
2016 PyObject *name = NULL;
2017 if (self->opInfo) {
2018 name = char2Object(self->opInfo->name().c_str());
2019 }
2020 return name;
2021 }
PyMNNOpInfo_getType(PyMNNOpInfo * self,PyObject * args)2022 static PyObject* PyMNNOpInfo_getType(PyMNNOpInfo *self, PyObject *args) {
2023 PyObject *type = NULL;
2024 if (self->opInfo) {
2025 type = char2Object(self->opInfo->type().c_str());
2026 }
2027 return type;
2028 }
2029 static PyTypeObject PyMNNOpInfoType = {
2030 PyVarObject_HEAD_INIT(NULL, 0)
2031 "MNN.OpInfo", /*tp_name*/
2032 sizeof(PyMNNOpInfo), /*tp_basicsize*/
2033 0, /*tp_itemsize*/
2034 (destructor)PyMNNOpInfo_dealloc, /*tp_dealloc*/
2035 0, /*tp_print*/
2036 0, /*tp_getattr*/
2037 0, /*tp_setattr*/
2038 0, /*tp_compare*/
2039 0, /*tp_repr*/
2040 0, /*tp_as_number*/
2041 0, /*tp_as_sequence*/
2042 0, /*tp_as_mapping*/
2043 0, /*tp_hash */
2044 0, /*tp_call*/
2045 0, /*tp_str*/
2046 0, /*tp_getattro*/
2047 0, /*tp_setattro*/
2048 0, /*tp_as_buffer*/
2049 Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE, /*tp_flags*/
2050 "MNN OpInfo objects", /* tp_doc */
2051 0, /* tp_traverse */
2052 0, /* tp_clear */
2053 0, /* tp_richcompare */
2054 0, /* tp_weaklistoffset */
2055 0, /* tp_iter */
2056 0, /* tp_iternext */
2057 PyMNNOpInfo_methods, /* tp_methods */
2058 0, /* tp_members */
2059 0, /* tp_getset */
2060 0, /* tp_base */
2061 0, /* tp_dict */
2062 0, /* tp_descr_get */
2063 0, /* tp_descr_set */
2064 0, /* tp_dictoffset */
2065 (initproc)PyMNNOpInfo_init, /* tp_init */
2066 0, /* tp_alloc */
2067 PyMNNOpInfo_new, /* tp_new */
2068 };
PyMNNOpInfo_new(struct _typeobject * type,PyObject * args,PyObject * kwds)2069 static PyObject* PyMNNOpInfo_new(struct _typeobject *type, PyObject *args, PyObject *kwds) {
2070 PyMNNOpInfo* self = (PyMNNOpInfo *)type->tp_alloc(type, 0);
2071 return (PyObject*)self;
2072 }
PyMNNOpInfo_init(PyMNNOpInfo * info,PyObject * args,PyObject * kwds)2073 static int PyMNNOpInfo_init(PyMNNOpInfo *info, PyObject *args, PyObject *kwds) {
2074 return 0;
2075 }
2076
PyMNNOpInfo_dealloc(PyMNNOpInfo * self)2077 static void PyMNNOpInfo_dealloc(PyMNNOpInfo *self) {
2078 Py_TYPE(self)->tp_free((PyObject*)self);
2079 }
2080 /// module init
2081 static PyMethodDef module_methods[] = {
2082 {NULL, NULL, 0, NULL}
2083 };
2084
add(int a,int b)2085 int add (int a , int b) {
2086 return a + b;
2087 }
2088
2089 // _MOD_NAME [_mnncengine or MNN]
2090 // MOD_NAME ["_mnncengine" or "MNN"]
2091 #if PYMNN_USE_ALINNPYTHON
2092 #if PYMNN_EXPR_API
2093 #define _MOD_NAME _mnncengine
2094 #else
2095 #define _MOD_NAME MNN
2096 #endif
2097 #else
2098 #define _MOD_NAME _mnncengine
2099 #endif
2100 #define _STRINGIFY(str) #str
2101 #define STRINGIFY(macro) _STRINGIFY(macro)
2102 #define MOD_NAME STRINGIFY(_MOD_NAME)
2103
2104 #if PY_MAJOR_VERSION >= 3
2105 static struct PyModuleDef moduledef = {
2106 PyModuleDef_HEAD_INIT,
2107 MOD_NAME, /* m_name */
2108 "MNNEngine", /* m_doc */
2109 -1, /* m_size */
2110 module_methods, /* m_methods */
2111 NULL, /* m_reload */
2112 NULL, /* m_traverse */
2113 NULL, /* m_clear */
2114 NULL, /* m_free */
2115 };
2116 #define MOD_INIT_FUNC_NAME(name) PyInit_##name
2117 #else
2118 #define MOD_INIT_FUNC_NAME(name) init##name
2119 #endif
2120 // MOD_INIT_FUNC [PyInit_{MOD_NAME} or init{MOD_NAME}]
2121 #define _MOD_INIT_FUNC(macro) MOD_INIT_FUNC_NAME(macro)
2122 #define MOD_INIT_FUNC _MOD_INIT_FUNC(_MOD_NAME)
2123
2124 static std::once_flag mLoadFlag1;
2125
MOD_INIT_FUNC(void)2126 PyMODINIT_FUNC MOD_INIT_FUNC(void) {
2127 #if PY_MAJOR_VERSION >= 3
2128 #define ERROR_RETURN return NULL;
2129 #else
2130 #define ERROR_RETURN return;
2131 #endif
2132
2133 #ifdef PYMNN_USE_ALINNPYTHON
2134 std::call_once(mLoadFlag1, [&](){
2135 if (global_new_python_flag > 0) {
2136 tls_key = PyThread_create_key();
2137 tls_key_2 = PyThread_create_key();
2138 }
2139 });
2140 #endif
2141
2142 if (PyType_Ready(&PyMNNInterpreterType) < 0) {
2143 PyErr_SetString(PyExc_Exception, "initMNN: PyType_Ready PyMNNInterpreterType failed");
2144 ERROR_RETURN
2145 }
2146 if (PyType_Ready(&PyMNNSessionType) < 0) {
2147 PyErr_SetString(PyExc_Exception, "initMNN: PyType_Ready PyMNNSessionType failed");
2148 ERROR_RETURN
2149 }
2150 if (PyType_Ready(&PyMNNTensorType) < 0) {
2151 PyErr_SetString(PyExc_Exception, "initMNN: PyType_Ready PyMNNTensorType failed");
2152 ERROR_RETURN
2153 }
2154 if (PyType_Ready(&PyMNNCVImageProcessType) < 0) {
2155 PyErr_SetString(PyExc_Exception, "initMNN: PyType_Ready PyMNNCVImageProcessType failed");
2156 ERROR_RETURN
2157 }
2158 if (PyType_Ready(&PyMNNCVMatrixType) < 0) {
2159 PyErr_SetString(PyExc_Exception, "initMNN: PyType_Ready PyMNNCVMatrixType failed");
2160 ERROR_RETURN
2161 }
2162 if (PyType_Ready(&PyMNNOpInfoType) < 0) {
2163 PyErr_SetString(PyExc_Exception, "initMNN: PyType_Ready PyMNNOpInfoType failed");
2164 ERROR_RETURN
2165 }
2166 #if PY_MAJOR_VERSION >= 3
2167 PyObject *m = PyModule_Create(&moduledef);
2168 #else
2169 PyObject *m = Py_InitModule3(MOD_NAME, module_methods, "MNN Module");
2170 #endif
2171 // module import failed!
2172 if (!m) {
2173 PyErr_SetString(PyExc_Exception, "initMNN: import MNN failed");
2174 ERROR_RETURN
2175 }
2176 #ifdef PYMNN_NUMPY_USABLE
2177 if(_import_array() < 0) {
2178 PyErr_SetString(PyExc_Exception, "initMNN: init numpy failed");
2179 ERROR_RETURN
2180 }
2181 #endif
2182
2183 PyModule_AddObject(m, "Interpreter", (PyObject*)PyType_FindTLSType(&PyMNNInterpreterType));
2184 PyModule_AddObject(m, "Session", (PyObject*)PyType_FindTLSType(&PyMNNSessionType));
2185 PyModule_AddObject(m, "Tensor", (PyObject*)PyType_FindTLSType(&PyMNNTensorType));
2186 PyModule_AddObject(m, "CVImageProcess", (PyObject*)PyType_FindTLSType(&PyMNNCVImageProcessType));
2187 PyModule_AddObject(m, "CVMatrix", (PyObject*)PyType_FindTLSType(&PyMNNCVMatrixType));
2188 PyModule_AddObject(m, "OpInfo", (PyObject*)PyType_FindTLSType(&PyMNNOpInfoType));
2189
2190 // Tensor::DimensionType
2191 PyObject *DimensionType_Tensorflow = PyLong_FromLong(Tensor::TENSORFLOW);
2192 PyObject *DimensionType_Caffe = PyLong_FromLong(Tensor::CAFFE);
2193 PyObject *DimensionType_Caffe_C4 = PyLong_FromLong(Tensor::CAFFE_C4);
2194 PyModule_AddObject(m, "Tensor_DimensionType_Tensorflow", DimensionType_Tensorflow);
2195 PyModule_AddObject(m, "Tensor_DimensionType_Caffe", DimensionType_Caffe);
2196 PyModule_AddObject(m, "Tensor_DimensionType_Caffe_C4", DimensionType_Caffe_C4);
2197
2198 struct MNN_TLSData *tlsData = static_cast<MNN_TLSData *>(malloc(sizeof(MNN_TLSData)));
2199 setTLSData(tlsData);
2200 tlsData->interpreterMap = new std::unordered_map<std::string, Interpreter *>();
2201 tlsData->sessionCacheMap = new std::unordered_map<std::string, Session *>();
2202
2203 // halide_type
2204 tlsData->PyMNNHalideTypeInt = PyCapsule_New(httInt(), NULL, NULL);
2205 tlsData->PyMNNHalideTypeInt64 = PyCapsule_New(httInt64(), NULL, NULL);
2206 tlsData->PyMNNHalideTypeFloat = PyCapsule_New(httFloat(), NULL, NULL);
2207 tlsData->PyMNNHalideTypeDouble = PyCapsule_New(httDouble(), NULL, NULL);
2208 tlsData->PyMNNHalideTypeUint8 = PyCapsule_New(httUint8(), NULL, NULL);
2209 tlsData->PyMNNHalideTypeString = PyCapsule_New(httString(), NULL, NULL);
2210
2211 #if defined(PYMNN_USE_ALINNPYTHON) && defined(PYMNN_EXPR_API)
2212 struct py::detail::rh_tls *rh_tls = static_cast<py::detail::rh_tls *>(malloc(sizeof(py::detail::rh_tls)));
2213 if(nullptr == rh_tls) {
2214 throw runtime_error("rh_tls malloc fail");
2215 }
2216 set_rh_tls_data(rh_tls);
2217 rh_tls->internals_pp = nullptr;
2218 rh_tls->locals = new py::detail::type_map<py::detail::type_info *>;
2219 #endif
2220
2221 PyModule_AddObject(m, "Halide_Type_Int", tlsData->PyMNNHalideTypeInt);
2222 PyModule_AddObject(m, "Halide_Type_Int64", tlsData->PyMNNHalideTypeInt64);
2223 PyModule_AddObject(m, "Halide_Type_Float", tlsData->PyMNNHalideTypeFloat);
2224 PyModule_AddObject(m, "Halide_Type_Double", tlsData->PyMNNHalideTypeDouble);
2225 PyModule_AddObject(m, "Halide_Type_Uint8", tlsData->PyMNNHalideTypeUint8);
2226 PyModule_AddObject(m, "Halide_Type_String", tlsData->PyMNNHalideTypeString);
2227
2228 // CV
2229 // ImageFormat
2230 PyObject *CV_ImageFormat_RGBA = PyLong_FromLong(CV::RGBA);
2231 PyObject *CV_ImageFormat_RGB = PyLong_FromLong(CV::RGB);
2232 PyObject *CV_ImageFormat_BGR = PyLong_FromLong(CV::BGR);
2233 PyObject *CV_ImageFormat_GRAY = PyLong_FromLong(CV::GRAY);
2234 PyObject *CV_ImageFormat_BGRA = PyLong_FromLong(CV::BGRA);
2235 PyObject *CV_ImageFormat_YUV_NV21 = PyLong_FromLong(CV::YUV_NV21);
2236 PyModule_AddObject(m, "CV_ImageFormat_RGBA", CV_ImageFormat_RGBA);
2237 PyModule_AddObject(m, "CV_ImageFormat_RGB", CV_ImageFormat_RGB);
2238 PyModule_AddObject(m, "CV_ImageFormat_BGR", CV_ImageFormat_BGR);
2239 PyModule_AddObject(m, "CV_ImageFormat_GRAY", CV_ImageFormat_GRAY);
2240 PyModule_AddObject(m, "CV_ImageFormat_BGRA", CV_ImageFormat_BGRA);
2241 PyModule_AddObject(m, "CV_ImageFormat_YUV_NV21", CV_ImageFormat_YUV_NV21);
2242 // Filter
2243 PyObject *CV_Filter_NEAREST = PyLong_FromLong(CV::NEAREST);
2244 PyObject *CV_Filter_BILINEAL = PyLong_FromLong(CV::BILINEAR);
2245 PyObject *CV_Filter_BICUBIC = PyLong_FromLong(CV::BICUBIC);
2246 PyModule_AddObject(m, "CV_Filter_NEAREST", CV_Filter_NEAREST);
2247 PyModule_AddObject(m, "CV_Filter_BILINEAL", CV_Filter_BILINEAL);
2248 PyModule_AddObject(m, "CV_Filter_BICUBIC", CV_Filter_BICUBIC);
2249 // wrap
2250 PyObject *CV_Wrap_CLAMP_TO_EDGE = PyLong_FromLong(CV::CLAMP_TO_EDGE);
2251 PyObject *CV_Wrap_ZERO = PyLong_FromLong(CV::ZERO);
2252 PyObject *CV_Wrap_REPEAT = PyLong_FromLong(CV::REPEAT);
2253 PyModule_AddObject(m, "CV_Wrap_CLAMP_TO_EDGE", CV_Wrap_CLAMP_TO_EDGE);
2254 PyModule_AddObject(m, "CV_Wrap_ZERO", CV_Wrap_ZERO);
2255 PyModule_AddObject(m, "CV_Wrap_REPEAT", CV_Wrap_REPEAT);
2256
2257 // static variable initialize
2258 interpreterMap();
2259 sessionCacheMap();
2260
2261 #ifdef PYMNN_EXPR_API
2262 auto py_module = py::reinterpret_borrow<py::module>(m);
2263 INTS default_shape = {};
2264 auto expr_module = py_module.def_submodule("_expr");
2265 py::enum_<Dimensionformat> (expr_module, "data_format")
2266 .value("NHWC", NHWC)
2267 .value("NC4HW4", NC4HW4)
2268 .value("NCHW", NCHW)
2269 .export_values();
2270 py::enum_<DType> (expr_module, "dtype")
2271 .value("float", DType_FLOAT)
2272 .value("double", DType_DOUBLE)
2273 .value("int", DType_INT32)
2274 .value("int64", DType_INT64)
2275 .value("uint8", DType_UINT8)
2276 .export_values();
2277
2278 py::enum_<PaddingMode> (expr_module, "Padding_Mode")
2279 .value("CAFFE", CAFFE)
2280 .value("VALID", VALID)
2281 .value("SAME", SAME)
2282 .export_values();
2283 py::enum_<MNN::Express::PadValueMode> (expr_module, "PadValue_Mode")
2284 .value("CONSTANT", CONSTANT)
2285 .value("REFLECT", REFLECT)
2286 .value("SYMMETRIC", SYMMETRIC)
2287 .export_values();
2288 py::enum_<PoolingMode> (expr_module, "Pooling_Mode")
2289 .value("MAXPOOL", MAXPOOL)
2290 .value("AVEPOOL", AVEPOOL)
2291 .export_values();
2292 py::enum_<InterpolationMethod> (expr_module, "Interp_Method")
2293 .value("BILINEAR", BILINEAR)
2294 .value("NEAREST", NEAREST)
2295 .export_values();
2296 py::class_<VARP>(expr_module, "Var")
2297 .def_property_readonly("shape",
2298 [](VARP *self){
2299 auto info = (*self)->getInfo();
2300 if(nullptr == info) {
2301 throw std::runtime_error("unable to get variable info");
2302 }
2303 return info->dim;
2304 })
2305 .def_property_readonly("valid",
2306 [](VARP *self){
2307 auto info = (*self)->getInfo();
2308 if(nullptr == info) {
2309 return false;
2310 }
2311 return true;
2312 })
2313 .def_property_readonly("data_format",
2314 [](VARP *self){
2315 auto info = (*self)->getInfo();
2316 if(nullptr == info)
2317 throw std::runtime_error("unable to get variable info");
2318 return info->order;
2319 })
2320 .def_property_readonly("dtype",
2321 [](VARP *self){
2322 auto info = (*self)->getInfo();
2323 if(nullptr == info)
2324 throw std::runtime_error("unable to get variable info");
2325 return htype2dtype(info->type);
2326 })
2327 .def_property_readonly("size",
2328 [](VARP *self){
2329 auto info = (*self)->getInfo();
2330 if(nullptr == info) {
2331 throw std::runtime_error("unable to get variable info");
2332 }
2333 return info->size;
2334 })
2335 .def_property("name",
2336 [](VARP *self){
2337 auto name = (*self)->name();
2338 return name;
2339 },
2340 [] (VARP* self, std::string name) {
2341 (*self)->setName(name);
2342 })
2343 #ifdef BUILD_OPTYPE
2344 .def_property_readonly("op_type",
2345 [](VARP *self){
2346 auto op = (*self)->expr().first->get();
2347 if (nullptr == op) {
2348 switch ((*self)->expr().first->inputType()) {
2349 case VARP::INPUT:
2350 return std::string("Input");
2351 case VARP::CONSTANT:
2352 return std::string("Const");
2353 case VARP::TRAINABLE:
2354 return std::string("Trainable");
2355 }
2356 }
2357
2358 auto type = op->type();
2359 if (type == OpType_BinaryOp) {
2360 return std::string(MNN::EnumNameBinaryOpOperation((BinaryOpOperation)op->main_as_BinaryOp()->opType()));
2361 }
2362 if (type == OpType_UnaryOp) {
2363 return std::string(MNN::EnumNameUnaryOpOperation((UnaryOpOperation)op->main_as_UnaryOp()->opType()));
2364 }
2365 return std::string(MNN::EnumNameOpType(type));
2366 })
2367 #endif
2368 .def_property_readonly("inputs",
2369 [] (VARP* self) {
2370 return (*self)->expr().first->inputs();
2371 })
2372 .def("fix_as_placeholder",
2373 [] (VARP* self) {
2374 (*self).fix(VARP::INPUT);
2375 })
2376
2377 .def("fix_as_const",
2378 [] (VARP* self) {
2379 (*self).fix(VARP::CONSTANT);
2380 })
2381 .def("fix_as_trainable",
2382 [] (VARP* self) {
2383 (*self).fix(VARP::TRAINABLE);
2384 })
2385 .def("close",
2386 [] (VARP* self) {
2387 (*self)->input(VARP(nullptr));
2388 })
2389 .def("copy_from",
2390 [] (VARP* self, VARP source) {
2391 bool res = (*self)->input(source);
2392 if (!res) {
2393 throw std::runtime_error("Copy from souce Error");
2394 }
2395 })
2396 .def("set_inputs",
2397 [] (VARP* self, std::vector<VARP> source) {
2398 if (source.empty()) {
2399 throw std::runtime_error("Empty source");
2400 }
2401 auto expr = (*self)->expr();
2402 auto newExpr = Expr::create(expr.first->extra(), std::move(source), expr.first->outputSize());
2403 Expr::replace(expr.first, newExpr);
2404 })
2405 .def("replace",
2406 [] (VARP* self, VARP source) {
2407 Variable::replace(*self, source);
2408 })
2409 .def("reorder",
2410 [] (VARP* self, Dimensionformat order) {
2411 auto newInput = _ChangeInputFormat(*self, order);
2412 (*self) = newInput;
2413 })
2414 .def("resize",
2415 [] (VARP* self, const std::vector<int>& shape) {
2416 (*self)->resize(shape);
2417 })
2418 #ifdef PYMNN_NUMPY_USABLE
2419 .def("read",
2420 [](VARP *self){
2421 auto info = (*self)->getInfo();
2422 if(nullptr == info)
2423 throw std::runtime_error("unable to get variable info");
2424 auto dtype = htype2dtype(info->type);
2425 auto shape = info->dim;
2426 int64_t total_length = info->size;
2427 auto readptr = [self](DType dtype, INTS shape, int64_t total_length) {
2428 void *dataPtr = (void *) (*self)->readMap<void>();
2429 std::vector<npy_intp> npy_dims;
2430 for(const auto dim: shape) {
2431 npy_dims.push_back(dim);
2432 }
2433
2434 switch(dtype) {
2435 case DType_FLOAT:
2436 return PyArray_SimpleNewFromData(npy_dims.size(), npy_dims.data(), NPY_FLOAT, dataPtr);
2437 case DType_DOUBLE:
2438 return PyArray_SimpleNewFromData(npy_dims.size(), npy_dims.data(), NPY_DOUBLE, dataPtr);
2439 case DType_INT32:
2440 return PyArray_SimpleNewFromData(npy_dims.size(), npy_dims.data(), NPY_INT32, dataPtr);
2441 case DType_INT64:
2442 return PyArray_SimpleNewFromData(npy_dims.size(), npy_dims.data(), NPY_INT64, dataPtr);
2443 case DType_UINT8:
2444 return PyArray_SimpleNewFromData(npy_dims.size(), npy_dims.data(), NPY_UINT8, dataPtr);
2445 default:
2446 throw std::runtime_error("does not support this dtype");
2447 }
2448 if (nullptr == dataPtr) {
2449 throw std::runtime_error("call to readMap meet a error");
2450 }
2451 };
2452 auto data = readptr(dtype, shape, total_length);
2453 (*self)->unMap();
2454 return py::reinterpret_steal<py::object>(data);
2455 })
2456 #endif
2457 .def("read_as_tuple",
2458 [](VARP *self){
2459 auto info = (*self)->getInfo();
2460 if(nullptr == info)
2461 throw std::runtime_error("unable to get variable info");
2462 auto dtype = htype2dtype(info->type);
2463 auto shape = info->dim;
2464 size_t total_length = info->size;
2465 auto readptr = [self](DType dtype, INTS shape, size_t total_length) {
2466 void *dataPtr = (void *) (*self)->readMap<void>();
2467 auto obj = PyTuple_New(total_length);
2468 if(DType_FLOAT == dtype) {
2469 auto data = (float*)dataPtr;
2470 for(size_t i = 0; i < total_length; i++) {
2471 PyTuple_SetItem(obj, i, PyFloat_FromDouble(data[i]));
2472 }
2473 } else if(DType_INT32 == dtype) {
2474 auto data = (int32_t*)dataPtr;
2475 for(size_t i = 0; i < total_length; i++) {
2476 PyTuple_SetItem(obj, i, PyLong_FromLong(data[i]));
2477 }
2478 } else if(DType_UINT8 == dtype) {
2479 auto data = (uint8_t*)dataPtr;
2480 for(size_t i = 0; i < total_length; i++) {
2481 PyTuple_SetItem(obj, i, PyLong_FromLong(data[i]));
2482 }
2483 } else if(DType_INT8 == dtype) {
2484 auto data = (int8_t*)dataPtr;
2485 for(size_t i = 0; i < total_length; i++) {
2486 PyTuple_SetItem(obj, i, PyLong_FromLong(data[i]));
2487 }
2488 } else {
2489 throw std::runtime_error("Don't support data type");
2490 }
2491 return obj;
2492 };
2493 auto data = readptr(dtype, shape, total_length);
2494 (*self)->unMap();
2495 return py::reinterpret_steal<py::object>(data);
2496 })
2497 .def("write",
2498 [](VARP *self, py::object data) {
2499 auto info = (*self)->getInfo();
2500 if(nullptr == info) {
2501 throw std::runtime_error("unable to get variable info");
2502 }
2503 auto dtype = htype2dtype(info->type);
2504 auto shape = info->dim;
2505 int64_t total_length = info->size;
2506 PyObject *obj = data.ptr();
2507 auto write = [self](PyObject *obj, DType dtype, int64_t total_length) {
2508 #ifdef PYMNN_NUMPY_USABLE
2509 if(PyArray_Check(obj)) {
2510 //numpy support
2511 if(total_length != PyArray_Size(obj)) {
2512 throw std::runtime_error("data size does not match each other");
2513 }
2514 int npy_type = PyArray_TYPE(obj);
2515 int itemsize = getitemsize(dtype, npy_type);
2516 PyArrayObject *obj_cont= PyArray_GETCONTIGUOUS((PyArrayObject*)obj);
2517 auto tmpBuffer = PyArray_DATA(obj_cont);
2518 if(NULL == tmpBuffer) {
2519 throw std::runtime_error("numpy failed to get buffer");
2520 }
2521 auto data = (*self)->writeMap<void>();
2522 if (nullptr == data) {
2523 throw std::runtime_error("call to writeMap meet a error");
2524 }
2525 memcpy(data, tmpBuffer, total_length * itemsize);
2526 Py_XDECREF(obj_cont);
2527 return;
2528 }
2529 #endif
2530 INTS shapeData = getshape(obj);
2531 int64_t totalLengthData = 1;
2532 INTS stride;
2533 for (size_t i = 0; i < shapeData.size(); i++) {
2534 totalLengthData *= shapeData[i];
2535 }
2536 int totalStride = 1;
2537 for (int i = shapeData.size() - 1; i >= 0; i--) {
2538 if(i + 1 < shapeData.size()) {
2539 totalStride *= shapeData[i+1];
2540 }
2541 stride.push_back(totalStride);
2542 }
2543 std::reverse(stride.begin(), stride.end());
2544 if(totalLengthData != total_length) {
2545 throw std::runtime_error("data size does not match each other");
2546 }
2547 if(DType_FLOAT == dtype) {
2548 auto data = (*self)->writeMap<float>();
2549 if (nullptr == data) {
2550 throw std::runtime_error("call to writeMap meet a error");
2551 }
2552 recursive_store((char*)data, shapeData, stride, 0, obj, dtype, sizeof(float));
2553 }
2554 else if(DType_INT32 == dtype) {
2555 auto data = (*self)->writeMap<int>();
2556 if (nullptr == data) {
2557 throw std::runtime_error("call to writeMap meet a error");
2558 }
2559 recursive_store((char*)data, shapeData, stride, 0, obj, dtype, sizeof(int));
2560 }
2561 else if(DType_UINT8 == dtype) {
2562 auto data = (*self)->writeMap<uint8_t>();
2563 if (nullptr == data) {
2564 throw std::runtime_error("call to writeMap meet a error");
2565 }
2566 recursive_store((char*)data, shapeData, stride, 0, obj, dtype, sizeof(uint8_t));
2567 }
2568 else if(DType_INT8 == dtype) {
2569 auto data = (*self)->writeMap<uint8_t>();
2570 if (nullptr == data) {
2571 throw std::runtime_error("call to writeMap meet a error");
2572 }
2573 recursive_store((char*)data, shapeData, stride, 0, obj, dtype, sizeof(int8_t));
2574 }
2575 };
2576 write(obj, dtype, total_length);
2577 (*self)->unMap();
2578 });
2579 // Load And Save
2580 expr_module.def("load_as_list",
2581 [](std::string fileName) {
2582 auto variable = Variable::load(fileName.c_str());
2583 return variable;
2584 });
2585 expr_module.def("save",
2586 [](const std::vector<VARP>& vars, std::string fileName, bool forInference = true) {
2587 std::vector<VARP> newVars;
2588 for (auto v : vars) {
2589 if (v.get() != nullptr) {
2590 newVars.emplace_back(v);
2591 }
2592 }
2593 #ifdef PYMNN_TRAIN_API
2594 if (forInference) {
2595 Transformer::turnModelToInfer()->onExecute(newVars);
2596 }
2597 #endif
2598 Variable::save(newVars, fileName.c_str());
2599 #ifdef PYMNN_TRAIN_API
2600 ConvertToFullQuant::convert(fileName);
2601 #endif
2602 }, py::arg("variables"), py::arg("file_name"), py::arg("for_inference") = true);
2603 expr_module.def("load_as_dict",
2604 [](std::string fileName) {
2605 auto variable = Variable::loadMap(fileName.c_str());
2606 return variable;
2607 });
2608 expr_module.def("get_inputs_and_outputs", &Variable::getInputAndOutput);
2609 // Executor
2610 expr_module.def("gc", [](bool full) {
2611 auto exe = Executor::getGlobalExecutor();
2612 if (full) {
2613 exe->gc(Executor::FULL);
2614 } else {
2615 exe->gc(Executor::PART);
2616 }
2617 });
2618
2619 // Don't care whether MNN library support corresponding backend, all backend type are usable by user,
2620 // which make MNN.whl setup.py easy
2621 py::enum_<MNNForwardType>(expr_module, "Backend")
2622 .value("CPU", MNN_FORWARD_CPU)
2623 .value("OPENCL", MNN_FORWARD_OPENCL)
2624 .value("OPENGL", MNN_FORWARD_OPENGL)
2625 .value("VULKAN", MNN_FORWARD_VULKAN)
2626 .value("METAL", MNN_FORWARD_METAL)
2627 .value("TRT", MNN_FORWARD_USER_1)
2628 .value("CUDA", MNN_FORWARD_CUDA)
2629 .value("HIAI", MNN_FORWARD_USER_0)
2630 .export_values();
2631
2632 using MemoryMode = BackendConfig::MemoryMode;
2633 using PowerMode = BackendConfig::PowerMode;
2634 using PrecisionMode = BackendConfig::PrecisionMode;
2635 py::enum_<MemoryMode>(expr_module, "MemoryMode")
2636 .value("Normal", MemoryMode::Memory_Normal)
2637 .value("High", MemoryMode::Memory_High)
2638 .value("Low", MemoryMode::Memory_Low)
2639 .export_values();
2640 py::enum_<PowerMode>(expr_module, "PowerMode")
2641 .value("Normal", PowerMode::Power_Normal)
2642 .value("High", PowerMode::Power_High)
2643 .value("Low", PowerMode::Power_Low)
2644 .export_values();
2645 py::enum_<PrecisionMode>(expr_module, "PrecisionMode")
2646 .value("Normal", PrecisionMode::Precision_Normal)
2647 .value("High", PrecisionMode::Precision_High)
2648 .value("Low", PrecisionMode::Precision_Low)
2649 .export_values();
2650 expr_module.def("set_config",
2651 [](MNNForwardType backend, MemoryMode memory_mode, PowerMode power_mode, PrecisionMode precision_mode, int thread_num) {
2652 if (thread_num < 1 || thread_num > 8) {
2653 PyErr_SetString(PyExc_Exception, "thread_num should bigger than 0 and less than 9");
2654 }
2655 thread_num = std::max(std::min(thread_num, 8), 1);
2656 //auto exe = ExecutorScope::Current();
2657 auto exe = Executor::getGlobalExecutor();
2658 BackendConfig config;
2659 config.memory = memory_mode;
2660 config.power = power_mode;
2661 config.precision = precision_mode;
2662 exe->setGlobalExecutorConfig(backend, config, thread_num);
2663 },
2664 py::arg("backend")=MNN_FORWARD_CPU, py::arg("memory_mode")=MemoryMode::Memory_Normal,
2665 py::arg("power_mode")=PowerMode::Power_Normal, py::arg("precision_mode")=PrecisionMode::Precision_Normal,
2666 py::arg("thread_num")=1);
2667
2668 //Begin of Math OPS
2669 //Unary OPS
2670 expr_module.def("sign", &Express::_Sign);
2671 expr_module.def("abs", &Express::_Abs);
2672 expr_module.def("negative", &Express::_Negative);
2673 expr_module.def("floor", &Express::_Floor);
2674 expr_module.def("ceil", &Express::_Ceil);
2675 expr_module.def("square", &Express::_Square);
2676 expr_module.def("sqrt", &Express::_Sqrt);
2677 expr_module.def("rsqrt", &Express::_Rsqrt);
2678 expr_module.def("exp", &Express::_Exp);
2679 expr_module.def("log", &Express::_Log);
2680 expr_module.def("sin", &Express::_Sin);
2681 expr_module.def("cos", &Express::_Cos);
2682 expr_module.def("tan", &Express::_Tan);
2683 expr_module.def("asin", &Express::_Asin);
2684 expr_module.def("acos", &Express::_Acos);
2685 expr_module.def("atan", &Express::_Atan);
2686 expr_module.def("reciprocal", &Express::_Reciprocal);
2687 expr_module.def("log1p", &Express::_Log1p);
2688 expr_module.def("tanh", &Express::_Tanh);
2689 expr_module.def("sigmoid", &Express::_Sigmoid);
2690 //Binary OPS
2691 expr_module.def("add", &Express::_Add);
2692 expr_module.def("subtract", &Express::_Subtract);
2693 expr_module.def("multiply", &Express::_Multiply);
2694 expr_module.def("divide", &Express::_Divide);
2695 expr_module.def("pow", &Express::_Pow);
2696 expr_module.def("minimum", &Express::_Minimum);
2697 expr_module.def("maximum", &Express::_Maximum);
2698 expr_module.def("bias_add", &Express::_BiasAdd);
2699 expr_module.def("greater", &Express::_Greater);
2700 expr_module.def("greater_equal", &Express::_GreaterEqual);
2701 expr_module.def("less", &Express::_Less);
2702 expr_module.def("floordiv", &Express::_FloorDiv);
2703 expr_module.def("squared_difference", &Express::_SquaredDifference);
2704 expr_module.def("equal", &Express::_Equal);
2705 expr_module.def("not_equal", &Express::_NotEqual);
2706 expr_module.def("less_equal", &Express::_LessEqual);
2707 expr_module.def("floormod", &Express::_FloorMod);
2708 //Reduce OPS
2709 expr_module.def("reduce_sum",
2710 [](VARP input, INTS axis, bool keep_dims) {
2711 return _ReduceSum(input, axis, keep_dims);
2712 }, py::arg("input"), py::arg("axis")=default_shape, py::arg("keep_dims")=false);
2713 expr_module.def("reduce_mean",
2714 [](VARP input, INTS axis, bool keep_dims) {
2715 return _ReduceMean(input, axis, keep_dims);
2716 }, py::arg("input"), py::arg("axis")=default_shape, py::arg("keep_dims")=false);
2717 expr_module.def("reduce_max",
2718 [](VARP input, INTS axis, bool keep_dims) {
2719 return _ReduceMax(input, axis, keep_dims);
2720 }, py::arg("input"), py::arg("axis")=default_shape, py::arg("keep_dims")=false);
2721 expr_module.def("reduce_min",
2722 [](VARP input, INTS axis, bool keep_dims) {
2723 return _ReduceMin(input, axis, keep_dims);
2724 }, py::arg("input"), py::arg("axis")=default_shape, py::arg("keep_dims")=false);
2725 expr_module.def("reduce_prod",
2726 [](VARP input, INTS axis, bool keep_dims) {
2727 return _ReduceProd(input, axis, keep_dims);
2728 }, py::arg("input"), py::arg("axis")=default_shape, py::arg("keep_dims")=false);
2729 expr_module.def("reduce_any",
2730 [](VARP input, INTS axis, bool keep_dims) {
2731 return _ReduceAny(input, axis, keep_dims);
2732 }, py::arg("input"), py::arg("axis")=default_shape, py::arg("keep_dims")=false);
2733 expr_module.def("reduce_all",
2734 [](VARP input, INTS axis, bool keep_dims) {
2735 return _ReduceAll(input, axis, keep_dims);
2736 }, py::arg("input"), py::arg("axis")=default_shape, py::arg("keep_dims")=false);
2737 //Eltwise OPS
2738 expr_module.def("eltwise_prod", &Express::_Prod);
2739 expr_module.def("eltwise_sum", &Express::_Sum);
2740 expr_module.def("eltwise_max", &Express::_Max);
2741 expr_module.def("eltwise_sub", &Express::_Sub);
2742 //Other OPS
2743 expr_module.def("cast",
2744 [](VARP x, DType dtype) {
2745 return _Cast(x, dtype2htype(dtype));
2746 });
2747 expr_module.def("matmul", &Express::_MatMul, py::arg("a"), py::arg("b"), py::arg("tranposeA")=false, py::arg("tranposeB")=false);
2748 expr_module.def("normalize", &Express::_Normalize);
2749 expr_module.def("argmax",
2750 [](VARP input, int axis) {
2751 return _ArgMax(input, axis);
2752 }, py::arg("input"), py::arg("axis")=0);
2753 expr_module.def("unravel_index", &Express::_UnravelIndex, py::arg("indices"), py::arg("dims"));
2754 expr_module.def("scatter_nd", &Express::_ScatterNd, py::arg("indices"), py::arg("updates"), py::arg("shape"));
2755 expr_module.def("one_hot",
2756 [](VARP indices, int depth, float onValue, float offValue, int axis) {
2757 return _OneHot(indices, _Scalar<int>(depth), _Scalar<float>(onValue), _Scalar<float>(offValue), axis);
2758 },py::arg("indices"), py::arg("depth"), py::arg("on_value")=1, py::arg("off_value")=0, py::arg("axis")=-1);
2759 expr_module.def("broadcast_to", &Express::_BroadcastTo, py::arg("input"), py::arg("shape"));
2760 //End of Math OPS
2761
2762 //Begin of NN OPS
2763 expr_module.def("placeholder",
2764 [](INTS shape,Dimensionformat data_format, DType dtype)->VARP{
2765 return _Input(shape, data_format, dtype2htype(dtype));
2766 },
2767 py::arg("shape")=default_shape,
2768 py::arg("data_format")=NCHW,
2769 py::arg("dtype")=DType_FLOAT);
2770 expr_module.def("clone",
2771 [](VARP source, bool deepCopy) {
2772 return _Clone(source, deepCopy);
2773 }, py::arg("source"), py::arg("deep_copy")=false);
2774 INTS default_pads = {0, 0};
2775 INTS default_axis = {};
2776 expr_module.def("const",
2777 [](py::object value, INTS shape, Dimensionformat data_format, DType dtype) {
2778 int64_t total_length = 1;
2779 for(size_t i = 0; i < shape.size(); i++) {
2780 if (data_format == NC4HW4 && 1 == i)
2781 {
2782 #ifndef ROUND_UP
2783 #define ROUND_UP(x, y) (((x) + (y) - (1)) / (y) * (y))
2784 #endif
2785 total_length *= ROUND_UP(shape[i], 4);
2786 }
2787 else
2788 {
2789 total_length *= shape[i];
2790 }
2791 }
2792 PyObject *obj = value.ptr();
2793 auto write = [](PyObject *obj, DType dtype, int64_t total_length) {
2794 #ifdef PYMNN_NUMPY_USABLE
2795 if(PyArray_Check(obj)) {
2796 //numpy support
2797 if(total_length != PyArray_Size(obj)) {
2798 throw std::runtime_error("data size does not match each other");
2799 }
2800 int npy_type = PyArray_TYPE(obj);
2801 int itemsize = getitemsize(dtype, npy_type);
2802 PyArrayObject *obj_cont= PyArray_GETCONTIGUOUS((PyArrayObject*)obj);
2803 auto tmpBuffer = PyArray_DATA(obj_cont);
2804 if(NULL == tmpBuffer) {
2805 throw std::runtime_error("numpy failed to get buffer");
2806 }
2807 auto data = malloc(total_length * itemsize);
2808 if (nullptr == data) {
2809 throw std::runtime_error("call to writeMap meet a error");
2810 }
2811 memcpy(data, tmpBuffer, total_length * itemsize);
2812 Py_XDECREF(obj_cont);
2813 return data;
2814 }
2815 #endif
2816 INTS shapeData = getshape(obj);
2817 int64_t totalLengthData = 1;
2818 INTS stride;
2819 for(size_t i = 0; i < shapeData.size(); i++) {
2820 totalLengthData *= shapeData[i];
2821 }
2822 int totalStride = 1;
2823 for (int i = shapeData.size() - 1; i >= 0; i--) {
2824 if (i + 1 < shapeData.size()) {
2825 totalStride *= shapeData[i+1];
2826 }
2827 stride.push_back(totalStride);
2828 }
2829 std::reverse(stride.begin(), stride.end());
2830 if(totalLengthData != total_length) {
2831 throw std::runtime_error("data size does not match each other");
2832 }
2833 void *data = nullptr;
2834 if(DType_FLOAT == dtype) {
2835 data = malloc(total_length * sizeof(float));
2836 if (nullptr == data) {
2837 throw std::runtime_error("not enough memory");
2838 }
2839 recursive_store((char*)data, shapeData, stride, 0, obj, dtype, sizeof(float));
2840 }
2841 else if(DType_INT32 == dtype) {
2842 data = malloc(total_length * sizeof(int));
2843 if (nullptr == data) {
2844 throw std::runtime_error("not enough memory");
2845 }
2846 recursive_store((char*)data, shapeData, stride, 0, obj, dtype, sizeof(int));
2847 }
2848 else if(DType_UINT8 == dtype) {
2849 data = malloc(total_length * sizeof(uint8_t));
2850 if (nullptr == data) {
2851 throw std::runtime_error("not enough memory");
2852 }
2853 recursive_store((char*)data, shapeData, stride, 0, obj, dtype, sizeof(uint8_t));
2854 }
2855 else if(DType_INT8 == dtype) {
2856 data = malloc(total_length * sizeof(int8_t));
2857 if (nullptr == data) {
2858 throw std::runtime_error("not enough memory");
2859 }
2860 recursive_store((char*)data, shapeData, stride, 0, obj, dtype, sizeof(int8_t));
2861 }
2862 return data;
2863 };
2864 auto data = write(obj, dtype, total_length);
2865 VARP ret = nullptr;
2866 if(data) {
2867 ret = _Const((const void*)data, shape, data_format, dtype2htype(dtype));
2868 free(data);
2869 }
2870 return ret;
2871 },py::arg("value_list"), py::arg("shape"), py::arg("data_format")=NCHW, py::arg("dtype")=DType::DType_FLOAT);
2872 INTS default_stride = {1, 1};
2873 INTS default_dialate = {1, 1};
2874 expr_module.def("conv2d",
2875 [](VARP input, VARP weight, VARP bias, INTS stride, INTS padding, INTS dilate, int group, PaddingMode padding_mode) {
2876 return _Conv(weight, bias, input, padding_mode, stride, dilate, group, padding);
2877 },py::arg("input"), py::arg("weight"), py::arg("bias"),
2878 py::arg("stride")=default_stride,
2879 py::arg("padding")=default_pads,
2880 py::arg("dilate")=default_dialate,
2881 py::arg("group")=1,
2882 py::arg("padding_mode")=VALID);
2883 expr_module.def("conv2d_transpose",
2884 [](VARP input, VARP weight, VARP bias, INTS stride, INTS padding, INTS dilate, int group, PaddingMode padding_mode) {
2885 return _Deconv(weight, bias, input, padding_mode, stride, dilate, group, padding);
2886 },py::arg("input"), py::arg("weight"), py::arg("bias"),
2887 py::arg("stride")=default_stride,
2888 py::arg("padding")=default_pads,
2889 py::arg("dilate")=default_dialate,
2890 py::arg("group")=1,
2891 py::arg("padding_mode")=VALID);
2892 expr_module.def("max_pool",
2893 [](VARP x, INTS kernel, INTS stride, PaddingMode pad, INTS pads) {
2894 return _MaxPool(x, kernel, stride, pad, pads);
2895 }, py::arg("input"), py::arg("kernel"), py::arg("stride"),
2896 py::arg("padding_mode")=VALID,
2897 py::arg("pads")=default_pads);
2898 expr_module.def("avg_pool",
2899 [](VARP x, INTS kernel, INTS stride, PaddingMode pad, INTS pads) {
2900 return _AvePool(x, kernel, stride, pad, pads);
2901 }, py::arg("input"), py::arg("kernel"), py::arg("stride"),
2902 py::arg("padding_mode")=VALID,
2903 py::arg("pads")=default_pads);
2904 expr_module.def("reshape",
2905 [](VARP x, INTS shape, Dimensionformat original_format) {
2906 return _Reshape(x, shape, original_format);
2907 }, py::arg("x"), py::arg("shape"), py::arg("original_format")=NCHW);
2908 expr_module.def("reshape",
2909 [](VARP x, VARP shape) {
2910 return _Reshape(x, shape);
2911 });
2912 expr_module.def("scale", &Express::_Scale, py::arg("x"), py::arg("channels"), py::arg("scales"), py::arg("bias"));
2913 expr_module.def("relu",
2914 [](VARP x, float slope) {
2915 return _Relu(x, slope);
2916 }, py::arg("x"), py::arg("slope")=0.0f);
2917 expr_module.def("relu6", &Express::_Relu6, py::arg("x"), py::arg("min") = 0.0f, py::arg("max") = 6.0f);
2918 expr_module.def("prelu", &Express::_PRelu, py::arg("x"), py::arg("slopes"));
2919 expr_module.def("softmax",
2920 [](VARP logits, int axis) {
2921 return _Softmax(logits, axis);
2922 }, py::arg("logits"), py::arg("axis")=-1);
2923 expr_module.def("softplus", &Express::_Softplus, py::arg("features"));
2924 expr_module.def("softsign", &Express::_Softsign, py::arg("features"));
2925 expr_module.def("slice", &Express::_Slice, py::arg("input"), py::arg("starts"), py::arg("sizes"));
2926 expr_module.def("split", &Express::_Split, py::arg("input"), py::arg("size_splits"), py::arg("axis"));
2927 expr_module.def("strided_slice", &Express::_StridedSlice, py::arg("input"), py::arg("begin"), py::arg("end"),
2928 py::arg("strides"), py::arg("begin_mask"), py::arg("end_mask"), py::arg("ellipsis_mask"), py::arg("new_axis_mask"), py::arg("shrink_axis_mask"));
2929 expr_module.def("concat", &Express::_Concat, py::arg("values"), py::arg("axis"));
2930 expr_module.def("convert", &Express::_Convert, py::arg("input"), py::arg("format"));
2931 expr_module.def("transpose",
2932 [](VARP x, INTS perm) {
2933 return _Transpose(x, perm);
2934 }, py::arg("x"), py::arg("perm"));
2935 expr_module.def("transpose",
2936 [](VARP x, VARP perm) {
2937 return _Transpose(x, perm);
2938 });
2939 expr_module.def("channel_shuffle", &Express::_ChannelShuffle);
2940 // change_inputformat not exposed because it's for static graphs.
2941 //expr_module.def("change_inputformat", &Express::_ChangeInputFormat);
2942 //
2943 expr_module.def("reverse_sequence", &Express::_ReverseSequence, py::arg("x"), py::arg("y"), py::arg("batch_dim"), py::arg("seq_dim"));
2944 expr_module.def("crop", &Express::_Crop, py::arg("images"), py::arg("size"), py::arg("axis"), py::arg("offset"));
2945 expr_module.def("resize", &Express::_Resize, py::arg("images"), py::arg("x_scale"), py::arg("y_scale"));
2946 expr_module.def("pad",
2947 [](VARP x, VARP paddings, MNN::Express::PadValueMode mode) {
2948 return Express::_Pad(x, paddings, mode);
2949 }, py::arg("x"), py::arg("paddings"), py::arg("mode")=CONSTANT);
2950 expr_module.def("expand_dims",
2951 [](VARP input, int axis) {
2952 return _ExpandDims(input, axis);
2953 });
2954 expr_module.def("expand_dims",
2955 [](VARP input, VARP axis) {
2956 return _ExpandDims(input, axis);
2957 });
2958 expr_module.def("shape",
2959 [](VARP input) {
2960 return Express::_Shape(input, false);
2961 }, py::arg("input"));
2962 expr_module.def("stack",
2963 [](VARPS values, int axis) {
2964 return _Stack(values, axis);
2965 }, py::arg("values"), py::arg("axis")=0);
2966 expr_module.def("crop_and_resize",
2967 [](VARP image, VARP boxes, VARP box_ind, VARP crop_size, InterpolationMethod method, float extrapolation_value) {
2968 return _CropAndResize(image, boxes, box_ind, crop_size, method, extrapolation_value);
2969 }, py::arg("image"), py::arg("boxes"), py::arg("box_ind"), py::arg("crop_size"),
2970 py::arg("method")=BILINEAR, py::arg("extrapolation_value")=0.0f);
2971 expr_module.def("fill", &Express::_Fill, py::arg("dims"), py::arg("value"));
2972 expr_module.def("tile", &Express::_Tile, py::arg("input"), py::arg("multiples"));
2973 expr_module.def("gather", &Express::_Gather, py::arg("params"), py::arg("indices"));
2974 expr_module.def("select", &Express::_Select);
2975
2976 // Currently only axis == 0 is supported, which is the same as gather.
2977 /*
2978 expr_module.def("gather_v2",
2979 [](VARP params, VARP indices, VARP axis = nullptr) {
2980 return _GatherV2(params, indices, axis);
2981 }, py::arg("params"), py::arg("indices"), py::arg("axis")=nullptr);
2982 */
2983
2984 expr_module.def("squeeze",
2985 [](VARP input, INTS axis) {
2986 return _Squeeze(input, axis);
2987 }, py::arg("input"), py::arg("axis")=default_axis);
2988 expr_module.def("unsqueeze",
2989 [](VARP input, INTS axis) {
2990 return _Unsqueeze(input, axis);
2991 }, py::arg("input"), py::arg("axis")=default_axis);
2992 expr_module.def("batch_to_space_nd", &Express::_BatchToSpaceND, py::arg("input"), py::arg("block_shape"), py::arg("crops"));
2993 expr_module.def("gather_nd", &Express::_GatherND, py::arg("params"), py::arg("indices"));
2994 expr_module.def("selu", &Express::_Selu, py::arg("features"), py::arg("scale"), py::arg("alpha"));
2995 expr_module.def("size", &Express::_Size, py::arg("input"));
2996 expr_module.def("elu",
2997 [](VARP features, float alpha) {
2998 return _Elu(features, alpha);
2999 }, py::arg("features"), py::arg("alpha")=1.0);
3000 expr_module.def("matrix_band_part", &Express::_MatrixBandPart, py::arg("input"), py::arg("num_lower"), py::arg("num_upper"));
3001 expr_module.def("moments", &Express::_Moments, py::arg("x"), py::arg("axes"), py::arg("shift"), py::arg("keep_dims"));
3002 expr_module.def("setdiff1d", &Express::_SetDiff1D, py::arg("x"), py::arg("y"));
3003 expr_module.def("space_to_depth", &Express::_SpaceToDepth, py::arg("input"), py::arg("block_size"));
3004 expr_module.def("space_to_batch_nd", &Express::_SpaceToBatchND, py::arg("input"), py::arg("block_shape"), py::arg("paddings"));
3005 expr_module.def("zeros_like", &Express::_ZerosLike, py::arg("input"));
3006 expr_module.def("unstack",
3007 [](VARP value, int axis) {
3008 return _Unstack(value, axis);
3009 }, py::arg("value"), py::arg("axis")=0);
3010 expr_module.def("rank", &Express::_Rank, py::arg("input"));
3011 expr_module.def("range", &Express::_Range, py::arg("start"), py::arg("limit"), py::arg("delta"));
3012 expr_module.def("depth_to_space", &Express::_DepthToSpace, py::arg("input"), py::arg("block_size"));
3013 expr_module.def("detection_post_process", &Express::_DetectionPostProcess,
3014 py::arg("encode_boxes"), py::arg("class_predictions"), py::arg("anchors"),
3015 py::arg("num_classes"), py::arg("max_detections"), py::arg("max_class_per_detection"),
3016 py::arg("detections_per_class"), py::arg("nms_threshold"), py::arg("iou_threshold"),
3017 py::arg("use_regular_nms")=false, py::arg("centersize_encoding"));
3018 //End of NN OPS
3019 auto cv_module = py_module.def_submodule("cv");
3020 py::enum_<CV::ImageFormat>(cv_module, "Format")
3021 .value("RGBA", CV::RGBA)
3022 .value("RGB", CV::RGB)
3023 .value("GRAY", CV::GRAY)
3024 .value("BGR", CV::BGR)
3025 .value("YUV_NV21", CV::YUV_NV21)
3026 .value("YUV_NV12", CV::YUV_NV12)
3027 .export_values();
3028
3029 auto nn_module = py_module.def_submodule("_nn");
3030
3031 class PyModule : public Module {
3032 public:
3033 using Module::Module;
3034 using Module::registerModel;
3035
3036 virtual std::vector<Express::VARP> onForward(const std::vector<Express::VARP>& inputs) override {
3037 PYBIND11_OVERLOAD_PURE(std::vector<Express::VARP>, Module, forward, inputs);
3038 }
3039 };
3040 py::class_<Module, PyModule, std::shared_ptr<Module>>(nn_module, "_Module")
3041 .def(py::init())
3042 .def("__call__", &Module::forward)
3043 .def("__call__", &Module::onForward)
3044 .def("forward", &Module::forward)
3045 .def("forward", &Module::onForward)
3046 .def_property_readonly("name", &Module::name) // TODO: too ugly, find way to fix it
3047 .def("set_name", &Module::setName)
3048 .def_property_readonly("is_training", &Module::getIsTraining)
3049 .def("train", &Module::setIsTraining, py::arg("is_training") = true)
3050 .def_property_readonly("parameters", &Module::parameters)
3051 .def("load_parameters", &Module::loadParameters)
3052 .def("clear_cache", &Module::clearCache)
3053 .def("_register_submodules", &PyModule::registerModel)
3054 .def("_add_parameter", &Module::addParameter);
3055
3056 nn_module.def("load_module", [](vector<VARP> inputs, vector<VARP> outputs, bool fortrain){
3057 #ifdef PYMNN_TRAIN_API
3058 return NN::extract(inputs, outputs, fortrain);
3059 #else
3060 return Module::extract(inputs, outputs, fortrain);
3061 #endif
3062 });
3063 nn_module.def("load_module_from_file", [](const vector<string>& inputs, const vector<string>& outputs,
3064 const char* file_name, bool dynamic, bool shape_mutable, bool rearrange,
3065 MNNForwardType backend, MemoryMode memory_mode, PowerMode power_mode,
3066 PrecisionMode precision_mode, int thread_num) -> Module* {
3067 BackendConfig backend_config;
3068 backend_config.memory = memory_mode;
3069 backend_config.power = power_mode;
3070 backend_config.precision = precision_mode;
3071
3072 Module::BackendInfo backend_info;
3073 backend_info.type = backend;
3074 backend_info.config = &backend_config;
3075
3076 Module::Config config;
3077 config.dynamic = dynamic;
3078 config.shapeMutable = shape_mutable;
3079 config.rearrange = rearrange;
3080 config.backend = &backend_info;
3081
3082 auto converted_file_name = convertBytesEncodeIfNeed(file_name);
3083 auto m_ptr = Module::load(inputs, outputs, converted_file_name.data(), &config);
3084 if (m_ptr == nullptr) {
3085 std::string mnn_errno = "load_module_from_file failed ";
3086 mnn_errno = mnn_errno + std::string(file_name);
3087 PyErr_SetString(PyExc_Exception, mnn_errno.c_str());
3088 }
3089 return m_ptr;
3090 });
3091
3092 #ifdef PYMNN_TRAIN_API
3093 // CNN
3094 nn_module.def("conv", [](int in_channel, int out_channel, INTS kernel_size, INTS stride, INTS padding,
3095 INTS dilation, bool depthwise, bool bias, PaddingMode padding_mode) {
3096 NN::ConvOption option;
3097 option.channel = {in_channel, out_channel};
3098 option.kernelSize = kernel_size;
3099 if (!stride.empty()) {
3100 option.stride = stride;
3101 }
3102 option.padMode = padding_mode;
3103 if (!padding.empty()) {
3104 option.pads = padding;
3105 }
3106 if (!dilation.empty()) {
3107 option.dilate = dilation;
3108 }
3109 option.depthwise = depthwise;
3110 return NN::Conv(std::move(option), bias);
3111 },
3112 py::arg("in_channels"), py::arg("out_channels"), py::arg("kernel_size"),
3113 py::arg("stride") = std::vector<int>({1, 1}),
3114 py::arg("padding") = std::vector<int>({0, 0}),
3115 py::arg("dilation") = std::vector<int>({1, 1}),
3116 py::arg("depthwise") = false,
3117 py::arg("bias") = true,
3118 py::arg("padding_mode") = PaddingMode::VALID
3119 );
3120 nn_module.def("linear", [](int in_channel, int out_channel, bool bias) {
3121 return NN::Linear(in_channel, out_channel, bias);
3122 },
3123 py::arg("in_channels"),
3124 py::arg("out_channels"),
3125 py::arg("bias") = true
3126 );
3127 nn_module.def("batch_norm", &NN::BatchNorm, py::arg("channels"), py::arg("dims") = 4, py::arg("momentum") = 0.99, py::arg("epsilon") = 1e-5);
3128 nn_module.def("dropout", &NN::Dropout, py::arg("dropout_ratio"));
3129
3130 auto optim_module = py_module.def_submodule("_optim");
3131
3132 {
3133 py::enum_<ParameterOptimizer::RegularizationMethod>(optim_module, "Regularization_Method")
3134 .value("L1", ParameterOptimizer::RegularizationMethod::L1)
3135 .value("L2", ParameterOptimizer::RegularizationMethod::L2)
3136 .value("L1L2", ParameterOptimizer::RegularizationMethod::L1L2)
3137 .export_values();
3138
3139 py::class_<ParameterOptimizer>(optim_module, "_Optimizer")
3140 .def_property("learning_rate", [](ParameterOptimizer* self) {
3141 return ((SGD*)self)->currentLearningRate();
3142 },
3143 [](ParameterOptimizer* self, float lr) {
3144 ((SGD*)self)->setLearningRate(lr);
3145 }
3146 )
3147 .def_property("momentum", [](ParameterOptimizer* self) {
3148 return ((SGD*)self)->getMomentum();
3149 },
3150 [](ParameterOptimizer* self, float m) {
3151 ((SGD*)self)->setMomentum(m);
3152 }
3153 )
3154 .def_property("momentum2", [](ParameterOptimizer* self) {
3155 return ((ADAM*)self)->getMomentum2();
3156 },
3157 [](ParameterOptimizer* self, float m) {
3158 ((ADAM*)self)->setMomentum2(m);
3159 }
3160 )
3161 .def_property("weight_decay", [](ParameterOptimizer* self) {
3162 return ((SGD*)self)->getWeightDecay();
3163 },
3164 [](ParameterOptimizer* self, float decay) {
3165 ((SGD*)self)->setWeightDecay(decay);
3166 }
3167 )
3168 .def_property("eps", [](ParameterOptimizer* self) {
3169 return ((ADAM*)self)->getEps();
3170 },
3171 [](ParameterOptimizer* self, float eps) {
3172 ((ADAM*)self)->setEps(eps);
3173 }
3174 )
3175 .def_property("regularization_method", [](ParameterOptimizer* self) {
3176 return ((SGD*)self)->getRegularizationMethod();
3177 },
3178 [](ParameterOptimizer* self, ParameterOptimizer::RegularizationMethod method) {
3179 ((SGD*)self)->setRegularizationMethod(method);
3180 }
3181 )
3182 .def("step", [](ParameterOptimizer* self, Express::VARP loss) {
3183 return self->step(loss);
3184 })
3185 ;
3186
3187 optim_module.def("SGD", &ParameterOptimizer::createSGD,
3188 py::arg("module"),
3189 py::arg("learning_rate"), py::arg("momentum") = 0.9, py::arg("weight_decay") = 0,
3190 py::arg("regularization_method") = ParameterOptimizer::RegularizationMethod::L2);
3191 optim_module.def("ADAM", &ParameterOptimizer::createADAM,
3192 py::arg("module"),
3193 py::arg("learning_rate") = 1e-3, py::arg("momentum") = 0.9, py::arg("momentum2") = 0.999,
3194 py::arg("weight_decay") = 0.0, py::arg("eps") = 1e-8,
3195 py::arg("regularization_method") = ParameterOptimizer::RegularizationMethod::L2);
3196 }
3197
3198
3199 {
3200 class PyDataset : public Dataset {
3201 public:
3202 using Dataset::Dataset;
3203
3204 virtual Example get(size_t index) override {
3205 PYBIND11_OVERLOAD_PURE(Example, Dataset, __getitem__, index);
3206 }
3207 virtual size_t size() override {
3208 PYBIND11_OVERLOAD_PURE(size_t, Dataset, __len__);
3209 }
3210 };
3211
3212 auto data_module = py_module.def_submodule("_data");
3213 py::class_<Dataset, PyDataset, std::shared_ptr<Dataset>>(data_module, "Dataset")
3214 .def(py::init())
3215 .def("__getitem__", &Dataset::get, py::arg("index"))
3216 .def("__len__", &Dataset::size)
3217 ;
3218
3219 py::class_<DataLoader>(data_module, "DataLoader")
3220 .def(py::init([](std::shared_ptr<Dataset> dataset, const int batchsize, const bool shuffle, const int numWorkers) {
3221 bool stack = true;
3222 //TODO:hardcode numworkers as 0, as to enable workers, we need gil, in private pybind, gil is removed.
3223 return DataLoader::makeDataLoader(dataset, batchsize, stack, shuffle, 0);
3224 }), py::arg("dataset"), py::arg("batch_size"), py::arg("shuffle") = true, py::arg("num_workers") = 0)
3225 .def_property_readonly("iter_number", &DataLoader::iterNumber)
3226 .def_property_readonly("size", &DataLoader::size)
3227 .def("reset", &DataLoader::reset)
3228 .def("next", [](DataLoader* self) {
3229 return self->next()[0]; // since we always stack
3230 })
3231 ;
3232 }
3233
3234 {
3235 // Loss
3236 auto loss_module = nn_module.def_submodule("loss");
3237 loss_module.def("cross_entropy", _CrossEntropy, py::arg("predicts"), py::arg("one_hot_targets"));
3238 loss_module.def("kl", _KLDivergence, py::arg("predicts"), py::arg("one_hot_targets"));
3239 loss_module.def("mse", _MSE, py::arg("predicts"), py::arg("one_hot_targets"));
3240 loss_module.def("mae", _MAE, py::arg("predicts"), py::arg("one_hot_targets"));
3241 loss_module.def("hinge", _Hinge, py::arg("predicts"), py::arg("one_hot_targets"));
3242 }
3243
3244 {
3245 auto compress_module = nn_module.def_submodule("compress");
3246 py::enum_<NN::FeatureScaleStatMethod>(compress_module, "Feature_Scale_Method")
3247 .value("PER_TENSOR", NN::PerTensor)
3248 .value("PER_CHANNEL", NN::PerChannel)
3249 .export_values();
3250 py::enum_<NN::ScaleUpdateMethod>(compress_module, "Scale_Update_Method")
3251 .value("MAXIMUM", NN::Maximum)
3252 .value("MOVING_AVERAGE", NN::MovingAverage)
3253 .export_values();
3254 compress_module.def("train_quant", &NN::turnQuantize,
3255 py::arg("module"),
3256 py::arg("quant_bits") = 8,
3257 py::arg("feature_scale_method") = NN::FeatureScaleStatMethod::PerTensor,
3258 py::arg("scale_update_method") = NN::ScaleUpdateMethod::MovingAverage);
3259 }
3260 // End of Train
3261 #endif
3262 #endif
3263 #if PY_MAJOR_VERSION >= 3
3264 return m;
3265 #else
3266 return;
3267 #endif
3268 }
3269
3270 // MNNPyBridge invoke loadMNN by static block on Windows / Linux / Mac / Android
3271 #if defined(PYMNN_USE_ALINNPYTHON) && !defined(TARGET_OS_IOS)
3272 static std::once_flag mLoadFlag2;
3273 // Declared (extern "C" PYMNN_PUBLIC) in MNNPyBridge
loadMNN()3274 void loadMNN() {
3275 std::call_once(mLoadFlag2, [](){
3276 WeImport_AppendInittab(MOD_NAME, MOD_INIT_FUNC);
3277 });
3278 }
__anon7c5d7e065f02() 3279 static auto registerMNN = []() {
3280 loadMNN();
3281 return true;
3282 }();
3283 #endif
3284