1 //
2 //  ConvolutionCommon.hpp
3 //  MNN
4 //
5 //  Created by MNN on 2020/03/02.
6 //  Copyright © 2018, Alibaba Group Holding Limited
7 //
8 
9 #ifndef ConvolutionCommon_hpp
10 #define ConvolutionCommon_hpp
11 #include "AutoStorage.h"
12 #include "Execution.hpp"
13 #include "MNN_generated.h"
14 namespace MNN {
15 class MNN_PUBLIC ConvolutionCommon : public Execution {
16 public:
17     struct Int8Common {
18         AutoStorage<int8_t> weight;
19         AutoStorage<float> alpha;
20         AutoStorage<float> weightFloat;
21         const IDSTQuan* quan;
22     };
23     static std::shared_ptr<Int8Common> load(const IDSTQuan* quan, bool forceFloat = false, bool forceInt8 = false);
24     static void getConvParameters(std::shared_ptr<ConvolutionCommon::Int8Common> *quanCommon, const MNN::Convolution2D *conv2d, const float** originWeight, int* originWeightSize);
25     static bool getConvInt8Parameters(const MNN::Convolution2D* conv2d, std::shared_ptr<Int8Common>& quanCommon,
26                                       const int8_t*& weight, float*& scale, int32_t*& bias, float inputScale, float outputScale, int inputZeroPoint, int outputZeroPoint);
27 
28     // Return padX, padY
29     static std::pair<int, int> convolutionPad(const Tensor* input, const Tensor* output,
30                                               const Convolution2DCommon* common);
31     // Return padLeft, padTop, padRight, padBottom
32     static std::tuple<int, int, int, int> convolutionPadFull(const Tensor* input, const Tensor* output,
33                                               const Convolution2DCommon* common);
34     static std::pair<int, int> convolutionTransposePad(const Tensor* input, const Tensor* output,
35                                                        const Convolution2DCommon* common);
36     struct Im2ColParameter {
37         int32_t padX;
38         int32_t padY;
39         int32_t dilateX;
40         int32_t dilateY;
41         int32_t strideX;
42         int32_t strideY;
43         int32_t kernelX;
44         int32_t kernelY;
45         int32_t icDiv4;
46         int32_t kernelCountUnit;
47         int32_t iw;
48         int32_t ih;
49         int32_t ow;
50         int32_t oh;
51         int32_t srcZStep;
52         int32_t srcYStep;
53     };
54 };
55 } // namespace MNN
56 #endif
57