1 #pragma once
2 
3 #include <absl/types/optional.h>
4 #include <cudnn.h>
5 
6 #include "chainerx/array.h"
7 #include "chainerx/dims.h"
8 #include "chainerx/dtype.h"
9 #include "chainerx/error.h"
10 #include "chainerx/float16.h"
11 #include "chainerx/macro.h"
12 
13 namespace chainerx {
14 namespace cuda {
15 
16 class CudnnError : public ChainerxError {
17 public:
18     using ChainerxError::ChainerxError;
19 
20     explicit CudnnError(cudnnStatus_t status);
error()21     cudnnStatus_t error() const noexcept { return status_; }
22 
23 private:
24     cudnnStatus_t status_{};
25 };
26 
27 void CheckCudnnError(cudnnStatus_t status);
28 
29 namespace cuda_internal {
30 
31 // Returns a pointer to a cuDNN coefficient value of given type, allocated on the static storage.
32 template <int kValue>
GetCudnnCoefficientPtr(Dtype dtype)33 const void* GetCudnnCoefficientPtr(Dtype dtype) {
34     // TODO(niboshi): Get rid of the assumption that native and cuda float16 share the same representation.
35     static const float kFloat32Value{kValue};
36     static const double kFloat64Value{kValue};
37 
38     switch (dtype) {
39         case Dtype::kFloat16:
40             // fallthrough: cuDNN accepts float32 coefficients for float16 tensor operations.
41         case Dtype::kFloat32:
42             return &kFloat32Value;
43         case Dtype::kFloat64:
44             return &kFloat64Value;
45         default:
46             CHAINERX_NEVER_REACH();
47     }
48 }
49 
50 class CudnnTensorDescriptor {
51 public:
52     CudnnTensorDescriptor();
53     explicit CudnnTensorDescriptor(const Array& arr);
54 
55     ~CudnnTensorDescriptor();
56 
57     CudnnTensorDescriptor(const CudnnTensorDescriptor&) = delete;
CudnnTensorDescriptor(CudnnTensorDescriptor && other)58     CudnnTensorDescriptor(CudnnTensorDescriptor&& other) noexcept : desc_{other.desc_} { other.desc_ = nullptr; }
59     CudnnTensorDescriptor& operator=(const CudnnTensorDescriptor&) = delete;
60     CudnnTensorDescriptor& operator=(CudnnTensorDescriptor&&) = delete;
61 
descriptor()62     cudnnTensorDescriptor_t descriptor() const { return desc_; }
63     cudnnTensorDescriptor_t operator*() const { return desc_; }
64 
65     Dtype GetDtype() const;
66 
67 private:
68     cudnnTensorDescriptor_t desc_{};
69 };
70 
71 class CudnnFilterDescriptor {
72 public:
73     explicit CudnnFilterDescriptor(const Array& w);
74 
75     ~CudnnFilterDescriptor();
76 
77     // TODO(hvy): Allow move semantics as needed.
78     CudnnFilterDescriptor(const CudnnFilterDescriptor&) = delete;
79     CudnnFilterDescriptor(CudnnFilterDescriptor&&) = delete;
80     CudnnFilterDescriptor& operator=(const CudnnFilterDescriptor&) = delete;
81     CudnnFilterDescriptor& operator=(CudnnFilterDescriptor&&) = delete;
82 
descriptor()83     cudnnFilterDescriptor_t descriptor() const { return desc_; }
84     cudnnFilterDescriptor_t operator*() const { return desc_; }
85 
86 private:
87     CudnnFilterDescriptor();
88     cudnnFilterDescriptor_t desc_{};
89 };
90 
91 class CudnnConvolutionDescriptor {
92 public:
93     explicit CudnnConvolutionDescriptor(Dtype dtype, const Dims& pad, const Dims& stride, const absl::optional<Dims>& dilation, int groups);
94 
95     ~CudnnConvolutionDescriptor();
96 
97     // TODO(hvy): Allow move semantics as needed.
98     CudnnConvolutionDescriptor(const CudnnConvolutionDescriptor&) = delete;
99     CudnnConvolutionDescriptor(CudnnConvolutionDescriptor&&) = delete;
100     CudnnConvolutionDescriptor& operator=(const CudnnConvolutionDescriptor&) = delete;
101     CudnnConvolutionDescriptor& operator=(CudnnConvolutionDescriptor&&) = delete;
102 
descriptor()103     cudnnConvolutionDescriptor_t descriptor() const { return desc_; }
104     cudnnConvolutionDescriptor_t operator*() const { return desc_; }
105 
106 private:
107     CudnnConvolutionDescriptor();
108     cudnnConvolutionDescriptor_t desc_{};
109 };
110 
111 class CudnnPoolingDescriptor {
112 public:
113     explicit CudnnPoolingDescriptor(
114             cudnnPoolingMode_t mode,
115             cudnnNanPropagation_t max_pooling_nan_opt,
116             const Dims& kernel_size,
117             const Dims& pad,
118             const Dims& stride);
119 
120     ~CudnnPoolingDescriptor();
121 
122     // TODO(hvy): Allow move semantics as needed.
123     CudnnPoolingDescriptor(const CudnnPoolingDescriptor&) = delete;
124     CudnnPoolingDescriptor(CudnnPoolingDescriptor&&) = delete;
125     CudnnPoolingDescriptor& operator=(const CudnnPoolingDescriptor&) = delete;
126     CudnnPoolingDescriptor& operator=(CudnnPoolingDescriptor&&) = delete;
127 
descriptor()128     cudnnPoolingDescriptor_t descriptor() const { return desc_; }
129     cudnnPoolingDescriptor_t operator*() const { return desc_; }
130 
131 private:
132     CudnnPoolingDescriptor();
133     cudnnPoolingDescriptor_t desc_{};
134 };
135 
136 // cuDNN API calls using same handle is not thread-safe.
137 // This class ensures that the API calls are serialized using mutex lock.
138 class CudnnHandle {
139 public:
CudnnHandle(int device_index)140     explicit CudnnHandle(int device_index) : device_index_{device_index} {}
141     ~CudnnHandle();
142 
143     CudnnHandle(const CudnnHandle&) = delete;
144     CudnnHandle(CudnnHandle&&) = delete;
145     CudnnHandle& operator=(const CudnnHandle&) = delete;
146     CudnnHandle& operator=(CudnnHandle&&) = delete;
147 
148     template <class Func, class... Args>
Call(Func && func,Args &&...args)149     void Call(Func&& func, Args&&... args) {
150         std::lock_guard<std::mutex> lock{handle_mutex_};
151         CheckCudnnError(func(handle(), args...));
152     }
153     cudnnHandle_t handle();
154 
155 private:
156     int device_index_;
157     std::mutex handle_mutex_{};
158     cudnnHandle_t handle_{};
159 };
160 
161 }  // namespace cuda_internal
162 }  // namespace cuda
163 }  // namespace chainerx
164