1 /* Tencent is pleased to support the open source community by making ncnn available.
2  *
3  * Copyright (C) 2020 THL A29 Limited, a Tencent company. All rights reserved.
4  *
5  * Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
6  * in compliance with the License. You may obtain a copy of the License at
7  *
8  * https://opensource.org/licenses/BSD-3-Clause
9  *
10  * Unless required by applicable law or agreed to in writing, software distributed
11  * under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
12  * CONDITIONS OF ANY KIND, either express or implied. See the License for the
13  * specific language governing permissions and limitations under the License.
14  */
15 
16 #include <pybind11/pybind11.h>
17 #include <pybind11/stl.h>
18 #include <pybind11/numpy.h>
19 #include <pybind11/functional.h>
20 
21 #include <cpu.h>
22 #include <gpu.h>
23 #include <net.h>
24 #include <option.h>
25 #include <blob.h>
26 #include <paramdict.h>
27 
28 #include "pybind11_mat.h"
29 #include "pybind11_datareader.h"
30 #include "pybind11_allocator.h"
31 #include "pybind11_modelbin.h"
32 #include "pybind11_layer.h"
33 using namespace ncnn;
34 
35 namespace py = pybind11;
36 
37 struct LayerFactory
38 {
39     std::string name;
40     int index;
41     std::function<Layer*()> creator;
42     std::function<void(Layer*)> destroyer;
43     layer_creator_func creator_func;
44     layer_destroyer_func destroyer_func;
45 };
46 
47 #define LayerFactoryDeclear(n)                  \
48     static ncnn::Layer* LayerCreator##n(void*); \
49     static void LayerDestroyer##n(ncnn::Layer*, void*);
50 
51 LayerFactoryDeclear(0);
52 LayerFactoryDeclear(1);
53 LayerFactoryDeclear(2);
54 LayerFactoryDeclear(3);
55 LayerFactoryDeclear(4);
56 LayerFactoryDeclear(5);
57 LayerFactoryDeclear(6);
58 LayerFactoryDeclear(7);
59 LayerFactoryDeclear(8);
60 LayerFactoryDeclear(9);
61 
62 std::vector<LayerFactory> g_layer_factroys = {
63     {"", -1, nullptr, nullptr, LayerCreator0, LayerDestroyer0},
64     {"", -1, nullptr, nullptr, LayerCreator1, LayerDestroyer1},
65     {"", -1, nullptr, nullptr, LayerCreator2, LayerDestroyer2},
66     {"", -1, nullptr, nullptr, LayerCreator3, LayerDestroyer3},
67     {"", -1, nullptr, nullptr, LayerCreator4, LayerDestroyer4},
68     {"", -1, nullptr, nullptr, LayerCreator5, LayerDestroyer5},
69     {"", -1, nullptr, nullptr, LayerCreator6, LayerDestroyer6},
70     {"", -1, nullptr, nullptr, LayerCreator7, LayerDestroyer7},
71     {"", -1, nullptr, nullptr, LayerCreator8, LayerDestroyer8},
72     {"", -1, nullptr, nullptr, LayerCreator9, LayerDestroyer9},
73 };
74 int g_layer_factroy_index = 0;
75 
76 #define LayerFactoryDefine(n)                                  \
77     static ncnn::Layer* LayerCreator##n(void* p)               \
78     {                                                          \
79         if (g_layer_factroys[n].creator != nullptr)            \
80         {                                                      \
81             return g_layer_factroys[n].creator();              \
82         }                                                      \
83         return nullptr;                                        \
84     }                                                          \
85     static void LayerDestroyer##n(ncnn::Layer* layer, void* p) \
86     {                                                          \
87         if (g_layer_factroys[n].destroyer)                     \
88         {                                                      \
89             g_layer_factroys[n].destroyer(layer);              \
90         }                                                      \
91     }
92 
93 LayerFactoryDefine(0);
94 LayerFactoryDefine(1);
95 LayerFactoryDefine(2);
96 LayerFactoryDefine(3);
97 LayerFactoryDefine(4);
98 LayerFactoryDefine(5);
99 LayerFactoryDefine(6);
100 LayerFactoryDefine(7);
101 LayerFactoryDefine(8);
102 LayerFactoryDefine(9);
103 
PYBIND11_MODULE(ncnn,m)104 PYBIND11_MODULE(ncnn, m)
105 {
106     auto atexit = py::module_::import("atexit");
107     atexit.attr("register")(py::cpp_function([]() {
108         for (int i = 0; i < g_layer_factroys.size(); i++)
109         {
110             g_layer_factroys[i].creator = nullptr;
111             g_layer_factroys[i].destroyer = nullptr;
112         }
113     }));
114 
115     py::class_<Allocator, PyAllocator<> >(m, "Allocator");
116     py::class_<PoolAllocator, Allocator, PyAllocatorOther<PoolAllocator> >(m, "PoolAllocator")
117     .def(py::init<>())
118     .def("set_size_compare_ratio", &PoolAllocator::set_size_compare_ratio, py::arg("src"))
119     .def("clear", &PoolAllocator::clear)
120     .def("fastMalloc", &PoolAllocator::fastMalloc, py::arg("size"))
121     .def("fastFree", &PoolAllocator::fastFree, py::arg("ptr"));
122     py::class_<UnlockedPoolAllocator, Allocator, PyAllocatorOther<UnlockedPoolAllocator> >(m, "UnlockedPoolAllocator")
123     .def(py::init<>())
124     .def("set_size_compare_ratio", &UnlockedPoolAllocator::set_size_compare_ratio, py::arg("src"))
125     .def("clear", &UnlockedPoolAllocator::clear)
126     .def("fastMalloc", &UnlockedPoolAllocator::fastMalloc, py::arg("size"))
127     .def("fastFree", &UnlockedPoolAllocator::fastFree, py::arg("ptr"));
128 
129     py::class_<DataReader, PyDataReader<> >(m, "DataReader")
130     .def(py::init<>())
131 #if NCNN_STRING
132     .def("scan", &DataReader::scan, py::arg("format"), py::arg("p"))
133 #endif // NCNN_STRING
134     .def("read", &DataReader::read, py::arg("buf"), py::arg("size"));
135     py::class_<DataReaderFromEmpty, DataReader, PyDataReaderOther<DataReaderFromEmpty> >(m, "DataReaderFromEmpty")
136     .def(py::init<>())
137 #if NCNN_STRING
138     .def("scan", &DataReaderFromEmpty::scan, py::arg("format"), py::arg("p"))
139 #endif // NCNN_STRING
140     .def("read", &DataReaderFromEmpty::read, py::arg("buf"), py::arg("size"));
141 
142     py::class_<Blob>(m, "Blob")
143     .def(py::init<>())
144 #if NCNN_STRING
145     .def_readwrite("name", &Blob::name)
146 #endif // NCNN_STRING
147     .def_readwrite("producer", &Blob::producer)
148     .def_readwrite("consumer", &Blob::consumer)
149     .def_readwrite("shape", &Blob::shape);
150 
151     py::class_<ModelBin, PyModelBin<> >(m, "ModelBin");
152     py::class_<ModelBinFromDataReader, ModelBin, PyModelBinOther<ModelBinFromDataReader> >(m, "ModelBinFromDataReader")
153     .def(py::init<const DataReader&>(), py::arg("dr"))
154     .def("load", &ModelBinFromDataReader::load, py::arg("w"), py::arg("type"));
155     py::class_<ModelBinFromMatArray, ModelBin, PyModelBinOther<ModelBinFromMatArray> >(m, "ModelBinFromMatArray")
156     .def(py::init<const Mat*>(), py::arg("weights"))
157     .def("load", &ModelBinFromMatArray::load, py::arg("w"), py::arg("type"));
158 
159     py::class_<ParamDict>(m, "ParamDict")
160     .def(py::init<>())
161     .def("type", &ParamDict::type, py::arg("id"))
162     .def("get", (int (ParamDict::*)(int, int) const) & ParamDict::get, py::arg("id"), py::arg("def"))
163     .def("get", (float (ParamDict::*)(int, float) const) & ParamDict::get, py::arg("id"), py::arg("def"))
164     .def("get", (Mat(ParamDict::*)(int, const Mat&) const) & ParamDict::get, py::arg("id"), py::arg("def"))
165     .def("set", (void (ParamDict::*)(int, int)) & ParamDict::set, py::arg("id"), py::arg("i"))
166     .def("set", (void (ParamDict::*)(int, float)) & ParamDict::set, py::arg("id"), py::arg("f"))
167     .def("set", (void (ParamDict::*)(int, const Mat&)) & ParamDict::set, py::arg("id"), py::arg("v"));
168 
169     py::class_<Option>(m, "Option")
170     .def(py::init<>())
171     .def_readwrite("lightmode", &Option::lightmode)
172     .def_readwrite("num_threads", &Option::num_threads)
173     .def_readwrite("blob_allocator", &Option::blob_allocator)
174     .def_readwrite("workspace_allocator", &Option::workspace_allocator)
175 #if NCNN_VULKAN
176     .def_readwrite("blob_vkallocator", &Option::blob_vkallocator)
177     .def_readwrite("workspace_vkallocator", &Option::workspace_vkallocator)
178     .def_readwrite("staging_vkallocator", &Option::staging_vkallocator)
179     //.def_readwrite("pipeline_cache", &Option::pipeline_cache)
180 #endif // NCNN_VULKAN
181     .def_readwrite("openmp_blocktime", &Option::openmp_blocktime)
182     .def_readwrite("use_winograd_convolution", &Option::use_winograd_convolution)
183     .def_readwrite("use_sgemm_convolution", &Option::use_sgemm_convolution)
184     .def_readwrite("use_int8_inference", &Option::use_int8_inference)
185     .def_readwrite("use_vulkan_compute", &Option::use_vulkan_compute)
186     .def_readwrite("use_bf16_storage", &Option::use_bf16_storage)
187     .def_readwrite("use_fp16_packed", &Option::use_fp16_packed)
188     .def_readwrite("use_fp16_storage", &Option::use_fp16_storage)
189     .def_readwrite("use_fp16_arithmetic", &Option::use_fp16_arithmetic)
190     .def_readwrite("use_int8_packed", &Option::use_int8_packed)
191     .def_readwrite("use_int8_storage", &Option::use_int8_storage)
192     .def_readwrite("use_int8_arithmetic", &Option::use_int8_arithmetic)
193     .def_readwrite("use_packing_layout", &Option::use_packing_layout)
194     .def_readwrite("use_shader_pack8", &Option::use_shader_pack8)
195     .def_readwrite("use_subgroup_basic", &Option::use_subgroup_basic)
196     .def_readwrite("use_subgroup_vote", &Option::use_subgroup_vote)
197     .def_readwrite("use_subgroup_ballot", &Option::use_subgroup_ballot)
198     .def_readwrite("use_subgroup_shuffle", &Option::use_subgroup_shuffle)
199     .def_readwrite("use_image_storage", &Option::use_image_storage)
200     .def_readwrite("use_tensor_storage", &Option::use_tensor_storage)
201     .def_readwrite("use_weight_fp16_storage", &Option::use_weight_fp16_storage);
202 
203     py::class_<Mat> mat(m, "Mat", py::buffer_protocol());
204     mat.def(py::init<>())
205     .def(py::init(
206     [](py::tuple shape, size_t elemsize, int elempack, Allocator* allocator) {
207         Mat* mat = nullptr;
208         switch (shape.size())
209         {
210         case 1:
211             mat = new Mat(shape[0].cast<int>(), elemsize, elempack, allocator);
212             break;
213         case 2:
214             mat = new Mat(shape[0].cast<int>(), shape[1].cast<int>(), elemsize, elempack, allocator);
215             break;
216         case 3:
217             mat = new Mat(shape[0].cast<int>(), shape[1].cast<int>(), shape[2].cast<int>(), elemsize, elempack, allocator);
218             break;
219         case 4:
220             mat = new Mat(shape[0].cast<int>(), shape[1].cast<int>(), shape[2].cast<int>(), shape[3].cast<int>(), elemsize, elempack, allocator);
221             break;
222         default:
223             std::stringstream ss;
224             ss << "shape must be 1, 2, 3 or 4 dims, not " << shape.size();
225             pybind11::pybind11_fail(ss.str());
226         }
227         return mat;
228     }),
229     py::arg("shape"), py::kw_only(),
230     py::arg("elemsize") = 4, py::arg("elempack") = 1, py::arg("allocator") = nullptr)
231     .def(py::init<int, size_t, int, Allocator*>(),
232          py::arg("w"), py::kw_only(),
233          py::arg("elemsize") = 4, py::arg("elempack") = 1, py::arg("allocator") = nullptr)
234     .def(py::init<int, int, size_t, int, Allocator*>(),
235          py::arg("w"), py::arg("h"), py::kw_only(),
236          py::arg("elemsize") = 4, py::arg("elempack") = 1, py::arg("allocator") = nullptr)
237     .def(py::init<int, int, int, size_t, int, Allocator*>(),
238          py::arg("w"), py::arg("h"), py::arg("c"), py::kw_only(),
239          py::arg("elemsize") = 4, py::arg("elempack") = 1, py::arg("allocator") = nullptr)
240     .def(py::init<int, int, int, int, size_t, int, Allocator*>(),
241          py::arg("w"), py::arg("h"), py::arg("d"), py::arg("c"), py::kw_only(),
242          py::arg("elemsize") = 4, py::arg("elempack") = 1, py::arg("allocator") = nullptr)
243 
244     .def(py::init<const Mat&>(), py::arg("m"))
245 
246     .def(py::init([](py::buffer const b) {
247         py::buffer_info info = b.request();
248         if (info.ndim > 4)
249         {
250             std::stringstream ss;
251             ss << "convert numpy.ndarray to ncnn.Mat only dims <=4 support now, but given " << info.ndim;
252             pybind11::pybind11_fail(ss.str());
253         }
254 
255         size_t elemsize = 4u;
256         if (info.format == py::format_descriptor<double>::format())
257         {
258             elemsize = 8u;
259         }
260         if (info.format == py::format_descriptor<float>::format() || info.format == py::format_descriptor<int>::format())
261         {
262             elemsize = 4u;
263         }
264         else if (info.format == "e")
265         {
266             elemsize = 2u;
267         }
268         else if (info.format == py::format_descriptor<int8_t>::format() || info.format == py::format_descriptor<uint8_t>::format())
269         {
270             elemsize = 1u;
271         }
272 
273         Mat* v = nullptr;
274         if (info.ndim == 1)
275         {
276             v = new Mat((int)info.shape[0], info.ptr, elemsize);
277         }
278         else if (info.ndim == 2)
279         {
280             v = new Mat((int)info.shape[1], (int)info.shape[0], info.ptr, elemsize);
281         }
282         else if (info.ndim == 3)
283         {
284             v = new Mat((int)info.shape[2], (int)info.shape[1], (int)info.shape[0], info.ptr, elemsize);
285 
286             // in ncnn, buffer to construct ncnn::Mat need align to ncnn::alignSize
287             // with (w * h * elemsize, 16) / elemsize, but the buffer from numpy not
288             // so we set the cstep as numpy's cstep
289             v->cstep = (int)info.shape[2] * (int)info.shape[1];
290         }
291         else if (info.ndim == 4)
292         {
293             v = new Mat((int)info.shape[3], (int)info.shape[2], (int)info.shape[1], (int)info.shape[0], info.ptr, elemsize);
294 
295             // in ncnn, buffer to construct ncnn::Mat need align to ncnn::alignSize
296             // with (w * h * d elemsize, 16) / elemsize, but the buffer from numpy not
297             // so we set the cstep as numpy's cstep
298             v->cstep = (int)info.shape[3] * (int)info.shape[2] * (int)info.shape[1];
299         }
300         return v;
301     }),
302     py::arg("array"))
303     .def_buffer([](Mat& m) -> py::buffer_info {
304         if (m.elemsize != 1 && m.elemsize != 2 && m.elemsize != 4)
305         {
306             std::stringstream ss;
307             ss << "convert ncnn.Mat to numpy.ndarray only elemsize 1, 2, 4 support now, but given " << m.elemsize;
308             pybind11::pybind11_fail(ss.str());
309         }
310         if (m.elempack != 1)
311         {
312             std::stringstream ss;
313             ss << "convert ncnn.Mat to numpy.ndarray only elempack 1 support now, but given " << m.elempack;
314             pybind11::pybind11_fail(ss.str());
315         }
316         std::string format = get_mat_format(m);
317         std::vector<py::ssize_t> shape;
318         std::vector<py::ssize_t> strides;
319         if (m.dims == 1)
320         {
321             shape.push_back(m.w);
322             strides.push_back(m.elemsize);
323         }
324         else if (m.dims == 2)
325         {
326             shape.push_back(m.h);
327             shape.push_back(m.w);
328             strides.push_back(m.w * m.elemsize);
329             strides.push_back(m.elemsize);
330         }
331         else if (m.dims == 3)
332         {
333             shape.push_back(m.c);
334             shape.push_back(m.h);
335             shape.push_back(m.w);
336             strides.push_back(m.cstep * m.elemsize);
337             strides.push_back(m.w * m.elemsize);
338             strides.push_back(m.elemsize);
339         }
340         else if (m.dims == 4)
341         {
342             shape.push_back(m.c);
343             shape.push_back(m.d);
344             shape.push_back(m.h);
345             shape.push_back(m.w);
346             strides.push_back(m.cstep * m.elemsize);
347             strides.push_back(m.w * m.h * m.elemsize);
348             strides.push_back(m.w * m.elemsize);
349             strides.push_back(m.elemsize);
350         }
351         return py::buffer_info(
352             m.data,     /* Pointer to buffer */
353             m.elemsize, /* Size of one scalar */
354             format,     /* Python struct-style format descriptor */
355             m.dims,     /* Number of dimensions */
356             shape,      /* Buffer dimensions */
357             strides     /* Strides (in bytes) for each index */
358         );
359     })
360     //.def("fill", (void (Mat::*)(int))(&Mat::fill), py::arg("v"))
361     .def("fill", (void (Mat::*)(float))(&Mat::fill), py::arg("v"))
362     .def("clone", &Mat::clone, py::arg("allocator") = nullptr)
363     .def("clone_from", &Mat::clone_from, py::arg("mat"), py::arg("allocator") = nullptr)
364     .def(
365         "reshape",
366     [](Mat& mat, py::tuple shape, Allocator* allocator) {
367         switch (shape.size())
368         {
369         case 1:
370             return mat.reshape(shape[0].cast<int>(), allocator);
371         case 2:
372             return mat.reshape(shape[0].cast<int>(), shape[1].cast<int>(), allocator);
373         case 3:
374             return mat.reshape(shape[0].cast<int>(), shape[1].cast<int>(), shape[2].cast<int>(), allocator);
375         case 4:
376             return mat.reshape(shape[0].cast<int>(), shape[1].cast<int>(), shape[2].cast<int>(), shape[3].cast<int>(), allocator);
377         default:
378             std::stringstream ss;
379             ss << "shape must be 1, 2, 3 or 4 dims, not " << shape.size();
380             pybind11::pybind11_fail(ss.str());
381         }
382         return Mat();
383     },
384     py::arg("shape") = py::tuple(1), py::arg("allocator") = nullptr)
385     .def("reshape", (Mat(Mat::*)(int, Allocator*) const) & Mat::reshape,
386          py::arg("w"), py::kw_only(), py::arg("allocator") = nullptr)
387     .def("reshape", (Mat(Mat::*)(int, int, Allocator*) const) & Mat::reshape,
388          py::arg("w"), py::arg("h"), py::kw_only(), py::arg("allocator") = nullptr)
389     .def("reshape", (Mat(Mat::*)(int, int, int, Allocator*) const) & Mat::reshape,
390          py::arg("w"), py::arg("h"), py::arg("c"), py::kw_only(), py::arg("allocator") = nullptr)
391     .def("reshape", (Mat(Mat::*)(int, int, int, int, Allocator*) const) & Mat::reshape,
392          py::arg("w"), py::arg("h"), py::arg("d"), py::arg("c"), py::kw_only(), py::arg("allocator") = nullptr)
393 
394     .def(
395         "create",
396     [](Mat& mat, py::tuple shape, size_t elemsize, int elempack, Allocator* allocator) {
397         switch (shape.size())
398         {
399         case 1:
400             return mat.create(shape[0].cast<int>(), elemsize, elempack, allocator);
401         case 2:
402             return mat.create(shape[0].cast<int>(), shape[1].cast<int>(), elemsize, elempack, allocator);
403         case 3:
404             return mat.create(shape[0].cast<int>(), shape[1].cast<int>(), shape[2].cast<int>(), elemsize, elempack, allocator);
405         case 4:
406             return mat.create(shape[0].cast<int>(), shape[1].cast<int>(), shape[2].cast<int>(), shape[3].cast<int>(), elemsize, elempack, allocator);
407         default:
408             std::stringstream ss;
409             ss << "shape must be 1, 2, 3 or 4 dims, not " << shape.size();
410             pybind11::pybind11_fail(ss.str());
411         }
412         return;
413     },
414     py::arg("shape"), py::kw_only(),
415     py::arg("elemsize") = 4, py::arg("elempack") = 1,
416     py::arg("allocator") = nullptr)
417     .def("create", (void (Mat::*)(int, size_t, int, Allocator*)) & Mat::create,
418          py::arg("w"), py::kw_only(),
419          py::arg("elemsize") = 4, py::arg("elempack") = 1, py::arg("allocator") = nullptr)
420     .def("create", (void (Mat::*)(int, int, size_t, int, Allocator*)) & Mat::create,
421          py::arg("w"), py::arg("h"), py::kw_only(),
422          py::arg("elemsize") = 4, py::arg("elempack") = 1, py::arg("allocator") = nullptr)
423     .def("create", (void (Mat::*)(int, int, int, size_t, int, Allocator*)) & Mat::create,
424          py::arg("w"), py::arg("h"), py::arg("c"), py::kw_only(),
425          py::arg("elemsize") = 4, py::arg("elempack") = 1, py::arg("allocator") = nullptr)
426     .def("create", (void (Mat::*)(int, int, int, int, size_t, int, Allocator*)) & Mat::create,
427          py::arg("w"), py::arg("h"), py::arg("d"), py::arg("c"), py::kw_only(),
428          py::arg("elemsize") = 4, py::arg("elempack") = 1, py::arg("allocator") = nullptr)
429     .def("create_like", (void (Mat::*)(const Mat&, Allocator*)) & Mat::create_like,
430          py::arg("m"), py::arg("allocator") = nullptr)
431     .def("addref", &Mat::addref)
432     .def("release", &Mat::release)
433     .def("empty", &Mat::empty)
434     .def("total", &Mat::total)
435     .def("elembits", &Mat::elembits)
436     .def("shape", &Mat::shape)
437     .def("channel", (Mat(Mat::*)(int)) & Mat::channel, py::arg("c"))
438     //.def("channel", (const Mat (Mat::*)(int) const) & Mat::channel, py::arg("c"))
439     .def("depth", (Mat(Mat::*)(int)) & Mat::depth, py::arg("z"))
440     //.def("depth", (const Mat (Mat::*)(int) const) & Mat::depth, py::arg("z"))
441     .def(
442         "row",
443     [](Mat& m, int y) {
444         if (m.elempack != 1)
445         {
446             std::stringstream ss;
447             ss << "get ncnn.Mat row only elempack 1 support now, but given " << m.elempack;
448             pybind11::pybind11_fail(ss.str());
449         }
450 
451         switch (m.elemsize)
452         {
453         case 1:
454             return py::memoryview::from_buffer(m.row<int8_t>(y), {m.w}, {sizeof(int8_t)});
455         //case 2:
456         //    return py::memoryview::from_buffer(m.row<short>(y), {m.w}, {sizeof(short)});
457         case 4:
458             return py::memoryview::from_buffer(m.row<float>(y), {m.w}, {sizeof(float)});
459         default:
460             std::stringstream ss;
461             ss << "ncnn.Mat row elemsize " << m.elemsize << "not support now";
462             pybind11::pybind11_fail(ss.str());
463         }
464         return py::memoryview::from_buffer(m.row<float>(y), {m.w}, {sizeof(float)});
465     },
466     py::arg("y"))
467     .def("channel_range", (Mat(Mat::*)(int, int)) & Mat::channel_range, py::arg("c"), py::arg("channels"))
468     //.def("channel_range", (const Mat (Mat::*)(int, int) const) & Mat::channel_range, py::arg("c"), py::arg("channels"))
469     .def("depth_range", (Mat(Mat::*)(int, int)) & Mat::depth_range, py::arg("z"), py::arg("depths"))
470     //.def("depth_range", (const Mat (Mat::*)(int, int) const) & Mat::depth_range, py::arg("z"), py::arg("depths"))
471     .def("row_range", (Mat(Mat::*)(int, int)) & Mat::row_range, py::arg("y"), py::arg("rows"))
472     //.def("row_range", (const Mat (Mat::*)(int, int) const) & Mat::row_range, py::arg("y"), py::arg("rows"))
473     .def("range", (Mat(Mat::*)(int, int)) & Mat::range, py::arg("x"), py::arg("n"))
474     //.def("range", (const Mat (Mat::*)(int, int) const) & Mat::range, py::arg("x"), py::arg("n"))
475     .def(
476     "__getitem__", [](const Mat& m, size_t i) {
477         return m[i];
478     },
479     py::arg("i"))
480     .def(
481     "__setitem__", [](Mat& m, size_t i, float v) {
482         m[i] = v;
483     },
484     py::arg("i"), py::arg("v"))
485     .def("__len__", [](Mat& m) {
486         return m.w;
487     })
488 
489     //convenient construct from pixel data
490     .def_static(
491     "from_pixels", [](py::buffer const b, int type, int w, int h, Allocator* allocator) {
492         return Mat::from_pixels((const unsigned char*)b.request().ptr, type, w, h, allocator);
493     },
494     py::arg("array"), py::arg("type"), py::arg("w"), py::arg("h"), py::arg("allocator") = nullptr)
495     .def_static(
496     "from_pixels", [](py::buffer const b, int type, int w, int h, int stride, Allocator* allocator) {
497         return Mat::from_pixels((const unsigned char*)b.request().ptr, type, w, h, stride, allocator);
498     },
499     py::arg("array"), py::arg("type"), py::arg("w"), py::arg("h"), py::arg("stride"), py::arg("allocator") = nullptr)
500     .def_static(
501     "from_pixels_resize", [](py::buffer const b, int type, int w, int h, int target_width, int target_height, Allocator* allocator) {
502         return Mat::from_pixels_resize((const unsigned char*)b.request().ptr,
503                                        type, w, h, target_width, target_height, allocator);
504     },
505     py::arg("array"), py::arg("type"), py::arg("w"), py::arg("h"), py::arg("target_width"), py::arg("target_height"), py::arg("allocator") = nullptr)
506     .def_static(
507     "from_pixels_resize", [](py::buffer const b, int type, int w, int h, int stride, int target_width, int target_height, Allocator* allocator) {
508         return Mat::from_pixels_resize((const unsigned char*)b.request().ptr,
509                                        type, w, h, stride, target_width, target_height, allocator);
510     },
511     py::arg("array"), py::arg("type"), py::arg("w"), py::arg("h"), py::arg("stride"), py::arg("target_width"), py::arg("target_height"), py::arg("allocator") = nullptr)
512     .def_static(
513     "from_pixels_roi", [](py::buffer const b, int type, int w, int h, int roix, int roiy, int roiw, int roih, Allocator* allocator) {
514         return Mat::from_pixels_roi((const unsigned char*)b.request().ptr,
515                                     type, w, h, roix, roiy, roiw, roih, allocator);
516     },
517     py::arg("array"), py::arg("type"), py::arg("w"), py::arg("h"), py::arg("roix"), py::arg("roiy"), py::arg("roiw"), py::arg("roih"), py::arg("allocator") = nullptr)
518     .def_static(
519     "from_pixels_roi", [](py::buffer const b, int type, int w, int h, int stride, int roix, int roiy, int roiw, int roih, Allocator* allocator) {
520         return Mat::from_pixels_roi((const unsigned char*)b.request().ptr,
521                                     type, w, h, stride, roix, roiy, roiw, roih, allocator);
522     },
523     py::arg("array"), py::arg("type"), py::arg("w"), py::arg("h"), py::arg("stride"), py::arg("roix"), py::arg("roiy"), py::arg("roiw"), py::arg("roih"), py::arg("allocator") = nullptr)
524     .def_static(
525     "from_pixels_roi_resize", [](py::buffer const b, int type, int w, int h, int roix, int roiy, int roiw, int roih, int target_width, int target_height, Allocator* allocator) {
526         return Mat::from_pixels_roi_resize((const unsigned char*)b.request().ptr,
527                                            type, w, h, roix, roiy, roiw, roih, target_width, target_height, allocator);
528     },
529     py::arg("array"), py::arg("type"), py::arg("w"), py::arg("h"), py::arg("roix"), py::arg("roiy"), py::arg("roiw"), py::arg("roih"), py::arg("target_width"), py::arg("target_height"), py::arg("allocator") = nullptr)
530     .def_static(
531     "from_pixels_roi_resize", [](py::buffer const b, int type, int w, int h, int stride, int roix, int roiy, int roiw, int roih, int target_width, int target_height, Allocator* allocator) {
532         return Mat::from_pixels_roi_resize((const unsigned char*)b.request().ptr,
533                                            type, w, h, stride, roix, roiy, roiw, roih, target_width, target_height, allocator);
534     },
535     py::arg("array"), py::arg("type"), py::arg("w"), py::arg("h"), py::arg("stride"), py::arg("roix"), py::arg("roiy"), py::arg("roiw"), py::arg("roih"), py::arg("target_width"), py::arg("target_height"), py::arg("allocator") = nullptr)
536     .def(
537     "substract_mean_normalize", [](Mat& mat, std::vector<float>& mean, std::vector<float>& norm) {
538         return mat.substract_mean_normalize(mean.size() > 0 ? &mean[0] : 0, norm.size() > 0 ? &norm[0] : 0);
539     },
540     py::arg("mean"), py::arg("norm"))
541     .def_readwrite("refcount", &Mat::refcount)
542     .def_readwrite("elemsize", &Mat::elemsize)
543     .def_readwrite("elempack", &Mat::elempack)
544     .def_readwrite("allocator", &Mat::allocator)
545     .def_readwrite("dims", &Mat::dims)
546     .def_readwrite("w", &Mat::w)
547     .def_readwrite("h", &Mat::h)
548     .def_readwrite("d", &Mat::d)
549     .def_readwrite("c", &Mat::c)
550     .def_readwrite("cstep", &Mat::cstep)
551     .def("__repr__", [](const Mat& m) {
552         std::stringstream ss;
553         ss << "<ncnn.Mat w=" << m.w << " h=" << m.h << " d=" << m.d << " c=" << m.c << " dims=" << m.dims
554            << " cstep=" << m.cstep << " elemsize=" << m.elemsize << " elempack=" << m.elempack << "\n\t"
555            << "refcount=" << (m.refcount ? *m.refcount : 0) << " data=0x" << static_cast<const void*>(m.data)
556            << " allocator=0x" << static_cast<const void*>(m.allocator) << ">\n";
557 
558         const int max_count = m.dims == 1 ? 10 : 6;
559         if (m.dims == 1)
560         {
561             ss << "[";
562             bool dot_printed_w = false;
563 
564             if (m.elemsize == 1)
565             {
566                 const int8_t* row = m.row<int8_t>(0);
567                 for (int i = 0; i < m.w; i++)
568                 {
569                     if (i < max_count / 2 || i >= m.w - max_count / 2)
570                     {
571                         if (i > 0)
572                         {
573                             ss << ", ";
574                         }
575                         ss << static_cast<int>(row[i]);
576                     }
577                     else if (!dot_printed_w)
578                     {
579                         dot_printed_w = true;
580                         ss << ", ...";
581                     }
582                 }
583             }
584             if (m.elemsize == 4)
585             {
586                 const float* row = m.row<float>(0);
587                 for (int i = 0; i < m.w; i++)
588                 {
589                     if (i < max_count / 2 || i >= m.w - max_count / 2)
590                     {
591                         if (i > 0)
592                         {
593                             ss << ", ";
594                         }
595                         ss << row[i];
596                     }
597                     else if (!dot_printed_w)
598                     {
599                         dot_printed_w = true;
600                         ss << ", ...";
601                     }
602                 }
603             }
604             ss << "]";
605         }
606         else if (m.dims == 2)
607         {
608             bool dot_printed_h = false;
609             ss << "[";
610             for (int j = 0; j < m.h; j++)
611             {
612                 bool dot_printed_w = false;
613                 if (j < max_count / 2 || j >= m.h - max_count / 2)
614                 {
615                     ss << "[";
616                     if (m.elemsize == 1)
617                     {
618                         const int8_t* row = m.row<int8_t>(j);
619                         for (int i = 0; i < m.w; i++)
620                         {
621                             if (i < max_count / 2 || i >= m.w - max_count / 2)
622                             {
623                                 if (i > 0)
624                                 {
625                                     ss << ", ";
626                                 }
627                                 ss << static_cast<int>(row[i]);
628                             }
629                             else if (!dot_printed_w)
630                             {
631                                 dot_printed_w = true;
632                                 ss << ", ...";
633                             }
634                         }
635                     }
636                     if (m.elemsize == 4)
637                     {
638                         const float* row = m.row<float>(j);
639                         for (int i = 0; i < m.w; i++)
640                         {
641                             if (i < max_count / 2 || i >= m.w - max_count / 2)
642                             {
643                                 if (i > 0)
644                                 {
645                                     ss << ", ";
646                                 }
647                                 ss << row[i];
648                             }
649                             else if (!dot_printed_w)
650                             {
651                                 dot_printed_w = true;
652                                 ss << ", ...";
653                             }
654                         }
655                     }
656                     ss << "]";
657                     if (j < m.h - 1)
658                     {
659                         ss << "\n";
660                     }
661                 }
662                 else if (!dot_printed_h)
663                 {
664                     dot_printed_h = true;
665                     ss << "...\n";
666                 }
667             }
668             ss << "]\n";
669         }
670         else if (m.dims == 3)
671         {
672             bool dot_printed_c = false;
673             ss << "[";
674             for (int k = 0; k < m.c; k++)
675             {
676                 bool dot_printed_h = false;
677                 if (k < max_count / 2 || k >= m.c - max_count / 2)
678                 {
679                     Mat channel = m.channel(k);
680                     if (k > 0)
681                     {
682                         ss << " ";
683                     }
684                     ss << "[";
685                     for (int j = 0; j < channel.h; j++)
686                     {
687                         bool dot_printed_w = false;
688                         if (j < max_count / 2 || j >= channel.h - max_count / 2)
689                         {
690                             if (j > 0)
691                             {
692                                 ss << "  ";
693                             }
694                             ss << "[";
695                             if (m.elemsize == 1)
696                             {
697                                 const int8_t* row = channel.row<int8_t>(j);
698                                 for (int i = 0; i < channel.w; i++)
699                                 {
700                                     if (i < max_count / 2 || i >= channel.w - max_count / 2)
701                                     {
702                                         if (i > 0)
703                                         {
704                                             ss << ", ";
705                                         }
706                                         ss << static_cast<int>(row[i]);
707                                     }
708                                     else if (!dot_printed_w)
709                                     {
710                                         dot_printed_w = true;
711                                         ss << ", ...";
712                                     }
713                                 }
714                             }
715                             if (m.elemsize == 4)
716                             {
717                                 const float* row = channel.row<float>(j);
718                                 for (int i = 0; i < m.w; i++)
719                                 {
720                                     if (i < max_count / 2 || i >= m.w - max_count / 2)
721                                     {
722                                         if (i > 0)
723                                         {
724                                             ss << ", ";
725                                         }
726                                         ss << row[i];
727                                     }
728                                     else if (!dot_printed_w)
729                                     {
730                                         dot_printed_w = true;
731                                         ss << ", ...";
732                                     }
733                                 }
734                             }
735                             ss << "]";
736                             if (j < channel.h - 1)
737                             {
738                                 ss << "\n";
739                             }
740                         }
741                         else if (!dot_printed_h)
742                         {
743                             dot_printed_h = true;
744                             ss << "  ...\n";
745                         }
746                     } // for j
747                     ss << "]";
748                     if (k < m.c - 1)
749                     {
750                         ss << "\n\n";
751                     }
752                 }
753                 else if (!dot_printed_c)
754                 {
755                     dot_printed_c = true;
756                     ss << " ...\n";
757                 }
758             } // for k
759             ss << "]\n";
760         }
761         else if (m.dims == 4)
762         {
763             bool dot_printed_c = false;
764             ss << "[";
765             for (int k = 0; k < m.c; k++)
766             {
767                 bool dot_printed_d = false;
768                 if (k < max_count / 2 || k >= m.c - max_count / 2)
769                 {
770                     Mat channel = m.channel(k);
771                     if (k > 0)
772                     {
773                         ss << " ";
774                     }
775                     ss << "[";
776                     for (int z = 0; z < channel.d; z++)
777                     {
778                         bool dot_printed_h = false;
779                         if (z < max_count / 2 || z >= channel.d - max_count / 2)
780                         {
781                             if (z > 0)
782                             {
783                                 ss << "  ";
784                             }
785                             ss << "[";
786                             for (int j = 0; j < channel.h; j++)
787                             {
788                                 bool dot_printed_w = false;
789                                 if (j < max_count / 2 || j >= channel.h - max_count / 2)
790                                 {
791                                     if (j > 0)
792                                     {
793                                         ss << "  ";
794                                     }
795                                     ss << "[";
796                                     if (m.elemsize == 1)
797                                     {
798                                         const int8_t* row = channel.depth(z).row<int8_t>(j);
799                                         for (int i = 0; i < channel.w; i++)
800                                         {
801                                             if (i < max_count / 2 || i >= channel.w - max_count / 2)
802                                             {
803                                                 if (i > 0)
804                                                 {
805                                                     ss << ", ";
806                                                 }
807                                                 ss << static_cast<int>(row[i]);
808                                             }
809                                             else if (!dot_printed_w)
810                                             {
811                                                 dot_printed_w = true;
812                                                 ss << ", ...";
813                                             }
814                                         }
815                                     }
816                                     if (m.elemsize == 4)
817                                     {
818                                         const float* row = channel.depth(z).row<float>(j);
819                                         for (int i = 0; i < m.w; i++)
820                                         {
821                                             if (i < max_count / 2 || i >= m.w - max_count / 2)
822                                             {
823                                                 if (i > 0)
824                                                 {
825                                                     ss << ", ";
826                                                 }
827                                                 ss << row[i];
828                                             }
829                                             else if (!dot_printed_w)
830                                             {
831                                                 dot_printed_w = true;
832                                                 ss << ", ...";
833                                             }
834                                         }
835                                     }
836                                     ss << "]";
837                                     if (j < channel.h - 1)
838                                     {
839                                         ss << "\n";
840                                     }
841                                 }
842                                 else if (!dot_printed_h)
843                                 {
844                                     dot_printed_h = true;
845                                     ss << "  ...\n";
846                                 }
847                             } // for j
848                             ss << "]";
849                             if (z < channel.d - 1)
850                             {
851                                 ss << "\n";
852                             }
853                         }
854                         else if (!dot_printed_d)
855                         {
856                             dot_printed_d = true;
857                             ss << " ...\n";
858                         }
859                     } // for z
860                     ss << "]";
861                     if (k < m.c - 1)
862                     {
863                         ss << "\n\n";
864                     }
865                 }
866                 else if (!dot_printed_c)
867                 {
868                     dot_printed_c = true;
869                     ss << " ...\n";
870                 }
871             } // for k
872             ss << "]\n";
873         }
874         return ss.str();
875     });
876 
877     py::enum_<ncnn::Mat::PixelType>(mat, "PixelType")
878     .value("PIXEL_CONVERT_SHIFT", ncnn::Mat::PixelType::PIXEL_CONVERT_SHIFT)
879     .value("PIXEL_FORMAT_MASK", ncnn::Mat::PixelType::PIXEL_FORMAT_MASK)
880     .value("PIXEL_CONVERT_MASK", ncnn::Mat::PixelType::PIXEL_CONVERT_MASK)
881 
882     .value("PIXEL_RGB", ncnn::Mat::PixelType::PIXEL_RGB)
883     .value("PIXEL_BGR", ncnn::Mat::PixelType::PIXEL_BGR)
884     .value("PIXEL_GRAY", ncnn::Mat::PixelType::PIXEL_GRAY)
885     .value("PIXEL_RGBA", ncnn::Mat::PixelType::PIXEL_RGBA)
886     .value("PIXEL_BGRA", ncnn::Mat::PixelType::PIXEL_BGRA)
887 
888     .value("PIXEL_RGB2BGR", ncnn::Mat::PixelType::PIXEL_RGB2BGR)
889     .value("PIXEL_RGB2GRAY", ncnn::Mat::PixelType::PIXEL_RGB2GRAY)
890     .value("PIXEL_RGB2RGBA", ncnn::Mat::PixelType::PIXEL_RGB2RGBA)
891     .value("PIXEL_RGB2BGRA", ncnn::Mat::PixelType::PIXEL_RGB2BGRA)
892 
893     .value("PIXEL_BGR2RGB", ncnn::Mat::PixelType::PIXEL_BGR2RGB)
894     .value("PIXEL_BGR2GRAY", ncnn::Mat::PixelType::PIXEL_BGR2GRAY)
895     .value("PIXEL_BGR2RGBA", ncnn::Mat::PixelType::PIXEL_BGR2RGBA)
896     .value("PIXEL_BGR2BGRA", ncnn::Mat::PixelType::PIXEL_BGR2BGRA)
897 
898     .value("PIXEL_GRAY2RGB", ncnn::Mat::PixelType::PIXEL_GRAY2RGB)
899     .value("PIXEL_GRAY2BGR", ncnn::Mat::PixelType::PIXEL_GRAY2BGR)
900     .value("PIXEL_GRAY2RGBA", ncnn::Mat::PixelType::PIXEL_GRAY2RGBA)
901     .value("PIXEL_GRAY2BGRA", ncnn::Mat::PixelType::PIXEL_GRAY2BGRA)
902 
903     .value("PIXEL_RGBA2RGB", ncnn::Mat::PixelType::PIXEL_RGBA2RGB)
904     .value("PIXEL_RGBA2BGR", ncnn::Mat::PixelType::PIXEL_RGBA2BGR)
905     .value("PIXEL_RGBA2GRAY", ncnn::Mat::PixelType::PIXEL_RGBA2GRAY)
906     .value("PIXEL_RGBA2BGRA", ncnn::Mat::PixelType::PIXEL_RGBA2BGRA)
907 
908     .value("PIXEL_BGRA2RGB", ncnn::Mat::PixelType::PIXEL_BGRA2RGB)
909     .value("PIXEL_BGRA2BGR", ncnn::Mat::PixelType::PIXEL_BGRA2BGR)
910     .value("PIXEL_BGRA2GRAY", ncnn::Mat::PixelType::PIXEL_BGRA2GRAY)
911     .value("PIXEL_BGRA2RGBA", ncnn::Mat::PixelType::PIXEL_BGRA2RGBA);
912 
913     py::class_<Extractor>(m, "Extractor")
914     .def("__enter__", [](Extractor& ex) -> Extractor& { return ex; })
915     .def("__exit__", [](Extractor& ex, pybind11::args) {
916         ex.clear();
917     })
918     .def("clear", &Extractor::clear)
919     .def("set_light_mode", &Extractor::set_light_mode, py::arg("enable"))
920     .def("set_num_threads", &Extractor::set_num_threads, py::arg("num_threads"))
921     .def("set_blob_allocator", &Extractor::set_blob_allocator, py::arg("allocator"))
922     .def("set_workspace_allocator", &Extractor::set_workspace_allocator, py::arg("allocator"))
923 #if NCNN_STRING
924     .def("input", (int (Extractor::*)(const char*, const Mat&)) & Extractor::input, py::arg("blob_name"), py::arg("in"))
925     .def("extract", (int (Extractor::*)(const char*, Mat&, int)) & Extractor::extract, py::arg("blob_name"), py::arg("feat"), py::arg("type") = 0)
926     .def(
927     "extract", [](Extractor& ex, const char* blob_name, int type) {
928         ncnn::Mat feat;
929         int ret = ex.extract(blob_name, feat, type);
930         return py::make_tuple(ret, feat.clone());
931     },
932     py::arg("blob_name"), py::arg("type") = 0)
933 #endif
934     .def("input", (int (Extractor::*)(int, const Mat&)) & Extractor::input)
935     .def("extract", (int (Extractor::*)(int, Mat&, int)) & Extractor::extract, py::arg("blob_index"), py::arg("feat"), py::arg("type") = 0)
936     .def(
937     "extract", [](Extractor& ex, int blob_index, int type) {
938         ncnn::Mat feat;
939         int ret = ex.extract(blob_index, feat, type);
940         return py::make_tuple(ret, feat.clone());
941     },
942     py::arg("blob_index"), py::arg("type") = 0);
943 
944     py::class_<Layer, PyLayer>(m, "Layer")
945     .def(py::init<>())
946     .def("load_param", &Layer::load_param, py::arg("pd"))
947     .def("load_model", &Layer::load_model, py::arg("mb"))
948     .def("create_pipeline", &Layer::create_pipeline, py::arg("opt"))
949     .def("destroy_pipeline", &Layer::destroy_pipeline, py::arg("opt"))
950     .def_readwrite("one_blob_only", &Layer::one_blob_only)
951     .def_readwrite("support_inplace", &Layer::support_inplace)
952     .def_readwrite("support_vulkan", &Layer::support_vulkan)
953     .def_readwrite("support_packing", &Layer::support_packing)
954     .def_readwrite("support_bf16_storage", &Layer::support_bf16_storage)
955     .def_readwrite("support_fp16_storage", &Layer::support_fp16_storage)
956     .def_readwrite("support_image_storage", &Layer::support_image_storage)
957     .def_readwrite("support_weight_fp16_storage", &Layer::support_weight_fp16_storage)
958     .def("forward", (int (Layer::*)(const std::vector<Mat>&, std::vector<Mat>&, const Option&) const) & Layer::forward,
959          py::arg("bottom_blobs"), py::arg("top_blobs"), py::arg("opt"))
960     .def("forward", (int (Layer::*)(const Mat&, Mat&, const Option&) const) & Layer::forward,
961          py::arg("bottom_blob"), py::arg("top_blob"), py::arg("opt"))
962     .def("forward_inplace", (int (Layer::*)(std::vector<Mat>&, const Option&) const) & Layer::forward_inplace,
963          py::arg("bottom_top_blobs"), py::arg("opt"))
964     .def("forward_inplace", (int (Layer::*)(Mat&, const Option&) const) & Layer::forward_inplace,
965          py::arg("bottom_top_blob"), py::arg("opt"))
966     .def_readwrite("typeindex", &Layer::typeindex)
967 #if NCNN_STRING
968     .def_readwrite("type", &Layer::type)
969     .def_readwrite("name", &Layer::name)
970 #endif // NCNN_STRING
971     .def_readwrite("bottoms", &Layer::bottoms)
972     .def_readwrite("tops", &Layer::tops)
973     .def_readwrite("bottom_shapes", &Layer::bottom_shapes)
974     .def_readwrite("top_shapes", &Layer::top_shapes);
975 
976     py::class_<Net>(m, "Net")
977     .def(py::init<>())
978     .def_readwrite("opt", &Net::opt)
979     .def("__enter__", [](Net& net) -> Net& { return net; })
980     .def("__exit__", [](Net& net, pybind11::args) {
981         net.clear();
982     })
983 
984 #if NCNN_VULKAN
985     .def("set_vulkan_device", (void (Net::*)(int)) & Net::set_vulkan_device, py::arg("device_index"))
986     .def("set_vulkan_device", (void (Net::*)(const VulkanDevice*)) & Net::set_vulkan_device, py::arg("vkdev"))
987     .def("vulkan_device", &Net::vulkan_device, py::return_value_policy::reference_internal)
988 #endif // NCNN_VULKAN
989 
990 #if NCNN_STRING
991     .def(
992     "register_custom_layer", [](Net& net, const char* type, const std::function<ncnn::Layer*()>& creator, const std::function<void(ncnn::Layer*)>& destroyer) {
993         if (g_layer_factroy_index == g_layer_factroys.size())
994         {
995             std::stringstream ss;
996             ss << "python version only support " << g_layer_factroys.size() << " custom layers now";
997             pybind11::pybind11_fail(ss.str());
998         }
999         LayerFactory& lf = g_layer_factroys[g_layer_factroy_index++];
1000         lf.name = type;
1001         lf.creator = creator;
1002         lf.destroyer = destroyer;
1003         return net.register_custom_layer(lf.name.c_str(), lf.creator_func, lf.destroyer_func);
1004     },
1005     py::arg("type"), py::arg("creator"), py::arg("destroyer"))
1006 #endif //NCNN_STRING
1007     .def(
1008     "register_custom_layer", [](Net& net, int index, const std::function<ncnn::Layer*()>& creator, const std::function<void(ncnn::Layer*)>& destroyer) {
1009         if (g_layer_factroy_index == g_layer_factroys.size())
1010         {
1011             std::stringstream ss;
1012             ss << "python version only support " << g_layer_factroys.size() << " custom layers now";
1013             pybind11::pybind11_fail(ss.str());
1014         }
1015         LayerFactory& lf = g_layer_factroys[g_layer_factroy_index++];
1016         lf.index = index;
1017         lf.creator = creator;
1018         lf.destroyer = destroyer;
1019         return net.register_custom_layer(index, lf.creator_func, lf.destroyer_func);
1020     },
1021     py::arg("index"), py::arg("creator"), py::arg("destroyer"))
1022 #if NCNN_STRING
1023     .def("load_param", (int (Net::*)(const DataReader&)) & Net::load_param, py::arg("dr"))
1024 #endif // NCNN_STRING
1025     .def("load_param_bin", (int (Net::*)(const DataReader&)) & Net::load_param_bin, py::arg("dr"))
1026     .def("load_model", (int (Net::*)(const DataReader&)) & Net::load_model, py::arg("dr"))
1027 
1028 #if NCNN_STDIO
1029 #if NCNN_STRING
1030     .def("load_param", (int (Net::*)(const char*)) & Net::load_param, py::arg("protopath"))
1031 #endif // NCNN_STRING
1032     .def("load_param_bin", (int (Net::*)(const char*)) & Net::load_param_bin, py::arg("protopath"))
1033     .def("load_model", (int (Net::*)(const char*)) & Net::load_model, py::arg("modelpath"))
1034 #endif // NCNN_STDIO
1035 
1036     .def("clear", &Net::clear)
1037     .def("create_extractor", &Net::create_extractor, py::keep_alive<0, 1>()) //net should be kept alive until retuned ex is freed by gc
1038 
1039     .def("input_indexes", &Net::input_indexes, py::return_value_policy::reference)
1040     .def("input_indexes", &Net::output_indexes, py::return_value_policy::reference)
1041 #if NCNN_STRING
1042     .def("input_names", &Net::input_names, py::return_value_policy::reference)
1043     .def("output_names", &Net::output_names, py::return_value_policy::reference)
1044 #endif // NCNN_STRING
1045 
1046     .def("blobs", &Net::blobs, py::return_value_policy::reference_internal)
1047     .def("layers", &Net::layers, py::return_value_policy::reference_internal);
1048 
1049     py::enum_<ncnn::BorderType>(m, "BorderType")
1050     .value("BORDER_CONSTANT", ncnn::BorderType::BORDER_CONSTANT)
1051     .value("BORDER_REPLICATE", ncnn::BorderType::BORDER_REPLICATE);
1052 
1053     m.def("cpu_support_arm_neon", &cpu_support_arm_neon);
1054     m.def("cpu_support_arm_vfpv4", &cpu_support_arm_vfpv4);
1055     m.def("cpu_support_arm_asimdhp", &cpu_support_arm_asimdhp);
1056     m.def("cpu_support_x86_avx2", &cpu_support_x86_avx2);
1057     m.def("cpu_support_x86_avx", &cpu_support_x86_avx);
1058     m.def("get_cpu_count", &get_cpu_count);
1059     m.def("get_little_cpu_count", &get_little_cpu_count);
1060     m.def("get_big_cpu_count", &get_big_cpu_count);
1061     m.def("get_cpu_powersave", &get_cpu_powersave);
1062     m.def("set_cpu_powersave", &set_cpu_powersave, py::arg("powersave"));
1063     m.def("get_omp_num_threads", &get_omp_num_threads);
1064     m.def("set_omp_num_threads", &set_omp_num_threads, py::arg("num_threads"));
1065     m.def("get_omp_dynamic", &get_omp_dynamic);
1066     m.def("set_omp_dynamic", &set_omp_dynamic, py::arg("dynamic"));
1067     m.def("get_omp_thread_num", &get_omp_thread_num);
1068     m.def("get_kmp_blocktime", &get_kmp_blocktime);
1069     m.def("set_kmp_blocktime", &set_kmp_blocktime, py::arg("time_ms"));
1070 
1071     m.def("copy_make_border", &copy_make_border,
1072           py::arg("src"), py::arg("dst"),
1073           py::arg("top"), py::arg("bottom"), py::arg("left"), py::arg("right"),
1074           py::arg("type"), py::arg("v"), py::arg("opt") = Option());
1075     m.def(
1076         "copy_make_border",
1077     [](const Mat& src, int top, int bottom, int left, int right, int type, float v, const Option& opt) {
1078         Mat dst;
1079         copy_make_border(src, dst, top, bottom, left, right, type, v, opt);
1080         return dst;
1081     },
1082     py::arg("src"),
1083     py::arg("top"), py::arg("bottom"), py::arg("left"), py::arg("right"),
1084     py::arg("type"), py::arg("v"), py::arg("opt") = Option());
1085 
1086     m.def("copy_make_border_3d", &copy_make_border_3d,
1087           py::arg("src"), py::arg("dst"),
1088           py::arg("top"), py::arg("bottom"), py::arg("left"), py::arg("right"), py::arg("front"), py::arg("behind"),
1089           py::arg("type"), py::arg("v"), py::arg("opt") = Option());
1090     m.def(
1091         "copy_make_border_3d",
1092     [](const Mat& src, int top, int bottom, int left, int right, int front, int behind, int type, float v, const Option& opt) {
1093         Mat dst;
1094         copy_make_border_3d(src, dst, top, bottom, left, right, front, behind, type, v, opt);
1095         return dst;
1096     },
1097     py::arg("src"),
1098     py::arg("top"), py::arg("bottom"), py::arg("left"), py::arg("right"), py::arg("front"), py::arg("behind"),
1099     py::arg("type"), py::arg("v"), py::arg("opt") = Option());
1100 
1101     m.def("copy_cut_border", &copy_cut_border,
1102           py::arg("src"), py::arg("dst"),
1103           py::arg("top"), py::arg("bottom"), py::arg("left"), py::arg("right"),
1104           py::arg("opt") = Option());
1105     m.def(
1106         "copy_cut_border",
1107     [](const Mat& src, int top, int bottom, int left, int right, const Option& opt) {
1108         Mat dst;
1109         copy_cut_border(src, dst, top, bottom, left, right, opt);
1110         return dst;
1111     },
1112     py::arg("src"),
1113     py::arg("top"), py::arg("bottom"), py::arg("left"), py::arg("right"),
1114     py::arg("opt") = Option());
1115 
1116     m.def("resize_nearest", &resize_nearest,
1117           py::arg("src"), py::arg("dst"),
1118           py::arg("w"), py::arg("h"),
1119           py::arg("opt") = Option());
1120     m.def(
1121         "resize_nearest",
1122     [](const Mat& src, int w, int h, const Option& opt) {
1123         Mat dst;
1124         resize_nearest(src, dst, w, h);
1125         return dst;
1126     },
1127     py::arg("src"),
1128     py::arg("w"), py::arg("h"),
1129     py::arg("opt") = Option());
1130 
1131     m.def("resize_bilinear", &resize_bilinear,
1132           py::arg("src"), py::arg("dst"),
1133           py::arg("w"), py::arg("h"),
1134           py::arg("opt") = Option());
1135     m.def(
1136         "resize_bilinear",
1137     [](const Mat& src, int w, int h, const Option& opt) {
1138         Mat dst;
1139         resize_bilinear(src, dst, w, h, opt);
1140         return dst;
1141     },
1142     py::arg("src"),
1143     py::arg("w"), py::arg("h"),
1144     py::arg("opt") = Option());
1145 
1146     m.def("resize_bicubic", &resize_bicubic,
1147           py::arg("src"), py::arg("dst"),
1148           py::arg("w"), py::arg("h"),
1149           py::arg("opt") = Option());
1150     m.def(
1151         "resize_bicubic",
1152     [](const Mat& src, int w, int h, const Option& opt) {
1153         Mat dst;
1154         resize_bicubic(src, dst, w, h, opt);
1155         return dst;
1156     },
1157     py::arg("src"),
1158     py::arg("w"), py::arg("h"),
1159     py::arg("opt") = Option());
1160 
1161     m.def("convert_packing", &convert_packing,
1162           py::arg("src"), py::arg("dst"),
1163           py::arg("elempack"),
1164           py::arg("opt") = Option());
1165     m.def(
1166         "convert_packing",
1167     [](const Mat& src, int elempack, const Option& opt) {
1168         Mat dst;
1169         convert_packing(src, dst, elempack, opt);
1170         return dst;
1171     },
1172     py::arg("src"),
1173     py::arg("elempack"),
1174     py::arg("opt") = Option());
1175 
1176     m.def("flatten", &flatten,
1177           py::arg("src"), py::arg("dst"),
1178           py::arg("opt") = Option());
1179     m.def(
1180         "flatten",
1181     [](const Mat& src, const Option& opt) {
1182         Mat dst;
1183         flatten(src, dst, opt);
1184         return dst;
1185     },
1186     py::arg("src"),
1187     py::arg("opt") = Option());
1188 
1189     m.def("cast_float32_to_float16", &cast_float32_to_float16,
1190           py::arg("src"), py::arg("dst"),
1191           py::arg("opt") = Option());
1192     m.def(
1193         "cast_float32_to_float16",
1194     [](const Mat& src, const Option& opt) {
1195         Mat dst;
1196         cast_float32_to_float16(src, dst, opt);
1197         return dst;
1198     },
1199     py::arg("src"),
1200     py::arg("opt") = Option());
1201 
1202     m.def("cast_float16_to_float32", &cast_float16_to_float32,
1203           py::arg("src"), py::arg("dst"),
1204           py::arg("opt") = Option());
1205     m.def(
1206         "cast_float16_to_float32",
1207     [](const Mat& src, const Option& opt) {
1208         Mat dst;
1209         cast_float16_to_float32(src, dst, opt);
1210         return dst;
1211     },
1212     py::arg("src"),
1213     py::arg("opt") = Option());
1214 
1215     m.def("cast_int8_to_float32", &cast_int8_to_float32,
1216           py::arg("src"), py::arg("dst"),
1217           py::arg("opt") = Option());
1218     m.def(
1219         "cast_int8_to_float32",
1220     [](const Mat& src, const Option& opt) {
1221         Mat dst;
1222         cast_int8_to_float32(src, dst, opt);
1223         return dst;
1224     },
1225     py::arg("src"),
1226     py::arg("opt") = Option());
1227 
1228     m.def("cast_float32_to_bfloat16", &cast_float32_to_bfloat16,
1229           py::arg("src"), py::arg("dst"),
1230           py::arg("opt") = Option());
1231     m.def(
1232         "cast_float32_to_bfloat16",
1233     [](const Mat& src, const Option& opt) {
1234         Mat dst;
1235         cast_float32_to_bfloat16(src, dst, opt);
1236         return dst;
1237     },
1238     py::arg("src"),
1239     py::arg("opt") = Option());
1240 
1241     m.def("cast_bfloat16_to_float32", &cast_bfloat16_to_float32,
1242           py::arg("src"), py::arg("dst"),
1243           py::arg("opt") = Option());
1244     m.def(
1245         "cast_bfloat16_to_float32",
1246     [](const Mat& src, const Option& opt) {
1247         Mat dst;
1248         cast_bfloat16_to_float32(src, dst, opt);
1249         return dst;
1250     },
1251     py::arg("src"),
1252     py::arg("opt") = Option());
1253 
1254     m.def("quantize_to_int8", &quantize_to_int8,
1255           py::arg("src"), py::arg("dst"),
1256           py::arg("scale_data"),
1257           py::arg("opt") = Option());
1258     m.def(
1259         "quantize_to_int8",
1260     [](const Mat& src, const Mat& scale_data, const Option& opt) {
1261         Mat dst;
1262         quantize_to_int8(src, dst, scale_data, opt);
1263         return dst;
1264     },
1265     py::arg("src"),
1266     py::arg("scale_data"),
1267     py::arg("opt") = Option());
1268 
1269 #if NCNN_STRING
1270     m.def("layer_to_index", &layer_to_index, py::arg("type"));
1271     m.def(
1272         "create_layer",
1273     [](const char* type) {
1274         return static_cast<Layer*>(create_layer(type));
1275     },
1276     py::arg("type"));
1277     m.def(
1278         "create_layer",
1279     [](int index) {
1280         return static_cast<Layer*>(create_layer(index));
1281     },
1282     py::arg("index"));
1283 #endif //NCNN_STRING
1284 
1285 #if NCNN_VULKAN
1286     m.def("create_gpu_instance", &create_gpu_instance);
1287     m.def("destroy_gpu_instance", &destroy_gpu_instance);
1288     m.def("get_gpu_count", &get_gpu_count);
1289     m.def("get_default_gpu_index", &get_default_gpu_index);
1290     m.def("get_gpu_info", &get_gpu_info, py::arg("device_index") = 0, py::return_value_policy::reference);
1291     m.def("get_gpu_device", &get_gpu_device, py::arg("device_index") = 0, py::return_value_policy::reference);
1292 
1293     py::class_<VkBufferMemory>(m, "VkBufferMemory")
1294     .def_readwrite("offset", &VkBufferMemory::offset)
1295     .def_readwrite("capacity", &VkBufferMemory::capacity)
1296     .def_readwrite("refcount", &VkBufferMemory::refcount);
1297 
1298     py::class_<VkImageMemory>(m, "VkImageMemory")
1299     .def_readwrite("width", &VkImageMemory::width)
1300     .def_readwrite("height", &VkImageMemory::height)
1301     .def_readwrite("depth", &VkImageMemory::depth)
1302     .def_readwrite("refcount", &VkImageMemory::refcount);
1303 
1304     py::class_<VkAllocator, PyVkAllocator<> >(m, "VkAllocator")
1305     .def_readonly("vkdev", &VkAllocator::vkdev)
1306     .def_readwrite("buffer_memory_type_index", &VkAllocator::buffer_memory_type_index)
1307     .def_readwrite("image_memory_type_index", &VkAllocator::image_memory_type_index)
1308     .def_readwrite("mappable", &VkAllocator::mappable)
1309     .def_readwrite("coherent", &VkAllocator::coherent);
1310 
1311     py::class_<VkBlobAllocator, VkAllocator, PyVkAllocatorOther<VkBlobAllocator> >(m, "VkBlobAllocator")
1312     .def(py::init<const VulkanDevice*>())
1313     .def("clear", &VkBlobAllocator::clear)
1314     .def("fastMalloc", (VkBufferMemory * (VkBlobAllocator::*)(size_t size)) & VkBlobAllocator::fastMalloc, py::return_value_policy::reference_internal)
1315     .def("fastFree", (void (VkBlobAllocator::*)(VkBufferMemory * ptr)) & VkBlobAllocator::fastFree)
1316     .def("fastMalloc", (VkImageMemory * (VkBlobAllocator::*)(int, int, int, size_t, int)) & VkBlobAllocator::fastMalloc, py::return_value_policy::reference_internal)
1317     .def("fastFree", (void (VkBlobAllocator::*)(VkImageMemory * ptr)) & VkBlobAllocator::fastFree);
1318 
1319     py::class_<VkWeightAllocator, VkAllocator, PyVkAllocatorOther<VkWeightAllocator> >(m, "VkWeightAllocator")
1320     .def(py::init<const VulkanDevice*>())
1321     .def("clear", &VkWeightAllocator::clear)
1322     .def("fastMalloc", (VkBufferMemory * (VkWeightAllocator::*)(size_t size)) & VkWeightAllocator::fastMalloc, py::return_value_policy::reference_internal)
1323     .def("fastFree", (void (VkWeightAllocator::*)(VkBufferMemory * ptr)) & VkWeightAllocator::fastFree)
1324     .def("fastMalloc", (VkImageMemory * (VkWeightAllocator::*)(int, int, int, size_t, int)) & VkWeightAllocator::fastMalloc, py::return_value_policy::reference_internal)
1325     .def("fastFree", (void (VkWeightAllocator::*)(VkImageMemory * ptr)) & VkWeightAllocator::fastFree);
1326 
1327     py::class_<VkStagingAllocator, VkAllocator, PyVkAllocatorOther<VkStagingAllocator> >(m, "VkStagingAllocator")
1328     .def(py::init<const VulkanDevice*>())
1329     .def("set_size_compare_ratio", &VkStagingAllocator::set_size_compare_ratio)
1330     .def("clear", &VkStagingAllocator::clear)
1331     .def("fastMalloc", (VkBufferMemory * (VkStagingAllocator::*)(size_t size)) & VkStagingAllocator::fastMalloc, py::return_value_policy::reference_internal)
1332     .def("fastFree", (void (VkStagingAllocator::*)(VkBufferMemory * ptr)) & VkStagingAllocator::fastFree)
1333     .def("fastMalloc", (VkImageMemory * (VkStagingAllocator::*)(int, int, int, size_t, int)) & VkStagingAllocator::fastMalloc, py::return_value_policy::reference_internal)
1334     .def("fastFree", (void (VkStagingAllocator::*)(VkImageMemory * ptr)) & VkStagingAllocator::fastFree);
1335 
1336     py::class_<VkWeightStagingAllocator, VkAllocator, PyVkAllocatorOther<VkWeightStagingAllocator> >(m, "VkWeightStagingAllocator")
1337     .def(py::init<const VulkanDevice*>())
1338     .def("fastMalloc", (VkBufferMemory * (VkWeightStagingAllocator::*)(size_t size)) & VkWeightStagingAllocator::fastMalloc, py::return_value_policy::reference_internal)
1339     .def("fastFree", (void (VkWeightStagingAllocator::*)(VkBufferMemory * ptr)) & VkWeightStagingAllocator::fastFree)
1340     .def("fastMalloc", (VkImageMemory * (VkWeightStagingAllocator::*)(int, int, int, size_t, int)) & VkWeightStagingAllocator::fastMalloc, py::return_value_policy::reference_internal)
1341     .def("fastFree", (void (VkWeightStagingAllocator::*)(VkImageMemory * ptr)) & VkWeightStagingAllocator::fastFree);
1342 
1343     py::class_<GpuInfo>(m, "GpuInfo")
1344     .def(py::init<>())
1345     .def("api_version", &GpuInfo::api_version)
1346     .def("driver_version", &GpuInfo::driver_version)
1347     .def("vendor_id", &GpuInfo::vendor_id)
1348     .def("device_id", &GpuInfo::device_id)
1349     .def("pipeline_cache_uuid", [](GpuInfo& gpuinfo) {
1350         return py::memoryview::from_buffer(gpuinfo.pipeline_cache_uuid(), {VK_UUID_SIZE}, {sizeof(uint8_t) * VK_UUID_SIZE});
1351     })
1352     .def("type", &GpuInfo::type);
1353 
1354     py::class_<VulkanDevice>(m, "VulkanDevice")
1355     .def(py::init<int>(), py::arg("device_index") = 0)
1356     .def(
1357     "info", [](VulkanDevice& dev) {
1358         return &dev.info;
1359     },
1360     py::return_value_policy::reference_internal);
1361 #endif // NCNN_VULKAN
1362 
1363     m.doc() = R"pbdoc(
1364         ncnn python wrapper
1365         -----------------------
1366         .. currentmodule:: pyncnn
1367         .. autosummary::
1368            :toctree: _generate
1369     )pbdoc";
1370 
1371 #ifdef VERSION_INFO
1372     m.attr("__version__") = VERSION_INFO;
1373 #else
1374     m.attr("__version__") = "dev";
1375 #endif
1376 }
1377