1 #pragma once
2
3 #include <absl/types/optional.h>
4 #include <cudnn.h>
5
6 #include "chainerx/array.h"
7 #include "chainerx/dims.h"
8 #include "chainerx/dtype.h"
9 #include "chainerx/error.h"
10 #include "chainerx/float16.h"
11 #include "chainerx/macro.h"
12
13 namespace chainerx {
14 namespace cuda {
15
16 class CudnnError : public ChainerxError {
17 public:
18 using ChainerxError::ChainerxError;
19
20 explicit CudnnError(cudnnStatus_t status);
error()21 cudnnStatus_t error() const noexcept { return status_; }
22
23 private:
24 cudnnStatus_t status_{};
25 };
26
27 void CheckCudnnError(cudnnStatus_t status);
28
29 namespace cuda_internal {
30
31 // Returns a pointer to a cuDNN coefficient value of given type, allocated on the static storage.
32 template <int kValue>
GetCudnnCoefficientPtr(Dtype dtype)33 const void* GetCudnnCoefficientPtr(Dtype dtype) {
34 // TODO(niboshi): Get rid of the assumption that native and cuda float16 share the same representation.
35 static const float kFloat32Value{kValue};
36 static const double kFloat64Value{kValue};
37
38 switch (dtype) {
39 case Dtype::kFloat16:
40 // fallthrough: cuDNN accepts float32 coefficients for float16 tensor operations.
41 case Dtype::kFloat32:
42 return &kFloat32Value;
43 case Dtype::kFloat64:
44 return &kFloat64Value;
45 default:
46 CHAINERX_NEVER_REACH();
47 }
48 }
49
50 class CudnnTensorDescriptor {
51 public:
52 CudnnTensorDescriptor();
53 explicit CudnnTensorDescriptor(const Array& arr);
54
55 ~CudnnTensorDescriptor();
56
57 CudnnTensorDescriptor(const CudnnTensorDescriptor&) = delete;
CudnnTensorDescriptor(CudnnTensorDescriptor && other)58 CudnnTensorDescriptor(CudnnTensorDescriptor&& other) noexcept : desc_{other.desc_} { other.desc_ = nullptr; }
59 CudnnTensorDescriptor& operator=(const CudnnTensorDescriptor&) = delete;
60 CudnnTensorDescriptor& operator=(CudnnTensorDescriptor&&) = delete;
61
descriptor()62 cudnnTensorDescriptor_t descriptor() const { return desc_; }
63 cudnnTensorDescriptor_t operator*() const { return desc_; }
64
65 Dtype GetDtype() const;
66
67 private:
68 cudnnTensorDescriptor_t desc_{};
69 };
70
71 class CudnnFilterDescriptor {
72 public:
73 explicit CudnnFilterDescriptor(const Array& w);
74
75 ~CudnnFilterDescriptor();
76
77 // TODO(hvy): Allow move semantics as needed.
78 CudnnFilterDescriptor(const CudnnFilterDescriptor&) = delete;
79 CudnnFilterDescriptor(CudnnFilterDescriptor&&) = delete;
80 CudnnFilterDescriptor& operator=(const CudnnFilterDescriptor&) = delete;
81 CudnnFilterDescriptor& operator=(CudnnFilterDescriptor&&) = delete;
82
descriptor()83 cudnnFilterDescriptor_t descriptor() const { return desc_; }
84 cudnnFilterDescriptor_t operator*() const { return desc_; }
85
86 private:
87 CudnnFilterDescriptor();
88 cudnnFilterDescriptor_t desc_{};
89 };
90
91 class CudnnConvolutionDescriptor {
92 public:
93 explicit CudnnConvolutionDescriptor(Dtype dtype, const Dims& pad, const Dims& stride, const absl::optional<Dims>& dilation, int groups);
94
95 ~CudnnConvolutionDescriptor();
96
97 // TODO(hvy): Allow move semantics as needed.
98 CudnnConvolutionDescriptor(const CudnnConvolutionDescriptor&) = delete;
99 CudnnConvolutionDescriptor(CudnnConvolutionDescriptor&&) = delete;
100 CudnnConvolutionDescriptor& operator=(const CudnnConvolutionDescriptor&) = delete;
101 CudnnConvolutionDescriptor& operator=(CudnnConvolutionDescriptor&&) = delete;
102
descriptor()103 cudnnConvolutionDescriptor_t descriptor() const { return desc_; }
104 cudnnConvolutionDescriptor_t operator*() const { return desc_; }
105
106 private:
107 CudnnConvolutionDescriptor();
108 cudnnConvolutionDescriptor_t desc_{};
109 };
110
111 class CudnnPoolingDescriptor {
112 public:
113 explicit CudnnPoolingDescriptor(
114 cudnnPoolingMode_t mode,
115 cudnnNanPropagation_t max_pooling_nan_opt,
116 const Dims& kernel_size,
117 const Dims& pad,
118 const Dims& stride);
119
120 ~CudnnPoolingDescriptor();
121
122 // TODO(hvy): Allow move semantics as needed.
123 CudnnPoolingDescriptor(const CudnnPoolingDescriptor&) = delete;
124 CudnnPoolingDescriptor(CudnnPoolingDescriptor&&) = delete;
125 CudnnPoolingDescriptor& operator=(const CudnnPoolingDescriptor&) = delete;
126 CudnnPoolingDescriptor& operator=(CudnnPoolingDescriptor&&) = delete;
127
descriptor()128 cudnnPoolingDescriptor_t descriptor() const { return desc_; }
129 cudnnPoolingDescriptor_t operator*() const { return desc_; }
130
131 private:
132 CudnnPoolingDescriptor();
133 cudnnPoolingDescriptor_t desc_{};
134 };
135
136 // cuDNN API calls using same handle is not thread-safe.
137 // This class ensures that the API calls are serialized using mutex lock.
138 class CudnnHandle {
139 public:
CudnnHandle(int device_index)140 explicit CudnnHandle(int device_index) : device_index_{device_index} {}
141 ~CudnnHandle();
142
143 CudnnHandle(const CudnnHandle&) = delete;
144 CudnnHandle(CudnnHandle&&) = delete;
145 CudnnHandle& operator=(const CudnnHandle&) = delete;
146 CudnnHandle& operator=(CudnnHandle&&) = delete;
147
148 template <class Func, class... Args>
Call(Func && func,Args &&...args)149 void Call(Func&& func, Args&&... args) {
150 std::lock_guard<std::mutex> lock{handle_mutex_};
151 CheckCudnnError(func(handle(), args...));
152 }
153 cudnnHandle_t handle();
154
155 private:
156 int device_index_;
157 std::mutex handle_mutex_{};
158 cudnnHandle_t handle_{};
159 };
160
161 } // namespace cuda_internal
162 } // namespace cuda
163 } // namespace chainerx
164