1 // 2 // ConvBufWinograd.hpp 3 // MNN 4 // 5 // Created by MNN on 2019/02/01. 6 // Copyright © 2018, Alibaba Group Holding Limited 7 // 8 9 #ifndef MNN_OPENCL_BUFFER_CLOSED 10 11 #ifndef __CONVBUF_WINOGRAD__ 12 #define __CONVBUF_WINOGRAD__ 13 14 #include "core/Execution.hpp" 15 16 #include <array> 17 #include <memory> 18 #include <vector> 19 #include "backend/opencl/execution/buffer/ConvBufExecution.hpp" 20 #include "backend/opencl/core/OpenCLRunningUtils.hpp" 21 22 namespace MNN { 23 namespace OpenCL { 24 class ConvBufWinograd : public Execution { 25 public: 26 ConvBufWinograd(const MNN::Convolution2D* op, Backend* backend); 27 virtual ~ConvBufWinograd(); 28 29 virtual ErrorCode onResize(const std::vector<Tensor*>& inputs, const std::vector<Tensor*>& outputs) override; 30 virtual ErrorCode onExecute(const std::vector<Tensor*>& inputs, const std::vector<Tensor*>& outputs) override; 31 static bool valid(const Convolution2DCommon* common, const Tensor* input, int limit = 8192); 32 std::vector<uint32_t> getLocalWS(std::string kernelName, int index, std::vector<uint32_t> &gws, const uint32_t maxWorkGroupSize, cl::Kernel mKernel); 33 34 private: 35 OpenCLBackend* mOpenCLBackend; 36 const Convolution2DCommon* mCommon; 37 int mKernelX; 38 int mKernelY; 39 int mStrideX; 40 int mStrideY; 41 std::shared_ptr<Tensor> mWeight; 42 std::shared_ptr<Tensor> mBias; 43 44 std::shared_ptr<Tensor> mSource; 45 std::shared_ptr<Tensor> mDest; 46 47 std::vector<cl::Kernel> mSourceTransform; 48 std::vector<cl::Kernel> mDestTransform; 49 std::vector<cl::Kernel> mMatMul; 50 51 std::vector<uint32_t> mMaxWGS_S; 52 std::vector<uint32_t> mMaxWGS_D; 53 std::vector<uint32_t> mMaxWGS_M; 54 55 std::vector<std::vector<uint32_t> > mGWS_S; 56 std::vector<std::vector<uint32_t> > mGWS_D; 57 std::vector<std::vector<uint32_t> > mGWS_M; 58 59 std::vector<std::vector<uint32_t> > mLWS_S; 60 std::vector<std::vector<uint32_t> > mLWS_D; 61 std::vector<std::vector<uint32_t> > mLWS_M; 62 }; 63 64 } // namespace OpenCL 65 } // namespace MNN 66 67 #endif /* __CONVBUF_WINOGRAD__ */ 68 #endif /* MNN_OPENCL_BUFFER_CLOSED */ 69