1 //////////////////////////////////////////////////////////////////////
2 // This file is distributed under the University of Illinois/NCSA Open Source
3 // License.  See LICENSE file in top directory for details.
4 //
5 // Copyright (c) 2016 Jeongnim Kim and QMCPACK developers.
6 //
7 // File developed by:
8 // Miguel A. Morales, moralessilva2@llnl.gov
9 //    Lawrence Livermore National Laboratory
10 //
11 // File created by:
12 // Miguel A. Morales, moralessilva2@llnl.gov
13 //    Lawrence Livermore National Laboratory
14 ////////////////////////////////////////////////////////////////////////////////
15 
16 #ifndef QMCPLUSPLUS_AFQMC_HAMILTONIANOPERATIONS_KP3INDEXFACTORIZATION_BATCHED_HPP
17 #define QMCPLUSPLUS_AFQMC_HAMILTONIANOPERATIONS_KP3INDEXFACTORIZATION_BATCHED_HPP
18 
19 #include <vector>
20 #include <type_traits>
21 #include <random>
22 #include <algorithm>
23 
24 #include "Configuration.h"
25 #include "multi/array.hpp"
26 #include "multi/array_ref.hpp"
27 #include "AFQMC/Numerics/ma_operations.hpp"
28 #include "AFQMC/Memory/buffer_managers.h"
29 
30 #include "AFQMC/Utilities/type_conversion.hpp"
31 #include "AFQMC/Utilities/Utils.hpp"
32 #include "AFQMC/Numerics/batched_operations.hpp"
33 #include "AFQMC/Numerics/tensor_operations.hpp"
34 
35 
36 namespace qmcplusplus
37 {
38 namespace afqmc
39 {
40 // testing the use of dynamic data transfer during execution to reduce memory in GPU
41 // when an approach is found, integrate in original class through additional template parameter
42 
43 template<class LQKankMatrix>
44 class KP3IndexFactorization_batched
45 {
46   // allocators
47   using Allocator          = device_allocator<ComplexType>;
48   using SpAllocator        = device_allocator<SPComplexType>;
49   using BAllocator         = device_allocator<bool>;
50   using IAllocator         = device_allocator<int>;
51   using Allocator_shared   = node_allocator<ComplexType>;
52   using SpAllocator_shared = node_allocator<SPComplexType>;
53   using IAllocator_shared  = node_allocator<int>;
54 
55   using device_alloc_type  = DeviceBufferManager::template allocator_t<SPComplexType>;
56   using device_alloc_Itype = DeviceBufferManager::template allocator_t<int>;
57 
58   // type defs
59   using pointer                 = typename Allocator::pointer;
60   using const_pointer           = typename Allocator::const_pointer;
61   using sp_pointer              = typename SpAllocator::pointer;
62   using const_sp_pointer        = typename SpAllocator::const_pointer;
63   using pointer_shared          = typename Allocator_shared::pointer;
64   using const_pointer_shared    = typename Allocator_shared::const_pointer;
65   using sp_pointer_shared       = typename SpAllocator_shared::pointer;
66   using const_sp_pointer_shared = typename SpAllocator_shared::const_pointer;
67 
68   using stdIVector = boost::multi::array<int, 1>;
69 
70   using IVector    = boost::multi::array<int, 1, IAllocator>;
71   using BoolMatrix = boost::multi::array<bool, 2, BAllocator>;
72   using CVector    = ComplexVector<Allocator>;
73   using IMatrix    = IntegerMatrix<IAllocator>;
74   using CMatrix    = ComplexMatrix<Allocator>;
75   using C3Tensor   = boost::multi::array<ComplexType, 3, Allocator>;
76 
77   using SpVector  = SPComplexVector<SpAllocator>;
78   using SpMatrix  = SPComplexMatrix<SpAllocator>;
79   using Sp3Tensor = boost::multi::array<SPComplexType, 3, SpAllocator>;
80 
81   using CMatrix_cref  = boost::multi::array_ref<ComplexType const, 2, const_pointer>;
82   using CVector_ref   = ComplexVector_ref<pointer>;
83   using CMatrix_ref   = ComplexMatrix_ref<pointer>;
84   using C3Tensor_ref  = Complex3Tensor_ref<pointer>;
85   using C4Tensor_ref  = ComplexArray_ref<4, pointer>;
86   using C3Tensor_cref = boost::multi::array_ref<ComplexType const, 3, const_pointer>;
87 
88   using SpMatrix_cref = boost::multi::array_ref<SPComplexType const, 2, sp_pointer>;
89   using SpVector_ref  = SPComplexVector_ref<sp_pointer>;
90   using SpMatrix_ref  = SPComplexMatrix_ref<sp_pointer>;
91   using Sp3Tensor_ref = SPComplex3Tensor_ref<sp_pointer>;
92   using Sp4Tensor_ref = SPComplexArray_ref<4, sp_pointer>;
93   using Sp5Tensor_ref = SPComplexArray_ref<5, sp_pointer>;
94 
95   using StaticIVector = boost::multi::static_array<int, 1, device_alloc_Itype>;
96   using StaticVector  = boost::multi::static_array<SPComplexType, 1, device_alloc_type>;
97   using StaticMatrix  = boost::multi::static_array<SPComplexType, 2, device_alloc_type>;
98   using Static3Tensor = boost::multi::static_array<SPComplexType, 3, device_alloc_type>;
99   using Static4Tensor = boost::multi::static_array<SPComplexType, 4, device_alloc_type>;
100 
101   using shmCVector  = ComplexVector<Allocator_shared>;
102   using shmCMatrix  = ComplexMatrix<Allocator_shared>;
103   using shmIMatrix  = IntegerMatrix<IAllocator_shared>;
104   using shmC3Tensor = Complex3Tensor<Allocator_shared>;
105 
106   using mpi3C3Tensor = Complex3Tensor<shared_allocator<ComplexType>>;
107 
108   using shmSpVector  = SPComplexVector<SpAllocator_shared>;
109   using shmSpMatrix  = SPComplexMatrix<SpAllocator_shared>;
110   using shmSp3Tensor = SPComplex3Tensor<SpAllocator_shared>;
111 
112 public:
113   static const HamiltonianTypes HamOpType = KPFactorized;
getHamType() const114   HamiltonianTypes getHamType() const { return HamOpType; }
115 
116   // NOTE: careful with nocc_max, not consistently defined!!!
117 
118   // since arrays can be in host, can't assume that types are consistent
119   template<class shmCMatrix_, class shmSpMatrix_>
KP3IndexFactorization_batched(WALKER_TYPES type,afqmc::TaskGroup_ & tg_,stdIVector && nopk_,stdIVector && ncholpQ_,stdIVector && kminus_,boost::multi::array<int,2> && nelpk_,boost::multi::array<int,2> && QKToK2_,mpi3C3Tensor && hij_,shmCMatrix_ && h1,std::vector<shmSpMatrix_> && vik,std::vector<shmSpMatrix_> && vak,std::vector<shmSpMatrix_> && vakn,std::vector<shmSpMatrix_> && vbl,std::vector<shmSpMatrix_> && vbln,stdIVector && qqm_,mpi3C3Tensor && vn0_,std::vector<RealType> && gQ_,int nsampleQ_,ValueType e0_,Allocator const & alloc_,int cv0,int gncv,int bf_size=4096)120   KP3IndexFactorization_batched(WALKER_TYPES type,
121                                 afqmc::TaskGroup_& tg_,
122                                 stdIVector&& nopk_,
123                                 stdIVector&& ncholpQ_,
124                                 stdIVector&& kminus_,
125                                 boost::multi::array<int, 2>&& nelpk_,
126                                 boost::multi::array<int, 2>&& QKToK2_,
127                                 mpi3C3Tensor&& hij_,
128                                 shmCMatrix_&& h1,
129                                 std::vector<shmSpMatrix_>&& vik,
130                                 std::vector<shmSpMatrix_>&& vak,
131                                 std::vector<shmSpMatrix_>&& vakn,
132                                 std::vector<shmSpMatrix_>&& vbl,
133                                 std::vector<shmSpMatrix_>&& vbln,
134                                 stdIVector&& qqm_,
135                                 mpi3C3Tensor&& vn0_,
136                                 std::vector<RealType>&& gQ_,
137                                 int nsampleQ_,
138                                 ValueType e0_,
139                                 Allocator const& alloc_,
140                                 int cv0,
141                                 int gncv,
142                                 int bf_size = 4096)
143       : TG(tg_),
144         allocator_(alloc_),
145         sp_allocator_(alloc_),
146         device_buffer_manager(),
147         walker_type(type),
148         global_nCV(gncv),
149         global_origin(cv0),
150         default_buffer_size_in_MB(bf_size),
151         last_nw(-1),
152         E0(e0_),
153         H1(std::move(hij_)),
154         haj(std::move(h1)),
155         nopk(std::move(nopk_)),
156         ncholpQ(std::move(ncholpQ_)),
157         kminus(std::move(kminus_)),
158         nelpk(std::move(nelpk_)),
159         QKToK2(std::move(QKToK2_)),
160         LQKikn(std::move(move_vector<shmSpMatrix>(std::move(vik)))),
161         //LQKank(std::move(move_vector<LQKankMatrix>(std::move(vak),TG.Node()))),
162         LQKank(std::move(move_vector<LQKankMatrix>(std::move(vak)))),
163         //needs_copy(true),
164         needs_copy(not std::is_same<decltype(ma::pointer_dispatch(LQKank[0].origin())), sp_pointer>::value),
165         LQKakn(std::move(move_vector<shmSpMatrix>(std::move(vakn)))),
166         LQKbnl(std::move(move_vector<shmSpMatrix>(std::move(vbl)))),
167         LQKbln(std::move(move_vector<shmSpMatrix>(std::move(vbln)))),
168         Qmap(std::move(qqm_)),
169         Q2vbias(Qmap.size()),
170         vn0(std::move(vn0_)),
171         nsampleQ(nsampleQ_),
172         gQ(std::move(gQ_)),
173         Qwn({1, 1}),
174         generator(),
175         distribution(gQ.begin(), gQ.end()),
176         KKTransID({nopk.size(), nopk.size()}, IAllocator{allocator_}),
177         dev_nopk(nopk),
178         dev_i0pk(typename IVector::extensions_type{nopk.size()}, IAllocator{allocator_}),
179         dev_kminus(kminus),
180         dev_ncholpQ(ncholpQ),
181         dev_Q2vbias(typename IVector::extensions_type{nopk.size()}, IAllocator{allocator_}),
182         dev_Qmap(Qmap),
183         dev_nelpk(nelpk),
184         dev_a0pk(typename IMatrix::extensions_type{nelpk.size(0), nelpk.size(1)}, IAllocator{allocator_}),
185         dev_QKToK2(QKToK2),
186         EQ(nopk.size() + 2)
187   {
188     using std::copy_n;
189     using std::fill_n;
190     nocc_max = *std::max_element(nelpk.origin(), nelpk.origin() + nelpk.num_elements());
191     fill_n(EQ.data(), EQ.size(), 0);
192     int nkpts = nopk.size();
193     // Defines behavior over Q vector:
194     //   <0: Ignore (handled by another TG)
195     //    0: Calculate, without rho^+ contribution
196     //   >0: Calculate, with rho^+ contribution. LQKbln data located at Qmap[Q]-1
197     number_of_symmetric_Q = 0;
198     number_of_Q_points    = 0;
199     local_nCV             = 0;
200     std::fill_n(Q2vbias.origin(), nkpts, -1);
201     for (int Q = 0; Q < nkpts; Q++)
202     {
203       if (Q > kminus[Q])
204       {
205         if (Qmap[kminus[Q]] == 0)
206         {
207           assert(Qmap[Q] == 0);
208           Q2vbias[Q] = 2 * local_nCV;
209           local_nCV += ncholpQ[Q];
210         }
211         else
212         {
213           assert(Qmap[kminus[Q]] < 0);
214           assert(Qmap[Q] < 0);
215         }
216       }
217       else if (Qmap[Q] >= 0)
218       {
219         Q2vbias[Q] = 2 * local_nCV;
220         local_nCV += ncholpQ[Q];
221         if (Qmap[Q] > 0)
222           number_of_symmetric_Q++;
223       }
224     }
225     for (int Q = 0; Q < nkpts; Q++)
226     {
227       if (Qmap[Q] >= 0)
228         number_of_Q_points++;
229       if (Qmap[Q] > 0)
230       {
231         assert(Q == kminus[Q]);
232         assert(Qmap[Q] <= number_of_symmetric_Q);
233       }
234     }
235     copy_n(Q2vbias.data(), nkpts, dev_Q2vbias.origin());
236     // setup dev integer arrays
237     std::vector<int> i0(nkpts);
238     // dev_nopk
239     i0[0] = 0;
240     for (int i = 1; i < nkpts; i++)
241       i0[i] = i0[i - 1] + nopk[i - 1];
242     copy_n(i0.data(), nkpts, dev_i0pk.origin());
243     // dev_nelpk
244     for (int n = 0; n < nelpk.size(0); n++)
245     {
246       i0[0] = 0;
247       for (int i = 1; i < nkpts; i++)
248         i0[i] = i0[i - 1] + nelpk[n][i - 1];
249       copy_n(i0.data(), nkpts, dev_a0pk[n].origin());
250       if (walker_type == COLLINEAR)
251       {
252         i0[0] = 0;
253         for (int i = 1; i < nkpts; i++)
254           i0[i] = i0[i - 1] + nelpk[n][nkpts + i - 1];
255         copy_n(i0.data(), nkpts, dev_a0pk[n].origin() + nkpts);
256       }
257     }
258     // setup copy/transpose tags
259     // 1: copy from [Ki][Kj] without rho^+ term
260     // 2: transpose from [Ki][Kj] without rho^+ term
261     // 3: ignore
262     // -P: copy from [Ki][Kj] and transpose from [nkpts+P-1][]
263     boost::multi::array<int, 2> KKid({nkpts, nkpts});
264     std::fill_n(KKid.origin(), KKid.num_elements(), 3); // ignore everything by default
265     for (int Q = 0; Q < nkpts; ++Q)
266     { // momentum conservation index
267       if (Qmap[Q] < 0)
268         continue;
269       if (Qmap[Q] > 0)
270       { // both rho and rho^+
271         assert(Q == kminus[Q]);
272         for (int K = 0; K < nkpts; ++K)
273         { // K is the index of the kpoint pair of (i,k)
274           int QK      = QKToK2[Q][K];
275           KKid[K][QK] = -Qmap[Q];
276         }
277       }
278       else if (Q <= kminus[Q])
279       {
280         // since Qmap[Q]==0 here, Q==kminus[Q] means a hermitian L_ik
281         for (int K = 0; K < nkpts; ++K)
282         { // K is the index of the kpoint pair of (i,k)
283           int QK      = QKToK2[Q][K];
284           KKid[K][QK] = 1;
285         }
286       }
287       else if (Q > kminus[Q])
288       { // use L(-Q)(ki)*
289         for (int K = 0; K < nkpts; ++K)
290         { // K is the index of the kpoint pair of (i,k)
291           int QK      = QKToK2[Q][K];
292           KKid[K][QK] = 2;
293         }
294       }
295     }
296     copy_n(KKid.origin(), KKid.num_elements(), KKTransID.origin());
297 
298     long memank = 0;
299     if (needs_copy)
300       for (auto& v : LQKank)
301         memank = std::max(memank, 2 * v.num_elements());
302     else
303       for (auto& v : LQKank)
304         memank += v.num_elements();
305 
306     // report memory usage
307     size_t likn(0), lakn(0), lbln(0), misc(0);
308     for (auto& v : LQKikn)
309       likn += v.num_elements();
310     for (auto& v : LQKakn)
311       lakn += v.num_elements();
312     for (auto& v : LQKbln)
313       lbln += v.num_elements();
314     for (auto& v : LQKbnl)
315       lbln += v.num_elements();
316     app_log() << "****************************************************************** \n";
317     if (needs_copy)
318       app_log() << "  Using out of core storage of LQKakn \n";
319     else
320       app_log() << "  Using device storage of LQKakn \n";
321     app_log() << "  Static memory usage by KP3IndexFactorization_batched (node 0 in MB) \n"
322               << "    L[Q][K][ikn]: " << likn * sizeof(SPComplexType) / 1024.0 / 1024.0 << " \n"
323               << "    L[Q][K][akn]: " << (lakn + memank) * sizeof(SPComplexType) / 1024.0 / 1024.0 << " \n"
324               << "    L[Q][K][bln]: " << lbln * sizeof(SPComplexType) / 1024.0 / 1024.0 << " \n";
325     memory_report();
326   }
327 
~KP3IndexFactorization_batched()328   ~KP3IndexFactorization_batched() {}
329 
330   KP3IndexFactorization_batched(const KP3IndexFactorization_batched& other) = delete;
331   KP3IndexFactorization_batched& operator=(const KP3IndexFactorization_batched& other) = delete;
332   KP3IndexFactorization_batched(KP3IndexFactorization_batched&& other)                 = default;
333   KP3IndexFactorization_batched& operator=(KP3IndexFactorization_batched&& other) = default;
334 
335   // must have the same signature as shared classes, so keeping it with std::allocator
336   // NOTE: THIS SHOULD USE mpi3::shm!!!
getOneBodyPropagatorMatrix(TaskGroup_ & TG_,boost::multi::array<ComplexType,1> const & vMF)337   boost::multi::array<ComplexType, 2> getOneBodyPropagatorMatrix(TaskGroup_& TG_,
338                                                                  boost::multi::array<ComplexType, 1> const& vMF)
339   {
340     int nkpts = nopk.size();
341     int NMO   = std::accumulate(nopk.begin(), nopk.end(), 0);
342     int npol  = (walker_type == NONCOLLINEAR) ? 2 : 1;
343 
344     CVector vMF_(vMF);
345     CVector P0D(iextensions<1u>{NMO * NMO});
346     fill_n(P0D.origin(), P0D.num_elements(), ComplexType(0));
347     vHS(vMF_, P0D);
348     if (TG_.TG().size() > 1)
349       TG_.TG().all_reduce_in_place_n(to_address(P0D.origin()), P0D.num_elements(), std::plus<>());
350 
351     boost::multi::array<ComplexType, 2> P0({NMO, NMO});
352     copy_n(P0D.origin(), NMO * NMO, P0.origin());
353 
354     boost::multi::array<ComplexType, 2> P1({npol * NMO, npol * NMO});
355     std::fill_n(P1.origin(), P1.num_elements(), ComplexType(0.0));
356 
357     // add spin-dependent H1
358     for (int K = 0, nk0 = 0; K < nkpts; ++K)
359     {
360       for (int i = 0, I = nk0; i < nopk[K]; i++, I++)
361       {
362         for (int p = 0; p < npol; ++p)
363           P1[p * NMO + I][p * NMO + I] += H1[K][p * nopk[K] + i][p * nopk[K] + i];
364         for (int j = i + 1, J = I + 1; j < nopk[K]; j++, J++)
365         {
366           for (int p = 0; p < npol; ++p)
367           {
368             P1[p * NMO + I][p * NMO + J] += H1[K][p * nopk[K] + i][p * nopk[K] + j];
369             P1[p * NMO + J][p * NMO + I] += H1[K][p * nopk[K] + j][p * nopk[K] + i];
370           }
371         }
372         if (walker_type == NONCOLLINEAR)
373         {
374           // offdiagonal piece
375           for (int j = 0, J = nk0; j < nopk[K]; j++, J++)
376           {
377             P1[I][NMO + J] += H1[K][i][nopk[K] + j];
378             P1[NMO + J][I] += H1[K][nopk[K] + j][i];
379           }
380         }
381       }
382       nk0 += nopk[K];
383     }
384 
385     // add P0 (diagonal in spin)
386     for (int p = 0; p < npol; ++p)
387       for (int I = 0; I < NMO; I++)
388         for (int J = 0; J < NMO; J++)
389           P1[p * NMO + I][p * NMO + J] += P0[I][J];
390 
391     // add vn0 (diagonal in spin)
392     for (int K = 0, nk0 = 0; K < nkpts; ++K)
393     {
394       for (int i = 0, I = nk0; i < nopk[K]; i++, I++)
395       {
396         for (int p = 0; p < npol; ++p)
397           P1[p * NMO + I][p * NMO + I] += vn0[K][i][i];
398         for (int j = i + 1, J = I + 1; j < nopk[K]; j++, J++)
399         {
400           for (int p = 0; p < npol; ++p)
401           {
402             P1[p * NMO + I][p * NMO + J] += vn0[K][i][j];
403             P1[p * NMO + J][p * NMO + I] += vn0[K][j][i];
404           }
405         }
406       }
407       nk0 += nopk[K];
408     }
409 
410     using ma::conj;
411     // symmetrize
412     for (int I = 0; I < npol * NMO; I++)
413     {
414       for (int J = I + 1; J < npol * NMO; J++)
415       {
416         // This is really cutoff dependent!!!
417 #if defined(MIXED_PRECISION)
418         if (std::abs(P1[I][J] - ma::conj(P1[J][I])) * 2.0 > 1e-5)
419         {
420 #else
421         if (std::abs(P1[I][J] - ma::conj(P1[J][I])) * 2.0 > 1e-6)
422         {
423 #endif
424           app_error() << " WARNING in getOneBodyPropagatorMatrix. H1 is not hermitian. \n";
425           app_error() << I << " " << J << " " << P1[I][J] << " " << P1[J][I] << std::endl;
426           //<< H1[K][i][j] << " "
427           //<< H1[K][j][i] << " " << vn0[K][i][j] << " " << vn0[K][j][i] << std::endl;
428         }
429         P1[I][J] = 0.5 * (P1[I][J] + ma::conj(P1[J][I]));
430         P1[J][I] = ma::conj(P1[I][J]);
431       }
432     }
433     return P1;
434   }
435 
436   template<class Mat, class MatB>
437   void energy(Mat&& E, MatB const& G, int k = 0, bool addH1 = true, bool addEJ = true, bool addEXX = true)
438   {
439     MatB* Kr(nullptr);
440     MatB* Kl(nullptr);
441     energy(E, G, k, Kl, Kr, addH1, addEJ, addEXX);
442   }
443 
444   // KEleft and KEright must be in shared memory for this to work correctly
445   template<
446       class Mat,
447       class MatB,
448       class MatC,
449       class MatD
450       //             typename = decltype(boost::multi::static_array_cast<ComplexType, pointer>(std::declval<Mat>())),
451       //             typename = decltype(boost::multi::static_array_cast<ComplexType, pointer>(std::declval<MatB>())),
452       //             typename = decltype(boost::multi::static_array_cast<ComplexType, pointer>(std::declval<MatC>())),
453       //             typename = decltype(boost::multi::static_array_cast<ComplexType, pointer>(std::declval<MatD>()))
454       >
455   void energy(Mat&& E,
456               MatB const& Gc,
457               int nd,
458               MatC* KEleft,
459               MatD* KEright,
460               bool addH1  = true,
461               bool addEJ  = true,
462               bool addEXX = true)
463   {
464     if (nsampleQ > 0)
465       energy_sampleQ(E, Gc, nd, KEleft, KEright, addH1, addEJ, addEXX);
466     else
467       energy_exact(E, Gc, nd, KEleft, KEright, addH1, addEJ, addEXX);
468   }
469 
470   // KEleft and KEright must be in shared memory for this to work correctly
471   template<class Mat, class MatB, class MatC, class MatD>
472   void energy_exact(Mat&& E,
473                     MatB const& Gc,
474                     int nd,
475                     MatC* KEleft,
476                     MatD* KEright,
477                     bool addH1  = true,
478                     bool addEJ  = true,
479                     bool addEXX = true)
480   {
481     using std::copy_n;
482     using std::fill_n;
483     int nkpts = nopk.size();
484     assert(E.size(1) >= 3);
485     assert(nd >= 0 && nd < nelpk.size());
486 
487     int nwalk     = Gc.size(1);
488     int nspin     = (walker_type == COLLINEAR ? 2 : 1);
489     int npol      = (walker_type == NONCOLLINEAR ? 2 : 1);
490     int nmo_tot   = std::accumulate(nopk.begin(), nopk.end(), 0);
491     int nmo_max   = *std::max_element(nopk.begin(), nopk.end());
492     int nocca_tot = std::accumulate(nelpk[nd].begin(), nelpk[nd].begin() + nkpts, 0);
493     int nocca_max = *std::max_element(nelpk[nd].begin(), nelpk[nd].begin() + nkpts);
494     int nchol_max = *std::max_element(ncholpQ.begin(), ncholpQ.end());
495     int noccb_tot = 0;
496     if (walker_type == COLLINEAR)
497       noccb_tot = std::accumulate(nelpk[nd].begin() + nkpts, nelpk[nd].begin() + 2 * nkpts, 0);
498     int getKr = KEright != nullptr;
499     int getKl = KEleft != nullptr;
500     if (E.size(0) != nwalk || E.size(1) < 3)
501       APP_ABORT(" Error in AFQMC/HamiltonianOperations/sparse_matrix_energy::calculate_energy(). Incorrect matrix "
502                 "dimensions \n");
503 
504     // take from BufferManager.
505     //      long default_buffer_size_in_MB(4L*1024L);
506     long batch_size(0);
507     if (addEXX)
508     {
509       long Bytes = long(default_buffer_size_in_MB) * 1024L * 1024L;
510       Bytes /= size_t(nwalk * nocc_max * nocc_max * nchol_max * sizeof(SPComplexType));
511       long bz0 = std::max(2L, Bytes);
512       // batch_size includes the factor of 2 from Q/Qm pair
513       batch_size = std::min(bz0, long(2 * number_of_Q_points * nkpts));
514       // make sure batch_size is even
515       batch_size = batch_size - (batch_size % 2L);
516       assert(batch_size % 2L == 0);
517     }
518 
519     long Knr = 0, Knc = 0;
520     if (addEJ)
521     {
522       Knr = nwalk;
523       Knc = local_nCV;
524       if (getKr)
525       {
526         assert(KEright->size(0) == nwalk && KEright->size(1) == local_nCV);
527         assert(KEright->stride(0) == KEright->size(1));
528       }
529       if (getKl)
530       {
531         assert(KEleft->size(0) == nwalk && KEleft->size(1) == local_nCV);
532         assert(KEleft->stride(0) == KEleft->size(1));
533       }
534     }
535     else if (getKr or getKl)
536     {
537       APP_ABORT(" Error: Kr and/or Kl can only be calculated with addEJ=true.\n");
538     }
539     StaticMatrix Kl({Knr, Knc}, device_buffer_manager.get_generator().template get_allocator<SPComplexType>());
540     StaticMatrix Kr({Knr, Knc}, device_buffer_manager.get_generator().template get_allocator<SPComplexType>());
541     fill_n(Kr.origin(), Knr * Knc, SPComplexType(0.0));
542     fill_n(Kl.origin(), Knr * Knc, SPComplexType(0.0));
543 
544     for (int n = 0; n < nwalk; n++)
545       fill_n(E[n].origin(), 3, ComplexType(0.));
546 
547     assert(Gc.num_elements() == nwalk * (nocca_tot + noccb_tot) * npol * nmo_tot);
548     C3Tensor_cref G3Da(make_device_ptr(Gc.origin()), {nocca_tot * npol, nmo_tot, nwalk});
549     C3Tensor_cref G3Db(make_device_ptr(Gc.origin()) + G3Da.num_elements() * (nspin - 1), {noccb_tot, nmo_tot, nwalk});
550 
551     // later on, rewrite routine to loop over spins, to avoid storage of both spin
552     // components simultaneously
553     Static4Tensor GKK({nspin, nkpts, nkpts, nwalk * npol * nmo_max * nocc_max},
554                       device_buffer_manager.get_generator().template get_allocator<SPComplexType>());
555     GKaKjw_to_GKKwaj(G3Da, GKK[0], nelpk[nd].sliced(0, nkpts), dev_nelpk[nd], dev_a0pk[nd]);
556     if (walker_type == COLLINEAR)
557       GKaKjw_to_GKKwaj(G3Db, GKK[1], nelpk[nd].sliced(nkpts, 2 * nkpts), dev_nelpk[nd].sliced(nkpts, 2 * nkpts),
558                        dev_a0pk[nd].sliced(nkpts, 2 * nkpts));
559     // one-body contribution
560     // haj[ndet*nkpts][nocc*nmo]
561     // not parallelized for now, since it would require customization of Wfn
562     if (addH1)
563     {
564       for (int n = 0; n < nwalk; n++)
565         fill_n(E[n].origin(), 1, ComplexType(E0));
566       // must use Gc since GKK is is SP
567       int na = 0, nk = 0, nb = 0;
568       for (int K = 0; K < nkpts; ++K)
569       {
570 #if defined(MIXED_PRECISION)
571         int ni(nopk[K]);
572         CMatrix_ref haj_K(make_device_ptr(haj[nd * nkpts + K].origin()), {nocc_max, npol * nmo_max});
573         for (int a = 0; a < nelpk[nd][K]; ++a)
574           for (int pol = 0; pol < npol; ++pol)
575             ma::product(ComplexType(1.), ma::T(G3Da[(na + a) * npol + pol].sliced(nk, nk + ni)),
576                         haj_K[a].sliced(pol * ni, pol * ni + ni), ComplexType(1.), E({0, nwalk}, 0));
577         na += nelpk[nd][K];
578         if (walker_type == COLLINEAR)
579         {
580           boost::multi::array_ref<ComplexType, 2, pointer> haj_Kb(haj_K.origin() + haj_K.num_elements(),
581                                                                   {nocc_max, nmo_max});
582           for (int b = 0; b < nelpk[nd][nkpts + K]; ++b)
583             ma::product(ComplexType(1.), ma::T(G3Db[nb + b].sliced(nk, nk + ni)), haj_Kb[b].sliced(0, ni),
584                         ComplexType(1.), E({0, nwalk}, 0));
585           nb += nelpk[nd][nkpts + K];
586         }
587         nk += ni;
588 #else
589         nk = nopk[K];
590         {
591           na = nelpk[nd][K];
592           CVector_ref haj_K(make_device_ptr(haj[nd * nkpts + K].origin()), {nocc_max * npol * nmo_max});
593           SpMatrix_ref Gaj(GKK[0][K][K].origin(), {nwalk, nocc_max * npol * nmo_max});
594           ma::product(ComplexType(1.), Gaj, haj_K, ComplexType(1.), E({0, nwalk}, 0));
595         }
596         if (walker_type == COLLINEAR)
597         {
598           na = nelpk[nd][nkpts + K];
599           CVector_ref haj_K(make_device_ptr(haj[nd * nkpts + K].origin()) + nocc_max * nmo_max, {nocc_max * nmo_max});
600           SpMatrix_ref Gaj(GKK[1][K][K].origin(), {nwalk, nocc_max * nmo_max});
601           ma::product(ComplexType(1.), Gaj, haj_K, ComplexType(1.), E({0, nwalk}, 0));
602         }
603 #endif
604       }
605     }
606 
607     // move calculation of H1 here
608     // NOTE: For CLOSED/NONCOLLINEAR, can do all walkers simultaneously to improve perf. of GEMM
609     //       Not sure how to do it for COLLINEAR.
610     if (addEXX)
611     {
612       int batch_cnt(0);
613       using ma::gemmBatched;
614       std::vector<sp_pointer> Aarray;
615       std::vector<sp_pointer> Barray;
616       std::vector<sp_pointer> Carray;
617       Aarray.reserve(batch_size);
618       Barray.reserve(batch_size);
619       Carray.reserve(batch_size);
620       std::vector<SPComplexType> scl_factors;
621       scl_factors.reserve(batch_size);
622       std::vector<int> kdiag;
623       kdiag.reserve(batch_size);
624 
625       StaticIVector IMats(iextensions<1u>{batch_size},
626                           device_buffer_manager.get_generator().template get_allocator<int>());
627       fill_n(IMats.origin(), IMats.num_elements(), 0);
628       StaticVector dev_scl_factors(iextensions<1u>{batch_size},
629                                    device_buffer_manager.get_generator().template get_allocator<SPComplexType>());
630       Static3Tensor T1({batch_size, nwalk * nocc_max, nocc_max * nchol_max},
631                        device_buffer_manager.get_generator().template get_allocator<SPComplexType>());
632       SPRealType scl = (walker_type == CLOSED ? 2.0 : 1.0);
633 
634       // I WANT C++17!!!!!!
635       long mem_ank(0);
636       if (needs_copy)
637         mem_ank = nkpts * nocc_max * nchol_max * npol * nmo_max;
638       StaticVector LBuff(iextensions<1u>{2 * mem_ank},
639                          device_buffer_manager.get_generator().template get_allocator<SPComplexType>());
640       sp_pointer LQptr(nullptr), LQmptr(nullptr);
641       if (needs_copy)
642       {
643         // data will be copied here
644         LQptr  = LBuff.origin();
645         LQmptr = LBuff.origin() + mem_ank;
646       }
647 
648       for (int spin = 0; spin < nspin; ++spin)
649       {
650         for (int Q = 0; Q < nkpts; ++Q)
651         {
652           if (Qmap[Q] < 0)
653             continue;
654           bool haveKE = false;
655           int Qm      = kminus[Q];
656 
657           // simple implementation for now
658           Aarray.clear();
659           Barray.clear();
660           Carray.clear();
661           scl_factors.clear();
662           kdiag.clear();
663           batch_cnt = 0;
664 
665           // choose source of data depending on whether data needs to be copied or not
666           if (!needs_copy)
667           {
668             // set to local array origin
669             LQptr  = make_device_ptr(LQKank[nd * nspin * nkpts + spin * nkpts + Q].origin());
670             LQmptr = make_device_ptr(LQKank[nd * nspin * nkpts + spin * nkpts + Qm].origin());
671           }
672 
673           SpMatrix_ref LQ(LQptr, LQKank[nd * nspin * nkpts + spin * nkpts + Q].extensions());
674           SpMatrix_ref LQm(LQmptr, LQKank[nd * nspin * nkpts + spin * nkpts + Qm].extensions());
675 
676           if (needs_copy)
677           {
678             copy_n(to_address(LQKank[nd * nspin * nkpts + spin * nkpts + Q].origin()), LQ.num_elements(), LQ.origin());
679             if (Q != Qm)
680               copy_n(to_address(LQKank[nd * nspin * nkpts + spin * nkpts + Qm].origin()), LQm.num_elements(),
681                      LQm.origin());
682           }
683 
684           for (int Ka = 0; Ka < nkpts; ++Ka)
685           {
686             int K0 = ((Qmap[Q] > 0) ? 0 : Ka);
687             for (int Kb = K0; Kb < nkpts; ++Kb)
688             {
689               int Kl_ = QKToK2[Qm][Kb];
690               int Kk  = QKToK2[Q][Ka];
691 
692               if (addEJ && Ka == Kb)
693                 kdiag.push_back(batch_cnt);
694 
695               if (Qmap[Q] > 0)
696                 Aarray.push_back(sp_pointer(
697                     LQKbnl[nd * nspin * number_of_symmetric_Q + spin * number_of_symmetric_Q + Qmap[Q] - 1][Kb]
698                         .origin()));
699               else
700                 Aarray.push_back(sp_pointer(LQm[Kb].origin()));
701 
702               Barray.push_back(GKK[spin][Ka][Kl_].origin());
703               Carray.push_back(T1[batch_cnt++].origin());
704               Aarray.push_back(sp_pointer(LQ[Ka].origin()));
705               Barray.push_back(GKK[spin][Kb][Kk].origin());
706               Carray.push_back(T1[batch_cnt++].origin());
707 
708               if (Qmap[Q] > 0 || Ka == Kb)
709                 scl_factors.push_back(SPComplexType(-scl * 0.5));
710               else
711                 scl_factors.push_back(SPComplexType(-scl));
712 
713               if (batch_cnt >= batch_size)
714               {
715                 gemmBatched('T', 'N', nocc_max * nchol_max, nwalk * nocc_max, npol * nmo_max, SPComplexType(1.0),
716                             Aarray.data(), npol * nmo_max, Barray.data(), npol * nmo_max, SPComplexType(0.0),
717                             Carray.data(), nocc_max * nchol_max, Aarray.size());
718 
719                 copy_n(scl_factors.data(), scl_factors.size(), dev_scl_factors.origin());
720                 using ma::batched_dot_wabn_wban;
721                 batched_dot_wabn_wban(scl_factors.size(), nwalk, nocc_max, nchol_max, dev_scl_factors.origin(),
722                                       T1.origin(), to_address(E[0].origin()) + 1, E.stride(0));
723 
724                 if (addEJ)
725                 {
726                   int nc0 = Q2vbias[Q] / 2; //std::accumulate(ncholpQ.begin(),ncholpQ.begin()+Q,0);
727                   copy_n(kdiag.data(), kdiag.size(), IMats.origin());
728                   using ma::batched_Tab_to_Klr;
729                   batched_Tab_to_Klr(kdiag.size(), nwalk, nocc_max, nchol_max, local_nCV, ncholpQ[Q], nc0,
730                                      IMats.origin(), T1.origin(), Kl.origin(), Kr.origin());
731                 }
732 
733                 // reset
734                 Aarray.clear();
735                 Barray.clear();
736                 Carray.clear();
737                 scl_factors.clear();
738                 kdiag.clear();
739                 batch_cnt = 0;
740               }
741             }
742           }
743 
744           if (batch_cnt > 0)
745           {
746             gemmBatched('T', 'N', nocc_max * nchol_max, nwalk * nocc_max, npol * nmo_max, SPComplexType(1.0),
747                         Aarray.data(), npol * nmo_max, Barray.data(), npol * nmo_max, SPComplexType(0.0), Carray.data(),
748                         nocc_max * nchol_max, Aarray.size());
749 
750             copy_n(scl_factors.data(), scl_factors.size(), dev_scl_factors.origin());
751             using ma::batched_dot_wabn_wban;
752             batched_dot_wabn_wban(scl_factors.size(), nwalk, nocc_max, nchol_max, dev_scl_factors.origin(), T1.origin(),
753                                   to_address(E[0].origin()) + 1, E.stride(0));
754 
755             if (addEJ)
756             {
757               int nc0 = Q2vbias[Q] / 2; //std::accumulate(ncholpQ.begin(),ncholpQ.begin()+Q,0);
758               copy_n(kdiag.data(), kdiag.size(), IMats.origin());
759               using ma::batched_Tab_to_Klr;
760               batched_Tab_to_Klr(kdiag.size(), nwalk, nocc_max, nchol_max, local_nCV, ncholpQ[Q], nc0, IMats.origin(),
761                                  T1.origin(), Kl.origin(), Kr.origin());
762             }
763           }
764         } // Q
765       }   // COLLINEAR
766     }
767 
768     if (addEJ)
769     {
770       if (not addEXX)
771       {
772         // calculate Kr
773         APP_ABORT(" Error: Finish addEJ and not addEXX");
774       }
775       RealType scl = (walker_type == CLOSED ? 2.0 : 1.0);
776       using ma::adotpby;
777       for (int n = 0; n < nwalk; ++n)
778       {
779         adotpby(SPComplexType(0.5 * scl * scl), Kl[n], Kr[n], ComplexType(0.0), E[n].origin() + 2);
780       }
781       if (getKr)
782         copy_n_cast(Kr.origin(), Kr.num_elements(), make_device_ptr(KEright->origin()));
783       if (getKl)
784         copy_n_cast(Kl.origin(), Kl.num_elements(), make_device_ptr(KEleft->origin()));
785     }
786   }
787 
788   // KEleft and KEright must be in shared memory for this to work correctly
789   template<class Mat, class MatB, class MatC, class MatD>
790   void energy_sampleQ(Mat&& E,
791                       MatB const& Gc,
792                       int nd,
793                       MatC* KEleft,
794                       MatD* KEright,
795                       bool addH1  = true,
796                       bool addEJ  = true,
797                       bool addEXX = true)
798   {
799     APP_ABORT("Error: energy_sampleQ not yet implemented in batched routine.\n");
800     /*
801       using std::fill_n;
802       int nkpts = nopk.size();
803       assert(E.size(1)>=3);
804       assert(nd >= 0 && nd < nelpk.size());
805 
806       int nwalk = Gc.size(1);
807       int nspin = (walker_type==COLLINEAR?2:1);
808       int nmo_tot = std::accumulate(nopk.begin(),nopk.end(),0);
809       int nmo_max = *std::max_element(nopk.begin(),nopk.end());
810       int nocca_tot = std::accumulate(nelpk[nd].begin(),nelpk[nd].begin()+nkpts,0);
811       int nocca_max = *std::max_element(nelpk[nd].begin(),nelpk[nd].begin()+nkpts);
812       int nchol_max = *std::max_element(ncholpQ.begin(),ncholpQ.end());
813       int noccb_tot = 0;
814       if(walker_type==COLLINEAR) noccb_tot = std::accumulate(nelpk[nd].begin()+nkpts,
815                                       nelpk[nd].begin()+2*nkpts,0);
816       int getKr = KEright!=nullptr;
817       int getKl = KEleft!=nullptr;
818       if(E.size(0) != nwalk || E.size(1) < 3)
819         APP_ABORT(" Error in AFQMC/HamiltonianOperations/sparse_matrix_energy::calculate_energy(). Incorrect matrix dimensions \n");
820 
821       size_t mem_needs(nwalk*nkpts*nkpts*nspin*nocca_max*nmo_max);
822       size_t cnt(0);
823       if(addEJ) {
824 #if defined(MIXED_PRECISION)
825         mem_needs += 2*nwalk*local_nCV;
826 #else
827         if(not getKr) mem_needs += nwalk*local_nCV;
828         if(not getKl) mem_needs += nwalk*local_nCV;
829 #endif
830       }
831       set_buffer(mem_needs);
832 
833       // messy
834       sp_pointer Krptr(nullptr), Klptr(nullptr);
835       long Knr=0, Knc=0;
836       if(addEJ) {
837         Knr=nwalk;
838         Knc=local_nCV;
839         cnt=0;
840 #if defined(MIXED_PRECISION)
841         if(getKr) {
842           assert(KEright->size(0) == nwalk && KEright->size(1) == local_nCV);
843           assert(KEright->stride(0) == KEright->size(1));
844         }
845 #else
846         if(getKr) {
847           assert(KEright->size(0) == nwalk && KEright->size(1) == local_nCV);
848           assert(KEright->stride() == KEright->size(1));
849           Krptr = make_device_ptr(KEright->origin());
850         } else
851 #endif
852         {
853           Krptr = BTMats.origin();
854           cnt += nwalk*local_nCV;
855         }
856 #if defined(MIXED_PRECISION)
857         if(getKl) {
858           assert(KEleft->size(0) == nwalk && KEleft->size(1) == local_nCV);
859           assert(KEleft->stride(0) == KEleft->size(1));
860         }
861 #else
862         if(getKl) {
863           assert(KEleft->size(0) == nwalk && KEleft->size(1) == local_nCV);
864           assert(KEleft->stride(0) == KEleft->size(1));
865           Klptr = make_device_ptr(KEleft->origin());
866         } else
867 #endif
868         {
869           Klptr = BTMats.origin()+cnt;
870           cnt += nwalk*local_nCV;
871         }
872         fill_n(Krptr,Knr*Knc,SPComplexType(0.0));
873         fill_n(Klptr,Knr*Knc,SPComplexType(0.0));
874       } else if(getKr or getKl) {
875         APP_ABORT(" Error: Kr and/or Kl can only be calculated with addEJ=true.\n");
876       }
877       SpMatrix_ref Kl(Klptr,{Knr,Knc});
878       SpMatrix_ref Kr(Krptr,{Knr,Knc});
879 
880       for(int n=0; n<nwalk; n++)
881         fill_n(E[n].origin(),3,ComplexType(0.));
882 
883       assert(Gc.num_elements() == nwalk*(nocca_tot+noccb_tot)*nmo_tot);
884       C3Tensor_cref G3Da(make_device_ptr(Gc.origin()),{nocca_tot,nmo_tot,nwalk} );
885       C3Tensor_cref G3Db(make_device_ptr(Gc.origin())+G3Da.num_elements()*(nspin-1),
886                             {noccb_tot,nmo_tot,nwalk} );
887 
888       Sp4Tensor_ref GKK(BTMats.origin()+cnt,
889                         {nspin,nkpts,nkpts,nwalk*nmo_max*nocca_max});
890       cnt+=GKK.num_elements();
891       GKaKjw_to_GKKwaj(G3Da,GKK[0],nelpk[nd].sliced(0,nkpts),dev_nelpk[nd],dev_a0pk[nd]);
892       if(walker_type==COLLINEAR)
893         GKaKjw_to_GKKwaj(G3Db,GKK[1],nelpk[nd].sliced(nkpts,2*nkpts),
894                                      dev_nelpk[nd].sliced(nkpts,2*nkpts),
895                                      dev_a0pk[nd].sliced(nkpts,2*nkpts));
896 
897       // one-body contribution
898       // haj[ndet*nkpts][nocc*nmo]
899       // not parallelized for now, since it would require customization of Wfn
900       if(addH1) {
901         // must use Gc since GKK is is SP
902         int na=0, nk=0, nb=0;
903         for(int n=0; n<nwalk; n++)
904           E[n][0] = E0;
905         for(int K=0; K<nkpts; ++K) {
906 #if defined(MIXED_PRECISION)
907           CMatrix_ref haj_K(make_device_ptr(haj[nd*nkpts+K].origin()),{nocc_max,nmo_max});
908           for(int a=0; a<nelpk[nd][K]; ++a)
909             ma::product(ComplexType(1.),ma::T(G3Da[na+a].sliced(nk,nk+nopk[K])),
910                                         haj_K[a].sliced(0,nopk[K]),
911                         ComplexType(1.),E({0,nwalk},0));
912           na+=nelpk[nd][K];
913           if(walker_type==COLLINEAR) {
914             boost::multi::array_ref<ComplexType,2,pointer> haj_Kb(haj_K.origin()+haj_K.num_elements(),
915                                                       {nocc_max,nmo_max});
916             for(int b=0; b<nelpk[nd][nkpts+K]; ++b)
917               ma::product(ComplexType(1.),ma::T(G3Db[nb+b].sliced(nk,nk+nopk[K])),
918                                         haj_Kb[b].sliced(0,nopk[K]),
919                         ComplexType(1.),E({0,nwalk},0));
920             nb+=nelpk[nd][nkpts+K];
921           }
922           nk+=nopk[K];
923 #else
924           nk = nopk[K];
925           {
926             na = nelpk[nd][K];
927             CVector_ref haj_K(make_device_ptr(haj[nd*nkpts+K].origin()),{nocc_max*nmo_max});
928             SpMatrix_ref Gaj(GKK[0][K][K].origin(),{nwalk,nocc_max*nmo_max});
929             ma::product(ComplexType(1.),Gaj,haj_K,ComplexType(1.),E({0,nwalk},0));
930           }
931           if(walker_type==COLLINEAR) {
932             na = nelpk[nd][nkpts+K];
933             CVector_ref haj_K(make_device_ptr(haj[nd*nkpts+K].origin())+nocc_max*nmo_max,{nocc_max*nmo_max});
934             SpMatrix_ref Gaj(GKK[1][K][K].origin(),{nwalk,nocc_max*nmo_max});
935             ma::product(ComplexType(1.),Gaj,haj_K,ComplexType(1.),E({0,nwalk},0));
936           }
937 #endif
938         }
939       }
940 
941       // move calculation of H1 here
942       // NOTE: For CLOSED/NONCOLLINEAR, can do all walkers simultaneously to improve perf. of GEMM
943       //       Not sure how to do it for COLLINEAR.
944       if(addEXX) {
945 
946         if(Qwn.size(0) != nwalk || Qwn.size(1) != nsampleQ)
947           Qwn = std::move(boost::multi::array<int,2>({nwalk,nsampleQ}));
948         {
949           for(int n=0; n<nwalk; ++n)
950             for(int nQ=0; nQ<nsampleQ; ++nQ) {
951               Qwn[n][nQ] = distribution(generator);
952               RealType drand = distribution(generator);
953               RealType s(0.0);
954               bool found=false;
955               for(int Q=0; Q<nkpts; Q++) {
956                 s += gQ[Q];
957                 if( drand < s ) {
958                   Qwn[n][nQ] = Q;
959                   found=true;
960                   break;
961                 }
962               }
963               if(not found)
964                 APP_ABORT(" Error: sampleQ Qwn. \n");
965             }
966         }
967         size_t local_memory_needs = 2*nocca_max*nocca_max*nchol_max;
968         if(TMats.num_elements() < local_memory_needs) {
969           TMats = std::move(SpVector(iextensions<1u>{local_memory_needs}));
970           using std::fill_n;
971           fill_n(TMats.origin(),TMats.num_elements(),SPComplexType(0.0));
972         }
973         size_t local_cnt=0;
974         RealType scl = (walker_type==CLOSED?2.0:1.0);
975         size_t nqk=1;
976         for(int n=0; n<nwalk; ++n) {
977           for(int nQ=0; nQ<nsampleQ; ++nQ) {
978             int Q = Qwn[n][nQ];
979             for(int Ka=0; Ka<nkpts; ++Ka) {
980               for(int Kb=0; Kb<nkpts; ++Kb) {
981                 {
982                   int nchol = ncholpQ[Q];
983                   int Qm = kminus[Q];
984                   int Kl = QKToK2[Qm][Kb];
985                   int Kk = QKToK2[Q][Ka];
986                   int nl = nopk[Kl];
987                   int nb = nelpk[nd][Kb];
988                   int na = nelpk[nd][Ka];
989                   int nk = nopk[Kk];
990 
991                   SpMatrix_ref Gal(GKK[0][Ka][Kl].origin()+n*na*nl,{na,nl});
992                   SpMatrix_ref Gbk(GKK[0][Kb][Kk].origin()+n*nb*nk,{nb,nk});
993                   SpMatrix_ref Lank(sp_pointer(LQKank[nd*nspin*nkpts+Q][Ka].origin()),
994                                                  {na*nchol,nk});
995                   auto bnl_ptr(sp_pointer(LQKank[nd*nspin*nkpts+Qm][Kb].origin()));
996                   if( Q == Qm ) bnl_ptr = sp_pointer(LQKbnl[nd*nspin*number_of_symmetric_Q+Qmap[Q]-1][Kb].origin());
997                   SpMatrix_ref Lbnl(bnl_ptr,{nb*nchol,nl});
998 
999                   SpMatrix_ref Tban(TMats.origin()+local_cnt,{nb,na*nchol});
1000                   Sp3Tensor_ref T3Dban(TMats.origin()+local_cnt,{nb,na,nchol});
1001                   SpMatrix_ref Tabn(Tban.origin()+Tban.num_elements(),{na,nb*nchol});
1002                   Sp3Tensor_ref T3Dabn(Tban.origin()+Tban.num_elements(),{na,nb,nchol});
1003 
1004                   ma::product(Gal,ma::T(Lbnl),Tabn);
1005                   ma::product(Gbk,ma::T(Lank),Tban);
1006 
1007                   SPComplexType E_(0.0);
1008                   for(int a=0; a<na; ++a)
1009                     for(int b=0; b<nb; ++b)
1010                       E_ += ma::dot(T3Dabn[a][b],T3Dban[b][a]);
1011                   E[n][1] -= scl*0.5*static_cast<ComplexType>(E_)/gQ[Q]/double(nsampleQ);
1012 
1013                 } // if
1014 
1015                 if(walker_type==COLLINEAR) {
1016 
1017                   {
1018                     int nchol = ncholpQ[Q];
1019                     int Qm = kminus[Q];
1020                     int Kl = QKToK2[Qm][Kb];
1021                     int Kk = QKToK2[Q][Ka];
1022                     int nl = nopk[Kl];
1023                     int nb = nelpk[nd][nkpts+Kb];
1024                     int na = nelpk[nd][nkpts+Ka];
1025                     int nk = nopk[Kk];
1026 
1027                     SpMatrix_ref Gal(GKK[1][Ka][Kl].origin()+n*na*nl,{na,nl});
1028                     SpMatrix_ref Gbk(GKK[1][Kb][Kk].origin()+n*nb*nk,{nb,nk});
1029                     SpMatrix_ref Lank(sp_pointer(LQKank[(nd*nspin+1)*nkpts+Q][Ka].origin()),
1030                                                  {na*nchol,nk});
1031                     auto bnl_ptr(sp_pointer(LQKank[nd*nspin*nkpts+Qm][Kb].origin()));
1032                     if( Q == Qm ) bnl_ptr = sp_pointer(LQKbnl[(nd*nspin+1)*number_of_symmetric_Q+
1033                                                                 Qmap[Q]-1][Kb].origin());
1034                     SpMatrix_ref Lbnl(bnl_ptr,{nb*nchol,nl});
1035 
1036                     SpMatrix_ref Tban(TMats.origin()+local_cnt,{nb,na*nchol});
1037                     Sp3Tensor_ref T3Dban(TMats.origin()+local_cnt,{nb,na,nchol});
1038                     SpMatrix_ref Tabn(Tban.origin()+Tban.num_elements(),{na,nb*nchol});
1039                     Sp3Tensor_ref T3Dabn(Tban.origin()+Tban.num_elements(),{na,nb,nchol});
1040 
1041                     ma::product(Gal,ma::T(Lbnl),Tabn);
1042                     ma::product(Gbk,ma::T(Lank),Tban);
1043 
1044                     SPComplexType E_(0.0);
1045                     for(int a=0; a<na; ++a)
1046                       for(int b=0; b<nb; ++b)
1047                         E_ += ma::dot(T3Dabn[a][b],T3Dban[b][a]);
1048                     E[n][1] -= scl*0.5*static_cast<ComplexType>(E_)/gQ[Q]/double(nsampleQ);
1049 
1050                   } // if
1051                 } // COLLINEAR
1052               } // Kb
1053             } // Ka
1054           } // nQ
1055         } // n
1056       }
1057 
1058       if(addEJ) {
1059         size_t local_memory_needs = 2*nchol_max*nwalk;
1060         if(TMats.num_elements() < local_memory_needs) {
1061           TMats = std::move(SpVector(iextensions<1u>{local_memory_needs}));
1062           using std::fill_n;
1063           fill_n(TMats.origin(),TMats.num_elements(),SPComplexType(0.0));
1064         }
1065         cnt=0;
1066         SpMatrix_ref Kr_local(TMats.origin(),{nwalk,nchol_max});
1067         cnt+=Kr_local.num_elements();
1068         SpMatrix_ref Kl_local(TMats.origin()+cnt,{nwalk,nchol_max});
1069         cnt+=Kl_local.num_elements();
1070         fill_n(Kr_local.origin(),Kr_local.num_elements(),SPComplexType(0.0));
1071         fill_n(Kl_local.origin(),Kl_local.num_elements(),SPComplexType(0.0));
1072         size_t nqk=1;
1073         for(int Q=0; Q<nkpts; ++Q) {
1074           bool haveKE=false;
1075           for(int Ka=0; Ka<nkpts; ++Ka) {
1076             {
1077               haveKE=true;
1078               int nchol = ncholpQ[Q];
1079               int Qm = kminus[Q];
1080               int Kl = QKToK2[Qm][Ka];
1081               int Kk = QKToK2[Q][Ka];
1082               int nl = nopk[Kl];
1083               int na = nelpk[nd][Ka];
1084               int nk = nopk[Kk];
1085 
1086               Sp3Tensor_ref Gwal(GKK[0][Ka][Kl].origin(),{nwalk,na,nl});
1087               Sp3Tensor_ref Gwbk(GKK[0][Ka][Kk].origin(),{nwalk,na,nk});
1088               Sp3Tensor_ref Lank(sp_pointer(LQKank[nd*nspin*nkpts+Q][Ka].origin()),
1089                                                  {na,nchol,nk});
1090               auto bnl_ptr(sp_pointer(LQKank[nd*nspin*nkpts+Qm][Ka].origin()));
1091               if( Q == Qm ) bnl_ptr = sp_pointer(LQKbnl[nd*nspin*number_of_symmetric_Q+Qmap[Q]-1][Ka].origin());
1092               Sp3Tensor_ref Lbnl(bnl_ptr,{na,nchol,nl});
1093 
1094               // Twan = sum_l G[w][a][l] L[a][n][l]
1095               for(int n=0; n<nwalk; ++n)
1096                 for(int a=0; a<na; ++a)
1097                   ma::product(SPComplexType(1.0),Lbnl[a],Gwal[n][a],
1098                               SPComplexType(1.0),Kl_local[n]);
1099               for(int n=0; n<nwalk; ++n)
1100                 for(int a=0; a<na; ++a)
1101                   ma::product(SPComplexType(1.0),Lank[a],Gwbk[n][a],
1102                               SPComplexType(1.0),Kr_local[n]);
1103             } // if
1104 
1105             if(walker_type==COLLINEAR) {
1106 
1107               {
1108                 haveKE=true;
1109                 int nchol = ncholpQ[Q];
1110                 int Qm = kminus[Q];
1111                 int Kl = QKToK2[Qm][Ka];
1112                 int Kk = QKToK2[Q][Ka];
1113                 int nl = nopk[Kl];
1114                 int na = nelpk[nd][nkpts+Ka];
1115                 int nk = nopk[Kk];
1116 
1117                 Sp3Tensor_ref Gwal(GKK[1][Ka][Kl].origin(),{nwalk,na,nl});
1118                 Sp3Tensor_ref Gwbk(GKK[1][Ka][Kk].origin(),{nwalk,na,nk});
1119                 Sp3Tensor_ref Lank(sp_pointer(LQKank[(nd*nspin+1)*nkpts+Q][Ka].origin()),
1120                                                  {na,nchol,nk});
1121                 auto bnl_ptr(sp_pointer(LQKank[(nd*nspin+1)*nkpts+Qm][Ka].origin()));
1122                 if( Q == Qm ) bnl_ptr = sp_pointer(LQKbnl[(nd*nspin+1)*number_of_symmetric_Q+Qmap[Q]-1][Ka].origin());
1123                 Sp3Tensor_ref Lbnl(bnl_ptr,{na,nchol,nl});
1124 
1125                 // Twan = sum_l G[w][a][l] L[a][n][l]
1126                 for(int n=0; n<nwalk; ++n)
1127                   for(int a=0; a<na; ++a)
1128                     ma::product(SPComplexType(1.0),Lbnl[a],Gwal[n][a],
1129                                 SPComplexType(1.0),Kl_local[n]);
1130                 for(int n=0; n<nwalk; ++n)
1131                   for(int a=0; a<na; ++a)
1132                     ma::product(SPComplexType(1.0),Lank[a],Gwbk[n][a],
1133                                 SPComplexType(1.0),Kr_local[n]);
1134 
1135               } // if
1136             } // COLLINEAR
1137           } // Ka
1138           if(haveKE) {
1139             int nc0 = Q2vbias[Q]/2; //std::accumulate(ncholpQ.begin(),ncholpQ.begin()+Q,0);
1140             using ma::axpy;
1141             for(int n=0; n<nwalk; n++) {
1142               axpy(SPComplexType(1.0),Kr_local[n].sliced(0,ncholpQ[Q]),
1143                                         Kr[n].sliced(nc0,nc0+ncholpQ[Q]));
1144               axpy(SPComplexType(1.0),Kl_local[n].sliced(0,ncholpQ[Q]),
1145                                         Kl[n].sliced(nc0,nc0+ncholpQ[Q]));
1146             }
1147           } // to release the lock
1148           if(haveKE) {
1149             fill_n(Kr_local.origin(),Kr_local.num_elements(),SPComplexType(0.0));
1150             fill_n(Kl_local.origin(),Kl_local.num_elements(),SPComplexType(0.0));
1151           }
1152         } // Q
1153         nqk=0;
1154         RealType scl = (walker_type==CLOSED?2.0:1.0);
1155         for(int n=0; n<nwalk; ++n) {
1156           for(int Q=0; Q<nkpts; ++Q) {      // momentum conservation index
1157             {
1158               int nc0 = Q2vbias[Q]/2; //std::accumulate(ncholpQ.begin(),ncholpQ.begin()+Q,0);
1159               E[n][2] += 0.5*scl*scl*static_cast<ComplexType>(ma::dot(Kl[n]({nc0,nc0+ncholpQ[Q]}),
1160                                             Kr[n]({nc0,nc0+ncholpQ[Q]})));
1161             }
1162           }
1163         }
1164       }
1165 */
1166   }
1167 
1168   template<class... Args>
1169   void fast_energy(Args&&... args)
1170   {
1171     APP_ABORT(" Error: fast_energy not implemented in KP3IndexFactorization_batched. \n");
1172   }
1173 
1174   template<
1175       class MatA,
1176       class MatB,
1177       typename = typename std::enable_if_t<(std::decay<MatA>::type::dimensionality == 1)>,
1178       typename = typename std::enable_if_t<(std::decay<MatB>::type::dimensionality == 1)>,
1179       //             typename = decltype(boost::multi::static_array_cast<ComplexType, pointer>(std::declval<MatA>())),
1180       //             typename = decltype(boost::multi::static_array_cast<ComplexType, pointer>(std::declval<MatB>())),
1181       typename = void>
1182   void vHS(MatA& X, MatB&& v, double a = 1., double c = 0.)
1183   {
1184     using BType = typename std::decay<MatB>::type::element;
1185     using AType = typename std::decay<MatA>::type::element;
1186     boost::multi::array_ref<AType, 2, decltype(X.origin())> X_(X.origin(), {X.size(0), 1});
1187     boost::multi::array_ref<BType, 2, decltype(v.origin())> v_(v.origin(), {1, v.size(0)});
1188     return vHS(X_, v_, a, c);
1189   }
1190 
1191   template<
1192       class MatA,
1193       class MatB,
1194       typename = typename std::enable_if_t<(std::decay<MatA>::type::dimensionality == 2)>,
1195       typename = typename std::enable_if_t<(std::decay<MatB>::type::dimensionality == 2)>
1196       //             typename = decltype(boost::multi::static_array_cast<ComplexType, pointer>(std::declval<MatA>())),
1197       //             typename = decltype(boost::multi::static_array_cast<ComplexType, pointer>(std::declval<MatB>()))
1198       >
1199   void vHS(MatA& X, MatB&& v, double a = 1., double c = 0.)
1200   {
1201     int nkpts = nopk.size();
1202     int nwalk = X.size(1);
1203     assert(v.size(0) == nwalk);
1204     int nspin     = (walker_type == COLLINEAR ? 2 : 1);
1205     int nmo_tot   = std::accumulate(nopk.begin(), nopk.end(), 0);
1206     int nmo_max   = *std::max_element(nopk.begin(), nopk.end());
1207     int nchol_max = *std::max_element(ncholpQ.begin(), ncholpQ.end());
1208     assert(X.num_elements() == nwalk * 2 * local_nCV);
1209     assert(v.num_elements() == nwalk * nmo_tot * nmo_tot);
1210     SPComplexType one(1.0, 0.0);
1211     SPComplexType im(0.0, 1.0);
1212     SPComplexType halfa(0.5 * a, 0.0);
1213     SPComplexType minusimhalfa(0.0, -0.5 * a);
1214     SPComplexType imhalfa(0.0, 0.5 * a);
1215 
1216     Static3Tensor vKK({nkpts + number_of_symmetric_Q, nkpts, nwalk * nmo_max * nmo_max},
1217                       device_buffer_manager.get_generator().template get_allocator<SPComplexType>());
1218     fill_n(vKK.origin(), vKK.num_elements(), SPComplexType(0.0));
1219     Static4Tensor XQnw({nkpts, 2, nchol_max, nwalk},
1220                        device_buffer_manager.get_generator().template get_allocator<SPComplexType>());
1221     fill_n(XQnw.origin(), XQnw.num_elements(), SPComplexType(0.0));
1222 
1223     // "rotate" X
1224     //  XIJ = 0.5*a*(Xn+ -i*Xn-), XJI = 0.5*a*(Xn+ +i*Xn-)
1225 #if defined(MIXED_PRECISION)
1226     StaticMatrix Xdev(X.extensions(), device_buffer_manager.get_generator().template get_allocator<SPComplexType>());
1227     copy_n_cast(make_device_ptr(X.origin()), X.num_elements(), Xdev.origin());
1228 #else
1229     SpMatrix_ref Xdev(make_device_ptr(X.origin()), X.extensions());
1230 #endif
1231     for (int Q = 0; Q < nkpts; ++Q)
1232     {
1233       if (Qmap[Q] < 0)
1234         continue;
1235       int nq = Q2vbias[Q];
1236       auto&& Xp(Xdev.sliced(nq, nq + ncholpQ[Q]));
1237       auto&& Xm(Xdev.sliced(nq + ncholpQ[Q], nq + 2 * ncholpQ[Q]));
1238       ma::add(halfa, Xp, minusimhalfa, Xm, XQnw[Q][0].sliced(0, ncholpQ[Q]));
1239       ma::add(halfa, Xp, imhalfa, Xm, XQnw[Q][1].sliced(0, ncholpQ[Q]));
1240       nq += 2 * ncholpQ[Q];
1241     }
1242     //  then combine Q/(-Q) pieces
1243     //  X(Q)np = (X(Q)np + X(-Q)nm)
1244     for (int Q = 0; Q < nkpts; ++Q)
1245     {
1246       if (Qmap[Q] == 0)
1247       {
1248         int Qm = kminus[Q];
1249         ma::axpy(SPComplexType(1.0), XQnw[Qm][1], XQnw[Q][0]);
1250       }
1251     }
1252     {
1253       // assuming contiguous
1254       ma::scal(c, v);
1255     }
1256 
1257     int nmo_max2 = nmo_max * nmo_max;
1258     using ma::gemmBatched;
1259     std::vector<sp_pointer> Aarray;
1260     std::vector<sp_pointer> Barray;
1261     std::vector<sp_pointer> Carray;
1262     Aarray.reserve(nkpts * nkpts);
1263     Barray.reserve(nkpts * nkpts);
1264     Carray.reserve(nkpts * nkpts);
1265     for (int Q = 0; Q < nkpts; ++Q)
1266     { // momentum conservation index
1267       if (Qmap[Q] < 0)
1268         continue;
1269       // v[nw][i(in K)][k(in Q(K))] += sum_n LQK[i][k][n] X[Q][0][n][nw]
1270       if (Q <= kminus[Q])
1271       {
1272         for (int K = 0; K < nkpts; ++K)
1273         { // K is the index of the kpoint pair of (i,k)
1274           int QK = QKToK2[Q][K];
1275           Aarray.push_back(sp_pointer(LQKikn[Q][K].origin()));
1276           Barray.push_back(XQnw[Q][0].origin());
1277           Carray.push_back(vKK[K][QK].origin());
1278         }
1279       }
1280     }
1281     // C: v = T(X) * T(Lik) --> F: T(Lik) * T(X) = v
1282     gemmBatched('T', 'T', nmo_max2, nwalk, nchol_max, SPComplexType(1.0), Aarray.data(), nchol_max, Barray.data(),
1283                 nwalk, SPComplexType(0.0), Carray.data(), nmo_max2, Aarray.size());
1284 
1285 
1286     Aarray.clear();
1287     Barray.clear();
1288     Carray.clear();
1289     for (int Q = 0; Q < nkpts; ++Q)
1290     { // momentum conservation index
1291       if (Qmap[Q] < 0)
1292         continue;
1293       // v[nw][i(in K)][k(in Q(K))] += sum_n LQK[i][k][n] X[Q][0][n][nw]
1294       if (Q > kminus[Q])
1295       { // use L(-Q)(ki)*
1296         for (int K = 0; K < nkpts; ++K)
1297         { // K is the index of the kpoint pair of (i,k)
1298           int QK = QKToK2[Q][K];
1299           Aarray.push_back(sp_pointer(LQKikn[kminus[Q]][QK].origin()));
1300           Barray.push_back(XQnw[Q][0].origin());
1301           Carray.push_back(vKK[K][QK].origin());
1302         }
1303       }
1304       else if (Qmap[Q] > 0)
1305       { // rho(Q)^+ term
1306         for (int K = 0; K < nkpts; ++K)
1307         { // K is the index of the kpoint pair of (i,k)
1308           int QK = QKToK2[Q][K];
1309           Aarray.push_back(sp_pointer(LQKikn[Q][K].origin()));
1310           Barray.push_back(XQnw[Q][1].origin());
1311           Carray.push_back(vKK[nkpts + Qmap[Q] - 1][QK].origin());
1312         }
1313       }
1314     }
1315     // C: v = T(X) * T(Lik) --> F: T(Lik) * T(X) = v
1316     gemmBatched('C', 'T', nmo_max2, nwalk, nchol_max, SPComplexType(1.0), Aarray.data(), nchol_max, Barray.data(),
1317                 nwalk, SPComplexType(0.0), Carray.data(), nmo_max2, Aarray.size());
1318 
1319 
1320     using vType = typename std::decay<MatB>::type::element;
1321     boost::multi::array_ref<vType, 3, decltype(make_device_ptr(v.origin()))> v3D(make_device_ptr(v.origin()),
1322                                                                                  {nwalk, nmo_tot, nmo_tot});
1323     vKKwij_to_vwKiKj(vKK, v3D);
1324     // do I need to "rotate" back, can be done if necessary
1325   }
1326 
1327   template<
1328       class MatA,
1329       class MatB,
1330       typename = typename std::enable_if_t<(std::decay<MatA>::type::dimensionality == 1)>,
1331       typename = typename std::enable_if_t<(std::decay<MatB>::type::dimensionality == 1)>,
1332       //             typename = decltype(boost::multi::static_array_cast<ComplexType, pointer>(std::declval<MatA&>())),
1333       //             typename = decltype(boost::multi::static_array_cast<ComplexType, pointer>(std::declval<MatB>())),
1334       typename = void>
1335   void vbias(const MatA& G, MatB&& v, double a = 1., double c = 0., int k = 0)
1336   {
1337     using BType = typename std::decay<MatB>::type::element;
1338     using AType = typename std::decay<MatA>::type::element;
1339     boost::multi::array_ref<BType, 2, decltype(v.origin())> v_(v.origin(), {v.size(0), 1});
1340     boost::multi::array_ref<AType const, 2, decltype(G.origin())> G_(G.origin(), {G.size(0), 1});
1341     return vbias(G_, v_, a, c, k);
1342   }
1343 
1344   /*
1345     template<class MatA, class MatB,
1346              typename = typename std::enable_if_t<(std::decay<MatA>::type::dimensionality==2)>,
1347              typename = typename std::enable_if_t<(std::decay<MatB>::type::dimensionality==1)>,
1348 //             typename = typename std::enable_if_t<(std::is_convertible<typename std::decay<MatA>::type::pointer,pointer>::value)>,
1349 //             typename = typename std::enable_if_t<(not std::is_convertible<typename std::decay<MatB>::type::pointer,pointer>::value)>,
1350               typename = void,
1351               typename = void,
1352               typename = void,
1353               typename = void,
1354               typename = void
1355             >
1356     void vbias(const MatA& G, MatB&& v, double a=1., double c=0., int nd=0) {
1357     }
1358 */
1359 
1360   template<
1361       class MatA,
1362       class MatB,
1363       typename = std::enable_if_t<(std::decay<MatA>::type::dimensionality == 2)>,
1364       typename = std::enable_if_t<(std::decay<MatB>::type::dimensionality == 2)>
1365       //             typename = std::enable_if_t<(std::is_convertible<typename std::decay<MatA>::type::element_ptr,pointer>::value)>,
1366       //             typename = std::enable_if_t<(std::is_convertible<typename std::decay<MatB>::type::element_ptr,pointer>::value)>
1367       >
1368   void vbias(const MatA& G, MatB&& v, double a = 1., double c = 0., int nd = 0)
1369   {
1370     using ma::gemmBatched;
1371 
1372     int nkpts = nopk.size();
1373     assert(nd >= 0 && nd < nelpk.size());
1374     int nwalk = G.size(1);
1375     assert(v.size(0) == 2 * local_nCV);
1376     assert(v.size(1) == nwalk);
1377     int nspin     = (walker_type == COLLINEAR ? 2 : 1);
1378     int npol      = (walker_type == NONCOLLINEAR ? 2 : 1);
1379     int nmo_tot   = std::accumulate(nopk.begin(), nopk.end(), 0);
1380     int nmo_max   = *std::max_element(nopk.begin(), nopk.end());
1381     int nocca_tot = std::accumulate(nelpk[nd].begin(), nelpk[nd].begin() + nkpts, 0);
1382     int nocca_max = *std::max_element(nelpk[nd].begin(), nelpk[nd].begin() + nkpts);
1383     int noccb_max = nocca_max;
1384     int nchol_max = *std::max_element(ncholpQ.begin(), ncholpQ.end());
1385     int noccb_tot = 0;
1386     if (walker_type == COLLINEAR)
1387     {
1388       noccb_tot = std::accumulate(nelpk[nd].begin() + nkpts, nelpk[nd].begin() + 2 * nkpts, 0);
1389       noccb_max = *std::max_element(nelpk[nd].begin() + nkpts, nelpk[nd].begin() + 2 * nkpts);
1390     }
1391     RealType scl = (walker_type == CLOSED ? 2.0 : 1.0);
1392     SPComplexType one(1.0, 0.0);
1393     SPComplexType halfa(0.5 * a * scl, 0.0);
1394     SPComplexType minusimhalfa(0.0, -0.5 * a * scl);
1395     SPComplexType imhalfa(0.0, 0.5 * a * scl);
1396 
1397     assert(G.num_elements() == nwalk * (nocca_tot + noccb_tot) * npol * nmo_tot);
1398     // MAM: use reshape when available, then no need to deal with types
1399     using GType = typename std::decay<MatA>::type::element;
1400     boost::multi::array_ref<GType const, 3, decltype(make_device_ptr(G.origin()))> G3Da(make_device_ptr(G.origin()),
1401                                                                                         {nocca_tot * npol, nmo_tot,
1402                                                                                          nwalk});
1403     boost::multi::array_ref<GType const, 3, decltype(make_device_ptr(G.origin()))> G3Db(make_device_ptr(G.origin()) +
1404                                                                                             G3Da.num_elements() *
1405                                                                                                 (nspin - 1),
1406                                                                                         {noccb_tot, nmo_tot, nwalk});
1407 
1408     // assuming contiguous
1409     ma::scal(c, v);
1410 
1411     for (int spin = 0; spin < nspin; spin++)
1412     {
1413       size_t cnt(0);
1414       Static3Tensor v1({nkpts + number_of_symmetric_Q, nchol_max, nwalk},
1415                        device_buffer_manager.get_generator().template get_allocator<SPComplexType>());
1416       Static3Tensor GQ({nkpts, nkpts * nocc_max * npol * nmo_max, nwalk},
1417                        device_buffer_manager.get_generator().template get_allocator<SPComplexType>());
1418       fill_n(v1.origin(), v1.num_elements(), SPComplexType(0.0));
1419       fill_n(GQ.origin(), GQ.num_elements(), SPComplexType(0.0));
1420 
1421       if (spin == 0)
1422         GKaKjw_to_GQKajw(G3Da, GQ, nelpk[nd], dev_nelpk[nd], dev_a0pk[nd]);
1423       else
1424         GKaKjw_to_GQKajw(G3Db, GQ, nelpk[nd].sliced(nkpts, 2 * nkpts), dev_nelpk[nd].sliced(nkpts, 2 * nkpts),
1425                          dev_a0pk[nd].sliced(nkpts, 2 * nkpts));
1426 
1427       // can use productStridedBatched if LQKakn is changed to a 3Tensor array
1428       int Kak = nkpts * nocc_max * npol * nmo_max;
1429       std::vector<sp_pointer> Aarray;
1430       std::vector<sp_pointer> Barray;
1431       std::vector<sp_pointer> Carray;
1432       Aarray.reserve(nkpts + number_of_symmetric_Q);
1433       Barray.reserve(nkpts + number_of_symmetric_Q);
1434       Carray.reserve(nkpts + number_of_symmetric_Q);
1435       for (int Q = 0; Q < nkpts; ++Q)
1436       { // momentum conservation index
1437         if (Qmap[Q] < 0)
1438           continue;
1439         // v_[Q][n][w] = sum_Kak LQ[Kak][n]*G[Q][Kak][w]
1440         //             F: -->   G[Kak][w] * LQ[Kak][n]
1441         Aarray.push_back(GQ[Q].origin());
1442         Barray.push_back(sp_pointer(LQKakn[nd * nspin * nkpts + spin * nkpts + Q].origin()));
1443         Carray.push_back(v1[Q].origin());
1444         if (Qmap[Q] > 0)
1445         {
1446           Aarray.push_back(GQ[Q].origin());
1447           Barray.push_back(sp_pointer(
1448               LQKbln[nd * nspin * number_of_symmetric_Q + spin * number_of_symmetric_Q + Qmap[Q] - 1].origin()));
1449           Carray.push_back(v1[nkpts + Qmap[Q] - 1].origin());
1450         }
1451       }
1452       gemmBatched('N', 'T', nwalk, nchol_max, Kak, SPComplexType(1.0), Aarray.data(), nwalk, Barray.data(), nchol_max,
1453                   SPComplexType(0.0), Carray.data(), nwalk, Aarray.size());
1454       // optimize later, right now it adds contributions from Q's not assigned
1455       vbias_from_v1(halfa, v1, v);
1456     }
1457   }
1458 
1459   template<class Mat, class MatB>
1460   void generalizedFockMatrix(Mat&& G, MatB&& Fp, MatB&& Fm)
1461   {
1462     APP_ABORT(" Error: generalizedFockMatrix not implemented for this hamiltonian.\n");
1463   }
1464 
1465   bool distribution_over_cholesky_vectors() const { return true; }
1466   int number_of_ke_vectors() const { return local_nCV; }
1467   int local_number_of_cholesky_vectors() const { return 2 * local_nCV; }
1468   int global_number_of_cholesky_vectors() const { return global_nCV; }
1469   int global_origin_cholesky_vector() const { return global_origin; }
1470 
1471   // transpose=true means G[nwalk][ik], false means G[ik][nwalk]
1472   bool transposed_G_for_vbias() const { return false; }
1473   bool transposed_G_for_E() const { return false; }
1474   // transpose=true means vHS[nwalk][ik], false means vHS[ik][nwalk]
1475   bool transposed_vHS() const { return true; }
1476 
1477   bool fast_ph_energy() const { return false; }
1478 
1479   boost::multi::array<ComplexType, 2> getHSPotentials() { return boost::multi::array<ComplexType, 2>{}; }
1480 
1481 private:
1482   int nocc_max;
1483 
1484   afqmc::TaskGroup_& TG;
1485 
1486   Allocator allocator_;
1487   SpAllocator sp_allocator_;
1488   DeviceBufferManager device_buffer_manager;
1489 
1490   WALKER_TYPES walker_type;
1491 
1492   int global_nCV;
1493   int local_nCV;
1494   int global_origin;
1495 
1496   int default_buffer_size_in_MB;
1497   int last_nw;
1498 
1499   ValueType E0;
1500 
1501   // bare one body hamiltonian
1502   mpi3C3Tensor H1;
1503 
1504   // (potentially half rotated) one body hamiltonian
1505   shmCMatrix haj;
1506   //std::vector<shmCVector> haj;
1507 
1508   // number of orbitals per k-point
1509   boost::multi::array<int, 1> nopk;
1510 
1511   // number of cholesky vectors per Q-point
1512   boost::multi::array<int, 1> ncholpQ;
1513 
1514   // position of (-K) in kp-list for every K
1515   boost::multi::array<int, 1> kminus;
1516 
1517   // number of electrons per k-point
1518   // nelpk[ndet][nspin*nkpts]
1519   //shmIMatrix nelpk;
1520   boost::multi::array<int, 2> nelpk;
1521 
1522   // maps (Q,K) --> k2
1523   //shmIMatrix QKToK2;
1524   boost::multi::array<int, 2> QKToK2;
1525 
1526   //Cholesky Tensor Lik[Q][nk][i][k][n]
1527   std::vector<shmSpMatrix> LQKikn;
1528 
1529   // half-tranformed Cholesky tensor
1530   std::vector<LQKankMatrix> LQKank;
1531   const bool needs_copy;
1532 
1533   // half-tranformed Cholesky tensor
1534   std::vector<shmSpMatrix> LQKakn;
1535 
1536   // half-tranformed Cholesky tensor
1537   std::vector<shmSpMatrix> LQKbnl;
1538 
1539   // half-tranformed Cholesky tensor
1540   std::vector<shmSpMatrix> LQKbln;
1541 
1542   // number of Q vectors that satisfy Q==-Q
1543   int number_of_symmetric_Q;
1544 
1545   // number of Q points assigned to this task
1546   int number_of_Q_points;
1547 
1548   // Defines behavior over Q vector:
1549   //   <0: Ignore (handled by another TG)
1550   //    0: Calculate, without rho^+ contribution
1551   //   >0: Calculate, with rho^+ contribution. LQKbln data located at Qmap[Q]-1
1552   stdIVector Qmap;
1553 
1554   // maps Q (only for those with Qmap >=0) to the corresponding sector in vbias
1555   stdIVector Q2vbias;
1556 
1557   // one-body piece of Hamiltonian factorization
1558   mpi3C3Tensor vn0;
1559 
1560   int nsampleQ;
1561   std::vector<RealType> gQ;
1562   boost::multi::array<int, 2> Qwn;
1563   std::default_random_engine generator;
1564   std::discrete_distribution<int> distribution;
1565 
1566   IMatrix KKTransID;
1567   IVector dev_nopk;
1568   IVector dev_i0pk;
1569   IVector dev_kminus;
1570   IVector dev_ncholpQ;
1571   IVector dev_Q2vbias;
1572   IVector dev_Qmap;
1573   IMatrix dev_nelpk;
1574   IMatrix dev_a0pk;
1575   IMatrix dev_QKToK2;
1576 
1577   //    std::vector<std::unique_ptr<shared_mutex>> mutex;
1578 
1579   //    boost::multi::array<ComplexType,3> Qave;
1580   //    int cntQave=0;
1581   std::vector<ComplexType> EQ;
1582   //    std::default_random_engine generator;
1583   //    std::uniform_real_distribution<RealType> distribution(RealType(0.0),Realtype(1.0));
1584 
1585   template<class MatA, class MatB, class IVec, class IVec2>
1586   void GKaKjw_to_GKKwaj(MatA const& GKaKj, MatB&& GKKaj, IVec&& nocc, IVec2&& dev_no, IVec2&& dev_a0)
1587   {
1588     int npol    = (walker_type == NONCOLLINEAR) ? 2 : 1;
1589     int nmo_max = *std::max_element(nopk.begin(), nopk.end());
1590     //      int nocc_max = *std::max_element(nocc.begin(),nocc.end());
1591     int nmo_tot = GKaKj.size(1);
1592     int nwalk   = GKaKj.size(2);
1593     int nkpts   = nopk.size();
1594     assert(GKKaj.num_elements() >= nkpts * nkpts * nwalk * nocc_max * npol * nmo_max);
1595 
1596     using ma::KaKjw_to_KKwaj;
1597     KaKjw_to_KKwaj(nwalk, nkpts, npol, nmo_max, nmo_tot, nocc_max, dev_nopk.origin(), dev_i0pk.origin(),
1598                    dev_no.origin(), dev_a0.origin(), GKaKj.origin(), GKKaj.origin());
1599   }
1600 
1601   template<class MatA, class MatB, class IVec, class IVec2>
1602   void GKaKjw_to_GQKajw(MatA const& GKaKj, MatB&& GQKaj, IVec&& nocc, IVec2&& dev_no, IVec2&& dev_a0)
1603   {
1604     int npol    = (walker_type == NONCOLLINEAR) ? 2 : 1;
1605     int nmo_max = *std::max_element(nopk.begin(), nopk.end());
1606     //      int nocc_max = *std::max_element(nocc.begin(),nocc.end());
1607     int nmo_tot = GKaKj.size(1);
1608     int nwalk   = GKaKj.size(2);
1609     int nkpts   = nopk.size();
1610     assert(GQKaj.num_elements() >= nkpts * nkpts * nwalk * nocc_max * npol * nmo_max);
1611 
1612     using ma::KaKjw_to_QKajw;
1613     KaKjw_to_QKajw(nwalk, nkpts, npol, nmo_max, nmo_tot, nocc_max, dev_nopk.origin(), dev_i0pk.origin(),
1614                    dev_no.origin(), dev_a0.origin(), dev_QKToK2.origin(), GKaKj.origin(), GQKaj.origin());
1615   }
1616 
1617 
1618   /*
1619      *   vKiKj({nwalk,nmo_tot,nmo_tot});
1620      *   vKK({nkpts,nkpts,nwalk*nmo_max*nmo_max} );
1621      */
1622   template<class MatA, class MatB>
1623   void vKKwij_to_vwKiKj(MatA const& vKK, MatB&& vKiKj)
1624   {
1625     int nmo_max = *std::max_element(nopk.begin(), nopk.end());
1626     int nwalk   = vKiKj.size(0);
1627     int nmo_tot = vKiKj.size(1);
1628     int nkpts   = nopk.size();
1629 
1630     using ma::vKKwij_to_vwKiKj;
1631     vKKwij_to_vwKiKj(nwalk, nkpts, nmo_max, nmo_tot, KKTransID.origin(), dev_nopk.origin(), dev_i0pk.origin(),
1632                      vKK.origin(), vKiKj.origin());
1633   }
1634 
1635   template<class MatA, class MatB>
1636   void vbias_from_v1(ComplexType a, MatA const& v1, MatB&& vbias)
1637   {
1638     using BType   = typename std::decay<MatB>::type::element;
1639     int nwalk     = vbias.size(1);
1640     int nkpts     = nopk.size();
1641     int nchol_max = *std::max_element(ncholpQ.begin(), ncholpQ.end());
1642 
1643     using ma::vbias_from_v1;
1644     // using make_device_ptr(vbias.origin()) to catch errors here
1645     vbias_from_v1(nwalk, nkpts, nchol_max, dev_Qmap.origin(), dev_kminus.origin(), dev_ncholpQ.origin(),
1646                   dev_Q2vbias.origin(), static_cast<BType>(a), v1.origin(),
1647                   to_address(make_device_ptr(vbias.origin())));
1648   }
1649 };
1650 
1651 } // namespace afqmc
1652 
1653 } // namespace qmcplusplus
1654 
1655 #endif
1656