1 //
2 //  WinogradInt8Helper.hpp
3 //  MNN
4 //
5 //  Created by MNN on 2018/07/16.
6 //  Copyright © 2018, Alibaba Group Holding Limited
7 //
8 
9 #ifndef WinogradInt8Helper_hpp
10 #define WinogradInt8Helper_hpp
11 
12 #include <cstddef>
13 #include <cstdint>
14 #include "MNN/Tensor.hpp"
15 #include "core/ConvolutionCommon.hpp"
16 
17 namespace MNN {
18 struct CoreInt8Functions;
19 class WinogradInt8Helper {
20 public:
21     WinogradInt8Helper(int unitY, int unitX, const Convolution2DCommon *common, const CoreInt8Functions* core);
22     ~WinogradInt8Helper() = default;
23     std::shared_ptr<Tensor> allocTransformWeight(const Tensor* weightSrc);
24     bool transformWeight(const Tensor* weightSrc, Tensor* weightDst);
25 
26     typedef void(*SrcTransFunc)(const int8_t* srcStart, int8_t* dstStart, size_t srcZStep, size_t dstXStep, size_t dstZStep, size_t countUnit);
27     typedef void(*DstTransFunc)(const float* srcStart, float* dstStart, size_t srcXStep, size_t srcZStep, size_t dstXStep, size_t dstZStep, size_t countUnit);
28     static SrcTransFunc chooseSourceTransform(int alpha, int inPack, int outPack);
29     static DstTransFunc chooseDestTransform(int alpha, int unit);
30     static bool weightOverflow(const Tensor* weightSrc, int unitY, int unitX, const Convolution2DCommon* common, const CoreInt8Functions* core);
31     static bool featureOverflow(const Tensor* featureSrc, int alphaY, int alphaX);
32 private:
33     const Convolution2DCommon *mCommon;
34     int mAlphaY;
35     int mAlphaX;
36     const CoreInt8Functions* mInt8Core;
37     bool mValid = true;
38 };
39 }
40 
41 #endif // WinogradInt8Helper_hpp
42