1 // 2 // Convolution1x1Strassen.hpp 3 // MNN 4 // 5 // Created by MNN on 2019/02/12. 6 // Copyright © 2018, Alibaba Group Holding Limited 7 // 8 9 #ifndef Convolution1x1Strassen_hpp 10 #define Convolution1x1Strassen_hpp 11 12 #include <functional> 13 #include "backend/cpu/CPUConvolution.hpp" 14 #include "backend/cpu/compute/StrassenMatmulComputor.hpp" 15 namespace MNN { 16 class Convolution1x1Strassen : public CPUConvolution { 17 public: 18 Convolution1x1Strassen(const Convolution2DCommon *common, Backend *b, const float *originWeight, 19 size_t originWeightSize, const float *bias, size_t biasSize); 20 Convolution1x1Strassen(std::shared_ptr<CPUConvolution::Resource> resource, const Convolution2DCommon *common, Backend* b); 21 virtual ~Convolution1x1Strassen(); 22 23 virtual ErrorCode onExecute(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) override; 24 25 virtual ErrorCode onResize(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) override; 26 virtual bool onClone(Backend* bn, const Op* op, Execution** dst) override; 27 private: 28 std::shared_ptr<CPUConvolution::Resource> mResource; 29 30 struct Unit { 31 bool mValid = true; 32 std::vector<Tensor *> mTempInputVector; 33 std::vector<Tensor *> mTempOutputVector; 34 std::shared_ptr<StrassenMatrixComputor> mStracssenComputor; 35 }; 36 37 std::vector<Unit> mUnits; 38 std::shared_ptr<Tensor> mTempInputBatch; 39 std::shared_ptr<Tensor> mTempOutputBatch; 40 bool mNeedPretreat = false; 41 std::function<void(const uint8_t* srcBatch, uint8_t* dstBatch)> mPretreatFunction; 42 }; 43 } // namespace MNN 44 45 #endif /* Convolution1x1Strassen_hpp */ 46