1 /////////////////////////////////////////////////////////////////////////////// 2 // This file is distributed under the University of Illinois/NCSA Open Source 3 // License. See LICENSE file in top directory for details. 4 // 5 // Copyright (c) 2020 QMCPACK developers. 6 // 7 // File developed by: Fionn Malone, malone14@llnl.gov, LLNL 8 // 9 // File created by: Fionn Malone, malone14@llnl.gov, LLNL 10 //////////////////////////////////////////////////////////////////////////////// 11 12 #ifndef AFQMC_AUWN_BUN_CUW_KERNELS_HPP 13 #define AFQMC_AUWN_BUN_CUW_KERNELS_HPP 14 15 #include <complex> 16 17 namespace kernels 18 { 19 // C[u][w] = alpha * sum_a A[u][w][a] * B[u][a] 20 void Auwn_Bun_Cuw(int nu, 21 int nw, 22 int na, 23 std::complex<double> alpha, 24 std::complex<double> const* A, 25 std::complex<double> const* B, 26 std::complex<double>* C); 27 void Auwn_Bun_Cuw(int nu, 28 int nw, 29 int na, 30 std::complex<float> alpha, 31 std::complex<float> const* A, 32 std::complex<float> const* B, 33 std::complex<float>* C); 34 35 // C[u][w] = alpha * sum_i A[w][i][u] * B[i][u] 36 void Awiu_Biu_Cuw(int nu, 37 int nw, 38 int ni, 39 std::complex<double> alpha, 40 std::complex<double> const* A, 41 double const* B, 42 int ldb, 43 std::complex<double>* C, 44 int ldc); 45 void Awiu_Biu_Cuw(int nu, 46 int nw, 47 int ni, 48 std::complex<float> alpha, 49 std::complex<float> const* A, 50 float const* B, 51 int ldb, 52 std::complex<float>* C, 53 int ldc); 54 void Awiu_Biu_Cuw(int nu, 55 int nw, 56 int ni, 57 std::complex<double> alpha, 58 std::complex<double> const* A, 59 std::complex<double> const* B, 60 int ldb, 61 std::complex<double>* C, 62 int ldc); 63 void Awiu_Biu_Cuw(int nu, 64 int nw, 65 int ni, 66 std::complex<float> alpha, 67 std::complex<float> const* A, 68 std::complex<float> const* B, 69 int ldb, 70 std::complex<float>* C, 71 int ldc); 72 73 // C[i][k] = sum_i A[i][j][k] * B[k][j] 74 void Aijk_Bkj_Cik(int ni, 75 int nj, 76 int nk, 77 std::complex<double> const* A, 78 int lda, 79 int stride, 80 std::complex<double> const* B, 81 int ldb, 82 std::complex<double>* C, 83 int ldc); 84 void Aijk_Bkj_Cik(int ni, 85 int nj, 86 int nk, 87 std::complex<double> const* A, 88 int lda, 89 int stride, 90 double const* B, 91 int ldb, 92 std::complex<double>* C, 93 int ldc); 94 void Aijk_Bkj_Cik(int ni, 95 int nj, 96 int nk, 97 std::complex<float> const* A, 98 int lda, 99 int stride, 100 std::complex<float> const* B, 101 int ldb, 102 std::complex<float>* C, 103 int ldc); 104 void Aijk_Bkj_Cik(int ni, 105 int nj, 106 int nk, 107 std::complex<float> const* A, 108 int lda, 109 int stride, 110 float const* B, 111 int ldb, 112 std::complex<float>* C, 113 int ldc); 114 115 // A[w][i][j] = B[i][w][j] 116 void viwj_vwij(int nw, int ni, int i0, int iN, std::complex<double> const* B, std::complex<double>* A); 117 void viwj_vwij(int nw, int ni, int i0, int iN, std::complex<double> const* B, std::complex<float>* A); 118 void viwj_vwij(int nw, int ni, int i0, int iN, std::complex<float> const* B, std::complex<double>* A); 119 void viwj_vwij(int nw, int ni, int i0, int iN, std::complex<float> const* B, std::complex<float>* A); 120 121 // element-wise C[k][i][j] = A[i][j] * B[j][k] 122 void element_wise_Aij_Bjk_Ckij(char transA, 123 int ni, 124 int nj, 125 int nk, 126 double const* A, 127 int lda, 128 std::complex<double> const* B, 129 int ldb, 130 std::complex<double>* C, 131 int ldc1, 132 int ldc2); 133 void element_wise_Aij_Bjk_Ckij(char transA, 134 int ni, 135 int nj, 136 int nk, 137 float const* A, 138 int lda, 139 std::complex<float> const* B, 140 int ldb, 141 std::complex<float>* C, 142 int ldc1, 143 int ldc2); 144 void element_wise_Aij_Bjk_Ckij(char transA, 145 int ni, 146 int nj, 147 int nk, 148 std::complex<double> const* A, 149 int lda, 150 std::complex<double> const* B, 151 int ldb, 152 std::complex<double>* C, 153 int ldc1, 154 int ldc2); 155 void element_wise_Aij_Bjk_Ckij(char transA, 156 int ni, 157 int nj, 158 int nk, 159 std::complex<float> const* A, 160 int lda, 161 std::complex<float> const* B, 162 int ldb, 163 std::complex<float>* C, 164 int ldc1, 165 int ldc2); 166 167 // element-wise C[k][j][i] = A[i][j] * B[j][k] 168 void element_wise_Aij_Bjk_Ckji(int ni, 169 int nj, 170 int nk, 171 double const* A, 172 int lda, 173 std::complex<double> const* B, 174 int ldb, 175 std::complex<double>* C, 176 int ldc, 177 int stride); 178 void element_wise_Aij_Bjk_Ckji(int ni, 179 int nj, 180 int nk, 181 float const* A, 182 int lda, 183 std::complex<float> const* B, 184 int ldb, 185 std::complex<float>* C, 186 int ldc, 187 int stride); 188 void element_wise_Aij_Bjk_Ckji(int ni, 189 int nj, 190 int nk, 191 std::complex<double> const* A, 192 int lda, 193 std::complex<double> const* B, 194 int ldb, 195 std::complex<double>* C, 196 int ldc, 197 int stride); 198 void element_wise_Aij_Bjk_Ckji(int ni, 199 int nj, 200 int nk, 201 std::complex<float> const* A, 202 int lda, 203 std::complex<float> const* B, 204 int ldb, 205 std::complex<float>* C, 206 int ldc, 207 int stride); 208 209 } // namespace kernels 210 211 #endif 212