1 //
2 //  NN.hpp
3 //  MNN
4 //
5 //  Created by MNN on 2019/11/25.
6 //  Copyright © 2018, Alibaba Group Holding Limited
7 //
8 
9 #ifndef MNN_Train_NN_hpp
10 #define MNN_Train_NN_hpp
11 #include <MNN/expr/ExprCreator.hpp>
12 #include <MNN/expr/Module.hpp>
13 #include <vector>
14 namespace MNN {
15 namespace Express {
16 class Initializer;
17 
18 class MNN_PUBLIC NN {
19 public:
20     enum ActivationFunctionType {
21         None = 0,
22         Relu = 1,
23         Relu6 = 2,
24     };
25     enum ScaleUpdateMethod {
26         Maximum = 0,
27         MovingAverage = 1
28     };
29     enum FeatureScaleStatMethod {
30         PerTensor = 0,
31         PerChannel = 1 // Depercerate
32     };
33     /* Unlike enum in class, class in class need be dllimport or dllexport explcility.
34        Compiling in other system will not be affected.
35      */
36     struct MNN_PUBLIC ConvOption {
37         Express::INTS kernelSize     = {1, 1};
38         Express::INTS channel        = {0, 0};
39         Express::INTS stride         = {1, 1};
40         Express::INTS dilate         = {1, 1};
41         Express::PaddingMode padMode = Express::VALID;
42         Express::INTS pads           = {0, 0};
43         bool depthwise               = false;
44         ActivationFunctionType fusedActivationFunction = None;
45         void reset(int size = 2);
46     };
47     static Module* Conv(const ConvOption& option, bool bias = true,
48                                         std::shared_ptr<Initializer> weightInit = nullptr,
49                                         std::shared_ptr<Initializer> biasInit   = nullptr);
50     static Module* ConvTranspose(const ConvOption& option, bool bias = true,
51                                                  std::shared_ptr<Initializer> weightInit = nullptr,
52                                                  std::shared_ptr<Initializer> biasInit   = nullptr);
53     static Module* Linear(int l, int t, bool hasBias = true,
54                                           std::shared_ptr<Initializer> weightInit = nullptr,
55                                           std::shared_ptr<Initializer> biasInit   = nullptr);
56     static Module* Dropout(const float dropRatio);
57     static Module* BatchNorm(const int channels, const int dims = 4, const float m = 0.999,
58                                              const float e = 1e-5);
59 
60     static Module* ConvInt8(const ConvOption& option, int bits = 8, bool bias = true,
61                                             std::shared_ptr<Initializer> weightInit = nullptr,
62                                             std::shared_ptr<Initializer> biasInit   = nullptr,
63                                             FeatureScaleStatMethod featureMethod = PerChannel,
64                                             ScaleUpdateMethod method = MovingAverage
65                                             );
66     struct ConvParameters {
67         ConvOption option;
68         Express::VARP weight;
69         Express::VARP bias;
70         int group;
71         std::string name;
72     };
73     static Module* ConvInt8(const ConvParameters& parameters, int bits,
74                                             FeatureScaleStatMethod featureMethod = PerChannel,
75                                             ScaleUpdateMethod method = MovingAverage);
76     static Module* Conv(const ConvParameters& parameters);
77     static Module* ConvBNReluFused(std::vector<std::shared_ptr<Module> > modules,
78                                                    NN::FeatureScaleStatMethod featureScaleStatMethod = PerTensor,
79                                                    NN::ScaleUpdateMethod scaleUpdateMethod = MovingAverage, const int bits = 8);
80 
81     class Utils {
82     public:
83         // ConvOption, Weight, Bias, Group
84         static ConvParameters ExtractConvolution(Express::EXPRP expr);
85 
86         // Extract BatchNormal and Dropout
87         static Module* ExtractNotRunableOp(Express::EXPRP expr, const std::map<std::string, SubGraph>& subgraphs);
88     };
89 
90     static bool turnQuantize(Module* module, const int bits = 8, NN::FeatureScaleStatMethod featureScaleStatMethod = NN::PerTensor, NN::ScaleUpdateMethod scaleUpdateMethod = NN::MovingAverage);
91     static Module* extract(std::vector<Express::VARP> inputs, std::vector<Express::VARP> outputs, bool fortrain, const std::map<std::string, SubGraph>& subGraph = {});
92 };
93 
94 } // namespace Train
95 } // namespace MNN
96 
97 #endif
98