1 //////////////////////////////////////////////////////////////////////
2 // This file is distributed under the University of Illinois/NCSA Open Source
3 // License.  See LICENSE file in top directory for details.
4 //
5 // Copyright (c) 2016 Jeongnim Kim and QMCPACK developers.
6 //
7 // File developed by:
8 //    Lawrence Livermore National Laboratory
9 //
10 // File created by:
11 // Miguel A. Morales, moralessilva2@llnl.gov
12 //    Lawrence Livermore National Laboratory
13 ////////////////////////////////////////////////////////////////////////////////
14 
15 #ifndef CUBLASXT_FUNCTIONDEFS_H
16 #define CUBLASXT_FUNCTIONDEFS_H
17 
18 #include <cassert>
19 #include <cuda_runtime.h>
20 #include "cublas_v2.h"
21 #include "cublasXt.h"
22 #include "AFQMC/Memory/CUDA/cuda_utilities.h"
23 
24 namespace cublas
25 {
26 using qmc_cuda::cublasOperation;
27 
28 // cublasXt Level 3
cublasXt_gemm(cublasXtHandle_t handle,char Atrans,char Btrans,int M,int N,int K,const float alpha,const float * A,int lda,const float * B,int ldb,const float beta,float * C,int ldc)29 inline cublasStatus_t cublasXt_gemm(cublasXtHandle_t handle,
30                                     char Atrans,
31                                     char Btrans,
32                                     int M,
33                                     int N,
34                                     int K,
35                                     const float alpha,
36                                     const float* A,
37                                     int lda,
38                                     const float* B,
39                                     int ldb,
40                                     const float beta,
41                                     float* C,
42                                     int ldc)
43 {
44   cublasStatus_t sucess = cublasXtSgemm(handle, cublasOperation(Atrans), cublasOperation(Btrans), M, N, K, &alpha, A,
45                                         lda, B, ldb, &beta, C, ldc);
46   cudaDeviceSynchronize();
47   return sucess;
48 }
49 
cublasXt_gemm(cublasXtHandle_t handle,char Atrans,char Btrans,int M,int N,int K,const double alpha,const double * A,int lda,const double * B,int ldb,const double beta,double * C,int ldc)50 inline cublasStatus_t cublasXt_gemm(cublasXtHandle_t handle,
51                                     char Atrans,
52                                     char Btrans,
53                                     int M,
54                                     int N,
55                                     int K,
56                                     const double alpha,
57                                     const double* A,
58                                     int lda,
59                                     const double* B,
60                                     int ldb,
61                                     const double beta,
62                                     double* C,
63                                     int ldc)
64 {
65   cublasStatus_t sucess = cublasXtDgemm(handle, cublasOperation(Atrans), cublasOperation(Btrans), M, N, K, &alpha, A,
66                                         lda, B, ldb, &beta, C, ldc);
67   /*
68 std::cout<<" Dgemm error message " <<sucess <<std::endl;
69 using std::cout;
70 using std::endl;
71 switch(sucess)
72 {
73   case CUBLAS_STATUS_NOT_INITIALIZED:
74     std::cout<<"CUBLAS_STATUS_NOT_INITIALIZED";
75     break;
76   case CUBLAS_STATUS_ALLOC_FAILED:
77     cout<<"CUBLAS_STATUS_ALLOC_FAILED";
78     break;
79   case CUBLAS_STATUS_INVALID_VALUE:
80     cout<<"CUBLAS_STATUS_INVALID_VALUE";
81     break;
82   case CUBLAS_STATUS_EXECUTION_FAILED:
83     cout<<"CUBLAS_STATUS_EXECUTION_FAILED";
84     break;
85 }
86 std::cout<<std::endl;
87 */
88   cudaDeviceSynchronize();
89   return sucess;
90 }
91 
cublasXt_gemm(cublasXtHandle_t handle,char Atrans,char Btrans,int M,int N,int K,const std::complex<float> alpha,const std::complex<float> * A,int lda,const std::complex<float> * B,int ldb,const std::complex<float> beta,std::complex<float> * C,int ldc)92 inline cublasStatus_t cublasXt_gemm(cublasXtHandle_t handle,
93                                     char Atrans,
94                                     char Btrans,
95                                     int M,
96                                     int N,
97                                     int K,
98                                     const std::complex<float> alpha,
99                                     const std::complex<float>* A,
100                                     int lda,
101                                     const std::complex<float>* B,
102                                     int ldb,
103                                     const std::complex<float> beta,
104                                     std::complex<float>* C,
105                                     int ldc)
106 {
107   cublasStatus_t sucess =
108       cublasXtCgemm(handle, cublasOperation(Atrans), cublasOperation(Btrans), M, N, K,
109                     reinterpret_cast<cuComplex const*>(&alpha), reinterpret_cast<cuComplex const*>(A), lda,
110                     reinterpret_cast<cuComplex const*>(B), ldb, reinterpret_cast<cuComplex const*>(&beta),
111                     reinterpret_cast<cuComplex*>(C), ldc);
112   cudaDeviceSynchronize();
113   return sucess;
114 }
115 
cublasXt_gemm(cublasXtHandle_t handle,char Atrans,char Btrans,int M,int N,int K,const std::complex<double> alpha,const std::complex<double> * A,int lda,const std::complex<double> * B,int ldb,const std::complex<double> beta,std::complex<double> * C,int ldc)116 inline cublasStatus_t cublasXt_gemm(cublasXtHandle_t handle,
117                                     char Atrans,
118                                     char Btrans,
119                                     int M,
120                                     int N,
121                                     int K,
122                                     const std::complex<double> alpha,
123                                     const std::complex<double>* A,
124                                     int lda,
125                                     const std::complex<double>* B,
126                                     int ldb,
127                                     const std::complex<double> beta,
128                                     std::complex<double>* C,
129                                     int ldc)
130 {
131   cublasStatus_t sucess =
132       cublasXtZgemm(handle, cublasOperation(Atrans), cublasOperation(Btrans), M, N, K,
133                     reinterpret_cast<cuDoubleComplex const*>(&alpha), reinterpret_cast<cuDoubleComplex const*>(A), lda,
134                     reinterpret_cast<cuDoubleComplex const*>(B), ldb, reinterpret_cast<cuDoubleComplex const*>(&beta),
135                     reinterpret_cast<cuDoubleComplex*>(C), ldc);
136   cudaDeviceSynchronize();
137   return sucess;
138 }
139 
140 } // namespace cublas
141 
142 #endif
143