1 // 2 // MetalConvolutionWinograd.hpp 3 // MNN 4 // 5 // Created by MNN on 2019/01/31. 6 // Copyright © 2018, Alibaba Group Holding Limited 7 // 8 9 #ifndef MetalConvolutionWinograd_hpp 10 #define MetalConvolutionWinograd_hpp 11 12 #import "MetalConvolutionCommon.hpp" 13 14 #if MNN_METAL_ENABLED 15 namespace MNN { 16 17 class MetalConvolutionWinograd : public MetalConvolutionCommon { 18 public: 19 static bool isValid(const Convolution2D *conv, const Tensor *input); 20 MetalConvolutionWinograd(Backend *backend, const Tensor *input, const MNN::Op *op); 21 virtual ~MetalConvolutionWinograd() = default; 22 virtual ErrorCode onResize(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) override; 23 24 protected: 25 virtual ErrorCode onFloat(const Tensor *input, const Tensor *output) override; 26 virtual id<MTLBuffer> weightForFloat(int group, int oc, int ic, int kh, int kw, const float *src) override; 27 28 private: 29 id<MTLBuffer> mShapeBuffer = nil; 30 31 int mSrcUnit; 32 int mDstUnit; 33 34 std::shared_ptr<Tensor> mTempSrc; 35 std::shared_ptr<Tensor> mTempDst; 36 37 MTLSize mInputTransformThreads; 38 MTLSize mMatMulThreads; 39 MTLSize mOutputTransformThreads; 40 }; 41 42 } // namespace MNN 43 #endif /* MNN_METAL_ENABLED */ 44 #endif /* MetalConvolutionWinograd_hpp */ 45