1 //
2 //  MetalConvolutionWinograd.hpp
3 //  MNN
4 //
5 //  Created by MNN on 2019/01/31.
6 //  Copyright © 2018, Alibaba Group Holding Limited
7 //
8 
9 #ifndef MetalConvolutionWinograd_hpp
10 #define MetalConvolutionWinograd_hpp
11 
12 #import "MetalConvolutionCommon.hpp"
13 
14 #if MNN_METAL_ENABLED
15 namespace MNN {
16 
17 class MetalConvolutionWinograd : public MetalConvolutionCommon {
18 public:
19     static bool isValid(const Convolution2D *conv, const Tensor *input);
20     MetalConvolutionWinograd(Backend *backend, const Tensor *input, const MNN::Op *op);
21     virtual ~MetalConvolutionWinograd() = default;
22     virtual ErrorCode onResize(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) override;
23 
24 protected:
25     virtual ErrorCode onFloat(const Tensor *input, const Tensor *output) override;
26     virtual id<MTLBuffer> weightForFloat(int group, int oc, int ic, int kh, int kw, const float *src) override;
27 
28 private:
29     id<MTLBuffer> mShapeBuffer = nil;
30 
31     int mSrcUnit;
32     int mDstUnit;
33 
34     std::shared_ptr<Tensor> mTempSrc;
35     std::shared_ptr<Tensor> mTempDst;
36 
37     MTLSize mInputTransformThreads;
38     MTLSize mMatMulThreads;
39     MTLSize mOutputTransformThreads;
40 };
41 
42 } // namespace MNN
43 #endif /* MNN_METAL_ENABLED */
44 #endif /* MetalConvolutionWinograd_hpp */
45