1 // This file is part of OpenCV project. 2 // It is subject to the license terms in the LICENSE file found in the top-level directory 3 // of this distribution and at http://opencv.org/license.html. 4 5 #ifndef OPENCV_DNN_CUDA4DNN_CSL_CUDNN_CONVOLUTION_HPP 6 #define OPENCV_DNN_CUDA4DNN_CSL_CUDNN_CONVOLUTION_HPP 7 8 #include "cudnn.hpp" 9 #include "activation.hpp" 10 11 #include "../pointer.hpp" 12 #include "../workspace.hpp" 13 14 #include <cudnn.h> 15 16 #include <cstddef> 17 #include <array> 18 #include <algorithm> 19 #include <vector> 20 #include <type_traits> 21 #include <iterator> 22 23 namespace cv { namespace dnn { namespace cuda4dnn { namespace csl { namespace cudnn { 24 25 /** describe convolution filters 26 * 27 * @tparam T type of elements in the kernels 28 */ 29 template <class T> 30 class FilterDescriptor { 31 public: FilterDescriptor()32 FilterDescriptor() noexcept : descriptor{ nullptr } { } 33 FilterDescriptor(const FilterDescriptor&) = delete; FilterDescriptor(FilterDescriptor && other)34 FilterDescriptor(FilterDescriptor&& other) noexcept 35 : descriptor{ other.descriptor } { 36 other.descriptor = nullptr; 37 } 38 39 /** constructs a filter descriptor from the filter dimensions provided in \p shape 40 * 41 * Shape dimensions: 42 * 0: number of filters 43 * 1: number of input feature maps 44 * 2..n: kernel dimensions 45 * 46 * Exception Guarantee: Strong 47 */ 48 template <class SequenceContainer, typename = decltype(std::begin(std::declval<SequenceContainer>()))> FilterDescriptor(const SequenceContainer & shape)49 FilterDescriptor(const SequenceContainer& shape) { 50 constructor(shape.begin(), shape.end()); 51 } 52 53 /** constructs a filter descriptor from the filter dimensions provided in [begin, end) 54 * 55 * Shape dimensions: 56 * 0: number of filters 57 * 1: number of input feature maps 58 * 2..n: kernel dimensions 59 * 60 * Exception Guarantee: Strong 61 */ 62 template <class ForwardItr, typename = typename std::enable_if<!std::is_integral<ForwardItr>::value, void>::type> // TODO is_iterator FilterDescriptor(ForwardItr begin,ForwardItr end)63 FilterDescriptor(ForwardItr begin, ForwardItr end) { 64 constructor(begin, end); 65 } 66 67 /** constructs a filter descriptor from the filter dimensions provided as arguments 68 * 69 * Shape dimensions: 70 * 0: number of filters 71 * 1: number of input feature maps 72 * 2..n: kernel dimensions 73 * 74 * Exception Guarantee: Strong 75 */ 76 template <class ...Sizes> FilterDescriptor(Sizes...sizes)77 FilterDescriptor(Sizes ...sizes) { 78 static_assert(sizeof...(Sizes) >= 3, "filter descriptors must have at least three dimensions"); 79 static_assert(sizeof...(Sizes) <= CUDNN_DIM_MAX, "required rank exceeds maximum supported rank"); 80 std::array<int, sizeof...(Sizes)> dims = { static_cast<int>(sizes)... }; 81 constructor(std::begin(dims), std::end(dims)); 82 } 83 ~FilterDescriptor()84 ~FilterDescriptor() noexcept { 85 if (descriptor != nullptr) { 86 /* cudnnDestroyFilterDescriptor will not fail for a valid descriptor object */ 87 CUDA4DNN_CHECK_CUDNN(cudnnDestroyFilterDescriptor(descriptor)); 88 } 89 } 90 91 FilterDescriptor& operator=(const FilterDescriptor&) = delete; operator =(FilterDescriptor && other)92 FilterDescriptor& operator=(FilterDescriptor&& other) noexcept { 93 descriptor = other.descriptor; 94 other.descriptor = nullptr; 95 return *this; 96 }; 97 get() const98 cudnnFilterDescriptor_t get() const noexcept { return descriptor; } 99 100 private: 101 template <class ForwardItr> constructor(ForwardItr start,ForwardItr end)102 void constructor(ForwardItr start, ForwardItr end) { 103 CV_Assert(start != end); 104 CV_Assert(std::distance(start, end) >= 3); 105 CV_Assert(std::distance(start, end) <= CUDNN_DIM_MAX); 106 107 CUDA4DNN_CHECK_CUDNN(cudnnCreateFilterDescriptor(&descriptor)); 108 try { 109 const auto rank = std::distance(start, end); 110 if (rank == 4) { 111 std::array<int, 4> dims; 112 std::copy(start, end, std::begin(dims)); 113 CUDA4DNN_CHECK_CUDNN( 114 cudnnSetFilter4dDescriptor( 115 descriptor, 116 detail::get_data_type<T>(), CUDNN_TENSOR_NCHW, 117 dims[0], dims[1], dims[2], dims[3] 118 ) 119 ); 120 } else { 121 std::vector<int> dims(start, end); 122 CUDA4DNN_CHECK_CUDNN( 123 cudnnSetFilterNdDescriptor( 124 descriptor, 125 detail::get_data_type<T>(), CUDNN_TENSOR_NCHW, 126 dims.size(), dims.data() 127 ) 128 ); 129 } 130 } catch (...) { 131 /* cudnnDestroyFilterDescriptor will not fail for a valid descriptor object */ 132 CUDA4DNN_CHECK_CUDNN(cudnnDestroyFilterDescriptor(descriptor)); 133 throw; 134 } 135 } 136 137 cudnnFilterDescriptor_t descriptor; 138 }; 139 140 /** describes a convolution operation 141 * 142 * @tparam T type of element participating in convolution 143 */ 144 template <class T> 145 class ConvolutionDescriptor { 146 public: ConvolutionDescriptor()147 ConvolutionDescriptor() noexcept : descriptor{ nullptr } { } 148 ConvolutionDescriptor(const ConvolutionDescriptor&) = delete; ConvolutionDescriptor(ConvolutionDescriptor && other)149 ConvolutionDescriptor(ConvolutionDescriptor&& other) noexcept 150 : descriptor{ other.descriptor } { 151 other.descriptor = nullptr; 152 } 153 154 /** constructs a convolution descriptor 155 * 156 * Pre-conditions: 157 * - \p zero_padding, \p stride and \p dilation must have the same size 158 * 159 * The length of the containers is interpreted as the order of the convolution. 160 * 161 * Exception Guarantee: Strong 162 */ 163 template <class SequenceContainer, typename = decltype(std::begin(std::declval<SequenceContainer>()))> ConvolutionDescriptor(const SequenceContainer & zero_padding,const SequenceContainer & stride,const SequenceContainer & dilation,std::size_t group_count)164 ConvolutionDescriptor( 165 const SequenceContainer& zero_padding, 166 const SequenceContainer& stride, 167 const SequenceContainer& dilation, 168 std::size_t group_count) 169 { 170 constructor(zero_padding, stride, dilation, group_count); 171 } 172 ~ConvolutionDescriptor()173 ~ConvolutionDescriptor() noexcept { 174 if (descriptor != nullptr) { 175 /* cudnnDestroyConvolutionDescriptor will not fail for a valid descriptor object */ 176 CUDA4DNN_CHECK_CUDNN(cudnnDestroyConvolutionDescriptor(descriptor)); 177 } 178 } 179 180 ConvolutionDescriptor& operator=(const ConvolutionDescriptor&) = delete; operator =(ConvolutionDescriptor && other)181 ConvolutionDescriptor& operator=(ConvolutionDescriptor&& other) noexcept { 182 descriptor = other.descriptor; 183 other.descriptor = nullptr; 184 return *this; 185 }; 186 get() const187 cudnnConvolutionDescriptor_t get() const noexcept { return descriptor; } 188 189 private: 190 template <class SequenceContainer> constructor(const SequenceContainer & zero_padding,const SequenceContainer & stride,const SequenceContainer & dilation,std::size_t group_count)191 void constructor( 192 const SequenceContainer& zero_padding, 193 const SequenceContainer& stride, 194 const SequenceContainer& dilation, 195 std::size_t group_count) 196 { 197 CV_Assert(zero_padding.size() == stride.size()); 198 CV_Assert(zero_padding.size() == dilation.size()); 199 200 CUDA4DNN_CHECK_CUDNN(cudnnCreateConvolutionDescriptor(&descriptor)); 201 try { 202 const auto rank = zero_padding.size(); 203 if (rank == 2) { 204 CUDA4DNN_CHECK_CUDNN( 205 cudnnSetConvolution2dDescriptor( 206 descriptor, 207 zero_padding[0], zero_padding[1], 208 stride[0], stride[1], 209 dilation[0], dilation[1], 210 CUDNN_CROSS_CORRELATION, 211 detail::get_data_type<T>() 212 ) 213 ); 214 } else { 215 std::vector<int> ipadding(std::begin(zero_padding), std::end(zero_padding)); 216 std::vector<int> istride(std::begin(stride), std::end(stride)); 217 std::vector<int> idilation(std::begin(dilation), std::end(dilation)); 218 CUDA4DNN_CHECK_CUDNN( 219 cudnnSetConvolutionNdDescriptor( 220 descriptor, 221 rank, ipadding.data(), istride.data(), idilation.data(), 222 CUDNN_CROSS_CORRELATION, 223 detail::get_data_type<T>() 224 ) 225 ); 226 } 227 CUDA4DNN_CHECK_CUDNN(cudnnSetConvolutionGroupCount(descriptor, group_count)); 228 229 #if CUDNN_MAJOR >= 8 230 /* cuDNN 7 and below use FMA math by default. cuDNN 8 includes TF32 Tensor Ops 231 * in the default setting. TF32 convolutions have lower precision than FP32. 232 * Hence, we set the math type to CUDNN_FMA_MATH to reproduce old behavior. 233 */ 234 CUDA4DNN_CHECK_CUDNN(cudnnSetConvolutionMathType(descriptor, CUDNN_FMA_MATH)); 235 #endif 236 237 if (std::is_same<T, half>::value) 238 CUDA4DNN_CHECK_CUDNN(cudnnSetConvolutionMathType(descriptor, CUDNN_TENSOR_OP_MATH)); 239 } catch (...) { 240 /* cudnnDestroyConvolutionDescriptor will not fail for a valid descriptor object */ 241 CUDA4DNN_CHECK_CUDNN(cudnnDestroyConvolutionDescriptor(descriptor)); 242 throw; 243 } 244 } 245 246 cudnnConvolutionDescriptor_t descriptor; 247 }; 248 249 /** wrapper around a convolution algorithm 250 * 251 * @tparam T type of elements being convolved 252 */ 253 template <class T> 254 class ConvolutionAlgorithm { 255 public: ConvolutionAlgorithm()256 ConvolutionAlgorithm() noexcept : workspace_size{ 0 } { } 257 ConvolutionAlgorithm(ConvolutionAlgorithm&) = default; 258 ConvolutionAlgorithm(ConvolutionAlgorithm&&) = default; 259 260 /** selects a good algorithm for convolution for given configuration 261 * 262 * Exception Guarantee: Strong 263 */ ConvolutionAlgorithm(const Handle & handle,const ConvolutionDescriptor<T> & convDesc,const FilterDescriptor<T> & filterDesc,const TensorDescriptor<T> & inputDesc,const TensorDescriptor<T> & outputDesc)264 ConvolutionAlgorithm( 265 const Handle& handle, 266 const ConvolutionDescriptor<T>& convDesc, 267 const FilterDescriptor<T>& filterDesc, 268 const TensorDescriptor<T>& inputDesc, 269 const TensorDescriptor<T>& outputDesc) 270 { 271 #if CUDNN_MAJOR >= 8 272 int requestedAlgoCount = 0, returnedAlgoCount = 0; 273 CUDA4DNN_CHECK_CUDNN(cudnnGetConvolutionForwardAlgorithmMaxCount(handle.get(), &requestedAlgoCount)); 274 std::vector<cudnnConvolutionFwdAlgoPerf_t> results(requestedAlgoCount); 275 CUDA4DNN_CHECK_CUDNN( 276 cudnnGetConvolutionForwardAlgorithm_v7( 277 handle.get(), 278 inputDesc.get(), filterDesc.get(), convDesc.get(), outputDesc.get(), 279 requestedAlgoCount, 280 &returnedAlgoCount, 281 &results[0] 282 ) 283 ); 284 285 size_t free_memory, total_memory; 286 CUDA4DNN_CHECK_CUDA(cudaMemGetInfo(&free_memory, &total_memory)); 287 288 bool found_conv_algorithm = false; 289 for (int i = 0; i < returnedAlgoCount; i++) 290 { 291 if (results[i].status == CUDNN_STATUS_SUCCESS && 292 results[i].algo != CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD_NONFUSED && 293 results[i].memory < free_memory) 294 { 295 found_conv_algorithm = true; 296 algo = results[i].algo; 297 workspace_size = results[i].memory; 298 break; 299 } 300 } 301 302 if (!found_conv_algorithm) 303 CV_Error (cv::Error::GpuApiCallError, "cuDNN did not return a suitable algorithm for convolution."); 304 #else 305 CUDA4DNN_CHECK_CUDNN( 306 cudnnGetConvolutionForwardAlgorithm( 307 handle.get(), 308 inputDesc.get(), filterDesc.get(), convDesc.get(), outputDesc.get(), 309 CUDNN_CONVOLUTION_FWD_PREFER_FASTEST, 310 0, /* no memory limit */ 311 &algo 312 ) 313 ); 314 315 CUDA4DNN_CHECK_CUDNN( 316 cudnnGetConvolutionForwardWorkspaceSize( 317 handle.get(), 318 inputDesc.get(), filterDesc.get(), convDesc.get(), outputDesc.get(), 319 algo, &workspace_size 320 ) 321 ); 322 #endif 323 } 324 325 ConvolutionAlgorithm& operator=(const ConvolutionAlgorithm&) = default; 326 ConvolutionAlgorithm& operator=(ConvolutionAlgorithm&& other) = default; 327 get() const328 cudnnConvolutionFwdAlgo_t get() const noexcept { return algo; } 329 330 /** number of bytes of workspace memory required by the algorithm */ get_workspace_size() const331 std::size_t get_workspace_size() const noexcept { return workspace_size; } 332 333 private: 334 cudnnConvolutionFwdAlgo_t algo; 335 std::size_t workspace_size; 336 }; 337 338 /** gives the shape of the output tensor of convolution 339 * 340 * Exception Guarantee: Basic 341 */ 342 template <class T> getConvolutionForwardOutputDim(const ConvolutionDescriptor<T> & convDesc,const FilterDescriptor<T> & filterDesc,const TensorDescriptor<T> & inputDesc,std::vector<int> & output)343 void getConvolutionForwardOutputDim( 344 const ConvolutionDescriptor<T>& convDesc, 345 const FilterDescriptor<T>& filterDesc, 346 const TensorDescriptor<T>& inputDesc, 347 std::vector<int>& output) 348 { 349 output.clear(); 350 output.resize(CUDNN_DIM_MAX); /* we use `output` to hold temporaries */ 351 352 std::vector<int> temp(CUDNN_DIM_MAX); 353 cudnnDataType_t tempDataType; 354 CUDA4DNN_CHECK_CUDNN( 355 cudnnGetTensorNdDescriptor( 356 inputDesc.get(), 357 CUDNN_DIM_MAX + 1, /* according to docs, this is what we do to get the rank */ 358 &tempDataType, 359 output.data(), 360 temp.data(), 361 temp.data() 362 ) 363 ); 364 365 const auto rank = output[0]; 366 output.resize(rank); 367 CUDA4DNN_CHECK_CUDNN( 368 cudnnGetConvolutionNdForwardOutputDim( 369 convDesc.get(), inputDesc.get(), filterDesc.get(), rank, output.data() 370 ) 371 ); 372 } 373 374 /** @brief performs convolution 375 * 376 * dstValue = alpha * result + beta * priorDstValue 377 * 378 * @tparam T convolution element type (must be `half` or `float`) 379 * 380 * @param handle valid cuDNN Handle 381 * @param convDesc convolution description 382 * @param convAlgo algorithm to use for convolution 383 * @param workspace workspace memory which meets the requirements of \p convAlgo 384 * @param filterDesc filter descriptor 385 * @param[in] filterPtr pointer to device memory containing the filters 386 * @param inputDesc tensor descriptor describing the input 387 * @param[in] inputPtr pointer to input tensor in device memory 388 * @param alpha result scale factor 389 * @param beta previous value scale factor 390 * @param outputDesc tensor descriptor describing the output 391 * @param[out] outputPtr pointer to output tensor in device memory 392 * 393 * Exception Guarantee: Basic 394 */ 395 template <class T> convolve(const Handle & handle,const ConvolutionDescriptor<T> & convDesc,const ConvolutionAlgorithm<T> & convAlgo,WorkspaceInstance workspace,const FilterDescriptor<T> & filterDesc,DevicePtr<const T> filterPtr,const TensorDescriptor<T> & inputDesc,DevicePtr<const T> inputPtr,T alpha,T beta,const TensorDescriptor<T> & outputDesc,DevicePtr<T> outputPtr)396 void convolve( 397 const Handle& handle, 398 const ConvolutionDescriptor<T>& convDesc, 399 const ConvolutionAlgorithm<T>& convAlgo, 400 WorkspaceInstance workspace, 401 const FilterDescriptor<T>& filterDesc, 402 DevicePtr<const T> filterPtr, 403 const TensorDescriptor<T>& inputDesc, 404 DevicePtr<const T> inputPtr, 405 T alpha, T beta, 406 const TensorDescriptor<T>& outputDesc, 407 DevicePtr<T> outputPtr) 408 { 409 CV_Assert(handle); 410 411 CUDA4DNN_CHECK_CUDNN( 412 cudnnConvolutionForward( 413 handle.get(), 414 &alpha, inputDesc.get(), inputPtr.get(), 415 filterDesc.get(), filterPtr.get(), 416 convDesc.get(), convAlgo.get(), 417 static_cast<void*>(workspace.get()), workspace.size_in_bytes(), 418 &beta, outputDesc.get(), outputPtr.get() 419 ) 420 ); 421 } 422 423 template <> inline convolve(const Handle & handle,const ConvolutionDescriptor<half> & convDesc,const ConvolutionAlgorithm<half> & convAlgo,WorkspaceInstance workspace,const FilterDescriptor<half> & filterDesc,DevicePtr<const half> filterPtr,const TensorDescriptor<half> & inputDesc,DevicePtr<const half> inputPtr,half alpha,half beta,const TensorDescriptor<half> & outputDesc,DevicePtr<half> outputPtr)424 void convolve( 425 const Handle& handle, 426 const ConvolutionDescriptor<half>& convDesc, 427 const ConvolutionAlgorithm<half>& convAlgo, 428 WorkspaceInstance workspace, 429 const FilterDescriptor<half>& filterDesc, 430 DevicePtr<const half> filterPtr, 431 const TensorDescriptor<half>& inputDesc, 432 DevicePtr<const half> inputPtr, 433 half alpha, half beta, 434 const TensorDescriptor<half>& outputDesc, 435 DevicePtr<half> outputPtr) 436 { 437 CV_Assert(handle); 438 439 /* we specalize for fp16 as the scaling factors must be provided as `float` */ 440 float alpha_ = alpha, beta_ = beta; 441 CUDA4DNN_CHECK_CUDNN( 442 cudnnConvolutionForward( 443 handle.get(), 444 &alpha_, inputDesc.get(), inputPtr.get(), 445 filterDesc.get(), filterPtr.get(), 446 convDesc.get(), convAlgo.get(), 447 static_cast<void*>(workspace.get()), workspace.size_in_bytes(), 448 &beta_, outputDesc.get(), outputPtr.get() 449 ) 450 ); 451 } 452 453 /** @brief performs convolution, bias addition and activation simultaneously 454 * 455 * dstValue = act(alpha * conv(input) + bias) 456 * 457 * @tparam T convolution element type (must be `half` or `float`) 458 * 459 * @param handle valid cuDNN Handle 460 * @param convDesc convolution description 461 * @param convAlgo algorithm to use for convolution 462 * @param workspace workspace memory which meets the requirements of \p convAlgo 463 * @param filterDesc filter descriptor 464 * @param[in] filterPtr pointer to device memory containing the filters 465 * @param alpha convolution scale factor 466 * @param inputDesc tensor descriptor describing the input 467 * @param[in] inputPtr pointer to input tensor in device memory 468 * @param biasDesc tensor descriptor describing the bias 469 * @param[in] biasPtr pointer to bias tensor in device memory 470 * @param actDesc activation descriptor 471 * @param outputDesc tensor descriptor describing the output 472 * @param[out] outputPtr pointer to output tensor in device memory 473 * 474 * Exception Guarantee: Basic 475 */ 476 template <class T> convolve_with_bias_activation(const Handle & handle,T alpha,const ConvolutionDescriptor<T> & convDesc,const ConvolutionAlgorithm<T> & convAlgo,WorkspaceInstance workspace,const FilterDescriptor<T> & filterDesc,DevicePtr<const T> filterPtr,const TensorDescriptor<T> & inputDesc,DevicePtr<const T> inputPtr,const TensorDescriptor<T> & biasDesc,DevicePtr<const T> biasPtr,const ActivationDescriptor & actDesc,const TensorDescriptor<T> & outputDesc,DevicePtr<T> outputPtr)477 void convolve_with_bias_activation( 478 const Handle& handle, 479 T alpha, 480 const ConvolutionDescriptor<T>& convDesc, 481 const ConvolutionAlgorithm<T>& convAlgo, 482 WorkspaceInstance workspace, 483 const FilterDescriptor<T>& filterDesc, 484 DevicePtr<const T> filterPtr, 485 const TensorDescriptor<T>& inputDesc, 486 DevicePtr<const T> inputPtr, 487 const TensorDescriptor<T>& biasDesc, 488 DevicePtr<const T> biasPtr, 489 const ActivationDescriptor& actDesc, 490 const TensorDescriptor<T>& outputDesc, 491 DevicePtr<T> outputPtr) 492 { 493 CV_Assert(handle); 494 495 T alpha2 = 0.0; 496 CUDA4DNN_CHECK_CUDNN(cudnnConvolutionBiasActivationForward( 497 handle.get(), 498 &alpha, inputDesc.get(), inputPtr.get(), 499 filterDesc.get(), filterPtr.get(), 500 convDesc.get(), convAlgo.get(), 501 static_cast<void*>(workspace.get()), workspace.size_in_bytes(), 502 &alpha2, outputDesc.get(), outputPtr.get(), 503 biasDesc.get(), biasPtr.get(), 504 actDesc.get(), 505 outputDesc.get(), outputPtr.get())); 506 } 507 508 template <> inline convolve_with_bias_activation(const Handle & handle,half alpha,const ConvolutionDescriptor<half> & convDesc,const ConvolutionAlgorithm<half> & convAlgo,WorkspaceInstance workspace,const FilterDescriptor<half> & filterDesc,DevicePtr<const half> filterPtr,const TensorDescriptor<half> & inputDesc,DevicePtr<const half> inputPtr,const TensorDescriptor<half> & biasDesc,DevicePtr<const half> biasPtr,const ActivationDescriptor & actDesc,const TensorDescriptor<half> & outputDesc,DevicePtr<half> outputPtr)509 void convolve_with_bias_activation( 510 const Handle& handle, 511 half alpha, 512 const ConvolutionDescriptor<half>& convDesc, 513 const ConvolutionAlgorithm<half>& convAlgo, 514 WorkspaceInstance workspace, 515 const FilterDescriptor<half>& filterDesc, 516 DevicePtr<const half> filterPtr, 517 const TensorDescriptor<half>& inputDesc, 518 DevicePtr<const half> inputPtr, 519 const TensorDescriptor<half>& biasDesc, 520 DevicePtr<const half> biasPtr, 521 const ActivationDescriptor& actDesc, 522 const TensorDescriptor<half>& outputDesc, 523 DevicePtr<half> outputPtr) 524 { 525 CV_Assert(handle); 526 527 float alpha_ = alpha, alpha2 = 0.0; 528 CUDA4DNN_CHECK_CUDNN(cudnnConvolutionBiasActivationForward( 529 handle.get(), 530 &alpha_, inputDesc.get(), inputPtr.get(), 531 filterDesc.get(), filterPtr.get(), 532 convDesc.get(), convAlgo.get(), 533 static_cast<void*>(workspace.get()), workspace.size_in_bytes(), 534 &alpha2, outputDesc.get(), outputPtr.get(), 535 biasDesc.get(), biasPtr.get(), 536 actDesc.get(), 537 outputDesc.get(), outputPtr.get())); 538 } 539 540 /** @brief performs convolution, bias addition, eltwise addition and activation simultaneously 541 * 542 * dstValue = act(alpha1 * conv(input) + bias + alpha2 * eltwise) 543 * 544 * @tparam T convolution element type (must be `half` or `float`) 545 * 546 * @param handle valid cuDNN Handle 547 * @param convDesc convolution description 548 * @param convAlgo algorithm to use for convolution 549 * @param workspace workspace memory which meets the requirements of \p convAlgo 550 * @param filterDesc filter descriptor 551 * @param[in] filterPtr pointer to device memory containing the filters 552 * @param alpha1 convolution scale factor 553 * @param inputDesc tensor descriptor describing the input 554 * @param[in] inputPtr pointer to input tensor in device memory 555 * @param biasDesc tensor descriptor describing the bias 556 * @param[in] biasPtr pointer to bias tensor in device memory 557 * @param alpha2 eltwise scale factor 558 * @param eltwiseDesc tensor descriptor describing the eltwise tensor 559 * @param[in] eltwisePtr pointer to the eltwise tensor in device memory 560 * @param actDesc activation descriptor 561 * @param outputDesc tensor descriptor describing the output 562 * @param[out] outputPtr pointer to output tensor in device memory 563 * 564 * Exception Guarantee: Basic 565 */ 566 template <class T> convolve_with_bias_eltwise_activation(const Handle & handle,T alpha1,const ConvolutionDescriptor<T> & convDesc,const ConvolutionAlgorithm<T> & convAlgo,WorkspaceInstance workspace,const FilterDescriptor<T> & filterDesc,DevicePtr<const T> filterPtr,const TensorDescriptor<T> & inputDesc,DevicePtr<const T> inputPtr,const TensorDescriptor<T> & biasDesc,DevicePtr<const T> biasPtr,T alpha2,const TensorDescriptor<T> & eltwiseDesc,DevicePtr<const T> eltwisePtr,const ActivationDescriptor & actDesc,const TensorDescriptor<T> & outputDesc,DevicePtr<T> outputPtr)567 void convolve_with_bias_eltwise_activation( 568 const Handle& handle, 569 T alpha1, 570 const ConvolutionDescriptor<T>& convDesc, 571 const ConvolutionAlgorithm<T>& convAlgo, 572 WorkspaceInstance workspace, 573 const FilterDescriptor<T>& filterDesc, 574 DevicePtr<const T> filterPtr, 575 const TensorDescriptor<T>& inputDesc, 576 DevicePtr<const T> inputPtr, 577 const TensorDescriptor<T>& biasDesc, 578 DevicePtr<const T> biasPtr, 579 T alpha2, 580 const TensorDescriptor<T>& eltwiseDesc, 581 DevicePtr<const T> eltwisePtr, 582 const ActivationDescriptor& actDesc, 583 const TensorDescriptor<T>& outputDesc, 584 DevicePtr<T> outputPtr) 585 { 586 CV_Assert(handle); 587 588 CUDA4DNN_CHECK_CUDNN(cudnnConvolutionBiasActivationForward( 589 handle.get(), 590 &alpha1, inputDesc.get(), inputPtr.get(), 591 filterDesc.get(), filterPtr.get(), 592 convDesc.get(), convAlgo.get(), 593 static_cast<void*>(workspace.get()), workspace.size_in_bytes(), 594 &alpha2, eltwiseDesc.get(), eltwisePtr.get(), 595 biasDesc.get(), biasPtr.get(), 596 actDesc.get(), 597 outputDesc.get(), outputPtr.get())); 598 } 599 600 template <> inline convolve_with_bias_eltwise_activation(const Handle & handle,half alpha1,const ConvolutionDescriptor<half> & convDesc,const ConvolutionAlgorithm<half> & convAlgo,WorkspaceInstance workspace,const FilterDescriptor<half> & filterDesc,DevicePtr<const half> filterPtr,const TensorDescriptor<half> & inputDesc,DevicePtr<const half> inputPtr,const TensorDescriptor<half> & biasDesc,DevicePtr<const half> biasPtr,half alpha2,const TensorDescriptor<half> & eltwiseDesc,DevicePtr<const half> eltwisePtr,const ActivationDescriptor & actDesc,const TensorDescriptor<half> & outputDesc,DevicePtr<half> outputPtr)601 void convolve_with_bias_eltwise_activation( 602 const Handle& handle, 603 half alpha1, 604 const ConvolutionDescriptor<half>& convDesc, 605 const ConvolutionAlgorithm<half>& convAlgo, 606 WorkspaceInstance workspace, 607 const FilterDescriptor<half>& filterDesc, 608 DevicePtr<const half> filterPtr, 609 const TensorDescriptor<half>& inputDesc, 610 DevicePtr<const half> inputPtr, 611 const TensorDescriptor<half>& biasDesc, 612 DevicePtr<const half> biasPtr, 613 half alpha2, 614 const TensorDescriptor<half>& eltwiseDesc, 615 DevicePtr<const half> eltwisePtr, 616 const ActivationDescriptor& actDesc, 617 const TensorDescriptor<half>& outputDesc, 618 DevicePtr<half> outputPtr) 619 { 620 CV_Assert(handle); 621 622 float alpha1_ = alpha1, alpha2_ = alpha2; 623 CUDA4DNN_CHECK_CUDNN(cudnnConvolutionBiasActivationForward( 624 handle.get(), 625 &alpha1_, inputDesc.get(), inputPtr.get(), 626 filterDesc.get(), filterPtr.get(), 627 convDesc.get(), convAlgo.get(), 628 static_cast<void*>(workspace.get()), workspace.size_in_bytes(), 629 &alpha2_, eltwiseDesc.get(), eltwisePtr.get(), 630 biasDesc.get(), biasPtr.get(), 631 actDesc.get(), 632 outputDesc.get(), outputPtr.get())); 633 } 634 635 }}}}} /* namespace cv::dnn::cuda4dnn::csl::cudnn */ 636 637 #endif /* OPENCV_DNN_CUDA4DNN_CSL_CUDNN_CONVOLUTION_HPP */ 638