1 //
2 //  ConvInt8Winograd.hpp
3 //  MNN
4 //
5 //  Created by MNN on 2018/08/20.
6 //  Copyright © 2018, Alibaba Group Holding Limited
7 //
8 
9 #ifndef ConvInt8Winograd_hpp
10 #define ConvInt8Winograd_hpp
11 
12 #include "backend/cpu/CPUConvolution.hpp"
13 #include "backend/cpu/compute/Int8FunctionsOpt.h"
14 
15 namespace MNN {
16 class ConvInt8Winograd : public CPUConvolution {
17 public:
18     using CommonPair = std::pair<const Convolution2DCommon*, unsigned char*>;
19     struct UnitAttr {
20         int kyStart;
21         int kySize;
22         int kxStart;
23         int kxSize;
24         int unitY;
25         int unitX;
26     };
27     struct Unit {
28         UnitAttr attr;
29         std::shared_ptr<CommonPair> common;
30         std::shared_ptr<Tensor> input;
31         std::shared_ptr<Tensor> output;
32         std::shared_ptr<Execution> runner;
33     };
34     ConvInt8Winograd(Backend *b, const Convolution2D *convOp, std::shared_ptr<ResourceInt8> res, std::vector<ConvInt8Winograd::UnitAttr>& unitAttrs);
35     virtual ~ConvInt8Winograd();
36     virtual ErrorCode onExecute(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) override;
37     virtual ErrorCode onResize(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) override;
38 
39     static bool bestWinogradUnit(const Convolution2D *convOp, const Tensor *input, const Tensor* weightSrc, const Tensor *output, Backend* bn, std::vector<UnitAttr>& unitAttrs);
40     virtual bool onClone(Backend* bn, const Op* op, Execution** dst) override;
41 private:
42     ConvInt8Winograd(Backend* backend, const Convolution2DCommon* common, const ConvInt8Winograd& exe);
43     // transform func
44     using WinoSrcTransFunc = WinogradInt8Helper::SrcTransFunc;
45     using WinoDstTransFunc = WinogradInt8Helper::DstTransFunc;
46     // subExecutions
47     std::vector<Unit> mUnits;
48     std::shared_ptr<CPUConvolution::ResourceInt8> mResource;
49 
50     class WinoExecution : public CPUConvolution {
51     public:
52         WinoExecution(Backend *b, const Convolution2DCommon* common, Tensor* weight, int unitY, int unitX, bool fastgemm);
53 
54         WinoExecution(Backend* bn, const Convolution2DCommon* common, const WinoExecution& exe);
55         virtual ~WinoExecution();
56         virtual ErrorCode onExecute(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) override;
57         virtual ErrorCode onResize(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) override;
58         virtual bool onClone(Backend* bn, const Op* op, Execution** dst) override;
59         // weight
60         std::shared_ptr<Tensor> mWeight;
61         // buffer
62         std::shared_ptr<Tensor> mTempInputBuffer;
63         std::shared_ptr<Tensor> mTempOutputBuffer;
64         std::shared_ptr<Tensor> mTransformMidBuffer;
65         // transform func
66         WinoSrcTransFunc mSourceTransformY = nullptr;
67         WinoSrcTransFunc mSourceTransformX = nullptr;
68         WinoDstTransFunc mDestTransformY = nullptr;
69         WinoDstTransFunc mDestTransformX = nullptr;
70         // unit and kernel
71         int mUnitY, mUnitX;
72         int mKernelY, mKernelX;
73         // gemm func
74         decltype(CoreInt8Functions::Int8GemmKernel) mGemmKernel;
75         // other quan attr
76         int8_t mInputZeroPoint;
77         std::shared_ptr<Tensor> mOffsets;
78         friend class ConvInt8Winograd;
79     };
80 
81     static bool chooseTransformFuncs(int kernelY, int kernelX, int unitY, int unitX, ConvInt8Winograd::WinoExecution* exe, Backend* bn);
82 };
83 } // namespace MNN
84 #endif /* ConvInt8Winograd_hpp */
85