1 // 2 // MetalDeconvolution.hpp 3 // MNN 4 // 5 // Created by MNN on 2019/01/30. 6 // Copyright © 2018, Alibaba Group Holding Limited 7 // 8 9 #ifndef MetalDeconvolution_hpp 10 #define MetalDeconvolution_hpp 11 12 #import "core/Execution.hpp" 13 #import "MNN_generated.h" 14 #import "MetalDefine.h" 15 16 #if MNN_METAL_ENABLED 17 namespace MNN { 18 19 class MetalDeconvolution : public Execution { 20 public: 21 MetalDeconvolution(Backend *backend, const MNN::Op *op); 22 virtual ~MetalDeconvolution() = default; 23 virtual ErrorCode onResize(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) override; 24 virtual ErrorCode onExecute(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) override; 25 26 private: 27 bool mDepthwise = false; 28 int mGroup = 0; 29 int mKernelX = 0; 30 int mKernelY = 0; 31 PadMode mPadMode = PadMode_CAFFE; 32 int mPadX = 0; 33 int mPadY = 0; 34 int mStrideX = 0; 35 int mStrideY = 0; 36 int mDilateX = 0; 37 int mDilateY = 0; 38 39 id<MTLBuffer> mWeight = nil; 40 id<MTLBuffer> mBias = nil; 41 id<MTLBuffer> mConstBuffer = nil; 42 id<MTLComputePipelineState> mPipeline; 43 std::pair<MTLSize, MTLSize> mThreads; 44 45 }; 46 47 } // namespace MNN 48 #endif /* MNN_METAL_ENABLED */ 49 #endif /* MetalDeconvolution_hpp */ 50