1 #include "chainerx/cuda/cudnn.h"
2
3 #include <absl/types/optional.h>
4 #include <cudnn.h>
5
6 #include "chainerx/array.h"
7 #include "chainerx/cuda/cuda_runtime.h"
8 #include "chainerx/dims.h"
9 #include "chainerx/dtype.h"
10 #include "chainerx/error.h"
11 #include "chainerx/macro.h"
12
13 namespace chainerx {
14 namespace cuda {
15
CudnnError(cudnnStatus_t status)16 CudnnError::CudnnError(cudnnStatus_t status) : ChainerxError{cudnnGetErrorString(status)}, status_{status} {}
17
CheckCudnnError(cudnnStatus_t status)18 void CheckCudnnError(cudnnStatus_t status) {
19 if (status != CUDNN_STATUS_SUCCESS) {
20 throw CudnnError{status};
21 }
22 }
23
24 namespace {
25
GetCudnnDataType(Dtype dtype)26 cudnnDataType_t GetCudnnDataType(Dtype dtype) {
27 switch (dtype) {
28 case Dtype::kFloat16:
29 return CUDNN_DATA_HALF;
30 case Dtype::kFloat32:
31 return CUDNN_DATA_FLOAT;
32 case Dtype::kFloat64:
33 return CUDNN_DATA_DOUBLE;
34 case Dtype::kInt8:
35 return CUDNN_DATA_INT8;
36 case Dtype::kInt32:
37 return CUDNN_DATA_INT32;
38 default:
39 throw DtypeError{"Dtype ", dtype, " is not supported in cuDNN"};
40 }
41 }
42
43 template <typename T, typename U, typename... ErrorArgs>
44 T narrow(U u, const ErrorArgs&... error_args) {
45 auto t = static_cast<T>(u);
46 if (static_cast<U>(t) != u) {
47 throw ChainerxError{error_args...};
48 }
49 return t;
50 }
51
52 template <typename T>
GetIntStackVector(const T & container,const char * src)53 StackVector<int, kMaxNdim> GetIntStackVector(const T& container, const char* src) {
54 StackVector<int, kMaxNdim> int_container;
55 for (size_t i = 0; i < container.size(); ++i) {
56 int_container.emplace_back(
57 narrow<int>(container[i], "Casting the ", src, ": ", container[i], " at dimension: ", i, " to int failed."));
58 }
59 return int_container;
60 }
61
GetIntShape(const Shape & shape)62 StackVector<int, kMaxNdim> GetIntShape(const Shape& shape) { return GetIntStackVector(shape, "shape size"); }
63
GetIntKernelSize(const Dims & kernel_size)64 StackVector<int, kMaxNdim> GetIntKernelSize(const Dims& kernel_size) { return GetIntStackVector(kernel_size, "kernel size"); }
65
GetIntStride(const Dims & stride)66 StackVector<int, kMaxNdim> GetIntStride(const Dims& stride) { return GetIntStackVector(stride, "stride"); }
67
GetIntPad(const Dims & pad)68 StackVector<int, kMaxNdim> GetIntPad(const Dims& pad) { return GetIntStackVector(pad, "pad"); }
69
GetIntDilation(const StackVector<int64_t,kMaxNdim> & dilation)70 StackVector<int, kMaxNdim> GetIntDilation(const StackVector<int64_t, kMaxNdim>& dilation) {
71 return GetIntStackVector(dilation, "dilation");
72 }
73
74 // Returns strides divided by item size
GetIntArrayStrides(const Strides & strides,int64_t item_size)75 StackVector<int, kMaxNdim> GetIntArrayStrides(const Strides& strides, int64_t item_size) {
76 StackVector<int, kMaxNdim> int_strides;
77 for (int8_t i = 0; i < strides.ndim(); ++i) {
78 int64_t v = strides[i] / item_size;
79 int_strides.emplace_back(
80 narrow<int>(v, "Casting the array stride: ", v, " (in number of items) at dimension: ", i, " to int failed."));
81 }
82 return int_strides;
83 }
84
85 } // namespace
86
87 namespace cuda_internal {
88
CudnnTensorDescriptor()89 CudnnTensorDescriptor::CudnnTensorDescriptor() { CheckCudnnError(cudnnCreateTensorDescriptor(&desc_)); }
90
~CudnnTensorDescriptor()91 CudnnTensorDescriptor::~CudnnTensorDescriptor() {
92 if (desc_ != nullptr) {
93 cudnnDestroyTensorDescriptor(desc_);
94 }
95 }
96
CudnnTensorDescriptor(const Array & arr)97 CudnnTensorDescriptor::CudnnTensorDescriptor(const Array& arr) : CudnnTensorDescriptor{} {
98 CHAINERX_ASSERT(arr.IsContiguous());
99
100 cudnnDataType_t cudnn_dtype = GetCudnnDataType(arr.dtype());
101 if (arr.shape().ndim() == 4) {
102 StackVector<int, kMaxNdim> nchw = GetIntShape(arr.shape());
103 CheckCudnnError(cudnnSetTensor4dDescriptor(desc_, CUDNN_TENSOR_NCHW, cudnn_dtype, nchw[0], nchw[1], nchw[2], nchw[3]));
104 } else {
105 StackVector<int, kMaxNdim> int_strides = GetIntArrayStrides(arr.strides(), arr.GetItemSize()); // strides divided by item size
106 StackVector<int, kMaxNdim> int_shape = GetIntShape(arr.shape());
107 CheckCudnnError(cudnnSetTensorNdDescriptor(desc_, cudnn_dtype, arr.ndim(), &int_shape[0], &int_strides[0]));
108 }
109 }
110
GetDtype() const111 Dtype CudnnTensorDescriptor::GetDtype() const {
112 cudnnDataType_t cudnn_dtype{};
113 int ndim{};
114
115 CheckCudnnError(cudnnGetTensorNdDescriptor(desc_, 0, &cudnn_dtype, &ndim, nullptr, nullptr));
116
117 switch (cudnn_dtype) {
118 case CUDNN_DATA_HALF:
119 return Dtype::kFloat16;
120 case CUDNN_DATA_FLOAT:
121 return Dtype::kFloat32;
122 case CUDNN_DATA_DOUBLE:
123 return Dtype::kFloat64;
124 default:
125 throw DtypeError{"Unsupported cudnn data type: ", cudnn_dtype};
126 }
127 }
128
CudnnFilterDescriptor()129 CudnnFilterDescriptor::CudnnFilterDescriptor() { CheckCudnnError(cudnnCreateFilterDescriptor(&desc_)); }
130
~CudnnFilterDescriptor()131 CudnnFilterDescriptor::~CudnnFilterDescriptor() {
132 if (desc_ != nullptr) {
133 cudnnDestroyFilterDescriptor(desc_);
134 }
135 }
136
CudnnFilterDescriptor(const Array & w)137 CudnnFilterDescriptor::CudnnFilterDescriptor(const Array& w) : CudnnFilterDescriptor{} {
138 CHAINERX_ASSERT(w.IsContiguous());
139
140 cudnnDataType_t cudnn_dtype = GetCudnnDataType(w.dtype());
141 if (w.shape().ndim() == 4) {
142 StackVector<int, kMaxNdim> nchw = GetIntShape(w.shape());
143 CheckCudnnError(cudnnSetFilter4dDescriptor(desc_, cudnn_dtype, CUDNN_TENSOR_NCHW, nchw[0], nchw[1], nchw[2], nchw[3]));
144 } else {
145 StackVector<int, kMaxNdim> int_shape = GetIntShape(w.shape());
146 CheckCudnnError(cudnnSetFilterNdDescriptor(desc_, cudnn_dtype, CUDNN_TENSOR_NCHW, w.ndim(), &int_shape[0]));
147 }
148 }
149
CudnnConvolutionDescriptor()150 CudnnConvolutionDescriptor::CudnnConvolutionDescriptor() { CheckCudnnError(cudnnCreateConvolutionDescriptor(&desc_)); }
151
~CudnnConvolutionDescriptor()152 CudnnConvolutionDescriptor::~CudnnConvolutionDescriptor() {
153 if (desc_ != nullptr) {
154 cudnnDestroyConvolutionDescriptor(desc_);
155 }
156 }
157
CudnnConvolutionDescriptor(Dtype dtype,const Dims & pad,const Dims & stride,const absl::optional<Dims> & dilation,int groups)158 CudnnConvolutionDescriptor::CudnnConvolutionDescriptor(
159 Dtype dtype, const Dims& pad, const Dims& stride, const absl::optional<Dims>& dilation, int groups)
160 : CudnnConvolutionDescriptor{} {
161 size_t ndim = pad.size();
162 CHAINERX_ASSERT(ndim == stride.size());
163 CHAINERX_ASSERT(!dilation || ndim == dilation->size());
164
165 StackVector<int, kMaxNdim> int_stride = GetIntStride(stride);
166 StackVector<int, kMaxNdim> int_pad = GetIntPad(pad);
167 StackVector<int, kMaxNdim> int_dilation{};
168 if (!dilation) {
169 // TODO(sonots): Use assign(ndim, 1) if it becomes available
170 for (size_t i = 0; i < ndim; ++i) {
171 int_dilation.emplace_back(1);
172 }
173 } else {
174 int_dilation = GetIntDilation(*dilation);
175 }
176
177 cudnnDataType_t compute_type = GetCudnnDataType(dtype);
178
179 if (ndim == 2) {
180 CheckCudnnError(cudnnSetConvolution2dDescriptor(
181 desc_,
182 int_pad[0],
183 int_pad[1],
184 int_stride[0],
185 int_stride[1],
186 int_dilation[0],
187 int_dilation[1],
188 CUDNN_CROSS_CORRELATION,
189 compute_type));
190 } else {
191 CheckCudnnError(cudnnSetConvolutionNdDescriptor(
192 desc_, ndim, &int_pad[0], &int_stride[0], &int_dilation[0], CUDNN_CROSS_CORRELATION, compute_type));
193 }
194 if (groups > 1) {
195 CheckCudnnError(cudnnSetConvolutionGroupCount(desc_, groups));
196 }
197 }
198
CudnnPoolingDescriptor()199 CudnnPoolingDescriptor::CudnnPoolingDescriptor() { CheckCudnnError(cudnnCreatePoolingDescriptor(&desc_)); }
200
~CudnnPoolingDescriptor()201 CudnnPoolingDescriptor::~CudnnPoolingDescriptor() {
202 if (desc_ != nullptr) {
203 cudnnDestroyPoolingDescriptor(desc_);
204 }
205 }
206
CudnnPoolingDescriptor(cudnnPoolingMode_t mode,cudnnNanPropagation_t max_pooling_nan_opt,const Dims & kernel_size,const Dims & pad,const Dims & stride)207 CudnnPoolingDescriptor::CudnnPoolingDescriptor(
208 cudnnPoolingMode_t mode, cudnnNanPropagation_t max_pooling_nan_opt, const Dims& kernel_size, const Dims& pad, const Dims& stride)
209 : CudnnPoolingDescriptor{} {
210 size_t ndim = kernel_size.size();
211 CHAINERX_ASSERT(ndim == pad.size());
212 CHAINERX_ASSERT(ndim == stride.size());
213
214 StackVector<int, kMaxNdim> int_kernel_size = GetIntKernelSize(kernel_size);
215 StackVector<int, kMaxNdim> int_pad = GetIntPad(pad);
216 StackVector<int, kMaxNdim> int_stride = GetIntStride(stride);
217
218 if (ndim == 2) {
219 CheckCudnnError(cudnnSetPooling2dDescriptor(
220 desc_,
221 mode,
222 max_pooling_nan_opt,
223 int_kernel_size[0],
224 int_kernel_size[1],
225 int_pad[0],
226 int_pad[1],
227 int_stride[0],
228 int_stride[1]));
229 } else {
230 CheckCudnnError(
231 cudnnSetPoolingNdDescriptor(desc_, mode, max_pooling_nan_opt, ndim, &int_kernel_size[0], &int_pad[0], &int_stride[0]));
232 }
233 }
234
~CudnnHandle()235 CudnnHandle::~CudnnHandle() {
236 if (handle_ != nullptr) {
237 // TODO(hvy): Reset device upon return similar to CublasHandle?
238 cudaSetDevice(device_index_);
239 cudnnDestroy(handle_);
240 }
241 }
242
handle()243 cudnnHandle_t CudnnHandle::handle() {
244 if (handle_ == nullptr) {
245 // TODO(hvy): Use CudaSetDeviceScope similar to CublasHandle?
246 CheckCudaError(cudaSetDevice(device_index_));
247 CheckCudnnError(cudnnCreate(&handle_));
248 }
249 return handle_;
250 }
251
252 } // namespace cuda_internal
253 } // namespace cuda
254 } // namespace chainerx
255