1 #include "chainerx/cuda/cudnn.h"
2 
3 #include <absl/types/optional.h>
4 #include <cudnn.h>
5 
6 #include "chainerx/array.h"
7 #include "chainerx/cuda/cuda_runtime.h"
8 #include "chainerx/dims.h"
9 #include "chainerx/dtype.h"
10 #include "chainerx/error.h"
11 #include "chainerx/macro.h"
12 
13 namespace chainerx {
14 namespace cuda {
15 
CudnnError(cudnnStatus_t status)16 CudnnError::CudnnError(cudnnStatus_t status) : ChainerxError{cudnnGetErrorString(status)}, status_{status} {}
17 
CheckCudnnError(cudnnStatus_t status)18 void CheckCudnnError(cudnnStatus_t status) {
19     if (status != CUDNN_STATUS_SUCCESS) {
20         throw CudnnError{status};
21     }
22 }
23 
24 namespace {
25 
GetCudnnDataType(Dtype dtype)26 cudnnDataType_t GetCudnnDataType(Dtype dtype) {
27     switch (dtype) {
28         case Dtype::kFloat16:
29             return CUDNN_DATA_HALF;
30         case Dtype::kFloat32:
31             return CUDNN_DATA_FLOAT;
32         case Dtype::kFloat64:
33             return CUDNN_DATA_DOUBLE;
34         case Dtype::kInt8:
35             return CUDNN_DATA_INT8;
36         case Dtype::kInt32:
37             return CUDNN_DATA_INT32;
38         default:
39             throw DtypeError{"Dtype ", dtype, " is not supported in cuDNN"};
40     }
41 }
42 
43 template <typename T, typename U, typename... ErrorArgs>
44 T narrow(U u, const ErrorArgs&... error_args) {
45     auto t = static_cast<T>(u);
46     if (static_cast<U>(t) != u) {
47         throw ChainerxError{error_args...};
48     }
49     return t;
50 }
51 
52 template <typename T>
GetIntStackVector(const T & container,const char * src)53 StackVector<int, kMaxNdim> GetIntStackVector(const T& container, const char* src) {
54     StackVector<int, kMaxNdim> int_container;
55     for (size_t i = 0; i < container.size(); ++i) {
56         int_container.emplace_back(
57                 narrow<int>(container[i], "Casting the ", src, ": ", container[i], " at dimension: ", i, " to int failed."));
58     }
59     return int_container;
60 }
61 
GetIntShape(const Shape & shape)62 StackVector<int, kMaxNdim> GetIntShape(const Shape& shape) { return GetIntStackVector(shape, "shape size"); }
63 
GetIntKernelSize(const Dims & kernel_size)64 StackVector<int, kMaxNdim> GetIntKernelSize(const Dims& kernel_size) { return GetIntStackVector(kernel_size, "kernel size"); }
65 
GetIntStride(const Dims & stride)66 StackVector<int, kMaxNdim> GetIntStride(const Dims& stride) { return GetIntStackVector(stride, "stride"); }
67 
GetIntPad(const Dims & pad)68 StackVector<int, kMaxNdim> GetIntPad(const Dims& pad) { return GetIntStackVector(pad, "pad"); }
69 
GetIntDilation(const StackVector<int64_t,kMaxNdim> & dilation)70 StackVector<int, kMaxNdim> GetIntDilation(const StackVector<int64_t, kMaxNdim>& dilation) {
71     return GetIntStackVector(dilation, "dilation");
72 }
73 
74 // Returns strides divided by item size
GetIntArrayStrides(const Strides & strides,int64_t item_size)75 StackVector<int, kMaxNdim> GetIntArrayStrides(const Strides& strides, int64_t item_size) {
76     StackVector<int, kMaxNdim> int_strides;
77     for (int8_t i = 0; i < strides.ndim(); ++i) {
78         int64_t v = strides[i] / item_size;
79         int_strides.emplace_back(
80                 narrow<int>(v, "Casting the array stride: ", v, " (in number of items) at dimension: ", i, " to int failed."));
81     }
82     return int_strides;
83 }
84 
85 }  // namespace
86 
87 namespace cuda_internal {
88 
CudnnTensorDescriptor()89 CudnnTensorDescriptor::CudnnTensorDescriptor() { CheckCudnnError(cudnnCreateTensorDescriptor(&desc_)); }
90 
~CudnnTensorDescriptor()91 CudnnTensorDescriptor::~CudnnTensorDescriptor() {
92     if (desc_ != nullptr) {
93         cudnnDestroyTensorDescriptor(desc_);
94     }
95 }
96 
CudnnTensorDescriptor(const Array & arr)97 CudnnTensorDescriptor::CudnnTensorDescriptor(const Array& arr) : CudnnTensorDescriptor{} {
98     CHAINERX_ASSERT(arr.IsContiguous());
99 
100     cudnnDataType_t cudnn_dtype = GetCudnnDataType(arr.dtype());
101     if (arr.shape().ndim() == 4) {
102         StackVector<int, kMaxNdim> nchw = GetIntShape(arr.shape());
103         CheckCudnnError(cudnnSetTensor4dDescriptor(desc_, CUDNN_TENSOR_NCHW, cudnn_dtype, nchw[0], nchw[1], nchw[2], nchw[3]));
104     } else {
105         StackVector<int, kMaxNdim> int_strides = GetIntArrayStrides(arr.strides(), arr.GetItemSize());  // strides divided by item size
106         StackVector<int, kMaxNdim> int_shape = GetIntShape(arr.shape());
107         CheckCudnnError(cudnnSetTensorNdDescriptor(desc_, cudnn_dtype, arr.ndim(), &int_shape[0], &int_strides[0]));
108     }
109 }
110 
GetDtype() const111 Dtype CudnnTensorDescriptor::GetDtype() const {
112     cudnnDataType_t cudnn_dtype{};
113     int ndim{};
114 
115     CheckCudnnError(cudnnGetTensorNdDescriptor(desc_, 0, &cudnn_dtype, &ndim, nullptr, nullptr));
116 
117     switch (cudnn_dtype) {
118         case CUDNN_DATA_HALF:
119             return Dtype::kFloat16;
120         case CUDNN_DATA_FLOAT:
121             return Dtype::kFloat32;
122         case CUDNN_DATA_DOUBLE:
123             return Dtype::kFloat64;
124         default:
125             throw DtypeError{"Unsupported cudnn data type: ", cudnn_dtype};
126     }
127 }
128 
CudnnFilterDescriptor()129 CudnnFilterDescriptor::CudnnFilterDescriptor() { CheckCudnnError(cudnnCreateFilterDescriptor(&desc_)); }
130 
~CudnnFilterDescriptor()131 CudnnFilterDescriptor::~CudnnFilterDescriptor() {
132     if (desc_ != nullptr) {
133         cudnnDestroyFilterDescriptor(desc_);
134     }
135 }
136 
CudnnFilterDescriptor(const Array & w)137 CudnnFilterDescriptor::CudnnFilterDescriptor(const Array& w) : CudnnFilterDescriptor{} {
138     CHAINERX_ASSERT(w.IsContiguous());
139 
140     cudnnDataType_t cudnn_dtype = GetCudnnDataType(w.dtype());
141     if (w.shape().ndim() == 4) {
142         StackVector<int, kMaxNdim> nchw = GetIntShape(w.shape());
143         CheckCudnnError(cudnnSetFilter4dDescriptor(desc_, cudnn_dtype, CUDNN_TENSOR_NCHW, nchw[0], nchw[1], nchw[2], nchw[3]));
144     } else {
145         StackVector<int, kMaxNdim> int_shape = GetIntShape(w.shape());
146         CheckCudnnError(cudnnSetFilterNdDescriptor(desc_, cudnn_dtype, CUDNN_TENSOR_NCHW, w.ndim(), &int_shape[0]));
147     }
148 }
149 
CudnnConvolutionDescriptor()150 CudnnConvolutionDescriptor::CudnnConvolutionDescriptor() { CheckCudnnError(cudnnCreateConvolutionDescriptor(&desc_)); }
151 
~CudnnConvolutionDescriptor()152 CudnnConvolutionDescriptor::~CudnnConvolutionDescriptor() {
153     if (desc_ != nullptr) {
154         cudnnDestroyConvolutionDescriptor(desc_);
155     }
156 }
157 
CudnnConvolutionDescriptor(Dtype dtype,const Dims & pad,const Dims & stride,const absl::optional<Dims> & dilation,int groups)158 CudnnConvolutionDescriptor::CudnnConvolutionDescriptor(
159         Dtype dtype, const Dims& pad, const Dims& stride, const absl::optional<Dims>& dilation, int groups)
160     : CudnnConvolutionDescriptor{} {
161     size_t ndim = pad.size();
162     CHAINERX_ASSERT(ndim == stride.size());
163     CHAINERX_ASSERT(!dilation || ndim == dilation->size());
164 
165     StackVector<int, kMaxNdim> int_stride = GetIntStride(stride);
166     StackVector<int, kMaxNdim> int_pad = GetIntPad(pad);
167     StackVector<int, kMaxNdim> int_dilation{};
168     if (!dilation) {
169         // TODO(sonots): Use assign(ndim, 1) if it becomes available
170         for (size_t i = 0; i < ndim; ++i) {
171             int_dilation.emplace_back(1);
172         }
173     } else {
174         int_dilation = GetIntDilation(*dilation);
175     }
176 
177     cudnnDataType_t compute_type = GetCudnnDataType(dtype);
178 
179     if (ndim == 2) {
180         CheckCudnnError(cudnnSetConvolution2dDescriptor(
181                 desc_,
182                 int_pad[0],
183                 int_pad[1],
184                 int_stride[0],
185                 int_stride[1],
186                 int_dilation[0],
187                 int_dilation[1],
188                 CUDNN_CROSS_CORRELATION,
189                 compute_type));
190     } else {
191         CheckCudnnError(cudnnSetConvolutionNdDescriptor(
192                 desc_, ndim, &int_pad[0], &int_stride[0], &int_dilation[0], CUDNN_CROSS_CORRELATION, compute_type));
193     }
194     if (groups > 1) {
195         CheckCudnnError(cudnnSetConvolutionGroupCount(desc_, groups));
196     }
197 }
198 
CudnnPoolingDescriptor()199 CudnnPoolingDescriptor::CudnnPoolingDescriptor() { CheckCudnnError(cudnnCreatePoolingDescriptor(&desc_)); }
200 
~CudnnPoolingDescriptor()201 CudnnPoolingDescriptor::~CudnnPoolingDescriptor() {
202     if (desc_ != nullptr) {
203         cudnnDestroyPoolingDescriptor(desc_);
204     }
205 }
206 
CudnnPoolingDescriptor(cudnnPoolingMode_t mode,cudnnNanPropagation_t max_pooling_nan_opt,const Dims & kernel_size,const Dims & pad,const Dims & stride)207 CudnnPoolingDescriptor::CudnnPoolingDescriptor(
208         cudnnPoolingMode_t mode, cudnnNanPropagation_t max_pooling_nan_opt, const Dims& kernel_size, const Dims& pad, const Dims& stride)
209     : CudnnPoolingDescriptor{} {
210     size_t ndim = kernel_size.size();
211     CHAINERX_ASSERT(ndim == pad.size());
212     CHAINERX_ASSERT(ndim == stride.size());
213 
214     StackVector<int, kMaxNdim> int_kernel_size = GetIntKernelSize(kernel_size);
215     StackVector<int, kMaxNdim> int_pad = GetIntPad(pad);
216     StackVector<int, kMaxNdim> int_stride = GetIntStride(stride);
217 
218     if (ndim == 2) {
219         CheckCudnnError(cudnnSetPooling2dDescriptor(
220                 desc_,
221                 mode,
222                 max_pooling_nan_opt,
223                 int_kernel_size[0],
224                 int_kernel_size[1],
225                 int_pad[0],
226                 int_pad[1],
227                 int_stride[0],
228                 int_stride[1]));
229     } else {
230         CheckCudnnError(
231                 cudnnSetPoolingNdDescriptor(desc_, mode, max_pooling_nan_opt, ndim, &int_kernel_size[0], &int_pad[0], &int_stride[0]));
232     }
233 }
234 
~CudnnHandle()235 CudnnHandle::~CudnnHandle() {
236     if (handle_ != nullptr) {
237         // TODO(hvy): Reset device upon return similar to CublasHandle?
238         cudaSetDevice(device_index_);
239         cudnnDestroy(handle_);
240     }
241 }
242 
handle()243 cudnnHandle_t CudnnHandle::handle() {
244     if (handle_ == nullptr) {
245         // TODO(hvy): Use CudaSetDeviceScope similar to CublasHandle?
246         CheckCudaError(cudaSetDevice(device_index_));
247         CheckCudnnError(cudnnCreate(&handle_));
248     }
249     return handle_;
250 }
251 
252 }  // namespace cuda_internal
253 }  // namespace cuda
254 }  // namespace chainerx
255