1 // 2 // ConvInt8Winograd.hpp 3 // MNN 4 // 5 // Created by MNN on 2018/08/20. 6 // Copyright © 2018, Alibaba Group Holding Limited 7 // 8 9 #ifndef ConvInt8Winograd_hpp 10 #define ConvInt8Winograd_hpp 11 12 #include "backend/cpu/CPUConvolution.hpp" 13 #include "backend/cpu/compute/Int8FunctionsOpt.h" 14 15 namespace MNN { 16 class ConvInt8Winograd : public CPUConvolution { 17 public: 18 using CommonPair = std::pair<const Convolution2DCommon*, unsigned char*>; 19 struct UnitAttr { 20 int kyStart; 21 int kySize; 22 int kxStart; 23 int kxSize; 24 int unitY; 25 int unitX; 26 }; 27 struct Unit { 28 UnitAttr attr; 29 std::shared_ptr<CommonPair> common; 30 std::shared_ptr<Tensor> input; 31 std::shared_ptr<Tensor> output; 32 std::shared_ptr<Execution> runner; 33 }; 34 ConvInt8Winograd(Backend *b, const Convolution2D *convOp, std::shared_ptr<ResourceInt8> res, std::vector<ConvInt8Winograd::UnitAttr>& unitAttrs); 35 virtual ~ConvInt8Winograd(); 36 virtual ErrorCode onExecute(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) override; 37 virtual ErrorCode onResize(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) override; 38 39 static bool bestWinogradUnit(const Convolution2D *convOp, const Tensor *input, const Tensor* weightSrc, const Tensor *output, Backend* bn, std::vector<UnitAttr>& unitAttrs); 40 virtual bool onClone(Backend* bn, const Op* op, Execution** dst) override; 41 private: 42 ConvInt8Winograd(Backend* backend, const Convolution2DCommon* common, const ConvInt8Winograd& exe); 43 // transform func 44 using WinoSrcTransFunc = WinogradInt8Helper::SrcTransFunc; 45 using WinoDstTransFunc = WinogradInt8Helper::DstTransFunc; 46 // subExecutions 47 std::vector<Unit> mUnits; 48 std::shared_ptr<CPUConvolution::ResourceInt8> mResource; 49 50 class WinoExecution : public CPUConvolution { 51 public: 52 WinoExecution(Backend *b, const Convolution2DCommon* common, Tensor* weight, int unitY, int unitX, bool fastgemm); 53 54 WinoExecution(Backend* bn, const Convolution2DCommon* common, const WinoExecution& exe); 55 virtual ~WinoExecution(); 56 virtual ErrorCode onExecute(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) override; 57 virtual ErrorCode onResize(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) override; 58 virtual bool onClone(Backend* bn, const Op* op, Execution** dst) override; 59 // weight 60 std::shared_ptr<Tensor> mWeight; 61 // buffer 62 std::shared_ptr<Tensor> mTempInputBuffer; 63 std::shared_ptr<Tensor> mTempOutputBuffer; 64 std::shared_ptr<Tensor> mTransformMidBuffer; 65 // transform func 66 WinoSrcTransFunc mSourceTransformY = nullptr; 67 WinoSrcTransFunc mSourceTransformX = nullptr; 68 WinoDstTransFunc mDestTransformY = nullptr; 69 WinoDstTransFunc mDestTransformX = nullptr; 70 // unit and kernel 71 int mUnitY, mUnitX; 72 int mKernelY, mKernelX; 73 // gemm func 74 decltype(CoreInt8Functions::Int8GemmKernel) mGemmKernel; 75 // other quan attr 76 int8_t mInputZeroPoint; 77 std::shared_ptr<Tensor> mOffsets; 78 friend class ConvInt8Winograd; 79 }; 80 81 static bool chooseTransformFuncs(int kernelY, int kernelX, int unitY, int unitX, ConvInt8Winograd::WinoExecution* exe, Backend* bn); 82 }; 83 } // namespace MNN 84 #endif /* ConvInt8Winograd_hpp */ 85