1 //
2 //  Convolution1x1Strassen.hpp
3 //  MNN
4 //
5 //  Created by MNN on 2019/02/12.
6 //  Copyright © 2018, Alibaba Group Holding Limited
7 //
8 
9 #ifndef Convolution1x1Strassen_hpp
10 #define Convolution1x1Strassen_hpp
11 
12 #include <functional>
13 #include "backend/cpu/CPUConvolution.hpp"
14 #include "backend/cpu/compute/StrassenMatmulComputor.hpp"
15 namespace MNN {
16 class Convolution1x1Strassen : public CPUConvolution {
17 public:
18     Convolution1x1Strassen(const Convolution2DCommon *common, Backend *b, const float *originWeight,
19                            size_t originWeightSize, const float *bias, size_t biasSize);
20     Convolution1x1Strassen(std::shared_ptr<CPUConvolution::Resource> resource, const Convolution2DCommon *common, Backend* b);
21     virtual ~Convolution1x1Strassen();
22 
23     virtual ErrorCode onExecute(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) override;
24 
25     virtual ErrorCode onResize(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) override;
26     virtual bool onClone(Backend* bn, const Op* op, Execution** dst) override;
27 private:
28     std::shared_ptr<CPUConvolution::Resource> mResource;
29 
30     struct Unit {
31         bool mValid = true;
32         std::vector<Tensor *> mTempInputVector;
33         std::vector<Tensor *> mTempOutputVector;
34         std::shared_ptr<StrassenMatrixComputor> mStracssenComputor;
35     };
36 
37     std::vector<Unit> mUnits;
38     std::shared_ptr<Tensor> mTempInputBatch;
39     std::shared_ptr<Tensor> mTempOutputBatch;
40     bool mNeedPretreat = false;
41     std::function<void(const uint8_t* srcBatch, uint8_t* dstBatch)> mPretreatFunction;
42 };
43 } // namespace MNN
44 
45 #endif /* Convolution1x1Strassen_hpp */
46