1 // 2 // WinogradInt8Helper.hpp 3 // MNN 4 // 5 // Created by MNN on 2018/07/16. 6 // Copyright © 2018, Alibaba Group Holding Limited 7 // 8 9 #ifndef WinogradInt8Helper_hpp 10 #define WinogradInt8Helper_hpp 11 12 #include <cstddef> 13 #include <cstdint> 14 #include "MNN/Tensor.hpp" 15 #include "core/ConvolutionCommon.hpp" 16 17 namespace MNN { 18 struct CoreInt8Functions; 19 class WinogradInt8Helper { 20 public: 21 WinogradInt8Helper(int unitY, int unitX, const Convolution2DCommon *common, const CoreInt8Functions* core); 22 ~WinogradInt8Helper() = default; 23 std::shared_ptr<Tensor> allocTransformWeight(const Tensor* weightSrc); 24 bool transformWeight(const Tensor* weightSrc, Tensor* weightDst); 25 26 typedef void(*SrcTransFunc)(const int8_t* srcStart, int8_t* dstStart, size_t srcZStep, size_t dstXStep, size_t dstZStep, size_t countUnit); 27 typedef void(*DstTransFunc)(const float* srcStart, float* dstStart, size_t srcXStep, size_t srcZStep, size_t dstXStep, size_t dstZStep, size_t countUnit); 28 static SrcTransFunc chooseSourceTransform(int alpha, int inPack, int outPack); 29 static DstTransFunc chooseDestTransform(int alpha, int unit); 30 static bool weightOverflow(const Tensor* weightSrc, int unitY, int unitX, const Convolution2DCommon* common, const CoreInt8Functions* core); 31 static bool featureOverflow(const Tensor* featureSrc, int alphaY, int alphaX); 32 private: 33 const Convolution2DCommon *mCommon; 34 int mAlphaY; 35 int mAlphaX; 36 const CoreInt8Functions* mInt8Core; 37 bool mValid = true; 38 }; 39 } 40 41 #endif // WinogradInt8Helper_hpp 42