1 //
2 //  CUDARuntime.hpp
3 //  MNN
4 //
5 //  Created by MNN on 2019/01/31.
6 //  Copyright © 2018, Alibaba Group Holding Limited
7 //
8 
9 #ifndef OpenCLRuntime_hpp
10 #define OpenCLRuntime_hpp
11 
12 #include <map>
13 #include <memory>
14 #include <mutex>
15 #include <set>
16 #include <string>
17 #include <vector>
18 
19 #include <cublas_v2.h>
20 #include <cuda.h>
21 #include <cuda_runtime_api.h>
22 #include <cudnn.h>
23 #include <cusolverDn.h>
24 #include <sstream>
25 #include <string>
26 #include <vector>
27 #include "Type_generated.h"
28 #include "core/Macro.h"
29 #if CUDA_VERSION >= 10010
30 #include <cublasLt.h>
31 #endif
32 
33 typedef enum {
34     CUDA_FLOAT32 = 0,
35     CUDA_FLOAT16 = 1,
36 } MNNCUDADataType_t;
37 
38 typedef enum {
39     MNNMemcpyHostToDevice   = 1,
40     MNNMemcpyDeviceToHost   = 2,
41     MNNMemcpyDeviceToDevice = 3,
42 } MNNMemcpyKind_t;
43 
44 #define cuda_check(_x)             \
45     do {                           \
46         cudaError_t _err = (_x);   \
47         if (_err != cudaSuccess) { \
48             MNN_CHECK(_err, #_x);  \
49         }                          \
50     } while (0)
51 
52 #define cublas_check(_x)                     \
53     do {                                     \
54         cublasStatus_t _err = (_x);          \
55         if (_err != CUBLAS_STATUS_SUCCESS) { \
56             MNN_CHECK(_err, #_x);            \
57         }                                    \
58     } while (0)
59 
60 #define cudnn_check(_x)                     \
61     do {                                    \
62         cudnnStatus_t _err = (_x);          \
63         if (_err != CUDNN_STATUS_SUCCESS) { \
64             MNN_CHECK(_err, #_x);           \
65         }                                   \
66     } while (0)
67 
68 #define cusolver_check(_x)                     \
69     do {                                       \
70         cusolverStatus_t _err = (_x);          \
71         if (_err != CUSOLVER_STATUS_SUCCESS) { \
72             MNN_CHECK(_err, #_x);              \
73         }                                      \
74     } while (0)
75 
76 #define after_kernel_launch()           \
77     do {                                \
78         cuda_check(cudaGetLastError()); \
79     } while (0)
80 
81 namespace MNN {
82 
83 class CUDARuntime {
84 public:
85     CUDARuntime(bool permitFloat16, int device_id);
86     ~CUDARuntime();
87     CUDARuntime(const CUDARuntime &) = delete;
88     CUDARuntime &operator=(const CUDARuntime &) = delete;
89 
90     bool isSupportedFP16() const;
91     bool isSupportedDotInt8() const;
92     bool isSupportedDotAccInt8() const;
93 
94     std::vector<size_t> getMaxImage2DSize();
95     bool isCreateError() const;
96 
flops() const97     float flops() const {
98         return mFlops;
99     }
100     int device_id() const;
101     size_t mem_alignment_in_bytes() const;
102     void activate();
103     void *alloc(size_t size_in_bytes);
104     void free(void *ptr);
105 
106     void memcpy(void *dst, const void *src, size_t size_in_bytes, MNNMemcpyKind_t kind, bool sync = false);
107     void memset(void *dst, int value, size_t size_in_bytes);
108     cublasHandle_t cublas_handle();
109     cudnnHandle_t cudnn_handle();
110 
threads_num()111     int threads_num() {
112         return mThreadPerBlock;
113     }
major_sm() const114     int major_sm() const {
115         return mProp.major;
116     }
117     int blocks_num(const int total_threads);
118 
119 private:
120     cudaDeviceProp mProp;
121     int mDeviceId;
122 
123     cublasHandle_t mCublasHandle;
124     cudnnHandle_t mCudnnHandle;
125 
126     bool mIsSupportedFP16   = false;
127     bool mSupportDotInt8    = false;
128     bool mSupportDotAccInt8 = false;
129     float mFlops            = 4.0f;
130     bool mIsCreateError{false};
131     int mThreadPerBlock = 128;
132 };
133 
134 } // namespace MNN
135 #endif /* CUDARuntime_hpp */
136