1 //
2 //  ConvBufWinograd.hpp
3 //  MNN
4 //
5 //  Created by MNN on 2019/02/01.
6 //  Copyright © 2018, Alibaba Group Holding Limited
7 //
8 
9 #ifndef MNN_OPENCL_BUFFER_CLOSED
10 
11 #ifndef __CONVBUF_WINOGRAD__
12 #define __CONVBUF_WINOGRAD__
13 
14 #include "core/Execution.hpp"
15 
16 #include <array>
17 #include <memory>
18 #include <vector>
19 #include "backend/opencl/execution/buffer/ConvBufExecution.hpp"
20 #include "backend/opencl/core/OpenCLRunningUtils.hpp"
21 
22 namespace MNN {
23 namespace OpenCL {
24 class ConvBufWinograd : public Execution {
25 public:
26     ConvBufWinograd(const MNN::Convolution2D* op, Backend* backend);
27     virtual ~ConvBufWinograd();
28 
29     virtual ErrorCode onResize(const std::vector<Tensor*>& inputs, const std::vector<Tensor*>& outputs) override;
30     virtual ErrorCode onExecute(const std::vector<Tensor*>& inputs, const std::vector<Tensor*>& outputs) override;
31     static bool valid(const Convolution2DCommon* common, const Tensor* input, int limit = 8192);
32     std::vector<uint32_t> getLocalWS(std::string kernelName, int index, std::vector<uint32_t> &gws, const uint32_t maxWorkGroupSize, cl::Kernel mKernel);
33 
34 private:
35     OpenCLBackend* mOpenCLBackend;
36     const Convolution2DCommon* mCommon;
37     int mKernelX;
38     int mKernelY;
39     int mStrideX;
40     int mStrideY;
41     std::shared_ptr<Tensor> mWeight;
42     std::shared_ptr<Tensor> mBias;
43 
44     std::shared_ptr<Tensor> mSource;
45     std::shared_ptr<Tensor> mDest;
46 
47     std::vector<cl::Kernel> mSourceTransform;
48     std::vector<cl::Kernel> mDestTransform;
49     std::vector<cl::Kernel> mMatMul;
50 
51     std::vector<uint32_t> mMaxWGS_S;
52     std::vector<uint32_t> mMaxWGS_D;
53     std::vector<uint32_t> mMaxWGS_M;
54 
55     std::vector<std::vector<uint32_t> > mGWS_S;
56     std::vector<std::vector<uint32_t> > mGWS_D;
57     std::vector<std::vector<uint32_t> > mGWS_M;
58 
59     std::vector<std::vector<uint32_t> > mLWS_S;
60     std::vector<std::vector<uint32_t> > mLWS_D;
61     std::vector<std::vector<uint32_t> > mLWS_M;
62 };
63 
64 } // namespace OpenCL
65 } // namespace MNN
66 
67 #endif /* __CONVBUF_WINOGRAD__ */
68 #endif /* MNN_OPENCL_BUFFER_CLOSED */
69