1 //
2 //  CUDABackend.hpp
3 //  MNN
4 //
5 //  Created by MNN on 2019/01/31.
6 //  Copyright © 2018, Alibaba Group Holding Limited
7 //
8 
9 #ifndef CUDABackend_hpp
10 #define CUDABackend_hpp
11 
12 #include <set>
13 #include <vector>
14 #include "MNN_generated.h"
15 #include "backend/cuda/core/runtime/CUDARuntime.hpp"
16 #include "core/Backend.hpp"
17 #include "core/Macro.h"
18 #include "core/ConvolutionCommon.hpp"
19 #include "core/BufferAllocator.hpp"
20 namespace MNN {
21 namespace CUDA {
22 class MNN_PUBLIC CUDARuntimeWrapper : public Runtime {
23 public:
24     CUDARuntimeWrapper(BackendConfig::PrecisionMode precision, BackendConfig::PowerMode power);
25     virtual ~CUDARuntimeWrapper();
26     virtual Backend *onCreate(const BackendConfig* config) const override;
27     virtual void onGabageCollect(int level) override;
isCreateError() const28     bool isCreateError() const {
29         return mIsCreateError;
30     }
onGetCompilerType() const31     virtual CompilerType onGetCompilerType() const override {
32         return Compiler_Loop;
33     }
34     virtual float onGetMemoryInMB() override;
35 
36 private:
37     std::shared_ptr<BufferAllocator> mBufferPool;
38     std::shared_ptr<CUDARuntime> mCUDARuntime;
39     bool mIsCreateError{false};
40 };
41 
42 class CUDABackend final : public Backend {
43 public:
44     CUDABackend(std::shared_ptr<BufferAllocator> st, std::shared_ptr<CUDARuntime> rt);
45     ~CUDABackend();
46 
47     CUDARuntime *getCUDARuntime();
48     virtual bool onAcquireBuffer(const Tensor *nativeTensor, StorageType storageType) override;
49     virtual bool onReleaseBuffer(const Tensor *nativeTensor, StorageType storageType) override;
50     virtual bool onClearBuffer() override;
51 
52     virtual Execution *onCreate(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs,
53                                 const MNN::Op *op) override;
54     virtual void onResizeBegin() override;
55     virtual void onResizeEnd() override;
56 
57     virtual void onExecuteBegin() const override;
58     virtual void onExecuteEnd() const override;
59 
60     virtual void onCopyBuffer(const Tensor *srcTensor, const Tensor *dstTensor) const override;
61 
62     class Creator {
63     public:
64         virtual ~Creator()                                                     = default;
65         virtual Execution *onCreate(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &output,
66                                     const MNN::Op *op, Backend *backend) const = 0;
67     };
68 
69     static bool addCreator(OpType t, Creator *c);
70 
getBufferPool() const71     BufferAllocator *getBufferPool() const {
72         return mBufferPool.get();
73     }
getStaticBufferPool() const74     BufferAllocator *getStaticBufferPool() const {
75         return mStaticBufferPool.get();
76     }
77     virtual std::pair<float, bool> onMeasure(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs,
78                                              const MNN::Op *op) override;
79     static size_t realSize(const Tensor *tensor);
80 
81 private:
82     std::shared_ptr<BufferAllocator> mBufferPool;
83     std::shared_ptr<BufferAllocator> mStaticBufferPool;
84     std::shared_ptr<CUDARuntime> mCUDARuntime;
85 };
86 
87 template <class T>
88 class CUDACreatorRegister {
89 public:
CUDACreatorRegister(OpType type)90     CUDACreatorRegister(OpType type) {
91         T *t = new T;
92         CUDABackend::addCreator(type, t);
93     }
94     ~CUDACreatorRegister() = default;
95 };
96 
97 template <typename T>
98 class TypedCreator : public CUDABackend::Creator {
99 public:
100     virtual ~TypedCreator() = default;
onCreate(const std::vector<Tensor * > & inputs,const std::vector<Tensor * > & outputs,const MNN::Op * op,Backend * backend) const101     virtual Execution *onCreate(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs,
102                                 const MNN::Op *op, Backend *backend) const override {
103         return new T(inputs, op, backend);
104     }
105 };
106 
107 } // namespace CUDA
108 } // namespace MNN
109 #endif /* CUDABackend_hpp */
110