1 // 2 // NN.hpp 3 // MNN 4 // 5 // Created by MNN on 2019/11/25. 6 // Copyright © 2018, Alibaba Group Holding Limited 7 // 8 9 #ifndef MNN_Train_NN_hpp 10 #define MNN_Train_NN_hpp 11 #include <MNN/expr/ExprCreator.hpp> 12 #include <MNN/expr/Module.hpp> 13 #include <vector> 14 namespace MNN { 15 namespace Express { 16 class Initializer; 17 18 class MNN_PUBLIC NN { 19 public: 20 enum ActivationFunctionType { 21 None = 0, 22 Relu = 1, 23 Relu6 = 2, 24 }; 25 enum ScaleUpdateMethod { 26 Maximum = 0, 27 MovingAverage = 1 28 }; 29 enum FeatureScaleStatMethod { 30 PerTensor = 0, 31 PerChannel = 1 // Depercerate 32 }; 33 /* Unlike enum in class, class in class need be dllimport or dllexport explcility. 34 Compiling in other system will not be affected. 35 */ 36 struct MNN_PUBLIC ConvOption { 37 Express::INTS kernelSize = {1, 1}; 38 Express::INTS channel = {0, 0}; 39 Express::INTS stride = {1, 1}; 40 Express::INTS dilate = {1, 1}; 41 Express::PaddingMode padMode = Express::VALID; 42 Express::INTS pads = {0, 0}; 43 bool depthwise = false; 44 ActivationFunctionType fusedActivationFunction = None; 45 void reset(int size = 2); 46 }; 47 static Module* Conv(const ConvOption& option, bool bias = true, 48 std::shared_ptr<Initializer> weightInit = nullptr, 49 std::shared_ptr<Initializer> biasInit = nullptr); 50 static Module* ConvTranspose(const ConvOption& option, bool bias = true, 51 std::shared_ptr<Initializer> weightInit = nullptr, 52 std::shared_ptr<Initializer> biasInit = nullptr); 53 static Module* Linear(int l, int t, bool hasBias = true, 54 std::shared_ptr<Initializer> weightInit = nullptr, 55 std::shared_ptr<Initializer> biasInit = nullptr); 56 static Module* Dropout(const float dropRatio); 57 static Module* BatchNorm(const int channels, const int dims = 4, const float m = 0.999, 58 const float e = 1e-5); 59 60 static Module* ConvInt8(const ConvOption& option, int bits = 8, bool bias = true, 61 std::shared_ptr<Initializer> weightInit = nullptr, 62 std::shared_ptr<Initializer> biasInit = nullptr, 63 FeatureScaleStatMethod featureMethod = PerChannel, 64 ScaleUpdateMethod method = MovingAverage 65 ); 66 struct ConvParameters { 67 ConvOption option; 68 Express::VARP weight; 69 Express::VARP bias; 70 int group; 71 std::string name; 72 }; 73 static Module* ConvInt8(const ConvParameters& parameters, int bits, 74 FeatureScaleStatMethod featureMethod = PerChannel, 75 ScaleUpdateMethod method = MovingAverage); 76 static Module* Conv(const ConvParameters& parameters); 77 static Module* ConvBNReluFused(std::vector<std::shared_ptr<Module> > modules, 78 NN::FeatureScaleStatMethod featureScaleStatMethod = PerTensor, 79 NN::ScaleUpdateMethod scaleUpdateMethod = MovingAverage, const int bits = 8); 80 81 class Utils { 82 public: 83 // ConvOption, Weight, Bias, Group 84 static ConvParameters ExtractConvolution(Express::EXPRP expr); 85 86 // Extract BatchNormal and Dropout 87 static Module* ExtractNotRunableOp(Express::EXPRP expr, const std::map<std::string, SubGraph>& subgraphs); 88 }; 89 90 static bool turnQuantize(Module* module, const int bits = 8, NN::FeatureScaleStatMethod featureScaleStatMethod = NN::PerTensor, NN::ScaleUpdateMethod scaleUpdateMethod = NN::MovingAverage); 91 static Module* extract(std::vector<Express::VARP> inputs, std::vector<Express::VARP> outputs, bool fortrain, const std::map<std::string, SubGraph>& subGraph = {}); 92 }; 93 94 } // namespace Train 95 } // namespace MNN 96 97 #endif 98