1 // Tencent is pleased to support the open source community by making ncnn available. 2 // 3 // Copyright (C) 2017 THL A29 Limited, a Tencent company. All rights reserved. 4 // 5 // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except 6 // in compliance with the License. You may obtain a copy of the License at 7 // 8 // https://opensource.org/licenses/BSD-3-Clause 9 // 10 // Unless required by applicable law or agreed to in writing, software distributed 11 // under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR 12 // CONDITIONS OF ANY KIND, either express or implied. See the License for the 13 // specific language governing permissions and limitations under the License. 14 15 #ifndef NCNN_NET_H 16 #define NCNN_NET_H 17 18 #include "blob.h" 19 #include "layer.h" 20 #include "mat.h" 21 #include "option.h" 22 #include "platform.h" 23 24 #if NCNN_PLATFORM_API 25 #if __ANDROID_API__ >= 9 26 #include <android/asset_manager.h> 27 #endif // __ANDROID_API__ >= 9 28 #endif // NCNN_PLATFORM_API 29 30 namespace ncnn { 31 32 #if NCNN_VULKAN 33 class VkCompute; 34 #endif // NCNN_VULKAN 35 class DataReader; 36 class Extractor; 37 class NetPrivate; 38 class NCNN_EXPORT Net 39 { 40 public: 41 // empty init 42 Net(); 43 // clear and destroy 44 virtual ~Net(); 45 46 public: 47 // option can be changed before loading 48 Option opt; 49 50 #if NCNN_VULKAN 51 // set gpu device by index 52 void set_vulkan_device(int device_index); 53 54 // set gpu device by device handle, no owner transfer 55 void set_vulkan_device(const VulkanDevice* vkdev); 56 57 const VulkanDevice* vulkan_device() const; 58 #endif // NCNN_VULKAN 59 60 #if NCNN_STRING 61 // register custom layer by layer type name 62 // return 0 if success 63 int register_custom_layer(const char* type, layer_creator_func creator, layer_destroyer_func destroyer = 0, void* userdata = 0); 64 #endif // NCNN_STRING 65 // register custom layer by layer type 66 // return 0 if success 67 int register_custom_layer(int index, layer_creator_func creator, layer_destroyer_func destroyer = 0, void* userdata = 0); 68 69 #if NCNN_STRING 70 int load_param(const DataReader& dr); 71 #endif // NCNN_STRING 72 73 int load_param_bin(const DataReader& dr); 74 75 int load_model(const DataReader& dr); 76 77 #if NCNN_STDIO 78 #if NCNN_STRING 79 // load network structure from plain param file 80 // return 0 if success 81 int load_param(FILE* fp); 82 int load_param(const char* protopath); 83 int load_param_mem(const char* mem); 84 #endif // NCNN_STRING 85 // load network structure from binary param file 86 // return 0 if success 87 int load_param_bin(FILE* fp); 88 int load_param_bin(const char* protopath); 89 90 // load network weight data from model file 91 // return 0 if success 92 int load_model(FILE* fp); 93 int load_model(const char* modelpath); 94 #endif // NCNN_STDIO 95 96 // load network structure from external memory 97 // memory pointer must be 32-bit aligned 98 // return bytes consumed 99 int load_param(const unsigned char* mem); 100 101 // reference network weight data from external memory 102 // weight data is not copied but referenced 103 // so external memory should be retained when used 104 // memory pointer must be 32-bit aligned 105 // return bytes consumed 106 int load_model(const unsigned char* mem); 107 108 #if NCNN_PLATFORM_API 109 #if __ANDROID_API__ >= 9 110 #if NCNN_STRING 111 // convenient load network structure from android asset plain param file 112 int load_param(AAsset* asset); 113 int load_param(AAssetManager* mgr, const char* assetpath); 114 #endif // NCNN_STRING 115 // convenient load network structure from android asset binary param file 116 int load_param_bin(AAsset* asset); 117 int load_param_bin(AAssetManager* mgr, const char* assetpath); 118 119 // convenient load network weight data from android asset model file 120 int load_model(AAsset* asset); 121 int load_model(AAssetManager* mgr, const char* assetpath); 122 #endif // __ANDROID_API__ >= 9 123 #endif // NCNN_PLATFORM_API 124 125 // unload network structure and weight data 126 void clear(); 127 128 // construct an Extractor from network 129 Extractor create_extractor() const; 130 131 // get input/output indexes/names 132 const std::vector<int>& input_indexes() const; 133 const std::vector<int>& output_indexes() const; 134 #if NCNN_STRING 135 const std::vector<const char*>& input_names() const; 136 const std::vector<const char*>& output_names() const; 137 #endif 138 139 const std::vector<Blob>& blobs() const; 140 const std::vector<Layer*>& layers() const; 141 142 std::vector<Blob>& mutable_blobs(); 143 std::vector<Layer*>& mutable_layers(); 144 145 protected: 146 friend class Extractor; 147 #if NCNN_STRING 148 int find_blob_index_by_name(const char* name) const; 149 int find_layer_index_by_name(const char* name) const; 150 virtual int custom_layer_to_index(const char* type); 151 virtual Layer* create_custom_layer(const char* type); 152 #endif // NCNN_STRING 153 virtual Layer* create_custom_layer(int index); 154 155 private: 156 Net(const Net&); 157 Net& operator=(const Net&); 158 159 private: 160 NetPrivate* const d; 161 }; 162 163 class ExtractorPrivate; 164 class NCNN_EXPORT Extractor 165 { 166 public: 167 virtual ~Extractor(); 168 169 // copy 170 Extractor(const Extractor&); 171 172 // assign 173 Extractor& operator=(const Extractor&); 174 175 // clear blob mats and alloctors 176 void clear(); 177 178 // enable light mode 179 // intermediate blob will be recycled when enabled 180 // enabled by default 181 void set_light_mode(bool enable); 182 183 // set thread count for this extractor 184 // this will overwrite the global setting 185 // default count is system depended 186 void set_num_threads(int num_threads); 187 188 // set blob memory allocator 189 void set_blob_allocator(Allocator* allocator); 190 191 // set workspace memory allocator 192 void set_workspace_allocator(Allocator* allocator); 193 194 #if NCNN_VULKAN 195 void set_vulkan_compute(bool enable); 196 197 void set_blob_vkallocator(VkAllocator* allocator); 198 199 void set_workspace_vkallocator(VkAllocator* allocator); 200 201 void set_staging_vkallocator(VkAllocator* allocator); 202 #endif // NCNN_VULKAN 203 204 #if NCNN_STRING 205 // set input by blob name 206 // return 0 if success 207 int input(const char* blob_name, const Mat& in); 208 209 // get result by blob name 210 // return 0 if success 211 // type = 0, default 212 // type = 1, do not convert fp16/bf16 or / and packing 213 int extract(const char* blob_name, Mat& feat, int type = 0); 214 #endif // NCNN_STRING 215 216 // set input by blob index 217 // return 0 if success 218 int input(int blob_index, const Mat& in); 219 220 // get result by blob index 221 // return 0 if success 222 // type = 0, default 223 // type = 1, do not convert fp16/bf16 or / and packing 224 int extract(int blob_index, Mat& feat, int type = 0); 225 226 #if NCNN_VULKAN 227 #if NCNN_STRING 228 // set input by blob name 229 // return 0 if success 230 int input(const char* blob_name, const VkMat& in); 231 232 // get result by blob name 233 // return 0 if success 234 int extract(const char* blob_name, VkMat& feat, VkCompute& cmd); 235 236 // set input by blob name 237 // return 0 if success 238 int input(const char* blob_name, const VkImageMat& in); 239 240 // get result by blob name 241 // return 0 if success 242 int extract(const char* blob_name, VkImageMat& feat, VkCompute& cmd); 243 #endif // NCNN_STRING 244 245 // set input by blob index 246 // return 0 if success 247 int input(int blob_index, const VkMat& in); 248 249 // get result by blob index 250 // return 0 if success 251 int extract(int blob_index, VkMat& feat, VkCompute& cmd); 252 253 // set input by blob index 254 // return 0 if success 255 int input(int blob_index, const VkImageMat& in); 256 257 // get result by blob index 258 // return 0 if success 259 int extract(int blob_index, VkImageMat& feat, VkCompute& cmd); 260 #endif // NCNN_VULKAN 261 262 protected: 263 friend Extractor Net::create_extractor() const; 264 Extractor(const Net* net, size_t blob_count); 265 266 private: 267 ExtractorPrivate* const d; 268 }; 269 270 } // namespace ncnn 271 272 #endif // NCNN_NET_H 273