1 /*
2     MNN python module
3     PYMNN_EXPR_API: MNN.expr, MNN.nn
4     PYMNN_TRAIN_API: MNN.nn.compress, MNN.nn.loss, MNN.data, MNN.optim
5 */
6 #include "MNNPyBridge.h"
7 #include "common.h"
8 #include "util.h"
9 
10 static int tls_key = 0;
11 static int tls_key_2 = 0;
12 
13 #ifdef PYMNN_EXPR_API
14 #ifdef PYMNN_USE_ALINNPYTHON
15 #include "pybind_private/pybind11.h"
16 #include "pybind_private/stl.h"
17 #include "pybind_private/operators.h"
18 #else
19 #include "pybind11/pybind11.h"
20 #include "pybind11/stl.h"
21 #include "pybind11/operators.h"
22 #endif // PYMNN_USE_ALINNPYTHON
23 #endif // PYMNN_EXPR_API
24 
25 #include <MNN/Interpreter.hpp>
26 #include <MNN/ImageProcess.hpp>
27 #ifdef PYMNN_EXPR_API
28 namespace py = pybind11;
29 #include <MNN/expr/Expr.hpp>
30 #include <MNN/expr/ExprCreator.hpp>
31 #include <MNN/expr/Executor.hpp>
32 //#include <MNN/expr/ExecutorScope.hpp>
33 #include <MNN/expr/Module.hpp>
34 using namespace MNN::Express;
35 #endif // PYMNN_EXPR_API
36 
37 #ifdef BUILD_OPTYPE
38 #include "MNN_generated.h"
39 #endif // BUILD_OPTYPE
40 
41 #ifdef PYMNN_TRAIN_API
42 #include "NN.hpp"
43 #include "OpGrad.hpp"
44 #include "ParameterOptimizer.hpp"
45 #include "SGD.hpp"
46 #include "ADAM.hpp"
47 #include "Dataset.hpp"
48 #include "DataLoader.hpp"
49 #include "Loss.hpp"
50 #include "Transformer.hpp"
51 #include "PipelineModule.hpp"
52 #include "cpp/ConvertToFullQuant.hpp"
53 using namespace MNN::Train;
54 #endif // PYMNN_TRAIN_API
55 
56 #include <mutex>
57 #include <unordered_map>
58 
59 using namespace MNN;
60 
61 using namespace std;
62 
63 struct MNN_TLSData {
64     PyObject *PyMNNHalideTypeInt = NULL;
65     PyObject *PyMNNHalideTypeInt64 = NULL;
66     PyObject *PyMNNHalideTypeFloat = NULL;
67     PyObject *PyMNNHalideTypeDouble = NULL;
68     PyObject *PyMNNHalideTypeUint8 = NULL;
69     PyObject *PyMNNHalideTypeString = NULL;
70     std::unordered_map<std::string, Interpreter *> *interpreterMap = NULL;
71     std::unordered_map<std::string, Session *> *sessionCacheMap = NULL;
72 };
73 static MNN_TLSData* old_python_data = NULL;
getTLSData()74 static MNN_TLSData * getTLSData() {
75     if(global_new_python_flag > 0) {
76         return static_cast<MNN_TLSData*>(PyThread_get_key_value(tls_key));
77     }else{
78         return old_python_data;
79     }
80 }
setTLSData(MNN_TLSData * tlsData)81 static void setTLSData(MNN_TLSData* tlsData) {
82     if(global_new_python_flag > 0) {
83         PyThread_set_key_value(tls_key, tlsData);
84     } else {
85         old_python_data = tlsData;
86     }
87 }
88 
89 #if defined(PYMNN_EXPR_API) && defined(PYMNN_USE_ALINNPYTHON)
set_rh_tls_data(py::detail::rh_tls * rh_tls)90 static void set_rh_tls_data(py::detail::rh_tls* rh_tls) {
91     if(global_new_python_flag > 0) {
92         PyThread_set_key_value(tls_key_2, rh_tls);
93     } else {
94         py::detail::old_rh_tls_data = rh_tls;
95     }
96 }
97 #endif
98 
99 #ifdef PYMNN_EXPR_API
100 namespace py = pybind11;
101 #endif
importName(const char * name,const char * symbol)102 static PyObject *importName(const char *name, const char *symbol)
103 {
104     PyObject *u_name, *module;
105     u_name = PyUnicode_FromString(name);
106     module = PyImport_Import(u_name);
107     if (!module) {
108         return NULL;
109     }
110     Py_DECREF(u_name);
111     return PyObject_GetAttrString(module, symbol);
112 }
113 
114 typedef struct {
115     PyObject_HEAD
116     std::string *modelPath;
117     Interpreter *interpreter;
118 } PyMNNInterpreter;
119 
120 typedef struct {
121     PyObject_HEAD
122     std::string *modelPath;
123     Session *session;
124 } PyMNNSession;
125 
126 typedef struct {
127     PyObject_HEAD
128     Tensor *tensor;
129     int owner;
130 } PyMNNTensor;
131 
132 typedef struct {
133     PyObject_HEAD
134     CV::ImageProcess *imageProcess;
135 } PyMNNCVImageProcess;
136 
137 typedef struct {
138     PyObject_HEAD
139     CV::Matrix *matrix;
140 } PyMNNCVMatrix;
141 
142 typedef struct {
143     PyObject_HEAD
144     const OperatorInfo *opInfo;
145 }PyMNNOpInfo;
httInt()146 halide_type_t* httInt() {
147     static halide_type_t httInt = halide_type_of<int>();
148     return &httInt;
149 }
150 
httInt64()151 halide_type_t* httInt64() {
152     static halide_type_t httInt64 = halide_type_of<int64_t>();
153     return &httInt64;
154 }
155 
httFloat()156 halide_type_t* httFloat() {
157     static halide_type_t httFloat = halide_type_of<float>();
158     return &httFloat;
159 }
160 
httDouble()161 halide_type_t* httDouble() {
162     static halide_type_t httDouble = halide_type_of<double>();
163     return &httDouble;
164 }
165 
httUint8()166 halide_type_t* httUint8() {
167     static halide_type_t httUint8 = halide_type_of<uint8_t>();
168     return &httUint8;
169 }
170 
httString()171 halide_type_t* httString() {
172     static halide_type_t httString = halide_type_t(halide_type_handle, sizeof(void*)*8);
173     return &httString;
174 }
175 
176 /// MNN NetInstance Type
177 static PyObject* PyMNNInterpreter_createSession(PyMNNInterpreter *self, PyObject *args);
178 static PyObject* PyMNNInterpreter_resizeSession(PyMNNInterpreter *self, PyObject *args);
179 static PyObject* PyMNNInterpreter_resizeTensor(PyMNNInterpreter *self, PyObject *args);
180 static PyObject* PyMNNInterpreter_runSession(PyMNNInterpreter *self, PyObject *args);
181 static PyObject* PyMNNInterpreter_runSessionWithCallBack(PyMNNInterpreter *self, PyObject *args);
182 static PyObject* PyMNNInterpreter_runSessionWithCallBackInfo(PyMNNInterpreter *self, PyObject *args);
183 static PyObject* PyMNNInterpreter_getSessionInput(PyMNNInterpreter *self, PyObject *args);
184 static PyObject* PyMNNInterpreter_getSessionOutput(PyMNNInterpreter *self, PyObject *args);
185 static PyObject* PyMNNInterpreter_getSessionInputAll(PyMNNInterpreter *self, PyObject *args);
186 static PyObject* PyMNNInterpreter_getSessionOutputAll(PyMNNInterpreter *self, PyObject *args);
187 #ifndef PYMNN_USE_ALINNPYTHON
188 static PyObject* PyMNNInterpreter_setCacheFile(PyMNNInterpreter *self, PyObject *args);
189 #endif
190 static PyObject* PyMNNInterpreter_cache(PyMNNInterpreter *self, PyObject *args);
191 static PyObject* PyMNNInterpreter_removeCache(PyMNNInterpreter *self, PyObject *args);
192 static PyObject* PyMNNInterpreter_updateSessionToModel(PyMNNInterpreter *self, PyObject *args);
193 static PyObject* PyMNNInterpreter_new(struct _typeobject *type, PyObject *args, PyObject *kwds);
194 static int PyMNNInterpreter_init(PyMNNInterpreter *self, PyObject *args, PyObject *kwds);
195 static void PyMNNInterpreter_dealloc(PyMNNInterpreter *);
196 
197 static PyMethodDef PyMNNInterpreter_methods[] = {
198     {"createSession", (PyCFunction)PyMNNInterpreter_createSession, METH_VARARGS, "create session"},
199 #ifndef PYMNN_USE_ALINNPYTHON
200     {"setCacheFile", (PyCFunction)PyMNNInterpreter_setCacheFile, METH_VARARGS, "set cache file for create session"},
201 #endif
202     {"resizeSession", (PyCFunction)PyMNNInterpreter_resizeSession, METH_VARARGS, "resize session"},
203     {"runSession", (PyCFunction)PyMNNInterpreter_runSession, METH_VARARGS, "run session"},
204     {"runSessionWithCallBack", (PyCFunction)PyMNNInterpreter_runSessionWithCallBack, METH_VARARGS, "run session with callback"},
205     {"runSessionWithCallBackInfo", (PyCFunction)PyMNNInterpreter_runSessionWithCallBackInfo, METH_VARARGS, "run session with callback info"},
206     {"getSessionOutput", (PyCFunction)PyMNNInterpreter_getSessionOutput, METH_VARARGS, "get session output"},
207     {"getSessionInput", (PyCFunction)PyMNNInterpreter_getSessionInput, METH_VARARGS, "get session input"},
208     {"getSessionOutputAll", (PyCFunction)PyMNNInterpreter_getSessionOutputAll, METH_VARARGS, "get session output all"},
209     {"getSessionInputAll", (PyCFunction)PyMNNInterpreter_getSessionInputAll, METH_VARARGS, "get session input all"},
210     {"resizeTensor", (PyCFunction)PyMNNInterpreter_resizeTensor, METH_VARARGS, "resize tensor"},
211     {"cache", (PyCFunction)PyMNNInterpreter_cache, METH_VARARGS, "cache current net instance"},
212     {"removeCache", (PyCFunction)PyMNNInterpreter_removeCache, METH_VARARGS, "remove cache with given path"},
213     {"updateSessionToModel", (PyCFunction)PyMNNInterpreter_updateSessionToModel, METH_VARARGS, "updateSessionToModel"},
214     {NULL}  /* Sentinel */
215 };
216 
217 static PyTypeObject PyMNNInterpreterType = {
218     PyVarObject_HEAD_INIT(NULL, 0)
219     "MNN.Interpreter",                   /*tp_name*/
220     sizeof(PyMNNInterpreter),                      /*tp_basicsize*/
221     0,                                        /*tp_itemsize*/
222     (destructor)PyMNNInterpreter_dealloc,          /*tp_dealloc*/
223     0,                                        /*tp_print*/
224     0,                                        /*tp_getattr*/
225     0,                                        /*tp_setattr*/
226     0,                                        /*tp_compare*/
227     0,                                        /*tp_repr*/
228     0,                                        /*tp_as_number*/
229     0,                                        /*tp_as_sequence*/
230     0,                                        /*tp_as_mapping*/
231     0,                                        /*tp_hash */
232     0,                                        /*tp_call*/
233     0,                                        /*tp_str*/
234     0,                                        /*tp_getattro*/
235     0,                                        /*tp_setattro*/
236     0,                                        /*tp_as_buffer*/
237     Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE, /*tp_flags*/
238     "MNN Interpreter objects",                    /* tp_doc */
239     0,                                        /* tp_traverse */
240     0,                                        /* tp_clear */
241     0,                                        /* tp_richcompare */
242     0,                                        /* tp_weaklistoffset */
243     0,                                        /* tp_iter */
244     0,                                        /* tp_iternext */
245     PyMNNInterpreter_methods,                      /* tp_methods */
246     0,                      /* tp_members */
247     0,                    /* tp_getset */
248     0,                                        /* tp_base */
249     0,                                        /* tp_dict */
250     0,                                        /* tp_descr_get */
251     0,                                        /* tp_descr_set */
252     0,                                        /* tp_dictoffset */
253     (initproc)PyMNNInterpreter_init,               /* tp_init */
254     0,                                        /* tp_alloc */
255     PyMNNInterpreter_new,                          /* tp_new */
256 };
257 
258 /// MNN Session Type
259 static PyObject* PyMNNSession_new(struct _typeobject *type, PyObject *args, PyObject *kwds);
260 static void PyMNNSession_dealloc(PyMNNSession *);
261 static PyObject* PyMNNSession_cache(PyMNNSession *self, PyObject *args);
262 static PyObject* PyMNNSession_removeCache(PyMNNSession *self, PyObject *args);
263 
264 static PyMethodDef PyMNNSession_methods[] = {
265     {"cache", (PyCFunction)PyMNNSession_cache, METH_VARARGS, "cache current session instance"},
266     {"removeCache", (PyCFunction)PyMNNSession_removeCache, METH_VARARGS, "remove session cache with given path"},
267     {NULL}  /* Sentinel */
268 };
269 
270 static PyTypeObject PyMNNSessionType = {
271     PyVarObject_HEAD_INIT(NULL, 0)
272     "MNN.Session",                   /*tp_name*/
273     sizeof(PyMNNSession),                      /*tp_basicsize*/
274     0,                                        /*tp_itemsize*/
275     (destructor)PyMNNSession_dealloc,          /*tp_dealloc*/
276     0,                                        /*tp_print*/
277     0,                                        /*tp_getattr*/
278     0,                                        /*tp_setattr*/
279     0,                                        /*tp_compare*/
280     0,                                        /*tp_repr*/
281     0,                                        /*tp_as_number*/
282     0,                                        /*tp_as_sequence*/
283     0,                                        /*tp_as_mapping*/
284     0,                                        /*tp_hash */
285     0,                                        /*tp_call*/
286     0,                                        /*tp_str*/
287     0,                                        /*tp_getattro*/
288     0,                                        /*tp_setattro*/
289     0,                                        /*tp_as_buffer*/
290     Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE, /*tp_flags*/
291     "MNN Session objects",                    /* tp_doc */
292     0,                                        /* tp_traverse */
293     0,                                        /* tp_clear */
294     0,                                        /* tp_richcompare */
295     0,                                        /* tp_weaklistoffset */
296     0,                                        /* tp_iter */
297     0,                                        /* tp_iternext */
298     PyMNNSession_methods,                   /* tp_methods */
299     0,                      /* tp_members */
300     0,                    /* tp_getset */
301     0,                                        /* tp_base */
302     0,                                        /* tp_dict */
303     0,                                        /* tp_descr_get */
304     0,                                        /* tp_descr_set */
305     0,                                        /* tp_dictoffset */
306     0,               /* tp_init */
307     0,                                        /* tp_alloc */
308     PyMNNSession_new,                          /* tp_new */
309 };
310 
311 /// MNN Tensor Type
312 static PyObject* PyMNNTensor_new(struct _typeobject *type, PyObject *args, PyObject *kwds);
313 static void PyMNNTensor_dealloc(PyMNNTensor *);
314 static int PyMNNTensor_init(PyMNNTensor *self, PyObject *args, PyObject *kwds);
315 #ifdef PYMNN_NUMPY_USABLE
316 static PyObject* PyMNNTensor_fromNumpy(PyMNNTensor *self, PyObject *args);
317 static PyObject* PyMNNTensor_getNumpyData(PyMNNTensor *self, PyObject *args);
318 #endif
319 static PyObject* PyMNNTensor_printTensorData(PyMNNTensor *self, PyObject *args);
320 static PyObject* PyMNNTensor_getShape(PyMNNTensor *self, PyObject *args);
321 static PyObject* PyMNNTensor_getDataType(PyMNNTensor *self, PyObject *args);
322 static PyObject* PyMNNTensor_getDimensionType(PyMNNTensor *self, PyObject *args);
323 static PyObject* PyMNNTensor_getData(PyMNNTensor *self, PyObject *args);
324 static PyObject* PyMNNTensor_getHost(PyMNNTensor *self, PyObject *args);
325 static PyObject* PyMNNTensor_copyFrom(PyMNNTensor *self, PyObject *args);
326 static PyObject* PyMNNTensor_copyToHostTensor(PyMNNTensor *self, PyObject *args);
327 
328 static PyMethodDef PyMNNTensor_methods[] = {
329 #ifdef PYMNN_NUMPY_USABLE
330     {"fromNumpy", (PyCFunction)PyMNNTensor_fromNumpy, METH_VARARGS, "copy data from numpy"},
331     {"getNumpyData", (PyCFunction)PyMNNTensor_getNumpyData, METH_NOARGS, "get tensor data (numpy)"},
332 #endif
333     {"printTensorData", (PyCFunction)PyMNNTensor_printTensorData, METH_NOARGS, "print tensor data"},
334     {"getShape", (PyCFunction)PyMNNTensor_getShape, METH_NOARGS, "get tensor shape"},
335     {"getDataType", (PyCFunction)PyMNNTensor_getDataType, METH_NOARGS, "get tensor data type"},
336     {"getData", (PyCFunction)PyMNNTensor_getData, METH_NOARGS, "get tensor data (tuple)"},
337     {"getHost", (PyCFunction)PyMNNTensor_getHost, METH_NOARGS, "get tensor host"},
338     {"getDimensionType", (PyCFunction)PyMNNTensor_getDimensionType, METH_NOARGS, "get dimension data"},
339     {"copyFrom", (PyCFunction)PyMNNTensor_copyFrom, METH_VARARGS, "copy data from host tensor"},
340     {"copyFromHostTensor", (PyCFunction)PyMNNTensor_copyFrom, METH_VARARGS, "copy data from host tensor"},
341     {"copyToHostTensor", (PyCFunction)PyMNNTensor_copyToHostTensor, METH_VARARGS, "copy data to host tensor"},
342     {NULL}  /* Sentinel */
343 };
344 
345 static PyTypeObject PyMNNTensorType = {
346     PyVarObject_HEAD_INIT(NULL, 0)
347     "MNN.Tensor",                   /*tp_name*/
348     sizeof(PyMNNTensor),                      /*tp_basicsize*/
349     0,                                        /*tp_itemsize*/
350     (destructor)PyMNNTensor_dealloc,          /*tp_dealloc*/
351     0,                                        /*tp_print*/
352     0,                                        /*tp_getattr*/
353     0,                                        /*tp_setattr*/
354     0,                                        /*tp_compare*/
355     0,                                        /*tp_repr*/
356     0,                                        /*tp_as_number*/
357     0,                                        /*tp_as_sequence*/
358     0,                                        /*tp_as_mapping*/
359     0,                                        /*tp_hash */
360     0,                                        /*tp_call*/
361     0,                                        /*tp_str*/
362     0,                                        /*tp_getattro*/
363     0,                                        /*tp_setattro*/
364     0,                                        /*tp_as_buffer*/
365     Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE, /*tp_flags*/
366     "MNN Tensor objects",                    /* tp_doc */
367     0,                                        /* tp_traverse */
368     0,                                        /* tp_clear */
369     0,                                        /* tp_richcompare */
370     0,                                        /* tp_weaklistoffset */
371     0,                                        /* tp_iter */
372     0,                                        /* tp_iternext */
373     PyMNNTensor_methods,                                   /* tp_methods */
374     0,                      /* tp_members */
375     0,                    /* tp_getset */
376     0,                                        /* tp_base */
377     0,                                        /* tp_dict */
378     0,                                        /* tp_descr_get */
379     0,                                        /* tp_descr_set */
380     0,                                        /* tp_dictoffset */
381     (initproc)PyMNNTensor_init,               /* tp_init */
382     0,                                        /* tp_alloc */
383     PyMNNTensor_new,                          /* tp_new */
384 };
385 
386 /// MNN ImageProcess Type
387 static PyObject* PyMNNCVImageProcess_new(struct _typeobject *type, PyObject *args, PyObject *kwds);
388 static void PyMNNCVImageProcess_dealloc(PyMNNCVImageProcess *);
389 static int PyMNNCVImageProcess_init(PyMNNCVImageProcess *self, PyObject *args, PyObject *kwds);
390 static PyObject* PyMNNCVImageProcess_setMatrix(PyMNNCVImageProcess *self, PyObject *args);
391 static PyObject* PyMNNCVImageProcess_convert(PyMNNCVImageProcess *self, PyObject *args);
392 static PyObject* PyMNNCVImageProcess_createImageTensor(PyMNNCVImageProcess *self, PyObject *args);
393 
394 static PyMethodDef PyMNNCVImageProcess_methods[] = {
395     {"setMatrix", (PyCFunction)PyMNNCVImageProcess_setMatrix, METH_VARARGS, "ImageProcess setMatrix"},
396     {"convert", (PyCFunction)PyMNNCVImageProcess_convert, METH_VARARGS, "ImageProcess convert"},
397     {"createImageTensor", (PyCFunction)PyMNNCVImageProcess_createImageTensor, METH_VARARGS, "ImageProcess create Image Tensor"},
398     {NULL}  /* Sentinel */
399 };
400 
401 static PyTypeObject PyMNNCVImageProcessType = {
402     PyVarObject_HEAD_INIT(NULL, 0)
403     "MNN.CVImageProcess",                   /*tp_name*/
404     sizeof(PyMNNCVImageProcess),                      /*tp_basicsize*/
405     0,                                        /*tp_itemsize*/
406     (destructor)PyMNNCVImageProcess_dealloc,          /*tp_dealloc*/
407     0,                                        /*tp_print*/
408     0,                                        /*tp_getattr*/
409     0,                                        /*tp_setattr*/
410     0,                                        /*tp_compare*/
411     0,                                        /*tp_repr*/
412     0,                                        /*tp_as_number*/
413     0,                                        /*tp_as_sequence*/
414     0,                                        /*tp_as_mapping*/
415     0,                                        /*tp_hash */
416     0,                                        /*tp_call*/
417     0,                                        /*tp_str*/
418     0,                                        /*tp_getattro*/
419     0,                                        /*tp_setattro*/
420     0,                                        /*tp_as_buffer*/
421     Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE, /*tp_flags*/
422     "MNN CVImageProcess objects",                    /* tp_doc */
423     0,                                        /* tp_traverse */
424     0,                                        /* tp_clear */
425     0,                                        /* tp_richcompare */
426     0,                                        /* tp_weaklistoffset */
427     0,                                        /* tp_iter */
428     0,                                        /* tp_iternext */
429     PyMNNCVImageProcess_methods,                                   /* tp_methods */
430     0,                      /* tp_members */
431     0,                    /* tp_getset */
432     0,                                        /* tp_base */
433     0,                                        /* tp_dict */
434     0,                                        /* tp_descr_get */
435     0,                                        /* tp_descr_set */
436     0,                                        /* tp_dictoffset */
437     (initproc)PyMNNCVImageProcess_init,               /* tp_init */
438     0,                                        /* tp_alloc */
439     PyMNNCVImageProcess_new,                          /* tp_new */
440 };
441 
442 /// MNN CVMatrix Type
443 static PyObject* PyMNNCVMatrix_new(struct _typeobject *type, PyObject *args, PyObject *kwds);
444 static void PyMNNCVMatrix_dealloc(PyMNNCVMatrix *);
445 /// scale
446 static PyObject* PyMNNCVMatrix_setScale(PyMNNCVMatrix *, PyObject *args);
447 static PyObject* PyMNNCVMatrix_preScale(PyMNNCVMatrix *, PyObject *args);
448 static PyObject* PyMNNCVMatrix_postScale(PyMNNCVMatrix *, PyObject *args);
449 /// rotate
450 static PyObject* PyMNNCVMatrix_setRotate(PyMNNCVMatrix *, PyObject *args);
451 static PyObject* PyMNNCVMatrix_preRotate(PyMNNCVMatrix *, PyObject *args);
452 static PyObject* PyMNNCVMatrix_postRotate(PyMNNCVMatrix *, PyObject *args);
453 /// translate
454 static PyObject* PyMNNCVMatrix_setTranslate(PyMNNCVMatrix *, PyObject *args);
455 static PyObject* PyMNNCVMatrix_preTranslate(PyMNNCVMatrix *, PyObject *args);
456 static PyObject* PyMNNCVMatrix_postTranslate(PyMNNCVMatrix *, PyObject *args);
457 
458 static PyObject* PyMNNCVMatrix_invert(PyMNNCVMatrix *);
459 
460 static PyMethodDef PyMNNCVMatrix_methods[] = {
461     {"setScale", (PyCFunction)PyMNNCVMatrix_setScale, METH_VARARGS, "MNNCVMatrix setScale"},
462     {"preScale", (PyCFunction)PyMNNCVMatrix_preScale, METH_VARARGS, "MNNCVMatrix preScale"},
463     {"postScale", (PyCFunction)PyMNNCVMatrix_postScale, METH_VARARGS, "MNNCVMatrix postScale"},
464 
465     {"setRotate", (PyCFunction)PyMNNCVMatrix_setRotate, METH_VARARGS, "MNNCVMatrix setRotate"},
466     {"preRotate", (PyCFunction)PyMNNCVMatrix_preRotate, METH_VARARGS, "MNNCVMatrix preRotate"},
467     {"postRotate", (PyCFunction)PyMNNCVMatrix_postRotate, METH_VARARGS, "MNNCVMatrix postRotate"},
468 
469     {"setTranslate", (PyCFunction)PyMNNCVMatrix_setTranslate, METH_VARARGS, "MNNCVMatrix setTranslate"},
470     {"preTranslate", (PyCFunction)PyMNNCVMatrix_preTranslate, METH_VARARGS, "MNNCVMatrix preTranslate"},
471     {"postTranslate", (PyCFunction)PyMNNCVMatrix_postTranslate, METH_VARARGS, "MNNCVMatrix postTranslate"},
472 
473     {"invert", (PyCFunction)PyMNNCVMatrix_invert, METH_VARARGS, "MNNCVMatrix invert"},
474     {NULL}  /* Sentinel */
475 };
476 
477 static PyTypeObject PyMNNCVMatrixType = {
478     PyVarObject_HEAD_INIT(NULL, 0)
479     "MNN.CVImageProcess",                   /*tp_name*/
480     sizeof(PyMNNCVMatrix),                      /*tp_basicsize*/
481     0,                                        /*tp_itemsize*/
482     (destructor)PyMNNCVMatrix_dealloc,          /*tp_dealloc*/
483     0,                                        /*tp_print*/
484     0,                                        /*tp_getattr*/
485     0,                                        /*tp_setattr*/
486     0,                                        /*tp_compare*/
487     0,                                        /*tp_repr*/
488     0,                                        /*tp_as_number*/
489     0,                                        /*tp_as_sequence*/
490     0,                                        /*tp_as_mapping*/
491     0,                                        /*tp_hash */
492     0,                                        /*tp_call*/
493     0,                                        /*tp_str*/
494     0,                                        /*tp_getattro*/
495     0,                                        /*tp_setattro*/
496     0,                                        /*tp_as_buffer*/
497     Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE, /*tp_flags*/
498     "MNN CVMatrix objects",                    /* tp_doc */
499     0,                                        /* tp_traverse */
500     0,                                        /* tp_clear */
501     0,                                        /* tp_richcompare */
502     0,                                        /* tp_weaklistoffset */
503     0,                                        /* tp_iter */
504     0,                                        /* tp_iternext */
505     PyMNNCVMatrix_methods,                                   /* tp_methods */
506     0,                      /* tp_members */
507     0,                    /* tp_getset */
508     0,                                        /* tp_base */
509     0,                                        /* tp_dict */
510     0,                                        /* tp_descr_get */
511     0,                                        /* tp_descr_set */
512     0,                                        /* tp_dictoffset */
513     0,               /* tp_init */
514     0,                                        /* tp_alloc */
515     PyMNNCVMatrix_new,                          /* tp_new */
516 };
517 
518 /// MNN NetInstance implementation
519 // 用来缓存net的实例
520 
interpreterMap()521 std::unordered_map<std::string, Interpreter *> *interpreterMap() {
522 //    static std::unordered_map<std::string, Interpreter *> *interpreterMap = nullptr; // <path, instance>
523 //    static std::once_flag flag;
524 //    std::call_once(flag, [](){interpreterMap = new std::unordered_map<std::string, Interpreter *>();});
525     struct MNN_TLSData *tlsData = getTLSData();
526     if (tlsData == nullptr) {
527         return nullptr;
528     }
529     return tlsData->interpreterMap;
530 }
531 
sessionCacheMap()532 std::unordered_map<std::string, Session *> *sessionCacheMap() {
533     struct MNN_TLSData *tlsData = getTLSData();
534     if (tlsData == nullptr) {
535         return nullptr;
536     }
537     return tlsData->sessionCacheMap;
538 }
539 
540 namespace ec {
getVectorByKey(PyObject * dict,const char * key,std::vector<std::string> & result)541     int getVectorByKey(PyObject* dict, const char *key, std::vector<std::string>& result){
542         PyObject *saveTensors = PyDict_GetItemString(dict, key);
543         int count = 0;
544         if (saveTensors) {
545             if (!PyTuple_Check(saveTensors)) {
546                 PyErr_SetString(PyExc_Exception,
547                                 "PyMNNInterpreter_createSession: saveTensors must be a tuple");
548                 return -1;
549             }
550 
551             size_t saveTensorsCount = PyTuple_Size(saveTensors);
552             for (size_t i = 0; i < saveTensorsCount; i++) {
553                 PyObject *tensorNameItem = PyTuple_GetItem(saveTensors, i);
554                 if (!checkString(tensorNameItem)) {
555                     PyErr_SetString(PyExc_Exception,
556                                     "PyMNNInterpreter_createSession: saveTensors's member must be string");
557                     return -1;
558                 }
559 
560 
561                 result.push_back(object2String(tensorNameItem));
562                 count++;
563             }
564         }
565         return count;
566     }
567 }
568 
PyMNNInterpreter_createSession(PyMNNInterpreter * self,PyObject * args)569 static PyObject* PyMNNInterpreter_createSession(PyMNNInterpreter *self, PyObject *args) {
570     PyMNNInterpreter* instance = (PyMNNInterpreter *)self;
571     PyObject* dict = NULL;
572     if (!PyArg_ParseTuple(args, "|O", &dict)) {
573         return NULL;
574     }
575 
576     PyObject *f = importName("MNN", "Session");
577     if (!f || !PyCallable_Check(f)) {
578         PyErr_SetString(PyExc_Exception,
579                         "PyMNNInterpreter_createSession: MNN.Session not found");
580         return NULL;
581     }
582 
583     // create a new session
584     PyMNNSession *session = (PyMNNSession *)PyObject_Call(f, PyTuple_New(0), NULL);
585     if (!session) {
586         PyErr_SetString(PyExc_Exception,
587                         "PyMNNInterpreter_createSession: MNN.Session instance create failed");
588         return NULL;
589     }
590 
591     if (self->modelPath && (*sessionCacheMap())[*self->modelPath]) {
592         session->modelPath = self->modelPath;
593         session->session = (*sessionCacheMap())[*self->modelPath];
594         return (PyObject *)session;
595     }
596 
597     ScheduleConfig config;
598     BackendConfig backendConfig;
599     config.backendConfig = &backendConfig;
600     if (dict) {
601         PyObject *backend = PyDict_GetItemString(dict, "backend");
602         config.type = MNN_FORWARD_CPU;
603         if (backend) {
604             auto backend_name = object2String(backend);
605             // Avoid misusing backend not supported by the bridge and corresponding MNN library on python level,
606             // then user will ask for right version bridge library to us, same like MNN.expr.Backend.* python enum
607             std::unordered_map<std::string, MNNForwardType> backend_map = {
608                 // Don't care whether MNN library support corresponding backend, all backend type are usable by user,
609                 // which make MNN.whl setup.py easy
610                 {"CPU", MNN_FORWARD_CPU},
611                 {"OPENCL", MNN_FORWARD_OPENCL},
612                 {"OPENGL", MNN_FORWARD_OPENGL},
613                 {"VULKAN", MNN_FORWARD_VULKAN},
614                 {"METAL", MNN_FORWARD_METAL},
615                 {"TRT", MNN_FORWARD_USER_1},
616                 {"CUDA", MNN_FORWARD_CUDA},
617                 {"HIAI", MNN_FORWARD_USER_0}
618             };
619             auto iter = backend_map.find(backend_name);
620             if (iter == backend_map.end()) {
621                 // backend not support, issue on python level when development
622                 PyErr_SetString(PyExc_Exception,
623                                 "PyMNNInterpreter_createSession: backend not support");
624                 return NULL;
625             }
626             config.type = iter->second;
627         }
628         if(config.type == MNN_FORWARD_CPU) {
629             PyObject *numThread = PyDict_GetItemString(dict, "numThread");
630             if (numThread) {
631                 if (!PyLong_Check(numThread)) {
632                     PyErr_SetString(PyExc_Exception,
633                                     "PyMNNInterpreter_createSession: numThread must be a integer");
634                     return NULL;
635                 }
636                 config.numThread = (int)PyLong_AsLong(numThread);
637             }
638         }
639 
640         {
641             //precision
642             PyObject *obj = PyDict_GetItemString(dict, "precision");
643             if (obj) {
644                 auto obj_name = object2String(obj);
645                 if (!obj_name.compare("low")) {
646                     MNN_PRINT("MNN use low precision\n");
647                     backendConfig.precision = MNN::BackendConfig::Precision_Low;
648                 }
649             }
650         }
651 
652         if (-1 == ec::getVectorByKey(dict, "saveTensors", config.saveTensors)
653             || -1 == ec::getVectorByKey(dict, "inputPaths", config.path.inputs)
654             || -1 == ec::getVectorByKey(dict, "outputPaths", config.path.outputs)){
655             return NULL;
656         }
657 
658     }
659 
660     Session *s = instance->interpreter->createSession(config);
661     if (!s) {
662         PyErr_SetString(PyExc_Exception,
663                         "PyMNNInterpreter_createSession: NetInstance createSession failed");
664         return NULL;
665     }
666 
667     session->session = s;
668     session->modelPath = instance->modelPath;
669 
670     return (PyObject *)session;
671 }
672 
PyMNNInterpreter_resizeSession(PyMNNInterpreter * self,PyObject * args)673 static PyObject* PyMNNInterpreter_resizeSession(PyMNNInterpreter *self, PyObject *args) {
674     PyMNNSession* session = NULL;
675     if (!PyArg_ParseTuple(args, "O", &session)) {
676         return NULL;
677     }
678 
679     if (!PyObject_TypeCheck(session, PyType_FindTLSType(&PyMNNSessionType))) {
680         PyErr_SetString(PyExc_Exception,
681                         "PyMNNInterpreter_resizeSession: First argument is not a MNN.Session instance");
682         return NULL;
683     }
684 
685     self->interpreter->resizeSession(session->session);
686     Py_RETURN_TRUE;
687 }
688 
PyMNNInterpreter_resizeTensor(PyMNNInterpreter * self,PyObject * args)689 static PyObject* PyMNNInterpreter_resizeTensor(PyMNNInterpreter *self, PyObject *args) {
690     PyMNNTensor* tensor = NULL;
691     PyObject* shape = NULL;
692     if (!PyArg_ParseTuple(args, "OO", &tensor, &shape)) {
693         return NULL;
694     }
695 
696     if (!PyObject_TypeCheck(tensor, PyType_FindTLSType(&PyMNNTensorType))) {
697         PyErr_SetString(PyExc_Exception,
698                         "PyMNNInterpreter_resizeTensor: First argument is not a MNN.Tensor instance");
699         return NULL;
700     }
701 
702     if (!PyTuple_Check(shape)) {
703         PyErr_SetString(PyExc_Exception,
704                         "PyMNNInterpreter_resizeTensor: Second argument is not a tuple");
705         return NULL;
706     }
707 
708     size_t shapeSize = PyTuple_Size(shape);
709 
710     std::vector<int> vShape;
711     for (size_t i = 0; i < shapeSize; i++) {
712         int shapeItem = (int)PyLong_AsLong(PyTuple_GetItem(shape, i));
713         vShape.push_back(shapeItem);
714     }
715 
716     self->interpreter->resizeTensor(tensor->tensor, vShape);
717     Py_RETURN_NONE;
718 }
719 #ifndef PYMNN_USE_ALINNPYTHON
PyMNNInterpreter_setCacheFile(PyMNNInterpreter * self,PyObject * args)720 static PyObject* PyMNNInterpreter_setCacheFile(PyMNNInterpreter *self, PyObject *args) {
721     char *path = NULL;
722     if (!PyArg_ParseTuple(args, "s", &path)) {
723         PyErr_SetString(PyExc_Exception,
724                         "PyMNNInterpreter_setCacheFile: Not string input");
725         return NULL;
726     }
727     Py_BEGIN_ALLOW_THREADS
728     self->interpreter->setCacheFile(path);
729     Py_END_ALLOW_THREADS
730     Py_RETURN_NONE;
731 }
732 #endif
733 
PyMNNInterpreter_runSession(PyMNNInterpreter * self,PyObject * args)734 static PyObject* PyMNNInterpreter_runSession(PyMNNInterpreter *self, PyObject *args) {
735     PyMNNSession* session = NULL;
736     if (!args) {
737         PyErr_SetString(PyExc_Exception,
738                         "PyMNNInterpreter_runSession: No argument passed, expect 1");
739         return NULL;
740     }
741 
742     if (!PyArg_ParseTuple(args, "O", &session)) {
743         return NULL;
744     }
745 
746     if (!PyObject_TypeCheck(session, PyType_FindTLSType(&PyMNNSessionType))) {
747         PyErr_SetString(PyExc_Exception,
748                         "PyMNNInterpreter_runSession: First argument is not a MNN.Session instance");
749         return NULL;
750     }
751     ErrorCode r = NO_ERROR;
752     Py_BEGIN_ALLOW_THREADS
753     r = self->interpreter->runSession(session->session);
754     Py_END_ALLOW_THREADS
755     return PyLong_FromLong(r);
756 }
PyMNNInterpreter_runSessionWithCallBack(PyMNNInterpreter * self,PyObject * args)757 static PyObject* PyMNNInterpreter_runSessionWithCallBack(PyMNNInterpreter *self, PyObject *args) {
758     PyMNNSession* session = NULL;
759     PyObject *beginCallback = NULL;
760     PyObject *endCallback = NULL;
761     if (!args) {
762         PyErr_SetString(PyExc_Exception,
763                         "PyMNNInterpreter_runSessionWithCallBack: No argument passed, expect 1 or 3");
764         return NULL;
765     }
766 
767     if (!PyArg_ParseTuple(args, "O|OO", &session, &beginCallback, &endCallback)) {
768         return NULL;
769     }
770 
771     if (!PyObject_TypeCheck(session, PyType_FindTLSType(&PyMNNSessionType))) {
772         PyErr_SetString(PyExc_Exception,
773                         "PyMNNInterpreter_runSessionWithCallBack: First argument is not a AliNN.Session instance");
774         return NULL;
775     }
776 
777     TensorCallBack begin = [beginCallback](const std::vector<Tensor*>& tensors, const std::string& name){
778 
779         if (!beginCallback || !PyCallable_Check(beginCallback)) {
780 
781             return true;
782         }
783 
784         PyObject *f = importName("MNN", "Tensor");
785             if (!f || !PyCallable_Check(f)) {
786                     PyErr_SetString(PyExc_Exception,
787                              "PyMNNInterpreter_runSessionWithCallBack: MNN.Tensor not found");
788              return true;
789         }
790 
791         PyObject *args = PyTuple_New(2);
792         size_t size_tensors = tensors.size();
793         PyObject *weTensorData = PyTuple_New(size_tensors);
794         for (size_t i = 0; i < size_tensors; i++) {
795             // create a new tensor
796             PyMNNTensor *tensor = (PyMNNTensor *)PyObject_Call(f, PyTuple_New(0), NULL);
797             if (!tensor) {
798                 PyErr_SetString(PyExc_Exception,
799                         "PyMNNInterpreter_runSessionWithCallBack: create Tensor failed");
800                 return true;
801             }
802             tensor->tensor = tensors[i];
803             PyTuple_SetItem(weTensorData, i, (PyObject *)tensor);
804         }
805         //printf("begincallback name=%s\n",name.c_str());
806         PyObject *weStringData = char2Object(name.c_str());
807         PyTuple_SetItem(args, 0, weTensorData);
808         PyTuple_SetItem(args, 1, weStringData);
809         bool ret = static_cast<bool>(PyLong_AsLong(PyObject_Call(beginCallback, args, NULL)));
810         Py_XDECREF(args);//del all the C++ created python api parameters
811         return ret;
812     };
813     TensorCallBack end = [endCallback](const std::vector<Tensor*>& tensors, const std::string& name){
814         if (!endCallback || !PyCallable_Check(endCallback)) {
815             return true;
816         }
817         PyObject *f = importName("MNN", "Tensor");
818             if (!f || !PyCallable_Check(f)) {
819                     PyErr_SetString(PyExc_Exception,
820                              "PyMNNInterpreter_runSessionWithCallBack: MNN.Tensor not found");
821              return true;
822         }
823         PyObject *args = PyTuple_New(2);
824         size_t size_tensors = tensors.size();
825         PyObject *weTensorData = PyTuple_New(size_tensors);
826         for (size_t i = 0; i < size_tensors; i++) {
827             // create a new tensor
828             PyMNNTensor *tensor = (PyMNNTensor *)PyObject_Call(f, PyTuple_New(0), NULL);
829             if (!tensor) {
830                 PyErr_SetString(PyExc_Exception,
831                         "PyMNNInterpreter_runSessionWithCallBack: create Tensor failed");
832                 return true;
833             }
834             tensor->tensor = tensors[i];
835             PyTuple_SetItem(weTensorData, i, (PyObject *)tensor);
836         }
837         PyObject *weStringData = char2Object(name.c_str());
838         PyTuple_SetItem(args, 0, weTensorData);
839         PyTuple_SetItem(args, 1, weStringData);
840         bool ret = static_cast<bool>(PyLong_AsLong(PyObject_Call(endCallback, args, NULL)));
841         Py_XDECREF(args);//del all the C++ created python api parameters
842         return ret;
843     };
844 
845     ErrorCode r = NO_ERROR;
846     //Py_BEGIN_ALLOW_THREADS
847     r = self->interpreter->runSessionWithCallBack(session->session, begin, end);
848     //Py_END_ALLOW_THREADS
849     return PyLong_FromLong(r);
850 }
851 
PyMNNInterpreter_runSessionWithCallBackInfo(PyMNNInterpreter * self,PyObject * args)852 static PyObject* PyMNNInterpreter_runSessionWithCallBackInfo(PyMNNInterpreter *self, PyObject *args) {
853     PyMNNSession* session = NULL;
854     PyObject *beginCallback = NULL;
855     PyObject *endCallback = NULL;
856     if (!args) {
857         PyErr_SetString(PyExc_Exception,
858                         "PyMNNInterpreter_runSessionWithCallBackInfo: No argument passed, expect 1 or 3");
859         return NULL;
860     }
861 
862     if (!PyArg_ParseTuple(args, "O|OO", &session, &beginCallback, &endCallback)) {
863         return NULL;
864     }
865 
866     if (!PyObject_TypeCheck(session, PyType_FindTLSType(&PyMNNSessionType))) {
867         PyErr_SetString(PyExc_Exception,
868                         "PyMNNInterpreter_runSessionWithCallBackInfo: First argument is not a AliNN.Session instance");
869         return NULL;
870     }
871 
872     TensorCallBackWithInfo begin = [beginCallback](const std::vector<Tensor*>& tensors, const OperatorInfo* info){
873 
874         if (!beginCallback || !PyCallable_Check(beginCallback)) {
875 
876             return true;
877         }
878 
879         PyObject *ftensor = importName("MNN", "Tensor");
880         PyObject *finfo = importName("MNN", "OpInfo");
881         if (!ftensor || !PyCallable_Check(ftensor)) {
882                     PyErr_SetString(PyExc_Exception,
883                              "PyMNNInterpreter_runSessionWithCallBackINfo: MNN.Tensor not found");
884              return true;
885         }
886         if (!finfo || !PyCallable_Check(finfo)) {
887                     PyErr_SetString(PyExc_Exception,
888                              "PyMNNInterpreter_runSessionWithCallBackInfo: MNN.OpInfo not found");
889              return true;
890         }
891 
892         PyObject *args = PyTuple_New(2);
893         size_t size_tensors = tensors.size();
894         PyObject *weTensorData = PyTuple_New(size_tensors);
895         for (size_t i = 0; i < size_tensors; i++) {
896             // create a new tensor
897             PyMNNTensor *tensor = (PyMNNTensor *)PyObject_Call(ftensor, PyTuple_New(0), NULL);
898             if (!tensor) {
899                 PyErr_SetString(PyExc_Exception,
900                         "PyMNNInterpreter_runSessionWithCallBackInfo: create Tensor failed");
901                 return true;
902             }
903             tensor->tensor = tensors[i];
904             PyTuple_SetItem(weTensorData, i, (PyObject *)tensor);
905         }
906         //printf("begincallback name=%s\n",name.c_str());
907         PyMNNOpInfo *pyinfo = (PyMNNOpInfo *)PyObject_Call(finfo,PyTuple_New(0), NULL);
908         if(!pyinfo){
909             PyErr_SetString(PyExc_Exception,
910                     "PyMNNInterpreter_runSessionWithCallBackInfo: create OpInfo failed");
911             return true;
912         }
913         pyinfo->opInfo = info;
914         PyTuple_SetItem(args, 0, weTensorData);
915         PyTuple_SetItem(args, 1, (PyObject *)pyinfo);
916         bool ret = static_cast<bool>(PyLong_AsLong(PyObject_Call(beginCallback, args, NULL)));
917         Py_XDECREF(args);//del all the C++ created python api parameters
918         return ret;
919     };
920     TensorCallBackWithInfo end = [endCallback](const std::vector<Tensor*>& tensors, const OperatorInfo* info){
921         if (!endCallback || !PyCallable_Check(endCallback)) {
922             return true;
923         }
924         PyObject *ftensor = importName("MNN", "Tensor");
925         PyObject *finfo = importName("MNN", "OpInfo");
926         if (!ftensor || !PyCallable_Check(ftensor)) {
927                     PyErr_SetString(PyExc_Exception,
928                              "PyMNNInterpreter_runSessionWithCallBackInfo: MNN.Tensor not found");
929              return true;
930         }
931         if (!finfo || !PyCallable_Check(finfo)) {
932                     PyErr_SetString(PyExc_Exception,
933                              "PyMNNInterpreter_runSessionWithCallBackInfo: MNN.OpInfo not found");
934              return true;
935         }
936         PyObject *args = PyTuple_New(2);
937         size_t size_tensors = tensors.size();
938         PyObject *weTensorData = PyTuple_New(size_tensors);
939         for (size_t i = 0; i < size_tensors; i++) {
940             // create a new tensor
941             PyMNNTensor *tensor = (PyMNNTensor *)PyObject_Call(ftensor, PyTuple_New(0), NULL);
942             if (!tensor) {
943                 PyErr_SetString(PyExc_Exception,
944                         "PyMNNInterpreter_runSessionWithCallBackInfo: create Tensor failed");
945                 return true;
946             }
947             tensor->tensor = tensors[i];
948             PyTuple_SetItem(weTensorData, i, (PyObject *)tensor);
949         }
950         PyMNNOpInfo *pyinfo = (PyMNNOpInfo *)PyObject_Call(finfo,PyTuple_New(0), NULL);
951         if(!pyinfo){
952             PyErr_SetString(PyExc_Exception,
953                     "PyMNNInterpreter_runSessionWithCallBackInfo: create OpInfo failed");
954             return true;
955         }
956         pyinfo->opInfo = info;
957         PyTuple_SetItem(args, 0, weTensorData);
958         PyTuple_SetItem(args, 1, (PyObject *)pyinfo);
959         bool ret = static_cast<bool>(PyLong_AsLong(PyObject_Call(endCallback, args, NULL)));
960         Py_XDECREF(args);//del all the C++ created python api parameters
961         return ret;
962     };
963 
964     ErrorCode r = NO_ERROR;
965     //Py_BEGIN_ALLOW_THREADS
966     r = self->interpreter->runSessionWithCallBackInfo(session->session, begin, end);
967     //Py_END_ALLOW_THREADS
968     return PyLong_FromLong(r);
969 }
970 
971 
PyMNNInterpreter_getSessionOutput(PyMNNInterpreter * self,PyObject * args)972 static PyObject* PyMNNInterpreter_getSessionOutput(PyMNNInterpreter *self, PyObject *args) {
973     PyMNNSession* session = NULL;
974     char* name = NULL;
975     if (!PyArg_ParseTuple(args, "O|s", &session, &name)) {
976         return NULL;
977     }
978 
979     if (!PyObject_TypeCheck(session, PyType_FindTLSType(&PyMNNSessionType))) {
980         PyErr_SetString(PyExc_Exception,
981                         "PyMNNInterpreter_getSessionOutput: First argument is not a MNN.Session instance");
982         return NULL;
983     }
984 
985     Tensor *t = self->interpreter->getSessionOutput(session->session, name);
986     if (!t) {
987         PyErr_SetString(PyExc_Exception,
988                         "PyMNNInterpreter_getSessionOutput: Get output failed");
989         return NULL;
990     }
991 
992     PyObject *f = importName("MNN", "Tensor");
993     if (!f || !PyCallable_Check(f)) {
994         PyErr_SetString(PyExc_Exception,
995                         "PyMNNInterpreter_getSessionOutput: MNN.Tensor not found");
996         return NULL;
997     }
998 
999     // create a new tensor
1000     PyMNNTensor *tensor = (PyMNNTensor *)PyObject_Call(f, PyTuple_New(0), NULL);
1001     if (!tensor) {
1002         PyErr_SetString(PyExc_Exception,
1003                         "PyMNNInterpreter_createSession: MNN.Session instance create failed");
1004         return NULL;
1005     }
1006 
1007     tensor->tensor = t;
1008     return (PyObject *)tensor;
1009 }
1010 
PyMNNInterpreter_getSessionInput(PyMNNInterpreter * self,PyObject * args)1011 static PyObject* PyMNNInterpreter_getSessionInput(PyMNNInterpreter *self, PyObject *args) {
1012     PyMNNSession* session = NULL;
1013     char* name = NULL;
1014     if (!PyArg_ParseTuple(args, "O|s", &session, &name)) {
1015         return NULL;
1016     }
1017 
1018     if (!PyObject_TypeCheck(session, PyType_FindTLSType(&PyMNNSessionType))) {
1019         PyErr_SetString(PyExc_Exception,
1020                         "PyMNNInterpreter_getSessionInput: First argument is not a MNN.Session instance");
1021         return NULL;
1022     }
1023 
1024     Tensor *t = self->interpreter->getSessionInput(session->session, name);
1025     if (!t) {
1026         PyErr_SetString(PyExc_Exception,
1027                         "PyMNNInterpreter_getSessionInput: Get input failed");
1028         return NULL;
1029     }
1030 
1031     PyObject *f = importName("MNN", "Tensor");
1032     if (!f || !PyCallable_Check(f)) {
1033         PyErr_SetString(PyExc_Exception,
1034                         "PyMNNInterpreter_getSessionInput: MNN.Tensor not found");
1035         return NULL;
1036     }
1037 
1038     // create a new tensor
1039     PyMNNTensor *tensor = (PyMNNTensor *)PyObject_Call(f, PyTuple_New(0), NULL);
1040     if (!tensor) {
1041         PyErr_SetString(PyExc_Exception,
1042                         "PyMNNInterpreter_createSession: MNN.Session instance create failed");
1043         return NULL;
1044     }
1045 
1046     tensor->tensor = t;
1047     return (PyObject *)tensor;
1048 }
1049 
PyMNNInterpreter_getSessionOutputAll(PyMNNInterpreter * self,PyObject * args)1050 static PyObject* PyMNNInterpreter_getSessionOutputAll(PyMNNInterpreter *self, PyObject *args) {
1051     PyMNNSession* session = NULL;
1052     if (!PyArg_ParseTuple(args, "O", &session)) {
1053         return NULL;
1054     }
1055     if (!PyObject_TypeCheck(session, PyType_FindTLSType(&PyMNNSessionType))) {
1056         PyErr_SetString(PyExc_Exception,"PyMNNInterpreter_getSessionOutputAll: First argument is not a MNN.Session instance");
1057         return NULL;
1058     }
1059     PyObject *f = importName("MNN", "Tensor");
1060     if (!f || !PyCallable_Check(f)) {
1061         PyErr_SetString(PyExc_Exception,"PyMNNInterpreter_getSessionOutputAll: MNN.Tensor not found");
1062         return NULL;
1063     }
1064     auto map = self->interpreter->getSessionOutputAll(session->session);
1065     PyObject* output = PyDict_New();
1066     for (auto it=map.begin(); it!=map.end(); ++it) {
1067         PyObject *tensor = PyObject_Call(f, PyTuple_New(0), NULL);
1068         if (!tensor) {
1069             PyErr_SetString(PyExc_Exception,"PyMNNInterpreter_getSessionOutputAll: MNN.Tensor instance create failed");
1070             return NULL;
1071         }
1072         ((PyMNNTensor*)tensor)->tensor = it->second;
1073         PyDict_SetItem(output, char2Object(it->first.c_str()), tensor);
1074     }
1075     return output;
1076 }
1077 
PyMNNInterpreter_getSessionInputAll(PyMNNInterpreter * self,PyObject * args)1078 static PyObject* PyMNNInterpreter_getSessionInputAll(PyMNNInterpreter *self, PyObject *args) {
1079     PyMNNSession* session = NULL;
1080     if (!PyArg_ParseTuple(args, "O", &session)) {
1081         return NULL;
1082     }
1083     if (!PyObject_TypeCheck(session, PyType_FindTLSType(&PyMNNSessionType))) {
1084         PyErr_SetString(PyExc_Exception,"PyMNNInterpreter_getSessionInputAll: First argument is not a MNN.Session instance");
1085         return NULL;
1086     }
1087     PyObject *f = importName("MNN", "Tensor");
1088     if (!f || !PyCallable_Check(f)) {
1089         PyErr_SetString(PyExc_Exception,"PyMNNInterpreter_getSessionInputAll: MNN.Tensor not found");
1090         return NULL;
1091     }
1092     auto map = self->interpreter->getSessionInputAll(session->session);
1093     PyObject* output = PyDict_New();
1094     for (auto it=map.begin(); it!=map.end(); ++it) {
1095         PyObject *tensor = PyObject_Call(f, PyTuple_New(0), NULL);
1096         if (!tensor) {
1097             PyErr_SetString(PyExc_Exception,"PyMNNInterpreter_getSessionInputAll: MNN.Tensor instance create failed");
1098             return NULL;
1099         }
1100         ((PyMNNTensor*)tensor)->tensor = it->second;
1101         PyDict_SetItem(output, char2Object(it->first.c_str()), tensor);
1102     }
1103     return output;
1104 }
1105 
PyMNNInterpreter_new(struct _typeobject * type,PyObject * args,PyObject * kwds)1106 PyObject* PyMNNInterpreter_new(struct _typeobject *type, PyObject *args, PyObject *kwds) {
1107     PyMNNInterpreter* self = (PyMNNInterpreter *)type->tp_alloc(type, 0);
1108     return (PyObject*)self;
1109 }
1110 
PyMNNInterpreter_init(PyMNNInterpreter * self,PyObject * args,PyObject * kwds)1111 static int PyMNNInterpreter_init(PyMNNInterpreter *self, PyObject *args, PyObject *kwds) {
1112     char *path = NULL;
1113     if (!PyArg_ParseTuple(args, "s", &path)) {
1114         PyErr_SetString(PyExc_Exception,
1115                         "PyMNNInterpreter_new: PyArg_ParseTuple failed");
1116         return -1;
1117     }
1118     auto converted_path = convertBytesEncodeIfNeed(path);
1119     self->modelPath = new std::string(converted_path.data());
1120     if (!self->modelPath) {
1121         PyErr_SetString(PyExc_Exception,
1122                         "PyMNNInterpreter_new: create modelPath string failed");
1123         return -1;
1124     }
1125 
1126     if ((*interpreterMap())[*self->modelPath]) {
1127         self->interpreter = (*interpreterMap())[*self->modelPath];
1128     } else {
1129         self->interpreter = Interpreter::createFromFile(path);
1130     }
1131     if (!self->interpreter) {
1132         PyErr_SetString(PyExc_Exception,
1133                         "PyMNNInterpreter_new: NetInstance::createFromFile failed");
1134         return -1;
1135     }
1136 
1137     return 0;
1138 }
1139 
PyMNNInterpreter_cache(PyMNNInterpreter * self,PyObject * args)1140 static PyObject* PyMNNInterpreter_cache(PyMNNInterpreter *self, PyObject *args) {
1141     if (self->modelPath && !(*interpreterMap())[*self->modelPath]) {
1142         (*interpreterMap())[*self->modelPath] = self->interpreter;
1143     }
1144     Py_RETURN_NONE;
1145 }
1146 
PyMNNInterpreter_removeCache(PyMNNInterpreter * self,PyObject * args)1147 static PyObject* PyMNNInterpreter_removeCache(PyMNNInterpreter *self, PyObject *args) {
1148     if (!self->modelPath) {
1149         Py_RETURN_NONE;
1150     }
1151     Interpreter* net = (*interpreterMap())[*self->modelPath];
1152     if (net) {
1153         interpreterMap()->erase(*self->modelPath);
1154         //delete net;
1155     }
1156     Py_RETURN_NONE;
1157 }
1158 
PyMNNInterpreter_updateSessionToModel(PyMNNInterpreter * self,PyObject * args)1159 static PyObject* PyMNNInterpreter_updateSessionToModel(PyMNNInterpreter *self, PyObject *args) {
1160     PyMNNSession* session = NULL;
1161     char* name = NULL;
1162     if (!PyArg_ParseTuple(args, "O|s", &session, &name)) {
1163         return NULL;
1164     }
1165 
1166     if (!PyObject_TypeCheck(session, PyType_FindTLSType(&PyMNNSessionType))) {
1167         PyErr_SetString(PyExc_Exception,
1168                         "PyMNNInterpreter_updateSessionToModel: First argument is not a MNN.Session instance");
1169         return NULL;
1170     }
1171 
1172     self->interpreter->updateSessionToModel(session->session);
1173     if(name){
1174         auto modelBuffer = self->interpreter->getModelBuffer();
1175         ofstream output(name);
1176         output.write((const char*)modelBuffer.first, modelBuffer.second);
1177     }
1178     Py_RETURN_NONE;
1179 }
1180 
PyMNNInterpreter_dealloc(PyMNNInterpreter * self)1181 static void PyMNNInterpreter_dealloc(PyMNNInterpreter *self) {
1182     if (!self->modelPath) {
1183         return;
1184     }
1185     Interpreter* net = (*interpreterMap())[*self->modelPath];
1186     // 如果对象不存在缓存中, 则释放实例
1187     if (!net && self->interpreter) {
1188         delete self->interpreter;
1189         self->interpreter = NULL;
1190     }
1191     delete self->modelPath;
1192     Py_TYPE(self)->tp_free((PyObject*)self);
1193 }
1194 
1195 /// MNN Session implementation
PyMNNSession_new(struct _typeobject * type,PyObject * args,PyObject * kwds)1196 static PyObject* PyMNNSession_new(struct _typeobject *type, PyObject *args, PyObject *kwds) {
1197     PyMNNSession* self = (PyMNNSession *)type->tp_alloc(type, 0);
1198     return (PyObject*)self;
1199 }
1200 
PyMNNSession_dealloc(PyMNNSession * self)1201 static void PyMNNSession_dealloc(PyMNNSession *self) {
1202     self->session = NULL;
1203     Py_TYPE(self)->tp_free((PyObject*)self);
1204 }
1205 
1206 // cache session
PyMNNSession_cache(PyMNNSession * self,PyObject * args)1207 static PyObject* PyMNNSession_cache(PyMNNSession *self, PyObject *args) {
1208     if (!self->modelPath) {
1209         Py_RETURN_NONE;
1210     }
1211     if (!(*sessionCacheMap())[*self->modelPath]) {
1212         (*sessionCacheMap())[*self->modelPath] = self->session;
1213     }
1214     Py_RETURN_NONE;
1215 }
1216 
PyMNNSession_removeCache(PyMNNSession * self,PyObject * args)1217 static PyObject* PyMNNSession_removeCache(PyMNNSession *self, PyObject *args) {
1218     if (!self->modelPath) {
1219         Py_RETURN_NONE;
1220     }
1221     Session* s = (*sessionCacheMap())[*self->modelPath];
1222     if (s) {
1223         sessionCacheMap()->erase(*self->modelPath);
1224     }
1225     Py_RETURN_NONE;
1226 }
1227 
1228 /// MNN Tensor implementation
PyMNNTensor_new(struct _typeobject * type,PyObject * args,PyObject * kwds)1229 static PyObject* PyMNNTensor_new(struct _typeobject *type, PyObject *args, PyObject *kwds) {
1230     PyMNNTensor* self = (PyMNNTensor *)type->tp_alloc(type, 0);
1231     return (PyObject*)self;
1232 }
1233 
PyMNNTensor_dealloc(PyMNNTensor * self)1234 static void PyMNNTensor_dealloc(PyMNNTensor *self) {
1235     if (self->owner) {
1236         if (self->tensor->host<void *>()) {
1237             free(self->tensor->host<void *>());
1238         }
1239         delete self->tensor;
1240     }
1241     Py_TYPE(self)->tp_free((PyObject*)self);
1242 }
1243 
PyMNNTensor_init(PyMNNTensor * self,PyObject * args,PyObject * kwds)1244 static int PyMNNTensor_init(PyMNNTensor *self, PyObject *args, PyObject *kwds) {
1245     if (!PyTuple_Size(args)) {
1246         return 0;
1247     }
1248 
1249     PyObject *shape, *dataType, *data;
1250     long dimensionType;
1251     if (!PyArg_ParseTuple(args, "OOOl", &shape, &dataType, &data, &dimensionType)) {
1252         return -1;
1253     }
1254 
1255     size_t shapeSize = PyTuple_Size(shape);
1256 
1257     std::vector<int> vShape;
1258     size_t dataSize = 1;
1259     for (size_t i = 0; i<shapeSize; i++) {
1260         int shapeItem = (int)PyLong_AsLong(PyTuple_GetItem(shape, i));
1261         vShape.push_back(shapeItem);
1262         dataSize *= shapeItem;
1263     }
1264     bool isNumpy = false;
1265     void *pData = NULL;
1266     if(PyTuple_Check(data)) {
1267         if(dataSize != PyTuple_Size(data)) {
1268             PyErr_SetString(PyExc_Exception,
1269                         "PyMNNTensor_init: Tensor Dim not match");
1270             return -1;
1271         }
1272     }
1273 #ifdef PYMNN_NUMPY_USABLE
1274     else {
1275         if(PyArray_Check(data)) {
1276             isNumpy = true;
1277             if(dataSize != PyArray_Size(data)) {
1278                 PyErr_SetString(PyExc_Exception, "PyMNNTensor_init: numpy array size does not match shape requirement");
1279                 return -1;
1280             }
1281         }
1282         else {
1283             PyErr_SetString(PyExc_Exception, "PyMNNTensor_init: data is not tuple/numpy");
1284             return -1;
1285         }
1286     }
1287 #endif
1288     halide_type_t htt;
1289     struct MNN_TLSData *tlsData = getTLSData();
1290     if (dataType == tlsData->PyMNNHalideTypeInt) {
1291         htt = halide_type_of<int32_t>();
1292     }
1293     else if(dataType == tlsData->PyMNNHalideTypeFloat) {
1294         htt = halide_type_of<float>();
1295     }
1296     else if(dataType == tlsData->PyMNNHalideTypeDouble) {
1297         htt = halide_type_of<float>();
1298     }
1299     else if(dataType == tlsData->PyMNNHalideTypeUint8) {
1300         htt = halide_type_of<uint8_t>();
1301     }
1302     else if(dataType == tlsData->PyMNNHalideTypeInt64) {
1303         htt = halide_type_of<int64_t>();
1304     }
1305     else if(dataType == tlsData->PyMNNHalideTypeString) {
1306         htt = *httString();
1307     }
1308     else {
1309         PyErr_SetString(PyExc_Exception,"PyMNNTensor_create: unsupported data type");
1310         return -1;
1311     }
1312     DType dtype = htype2dtype(htt);
1313     if(!isNumpy) {
1314         int itemsize = getitemsize(dtype);
1315         pData = malloc(dataSize * itemsize);
1316         if(NULL == pData) {
1317             PyErr_SetString(PyExc_Exception,"PyMNNTensor_init: malloc failed");
1318             return -1;
1319         }
1320         if (dataType == tlsData->PyMNNHalideTypeInt) {
1321             for (size_t i = 0; i < dataSize; i++) {
1322                 ((int *)pData)[i] = (int)PyLong_AsLong(PyTuple_GetItem(data, i));
1323             }
1324         } else if (dataType == tlsData->PyMNNHalideTypeFloat) {
1325             for (size_t i = 0; i < dataSize; i++) {
1326                 ((float *)pData)[i] = (float)PyFloat_AsDouble(PyTuple_GetItem(data, i));
1327             }
1328         } else if (dataType == tlsData->PyMNNHalideTypeDouble) {
1329             for (size_t i = 0; i < dataSize; i++) {
1330                ((double *)pData)[i] = PyFloat_AsDouble(PyTuple_GetItem(data, i));
1331             }
1332         } else if (dataType == tlsData->PyMNNHalideTypeUint8) {
1333             for (size_t i = 0; i < dataSize; i++) {
1334                ((uint8_t *)pData)[i] = (uint8_t)PyLong_AsLong(PyTuple_GetItem(data, i));
1335             }
1336         } else if (dataType == tlsData->PyMNNHalideTypeInt64) {
1337             for (size_t i = 0; i < dataSize; i++) {
1338                ((int64_t *)pData)[i] = (int64_t)PyLong_AsLong(PyTuple_GetItem(data, i));
1339             }
1340          } else if (dataType == tlsData->PyMNNHalideTypeString) {
1341             for (size_t i = 0; i < dataSize; i++) {
1342                char *item = (char *)object2String(PyTuple_GetItem(data, i)).c_str();
1343                ((char **)pData)[i] = item;
1344             }
1345         }
1346     }
1347 #ifdef PYMNN_NUMPY_USABLE
1348     else {
1349         int npy_type = PyArray_TYPE(data);
1350         int itemsize = getitemsize(dtype, npy_type);
1351         pData = malloc(dataSize * itemsize);
1352         if(NULL == pData) {
1353             PyErr_SetString(PyExc_Exception,"PyMNNTensor_init: malloc failed");
1354             return -1;
1355         }
1356         PyArrayObject *data_cont= PyArray_GETCONTIGUOUS((PyArrayObject*)data);
1357         auto tmpBuffer = PyArray_DATA(data_cont);
1358         if(NULL == tmpBuffer) {
1359              PyErr_SetString(PyExc_Exception,"PyMNNTensor_init: ndarry failed to get buffer data");
1360              return -1;
1361         }
1362         memcpy(pData, tmpBuffer, dataSize * itemsize);
1363         Py_XDECREF(data_cont);
1364      }
1365  #endif
1366     Tensor *tensor = Tensor::create(vShape
1367                                , htt
1368                                , pData
1369                                , (Tensor::DimensionType)dimensionType
1370                                );
1371     if (!tensor) {
1372         PyErr_SetString(PyExc_Exception,
1373                         "PyMNNTensor_create: Tensor create failed");
1374         return -1;
1375     }
1376     self->tensor = tensor;
1377     self->owner = 1;
1378     return 0;
1379 }
1380 #ifdef PYMNN_NUMPY_USABLE
PyMNNTensor_fromNumpy(PyMNNTensor * self,PyObject * args)1381 static PyObject* PyMNNTensor_fromNumpy(PyMNNTensor *self, PyObject *args) {
1382     PyObject *data;
1383     if (!PyArg_ParseTuple(args, "O", &data)) {
1384         return NULL;
1385     }
1386     if (!PyArray_Check(data)) {
1387         PyErr_SetString(PyExc_Exception,"PyMNNTensor_fromNumpy: input is not a numpy");
1388     }
1389     if (self->owner){
1390         if(self->tensor->size() != PyArray_Size(data)) {
1391             PyErr_SetString(PyExc_Exception,"PyMNNTensor_fromNumpy: tensor/numpy size does not match each other");
1392             return NULL;
1393         }
1394         DType dtype = htype2dtype(self->tensor->getType());
1395         int npy_type = PyArray_TYPE(data);
1396         int itemsize = getitemsize(dtype, npy_type);
1397         PyArrayObject *data_cont= PyArray_GETCONTIGUOUS((PyArrayObject*)data);
1398         auto tmpBuffer = PyArray_DATA(data_cont);
1399         if(NULL == tmpBuffer) {
1400              PyErr_SetString(PyExc_Exception,"PyMNNTensor_fromNumpy: ndarry failed to get buffer data");
1401              return NULL;
1402         }
1403         memcpy(self->tensor->host<void *>(), tmpBuffer, self->tensor->size() * itemsize);
1404         Py_XDECREF(data_cont);
1405     }
1406     Py_RETURN_NONE;
1407 }
1408 #endif
PyMNNTensor_printTensorData(PyMNNTensor * self,PyObject * args)1409 static PyObject* PyMNNTensor_printTensorData(PyMNNTensor *self, PyObject *args) {
1410     if (self->tensor) {
1411         // Do nothing
1412     }
1413     Py_RETURN_NONE;
1414 }
1415 
PyMNNTensor_getHost(PyMNNTensor * self,PyObject * args)1416 static PyObject* PyMNNTensor_getHost(PyMNNTensor *self, PyObject *args) {
1417     if (self->tensor) {
1418         return PyCapsule_New(self->tensor->host<void *>(), NULL, NULL);
1419     }
1420     Py_RETURN_NONE;
1421 }
1422 
PyMNNTensor_getDataType(PyMNNTensor * self,PyObject * args)1423 static PyObject* PyMNNTensor_getDataType(PyMNNTensor *self, PyObject *args) {
1424     if (self->tensor) {
1425         halide_type_t t = self->tensor->getType();
1426         PyObject *type;
1427         struct MNN_TLSData *tlsData =getTLSData();
1428         if (t == *httInt()) {
1429             type = tlsData->PyMNNHalideTypeInt;
1430         } else if (t == *httUint8()) {
1431             type = tlsData->PyMNNHalideTypeUint8;
1432         } else if (t == *httInt64()) {
1433             type = tlsData->PyMNNHalideTypeInt64;
1434         } else if (t == *httFloat()) {
1435             type = tlsData->PyMNNHalideTypeFloat;
1436         } else if (t == *httDouble()) {
1437             type = tlsData->PyMNNHalideTypeDouble;
1438         } else if (t == *httString()) {
1439             type = tlsData->PyMNNHalideTypeString;
1440         } else {
1441             Py_RETURN_NONE;
1442         }
1443         Py_XINCREF(type);
1444         return type;
1445     }
1446     Py_RETURN_NONE;
1447 }
1448 
PyMNNTensor_getData(PyMNNTensor * self,PyObject * args)1449 static PyObject* PyMNNTensor_getData(PyMNNTensor *self, PyObject *args) {
1450     if (self->tensor) {
1451         halide_type_t t = self->tensor->getType();
1452         size_t size = self->tensor->elementSize();
1453         PyObject *outputData = PyTuple_New(size);
1454         if (t == *httInt()) {
1455             auto data = self->tensor->host<int32_t>();
1456             for (size_t i = 0; i < size; i++) {
1457                 PyTuple_SetItem(outputData, i, PyLong_FromLong(data[i]));
1458             }
1459          } else if (t == *httUint8()) {
1460             auto data = self->tensor->host<uint8_t>();
1461             for (size_t i = 0; i < size; i++) {
1462                 PyTuple_SetItem(outputData, i, PyLong_FromLong(data[i]));
1463             }
1464          } else if (t == *httInt64()) {
1465             auto data = self->tensor->host<int64_t>();
1466             for (size_t i = 0; i < size; i++) {
1467                 PyTuple_SetItem(outputData, i, PyLong_FromLong(data[i]));
1468             }
1469          } else if (t == *httFloat()) {
1470             auto data = self->tensor->host<float>();
1471             for (size_t i = 0; i < size; i++) {
1472                 PyTuple_SetItem(outputData, i, PyFloat_FromDouble(data[i]));
1473             }
1474          } else if (t == *httDouble()) {
1475             auto data = self->tensor->host<double>();
1476             for (size_t i = 0; i < size; i++) {
1477                 PyTuple_SetItem(outputData, i, PyFloat_FromDouble(data[i]));
1478             }
1479          } else if (t == *httString()) {
1480             auto data = self->tensor->host<char *>();
1481             for (size_t i = 0; i < size; i++) {
1482                 char *dataItem = data[i];
1483                 PyTuple_SetItem(outputData, i, char2Object(dataItem?dataItem:""));
1484             }
1485          } else {
1486             Py_RETURN_NONE;
1487          }
1488          return outputData;
1489     }
1490     Py_RETURN_NONE;
1491 }
1492 
1493 #ifdef PYMNN_NUMPY_USABLE
PyMNNTensor_getNumpyData(PyMNNTensor * self,PyObject * args)1494 static PyObject* PyMNNTensor_getNumpyData(PyMNNTensor *self, PyObject *args) {
1495     if (self->tensor) {
1496         halide_type_t t = self->tensor->getType();
1497         std::vector<npy_intp> npy_dims;
1498         for(const auto dim : self->tensor->shape()) {
1499             npy_dims.push_back(dim);
1500         }
1501         PyObject* obj;
1502         if (t == *httInt()) {
1503             auto data = self->tensor->host<int32_t>();
1504             obj = PyArray_SimpleNewFromData(npy_dims.size(), npy_dims.data(), NPY_INT32, data);
1505         } else if (t == *httUint8()) {
1506             auto data = self->tensor->host<uint8_t>();
1507             obj = PyArray_SimpleNewFromData(npy_dims.size(), npy_dims.data(), NPY_UINT8, data);
1508         } else if (t == *httInt64()) {
1509             auto data = self->tensor->host<int64_t>();
1510             obj = PyArray_SimpleNewFromData(npy_dims.size(), npy_dims.data(), NPY_INT64, data);
1511         } else if (t == *httFloat()) {
1512             auto data = self->tensor->host<float>();
1513             obj = PyArray_SimpleNewFromData(npy_dims.size(), npy_dims.data(), NPY_FLOAT, data);
1514         } else if (t == *httDouble()) {
1515             auto data = self->tensor->host<double>();
1516             obj = PyArray_SimpleNewFromData(npy_dims.size(), npy_dims.data(), NPY_DOUBLE, data);
1517         } else {
1518             PyErr_SetString(PyExc_Exception, "tensor can not be read as numpy");
1519             Py_RETURN_NONE;
1520         }
1521         return obj;
1522     }
1523     Py_RETURN_NONE;
1524 }
1525 #endif
1526 
PyMNNTensor_getDimensionType(PyMNNTensor * self,PyObject * args)1527 static PyObject* PyMNNTensor_getDimensionType(PyMNNTensor *self, PyObject *args) {
1528     if (self->tensor) {
1529         return PyLong_FromLong(self->tensor->getDimensionType());
1530     }
1531     Py_RETURN_NONE;
1532 }
1533 
PyMNNTensor_copyFrom(PyMNNTensor * self,PyObject * args)1534 static PyObject* PyMNNTensor_copyFrom(PyMNNTensor *self, PyObject *args) {
1535     PyMNNTensor *fromTensor = NULL;
1536     if (!PyArg_ParseTuple(args, "O", &fromTensor)) {
1537         return NULL;
1538     }
1539 
1540     if (!fromTensor->tensor || !self->tensor) {
1541         PyErr_SetString(PyExc_Exception,
1542                         "PyMNNTensor_copyFrom: source or destination tensor is null");
1543     }
1544 
1545     bool r = self->tensor->copyFromHostTensor(fromTensor->tensor);
1546     if (!r) {
1547         Py_RETURN_FALSE;
1548     }
1549     Py_RETURN_TRUE;
1550 }
1551 
PyMNNTensor_copyToHostTensor(PyMNNTensor * self,PyObject * args)1552 static PyObject* PyMNNTensor_copyToHostTensor(PyMNNTensor *self, PyObject *args) {
1553     PyMNNTensor *toTensor = NULL;
1554     if (!PyArg_ParseTuple(args, "O", &toTensor)) {
1555         return NULL;
1556     }
1557 
1558     if (!toTensor->tensor || !self->tensor) {
1559         PyErr_SetString(PyExc_Exception,
1560                         "PyMNNTensor_copyTo: source or destination tensor is null");
1561     }
1562 
1563     bool r = self->tensor->copyToHostTensor(toTensor->tensor);
1564     if (!r) {
1565         Py_RETURN_FALSE;
1566     }
1567     Py_RETURN_TRUE;
1568 }
1569 
PyMNNTensor_getShape(PyMNNTensor * self,PyObject * args)1570 static PyObject* PyMNNTensor_getShape(PyMNNTensor *self, PyObject *args) {
1571     if (self->tensor) {
1572         PyObject *shape = PyTuple_New(self->tensor->shape().size());
1573         for (size_t i = 0; i < self->tensor->shape().size(); i++) {
1574             PyTuple_SetItem(shape, i, PyLong_FromLong(self->tensor->shape()[i]));
1575         }
1576         return shape;
1577     }
1578     Py_RETURN_NONE;
1579 }
1580 
1581 /// MNN ImageProcess implementation
PyMNNCVImageProcess_new(struct _typeobject * type,PyObject * args,PyObject * kwds)1582 static PyObject* PyMNNCVImageProcess_new(struct _typeobject *type, PyObject *args, PyObject *kwds) {
1583     PyMNNCVImageProcess* self = (PyMNNCVImageProcess *)type->tp_alloc(type, 0);
1584     return (PyObject*)self;
1585 }
1586 
PyMNNCVImageProcess_dealloc(PyMNNCVImageProcess * self)1587 static void PyMNNCVImageProcess_dealloc(PyMNNCVImageProcess *self) {
1588     delete self->imageProcess;
1589     Py_TYPE(self)->tp_free((PyObject*)self);
1590 }
1591 
1592 
PyMNNCVImageProcess_init(PyMNNCVImageProcess * self,PyObject * args,PyObject * kwds)1593 static int PyMNNCVImageProcess_init(PyMNNCVImageProcess *self, PyObject *args, PyObject *kwds) {
1594     PyObject *config = NULL, *destinationTensor = NULL;
1595     if (!PyArg_ParseTuple(args, "O|O", &config, &destinationTensor)) {
1596         return -1;
1597     }
1598 
1599     Tensor *t = NULL;
1600     if (destinationTensor
1601         && PyObject_TypeCheck(destinationTensor, PyType_FindTLSType(&PyMNNTensorType))) {
1602         t = ((PyMNNTensor *)destinationTensor)->tensor;
1603     }
1604 
1605     CV::ImageProcess::Config c;
1606     if (PyDict_Check(config)) {
1607         PyObject *filterType = PyDict_GetItemString(config, "filterType");
1608         if (filterType && PyLong_Check(filterType)) {
1609             c.filterType = (CV::Filter)PyLong_AsLong(filterType);
1610         }
1611 
1612         PyObject *sourceFormat = PyDict_GetItemString(config, "sourceFormat");
1613         if (sourceFormat && PyLong_Check(sourceFormat)) {
1614             c.sourceFormat = (CV::ImageFormat)PyLong_AsLong(sourceFormat);
1615         }
1616 
1617         PyObject *destFormat = PyDict_GetItemString(config, "destFormat");
1618         if (destFormat && PyLong_Check(destFormat)) {
1619             c.destFormat = (CV::ImageFormat)PyLong_AsLong(destFormat);
1620         }
1621 
1622         PyObject *wrap = PyDict_GetItemString(config, "wrap");
1623         if (wrap && PyLong_Check(wrap)) {
1624             c.wrap = (CV::Wrap)PyLong_AsLong(wrap);
1625         }
1626 
1627         PyObject *mean = PyDict_GetItemString(config, "mean");
1628         if (mean) {
1629             if (!PyTuple_Check(mean) || PyTuple_Size(mean) != 4) {
1630                 PyErr_SetString(PyExc_Exception,
1631                                 "PyMNNCVImageProcess_init: mean must be a tuple with 4 elements");
1632                 return -1;
1633             }
1634             for (int i = 0; i < 4; i++) {
1635                 c.mean[i] = (float)PyFloat_AsDouble(PyTuple_GetItem(mean, i));
1636             }
1637         }
1638 
1639         PyObject *normal = PyDict_GetItemString(config, "normal");
1640         if (normal) {
1641             if (!PyTuple_Check(normal) || PyTuple_Size(normal) != 4) {
1642                 PyErr_SetString(PyExc_Exception,
1643                                 "PyMNNCVImageProcess_init: normal must be a tuple with 4 elements");
1644                 return -1;
1645             }
1646             for (int i = 0; i < 4; i++) {
1647                 c.normal[i] = (float)PyFloat_AsDouble(PyTuple_GetItem(normal, i));
1648             }
1649         }
1650     }
1651 
1652     CV::ImageProcess *imageProcess = CV::ImageProcess::create(c, t);
1653     if (!imageProcess) {
1654         PyErr_SetString(PyExc_Exception,
1655                         "PyMNNCVImageProcess_init: ImageProcess create failed");
1656         return -1;
1657     }
1658 
1659     self->imageProcess = imageProcess;
1660     return 0;
1661 }
1662 
PyMNNCVImageProcess_setMatrix(PyMNNCVImageProcess * self,PyObject * args)1663 static PyObject* PyMNNCVImageProcess_setMatrix(PyMNNCVImageProcess *self, PyObject *args) {
1664     PyObject *matrix;
1665     if (!PyArg_ParseTuple(args, "O", &matrix)) {
1666         return NULL;
1667     }
1668 
1669     if (!PyObject_TypeCheck(matrix, PyType_FindTLSType(&PyMNNCVMatrixType))) {
1670         PyErr_SetString(PyExc_Exception,
1671                         "PyMNNCVImageProcess_setMatrix: argument is not a matrix");
1672         return NULL;
1673     }
1674 
1675     self->imageProcess->setMatrix(*((PyMNNCVMatrix *)matrix)->matrix);
1676     Py_RETURN_NONE;
1677 }
1678 
PyMNNCVImageProcess_convert(PyMNNCVImageProcess * self,PyObject * args)1679 static PyObject* PyMNNCVImageProcess_convert(PyMNNCVImageProcess *self, PyObject *args) {
1680     PyObject *source, *dest;
1681     int iw, ih, stride;
1682     if (!PyArg_ParseTuple(args, "OiiiO", &source, &iw, &ih, &stride, &dest)) {
1683         return NULL;
1684     }
1685 
1686     if (!PyObject_TypeCheck(dest, PyType_FindTLSType(&PyMNNTensorType))) {
1687         PyErr_SetString(PyExc_Exception,
1688                         "PyMNNCVImageProcess_convert: argument 4 is not a MNNTensor");
1689         return NULL;
1690     }
1691 
1692     if (PyCapsule_CheckExact(source)) {
1693         // Capsule Pointer
1694         ErrorCode ret = self->imageProcess->convert((const uint8_t *)PyCapsule_GetPointer(source, NULL),
1695                                                     iw, ih, stride,
1696                                                     ((PyMNNTensor *)dest)->tensor);
1697         return PyLong_FromLong(ret);
1698     } else if (PyTuple_Check(source)) {
1699         // Tuple Data
1700         size_t size = PyTuple_Size(source);
1701 
1702         void *pData = malloc(size * sizeof(uint8_t));
1703         for (size_t i = 0; i < size; i++) {
1704             ((uint8_t *)pData)[i] = (uint8_t)PyLong_AsLong(PyTuple_GetItem(source, i));
1705         }
1706 
1707         ErrorCode ret = self->imageProcess->convert((const uint8_t *)pData,
1708                                                     iw, ih, stride,
1709                                                     ((PyMNNTensor *)dest)->tensor);
1710 
1711         free(pData);
1712 
1713         return PyLong_FromLong(ret);
1714     }
1715 #ifdef PYMNN_NUMPY_USABLE
1716     else if(PyArray_Check(source)) {
1717         // Array Data
1718         int npy_type = PyArray_TYPE(source);
1719         if(npy_type != NPY_UINT8) {
1720             PyErr_SetString(PyExc_Exception,
1721                         "PyMNNCVImageProcess_convert: only numpy.uint8 is supported for numpy");
1722             return NULL;
1723         }
1724         int64_t total_length = 1;
1725         for (size_t i = 0; i < ((PyMNNTensor *)dest)->tensor->shape().size(); i++) {
1726             total_length *= ((PyMNNTensor *)dest)->tensor->shape()[i];
1727         }
1728         if(PyArray_Size(source) < total_length) //as input may contain stride, so we can only do basic check
1729         {
1730             PyErr_SetString(PyExc_Exception,
1731                         "PyMNNCVImageProcess_convert: data length does not match tensor size");
1732             return NULL;
1733         }
1734         PyArrayObject *data_cont= PyArray_GETCONTIGUOUS((PyArrayObject*)source);
1735         auto tmpBuffer = PyArray_DATA(data_cont);
1736         if(NULL == tmpBuffer) {
1737              PyErr_SetString(PyExc_Exception,"PyMNNTensor_init: ndarry failed to get buffer data");
1738              return NULL;
1739         }
1740         ErrorCode ret = self->imageProcess->convert((const uint8_t *)tmpBuffer,
1741                                                     iw, ih, stride,
1742                                                     ((PyMNNTensor *)dest)->tensor);
1743         Py_XDECREF(data_cont);
1744         return PyLong_FromLong(ret);
1745     }
1746 #endif
1747 
1748     PyErr_SetString(PyExc_Exception, "PyMNNCVImageProcess_convert: argument 0 is not a capsule or tuple or numpy");
1749 
1750     return NULL;
1751 }
1752 
1753 
PyMNNCVImageProcess_createImageTensor(PyMNNCVImageProcess * self,PyObject * args)1754 static PyObject* PyMNNCVImageProcess_createImageTensor(PyMNNCVImageProcess *self, PyObject *args) {
1755 
1756     PyObject *dataType;
1757     int width, height, bpp;
1758     PyObject *data;
1759 
1760     if (!PyArg_ParseTuple(args, "OiiiO", &dataType, &width, &height, &bpp, &data)) {
1761         return NULL;
1762     }
1763 
1764 
1765 //    if (nullptr != data && !PyCapsule_CheckExact(data)) {
1766 //        PyErr_SetString(PyExc_Exception,
1767 //                        "PyMNNCVImageProcess_createImageTensor: argument 4 is not a capsule");
1768 //        return NULL;
1769 //    }
1770 
1771     std::vector<int> vShape = {1, height, width, bpp};
1772 
1773     halide_type_t htt;
1774     struct MNN_TLSData *tlsData = getTLSData();
1775     if (dataType == tlsData->PyMNNHalideTypeInt) {
1776         htt = halide_type_of<int32_t>();
1777     } else if (dataType == tlsData->PyMNNHalideTypeFloat) {
1778         htt = halide_type_of<float>();
1779     } else if (dataType == tlsData->PyMNNHalideTypeDouble) {
1780         htt = halide_type_of<double>();
1781     } else if (dataType == tlsData->PyMNNHalideTypeUint8) {
1782         htt = halide_type_of<uint8_t>();
1783     } else if (dataType == tlsData->PyMNNHalideTypeInt64) {
1784         htt = halide_type_of<int64_t>();
1785     } else if (dataType == tlsData->PyMNNHalideTypeString) {
1786         htt = *httString();
1787     }
1788 
1789     Tensor *tensor = Tensor::create(vShape, htt);
1790 //    Tensor *tensor = Tensor::create(vShape, htt, PyCapsule_GetPointer(data, NULL));TODO
1791     if (!tensor) {
1792         PyErr_SetString(PyExc_Exception,
1793                         "PyMNNCVImageProcess_createImageTensor: Tensor create failed");
1794         return NULL;
1795     }
1796 
1797     PyObject *f = importName("MNN", "Tensor");
1798     if (!f || !PyCallable_Check(f)) {
1799         PyErr_SetString(PyExc_Exception,
1800                         "PyMNNCVImageProcess_createImageTensor: MNN.Tensor not found");
1801         return NULL;
1802     }
1803 
1804     PyMNNTensor *t = (PyMNNTensor *)PyObject_Call(f, PyTuple_New(0), NULL);
1805     if (!t) {
1806         PyErr_SetString(PyExc_Exception,
1807                         "PyMNNCVImageProcess_createImageTensor: create image tensor failed");
1808         return NULL;
1809     }
1810 
1811     t->tensor = tensor;
1812     t->owner = 1;
1813     return (PyObject *)t;
1814 }
1815 
1816 /// MNN CVMatrix implementation
PyMNNCVMatrix_new(struct _typeobject * type,PyObject * args,PyObject * kwds)1817 static PyObject* PyMNNCVMatrix_new(struct _typeobject *type, PyObject *args, PyObject *kwds) {
1818     PyMNNCVMatrix* self;
1819     self = (PyMNNCVMatrix *)type->tp_alloc(type, 0);
1820     self->matrix = new CV::Matrix();
1821     return (PyObject*)self;
1822 }
1823 
PyMNNCVMatrix_dealloc(PyMNNCVMatrix * self)1824 static void PyMNNCVMatrix_dealloc(PyMNNCVMatrix *self) {
1825     delete self->matrix;
1826     Py_TYPE(self)->tp_free((PyObject*)self);
1827 }
1828 
1829 // type: 0 set; 1 pre; 2 post
_PyMNNCVMatrix_Rotate(PyMNNCVMatrix * self,PyObject * args,int type)1830 static PyObject* _PyMNNCVMatrix_Rotate(PyMNNCVMatrix *self, PyObject *args, int type) {
1831     float degrees, px = 0.0, py = 0.0;
1832     size_t argsCount = PyTuple_Size(args);
1833     if (argsCount == 1) {
1834         if (!PyArg_ParseTuple(args, "f", &degrees)) {
1835             PyErr_SetString(PyExc_Exception,
1836                             "PyMNNCVMatrix_Rotate: PyArg_ParseTuple failed");
1837             return NULL;
1838         }
1839     } else if (argsCount == 3) {
1840         if (!PyArg_ParseTuple(args, "fff", &degrees, &px, &py)) {
1841             PyErr_SetString(PyExc_Exception,
1842                             "PyMNNCVMatrix_Rotate: PyArg_ParseTuple failed");
1843             return NULL;
1844         }
1845     } else {
1846         PyErr_SetString(PyExc_Exception,
1847                         "PyMNNCVMatrix_Rotate: argument count error (should be 1 or 3)");
1848         return NULL;
1849     }
1850 
1851     if (argsCount == 1) {
1852         switch (type) {
1853             case 0:
1854                 self->matrix->setRotate(degrees);
1855                 break;
1856             case 1:
1857                 self->matrix->preRotate(degrees);
1858                 break;
1859             case 2:
1860                 self->matrix->postRotate(degrees);
1861                 break;
1862             default:
1863                 break;
1864         }
1865 
1866     } else if (argsCount == 3) {
1867         switch (type) {
1868             case 0:
1869                 self->matrix->setRotate(degrees, px, py);
1870                 break;
1871             case 1:
1872                 self->matrix->preRotate(degrees, px, py);
1873                 break;
1874             case 2:
1875                 self->matrix->postRotate(degrees, px, py);
1876                 break;
1877             default:
1878                 break;
1879         }
1880     }
1881     Py_RETURN_NONE;
1882 }
1883 // set
PyMNNCVMatrix_setRotate(PyMNNCVMatrix * self,PyObject * args)1884 static PyObject* PyMNNCVMatrix_setRotate(PyMNNCVMatrix *self, PyObject *args) {
1885     return _PyMNNCVMatrix_Rotate(self, args, 0);
1886 }
1887 // pre
PyMNNCVMatrix_preRotate(PyMNNCVMatrix * self,PyObject * args)1888 static PyObject* PyMNNCVMatrix_preRotate(PyMNNCVMatrix *self, PyObject *args) {
1889     return _PyMNNCVMatrix_Rotate(self, args, 1);
1890 }
1891 // post
PyMNNCVMatrix_postRotate(PyMNNCVMatrix * self,PyObject * args)1892 static PyObject* PyMNNCVMatrix_postRotate(PyMNNCVMatrix *self, PyObject *args) {
1893     return _PyMNNCVMatrix_Rotate(self, args, 2);
1894 }
1895 
_PyMNNCVMatrix_Scale(PyMNNCVMatrix * self,PyObject * args,int type)1896 static PyObject* _PyMNNCVMatrix_Scale(PyMNNCVMatrix *self, PyObject *args, int type) {
1897     float sx, sy, px = 0.0, py = 0.0;
1898     size_t argsCount = PyTuple_Size(args);
1899     if (argsCount == 2) {
1900         if (!PyArg_ParseTuple(args, "ff", &sx, &sy)) {
1901             PyErr_SetString(PyExc_Exception,
1902                             "PyMNNCVMatrix_Scale: PyArg_ParseTuple failed");
1903             return NULL;
1904         }
1905     } else if (argsCount == 4) {
1906         if (!PyArg_ParseTuple(args, "ffff", &sx, &sy, &px, &py)) {
1907             PyErr_SetString(PyExc_Exception,
1908                             "PyMNNCVMatrix_Scale: PyArg_ParseTuple failed");
1909             return NULL;
1910         }
1911     } else {
1912         PyErr_SetString(PyExc_Exception,
1913                         "PyMNNCVMatrix_Scale: argument count error (should be 2 or 4)");
1914         return NULL;
1915     }
1916 
1917     if (argsCount == 2) {
1918         switch (type) {
1919             case 0:
1920                 self->matrix->setScale(sx, sy);
1921                 break;
1922             case 1:
1923                 self->matrix->preScale(sx, sy);
1924                 break;
1925             case 2:
1926                 self->matrix->postScale(sx, sy);
1927                 break;
1928             default:
1929                 break;
1930         }
1931     } else if (argsCount == 4) {
1932         switch (type) {
1933             case 0:
1934                 self->matrix->setScale(sx, sy, px, py);
1935                 break;
1936             case 1:
1937                 self->matrix->preScale(sx, sy, px, py);
1938                 break;
1939             case 2:
1940                 self->matrix->postScale(sx, sy, px, py);
1941                 break;
1942             default:
1943                 break;
1944         }
1945     }
1946     Py_RETURN_NONE;
1947 }
PyMNNCVMatrix_setScale(PyMNNCVMatrix * self,PyObject * args)1948 static PyObject* PyMNNCVMatrix_setScale(PyMNNCVMatrix *self, PyObject *args) {
1949     return _PyMNNCVMatrix_Scale(self, args, 0);
1950 }
PyMNNCVMatrix_preScale(PyMNNCVMatrix * self,PyObject * args)1951 static PyObject* PyMNNCVMatrix_preScale(PyMNNCVMatrix *self, PyObject *args) {
1952     return _PyMNNCVMatrix_Scale(self, args, 1);
1953 }
PyMNNCVMatrix_postScale(PyMNNCVMatrix * self,PyObject * args)1954 static PyObject* PyMNNCVMatrix_postScale(PyMNNCVMatrix *self, PyObject *args) {
1955     return _PyMNNCVMatrix_Scale(self, args, 2);
1956 }
1957 
_PyMNNCVMatrix_Translate(PyMNNCVMatrix * self,PyObject * args,int type)1958 static PyObject* _PyMNNCVMatrix_Translate(PyMNNCVMatrix *self, PyObject *args, int type) {
1959     float dx = 0.0, dy = 0.0;
1960     size_t argsCount = PyTuple_Size(args);
1961     if (argsCount == 2) {
1962         if (!PyArg_ParseTuple(args, "ff", &dx, &dy)) {
1963             PyErr_SetString(PyExc_Exception,
1964                             "PyMNNCVMatrix_postScale: PyArg_ParseTuple failed");
1965             return NULL;
1966         }
1967     } else {
1968         PyErr_SetString(PyExc_Exception,
1969                         "PyMNNCVMatrix_postScale: argument count error (should be 2 or 4)");
1970         return NULL;
1971     }
1972 
1973     switch (type) {
1974         case 0:
1975             self->matrix->setTranslate(dy, dy);
1976             break;
1977         case 1:
1978             self->matrix->preTranslate(dy, dy);
1979             break;
1980         case 2:
1981             self->matrix->postTranslate(dy, dy);
1982             break;
1983         default:
1984             break;
1985     }
1986     Py_RETURN_NONE;
1987 }
PyMNNCVMatrix_setTranslate(PyMNNCVMatrix * self,PyObject * args)1988 static PyObject* PyMNNCVMatrix_setTranslate(PyMNNCVMatrix *self, PyObject *args) {
1989     return _PyMNNCVMatrix_Translate(self, args, 0);
1990 }
PyMNNCVMatrix_preTranslate(PyMNNCVMatrix * self,PyObject * args)1991 static PyObject* PyMNNCVMatrix_preTranslate(PyMNNCVMatrix *self, PyObject *args) {
1992     return _PyMNNCVMatrix_Translate(self, args, 1);
1993 }
PyMNNCVMatrix_postTranslate(PyMNNCVMatrix * self,PyObject * args)1994 static PyObject* PyMNNCVMatrix_postTranslate(PyMNNCVMatrix *self, PyObject *args) {
1995     return _PyMNNCVMatrix_Translate(self, args, 2);
1996 }
1997 
PyMNNCVMatrix_invert(PyMNNCVMatrix * self)1998 static PyObject* PyMNNCVMatrix_invert(PyMNNCVMatrix *self) {
1999 
2000     self->matrix->invert(self->matrix);
2001     Py_RETURN_NONE;
2002 }
2003 static PyObject* PyMNNOpInfo_getName(PyMNNOpInfo *self, PyObject *args);
2004 static PyObject* PyMNNOpInfo_getType(PyMNNOpInfo *self, PyObject *args);
2005 
2006 static void PyMNNOpInfo_dealloc(PyMNNOpInfo *self);
2007 static PyObject* PyMNNOpInfo_new(struct _typeobject *type, PyObject *args, PyObject *kwds);
2008 static int PyMNNOpInfo_init(PyMNNOpInfo *info, PyObject *args, PyObject *kwds);
2009 
2010 static PyMethodDef PyMNNOpInfo_methods[] = {
2011     {"getName", (PyCFunction)PyMNNOpInfo_getName, METH_VARARGS, "get op name"},
2012     {"getType", (PyCFunction)PyMNNOpInfo_getType, METH_VARARGS, "get op type"},
2013     {NULL}  /* Sentinel */
2014 };
PyMNNOpInfo_getName(PyMNNOpInfo * self,PyObject * args)2015 static PyObject* PyMNNOpInfo_getName(PyMNNOpInfo *self, PyObject *args) {
2016     PyObject *name = NULL;
2017     if (self->opInfo) {
2018         name = char2Object(self->opInfo->name().c_str());
2019     }
2020     return name;
2021 }
PyMNNOpInfo_getType(PyMNNOpInfo * self,PyObject * args)2022 static PyObject* PyMNNOpInfo_getType(PyMNNOpInfo *self, PyObject *args) {
2023     PyObject *type = NULL;
2024     if (self->opInfo) {
2025         type = char2Object(self->opInfo->type().c_str());
2026     }
2027     return type;
2028 }
2029 static PyTypeObject PyMNNOpInfoType = {
2030     PyVarObject_HEAD_INIT(NULL, 0)
2031     "MNN.OpInfo",                   /*tp_name*/
2032     sizeof(PyMNNOpInfo),                      /*tp_basicsize*/
2033     0,                                        /*tp_itemsize*/
2034     (destructor)PyMNNOpInfo_dealloc,          /*tp_dealloc*/
2035     0,                                        /*tp_print*/
2036     0,                                        /*tp_getattr*/
2037     0,                                        /*tp_setattr*/
2038     0,                                        /*tp_compare*/
2039     0,                                        /*tp_repr*/
2040     0,                                        /*tp_as_number*/
2041     0,                                        /*tp_as_sequence*/
2042     0,                                        /*tp_as_mapping*/
2043     0,                                        /*tp_hash */
2044     0,                                        /*tp_call*/
2045     0,                                        /*tp_str*/
2046     0,                                        /*tp_getattro*/
2047     0,                                        /*tp_setattro*/
2048     0,                                        /*tp_as_buffer*/
2049     Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE, /*tp_flags*/
2050     "MNN OpInfo objects",                    /* tp_doc */
2051     0,                                        /* tp_traverse */
2052     0,                                        /* tp_clear */
2053     0,                                        /* tp_richcompare */
2054     0,                                        /* tp_weaklistoffset */
2055     0,                                        /* tp_iter */
2056     0,                                        /* tp_iternext */
2057     PyMNNOpInfo_methods,                                   /* tp_methods */
2058     0,                      /* tp_members */
2059     0,                    /* tp_getset */
2060     0,                                        /* tp_base */
2061     0,                                        /* tp_dict */
2062     0,                                        /* tp_descr_get */
2063     0,                                        /* tp_descr_set */
2064     0,                                        /* tp_dictoffset */
2065     (initproc)PyMNNOpInfo_init,               /* tp_init */
2066     0,                                        /* tp_alloc */
2067     PyMNNOpInfo_new,                          /* tp_new */
2068 };
PyMNNOpInfo_new(struct _typeobject * type,PyObject * args,PyObject * kwds)2069 static PyObject* PyMNNOpInfo_new(struct _typeobject *type, PyObject *args, PyObject *kwds) {
2070     PyMNNOpInfo* self = (PyMNNOpInfo *)type->tp_alloc(type, 0);
2071     return (PyObject*)self;
2072 }
PyMNNOpInfo_init(PyMNNOpInfo * info,PyObject * args,PyObject * kwds)2073 static int PyMNNOpInfo_init(PyMNNOpInfo *info, PyObject *args, PyObject *kwds) {
2074     return 0;
2075 }
2076 
PyMNNOpInfo_dealloc(PyMNNOpInfo * self)2077 static void PyMNNOpInfo_dealloc(PyMNNOpInfo *self) {
2078     Py_TYPE(self)->tp_free((PyObject*)self);
2079 }
2080 /// module init
2081 static PyMethodDef module_methods[] = {
2082     {NULL, NULL, 0, NULL}
2083 };
2084 
add(int a,int b)2085 int add (int a , int b) {
2086     return a + b;
2087 }
2088 
2089 // _MOD_NAME [_mnncengine or MNN]
2090 // MOD_NAME ["_mnncengine" or "MNN"]
2091 #if PYMNN_USE_ALINNPYTHON
2092 #if PYMNN_EXPR_API
2093 #define _MOD_NAME _mnncengine
2094 #else
2095 #define _MOD_NAME MNN
2096 #endif
2097 #else
2098 #define _MOD_NAME _mnncengine
2099 #endif
2100 #define _STRINGIFY(str) #str
2101 #define STRINGIFY(macro) _STRINGIFY(macro)
2102 #define MOD_NAME STRINGIFY(_MOD_NAME)
2103 
2104 #if PY_MAJOR_VERSION >= 3
2105 static struct PyModuleDef moduledef = {
2106     PyModuleDef_HEAD_INIT,
2107     MOD_NAME,     /* m_name */
2108     "MNNEngine",  /* m_doc */
2109     -1,                  /* m_size */
2110     module_methods,    /* m_methods */
2111     NULL,                /* m_reload */
2112     NULL,                /* m_traverse */
2113     NULL,                /* m_clear */
2114     NULL,                /* m_free */
2115 };
2116 #define MOD_INIT_FUNC_NAME(name) PyInit_##name
2117 #else
2118 #define MOD_INIT_FUNC_NAME(name) init##name
2119 #endif
2120 // MOD_INIT_FUNC [PyInit_{MOD_NAME} or init{MOD_NAME}]
2121 #define _MOD_INIT_FUNC(macro) MOD_INIT_FUNC_NAME(macro)
2122 #define MOD_INIT_FUNC _MOD_INIT_FUNC(_MOD_NAME)
2123 
2124 static std::once_flag mLoadFlag1;
2125 
MOD_INIT_FUNC(void)2126 PyMODINIT_FUNC MOD_INIT_FUNC(void) {
2127 #if PY_MAJOR_VERSION >= 3
2128 #define ERROR_RETURN return NULL;
2129 #else
2130 #define ERROR_RETURN return;
2131 #endif
2132 
2133 #ifdef PYMNN_USE_ALINNPYTHON
2134     std::call_once(mLoadFlag1, [&](){
2135         if (global_new_python_flag > 0) {
2136             tls_key = PyThread_create_key();
2137             tls_key_2 = PyThread_create_key();
2138         }
2139     });
2140 #endif
2141 
2142     if (PyType_Ready(&PyMNNInterpreterType) < 0) {
2143         PyErr_SetString(PyExc_Exception, "initMNN: PyType_Ready PyMNNInterpreterType failed");
2144         ERROR_RETURN
2145     }
2146     if (PyType_Ready(&PyMNNSessionType) < 0) {
2147         PyErr_SetString(PyExc_Exception, "initMNN: PyType_Ready PyMNNSessionType failed");
2148         ERROR_RETURN
2149     }
2150     if (PyType_Ready(&PyMNNTensorType) < 0) {
2151         PyErr_SetString(PyExc_Exception, "initMNN: PyType_Ready PyMNNTensorType failed");
2152         ERROR_RETURN
2153     }
2154     if (PyType_Ready(&PyMNNCVImageProcessType) < 0) {
2155         PyErr_SetString(PyExc_Exception, "initMNN: PyType_Ready PyMNNCVImageProcessType failed");
2156         ERROR_RETURN
2157     }
2158     if (PyType_Ready(&PyMNNCVMatrixType) < 0) {
2159         PyErr_SetString(PyExc_Exception, "initMNN: PyType_Ready PyMNNCVMatrixType failed");
2160         ERROR_RETURN
2161     }
2162     if (PyType_Ready(&PyMNNOpInfoType) < 0) {
2163         PyErr_SetString(PyExc_Exception, "initMNN: PyType_Ready PyMNNOpInfoType failed");
2164         ERROR_RETURN
2165     }
2166 #if PY_MAJOR_VERSION >= 3
2167     PyObject *m = PyModule_Create(&moduledef);
2168 #else
2169     PyObject *m = Py_InitModule3(MOD_NAME, module_methods, "MNN Module");
2170 #endif
2171     // module import failed!
2172     if (!m) {
2173         PyErr_SetString(PyExc_Exception, "initMNN: import MNN failed");
2174         ERROR_RETURN
2175     }
2176 #ifdef PYMNN_NUMPY_USABLE
2177     if(_import_array() < 0) {
2178         PyErr_SetString(PyExc_Exception, "initMNN: init numpy failed");
2179         ERROR_RETURN
2180     }
2181 #endif
2182 
2183     PyModule_AddObject(m, "Interpreter", (PyObject*)PyType_FindTLSType(&PyMNNInterpreterType));
2184     PyModule_AddObject(m, "Session", (PyObject*)PyType_FindTLSType(&PyMNNSessionType));
2185     PyModule_AddObject(m, "Tensor", (PyObject*)PyType_FindTLSType(&PyMNNTensorType));
2186     PyModule_AddObject(m, "CVImageProcess", (PyObject*)PyType_FindTLSType(&PyMNNCVImageProcessType));
2187     PyModule_AddObject(m, "CVMatrix", (PyObject*)PyType_FindTLSType(&PyMNNCVMatrixType));
2188     PyModule_AddObject(m, "OpInfo", (PyObject*)PyType_FindTLSType(&PyMNNOpInfoType));
2189 
2190     // Tensor::DimensionType
2191     PyObject *DimensionType_Tensorflow = PyLong_FromLong(Tensor::TENSORFLOW);
2192     PyObject *DimensionType_Caffe = PyLong_FromLong(Tensor::CAFFE);
2193     PyObject *DimensionType_Caffe_C4 = PyLong_FromLong(Tensor::CAFFE_C4);
2194     PyModule_AddObject(m, "Tensor_DimensionType_Tensorflow", DimensionType_Tensorflow);
2195     PyModule_AddObject(m, "Tensor_DimensionType_Caffe", DimensionType_Caffe);
2196     PyModule_AddObject(m, "Tensor_DimensionType_Caffe_C4", DimensionType_Caffe_C4);
2197 
2198     struct MNN_TLSData *tlsData = static_cast<MNN_TLSData *>(malloc(sizeof(MNN_TLSData)));
2199     setTLSData(tlsData);
2200     tlsData->interpreterMap = new std::unordered_map<std::string, Interpreter *>();
2201     tlsData->sessionCacheMap = new std::unordered_map<std::string, Session *>();
2202 
2203     // halide_type
2204     tlsData->PyMNNHalideTypeInt = PyCapsule_New(httInt(), NULL, NULL);
2205     tlsData->PyMNNHalideTypeInt64 = PyCapsule_New(httInt64(), NULL, NULL);
2206     tlsData->PyMNNHalideTypeFloat = PyCapsule_New(httFloat(), NULL, NULL);
2207     tlsData->PyMNNHalideTypeDouble = PyCapsule_New(httDouble(), NULL, NULL);
2208     tlsData->PyMNNHalideTypeUint8 = PyCapsule_New(httUint8(), NULL, NULL);
2209     tlsData->PyMNNHalideTypeString = PyCapsule_New(httString(), NULL, NULL);
2210 
2211 #if defined(PYMNN_USE_ALINNPYTHON) && defined(PYMNN_EXPR_API)
2212     struct py::detail::rh_tls *rh_tls = static_cast<py::detail::rh_tls *>(malloc(sizeof(py::detail::rh_tls)));
2213     if(nullptr == rh_tls) {
2214         throw runtime_error("rh_tls malloc fail");
2215     }
2216     set_rh_tls_data(rh_tls);
2217     rh_tls->internals_pp = nullptr;
2218     rh_tls->locals = new py::detail::type_map<py::detail::type_info *>;
2219 #endif
2220 
2221     PyModule_AddObject(m, "Halide_Type_Int", tlsData->PyMNNHalideTypeInt);
2222     PyModule_AddObject(m, "Halide_Type_Int64", tlsData->PyMNNHalideTypeInt64);
2223     PyModule_AddObject(m, "Halide_Type_Float", tlsData->PyMNNHalideTypeFloat);
2224     PyModule_AddObject(m, "Halide_Type_Double", tlsData->PyMNNHalideTypeDouble);
2225     PyModule_AddObject(m, "Halide_Type_Uint8", tlsData->PyMNNHalideTypeUint8);
2226     PyModule_AddObject(m, "Halide_Type_String", tlsData->PyMNNHalideTypeString);
2227 
2228     // CV
2229     // ImageFormat
2230     PyObject *CV_ImageFormat_RGBA = PyLong_FromLong(CV::RGBA);
2231     PyObject *CV_ImageFormat_RGB = PyLong_FromLong(CV::RGB);
2232     PyObject *CV_ImageFormat_BGR = PyLong_FromLong(CV::BGR);
2233     PyObject *CV_ImageFormat_GRAY = PyLong_FromLong(CV::GRAY);
2234     PyObject *CV_ImageFormat_BGRA = PyLong_FromLong(CV::BGRA);
2235     PyObject *CV_ImageFormat_YUV_NV21 = PyLong_FromLong(CV::YUV_NV21);
2236     PyModule_AddObject(m, "CV_ImageFormat_RGBA", CV_ImageFormat_RGBA);
2237     PyModule_AddObject(m, "CV_ImageFormat_RGB", CV_ImageFormat_RGB);
2238     PyModule_AddObject(m, "CV_ImageFormat_BGR", CV_ImageFormat_BGR);
2239     PyModule_AddObject(m, "CV_ImageFormat_GRAY", CV_ImageFormat_GRAY);
2240     PyModule_AddObject(m, "CV_ImageFormat_BGRA", CV_ImageFormat_BGRA);
2241     PyModule_AddObject(m, "CV_ImageFormat_YUV_NV21", CV_ImageFormat_YUV_NV21);
2242     // Filter
2243     PyObject *CV_Filter_NEAREST = PyLong_FromLong(CV::NEAREST);
2244     PyObject *CV_Filter_BILINEAL = PyLong_FromLong(CV::BILINEAR);
2245     PyObject *CV_Filter_BICUBIC = PyLong_FromLong(CV::BICUBIC);
2246     PyModule_AddObject(m, "CV_Filter_NEAREST", CV_Filter_NEAREST);
2247     PyModule_AddObject(m, "CV_Filter_BILINEAL", CV_Filter_BILINEAL);
2248     PyModule_AddObject(m, "CV_Filter_BICUBIC", CV_Filter_BICUBIC);
2249     // wrap
2250     PyObject *CV_Wrap_CLAMP_TO_EDGE = PyLong_FromLong(CV::CLAMP_TO_EDGE);
2251     PyObject *CV_Wrap_ZERO = PyLong_FromLong(CV::ZERO);
2252     PyObject *CV_Wrap_REPEAT = PyLong_FromLong(CV::REPEAT);
2253     PyModule_AddObject(m, "CV_Wrap_CLAMP_TO_EDGE", CV_Wrap_CLAMP_TO_EDGE);
2254     PyModule_AddObject(m, "CV_Wrap_ZERO", CV_Wrap_ZERO);
2255     PyModule_AddObject(m, "CV_Wrap_REPEAT", CV_Wrap_REPEAT);
2256 
2257     // static variable initialize
2258     interpreterMap();
2259     sessionCacheMap();
2260 
2261 #ifdef PYMNN_EXPR_API
2262     auto py_module = py::reinterpret_borrow<py::module>(m);
2263     INTS default_shape = {};
2264     auto expr_module = py_module.def_submodule("_expr");
2265     py::enum_<Dimensionformat> (expr_module, "data_format")
2266         .value("NHWC", NHWC)
2267         .value("NC4HW4", NC4HW4)
2268         .value("NCHW", NCHW)
2269         .export_values();
2270     py::enum_<DType> (expr_module, "dtype")
2271         .value("float", DType_FLOAT)
2272         .value("double", DType_DOUBLE)
2273         .value("int", DType_INT32)
2274         .value("int64", DType_INT64)
2275         .value("uint8", DType_UINT8)
2276         .export_values();
2277 
2278     py::enum_<PaddingMode> (expr_module, "Padding_Mode")
2279         .value("CAFFE", CAFFE)
2280         .value("VALID", VALID)
2281         .value("SAME", SAME)
2282         .export_values();
2283     py::enum_<MNN::Express::PadValueMode> (expr_module, "PadValue_Mode")
2284         .value("CONSTANT", CONSTANT)
2285         .value("REFLECT", REFLECT)
2286         .value("SYMMETRIC", SYMMETRIC)
2287         .export_values();
2288     py::enum_<PoolingMode> (expr_module, "Pooling_Mode")
2289         .value("MAXPOOL", MAXPOOL)
2290         .value("AVEPOOL", AVEPOOL)
2291         .export_values();
2292     py::enum_<InterpolationMethod> (expr_module, "Interp_Method")
2293         .value("BILINEAR", BILINEAR)
2294         .value("NEAREST", NEAREST)
2295         .export_values();
2296     py::class_<VARP>(expr_module, "Var")
2297         .def_property_readonly("shape",
2298 	    [](VARP *self){
2299             auto info = (*self)->getInfo();
2300             if(nullptr == info) {
2301                 throw std::runtime_error("unable to get variable info");
2302             }
2303             return info->dim;
2304 	    })
2305         .def_property_readonly("valid",
2306             [](VARP *self){
2307                 auto info = (*self)->getInfo();
2308                 if(nullptr == info) {
2309                     return false;
2310                 }
2311                 return true;
2312             })
2313         .def_property_readonly("data_format",
2314             [](VARP *self){
2315                 auto info = (*self)->getInfo();
2316                 if(nullptr == info)
2317                     throw std::runtime_error("unable to get variable info");
2318                 return info->order;
2319             })
2320         .def_property_readonly("dtype",
2321             [](VARP *self){
2322                 auto info = (*self)->getInfo();
2323                 if(nullptr == info)
2324                    throw std::runtime_error("unable to get variable info");
2325                 return htype2dtype(info->type);
2326             })
2327          .def_property_readonly("size",
2328             [](VARP *self){
2329                 auto info = (*self)->getInfo();
2330                 if(nullptr == info) {
2331                    throw std::runtime_error("unable to get variable info");
2332                 }
2333                 return info->size;
2334             })
2335         .def_property("name",
2336             [](VARP *self){
2337                 auto name = (*self)->name();
2338                 return name;
2339             },
2340             [] (VARP* self, std::string name) {
2341                 (*self)->setName(name);
2342             })
2343 #ifdef BUILD_OPTYPE
2344         .def_property_readonly("op_type",
2345             [](VARP *self){
2346                 auto op = (*self)->expr().first->get();
2347                 if (nullptr == op) {
2348                     switch ((*self)->expr().first->inputType()) {
2349                         case VARP::INPUT:
2350                             return std::string("Input");
2351                         case VARP::CONSTANT:
2352                             return std::string("Const");
2353                         case VARP::TRAINABLE:
2354                             return std::string("Trainable");
2355                     }
2356                 }
2357 
2358                 auto type = op->type();
2359                 if (type == OpType_BinaryOp) {
2360                     return std::string(MNN::EnumNameBinaryOpOperation((BinaryOpOperation)op->main_as_BinaryOp()->opType()));
2361                 }
2362                 if (type == OpType_UnaryOp) {
2363                     return std::string(MNN::EnumNameUnaryOpOperation((UnaryOpOperation)op->main_as_UnaryOp()->opType()));
2364                 }
2365                 return std::string(MNN::EnumNameOpType(type));
2366             })
2367 #endif
2368         .def_property_readonly("inputs",
2369             [] (VARP* self) {
2370                 return (*self)->expr().first->inputs();
2371             })
2372         .def("fix_as_placeholder",
2373             [] (VARP* self) {
2374                 (*self).fix(VARP::INPUT);
2375             })
2376 
2377         .def("fix_as_const",
2378             [] (VARP* self) {
2379                 (*self).fix(VARP::CONSTANT);
2380             })
2381         .def("fix_as_trainable",
2382             [] (VARP* self) {
2383                 (*self).fix(VARP::TRAINABLE);
2384             })
2385         .def("close",
2386             [] (VARP* self) {
2387                 (*self)->input(VARP(nullptr));
2388             })
2389         .def("copy_from",
2390             [] (VARP* self, VARP source) {
2391                 bool res = (*self)->input(source);
2392                 if (!res) {
2393                     throw std::runtime_error("Copy from souce Error");
2394                 }
2395             })
2396         .def("set_inputs",
2397             [] (VARP* self, std::vector<VARP> source) {
2398                 if (source.empty()) {
2399                     throw std::runtime_error("Empty source");
2400                 }
2401                 auto expr = (*self)->expr();
2402                 auto newExpr = Expr::create(expr.first->extra(), std::move(source), expr.first->outputSize());
2403                 Expr::replace(expr.first, newExpr);
2404             })
2405         .def("replace",
2406             [] (VARP* self, VARP source) {
2407                 Variable::replace(*self, source);
2408             })
2409         .def("reorder",
2410             [] (VARP* self, Dimensionformat order) {
2411                 auto newInput = _ChangeInputFormat(*self, order);
2412                 (*self) = newInput;
2413             })
2414         .def("resize",
2415             [] (VARP* self, const std::vector<int>& shape) {
2416                 (*self)->resize(shape);
2417             })
2418 #ifdef PYMNN_NUMPY_USABLE
2419         .def("read",
2420             [](VARP *self){
2421                 auto info = (*self)->getInfo();
2422                 if(nullptr == info)
2423                    throw std::runtime_error("unable to get variable info");
2424                 auto dtype = htype2dtype(info->type);
2425                 auto shape = info->dim;
2426                 int64_t total_length = info->size;
2427                 auto readptr = [self](DType dtype, INTS shape, int64_t total_length) {
2428                     void *dataPtr = (void *) (*self)->readMap<void>();
2429                     std::vector<npy_intp> npy_dims;
2430                     for(const auto dim: shape) {
2431                         npy_dims.push_back(dim);
2432                     }
2433 
2434                     switch(dtype) {
2435                        case DType_FLOAT:
2436                            return PyArray_SimpleNewFromData(npy_dims.size(), npy_dims.data(), NPY_FLOAT, dataPtr);
2437                        case DType_DOUBLE:
2438                            return PyArray_SimpleNewFromData(npy_dims.size(), npy_dims.data(), NPY_DOUBLE, dataPtr);
2439                        case DType_INT32:
2440                            return PyArray_SimpleNewFromData(npy_dims.size(), npy_dims.data(), NPY_INT32, dataPtr);
2441                        case DType_INT64:
2442                            return PyArray_SimpleNewFromData(npy_dims.size(), npy_dims.data(), NPY_INT64, dataPtr);
2443                        case DType_UINT8:
2444                            return PyArray_SimpleNewFromData(npy_dims.size(), npy_dims.data(), NPY_UINT8, dataPtr);
2445                        default:
2446                           throw std::runtime_error("does not support this dtype");
2447                     }
2448                     if (nullptr == dataPtr) {
2449                         throw std::runtime_error("call to readMap meet a error");
2450                     }
2451                 };
2452                 auto data = readptr(dtype, shape, total_length);
2453                 (*self)->unMap();
2454                 return py::reinterpret_steal<py::object>(data);
2455             })
2456 #endif
2457         .def("read_as_tuple",
2458             [](VARP *self){
2459                 auto info = (*self)->getInfo();
2460                 if(nullptr == info)
2461                    throw std::runtime_error("unable to get variable info");
2462                 auto dtype = htype2dtype(info->type);
2463                 auto shape = info->dim;
2464                 size_t total_length = info->size;
2465                 auto readptr = [self](DType dtype, INTS shape, size_t total_length) {
2466                     void *dataPtr = (void *) (*self)->readMap<void>();
2467                     auto obj = PyTuple_New(total_length);
2468                     if(DType_FLOAT == dtype) {
2469                         auto data = (float*)dataPtr;
2470                         for(size_t i = 0; i < total_length; i++) {
2471                             PyTuple_SetItem(obj, i, PyFloat_FromDouble(data[i]));
2472                         }
2473                     } else if(DType_INT32 == dtype) {
2474                         auto data = (int32_t*)dataPtr;
2475                         for(size_t i = 0; i < total_length; i++) {
2476                             PyTuple_SetItem(obj, i, PyLong_FromLong(data[i]));
2477                         }
2478                     } else if(DType_UINT8 == dtype) {
2479                         auto data = (uint8_t*)dataPtr;
2480                         for(size_t i = 0; i < total_length; i++) {
2481                             PyTuple_SetItem(obj, i, PyLong_FromLong(data[i]));
2482                         }
2483                     } else if(DType_INT8 == dtype) {
2484                         auto data = (int8_t*)dataPtr;
2485                         for(size_t i = 0; i < total_length; i++) {
2486                             PyTuple_SetItem(obj, i, PyLong_FromLong(data[i]));
2487                         }
2488                     } else {
2489                         throw std::runtime_error("Don't support data type");
2490                     }
2491                     return obj;
2492                 };
2493                 auto data = readptr(dtype, shape, total_length);
2494                 (*self)->unMap();
2495                 return py::reinterpret_steal<py::object>(data);
2496             })
2497         .def("write",
2498             [](VARP *self, py::object data) {
2499                 auto info = (*self)->getInfo();
2500                 if(nullptr == info) {
2501                     throw std::runtime_error("unable to get variable info");
2502                 }
2503                 auto dtype = htype2dtype(info->type);
2504                 auto shape = info->dim;
2505                 int64_t total_length = info->size;
2506                 PyObject *obj = data.ptr();
2507                 auto write = [self](PyObject *obj, DType dtype, int64_t total_length) {
2508  #ifdef PYMNN_NUMPY_USABLE
2509                     if(PyArray_Check(obj)) {
2510                         //numpy support
2511                         if(total_length != PyArray_Size(obj)) {
2512                             throw std::runtime_error("data size does not match each other");
2513                         }
2514                         int npy_type = PyArray_TYPE(obj);
2515                         int itemsize = getitemsize(dtype, npy_type);
2516                         PyArrayObject *obj_cont= PyArray_GETCONTIGUOUS((PyArrayObject*)obj);
2517                         auto tmpBuffer = PyArray_DATA(obj_cont);
2518                         if(NULL == tmpBuffer) {
2519                             throw std::runtime_error("numpy failed to get buffer");
2520                         }
2521                         auto data = (*self)->writeMap<void>();
2522                         if (nullptr == data) {
2523                             throw std::runtime_error("call to writeMap meet a error");
2524                         }
2525                         memcpy(data, tmpBuffer, total_length * itemsize);
2526                         Py_XDECREF(obj_cont);
2527                         return;
2528                     }
2529 #endif
2530                     INTS shapeData = getshape(obj);
2531                     int64_t totalLengthData = 1;
2532                     INTS stride;
2533                     for (size_t i = 0; i < shapeData.size(); i++) {
2534                         totalLengthData *= shapeData[i];
2535                     }
2536                     int totalStride = 1;
2537                     for (int i = shapeData.size() - 1; i >= 0; i--) {
2538                        if(i + 1 < shapeData.size()) {
2539                            totalStride *= shapeData[i+1];
2540                        }
2541                        stride.push_back(totalStride);
2542                     }
2543                     std::reverse(stride.begin(), stride.end());
2544                     if(totalLengthData != total_length) {
2545                         throw std::runtime_error("data size does not match each other");
2546                     }
2547                     if(DType_FLOAT == dtype) {
2548                         auto data = (*self)->writeMap<float>();
2549                         if (nullptr == data) {
2550                             throw std::runtime_error("call to writeMap meet a error");
2551                         }
2552                         recursive_store((char*)data, shapeData, stride, 0, obj, dtype, sizeof(float));
2553                     }
2554                     else if(DType_INT32 == dtype) {
2555                         auto data = (*self)->writeMap<int>();
2556                         if (nullptr == data) {
2557                             throw std::runtime_error("call to writeMap meet a error");
2558                         }
2559                         recursive_store((char*)data, shapeData, stride, 0, obj, dtype, sizeof(int));
2560                     }
2561                     else if(DType_UINT8 == dtype) {
2562                         auto data = (*self)->writeMap<uint8_t>();
2563                         if (nullptr == data) {
2564                             throw std::runtime_error("call to writeMap meet a error");
2565                         }
2566                         recursive_store((char*)data, shapeData, stride, 0, obj, dtype, sizeof(uint8_t));
2567                     }
2568                     else if(DType_INT8 == dtype) {
2569                         auto data = (*self)->writeMap<uint8_t>();
2570                         if (nullptr == data) {
2571                             throw std::runtime_error("call to writeMap meet a error");
2572                         }
2573                         recursive_store((char*)data, shapeData, stride, 0, obj, dtype, sizeof(int8_t));
2574                     }
2575                 };
2576                 write(obj, dtype, total_length);
2577                 (*self)->unMap();
2578             });
2579     // Load And Save
2580     expr_module.def("load_as_list",
2581     		[](std::string fileName) {
2582                 auto variable = Variable::load(fileName.c_str());
2583 			    return variable;
2584     });
2585     expr_module.def("save",
2586     		[](const std::vector<VARP>& vars, std::string fileName, bool forInference = true) {
2587                 std::vector<VARP> newVars;
2588                 for (auto v : vars) {
2589                     if (v.get() != nullptr) {
2590                         newVars.emplace_back(v);
2591                     }
2592                 }
2593 #ifdef PYMNN_TRAIN_API
2594                 if (forInference) {
2595                     Transformer::turnModelToInfer()->onExecute(newVars);
2596                 }
2597 #endif
2598                 Variable::save(newVars, fileName.c_str());
2599 #ifdef PYMNN_TRAIN_API
2600                 ConvertToFullQuant::convert(fileName);
2601 #endif
2602     }, py::arg("variables"), py::arg("file_name"), py::arg("for_inference") = true);
2603     expr_module.def("load_as_dict",
2604     		[](std::string fileName) {
2605                 auto variable = Variable::loadMap(fileName.c_str());
2606 			    return variable;
2607     });
2608     expr_module.def("get_inputs_and_outputs", &Variable::getInputAndOutput);
2609     // Executor
2610     expr_module.def("gc", [](bool full) {
2611         auto exe = Executor::getGlobalExecutor();
2612         if (full) {
2613             exe->gc(Executor::FULL);
2614         } else {
2615             exe->gc(Executor::PART);
2616         }
2617     });
2618 
2619     // Don't care whether MNN library support corresponding backend, all backend type are usable by user,
2620     // which make MNN.whl setup.py easy
2621     py::enum_<MNNForwardType>(expr_module, "Backend")
2622         .value("CPU", MNN_FORWARD_CPU)
2623         .value("OPENCL", MNN_FORWARD_OPENCL)
2624         .value("OPENGL", MNN_FORWARD_OPENGL)
2625         .value("VULKAN", MNN_FORWARD_VULKAN)
2626         .value("METAL", MNN_FORWARD_METAL)
2627         .value("TRT", MNN_FORWARD_USER_1)
2628         .value("CUDA", MNN_FORWARD_CUDA)
2629         .value("HIAI", MNN_FORWARD_USER_0)
2630         .export_values();
2631 
2632     using MemoryMode = BackendConfig::MemoryMode;
2633     using PowerMode = BackendConfig::PowerMode;
2634     using PrecisionMode = BackendConfig::PrecisionMode;
2635     py::enum_<MemoryMode>(expr_module, "MemoryMode")
2636         .value("Normal", MemoryMode::Memory_Normal)
2637         .value("High", MemoryMode::Memory_High)
2638         .value("Low", MemoryMode::Memory_Low)
2639         .export_values();
2640     py::enum_<PowerMode>(expr_module, "PowerMode")
2641         .value("Normal", PowerMode::Power_Normal)
2642         .value("High", PowerMode::Power_High)
2643         .value("Low", PowerMode::Power_Low)
2644         .export_values();
2645     py::enum_<PrecisionMode>(expr_module, "PrecisionMode")
2646         .value("Normal", PrecisionMode::Precision_Normal)
2647         .value("High", PrecisionMode::Precision_High)
2648         .value("Low", PrecisionMode::Precision_Low)
2649         .export_values();
2650     expr_module.def("set_config",
2651     		[](MNNForwardType backend, MemoryMode memory_mode, PowerMode power_mode, PrecisionMode precision_mode, int thread_num) {
2652                 if (thread_num < 1 || thread_num > 8) {
2653                     PyErr_SetString(PyExc_Exception, "thread_num should bigger than 0 and less than 9");
2654                 }
2655                 thread_num = std::max(std::min(thread_num, 8), 1);
2656                 //auto exe = ExecutorScope::Current();
2657                 auto exe = Executor::getGlobalExecutor();
2658                 BackendConfig config;
2659                 config.memory = memory_mode;
2660                 config.power = power_mode;
2661                 config.precision = precision_mode;
2662                 exe->setGlobalExecutorConfig(backend, config, thread_num);
2663             },
2664             py::arg("backend")=MNN_FORWARD_CPU, py::arg("memory_mode")=MemoryMode::Memory_Normal,
2665             py::arg("power_mode")=PowerMode::Power_Normal, py::arg("precision_mode")=PrecisionMode::Precision_Normal,
2666             py::arg("thread_num")=1);
2667 
2668     //Begin of Math OPS
2669     //Unary OPS
2670     expr_module.def("sign", &Express::_Sign);
2671     expr_module.def("abs", &Express::_Abs);
2672     expr_module.def("negative", &Express::_Negative);
2673     expr_module.def("floor", &Express::_Floor);
2674     expr_module.def("ceil", &Express::_Ceil);
2675     expr_module.def("square", &Express::_Square);
2676     expr_module.def("sqrt", &Express::_Sqrt);
2677     expr_module.def("rsqrt", &Express::_Rsqrt);
2678     expr_module.def("exp", &Express::_Exp);
2679     expr_module.def("log", &Express::_Log);
2680     expr_module.def("sin", &Express::_Sin);
2681     expr_module.def("cos", &Express::_Cos);
2682     expr_module.def("tan", &Express::_Tan);
2683     expr_module.def("asin", &Express::_Asin);
2684     expr_module.def("acos", &Express::_Acos);
2685     expr_module.def("atan", &Express::_Atan);
2686     expr_module.def("reciprocal", &Express::_Reciprocal);
2687     expr_module.def("log1p", &Express::_Log1p);
2688     expr_module.def("tanh", &Express::_Tanh);
2689     expr_module.def("sigmoid", &Express::_Sigmoid);
2690     //Binary OPS
2691     expr_module.def("add", &Express::_Add);
2692     expr_module.def("subtract", &Express::_Subtract);
2693     expr_module.def("multiply", &Express::_Multiply);
2694     expr_module.def("divide", &Express::_Divide);
2695     expr_module.def("pow", &Express::_Pow);
2696     expr_module.def("minimum", &Express::_Minimum);
2697     expr_module.def("maximum", &Express::_Maximum);
2698     expr_module.def("bias_add", &Express::_BiasAdd);
2699     expr_module.def("greater", &Express::_Greater);
2700     expr_module.def("greater_equal", &Express::_GreaterEqual);
2701     expr_module.def("less", &Express::_Less);
2702     expr_module.def("floordiv", &Express::_FloorDiv);
2703     expr_module.def("squared_difference", &Express::_SquaredDifference);
2704     expr_module.def("equal", &Express::_Equal);
2705     expr_module.def("not_equal", &Express::_NotEqual);
2706     expr_module.def("less_equal", &Express::_LessEqual);
2707     expr_module.def("floormod", &Express::_FloorMod);
2708     //Reduce OPS
2709     expr_module.def("reduce_sum",
2710                     [](VARP input, INTS axis, bool keep_dims) {
2711                         return _ReduceSum(input, axis, keep_dims);
2712                     }, py::arg("input"), py::arg("axis")=default_shape, py::arg("keep_dims")=false);
2713     expr_module.def("reduce_mean",
2714                     [](VARP input, INTS axis, bool keep_dims) {
2715                         return _ReduceMean(input, axis, keep_dims);
2716                     }, py::arg("input"), py::arg("axis")=default_shape, py::arg("keep_dims")=false);
2717     expr_module.def("reduce_max",
2718                     [](VARP input, INTS axis, bool keep_dims) {
2719                         return _ReduceMax(input, axis, keep_dims);
2720                     }, py::arg("input"), py::arg("axis")=default_shape, py::arg("keep_dims")=false);
2721     expr_module.def("reduce_min",
2722                     [](VARP input, INTS axis, bool keep_dims) {
2723                         return _ReduceMin(input, axis, keep_dims);
2724                     }, py::arg("input"), py::arg("axis")=default_shape, py::arg("keep_dims")=false);
2725     expr_module.def("reduce_prod",
2726                     [](VARP input, INTS axis, bool keep_dims) {
2727                         return _ReduceProd(input, axis, keep_dims);
2728                     }, py::arg("input"), py::arg("axis")=default_shape, py::arg("keep_dims")=false);
2729     expr_module.def("reduce_any",
2730                     [](VARP input, INTS axis, bool keep_dims) {
2731                         return _ReduceAny(input, axis, keep_dims);
2732                     }, py::arg("input"), py::arg("axis")=default_shape, py::arg("keep_dims")=false);
2733     expr_module.def("reduce_all",
2734                     [](VARP input, INTS axis, bool keep_dims) {
2735                         return _ReduceAll(input, axis, keep_dims);
2736                     }, py::arg("input"), py::arg("axis")=default_shape, py::arg("keep_dims")=false);
2737     //Eltwise OPS
2738     expr_module.def("eltwise_prod", &Express::_Prod);
2739     expr_module.def("eltwise_sum", &Express::_Sum);
2740     expr_module.def("eltwise_max", &Express::_Max);
2741     expr_module.def("eltwise_sub", &Express::_Sub);
2742     //Other OPS
2743     expr_module.def("cast",
2744 		    [](VARP x, DType dtype) {
2745 			return _Cast(x, dtype2htype(dtype));
2746                     });
2747     expr_module.def("matmul", &Express::_MatMul, py::arg("a"), py::arg("b"), py::arg("tranposeA")=false, py::arg("tranposeB")=false);
2748     expr_module.def("normalize", &Express::_Normalize);
2749     expr_module.def("argmax",
2750 		   [](VARP input, int axis) {
2751 			return _ArgMax(input, axis);
2752                    }, py::arg("input"), py::arg("axis")=0);
2753     expr_module.def("unravel_index", &Express::_UnravelIndex, py::arg("indices"), py::arg("dims"));
2754     expr_module.def("scatter_nd", &Express::_ScatterNd, py::arg("indices"), py::arg("updates"), py::arg("shape"));
2755     expr_module.def("one_hot",
2756 		   [](VARP indices, int depth, float onValue, float offValue, int axis) {
2757 			return _OneHot(indices, _Scalar<int>(depth), _Scalar<float>(onValue), _Scalar<float>(offValue), axis);
2758                    },py::arg("indices"), py::arg("depth"), py::arg("on_value")=1, py::arg("off_value")=0, py::arg("axis")=-1);
2759     expr_module.def("broadcast_to", &Express::_BroadcastTo, py::arg("input"), py::arg("shape"));
2760     //End of Math OPS
2761 
2762     //Begin of NN OPS
2763     expr_module.def("placeholder",
2764                   [](INTS shape,Dimensionformat data_format, DType dtype)->VARP{
2765     			return _Input(shape, data_format, dtype2htype(dtype));
2766                   },
2767                   py::arg("shape")=default_shape,
2768                   py::arg("data_format")=NCHW,
2769                   py::arg("dtype")=DType_FLOAT);
2770     expr_module.def("clone",
2771                    [](VARP source, bool deepCopy) {
2772 			return _Clone(source, deepCopy);
2773                    }, py::arg("source"), py::arg("deep_copy")=false);
2774     INTS default_pads = {0, 0};
2775     INTS default_axis = {};
2776     expr_module.def("const",
2777             [](py::object value, INTS shape, Dimensionformat data_format, DType dtype) {
2778                 int64_t total_length = 1;
2779                 for(size_t i = 0; i < shape.size(); i++) {
2780                     if (data_format == NC4HW4 && 1 == i)
2781                     {
2782 #ifndef ROUND_UP
2783 #define ROUND_UP(x, y) (((x) + (y) - (1)) / (y) * (y))
2784 #endif
2785                         total_length *= ROUND_UP(shape[i], 4);
2786                     }
2787                     else
2788                     {
2789                         total_length *= shape[i];
2790                     }
2791                 }
2792                 PyObject *obj = value.ptr();
2793                 auto write = [](PyObject *obj, DType dtype, int64_t total_length) {
2794  #ifdef PYMNN_NUMPY_USABLE
2795                     if(PyArray_Check(obj)) {
2796                         //numpy support
2797                         if(total_length != PyArray_Size(obj)) {
2798                             throw std::runtime_error("data size does not match each other");
2799                         }
2800                         int npy_type = PyArray_TYPE(obj);
2801                         int itemsize = getitemsize(dtype, npy_type);
2802                         PyArrayObject *obj_cont= PyArray_GETCONTIGUOUS((PyArrayObject*)obj);
2803                         auto tmpBuffer = PyArray_DATA(obj_cont);
2804                         if(NULL == tmpBuffer) {
2805                             throw std::runtime_error("numpy failed to get buffer");
2806                         }
2807                         auto data = malloc(total_length * itemsize);
2808                         if (nullptr == data) {
2809                             throw std::runtime_error("call to writeMap meet a error");
2810                         }
2811                         memcpy(data, tmpBuffer, total_length * itemsize);
2812                         Py_XDECREF(obj_cont);
2813                         return data;
2814                     }
2815 #endif
2816                     INTS shapeData = getshape(obj);
2817                     int64_t totalLengthData = 1;
2818                     INTS stride;
2819                     for(size_t i = 0; i < shapeData.size(); i++) {
2820                         totalLengthData *= shapeData[i];
2821                     }
2822                     int totalStride = 1;
2823                     for (int i = shapeData.size() - 1; i >= 0; i--) {
2824                        if (i + 1 < shapeData.size()) {
2825                            totalStride *= shapeData[i+1];
2826                        }
2827                        stride.push_back(totalStride);
2828                     }
2829                     std::reverse(stride.begin(), stride.end());
2830                     if(totalLengthData != total_length) {
2831                         throw std::runtime_error("data size does not match each other");
2832                     }
2833                     void *data = nullptr;
2834                     if(DType_FLOAT == dtype) {
2835                         data = malloc(total_length * sizeof(float));
2836                         if (nullptr == data) {
2837                             throw std::runtime_error("not enough memory");
2838                         }
2839                         recursive_store((char*)data, shapeData, stride, 0, obj, dtype, sizeof(float));
2840                     }
2841                     else if(DType_INT32 == dtype) {
2842                         data = malloc(total_length * sizeof(int));
2843                         if (nullptr == data) {
2844                             throw std::runtime_error("not enough memory");
2845                         }
2846                         recursive_store((char*)data, shapeData, stride, 0, obj, dtype, sizeof(int));
2847                     }
2848                     else if(DType_UINT8 == dtype) {
2849                         data = malloc(total_length * sizeof(uint8_t));
2850                         if (nullptr == data) {
2851                             throw std::runtime_error("not enough memory");
2852                         }
2853                         recursive_store((char*)data, shapeData, stride, 0, obj, dtype, sizeof(uint8_t));
2854                     }
2855                     else if(DType_INT8 == dtype) {
2856                         data = malloc(total_length * sizeof(int8_t));
2857                         if (nullptr == data) {
2858                             throw std::runtime_error("not enough memory");
2859                         }
2860                         recursive_store((char*)data, shapeData, stride, 0, obj, dtype, sizeof(int8_t));
2861                     }
2862                     return data;
2863                 };
2864                 auto data = write(obj, dtype, total_length);
2865                 VARP ret = nullptr;
2866                 if(data) {
2867                     ret = _Const((const void*)data, shape, data_format, dtype2htype(dtype));
2868                     free(data);
2869                 }
2870                 return ret;
2871             },py::arg("value_list"), py::arg("shape"), py::arg("data_format")=NCHW, py::arg("dtype")=DType::DType_FLOAT);
2872     INTS default_stride = {1, 1};
2873     INTS default_dialate = {1, 1};
2874     expr_module.def("conv2d",
2875             [](VARP input, VARP weight, VARP bias, INTS stride, INTS padding, INTS dilate, int group, PaddingMode padding_mode) {
2876                 return _Conv(weight, bias, input, padding_mode, stride, dilate, group, padding);
2877             },py::arg("input"), py::arg("weight"), py::arg("bias"),
2878             py::arg("stride")=default_stride,
2879             py::arg("padding")=default_pads,
2880             py::arg("dilate")=default_dialate,
2881             py::arg("group")=1,
2882             py::arg("padding_mode")=VALID);
2883     expr_module.def("conv2d_transpose",
2884             [](VARP input, VARP weight, VARP bias, INTS stride, INTS padding, INTS dilate, int group, PaddingMode padding_mode) {
2885                 return _Deconv(weight, bias, input, padding_mode, stride, dilate, group, padding);
2886             },py::arg("input"), py::arg("weight"), py::arg("bias"),
2887             py::arg("stride")=default_stride,
2888             py::arg("padding")=default_pads,
2889             py::arg("dilate")=default_dialate,
2890             py::arg("group")=1,
2891             py::arg("padding_mode")=VALID);
2892     expr_module.def("max_pool",
2893                    [](VARP x, INTS kernel, INTS stride, PaddingMode pad, INTS pads) {
2894                         return _MaxPool(x, kernel, stride, pad, pads);
2895                    }, py::arg("input"), py::arg("kernel"), py::arg("stride"),
2896 		   py::arg("padding_mode")=VALID,
2897 		   py::arg("pads")=default_pads);
2898     expr_module.def("avg_pool",
2899                    [](VARP x, INTS kernel, INTS stride, PaddingMode pad, INTS pads) {
2900                         return _AvePool(x, kernel, stride, pad, pads);
2901                    }, py::arg("input"), py::arg("kernel"), py::arg("stride"),
2902                    py::arg("padding_mode")=VALID,
2903                    py::arg("pads")=default_pads);
2904     expr_module.def("reshape",
2905                    [](VARP x, INTS shape, Dimensionformat original_format) {
2906                         return _Reshape(x, shape, original_format);
2907                    }, py::arg("x"), py::arg("shape"), py::arg("original_format")=NCHW);
2908     expr_module.def("reshape",
2909                    [](VARP x, VARP shape) {
2910                         return _Reshape(x, shape);
2911                    });
2912     expr_module.def("scale", &Express::_Scale, py::arg("x"), py::arg("channels"), py::arg("scales"), py::arg("bias"));
2913     expr_module.def("relu",
2914                    [](VARP x, float slope) {
2915                         return _Relu(x, slope);
2916                    }, py::arg("x"), py::arg("slope")=0.0f);
2917     expr_module.def("relu6", &Express::_Relu6, py::arg("x"), py::arg("min") = 0.0f, py::arg("max") = 6.0f);
2918     expr_module.def("prelu", &Express::_PRelu, py::arg("x"), py::arg("slopes"));
2919     expr_module.def("softmax",
2920                    [](VARP logits, int axis) {
2921                         return _Softmax(logits, axis);
2922                    }, py::arg("logits"), py::arg("axis")=-1);
2923     expr_module.def("softplus", &Express::_Softplus, py::arg("features"));
2924     expr_module.def("softsign", &Express::_Softsign, py::arg("features"));
2925     expr_module.def("slice", &Express::_Slice, py::arg("input"), py::arg("starts"), py::arg("sizes"));
2926     expr_module.def("split", &Express::_Split, py::arg("input"), py::arg("size_splits"), py::arg("axis"));
2927     expr_module.def("strided_slice", &Express::_StridedSlice, py::arg("input"), py::arg("begin"), py::arg("end"),
2928         py::arg("strides"), py::arg("begin_mask"), py::arg("end_mask"), py::arg("ellipsis_mask"), py::arg("new_axis_mask"), py::arg("shrink_axis_mask"));
2929     expr_module.def("concat", &Express::_Concat, py::arg("values"), py::arg("axis"));
2930     expr_module.def("convert", &Express::_Convert, py::arg("input"), py::arg("format"));
2931     expr_module.def("transpose",
2932                    [](VARP x, INTS perm) {
2933                         return _Transpose(x, perm);
2934                    }, py::arg("x"), py::arg("perm"));
2935     expr_module.def("transpose",
2936                    [](VARP x, VARP perm) {
2937                         return _Transpose(x, perm);
2938                    });
2939     expr_module.def("channel_shuffle", &Express::_ChannelShuffle);
2940     // change_inputformat not exposed because it's for static graphs.
2941     //expr_module.def("change_inputformat", &Express::_ChangeInputFormat);
2942     //
2943     expr_module.def("reverse_sequence", &Express::_ReverseSequence, py::arg("x"), py::arg("y"), py::arg("batch_dim"), py::arg("seq_dim"));
2944     expr_module.def("crop", &Express::_Crop, py::arg("images"), py::arg("size"), py::arg("axis"), py::arg("offset"));
2945     expr_module.def("resize", &Express::_Resize, py::arg("images"), py::arg("x_scale"), py::arg("y_scale"));
2946     expr_module.def("pad",
2947                    [](VARP x, VARP paddings, MNN::Express::PadValueMode mode) {
2948                         return Express::_Pad(x, paddings, mode);
2949                    }, py::arg("x"), py::arg("paddings"), py::arg("mode")=CONSTANT);
2950     expr_module.def("expand_dims",
2951                    [](VARP input, int axis) {
2952                         return _ExpandDims(input, axis);
2953                    });
2954     expr_module.def("expand_dims",
2955                    [](VARP input, VARP axis) {
2956                         return _ExpandDims(input, axis);
2957                    });
2958     expr_module.def("shape",
2959                     [](VARP input) {
2960                         return Express::_Shape(input, false);
2961                     }, py::arg("input"));
2962     expr_module.def("stack",
2963                    [](VARPS values, int axis) {
2964                         return _Stack(values, axis);
2965  		   }, py::arg("values"), py::arg("axis")=0);
2966     expr_module.def("crop_and_resize",
2967                    [](VARP image, VARP boxes, VARP box_ind, VARP crop_size, InterpolationMethod method, float extrapolation_value) {
2968                         return _CropAndResize(image, boxes, box_ind, crop_size, method, extrapolation_value);
2969                    }, py::arg("image"), py::arg("boxes"), py::arg("box_ind"), py::arg("crop_size"),
2970 		   py::arg("method")=BILINEAR, py::arg("extrapolation_value")=0.0f);
2971     expr_module.def("fill", &Express::_Fill, py::arg("dims"), py::arg("value"));
2972     expr_module.def("tile", &Express::_Tile, py::arg("input"), py::arg("multiples"));
2973     expr_module.def("gather", &Express::_Gather, py::arg("params"), py::arg("indices"));
2974     expr_module.def("select", &Express::_Select);
2975 
2976     // Currently only axis == 0 is supported, which is the same as gather.
2977     /*
2978     expr_module.def("gather_v2",
2979                    [](VARP params, VARP indices, VARP axis = nullptr) {
2980                         return _GatherV2(params, indices, axis);
2981                    }, py::arg("params"), py::arg("indices"), py::arg("axis")=nullptr);
2982                    */
2983 
2984     expr_module.def("squeeze",
2985                    [](VARP input, INTS axis) {
2986                         return _Squeeze(input, axis);
2987                    }, py::arg("input"), py::arg("axis")=default_axis);
2988     expr_module.def("unsqueeze",
2989                    [](VARP input, INTS axis) {
2990                         return _Unsqueeze(input, axis);
2991                    }, py::arg("input"), py::arg("axis")=default_axis);
2992     expr_module.def("batch_to_space_nd", &Express::_BatchToSpaceND, py::arg("input"), py::arg("block_shape"), py::arg("crops"));
2993     expr_module.def("gather_nd", &Express::_GatherND, py::arg("params"), py::arg("indices"));
2994     expr_module.def("selu", &Express::_Selu, py::arg("features"), py::arg("scale"), py::arg("alpha"));
2995     expr_module.def("size", &Express::_Size, py::arg("input"));
2996     expr_module.def("elu",
2997                    [](VARP features, float alpha) {
2998                         return _Elu(features, alpha);
2999                    }, py::arg("features"), py::arg("alpha")=1.0);
3000     expr_module.def("matrix_band_part", &Express::_MatrixBandPart, py::arg("input"), py::arg("num_lower"), py::arg("num_upper"));
3001     expr_module.def("moments", &Express::_Moments, py::arg("x"), py::arg("axes"), py::arg("shift"), py::arg("keep_dims"));
3002     expr_module.def("setdiff1d", &Express::_SetDiff1D, py::arg("x"), py::arg("y"));
3003     expr_module.def("space_to_depth", &Express::_SpaceToDepth, py::arg("input"), py::arg("block_size"));
3004     expr_module.def("space_to_batch_nd", &Express::_SpaceToBatchND, py::arg("input"), py::arg("block_shape"), py::arg("paddings"));
3005     expr_module.def("zeros_like", &Express::_ZerosLike, py::arg("input"));
3006     expr_module.def("unstack",
3007                    [](VARP value, int axis) {
3008                         return _Unstack(value, axis);
3009                    }, py::arg("value"), py::arg("axis")=0);
3010     expr_module.def("rank", &Express::_Rank, py::arg("input"));
3011     expr_module.def("range", &Express::_Range, py::arg("start"), py::arg("limit"), py::arg("delta"));
3012     expr_module.def("depth_to_space", &Express::_DepthToSpace, py::arg("input"), py::arg("block_size"));
3013     expr_module.def("detection_post_process", &Express::_DetectionPostProcess,
3014                    py::arg("encode_boxes"), py::arg("class_predictions"), py::arg("anchors"),
3015                    py::arg("num_classes"), py::arg("max_detections"), py::arg("max_class_per_detection"),
3016                    py::arg("detections_per_class"), py::arg("nms_threshold"), py::arg("iou_threshold"),
3017                    py::arg("use_regular_nms")=false, py::arg("centersize_encoding"));
3018     //End of NN OPS
3019     auto cv_module = py_module.def_submodule("cv");
3020     py::enum_<CV::ImageFormat>(cv_module, "Format")
3021         .value("RGBA", CV::RGBA)
3022         .value("RGB", CV::RGB)
3023         .value("GRAY", CV::GRAY)
3024         .value("BGR", CV::BGR)
3025         .value("YUV_NV21", CV::YUV_NV21)
3026         .value("YUV_NV12", CV::YUV_NV12)
3027         .export_values();
3028 
3029     auto nn_module = py_module.def_submodule("_nn");
3030 
3031     class PyModule : public Module {
3032     public:
3033         using Module::Module;
3034         using Module::registerModel;
3035 
3036         virtual std::vector<Express::VARP> onForward(const std::vector<Express::VARP>& inputs) override {
3037             PYBIND11_OVERLOAD_PURE(std::vector<Express::VARP>, Module, forward, inputs);
3038         }
3039     };
3040     py::class_<Module, PyModule, std::shared_ptr<Module>>(nn_module, "_Module")
3041         .def(py::init())
3042         .def("__call__", &Module::forward)
3043         .def("__call__", &Module::onForward)
3044         .def("forward", &Module::forward)
3045         .def("forward", &Module::onForward)
3046         .def_property_readonly("name", &Module::name) // TODO: too ugly, find way to fix it
3047         .def("set_name", &Module::setName)
3048         .def_property_readonly("is_training", &Module::getIsTraining)
3049         .def("train", &Module::setIsTraining, py::arg("is_training") = true)
3050         .def_property_readonly("parameters", &Module::parameters)
3051         .def("load_parameters", &Module::loadParameters)
3052         .def("clear_cache", &Module::clearCache)
3053         .def("_register_submodules", &PyModule::registerModel)
3054         .def("_add_parameter", &Module::addParameter);
3055 
3056     nn_module.def("load_module", [](vector<VARP> inputs, vector<VARP> outputs, bool fortrain){
3057 #ifdef PYMNN_TRAIN_API
3058         return NN::extract(inputs, outputs, fortrain);
3059 #else
3060         return Module::extract(inputs, outputs, fortrain);
3061 #endif
3062     });
3063     nn_module.def("load_module_from_file", [](const vector<string>& inputs, const vector<string>& outputs,
3064                                               const char* file_name, bool dynamic, bool shape_mutable, bool rearrange,
3065                                               MNNForwardType backend, MemoryMode memory_mode, PowerMode power_mode,
3066                                               PrecisionMode precision_mode, int thread_num) -> Module* {
3067         BackendConfig backend_config;
3068         backend_config.memory = memory_mode;
3069         backend_config.power = power_mode;
3070         backend_config.precision = precision_mode;
3071 
3072         Module::BackendInfo backend_info;
3073         backend_info.type = backend;
3074         backend_info.config = &backend_config;
3075 
3076         Module::Config config;
3077         config.dynamic = dynamic;
3078         config.shapeMutable = shape_mutable;
3079         config.rearrange = rearrange;
3080         config.backend = &backend_info;
3081 
3082         auto converted_file_name = convertBytesEncodeIfNeed(file_name);
3083         auto m_ptr = Module::load(inputs, outputs, converted_file_name.data(), &config);
3084         if (m_ptr == nullptr) {
3085             std::string mnn_errno = "load_module_from_file failed ";
3086             mnn_errno = mnn_errno + std::string(file_name);
3087             PyErr_SetString(PyExc_Exception, mnn_errno.c_str());
3088         }
3089         return m_ptr;
3090     });
3091 
3092 #ifdef PYMNN_TRAIN_API
3093     // CNN
3094     nn_module.def("conv", [](int in_channel, int out_channel, INTS kernel_size, INTS stride, INTS padding,
3095                              INTS dilation, bool depthwise, bool bias, PaddingMode padding_mode) {
3096             NN::ConvOption option;
3097             option.channel = {in_channel, out_channel};
3098             option.kernelSize = kernel_size;
3099             if (!stride.empty()) {
3100                 option.stride = stride;
3101             }
3102             option.padMode = padding_mode;
3103             if (!padding.empty()) {
3104                 option.pads = padding;
3105             }
3106             if (!dilation.empty()) {
3107                 option.dilate = dilation;
3108             }
3109             option.depthwise = depthwise;
3110             return NN::Conv(std::move(option), bias);
3111         },
3112         py::arg("in_channels"), py::arg("out_channels"), py::arg("kernel_size"),
3113         py::arg("stride") = std::vector<int>({1, 1}),
3114         py::arg("padding") = std::vector<int>({0, 0}),
3115         py::arg("dilation") = std::vector<int>({1, 1}),
3116         py::arg("depthwise") = false,
3117         py::arg("bias") = true,
3118         py::arg("padding_mode") = PaddingMode::VALID
3119     );
3120     nn_module.def("linear", [](int in_channel, int out_channel, bool bias) {
3121             return NN::Linear(in_channel, out_channel, bias);
3122         },
3123         py::arg("in_channels"),
3124         py::arg("out_channels"),
3125         py::arg("bias") = true
3126     );
3127     nn_module.def("batch_norm", &NN::BatchNorm, py::arg("channels"), py::arg("dims") = 4, py::arg("momentum") = 0.99, py::arg("epsilon") = 1e-5);
3128     nn_module.def("dropout", &NN::Dropout, py::arg("dropout_ratio"));
3129 
3130     auto optim_module = py_module.def_submodule("_optim");
3131 
3132     {
3133         py::enum_<ParameterOptimizer::RegularizationMethod>(optim_module, "Regularization_Method")
3134             .value("L1", ParameterOptimizer::RegularizationMethod::L1)
3135             .value("L2", ParameterOptimizer::RegularizationMethod::L2)
3136             .value("L1L2", ParameterOptimizer::RegularizationMethod::L1L2)
3137             .export_values();
3138 
3139         py::class_<ParameterOptimizer>(optim_module, "_Optimizer")
3140             .def_property("learning_rate", [](ParameterOptimizer* self) {
3141                     return ((SGD*)self)->currentLearningRate();
3142                 },
3143                 [](ParameterOptimizer* self, float lr) {
3144                     ((SGD*)self)->setLearningRate(lr);
3145                 }
3146             )
3147             .def_property("momentum", [](ParameterOptimizer* self) {
3148                     return ((SGD*)self)->getMomentum();
3149                 },
3150                 [](ParameterOptimizer* self, float m) {
3151                     ((SGD*)self)->setMomentum(m);
3152                 }
3153             )
3154             .def_property("momentum2", [](ParameterOptimizer* self) {
3155                     return ((ADAM*)self)->getMomentum2();
3156                 },
3157                 [](ParameterOptimizer* self, float m) {
3158                     ((ADAM*)self)->setMomentum2(m);
3159                 }
3160             )
3161             .def_property("weight_decay", [](ParameterOptimizer* self) {
3162                     return ((SGD*)self)->getWeightDecay();
3163                 },
3164                 [](ParameterOptimizer* self, float decay) {
3165                     ((SGD*)self)->setWeightDecay(decay);
3166                 }
3167             )
3168             .def_property("eps", [](ParameterOptimizer* self) {
3169                     return ((ADAM*)self)->getEps();
3170                 },
3171                 [](ParameterOptimizer* self, float eps) {
3172                     ((ADAM*)self)->setEps(eps);
3173                 }
3174             )
3175             .def_property("regularization_method", [](ParameterOptimizer* self) {
3176                     return ((SGD*)self)->getRegularizationMethod();
3177                 },
3178                 [](ParameterOptimizer* self, ParameterOptimizer::RegularizationMethod method) {
3179                     ((SGD*)self)->setRegularizationMethod(method);
3180                 }
3181             )
3182             .def("step", [](ParameterOptimizer* self, Express::VARP loss) {
3183                 return self->step(loss);
3184             })
3185         ;
3186 
3187         optim_module.def("SGD", &ParameterOptimizer::createSGD,
3188                         py::arg("module"),
3189                         py::arg("learning_rate"), py::arg("momentum") = 0.9, py::arg("weight_decay") = 0,
3190                         py::arg("regularization_method") = ParameterOptimizer::RegularizationMethod::L2);
3191         optim_module.def("ADAM", &ParameterOptimizer::createADAM,
3192                         py::arg("module"),
3193                         py::arg("learning_rate") = 1e-3, py::arg("momentum") = 0.9, py::arg("momentum2") = 0.999,
3194                         py::arg("weight_decay") = 0.0, py::arg("eps") = 1e-8,
3195                         py::arg("regularization_method") = ParameterOptimizer::RegularizationMethod::L2);
3196     }
3197 
3198 
3199     {
3200         class PyDataset : public Dataset {
3201         public:
3202             using Dataset::Dataset;
3203 
3204             virtual Example get(size_t index) override {
3205                 PYBIND11_OVERLOAD_PURE(Example, Dataset, __getitem__, index);
3206             }
3207             virtual size_t size() override {
3208                 PYBIND11_OVERLOAD_PURE(size_t, Dataset, __len__);
3209             }
3210         };
3211 
3212         auto data_module = py_module.def_submodule("_data");
3213         py::class_<Dataset, PyDataset, std::shared_ptr<Dataset>>(data_module, "Dataset")
3214             .def(py::init())
3215             .def("__getitem__", &Dataset::get, py::arg("index"))
3216             .def("__len__", &Dataset::size)
3217         ;
3218 
3219         py::class_<DataLoader>(data_module, "DataLoader")
3220             .def(py::init([](std::shared_ptr<Dataset> dataset, const int batchsize, const bool shuffle, const int numWorkers) {
3221                 bool stack = true;
3222                 //TODO:hardcode numworkers as 0, as to enable workers, we need gil, in private pybind, gil is removed.
3223                 return DataLoader::makeDataLoader(dataset, batchsize, stack, shuffle, 0);
3224             }), py::arg("dataset"), py::arg("batch_size"), py::arg("shuffle") = true, py::arg("num_workers") = 0)
3225             .def_property_readonly("iter_number", &DataLoader::iterNumber)
3226             .def_property_readonly("size", &DataLoader::size)
3227             .def("reset", &DataLoader::reset)
3228             .def("next", [](DataLoader* self) {
3229                 return self->next()[0]; // since we always stack
3230             })
3231         ;
3232     }
3233 
3234     {
3235         // Loss
3236         auto loss_module = nn_module.def_submodule("loss");
3237         loss_module.def("cross_entropy", _CrossEntropy, py::arg("predicts"), py::arg("one_hot_targets"));
3238         loss_module.def("kl", _KLDivergence, py::arg("predicts"), py::arg("one_hot_targets"));
3239         loss_module.def("mse", _MSE, py::arg("predicts"), py::arg("one_hot_targets"));
3240         loss_module.def("mae", _MAE, py::arg("predicts"), py::arg("one_hot_targets"));
3241         loss_module.def("hinge", _Hinge, py::arg("predicts"), py::arg("one_hot_targets"));
3242     }
3243 
3244     {
3245         auto compress_module = nn_module.def_submodule("compress");
3246         py::enum_<NN::FeatureScaleStatMethod>(compress_module, "Feature_Scale_Method")
3247             .value("PER_TENSOR", NN::PerTensor)
3248             .value("PER_CHANNEL", NN::PerChannel)
3249             .export_values();
3250         py::enum_<NN::ScaleUpdateMethod>(compress_module, "Scale_Update_Method")
3251             .value("MAXIMUM", NN::Maximum)
3252             .value("MOVING_AVERAGE", NN::MovingAverage)
3253             .export_values();
3254         compress_module.def("train_quant", &NN::turnQuantize,
3255             py::arg("module"),
3256             py::arg("quant_bits") = 8,
3257             py::arg("feature_scale_method") = NN::FeatureScaleStatMethod::PerTensor,
3258             py::arg("scale_update_method") = NN::ScaleUpdateMethod::MovingAverage);
3259     }
3260     // End of Train
3261 #endif
3262 #endif
3263 #if PY_MAJOR_VERSION >= 3
3264     return m;
3265 #else
3266     return;
3267 #endif
3268 }
3269 
3270 // MNNPyBridge invoke loadMNN by static block on Windows / Linux / Mac / Android
3271 #if defined(PYMNN_USE_ALINNPYTHON) && !defined(TARGET_OS_IOS)
3272 static std::once_flag mLoadFlag2;
3273 // Declared (extern "C" PYMNN_PUBLIC) in MNNPyBridge
loadMNN()3274 void loadMNN() {
3275     std::call_once(mLoadFlag2, [](){
3276         WeImport_AppendInittab(MOD_NAME, MOD_INIT_FUNC);
3277     });
3278 }
__anon7c5d7e065f02() 3279 static auto registerMNN = []() {
3280     loadMNN();
3281     return true;
3282 }();
3283 #endif
3284