1 /* Copyright 2014-2018 The PySCF Developers. All Rights Reserved.
2 
3    Licensed under the Apache License, Version 2.0 (the "License");
4     you may not use this file except in compliance with the License.
5     You may obtain a copy of the License at
6 
7         http://www.apache.org/licenses/LICENSE-2.0
8 
9     Unless required by applicable law or agreed to in writing, software
10     distributed under the License is distributed on an "AS IS" BASIS,
11     WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12     See the License for the specific language governing permissions and
13     limitations under the License.
14 
15  *
16  * Author: Qiming Sun <osirpt.sun@gmail.com>
17  */
18 
19 #include <stdlib.h>
20 #include <math.h>
21 #include <assert.h>
22 //#define NDEBUG
23 
24 //#include <omp.h>
25 #include "config.h"
26 #include "cint.h"
27 #include "np_helper/np_helper.h"
28 #include "vhf/cvhf.h"
29 #include "vhf/fblas.h"
30 #include "vhf/nr_direct.h"
31 #include "nr_ao2mo.h"
32 
33 #define MIN(X,Y)        ((X) < (Y) ? (X) : (Y))
34 #define MAX(X,Y)        ((X) > (Y) ? (X) : (Y))
35 // 9f or 7g or 5h functions should be enough
36 #define NCTRMAX         64
37 #define OUTPUTIJ        1
38 #define INPUT_IJ        2
39 
40 /*
41  * Denoting 2e integrals (ij|kl),
42  * AO2MOnr_e1_drv transforms ij for ksh_start <= k shell < ksh_end.
43  * The transformation C_pi C_qj (pq|k*) coefficients are stored in
44  * mo_coeff, C_pi and C_qj are offset by i_start and i_count, j_start and j_count.
45  * The output eri is an 2D array, ordered as (kl-AO-pair,ij-MO-pair) in
46  * C-order.  Transposing is needed before calling AO2MOnr_e2_drv.
47  *
48  * AO2MOnr_e2_drv transforms kl for nijcount of ij pairs.
49  * vin is assumed to be an C-array of (ij-MO-pair, kl-AO-pair)
50  * vout is an C-array of (ij-MO-pair, kl-MO-pair)
51  *
52  * ftranse1 and ftranse2
53  * ---------------------
54  * AO2MOtranse1_nr_s4, AO2MOtranse1_nr_s2ij, AO2MOtranse1_nr_s2kl AO2MOtranse1_nr_s1
55  * AO2MOtranse2_nr_s4, AO2MOtranse2_nr_s2ij, AO2MOtranse2_nr_s2kl AO2MOtranse2_nr_s1
56  * Labels s4, s2, s1 are used to label the AO integral symmetry.  The
57  * symmetry of transformed integrals are controled by function fmmm
58  *
59  * fmmm
60  * ----
61  * fmmm dim requirements:
62  *                      | vout                          | eri
63  * ---------------------+-------------------------------+-------------------
64  *  AO2MOmmm_nr_s2_s2   | [:,bra_count*(bra_count+1)/2] | [:,nao*(nao+1)/2]
65  *                      |    and bra_count==ket_count   |
66  *  AO2MOmmm_nr_s2_iltj | [:,bra_count*ket_count]       | [:,nao*nao]
67  *  AO2MOmmm_nr_s2_igtj | [:,bra_count*ket_count]       | [:,nao*nao]
68  *  AO2MOmmm_nr_s1_iltj | [:,bra_count*ket_count]       | [:,nao*nao]
69  *  AO2MOmmm_nr_s1_igtj | [:,bra_count*ket_count]       | [:,nao*nao]
70  *
71  * AO2MOmmm_nr_s1_iltj, AO2MOmmm_nr_s1_igtj, AO2MOmmm_nr_s2_s2,
72  * AO2MOmmm_nr_s2_iltj, AO2MOmmm_nr_s2_igtj
73  * Pick a proper function from the 5 kinds of AO2MO transformation.
74  * 1. AO integral I_ij != I_ji, use
75  *    AO2MOmmm_nr_s1_iltj or AO2MOmmm_nr_s1_igtj
76  * 2. AO integral I_ij == I_ji, but the MO coefficients for bra and ket
77  *    are different, use
78  *    AO2MOmmm_nr_s2_iltj or AO2MOmmm_nr_s2_igtj
79  * 3. AO integral I_ij == I_ji, and the MO coefficients are the same for
80  *    bra and ket, use
81  *    AO2MOmmm_nr_s2_s2
82  *
83  *      ftrans           |     allowed fmmm
84  * ----------------------+---------------------
85  *  AO2MOtranse1_nr_s4   |  AO2MOmmm_nr_s2_s2
86  *  AO2MOtranse1_nr_s2ij |  AO2MOmmm_nr_s2_iltj
87  *  AO2MOtranse2_nr_s4   |  AO2MOmmm_nr_s2_igtj
88  *  AO2MOtranse2_nr_s2kl |
89  * ----------------------+---------------------
90  *  AO2MOtranse1_nr_s2kl |  AO2MOmmm_nr_s2_s2
91  *  AO2MOtranse2_nr_s2ij |  AO2MOmmm_nr_s2_igtj
92  *                       |  AO2MOmmm_nr_s2_iltj
93  * ----------------------+---------------------
94  *  AO2MOtranse1_nr_s1   |  AO2MOmmm_nr_s1_iltj
95  *  AO2MOtranse2_nr_s1   |  AO2MOmmm_nr_s1_igtj
96  *
97  */
98 
99 
100 /* for m > n
101  * calculate the upper triangle part (of Fortran order matrix)
102  *   _        |------- n -------| _
103  *   diag_off [ . . . . . . . . ] |
104  *   _        [ . . . . . . . . ] m
105  *            [   . . . . . . . ] |
106  *            [     . . . . . . ] _
107  */
AO2MOdtriumm_o1(int m,int n,int k,int diag_off,double * a,double * b,double * c)108 void AO2MOdtriumm_o1(int m, int n, int k, int diag_off,
109                      double *a, double *b, double *c)
110 {
111         const double D0 = 0;
112         const double D1 = 1;
113         const char TRANS_N = 'N';
114         const char TRANS_T = 'T';
115         const int BLK = 48;
116         int mstart = m - MAX(0, (m-diag_off)/BLK)*BLK;
117         int nstart = mstart - diag_off;
118         int nleft;
119 
120         dgemm_(&TRANS_T, &TRANS_N, &mstart, &n, &k,
121                &D1, a, &k, b, &k, &D0, c, &m);
122 
123         for (; mstart < m; mstart+=BLK, nstart+=BLK) {
124                 nleft = n - nstart;
125 
126                 dgemm_(&TRANS_T, &TRANS_N, &BLK, &nleft, &k,
127                        &D1, a+mstart*k, &k, b+nstart*k, &k,
128                        &D0, c+nstart*m+mstart, &m);
129         }
130 }
131 
132 /* for m < n
133  * calculate the upper triangle part (of Fortran order matrix)
134  *   _        |------- n -------| _
135  *   diag_off [ . . . . . . . . ] |
136  *   _        [ . . . . . . . . ] m
137  *            [   . . . . . . . ] |
138  *            [     . . . . . . ] _
139  */
AO2MOdtriumm_o2(int m,int n,int k,int diag_off,double * a,double * b,double * c)140 void AO2MOdtriumm_o2(int m, int n, int k, int diag_off,
141                      double *a, double *b, double *c)
142 {
143         const double D0 = 0;
144         const double D1 = 1;
145         const char TRANS_N = 'N';
146         const char TRANS_T = 'T';
147         const int BLK = 48;
148         int nstart, nleft;
149         int mend = diag_off;
150 
151         for (nstart = 0; nstart < m-diag_off-BLK; nstart+=BLK) {
152                 mend += BLK;
153                 dgemm_(&TRANS_T, &TRANS_N, &mend, &BLK, &k,
154                        &D1, a, &k, b+nstart*k, &k,
155                        &D0, c+nstart*m, &m);
156         }
157         nleft = n - nstart;
158         dgemm_(&TRANS_T, &TRANS_N, &m, &nleft, &k,
159                &D1, a, &k, b+nstart*k, &k,
160                &D0, c+nstart*m, &m);
161 }
162 
163 
164 /*
165  * s1-AO integrals to s1-MO integrals, efficient for i_count < j_count
166  * shape requirements:
167  *      vout[:,bra_count*ket_count], eri[:,nao*nao]
168  * s1, s2 here to label the AO symmetry
169  */
AO2MOmmm_nr_s1_iltj(double * vout,double * eri,double * buf,struct _AO2MOEnvs * envs,int seekdim)170 int AO2MOmmm_nr_s1_iltj(double *vout, double *eri, double *buf,
171                         struct _AO2MOEnvs *envs, int seekdim)
172 {
173         switch (seekdim) {
174                 case OUTPUTIJ: return envs->bra_count * envs->ket_count;
175                 case INPUT_IJ: return envs->nao * envs->nao;
176         }
177         const double D0 = 0;
178         const double D1 = 1;
179         const char TRANS_T = 'T';
180         const char TRANS_N = 'N';
181         int nao = envs->nao;
182         int i_start = envs->bra_start;
183         int i_count = envs->bra_count;
184         int j_start = envs->ket_start;
185         int j_count = envs->ket_count;
186         double *mo_coeff = envs->mo_coeff;
187 
188         // C_pi (pq| = (iq|, where (pq| is in C-order
189         dgemm_(&TRANS_N, &TRANS_N, &nao, &i_count, &nao,
190                &D1, eri, &nao, mo_coeff+i_start*nao, &nao,
191                &D0, buf, &nao);
192         dgemm_(&TRANS_T, &TRANS_N, &j_count, &i_count, &nao,
193                &D1, mo_coeff+j_start*nao, &nao, buf, &nao,
194                &D0, vout, &j_count);
195         return 0;
196 }
197 /*
198  * s1-AO integrals to s1-MO integrals, efficient for i_count > j_count
199  * shape requirements:
200  *      vout[:,bra_count*ket_count], eri[:,nao*nao]
201  */
AO2MOmmm_nr_s1_igtj(double * vout,double * eri,double * buf,struct _AO2MOEnvs * envs,int seekdim)202 int AO2MOmmm_nr_s1_igtj(double *vout, double *eri, double *buf,
203                         struct _AO2MOEnvs *envs, int seekdim)
204 {
205         switch (seekdim) {
206                 case OUTPUTIJ: return envs->bra_count * envs->ket_count;
207                 case INPUT_IJ: return envs->nao * envs->nao;
208         }
209         const double D0 = 0;
210         const double D1 = 1;
211         const char TRANS_T = 'T';
212         const char TRANS_N = 'N';
213         int nao = envs->nao;
214         int i_start = envs->bra_start;
215         int i_count = envs->bra_count;
216         int j_start = envs->ket_start;
217         int j_count = envs->ket_count;
218         double *mo_coeff = envs->mo_coeff;
219 
220         // C_qj (pq| = (pj|, where (pq| is in C-order
221         dgemm_(&TRANS_T, &TRANS_N, &j_count, &nao, &nao,
222                &D1, mo_coeff+j_start*nao, &nao, eri, &nao,
223                &D0, buf, &j_count);
224         dgemm_(&TRANS_N, &TRANS_N, &j_count, &i_count, &nao,
225                &D1, buf, &j_count, mo_coeff+i_start*nao, &nao,
226                &D0, vout, &j_count);
227         return 0;
228 }
229 
230 /*
231  * s2-AO integrals to s2-MO integrals
232  * shape requirements:
233  *      vout[:,bra_count*(bra_count+1)/2] and bra_count==ket_count,
234  *      eri[:,nao*(nao+1)/2]
235  * first s2 is the AO symmetry, second s2 is the MO symmetry
236  */
AO2MOmmm_nr_s2_s2(double * vout,double * eri,double * buf,struct _AO2MOEnvs * envs,int seekdim)237 int AO2MOmmm_nr_s2_s2(double *vout, double *eri, double *buf,
238                       struct _AO2MOEnvs *envs, int seekdim)
239 {
240         switch (seekdim) {
241                 case OUTPUTIJ: assert(envs->bra_count == envs->ket_count);
242                                return envs->bra_count * (envs->bra_count+1) / 2;
243                 case INPUT_IJ: return envs->nao * (envs->nao+1) / 2;
244         }
245         const double D0 = 0;
246         const double D1 = 1;
247         const char SIDE_L = 'L';
248         const char UPLO_U = 'U';
249         int nao = envs->nao;
250         int i_start = envs->bra_start;
251         int i_count = envs->bra_count;
252         int j_start = envs->ket_start;
253         int j_count = envs->ket_count;
254         double *mo_coeff = envs->mo_coeff;
255         double *buf1 = buf + nao*i_count;
256         int i, j, ij;
257 
258         // C_pi (pq| = (iq|, where (pq| is in C-order
259         dsymm_(&SIDE_L, &UPLO_U, &nao, &i_count,
260                &D1, eri, &nao, mo_coeff+i_start*nao, &nao,
261                &D0, buf, &nao);
262         AO2MOdtriumm_o1(j_count, i_count, nao, 0,
263                         mo_coeff+j_start*nao, buf, buf1);
264 
265         for (i = 0, ij = 0; i < i_count; i++) {
266                 for (j = 0; j <= i; j++, ij++) {
267                         vout[ij] = buf1[j];
268                 }
269                 buf1 += j_count;
270         }
271         return 0;
272 }
273 
274 /*
275  * s2-AO integrals to s1-MO integrals, efficient for i_count < j_count
276  * shape requirements:
277  *      vout[:,bra_count*ket_count], eri[:,nao*(nao+1)/2]
278  */
AO2MOmmm_nr_s2_iltj(double * vout,double * eri,double * buf,struct _AO2MOEnvs * envs,int seekdim)279 int AO2MOmmm_nr_s2_iltj(double *vout, double *eri, double *buf,
280                         struct _AO2MOEnvs *envs, int seekdim)
281 {
282         switch (seekdim) {
283                 case OUTPUTIJ: return envs->bra_count * envs->ket_count;
284                 case INPUT_IJ: return envs->nao * (envs->nao+1) / 2;
285         }
286         const double D0 = 0;
287         const double D1 = 1;
288         const char SIDE_L = 'L';
289         const char UPLO_U = 'U';
290         const char TRANS_T = 'T';
291         const char TRANS_N = 'N';
292         int nao = envs->nao;
293         int i_start = envs->bra_start;
294         int i_count = envs->bra_count;
295         int j_start = envs->ket_start;
296         int j_count = envs->ket_count;
297         double *mo_coeff = envs->mo_coeff;
298 
299         // C_pi (pq| = (iq|, where (pq| is in C-order
300         dsymm_(&SIDE_L, &UPLO_U, &nao, &i_count,
301                &D1, eri, &nao, mo_coeff+i_start*nao, &nao,
302                &D0, buf, &nao);
303         // C_qj (iq| = (ij|
304         dgemm_(&TRANS_T, &TRANS_N, &j_count, &i_count, &nao,
305                &D1, mo_coeff+j_start*nao, &nao, buf, &nao,
306                &D0, vout, &j_count);
307         return 0;
308 }
309 
310 /*
311  * s2-AO integrals to s1-MO integrals, efficient for i_count > j_count
312  * shape requirements:
313  *      vout[:,bra_count*ket_count], eri[:,nao*(nao+1)/2]
314  */
AO2MOmmm_nr_s2_igtj(double * vout,double * eri,double * buf,struct _AO2MOEnvs * envs,int seekdim)315 int AO2MOmmm_nr_s2_igtj(double *vout, double *eri, double *buf,
316                         struct _AO2MOEnvs *envs, int seekdim)
317 {
318         switch (seekdim) {
319                 case OUTPUTIJ: return envs->bra_count * envs->ket_count;
320                 case INPUT_IJ: return envs->nao * (envs->nao+1) / 2;
321         }
322         const double D0 = 0;
323         const double D1 = 1;
324         const char SIDE_L = 'L';
325         const char UPLO_U = 'U';
326         const char TRANS_T = 'T';
327         const char TRANS_N = 'N';
328         int nao = envs->nao;
329         int i_start = envs->bra_start;
330         int i_count = envs->bra_count;
331         int j_start = envs->ket_start;
332         int j_count = envs->ket_count;
333         double *mo_coeff = envs->mo_coeff;
334 
335         // C_qj (pq| = (pj|, where (pq| is in C-order
336         dsymm_(&SIDE_L, &UPLO_U, &nao, &j_count,
337                &D1, eri, &nao, mo_coeff+j_start*nao, &nao,
338                &D0, buf, &nao);
339         // C_pi (pj| = (ij|
340         dgemm_(&TRANS_T, &TRANS_N, &j_count, &i_count, &nao,
341                &D1, buf, &nao, mo_coeff+i_start*nao, &nao,
342                &D0, vout, &j_count);
343         return 0;
344 }
345 
346 /*
347  * transform bra, s1 to label AO symmetry
348  */
AO2MOmmm_bra_nr_s1(double * vout,double * vin,double * buf,struct _AO2MOEnvs * envs,int seekdim)349 int AO2MOmmm_bra_nr_s1(double *vout, double *vin, double *buf,
350                        struct _AO2MOEnvs *envs, int seekdim)
351 {
352         switch (seekdim) {
353                 case 1: return envs->bra_count * envs->nao;
354                 case 2: return envs->nao * envs->nao;
355         }
356         const double D0 = 0;
357         const double D1 = 1;
358         const char TRANS_N = 'N';
359         int nao = envs->nao;
360         int i_start = envs->bra_start;
361         int i_count = envs->bra_count;
362         double *mo_coeff = envs->mo_coeff;
363 
364         dgemm_(&TRANS_N, &TRANS_N, &nao, &i_count, &nao,
365                &D1, vin, &nao, mo_coeff+i_start*nao, &nao,
366                &D0, vout, &nao);
367         return 0;
368 }
369 
370 /*
371  * transform ket, s1 to label AO symmetry
372  */
AO2MOmmm_ket_nr_s1(double * vout,double * vin,double * buf,struct _AO2MOEnvs * envs,int seekdim)373 int AO2MOmmm_ket_nr_s1(double *vout, double *vin, double *buf,
374                        struct _AO2MOEnvs *envs, int seekdim)
375 {
376         switch (seekdim) {
377                 case OUTPUTIJ: return envs->nao * envs->ket_count;
378                 case INPUT_IJ: return envs->nao * envs->nao;
379         }
380         const double D0 = 0;
381         const double D1 = 1;
382         const char TRANS_T = 'T';
383         const char TRANS_N = 'N';
384         int nao = envs->nao;
385         int j_start = envs->ket_start;
386         int j_count = envs->ket_count;
387         double *mo_coeff = envs->mo_coeff;
388 
389         dgemm_(&TRANS_T, &TRANS_N, &j_count, &nao, &nao,
390                &D1, mo_coeff+j_start*nao, &nao, vin, &nao,
391                &D0, vout, &j_count);
392         return 0;
393 }
394 
395 /*
396  * transform bra, s2 to label AO symmetry
397  */
AO2MOmmm_bra_nr_s2(double * vout,double * vin,double * buf,struct _AO2MOEnvs * envs,int seekdim)398 int AO2MOmmm_bra_nr_s2(double *vout, double *vin, double *buf,
399                        struct _AO2MOEnvs *envs, int seekdim)
400 {
401         switch (seekdim) {
402                 case OUTPUTIJ: return envs->bra_count * envs->nao;
403                 case INPUT_IJ: return envs->nao * (envs->nao+1) / 2;
404         }
405         const double D0 = 0;
406         const double D1 = 1;
407         const char SIDE_L = 'L';
408         const char UPLO_U = 'U';
409         int nao = envs->nao;
410         int i_start = envs->bra_start;
411         int i_count = envs->bra_count;
412         double *mo_coeff = envs->mo_coeff;
413 
414         dsymm_(&SIDE_L, &UPLO_U, &nao, &i_count,
415                &D1, vin, &nao, mo_coeff+i_start*nao, &nao,
416                &D0, vout, &nao);
417         return 0;
418 }
419 
420 /*
421  * transform ket, s2 to label AO symmetry
422  */
AO2MOmmm_ket_nr_s2(double * vout,double * vin,double * buf,struct _AO2MOEnvs * envs,int seekdim)423 int AO2MOmmm_ket_nr_s2(double *vout, double *vin, double *buf,
424                        struct _AO2MOEnvs *envs, int seekdim)
425 {
426         switch (seekdim) {
427                 case OUTPUTIJ: return envs->nao * envs->ket_count;
428                 case INPUT_IJ: return envs->nao * (envs->nao+1) / 2;
429         }
430         const double D0 = 0;
431         const double D1 = 1;
432         const char SIDE_L = 'L';
433         const char UPLO_U = 'U';
434         int nao = envs->nao;
435         int j_start = envs->ket_start;
436         int j_count = envs->ket_count;
437         double *mo_coeff = envs->mo_coeff;
438         int i, j;
439 
440         dsymm_(&SIDE_L, &UPLO_U, &nao, &j_count,
441                &D1, vin, &nao, mo_coeff+j_start*nao, &nao,
442                &D0, buf, &nao);
443         for (j = 0; j < nao; j++) {
444                 for (i = 0; i < j_count; i++) {
445                         vout[i] = buf[i*nao+j];
446                 }
447                 vout += j_count;
448         }
449         return 0;
450 }
451 
452 
453 /*
454  * s1, s2ij, s2kl, s4 here to label the AO symmetry
455  * eris[ncomp,nkl,nao_pair_ij]
456  */
s4_copy(double * eri,double * ints,int di,int dj,int dk,int dl,int istride,size_t nao2)457 static void s4_copy(double *eri, double *ints, int di, int dj, int dk, int dl,
458                     int istride, size_t nao2)
459 {
460         int i, j, k, l;
461         double *pints, *peri, *peri1;
462         switch (di) {
463         case 1:
464                 for (k = 0; k < dk; k++) {
465                 for (l = 0; l < dl; l++) {
466                         pints = ints + di * dj * (l*dk+k);
467                         for (j = 0; j < dj; j++) {
468                                 eri[j] = pints[j];
469                         }
470                         eri += nao2;
471                 } }
472                 break;
473         case 2:
474                 for (k = 0; k < dk; k++) {
475                 for (l = 0; l < dl; l++) {
476                         pints = ints + di * dj * (l*dk+k);
477                         peri = eri + istride;
478                         for (j = 0; j < dj;j++) {
479                                 eri [j] = pints[j*2+0];
480                                 peri[j] = pints[j*2+1];
481                         }
482                         eri += nao2;
483                 } }
484                 break;
485         case 3:
486                 for (k = 0; k < dk; k++) {
487                 for (l = 0; l < dl; l++) {
488                         pints = ints + di * dj * (l*dk+k);
489                         peri  = eri + istride;
490                         peri1 = peri + istride + 1;
491                         for (j = 0; j < dj;j++) {
492                                 eri  [j] = pints[j*3+0];
493                                 peri [j] = pints[j*3+1];
494                                 peri1[j] = pints[j*3+2];
495                         }
496                         eri += nao2;
497                 } }
498                 break;
499         default:
500                 for (k = 0; k < dk; k++) {
501                 for (l = 0; l < dl; l++) {
502                         pints = ints + di * dj * (l*dk+k);
503                         peri = eri;
504                         for (i = 0; i < di; i++) {
505                                 for (j = 0; j < dj; j++) {
506 //TODO: call nontemporal write to avoid write-allocate
507                                         peri[j] = pints[j*di+i];
508                                 }
509                                 peri += istride + i;
510                         }
511                         eri += nao2;
512                 } }
513         }
514 }
s4_set0(double * eri,double * nop,int di,int dj,int dk,int dl,int istride,size_t nao2)515 static void s4_set0(double *eri, double *nop,
516                     int di, int dj, int dk, int dl,
517                     int istride, size_t nao2)
518 {
519         int i, j, k, l;
520         double *peri, *peri1;
521         switch (di) {
522         case 1:
523                 for (k = 0; k < dk; k++) {
524                 for (l = 0; l < dl; l++) {
525                         for (j = 0; j < dj; j++) {
526                                 eri[j] = 0;
527                         }
528                         eri += nao2;
529                 } }
530                 break;
531         case 2:
532                 for (k = 0; k < dk; k++) {
533                 for (l = 0; l < dl; l++) {
534                         peri = eri + istride;
535                         for (j = 0; j < dj; j++) {
536                                 eri [j] = 0;
537                                 peri[j] = 0;
538                         }
539                         eri += nao2;
540                 } }
541                 break;
542         case 3:
543                 for (k = 0; k < dk; k++) {
544                 for (l = 0; l < dl; l++) {
545                         peri  = eri + istride;
546                         peri1 = peri + istride + 1;
547                         for (j = 0; j < dj; j++) {
548                                 eri  [j] = 0;
549                                 peri [j] = 0;
550                                 peri1[j] = 0;
551                         }
552                         eri += nao2;
553                 } }
554                 break;
555         default:
556                 for (k = 0; k < dk; k++) {
557                 for (l = 0; l < dl; l++) {
558                         peri = eri;
559                         for (i = 0; i < di; i++) {
560                                 for (j = 0; j < dj; j++) {
561 //TODO: call nontemporal write to avoid write-allocate
562                                         peri[j] = 0;
563                                 }
564                                 peri += istride + i;
565                         }
566                         eri += nao2;
567                 } }
568         }
569 }
570 
s4_copy_keql(double * eri,double * ints,int di,int dj,int dk,int dl,int istride,size_t nao2)571 static void s4_copy_keql(double *eri, double *ints,
572                          int di, int dj, int dk, int dl,
573                          int istride, size_t nao2)
574 {
575         int i, j, k, l;
576         double *pints, *peri;
577         for (k = 0; k < dk; k++) {
578         for (l = 0; l <= k; l++) {
579                 pints = ints + di * dj * (l*dk+k);
580                 peri = eri;
581                 for (i = 0; i < di; i++) {
582                         for (j = 0; j < dj; j++) {
583                                 peri[j] = pints[j*di+i];
584                         }
585                         peri += istride + i;
586                 }
587                 eri += nao2;
588         } }
589 }
s4_set0_keql(double * eri,double * nop,int di,int dj,int dk,int dl,int istride,size_t nao2)590 static void s4_set0_keql(double *eri, double *nop,
591                          int di, int dj, int dk, int dl,
592                          int istride, size_t nao2)
593 {
594         int i, j, k, l;
595         double *peri;
596         for (k = 0; k < dk; k++) {
597         for (l = 0; l <= k; l++) {
598                 peri = eri;
599                 for (i = 0; i < di; i++) {
600                         for (j = 0; j < dj; j++) {
601                                 peri[j] = 0;
602                         }
603                         peri += istride + i;
604                 }
605                 eri += nao2;
606         } }
607 }
s4_copy_ieqj(double * eri,double * ints,int di,int dj,int dk,int dl,int istride,size_t nao2)608 static void s4_copy_ieqj(double *eri, double *ints,
609                          int di, int dj, int dk, int dl,
610                          int istride, size_t nao2)
611 {
612         int i, j, k, l;
613         double *pints, *peri;
614         for (k = 0; k < dk; k++) {
615         for (l = 0; l < dl; l++) {
616                 pints = ints + di * dj * (l*dk+k);
617                 peri = eri;
618                 for (i = 0; i < di; i++) {
619                         for (j = 0; j <= i; j++) {
620                                 peri[j] = pints[j*di+i];
621                         }
622                         peri += istride + i;
623                 }
624                 eri += nao2;
625         } }
626 }
s4_set0_ieqj(double * eri,double * nop,int di,int dj,int dk,int dl,int istride,size_t nao2)627 static void s4_set0_ieqj(double *eri, double *nop,
628                          int di, int dj, int dk, int dl,
629                          int istride, size_t nao2)
630 {
631         int i, j, k, l;
632         double *peri;
633         for (k = 0; k < dk; k++) {
634         for (l = 0; l < dl; l++) {
635                 peri = eri;
636                 for (i = 0; i < di; i++) {
637                         for (j = 0; j <= i; j++) {
638                                 peri[j] = 0;
639                         }
640                         peri += istride + i;
641                 }
642                 eri += nao2;
643         } }
644 }
s4_copy_keql_ieqj(double * eri,double * ints,int di,int dj,int dk,int dl,int istride,size_t nao2)645 static void s4_copy_keql_ieqj(double *eri, double *ints,
646                               int di, int dj, int dk, int dl,
647                               int istride, size_t nao2)
648 {
649         int i, j, k, l;
650         double *pints, *peri;
651         for (k = 0; k < dk; k++) {
652         for (l = 0; l <= k; l++) {
653                 pints = ints + di * dj * (l*dk+k);
654                 peri = eri;
655                 for (i = 0; i < di; i++) {
656                         for (j = 0; j <= i; j++) {
657                                 peri[j] = pints[j*di+i];
658                         }
659                         peri += istride + i;
660                 }
661                 eri += nao2;
662         } }
663 }
s4_set0_keql_ieqj(double * eri,double * nop,int di,int dj,int dk,int dl,int istride,size_t nao2)664 static void s4_set0_keql_ieqj(double *eri, double *nop,
665                               int di, int dj, int dk, int dl,
666                               int istride, size_t nao2)
667 {
668         int i, j, k, l;
669         double *peri;
670         for (k = 0; k < dk; k++) {
671         for (l = 0; l <= k; l++) {
672                 peri = eri;
673                 for (i = 0; i < di; i++) {
674                         for (j = 0; j <= i; j++) {
675                                 peri[j] = 0;
676                         }
677                         peri += istride + i;
678                 }
679                 eri += nao2;
680         } }
681 }
s2kl_copy_keql(double * eri,double * ints,int di,int dj,int dk,int dl,int istride,size_t nao2)682 static void s2kl_copy_keql(double *eri, double *ints,
683                            int di, int dj, int dk, int dl,
684                            int istride, size_t nao2)
685 {
686         int i, j, k, l;
687         double *pints;
688         for (k = 0; k < dk; k++) {
689         for (l = 0; l <= k; l++) {
690                 pints = ints + di * dj * (l*dk+k);
691                 for (i = 0; i < di; i++) {
692                         for (j = 0; j < dj; j++) {
693                                 eri[i*istride+j] = pints[j*di+i];
694                         }
695                 }
696                 eri += nao2;
697         } }
698 }
s2kl_set0_keql(double * eri,double * nop,int di,int dj,int dk,int dl,int istride,size_t nao2)699 static void s2kl_set0_keql(double *eri, double *nop,
700                            int di, int dj, int dk, int dl,
701                            int istride, size_t nao2)
702 {
703         int i, j, k, l;
704         for (k = 0; k < dk; k++) {
705         for (l = 0; l <= k; l++) {
706                 for (i = 0; i < di; i++) {
707                         for (j = 0; j < dj; j++) {
708                                 eri[i*istride+j] = 0;
709                         }
710                 }
711                 eri += nao2;
712         } }
713 }
s1_copy(double * eri,double * ints,int di,int dj,int dk,int dl,int istride,size_t nao2)714 static void s1_copy(double *eri, double *ints,
715                     int di, int dj, int dk, int dl,
716                     int istride, size_t nao2)
717 {
718         int i, j, k, l;
719         double *pints;
720         for (k = 0; k < dk; k++) {
721         for (l = 0; l < dl; l++) {
722                 pints = ints + di * dj * (l*dk+k);
723                 for (i = 0; i < di; i++) {
724                         for (j = 0; j < dj; j++) {
725                                 eri[i*istride+j] = pints[j*di+i];
726                         }
727                 }
728                 eri += nao2;
729         } }
730 }
s1_set0(double * eri,double * nop,int di,int dj,int dk,int dl,int istride,size_t nao2)731 static void s1_set0(double *eri, double *nop,
732                     int di, int dj, int dk, int dl,
733                     int istride, size_t nao2)
734 {
735         int i, j, k, l;
736         for (k = 0; k < dk; k++) {
737         for (l = 0; l < dl; l++) {
738                 for (i = 0; i < di; i++) {
739                         for (j = 0; j < dj; j++) {
740                                 eri[i*istride+j] = 0;
741                         }
742                 }
743                 eri += nao2;
744         } }
745 }
746 
747 #define DISTR_INTS_BY(fcopy, fset0, istride) \
748         if ((*fprescreen)(shls, envs->vhfopt, envs->atm, envs->bas, envs->env) && \
749             (*intor)(buf, NULL, shls, envs->atm, envs->natm, \
750                      envs->bas, envs->nbas, envs->env, envs->cintopt, NULL)) { \
751                 pbuf = buf; \
752                 for (icomp = 0; icomp < envs->ncomp; icomp++) { \
753                         peri = eri + nao2 * nkl * icomp + ioff + ao_loc[jsh]; \
754                         fcopy(peri, pbuf, di, dj, dk, dl, istride, nao2); \
755                         pbuf += di * dj * dk * dl; \
756                 } \
757         } else { \
758                 for (icomp = 0; icomp < envs->ncomp; icomp++) { \
759                         peri = eri + nao2 * nkl * icomp + ioff + ao_loc[jsh]; \
760                         fset0(peri, pbuf, di, dj, dk, dl, istride, nao2); \
761                 } \
762         }
763 
AO2MOfill_nr_s1(int (* intor)(),int (* fprescreen)(),double * eri,double * buf,int nkl,int ish,struct _AO2MOEnvs * envs)764 void AO2MOfill_nr_s1(int (*intor)(), int (*fprescreen)(),
765                      double *eri, double *buf,
766                      int nkl, int ish, struct _AO2MOEnvs *envs)
767 {
768         const int nao = envs->nao;
769         const size_t nao2 = nao * nao;
770         const int *ao_loc = envs->ao_loc;
771         const int klsh_start = envs->klsh_start;
772         const int klsh_end = klsh_start + envs->klsh_count;
773         const int di = ao_loc[ish+1] - ao_loc[ish];
774         const int ioff = ao_loc[ish] * nao;
775         int kl, jsh, ksh, lsh, dj, dk, dl;
776         int icomp;
777         int shls[4];
778         double *pbuf, *peri;
779 
780         shls[0] = ish;
781 
782         for (kl = klsh_start; kl < klsh_end; kl++) {
783                 // kl = k * (k+1) / 2 + l
784                 ksh = kl / envs->nbas;
785                 lsh = kl - ksh * envs->nbas;
786                 dk = ao_loc[ksh+1] - ao_loc[ksh];
787                 dl = ao_loc[lsh+1] - ao_loc[lsh];
788                 shls[2] = ksh;
789                 shls[3] = lsh;
790 
791                 for (jsh = 0; jsh < envs->nbas; jsh++) {
792                         dj = ao_loc[jsh+1] - ao_loc[jsh];
793                         shls[1] = jsh;
794                         DISTR_INTS_BY(s1_copy, s1_set0, nao);
795                 }
796                 eri += nao2 * dk * dl;
797         }
798 }
799 
AO2MOfill_nr_s2ij(int (* intor)(),int (* fprescreen)(),double * eri,double * buf,int nkl,int ish,struct _AO2MOEnvs * envs)800 void AO2MOfill_nr_s2ij(int (*intor)(), int (*fprescreen)(),
801                        double *eri, double *buf,
802                        int nkl, int ish, struct _AO2MOEnvs *envs)
803 {
804         const int nao = envs->nao;
805         const size_t nao2 = nao * (nao+1) / 2;
806         const int *ao_loc = envs->ao_loc;
807         const int klsh_start = envs->klsh_start;
808         const int klsh_end = klsh_start + envs->klsh_count;
809         const int di = ao_loc[ish+1] - ao_loc[ish];
810         const int ioff = ao_loc[ish] * (ao_loc[ish]+1) / 2;
811         int kl, jsh, ksh, lsh, dj, dk, dl;
812         int icomp;
813         int shls[4];
814         double *pbuf = buf;
815         double *peri;
816 
817         shls[0] = ish;
818 
819         for (kl = klsh_start; kl < klsh_end; kl++) {
820                 // kl = k * (k+1) / 2 + l
821                 ksh = kl / envs->nbas;
822                 lsh = kl - ksh * envs->nbas;
823                 dk = ao_loc[ksh+1] - ao_loc[ksh];
824                 dl = ao_loc[lsh+1] - ao_loc[lsh];
825                 shls[2] = ksh;
826                 shls[3] = lsh;
827 
828                 for (jsh = 0; jsh < ish; jsh++) {
829                         dj = ao_loc[jsh+1] - ao_loc[jsh];
830                         shls[1] = jsh;
831                         DISTR_INTS_BY(s4_copy, s4_set0, ao_loc[ish]+1);
832                 }
833 
834                 jsh = ish;
835                 dj = di;
836                 shls[1] = jsh;
837                 DISTR_INTS_BY(s4_copy_ieqj, s4_set0_ieqj, ao_loc[ish]+1);
838                 eri += nao2 * dk * dl;
839         }
840 }
841 
AO2MOfill_nr_s2kl(int (* intor)(),int (* fprescreen)(),double * eri,double * buf,int nkl,int ish,struct _AO2MOEnvs * envs)842 void AO2MOfill_nr_s2kl(int (*intor)(), int (*fprescreen)(),
843                        double *eri, double *buf,
844                        int nkl, int ish, struct _AO2MOEnvs *envs)
845 {
846         const int nao = envs->nao;
847         const size_t nao2 = nao * nao;
848         const int *ao_loc = envs->ao_loc;
849         const int klsh_start = envs->klsh_start;
850         const int klsh_end = klsh_start + envs->klsh_count;
851         const int di = ao_loc[ish+1] - ao_loc[ish];
852         const int ioff = ao_loc[ish] * nao;
853         int kl, jsh, ksh, lsh, dj, dk, dl;
854         int icomp;
855         int shls[4];
856         double *pbuf = buf;
857         double *peri;
858 
859         shls[0] = ish;
860 
861         for (kl = klsh_start; kl < klsh_end; kl++) {
862 
863         // kl = k * (k+1) / 2 + l
864         ksh = (int)(sqrt(2*kl+.25) - .5 + 1e-7);
865         lsh = kl - ksh * (ksh+1) / 2;
866         dk = ao_loc[ksh+1] - ao_loc[ksh];
867         dl = ao_loc[lsh+1] - ao_loc[lsh];
868         shls[2] = ksh;
869         shls[3] = lsh;
870 
871         if (ksh == lsh) {
872                 for (jsh = 0; jsh < envs->nbas; jsh++) {
873                         dj = ao_loc[jsh+1] - ao_loc[jsh];
874                         shls[1] = jsh;
875                         DISTR_INTS_BY(s2kl_copy_keql, s2kl_set0_keql, nao);
876                 }
877                 eri += nao2 * dk*(dk+1)/2;
878 
879         } else {
880 
881                 for (jsh = 0; jsh < envs->nbas; jsh++) {
882                         dj = ao_loc[jsh+1] - ao_loc[jsh];
883                         shls[1] = jsh;
884                         DISTR_INTS_BY(s1_copy, s1_set0, nao);
885                 }
886                 eri += nao2 * dk * dl;
887         } }
888 }
889 
AO2MOfill_nr_s4(int (* intor)(),int (* fprescreen)(),double * eri,double * buf,int nkl,int ish,struct _AO2MOEnvs * envs)890 void AO2MOfill_nr_s4(int (*intor)(), int (*fprescreen)(),
891                      double *eri, double *buf,
892                      int nkl, int ish, struct _AO2MOEnvs *envs)
893 {
894         const int nao = envs->nao;
895         const size_t nao2 = nao * (nao+1) / 2;
896         const int *ao_loc = envs->ao_loc;
897         const int klsh_start = envs->klsh_start;
898         const int klsh_end = klsh_start + envs->klsh_count;
899         const int di = ao_loc[ish+1] - ao_loc[ish];
900         const int ioff = ao_loc[ish] * (ao_loc[ish]+1) / 2;
901         int kl, jsh, ksh, lsh, dj, dk, dl;
902         int icomp;
903         int shls[4];
904         double *pbuf = buf;
905         double *peri;
906 
907         shls[0] = ish;
908 
909         for (kl = klsh_start; kl < klsh_end; kl++) {
910 
911         // kl = k * (k+1) / 2 + l
912         ksh = (int)(sqrt(2*kl+.25) - .5 + 1e-7);
913         lsh = kl - ksh * (ksh+1) / 2;
914         dk = ao_loc[ksh+1] - ao_loc[ksh];
915         dl = ao_loc[lsh+1] - ao_loc[lsh];
916         shls[2] = ksh;
917         shls[3] = lsh;
918 
919         if (ksh == lsh) {
920                 for (jsh = 0; jsh < ish; jsh++) {
921                         dj = ao_loc[jsh+1] - ao_loc[jsh];
922                         shls[1] = jsh;
923                         DISTR_INTS_BY(s4_copy_keql, s4_set0_keql,
924                                       ao_loc[ish]+1);
925                 }
926 
927                 jsh = ish;
928                 dj = di;
929                 shls[1] = ish;
930                 DISTR_INTS_BY(s4_copy_keql_ieqj, s4_set0_keql_ieqj,
931                               ao_loc[ish]+1);
932                 eri += nao2 * dk*(dk+1)/2;
933 
934         } else {
935 
936                 for (jsh = 0; jsh < ish; jsh++) {
937                         dj = ao_loc[jsh+1] - ao_loc[jsh];
938                         shls[1] = jsh;
939                         DISTR_INTS_BY(s4_copy, s4_set0, ao_loc[ish]+1);
940                 }
941 
942                 jsh = ish;
943                 dj = di;
944                 shls[1] = ish;
945                 DISTR_INTS_BY(s4_copy_ieqj, s4_set0_ieqj, ao_loc[ish]+1);
946                 eri += nao2 * dk * dl;
947         } }
948 }
949 
950 /*
951  * ************************************************
952  * s1, s2ij, s2kl, s4 here to label the AO symmetry
953  */
AO2MOtranse1_nr_s1(int (* fmmm)(),int row_id,double * vout,double * vin,double * buf,struct _AO2MOEnvs * envs)954 void AO2MOtranse1_nr_s1(int (*fmmm)(), int row_id,
955                         double *vout, double *vin, double *buf,
956                         struct _AO2MOEnvs *envs)
957 {
958         size_t ij_pair = (*fmmm)(NULL, NULL, buf, envs, OUTPUTIJ);
959         size_t nao2 = envs->nao * envs->nao;
960         (*fmmm)(vout+ij_pair*row_id, vin+nao2*row_id, buf, envs, 0);
961 }
962 
AO2MOtranse1_nr_s2ij(int (* fmmm)(),int row_id,double * vout,double * vin,double * buf,struct _AO2MOEnvs * envs)963 void AO2MOtranse1_nr_s2ij(int (*fmmm)(), int row_id,
964                           double *vout, double *vin, double *buf,
965                           struct _AO2MOEnvs *envs)
966 {
967         int nao = envs->nao;
968         size_t ij_pair = (*fmmm)(NULL, NULL, buf, envs, OUTPUTIJ);
969         size_t nao2 = nao*(nao+1)/2;
970         NPdunpack_tril(nao, vin+nao2*row_id, buf, 0);
971         (*fmmm)(vout+ij_pair*row_id, buf, buf+nao*nao, envs, 0);
972 }
AO2MOtranse1_nr_s2(int (* fmmm)(),int row_id,double * vout,double * vin,double * buf,struct _AO2MOEnvs * envs)973 void AO2MOtranse1_nr_s2(int (*fmmm)(), int row_id,
974                         double *vout, double *vin, double *buf,
975                         struct _AO2MOEnvs *envs)
976 {
977         AO2MOtranse1_nr_s2ij(fmmm, row_id, vout, vin, buf, envs);
978 }
979 
AO2MOtranse1_nr_s2kl(int (* fmmm)(),int row_id,double * vout,double * vin,double * buf,struct _AO2MOEnvs * envs)980 void AO2MOtranse1_nr_s2kl(int (*fmmm)(), int row_id,
981                           double *vout, double *vin, double *buf,
982                           struct _AO2MOEnvs *envs)
983 {
984         AO2MOtranse1_nr_s1(fmmm, row_id, vout, vin, buf, envs);
985 }
986 
AO2MOtranse1_nr_s4(int (* fmmm)(),int row_id,double * vout,double * vin,double * buf,struct _AO2MOEnvs * envs)987 void AO2MOtranse1_nr_s4(int (*fmmm)(), int row_id,
988                         double *vout, double *vin, double *buf,
989                         struct _AO2MOEnvs *envs)
990 {
991         AO2MOtranse1_nr_s2ij(fmmm, row_id, vout, vin, buf, envs);
992 }
993 
994 
995 /*
996  * ************************************************
997  * s1, s2ij, s2kl, s4 here to label the AO symmetry
998  */
AO2MOtranse2_nr_s1(int (* fmmm)(),int row_id,double * vout,double * vin,double * buf,struct _AO2MOEnvs * envs)999 void AO2MOtranse2_nr_s1(int (*fmmm)(), int row_id,
1000                         double *vout, double *vin, double *buf,
1001                         struct _AO2MOEnvs *envs)
1002 {
1003         size_t ij_pair = (*fmmm)(NULL, NULL, buf, envs, OUTPUTIJ);
1004         size_t nao2 = (*fmmm)(NULL, NULL, buf, envs, INPUT_IJ);
1005         (*fmmm)(vout+ij_pair*row_id, vin+nao2*row_id, buf, envs, 0);
1006 }
1007 
AO2MOtranse2_nr_s2ij(int (* fmmm)(),int row_id,double * vout,double * vin,double * buf,struct _AO2MOEnvs * envs)1008 void AO2MOtranse2_nr_s2ij(int (*fmmm)(), int row_id,
1009                           double *vout, double *vin, double *buf,
1010                           struct _AO2MOEnvs *envs)
1011 {
1012         AO2MOtranse2_nr_s1(fmmm, row_id, vout, vin, buf, envs);
1013 }
1014 
AO2MOtranse2_nr_s2kl(int (* fmmm)(),int row_id,double * vout,double * vin,double * buf,struct _AO2MOEnvs * envs)1015 void AO2MOtranse2_nr_s2kl(int (*fmmm)(), int row_id,
1016                           double *vout, double *vin, double *buf,
1017                           struct _AO2MOEnvs *envs)
1018 {
1019         int nao = envs->nao;
1020         size_t ij_pair = (*fmmm)(NULL, NULL, buf, envs, OUTPUTIJ);
1021         size_t nao2 = (*fmmm)(NULL, NULL, buf, envs, INPUT_IJ);
1022         NPdunpack_tril(nao, vin+nao2*row_id, buf, 0);
1023         (*fmmm)(vout+ij_pair*row_id, buf, buf+nao*nao, envs, 0);
1024 }
AO2MOtranse2_nr_s2(int (* fmmm)(),int row_id,double * vout,double * vin,double * buf,struct _AO2MOEnvs * envs)1025 void AO2MOtranse2_nr_s2(int (*fmmm)(), int row_id,
1026                         double *vout, double *vin, double *buf,
1027                         struct _AO2MOEnvs *envs)
1028 {
1029         AO2MOtranse2_nr_s2kl(fmmm, row_id, vout, vin, buf, envs);
1030 }
1031 
AO2MOtranse2_nr_s4(int (* fmmm)(),int row_id,double * vout,double * vin,double * buf,struct _AO2MOEnvs * envs)1032 void AO2MOtranse2_nr_s4(int (*fmmm)(), int row_id,
1033                         double *vout, double *vin, double *buf,
1034                         struct _AO2MOEnvs *envs)
1035 {
1036         AO2MOtranse2_nr_s2kl(fmmm, row_id, vout, vin, buf, envs);
1037 }
1038 
1039 
1040 
1041 /*
1042  * ************************************************
1043  * sort (shell-based) integral blocks then transform
1044  */
AO2MOsortranse2_nr_s1(int (* fmmm)(),int row_id,double * vout,double * vin,double * buf,struct _AO2MOEnvs * envs)1045 void AO2MOsortranse2_nr_s1(int (*fmmm)(), int row_id,
1046                            double *vout, double *vin, double *buf,
1047                            struct _AO2MOEnvs *envs)
1048 {
1049         int nao = envs->nao;
1050         int *ao_loc = envs->ao_loc;
1051         size_t ij_pair = (*fmmm)(NULL, NULL, buf, envs, OUTPUTIJ);
1052         size_t nao2 = (*fmmm)(NULL, NULL, buf, envs, INPUT_IJ);
1053         int ish, jsh, di, dj;
1054         int i, j, ij;
1055         double *pbuf;
1056 
1057         vin += nao2 * row_id;
1058         ij = 0;
1059         for (ish = 0; ish < envs->nbas; ish++) {
1060                 di = ao_loc[ish+1] - ao_loc[ish];
1061                 for (jsh = 0; jsh < envs->nbas; jsh++) {
1062                         dj = ao_loc[jsh+1] - ao_loc[jsh];
1063                         pbuf = buf + ao_loc[ish] * nao + ao_loc[jsh];
1064                         for (i = 0; i < di; i++) {
1065                         for (j = 0; j < dj; j++, ij++) {
1066                                 pbuf[i*nao+j] = vin[ij];
1067                         } }
1068                 }
1069         }
1070 
1071         (*fmmm)(vout+ij_pair*row_id, buf, buf+nao*nao, envs, 0);
1072 }
1073 
AO2MOsortranse2_nr_s2ij(int (* fmmm)(),int row_id,double * vout,double * vin,double * buf,struct _AO2MOEnvs * envs)1074 void AO2MOsortranse2_nr_s2ij(int (*fmmm)(), int row_id,
1075                              double *vout, double *vin, double *buf,
1076                              struct _AO2MOEnvs *envs)
1077 {
1078         AO2MOsortranse2_nr_s1(fmmm, row_id, vout, vin, buf, envs);
1079 }
1080 
AO2MOsortranse2_nr_s2kl(int (* fmmm)(),int row_id,double * vout,double * vin,double * buf,struct _AO2MOEnvs * envs)1081 void AO2MOsortranse2_nr_s2kl(int (*fmmm)(), int row_id,
1082                              double *vout, double *vin, double *buf,
1083                              struct _AO2MOEnvs *envs)
1084 {
1085         int nao = envs->nao;
1086         int *ao_loc = envs->ao_loc;
1087         size_t ij_pair = (*fmmm)(NULL, NULL, buf, envs, OUTPUTIJ);
1088         size_t nao2 = (*fmmm)(NULL, NULL, buf, envs, INPUT_IJ);
1089         int ish, jsh, di, dj;
1090         int i, j, ij;
1091         double *pbuf;
1092 
1093         vin += nao2 * row_id;
1094         for (ish = 0; ish < envs->nbas; ish++) {
1095                 di = ao_loc[ish+1] - ao_loc[ish];
1096                 for (jsh = 0; jsh < ish; jsh++) {
1097                         dj = ao_loc[jsh+1] - ao_loc[jsh];
1098                         pbuf = buf + ao_loc[ish] * nao + ao_loc[jsh];
1099                         for (i = 0; i < di; i++) {
1100                         for (j = 0; j < dj; j++) {
1101                                 pbuf[i*nao+j] = vin[i*dj+j];
1102                         } }
1103                         vin += di * dj;
1104                 }
1105 
1106                 // lower triangle block when ish == jsh
1107                 pbuf = buf + ao_loc[ish] * nao + ao_loc[ish];
1108                 for (ij = 0, i = 0; i < di; i++) {
1109                 for (j = 0; j <= i; j++, ij++) {
1110                         pbuf[i*nao+j] = vin[ij];
1111                 } }
1112                 vin += di * (di+1) / 2;
1113         }
1114 
1115         (*fmmm)(vout+ij_pair*row_id, buf, buf+nao*nao, envs, 0);
1116 }
AO2MOsortranse2_nr_s2(int (* fmmm)(),int row_id,double * vout,double * vin,double * buf,struct _AO2MOEnvs * envs)1117 void AO2MOsortranse2_nr_s2(int (*fmmm)(), int row_id,
1118                            double *vout, double *vin, double *buf,
1119                            struct _AO2MOEnvs *envs)
1120 {
1121         AO2MOsortranse2_nr_s2kl(fmmm, row_id, vout, vin, buf, envs);
1122 }
1123 
AO2MOsortranse2_nr_s4(int (* fmmm)(),int row_id,double * vout,double * vin,double * buf,struct _AO2MOEnvs * envs)1124 void AO2MOsortranse2_nr_s4(int (*fmmm)(), int row_id,
1125                            double *vout, double *vin, double *buf,
1126                            struct _AO2MOEnvs *envs)
1127 {
1128         AO2MOsortranse2_nr_s2kl(fmmm, row_id, vout, vin, buf, envs);
1129 }
1130 
1131 /*
1132  * ************************************************
1133  * combine ftrans and fmmm
1134  */
1135 
AO2MOtrans_nr_s1_iltj(void * nop,int row_id,double * vout,double * eri,double * buf,struct _AO2MOEnvs * envs)1136 void AO2MOtrans_nr_s1_iltj(void *nop, int row_id,
1137                            double *vout, double *eri, double *buf,
1138                            struct _AO2MOEnvs *envs)
1139 {
1140         AO2MOtranse2_nr_s1(AO2MOmmm_nr_s1_iltj, row_id, vout, eri, buf, envs);
1141 }
1142 
AO2MOtrans_nr_s1_igtj(void * nop,int row_id,double * vout,double * eri,double * buf,struct _AO2MOEnvs * envs)1143 void AO2MOtrans_nr_s1_igtj(void *nop, int row_id,
1144                            double *vout, double *eri, double *buf,
1145                            struct _AO2MOEnvs *envs)
1146 {
1147         AO2MOtranse2_nr_s1(AO2MOmmm_nr_s1_igtj, row_id, vout, eri, buf, envs);
1148 }
1149 
AO2MOtrans_nr_sorts1_iltj(void * nop,int row_id,double * vout,double * eri,double * buf,struct _AO2MOEnvs * envs)1150 void AO2MOtrans_nr_sorts1_iltj(void *nop, int row_id,
1151                                double *vout, double *eri, double *buf,
1152                                struct _AO2MOEnvs *envs)
1153 {
1154         AO2MOsortranse2_nr_s1(AO2MOmmm_nr_s1_iltj, row_id, vout, eri, buf,envs);
1155 }
1156 
AO2MOtrans_nr_sorts1_igtj(void * nop,int row_id,double * vout,double * eri,double * buf,struct _AO2MOEnvs * envs)1157 void AO2MOtrans_nr_sorts1_igtj(void *nop, int row_id,
1158                                double *vout, double *eri, double *buf,
1159                                struct _AO2MOEnvs *envs)
1160 {
1161         AO2MOsortranse2_nr_s1(AO2MOmmm_nr_s1_igtj, row_id, vout, eri, buf,envs);
1162 }
1163 
AO2MOtrans_nr_s2_iltj(void * nop,int row_id,double * vout,double * eri,double * buf,struct _AO2MOEnvs * envs)1164 void AO2MOtrans_nr_s2_iltj(void *nop, int row_id,
1165                            double *vout, double *eri, double *buf,
1166                            struct _AO2MOEnvs *envs)
1167 {
1168         AO2MOtranse2_nr_s2kl(AO2MOmmm_nr_s2_iltj, row_id, vout, eri, buf, envs);
1169 }
1170 
AO2MOtrans_nr_s2_igtj(void * nop,int row_id,double * vout,double * eri,double * buf,struct _AO2MOEnvs * envs)1171 void AO2MOtrans_nr_s2_igtj(void *nop, int row_id,
1172                            double *vout, double *eri, double *buf,
1173                            struct _AO2MOEnvs *envs)
1174 {
1175         AO2MOtranse2_nr_s2kl(AO2MOmmm_nr_s2_igtj, row_id, vout, eri, buf, envs);
1176 }
1177 
AO2MOtrans_nr_s2_s2(void * nop,int row_id,double * vout,double * eri,double * buf,struct _AO2MOEnvs * envs)1178 void AO2MOtrans_nr_s2_s2(void *nop, int row_id,
1179                          double *vout, double *eri, double *buf,
1180                          struct _AO2MOEnvs *envs)
1181 {
1182         AO2MOtranse2_nr_s2kl(AO2MOmmm_nr_s2_s2, row_id, vout, eri, buf, envs);
1183 }
1184 
AO2MOtrans_nr_sorts2_iltj(void * nop,int row_id,double * vout,double * eri,double * buf,struct _AO2MOEnvs * envs)1185 void AO2MOtrans_nr_sorts2_iltj(void *nop, int row_id,
1186                                double *vout, double *eri, double *buf,
1187                                struct _AO2MOEnvs *envs)
1188 {
1189         AO2MOsortranse2_nr_s2kl(AO2MOmmm_nr_s2_iltj, row_id, vout, eri, buf, envs);
1190 }
1191 
AO2MOtrans_nr_sorts2_igtj(void * nop,int row_id,double * vout,double * eri,double * buf,struct _AO2MOEnvs * envs)1192 void AO2MOtrans_nr_sorts2_igtj(void *nop, int row_id,
1193                                double *vout, double *eri, double *buf,
1194                                struct _AO2MOEnvs *envs)
1195 {
1196         AO2MOsortranse2_nr_s2kl(AO2MOmmm_nr_s2_igtj, row_id, vout, eri, buf, envs);
1197 }
1198 
AO2MOtrans_nr_sorts2_s2(void * nop,int row_id,double * vout,double * eri,double * buf,struct _AO2MOEnvs * envs)1199 void AO2MOtrans_nr_sorts2_s2(void *nop, int row_id,
1200                              double *vout, double *eri, double *buf,
1201                              struct _AO2MOEnvs *envs)
1202 {
1203         AO2MOsortranse2_nr_s2kl(AO2MOmmm_nr_s2_s2, row_id, vout, eri, buf,envs);
1204 }
1205 
1206 /*
1207  * ************************************************
1208  * Denoting 2e integrals (ij|kl),
1209  * transform ij for ksh_start <= k shell < ksh_end.
1210  * The transformation C_pi C_qj (pq|k*) coefficients are stored in
1211  * mo_coeff, C_pi and C_qj are offset by i_start and i_count, j_start and j_count
1212  *
1213  * The output eri is an 2D array, ordered as (kl-AO-pair,ij-MO-pair) in
1214  * C-order.  Transposing is needed before calling AO2MOnr_e2_drv.
1215  * eri[ncomp,nkl,mo_i,mo_j]
1216  */
AO2MOnr_e1_drv(int (* intor)(),void (* fill)(),void (* ftrans)(),int (* fmmm)(),double * eri,double * mo_coeff,int klsh_start,int klsh_count,int nkl,int ncomp,int * orbs_slice,int * ao_loc,CINTOpt * cintopt,CVHFOpt * vhfopt,int * atm,int natm,int * bas,int nbas,double * env)1217 void AO2MOnr_e1_drv(int (*intor)(), void (*fill)(), void (*ftrans)(), int (*fmmm)(),
1218                     double *eri, double *mo_coeff,
1219                     int klsh_start, int klsh_count, int nkl, int ncomp,
1220                     int *orbs_slice, int *ao_loc,
1221                     CINTOpt *cintopt, CVHFOpt *vhfopt,
1222                     int *atm, int natm, int *bas, int nbas, double *env)
1223 {
1224         int nao = ao_loc[nbas];
1225         double *eri_ao = malloc(sizeof(double) * nao*nao*nkl*ncomp);
1226         assert(eri_ao);
1227         AO2MOnr_e1fill_drv(intor, fill, eri_ao, klsh_start, klsh_count,
1228                            nkl, ncomp, ao_loc, cintopt, vhfopt,
1229                            atm, natm, bas, nbas, env);
1230         AO2MOnr_e2_drv(ftrans, fmmm, eri, eri_ao, mo_coeff,
1231                        nkl*ncomp, nao, orbs_slice, ao_loc, nbas);
1232         free(eri_ao);
1233 }
1234 
AO2MOnr_e2_drv(void (* ftrans)(),int (* fmmm)(),double * vout,double * vin,double * mo_coeff,int nij,int nao,int * orbs_slice,int * ao_loc,int nbas)1235 void AO2MOnr_e2_drv(void (*ftrans)(), int (*fmmm)(),
1236                     double *vout, double *vin, double *mo_coeff,
1237                     int nij, int nao, int *orbs_slice, int *ao_loc, int nbas)
1238 {
1239         struct _AO2MOEnvs envs;
1240         envs.bra_start = orbs_slice[0];
1241         envs.bra_count = orbs_slice[1] - orbs_slice[0];
1242         envs.ket_start = orbs_slice[2];
1243         envs.ket_count = orbs_slice[3] - orbs_slice[2];
1244         envs.nao = nao;
1245         envs.nbas = nbas;
1246         envs.ao_loc = ao_loc;
1247         envs.mo_coeff = mo_coeff;
1248 
1249 #pragma omp parallel default(none) \
1250         shared(ftrans, fmmm, vout, vin, nij, envs, nao, orbs_slice)
1251 {
1252         int i;
1253         int i_count = envs.bra_count;
1254         int j_count = envs.ket_count;
1255         double *buf = malloc(sizeof(double) * (nao+i_count) * (nao+j_count));
1256 #pragma omp for schedule(dynamic)
1257         for (i = 0; i < nij; i++) {
1258                 (*ftrans)(fmmm, i, vout, vin, buf, &envs);
1259         }
1260         free(buf);
1261 }
1262 }
1263 
1264 /*
1265  * The size of eri is ncomp*nkl*nao*nao, note the upper triangular part
1266  * may not be filled
1267  */
AO2MOnr_e1fill_drv(int (* intor)(),void (* fill)(),double * eri,int klsh_start,int klsh_count,int nkl,int ncomp,int * ao_loc,CINTOpt * cintopt,CVHFOpt * vhfopt,int * atm,int natm,int * bas,int nbas,double * env)1268 void AO2MOnr_e1fill_drv(int (*intor)(), void (*fill)(), double *eri,
1269                         int klsh_start, int klsh_count, int nkl, int ncomp,
1270                         int *ao_loc, CINTOpt *cintopt, CVHFOpt *vhfopt,
1271                         int *atm, int natm, int *bas, int nbas, double *env)
1272 {
1273         int i;
1274         int nao = ao_loc[nbas];
1275         int dmax = 0;
1276         for (i= 0; i< nbas; i++) {
1277                 dmax = MAX(dmax, ao_loc[i+1]-ao_loc[i]);
1278         }
1279         struct _AO2MOEnvs envs = {natm, nbas, atm, bas, env, nao,
1280                                   klsh_start, klsh_count, 0, 0, 0, 0,
1281                                   ncomp, ao_loc, NULL, cintopt, vhfopt};
1282         int (*fprescreen)();
1283         if (vhfopt) {
1284                 fprescreen = vhfopt->fprescreen;
1285         } else {
1286                 fprescreen = CVHFnoscreen;
1287         }
1288 
1289 #pragma omp parallel default(none) \
1290         shared(fill, fprescreen, eri, envs, intor, nkl, nbas, dmax, ncomp)
1291 {
1292         int ish;
1293         double *buf = malloc(sizeof(double)*dmax*dmax*dmax*dmax*ncomp);
1294 #pragma omp for schedule(dynamic, 1)
1295         for (ish = 0; ish < nbas; ish++) {
1296                 (*fill)(intor, fprescreen, eri, buf, nkl, ish, &envs);
1297         }
1298         free(buf);
1299 }
1300 }
1301 
1302