1 //////////////////////////////////////////////////////////////////////
2 // This file is distributed under the University of Illinois/NCSA Open Source
3 // License. See LICENSE file in top directory for details.
4 //
5 // Copyright (c) 2016 Jeongnim Kim and QMCPACK developers.
6 //
7 // File developed by:
8 // Lawrence Livermore National Laboratory
9 //
10 // File created by:
11 // Miguel A. Morales, moralessilva2@llnl.gov
12 // Lawrence Livermore National Laboratory
13 ////////////////////////////////////////////////////////////////////////////////
14
15 #ifndef CUBLASXT_FUNCTIONDEFS_H
16 #define CUBLASXT_FUNCTIONDEFS_H
17
18 #include <cassert>
19 #include <cuda_runtime.h>
20 #include "cublas_v2.h"
21 #include "cublasXt.h"
22 #include "AFQMC/Memory/CUDA/cuda_utilities.h"
23
24 namespace cublas
25 {
26 using qmc_cuda::cublasOperation;
27
28 // cublasXt Level 3
cublasXt_gemm(cublasXtHandle_t handle,char Atrans,char Btrans,int M,int N,int K,const float alpha,const float * A,int lda,const float * B,int ldb,const float beta,float * C,int ldc)29 inline cublasStatus_t cublasXt_gemm(cublasXtHandle_t handle,
30 char Atrans,
31 char Btrans,
32 int M,
33 int N,
34 int K,
35 const float alpha,
36 const float* A,
37 int lda,
38 const float* B,
39 int ldb,
40 const float beta,
41 float* C,
42 int ldc)
43 {
44 cublasStatus_t sucess = cublasXtSgemm(handle, cublasOperation(Atrans), cublasOperation(Btrans), M, N, K, &alpha, A,
45 lda, B, ldb, &beta, C, ldc);
46 cudaDeviceSynchronize();
47 return sucess;
48 }
49
cublasXt_gemm(cublasXtHandle_t handle,char Atrans,char Btrans,int M,int N,int K,const double alpha,const double * A,int lda,const double * B,int ldb,const double beta,double * C,int ldc)50 inline cublasStatus_t cublasXt_gemm(cublasXtHandle_t handle,
51 char Atrans,
52 char Btrans,
53 int M,
54 int N,
55 int K,
56 const double alpha,
57 const double* A,
58 int lda,
59 const double* B,
60 int ldb,
61 const double beta,
62 double* C,
63 int ldc)
64 {
65 cublasStatus_t sucess = cublasXtDgemm(handle, cublasOperation(Atrans), cublasOperation(Btrans), M, N, K, &alpha, A,
66 lda, B, ldb, &beta, C, ldc);
67 /*
68 std::cout<<" Dgemm error message " <<sucess <<std::endl;
69 using std::cout;
70 using std::endl;
71 switch(sucess)
72 {
73 case CUBLAS_STATUS_NOT_INITIALIZED:
74 std::cout<<"CUBLAS_STATUS_NOT_INITIALIZED";
75 break;
76 case CUBLAS_STATUS_ALLOC_FAILED:
77 cout<<"CUBLAS_STATUS_ALLOC_FAILED";
78 break;
79 case CUBLAS_STATUS_INVALID_VALUE:
80 cout<<"CUBLAS_STATUS_INVALID_VALUE";
81 break;
82 case CUBLAS_STATUS_EXECUTION_FAILED:
83 cout<<"CUBLAS_STATUS_EXECUTION_FAILED";
84 break;
85 }
86 std::cout<<std::endl;
87 */
88 cudaDeviceSynchronize();
89 return sucess;
90 }
91
cublasXt_gemm(cublasXtHandle_t handle,char Atrans,char Btrans,int M,int N,int K,const std::complex<float> alpha,const std::complex<float> * A,int lda,const std::complex<float> * B,int ldb,const std::complex<float> beta,std::complex<float> * C,int ldc)92 inline cublasStatus_t cublasXt_gemm(cublasXtHandle_t handle,
93 char Atrans,
94 char Btrans,
95 int M,
96 int N,
97 int K,
98 const std::complex<float> alpha,
99 const std::complex<float>* A,
100 int lda,
101 const std::complex<float>* B,
102 int ldb,
103 const std::complex<float> beta,
104 std::complex<float>* C,
105 int ldc)
106 {
107 cublasStatus_t sucess =
108 cublasXtCgemm(handle, cublasOperation(Atrans), cublasOperation(Btrans), M, N, K,
109 reinterpret_cast<cuComplex const*>(&alpha), reinterpret_cast<cuComplex const*>(A), lda,
110 reinterpret_cast<cuComplex const*>(B), ldb, reinterpret_cast<cuComplex const*>(&beta),
111 reinterpret_cast<cuComplex*>(C), ldc);
112 cudaDeviceSynchronize();
113 return sucess;
114 }
115
cublasXt_gemm(cublasXtHandle_t handle,char Atrans,char Btrans,int M,int N,int K,const std::complex<double> alpha,const std::complex<double> * A,int lda,const std::complex<double> * B,int ldb,const std::complex<double> beta,std::complex<double> * C,int ldc)116 inline cublasStatus_t cublasXt_gemm(cublasXtHandle_t handle,
117 char Atrans,
118 char Btrans,
119 int M,
120 int N,
121 int K,
122 const std::complex<double> alpha,
123 const std::complex<double>* A,
124 int lda,
125 const std::complex<double>* B,
126 int ldb,
127 const std::complex<double> beta,
128 std::complex<double>* C,
129 int ldc)
130 {
131 cublasStatus_t sucess =
132 cublasXtZgemm(handle, cublasOperation(Atrans), cublasOperation(Btrans), M, N, K,
133 reinterpret_cast<cuDoubleComplex const*>(&alpha), reinterpret_cast<cuDoubleComplex const*>(A), lda,
134 reinterpret_cast<cuDoubleComplex const*>(B), ldb, reinterpret_cast<cuDoubleComplex const*>(&beta),
135 reinterpret_cast<cuDoubleComplex*>(C), ldc);
136 cudaDeviceSynchronize();
137 return sucess;
138 }
139
140 } // namespace cublas
141
142 #endif
143