1 #pragma once 2 3 #include <mutex> 4 5 #include <cublas_v2.h> 6 7 #include "chainerx/error.h" 8 9 namespace chainerx { 10 namespace cuda { 11 12 void CheckCublasError(cublasStatus_t status); 13 14 namespace cuda_internal { 15 16 class CublasHandle { 17 public: CublasHandle(int device_index)18 explicit CublasHandle(int device_index) : device_index_{device_index} {} 19 20 ~CublasHandle(); 21 22 CublasHandle(const CublasHandle&) = delete; 23 CublasHandle(CublasHandle&&) = delete; 24 CublasHandle& operator=(const CublasHandle&) = delete; 25 CublasHandle& operator=(CublasHandle&&) = delete; 26 27 template <class Func, class... Args> Call(Func && func,Args &&...args)28 void Call(Func&& func, Args&&... args) { 29 std::lock_guard<std::mutex> lock{handle_mutex_}; 30 CheckCublasError(func(handle(), args...)); 31 } 32 33 private: 34 cublasHandle_t handle(); 35 36 int device_index_; 37 std::mutex handle_mutex_{}; 38 cublasHandle_t handle_{}; 39 }; 40 41 } // namespace cuda_internal 42 43 class CublasError : public ChainerxError { 44 public: 45 explicit CublasError(cublasStatus_t status); error()46 cublasStatus_t error() const noexcept { return status_; } 47 48 private: 49 cublasStatus_t status_; 50 }; 51 52 } // namespace cuda 53 } // namespace chainerx 54