1 //////////////////////////////////////////////////////////////////////
2 // This file is distributed under the University of Illinois/NCSA Open Source
3 // License.  See LICENSE file in top directory for details.
4 //
5 // Copyright (c) 2016 Jeongnim Kim and QMCPACK developers.
6 //
7 // File developed by:
8 //    Lawrence Livermore National Laboratory
9 //
10 // File created by:
11 // Miguel A. Morales, moralessilva2@llnl.gov
12 //    Lawrence Livermore National Laboratory
13 ////////////////////////////////////////////////////////////////////////////////
14 
15 #ifndef AFQMC_LAPACK_CUDA_CATCH_ALL_HPP
16 #define AFQMC_LAPACK_CUDA_CATCH_ALL_HPP
17 
18 #include <cassert>
19 #include "AFQMC/Utilities/type_conversion.hpp"
20 #include "AFQMC/Memory/custom_pointers.hpp"
21 #include "AFQMC/Numerics/detail/CPU/lapack_cpu.hpp"
22 #include "AFQMC/Numerics/detail/CUDA/cublas_wrapper.hpp"
23 #include "AFQMC/Numerics/detail/CUDA/cusolver_wrapper.hpp"
24 #include "AFQMC/Numerics/detail/CUDA/Kernels/setIdentity.cuh"
25 
26 namespace device
27 {
28 using qmcplusplus::afqmc::remove_complex;
29 
30 // hevr
31 template<typename T, class ptr, class ptrR, class ptrI>
hevr(char JOBZ,char RANGE,char UPLO,int N,ptr A,int LDA,T VL,T VU,int IL,int IU,T ABSTOL,int & M,ptrR W,ptr Z,int LDZ,ptrI ISUPPZ,ptr WORK,int & LWORK,ptrR RWORK,int & LRWORK,ptrI IWORK,int & LIWORK,int & INFO)32 inline static void hevr(char JOBZ,
33                         char RANGE,
34                         char UPLO,
35                         int N,
36                         ptr A,
37                         int LDA,
38                         T VL,
39                         T VU,
40                         int IL,
41                         int IU,
42                         T ABSTOL,
43                         int& M,
44                         ptrR W,
45                         ptr Z,
46                         int LDZ,
47                         ptrI ISUPPZ,
48                         ptr WORK,
49                         int& LWORK,
50                         ptrR RWORK,
51                         int& LRWORK,
52                         ptrI IWORK,
53                         int& LIWORK,
54                         int& INFO)
55 {
56   throw std::runtime_error("Error: Calling qmc_cuda::hevr catch all.");
57 }
58 
59 // getrf
60 template<class ptr, class ptrW, class ptrI>
getrf(const int n,const int m,ptr const a,int lda,ptrI piv,int & st,ptrW work)61 inline static void getrf(const int n, const int m, ptr const a, int lda, ptrI piv, int& st, ptrW work)
62 {
63   throw std::runtime_error("Error: Calling qmc_cuda::getrf catch all.");
64 }
65 
66 // getrfBatched
67 template<class ptr, class ptrI1, class ptrI2>
getrfBatched(const int n,ptr * a,int lda,ptrI1 piv,ptrI2 info,int batchSize)68 inline static void getrfBatched(const int n, ptr* a, int lda, ptrI1 piv, ptrI2 info, int batchSize)
69 {
70   throw std::runtime_error("Error: Calling qmc_cuda::getrfBatched catch all.");
71 }
72 
73 // getri: will fail if not called correctly, but removing checks on ptrI and ptrW for now
74 template<class ptr, class ptrI, class ptrW>
getri(int n,ptr a,int n0,ptrI piv,ptrW work,int n1,int & status)75 inline static void getri(int n, ptr a, int n0, ptrI piv, ptrW work, int n1, int& status)
76 {
77   throw std::runtime_error("Error: Calling qmc_cuda::getri catch all.");
78 }
79 
80 // getriBatched
81 template<class ptr, class ptrR, class ptrI1, class ptrI2>
getriBatched(int n,ptr * a,int lda,ptrI1 piv,ptrR * c,int lwork,ptrI2 info,int batchSize)82 inline static void getriBatched(int n, ptr* a, int lda, ptrI1 piv, ptrR* c, int lwork, ptrI2 info, int batchSize)
83 {
84   throw std::runtime_error("Error: Calling qmc_cuda::getriBatched catch all.");
85 }
86 
87 template<class ptrA, class ptrB, class ptrC>
geqrf(int M,int N,ptrA A,const int LDA,ptrB TAU,ptrC WORK,int LWORK,int & INFO)88 inline static void geqrf(int M, int N, ptrA A, const int LDA, ptrB TAU, ptrC WORK, int LWORK, int& INFO)
89 {
90   throw std::runtime_error("Error: Calling qmc_cuda::geqrf catch all.");
91 }
92 
93 template<class ptrA, class ptrB, class ptrC>
gqr(int M,int N,int K,ptrA A,const int LDA,ptrB TAU,ptrC WORK,int LWORK,int & INFO)94 void static gqr(int M, int N, int K, ptrA A, const int LDA, ptrB TAU, ptrC WORK, int LWORK, int& INFO)
95 {
96   throw std::runtime_error("Error: Calling qmc_cuda::gqr catch all.");
97 }
98 
99 template<class ptrA, class ptrB, class ptrC>
gelqf(int M,int N,ptrA A,const int LDA,ptrB TAU,ptrC WORK,int LWORK,int & INFO)100 inline static void gelqf(int M, int N, ptrA A, const int LDA, ptrB TAU, ptrC WORK, int LWORK, int& INFO)
101 {
102   throw std::runtime_error("Error: Calling qmc_cuda::gelqf catch all.");
103 }
104 
105 template<class ptrA, class ptrB, class ptrC>
glq(int M,int N,int K,ptrA A,const int LDA,ptrB TAU,ptrC WORK,int LWORK,int & INFO)106 void static glq(int M, int N, int K, ptrA A, const int LDA, ptrB TAU, ptrC WORK, int LWORK, int& INFO)
107 {
108   throw std::runtime_error("Error: Calling qmc_cuda::glq catch all.");
109 }
110 
111 } // namespace device
112 
113 #endif
114