1 ///////////////////////////////////////////////////////////////////////////////
2 // This file is distributed under the University of Illinois/NCSA Open Source
3 // License.  See LICENSE file in top directory for details.
4 //
5 // Copyright (c) 2020 QMCPACK developers.
6 //
7 // File developed by: Fionn Malone, malone14@llnl.gov, LLNL
8 //
9 // File created by: Fionn Malone, malone14@llnl.gov, LLNL
10 ////////////////////////////////////////////////////////////////////////////////
11 
12 #ifndef AFQMC_AUWN_BUN_CUW_KERNELS_HPP
13 #define AFQMC_AUWN_BUN_CUW_KERNELS_HPP
14 
15 #include <complex>
16 
17 namespace kernels
18 {
19 // C[u][w] = alpha * sum_a A[u][w][a] * B[u][a]
20 void Auwn_Bun_Cuw(int nu,
21                   int nw,
22                   int na,
23                   std::complex<double> alpha,
24                   std::complex<double> const* A,
25                   std::complex<double> const* B,
26                   std::complex<double>* C);
27 void Auwn_Bun_Cuw(int nu,
28                   int nw,
29                   int na,
30                   std::complex<float> alpha,
31                   std::complex<float> const* A,
32                   std::complex<float> const* B,
33                   std::complex<float>* C);
34 
35 // C[u][w] = alpha * sum_i A[w][i][u] * B[i][u]
36 void Awiu_Biu_Cuw(int nu,
37                   int nw,
38                   int ni,
39                   std::complex<double> alpha,
40                   std::complex<double> const* A,
41                   double const* B,
42                   int ldb,
43                   std::complex<double>* C,
44                   int ldc);
45 void Awiu_Biu_Cuw(int nu,
46                   int nw,
47                   int ni,
48                   std::complex<float> alpha,
49                   std::complex<float> const* A,
50                   float const* B,
51                   int ldb,
52                   std::complex<float>* C,
53                   int ldc);
54 void Awiu_Biu_Cuw(int nu,
55                   int nw,
56                   int ni,
57                   std::complex<double> alpha,
58                   std::complex<double> const* A,
59                   std::complex<double> const* B,
60                   int ldb,
61                   std::complex<double>* C,
62                   int ldc);
63 void Awiu_Biu_Cuw(int nu,
64                   int nw,
65                   int ni,
66                   std::complex<float> alpha,
67                   std::complex<float> const* A,
68                   std::complex<float> const* B,
69                   int ldb,
70                   std::complex<float>* C,
71                   int ldc);
72 
73 // C[i][k] = sum_i A[i][j][k] * B[k][j]
74 void Aijk_Bkj_Cik(int ni,
75                   int nj,
76                   int nk,
77                   std::complex<double> const* A,
78                   int lda,
79                   int stride,
80                   std::complex<double> const* B,
81                   int ldb,
82                   std::complex<double>* C,
83                   int ldc);
84 void Aijk_Bkj_Cik(int ni,
85                   int nj,
86                   int nk,
87                   std::complex<double> const* A,
88                   int lda,
89                   int stride,
90                   double const* B,
91                   int ldb,
92                   std::complex<double>* C,
93                   int ldc);
94 void Aijk_Bkj_Cik(int ni,
95                   int nj,
96                   int nk,
97                   std::complex<float> const* A,
98                   int lda,
99                   int stride,
100                   std::complex<float> const* B,
101                   int ldb,
102                   std::complex<float>* C,
103                   int ldc);
104 void Aijk_Bkj_Cik(int ni,
105                   int nj,
106                   int nk,
107                   std::complex<float> const* A,
108                   int lda,
109                   int stride,
110                   float const* B,
111                   int ldb,
112                   std::complex<float>* C,
113                   int ldc);
114 
115 // A[w][i][j] = B[i][w][j]
116 void viwj_vwij(int nw, int ni, int i0, int iN, std::complex<double> const* B, std::complex<double>* A);
117 void viwj_vwij(int nw, int ni, int i0, int iN, std::complex<double> const* B, std::complex<float>* A);
118 void viwj_vwij(int nw, int ni, int i0, int iN, std::complex<float> const* B, std::complex<double>* A);
119 void viwj_vwij(int nw, int ni, int i0, int iN, std::complex<float> const* B, std::complex<float>* A);
120 
121 // element-wise C[k][i][j] = A[i][j] * B[j][k]
122 void element_wise_Aij_Bjk_Ckij(char transA,
123                                int ni,
124                                int nj,
125                                int nk,
126                                double const* A,
127                                int lda,
128                                std::complex<double> const* B,
129                                int ldb,
130                                std::complex<double>* C,
131                                int ldc1,
132                                int ldc2);
133 void element_wise_Aij_Bjk_Ckij(char transA,
134                                int ni,
135                                int nj,
136                                int nk,
137                                float const* A,
138                                int lda,
139                                std::complex<float> const* B,
140                                int ldb,
141                                std::complex<float>* C,
142                                int ldc1,
143                                int ldc2);
144 void element_wise_Aij_Bjk_Ckij(char transA,
145                                int ni,
146                                int nj,
147                                int nk,
148                                std::complex<double> const* A,
149                                int lda,
150                                std::complex<double> const* B,
151                                int ldb,
152                                std::complex<double>* C,
153                                int ldc1,
154                                int ldc2);
155 void element_wise_Aij_Bjk_Ckij(char transA,
156                                int ni,
157                                int nj,
158                                int nk,
159                                std::complex<float> const* A,
160                                int lda,
161                                std::complex<float> const* B,
162                                int ldb,
163                                std::complex<float>* C,
164                                int ldc1,
165                                int ldc2);
166 
167 // element-wise C[k][j][i] = A[i][j] * B[j][k]
168 void element_wise_Aij_Bjk_Ckji(int ni,
169                                int nj,
170                                int nk,
171                                double const* A,
172                                int lda,
173                                std::complex<double> const* B,
174                                int ldb,
175                                std::complex<double>* C,
176                                int ldc,
177                                int stride);
178 void element_wise_Aij_Bjk_Ckji(int ni,
179                                int nj,
180                                int nk,
181                                float const* A,
182                                int lda,
183                                std::complex<float> const* B,
184                                int ldb,
185                                std::complex<float>* C,
186                                int ldc,
187                                int stride);
188 void element_wise_Aij_Bjk_Ckji(int ni,
189                                int nj,
190                                int nk,
191                                std::complex<double> const* A,
192                                int lda,
193                                std::complex<double> const* B,
194                                int ldb,
195                                std::complex<double>* C,
196                                int ldc,
197                                int stride);
198 void element_wise_Aij_Bjk_Ckji(int ni,
199                                int nj,
200                                int nk,
201                                std::complex<float> const* A,
202                                int lda,
203                                std::complex<float> const* B,
204                                int ldb,
205                                std::complex<float>* C,
206                                int ldc,
207                                int stride);
208 
209 } // namespace kernels
210 
211 #endif
212