1 #pragma once
2 
3 #include <mutex>
4 
5 #include <cublas_v2.h>
6 
7 #include "chainerx/error.h"
8 
9 namespace chainerx {
10 namespace cuda {
11 
12 void CheckCublasError(cublasStatus_t status);
13 
14 namespace cuda_internal {
15 
16 class CublasHandle {
17 public:
CublasHandle(int device_index)18     explicit CublasHandle(int device_index) : device_index_{device_index} {}
19 
20     ~CublasHandle();
21 
22     CublasHandle(const CublasHandle&) = delete;
23     CublasHandle(CublasHandle&&) = delete;
24     CublasHandle& operator=(const CublasHandle&) = delete;
25     CublasHandle& operator=(CublasHandle&&) = delete;
26 
27     template <class Func, class... Args>
Call(Func && func,Args &&...args)28     void Call(Func&& func, Args&&... args) {
29         std::lock_guard<std::mutex> lock{handle_mutex_};
30         CheckCublasError(func(handle(), args...));
31     }
32 
33 private:
34     cublasHandle_t handle();
35 
36     int device_index_;
37     std::mutex handle_mutex_{};
38     cublasHandle_t handle_{};
39 };
40 
41 }  // namespace cuda_internal
42 
43 class CublasError : public ChainerxError {
44 public:
45     explicit CublasError(cublasStatus_t status);
error()46     cublasStatus_t error() const noexcept { return status_; }
47 
48 private:
49     cublasStatus_t status_;
50 };
51 
52 }  // namespace cuda
53 }  // namespace chainerx
54