1 //
2 //  MetalDeconvolution.hpp
3 //  MNN
4 //
5 //  Created by MNN on 2019/01/30.
6 //  Copyright © 2018, Alibaba Group Holding Limited
7 //
8 
9 #ifndef MetalDeconvolution_hpp
10 #define MetalDeconvolution_hpp
11 
12 #import "core/Execution.hpp"
13 #import "MNN_generated.h"
14 #import "MetalDefine.h"
15 
16 #if MNN_METAL_ENABLED
17 namespace MNN {
18 
19 class MetalDeconvolution : public Execution {
20 public:
21     MetalDeconvolution(Backend *backend, const MNN::Op *op);
22     virtual ~MetalDeconvolution() = default;
23     virtual ErrorCode onResize(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) override;
24     virtual ErrorCode onExecute(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) override;
25 
26 private:
27     bool mDepthwise  = false;
28     int mGroup       = 0;
29     int mKernelX     = 0;
30     int mKernelY     = 0;
31     PadMode mPadMode = PadMode_CAFFE;
32     int mPadX        = 0;
33     int mPadY        = 0;
34     int mStrideX     = 0;
35     int mStrideY     = 0;
36     int mDilateX     = 0;
37     int mDilateY     = 0;
38 
39     id<MTLBuffer> mWeight      = nil;
40     id<MTLBuffer> mBias        = nil;
41     id<MTLBuffer> mConstBuffer = nil;
42     id<MTLComputePipelineState> mPipeline;
43     std::pair<MTLSize, MTLSize> mThreads;
44 
45 };
46 
47 } // namespace MNN
48 #endif /* MNN_METAL_ENABLED */
49 #endif /* MetalDeconvolution_hpp */
50