1from pathlib import Path
2
3import numpy as np
4from scipy.spatial import cKDTree
5
6from gpaw.utilities import convert_string_to_fd
7from ase.utils.timing import Timer, timer
8
9from gpaw import GPAW, disable_dry_run
10import gpaw.mpi as mpi
11from gpaw.response.math_func import two_phi_planewave_integrals
12
13
14class KohnShamKPoint:
15    """Kohn-Sham orbital information for a given k-point."""
16    def __init__(self, n_t, s_t, K,
17                 eps_t, f_t, ut_tR, projections, shift_c):
18        """K-point data is indexed by a joint spin and band index h, which is
19        directly related to the transition index t."""
20        self.K = K                      # BZ k-point index
21        self.n_t = n_t                  # Band index
22        self.s_t = s_t                  # Spin index
23        self.eps_t = eps_t              # Eigenvalues
24        self.f_t = f_t                  # Occupation numbers
25        self.ut_tR = ut_tR              # Periodic part of wave functions
26        self.projections = projections  # PAW projections
27
28        self.shift_c = shift_c  # long story - see the
29        # PairDensity.construct_symmetry_operators() method
30
31
32class KohnShamKPointPair:
33    """Object containing all transitions between Kohn-Sham orbitals from a
34    specified k-point to another."""
35
36    def __init__(self, kpt1, kpt2, mynt, nt, ta, tb, comm=None):
37        self.kpt1 = kpt1
38        self.kpt2 = kpt2
39
40        self.mynt = mynt  # Number of transitions in this block
41        self.nt = nt      # Total number of transitions between all blocks
42        self.ta = ta      # First transition index of this block
43        self.tb = tb      # First transition index of this block not included
44        self.comm = comm  # MPI communicator between blocks of transitions
45
46    def transition_distribution(self):
47        """Get the distribution of transitions."""
48        return self.mynt, self.nt, self.ta, self.tb
49
50    def get_transitions(self):
51        return self.n1_t, self.n2_t, self.s1_t, self.s2_t
52
53    def get_all(self, A_mytx):
54        """Get a certain data array with all transitions"""
55        if self.comm is None or A_mytx is None:
56            return A_mytx
57
58        A_tx = np.empty((self.mynt * self.comm.size,) + A_mytx.shape[1:],
59                        dtype=A_mytx.dtype)
60
61        self.comm.all_gather(A_mytx, A_tx)
62
63        return A_tx[:self.nt]
64
65    @property
66    def n1_t(self):
67        return self.get_all(self.kpt1.n_t)
68
69    @property
70    def n2_t(self):
71        return self.get_all(self.kpt2.n_t)
72
73    @property
74    def s1_t(self):
75        return self.get_all(self.kpt1.s_t)
76
77    @property
78    def s2_t(self):
79        return self.get_all(self.kpt2.s_t)
80
81    @property
82    def deps_t(self):
83        return self.get_all(self.kpt2.eps_t) - self.get_all(self.kpt1.eps_t)
84
85    @property
86    def df_t(self):
87        return self.get_all(self.kpt2.f_t) - self.get_all(self.kpt1.f_t)
88
89    @classmethod
90    def add_mytransitions_array(cls, _key, key):
91        """Add a A_tx data array class attribute.
92        Handles the fact, that the transitions are distributed in blocks.
93
94        Parameters
95        ----------
96        _key : str
97            attribute name for the A_mytx data array
98        key : str
99            attribute name for the A_tx data array
100        """
101        # In general, the data array has not been specified to instances of
102        # the class. As a result, set the _key to None
103        setattr(cls, _key, None)
104        # self.key should return full data array
105        setattr(cls, key,
106                property(lambda self: self.get_all(self.__dict__[_key])))
107
108    def attach(self, _key, key, A_mytx):
109        """Attach a data array to the k-point pair.
110        Used by PairMatrixElement to attach matrix elements calculated
111        between the k-points for the different transitions."""
112        self.add_mytransitions_array(_key, key)
113        setattr(self, _key, A_mytx)
114
115
116class KohnShamPair:
117    """Class for extracting pairs of Kohn-Sham orbitals from a ground
118    state calculation."""
119
120    def __init__(self, gs, world=mpi.world, transitionblockscomm=None,
121                 kptblockcomm=None, txt='-', timer=None):
122        """
123        Parameters
124        ----------
125        transitionblockscomm : gpaw.mpi.Communicator
126            Communicator for distributing the transitions among processes
127        kptblockcomm : gpaw.mpi.Communicator
128            Communicator for distributing k-points among processes
129        """
130        self.world = world
131        self.fd = convert_string_to_fd(txt, world)
132        self.timer = timer or Timer()
133        self.calc = get_calc(gs, fd=self.fd, timer=self.timer)
134        self.calc_parallel = self.check_calc_parallelisation()
135
136        self.transitionblockscomm = transitionblockscomm
137        self.kptblockcomm = kptblockcomm
138
139        # Prepare to distribute transitions
140        self.mynt = None
141        self.nt = None
142        self.ta = None
143        self.tb = None
144
145        # Prepare to find k-point data from vector
146        kd = self.calc.wfs.kd
147        self.kdtree = cKDTree(np.mod(np.mod(kd.bzk_kc, 1).round(6), 1))
148
149        # Prepare to use other processes' k-points
150        self._pd0 = None
151
152        # Prepare to redistribute kptdata
153        self.rrequests = []
154        self.srequests = []
155
156        # Count bands so it is possible to remove null transitions
157        self.nocc1 = None  # number of completely filled bands
158        self.nocc2 = None  # number of non-empty bands
159        self.count_occupied_bands()
160
161    def check_calc_parallelisation(self):
162        """Check how ground state calculation is distributed in memory"""
163        if self.calc.world.size == 1:
164            return False
165        else:
166            assert self.world.rank == self.calc.wfs.world.rank
167            assert self.calc.wfs.gd.comm.size == 1
168            return True
169
170    def count_occupied_bands(self):
171        """Count number of occupied and unoccupied bands in ground state
172        calculation. Can be used to omit null-transitions between two occupied
173        bands or between two unoccupied bands."""
174        ftol = 1.e-9  # Could be given as input
175        nocc1 = 9999999
176        nocc2 = 0
177        for kpt in self.calc.wfs.kpt_u:
178            f_n = kpt.f_n / kpt.weight
179            nocc1 = min((f_n > 1 - ftol).sum(), nocc1)
180            nocc2 = max((f_n > ftol).sum(), nocc2)
181        nocc1 = int(nocc1)
182        nocc2 = int(nocc2)
183
184        # Collect nocc for all k-points
185        nocc1 = self.calc.wfs.kd.comm.min(nocc1)
186        nocc2 = self.calc.wfs.kd.comm.max(nocc2)
187
188        # Sum over band distribution
189        nocc1 = self.calc.wfs.bd.comm.sum(nocc1)
190        nocc2 = self.calc.wfs.bd.comm.sum(nocc2)
191
192        self.nocc1 = int(nocc1)
193        self.nocc2 = int(nocc2)
194        print('Number of completely filled bands:', self.nocc1, file=self.fd)
195        print('Number of partially filled bands:', self.nocc2, file=self.fd)
196        print('Total number of bands:', self.calc.wfs.bd.nbands,
197              file=self.fd)
198
199    @property
200    def pd0(self):
201        """Get a PWDescriptor that includes all k-points"""
202        if self._pd0 is None:
203            from gpaw.wavefunctions.pw import PWDescriptor
204            wfs = self.calc.wfs
205            assert wfs.gd.comm.size == 1
206
207            kd0 = wfs.kd.copy()
208            pd, gd = wfs.pd, wfs.gd
209
210            # Extract stuff from self.calc.wfs.pd
211            ecut, dtype = pd.ecut, pd.dtype
212            fftwflags, gammacentered = pd.fftwflags, pd.gammacentered
213
214            # Initiate _pd0 with kd0
215            self._pd0 = PWDescriptor(ecut, gd, dtype=dtype,
216                                     kd=kd0, fftwflags=fftwflags,
217                                     gammacentered=gammacentered)
218        return self._pd0
219
220    @timer('Get Kohn-Sham pairs')
221    def get_kpoint_pairs(self, n1_t, n2_t, k1_pc, k2_pc, s1_t, s2_t):
222        """Get all pairs of Kohn-Sham orbitals for transitions:
223        (n1_t, k1_p, s1_t) -> (n2_t, k2_p, s2_t)
224        Here, t is a composite band and spin transition index
225        and p is indexing the different k-points to be distributed."""
226
227        # Distribute transitions and extract data for transitions in
228        # this process' block
229        nt = len(n1_t)
230        assert nt == len(n2_t)
231        self.distribute_transitions(nt)
232
233        kpt1 = self.get_kpoints(k1_pc, n1_t, s1_t)
234        kpt2 = self.get_kpoints(k2_pc, n2_t, s2_t)
235
236        # The process might not have any k-point pairs to evaluate, as
237        # their are distributed in the kptblockcomm
238        if kpt1 is None:
239            assert kpt2 is None
240            return None
241        assert kpt2 is not None
242
243        return KohnShamKPointPair(kpt1, kpt2,
244                                  self.mynt, nt, self.ta, self.tb,
245                                  comm=self.transitionblockscomm)
246
247    def distribute_transitions(self, nt):
248        """Distribute transitions between processes in block communicator"""
249        if self.transitionblockscomm is None:
250            mynt = nt
251            ta = 0
252            tb = nt
253        else:
254            nblocks = self.transitionblockscomm.size
255            rank = self.transitionblockscomm.rank
256
257            mynt = (nt + nblocks - 1) // nblocks
258            ta = min(rank * mynt, nt)
259            tb = min(ta + mynt, nt)
260
261        self.mynt = mynt
262        self.nt = nt
263        self.ta = ta
264        self.tb = tb
265
266    def get_kpoints(self, k_pc, n_t, s_t):
267        """Get KohnShamKPoint and help other processes extract theirs"""
268        assert len(n_t) == len(s_t)
269        assert len(k_pc) <= self.kptblockcomm.size
270        kpt = None
271
272        # Use the data extraction factory to extract the kptdata
273        _extract_kptdata = self.create_extract_kptdata()
274        kptdata = _extract_kptdata(k_pc, n_t, s_t)
275
276        # Make local n and s arrays for the KohnShamKPoint object
277        n_myt = np.empty(self.mynt, dtype=n_t.dtype)
278        n_myt[:self.tb - self.ta] = n_t[self.ta:self.tb]
279        s_myt = np.empty(self.mynt, dtype=s_t.dtype)
280        s_myt[:self.tb - self.ta] = s_t[self.ta:self.tb]
281
282        # Initiate k-point object.
283        if self.kptblockcomm.rank in range(len(k_pc)):
284            assert kptdata is not None
285            kpt = KohnShamKPoint(n_myt, s_myt, *kptdata)
286
287        return kpt
288
289    def create_extract_kptdata(self):
290        """Creator component of the data extraction factory."""
291        if self.calc_parallel:
292            return self.parallel_extract_kptdata
293        else:
294            return self.serial_extract_kptdata
295            # Useful for debugging:
296            # return self.parallel_extract_kptdata
297
298    def parallel_extract_kptdata(self, k_pc, n_t, s_t):
299        """Returns the input to KohnShamKPoint:
300        K, n_myt, s_myt, eps_myt, f_myt, ut_mytR, projections, shift_c
301        if a k-point in the given list, k_pc, belongs to the process.
302        """
303        # Extract the data from the ground state calculator object
304        data, h_myt, myt_myt = self._parallel_extract_kptdata(k_pc, n_t, s_t)
305
306        # If the process has a k-point to return, symmetrize and unfold
307        if self.kptblockcomm.rank in range(len(k_pc)):
308            assert data is not None
309            # Unpack data, apply FT and symmetrization
310            K, k_c, eps_h, f_h, Ph, psit_hG = data
311            Ph, ut_hR, shift_c = self.transform_and_symmetrize(K, k_c, Ph,
312                                                               psit_hG)
313
314            (eps_myt, f_myt,
315             P, ut_mytR) = self.unfold_arrays(eps_h, f_h, Ph, ut_hR,
316                                              h_myt, myt_myt)
317
318            data = (K, eps_myt, f_myt, ut_mytR, P, shift_c)
319
320        # Wait for communication to finish
321        with self.timer('Waiting to complete mpi.send'):
322            while self.srequests:
323                self.world.wait(self.srequests.pop(0))
324
325        return data
326
327    @timer('Extracting data from the ground state calculator object')
328    def _parallel_extract_kptdata(self, k_pc, n_t, s_t):
329        """In-place kptdata extraction."""
330        (data, myu_eu,
331         myn_eueh, ik_r2,
332         nrh_r2, eh_eur2reh,
333         rh_eur2reh, h_r1rh,
334         h_myt, myt_myt) = self.get_extraction_protocol(k_pc, n_t, s_t)
335
336        (eps_r1rh, f_r1rh,
337         P_r1rhI, psit_r1rhG,
338         eps_r2rh, f_r2rh,
339         P_r2rhI, psit_r2rhG) = self.allocate_transfer_arrays(data, nrh_r2,
340                                                              ik_r2, h_r1rh)
341
342        # Do actual extraction
343        for myu, myn_eh, eh_r2reh, rh_r2reh in zip(myu_eu, myn_eueh,
344                                                   eh_eur2reh, rh_eur2reh):
345
346            eps_eh, f_eh, P_ehI = self.extract_wfs_data(myu, myn_eh)
347
348            for r2, (eh_reh, rh_reh) in enumerate(zip(eh_r2reh, rh_r2reh)):
349                if eh_reh:
350                    eps_r2rh[r2][rh_reh] = eps_eh[eh_reh]
351                    f_r2rh[r2][rh_reh] = f_eh[eh_reh]
352                    P_r2rhI[r2][rh_reh] = P_ehI[eh_reh]
353
354            # Wavefunctions are heavy objects which can only be extracted
355            # for one band index at a time, handle them seperately
356            self.add_wave_function(myu, myn_eh, eh_r2reh,
357                                   rh_r2reh, psit_r2rhG)
358
359        self.distribute_extracted_data(eps_r1rh, f_r1rh, P_r1rhI, psit_r1rhG,
360                                       eps_r2rh, f_r2rh, P_r2rhI, psit_r2rhG)
361
362        data = self.collect_kptdata(data, h_r1rh, eps_r1rh,
363                                    f_r1rh, P_r1rhI, psit_r1rhG)
364
365        return data, h_myt, myt_myt
366
367    @timer('Create data extraction protocol')
368    def get_extraction_protocol(self, k_pc, n_t, s_t):
369        """Figure out how to extract data efficiently.
370        For the serial communicator, all processes can access all data,
371        and resultantly, there is no need to send any data.
372        """
373        wfs = self.calc.wfs
374        get_extraction_info = self.create_get_extraction_info()
375
376        # Kpoint data
377        data = (None, None, None)
378
379        # Extraction protocol
380        myu_eu = []
381        myn_eueh = []
382
383        # Data distribution protocol
384        nrh_r2 = np.zeros(self.world.size, dtype=int)
385        ik_r2 = [None for _ in range(self.world.size)]
386        eh_eur2reh = []
387        rh_eur2reh = []
388        h_r1rh = [list([]) for _ in range(self.world.size)]
389
390        # h to t index mapping
391        myt_myt = np.arange(self.tb - self.ta)
392        t_myt = range(self.ta, self.tb)
393        n_myt, s_myt = n_t[t_myt], s_t[t_myt]
394        h_myt = np.empty(self.tb - self.ta, dtype=int)
395
396        nt = len(n_t)
397        assert nt == len(s_t)
398        t_t = np.arange(nt)
399        nh = 0
400        for p, k_c in enumerate(k_pc):  # p indicates the receiving process
401            K = self.find_kpoint(k_c)
402            ik = wfs.kd.bz2ibz_k[K]
403            for r2 in range(p * self.transitionblockscomm.size,
404                            min((p + 1) * self.transitionblockscomm.size,
405                                self.world.size)):
406                ik_r2[r2] = ik
407
408            if p == self.kptblockcomm.rank:
409                data = (K, k_c, ik)
410
411            # Find out who should store the data in KSKPpoint
412            r2_t, myt_t = self.map_who_has(p, t_t)
413
414            # Find out how to extract data
415            # In the ground state, kpts are indexed by u=(s, k)
416            for s in set(s_t):
417                thiss_myt = s_myt == s
418                thiss_t = s_t == s
419                t_ct = t_t[thiss_t]
420                n_ct = n_t[thiss_t]
421                r2_ct = r2_t[t_ct]
422
423                # Find out where data is in wfs
424                u = ik * wfs.nspins + s
425                myu, r1_ct, myn_ct = get_extraction_info(u, n_ct, r2_ct)
426
427                # If the process is extracting or receiving data,
428                # figure out how to do so
429                if self.world.rank in np.append(r1_ct, r2_ct):
430                    # Does this process have anything to send?
431                    thisr1_ct = r1_ct == self.world.rank
432                    if np.any(thisr1_ct):
433                        eh_r2reh = [list([]) for _ in range(self.world.size)]
434                        rh_r2reh = [list([]) for _ in range(self.world.size)]
435                        # Find composite indeces h = (n, s)
436                        n_et = n_ct[thisr1_ct]
437                        n_eh = np.unique(n_et)
438                        # Find composite local band indeces
439                        myn_eh = np.unique(myn_ct[thisr1_ct])
440
441                        # Where to send the data
442                        r2_et = r2_ct[thisr1_ct]
443                        for r2 in np.unique(r2_et):
444                            thisr2_et = r2_et == r2
445                            # What ns are the process sending?
446                            n_reh = np.unique(n_et[thisr2_et])
447                            eh_reh = []
448                            for n in n_reh:
449                                eh_reh.append(np.where(n_eh == n)[0][0])
450                            # How to send it
451                            eh_r2reh[r2] = eh_reh
452                            nreh = len(eh_reh)
453                            rh_r2reh[r2] = np.arange(nreh) + nrh_r2[r2]
454                            nrh_r2[r2] += nreh
455
456                        myu_eu.append(myu)
457                        myn_eueh.append(myn_eh)
458                        eh_eur2reh.append(eh_r2reh)
459                        rh_eur2reh.append(rh_r2reh)
460
461                    # Does this process have anything to receive?
462                    thisr2_ct = r2_ct == self.world.rank
463                    if np.any(thisr2_ct):
464                        # Find unique composite indeces h = (n, s)
465                        n_rt = n_ct[thisr2_ct]
466                        n_rn = np.unique(n_rt)
467                        nrn = len(n_rn)
468                        h_rn = np.arange(nrn) + nh
469                        nh += nrn
470
471                        # Where to get the data from
472                        r1_rt = r1_ct[thisr2_ct]
473                        for r1 in np.unique(r1_rt):
474                            thisr1_rt = r1_rt == r1
475                            # What ns are the process getting?
476                            n_reh = np.unique(n_rt[thisr1_rt])
477                            # Where to put them
478                            for n in n_reh:
479                                h = h_rn[np.where(n_rn == n)[0][0]]
480                                h_r1rh[r1].append(h)
481
482                                # h to t mapping
483                                thisn_myt = n_myt == n
484                                thish_myt = np.logical_and(thisn_myt,
485                                                           thiss_myt)
486                                h_myt[thish_myt] = h
487
488        return (data, myu_eu, myn_eueh, ik_r2, nrh_r2,
489                eh_eur2reh, rh_eur2reh, h_r1rh, h_myt, myt_myt)
490
491    def create_get_extraction_info(self):
492        """Creator component of the extraction information factory."""
493        if self.calc_parallel:
494            return self.get_parallel_extraction_info
495        else:
496            return self.get_serial_extraction_info
497
498    @staticmethod
499    def get_serial_extraction_info(u, n_ct, r2_ct):
500        """Figure out where to extract the data from in the gs calc"""
501        # Let the process extract its own data
502        myu = u  # The process has access to all data
503        r1_ct = r2_ct
504        myn_ct = n_ct
505
506        return myu, r1_ct, myn_ct
507
508    def get_parallel_extraction_info(self, u, n_ct, *unused):
509        """Figure out where to extract the data from in the gs calc"""
510        wfs = self.calc.wfs
511        # Find out where data is in wfs
512        k, s = divmod(u, wfs.nspins)
513        kptrank, q = wfs.kd.who_has(k)
514        myu = q * wfs.nspins + s
515        r1_ct, myn_ct = [], []
516        for n in n_ct:
517            bandrank, myn = wfs.bd.who_has(n)
518            # XXX this will fail when using non-standard nesting
519            # of communicators.
520            r1 = (kptrank * wfs.gd.comm.size * wfs.bd.comm.size
521                  + bandrank * wfs.gd.comm.size)
522            r1_ct.append(r1)
523            myn_ct.append(myn)
524
525        return myu, np.array(r1_ct), np.array(myn_ct)
526
527    @timer('Allocate transfer arrays')
528    def allocate_transfer_arrays(self, data, nrh_r2, ik_r2, h_r1rh):
529        """Allocate arrays for intermediate storage of data."""
530        wfs = self.calc.wfs
531        kptex = wfs.kpt_u[0]
532        Pshape = kptex.projections.array.shape
533        Pdtype = kptex.projections.matrix.dtype
534        psitdtype = kptex.psit.array.dtype
535
536        # Number of h-indeces to receive
537        nrh_r1 = [len(h_rh) for h_rh in h_r1rh]
538
539        # if self.kptblockcomm.rank in range(len(ik_p)):
540        if data[2] is not None:
541            ik = data[2]
542            ng = self.pd0.ng_q[ik]
543            eps_r1rh, f_r1rh, P_r1rhI, psit_r1rhG = [], [], [], []
544            for nrh in nrh_r1:
545                if nrh >= 1:
546                    eps_r1rh.append(np.empty(nrh))
547                    f_r1rh.append(np.empty(nrh))
548                    P_r1rhI.append(np.empty((nrh,) + Pshape[1:], dtype=Pdtype))
549                    psit_r1rhG.append(np.empty((nrh, ng), dtype=psitdtype))
550                else:
551                    eps_r1rh.append(None)
552                    f_r1rh.append(None)
553                    P_r1rhI.append(None)
554                    psit_r1rhG.append(None)
555        else:
556            eps_r1rh, f_r1rh, P_r1rhI, psit_r1rhG = None, None, None, None
557
558        eps_r2rh, f_r2rh, P_r2rhI, psit_r2rhG = [], [], [], []
559        for nrh, ik in zip(nrh_r2, ik_r2):
560            if nrh:
561                eps_r2rh.append(np.empty(nrh))
562                f_r2rh.append(np.empty(nrh))
563                P_r2rhI.append(np.empty((nrh,) + Pshape[1:], dtype=Pdtype))
564                ng = self.pd0.ng_q[ik]
565                psit_r2rhG.append(np.empty((nrh, ng), dtype=psitdtype))
566            else:
567                eps_r2rh.append(None)
568                f_r2rh.append(None)
569                P_r2rhI.append(None)
570                psit_r2rhG.append(None)
571
572        return (eps_r1rh, f_r1rh, P_r1rhI, psit_r1rhG,
573                eps_r2rh, f_r2rh, P_r2rhI, psit_r2rhG)
574
575    def map_who_has(self, p, t_t):
576        """Convert k-point and transition index to global world rank
577        and local transition index"""
578        trank_t, myt_t = np.divmod(t_t, self.mynt)
579        return p * self.transitionblockscomm.size + trank_t, myt_t
580
581    @timer('Extracting eps, f and P_I from wfs')
582    def extract_wfs_data(self, myu, myn_eh):
583        wfs = self.calc.wfs
584        kpt = wfs.kpt_u[myu]
585        # Get eig and occ
586        eps_eh, f_eh = kpt.eps_n[myn_eh], kpt.f_n[myn_eh] / kpt.weight
587
588        # Get projections
589        assert kpt.projections.atom_partition.comm.size == 1
590        P_ehI = kpt.projections.array[myn_eh]
591
592        return eps_eh, f_eh, P_ehI
593
594    @timer('Extracting wave function from wfs')
595    def add_wave_function(self, myu, myn_eh,
596                          eh_r2reh, rh_r2reh, psit_r2rhG):
597        """Add the plane wave coefficients of the smooth part of
598        the wave function to the psit_r2rtG arrays."""
599        wfs = self.calc.wfs
600        kpt = wfs.kpt_u[myu]
601
602        for eh_reh, rh_reh, psit_rhG in zip(eh_r2reh, rh_r2reh, psit_r2rhG):
603            if eh_reh:
604                for eh, rh in zip(eh_reh, rh_reh):
605                    psit_rhG[rh] = kpt.psit_nG[myn_eh[eh]]
606
607    @timer('Distributing kptdata')
608    def distribute_extracted_data(self, eps_r1rh, f_r1rh, P_r1rhI, psit_r1rhG,
609                                  eps_r2rh, f_r2rh, P_r2rhI, psit_r2rhG):
610        """Send the extracted data to appropriate destinations"""
611        # Store the data extracted by the process itself
612        rank = self.world.rank
613        # Check if there is actually some data to store
614        if eps_r2rh[rank] is not None:
615            eps_r1rh[rank] = eps_r2rh[rank]
616            f_r1rh[rank] = f_r2rh[rank]
617            P_r1rhI[rank] = P_r2rhI[rank]
618            psit_r1rhG[rank] = psit_r2rhG[rank]
619
620        # Receive data
621        if eps_r1rh is not None:  # The process may not be receiving anything
622            for r1, (eps_rh, f_rh,
623                     P_rhI, psit_rhG) in enumerate(zip(eps_r1rh, f_r1rh,
624                                                       P_r1rhI, psit_r1rhG)):
625                # Check if there is any data to receive
626                if r1 != rank and eps_rh is not None:
627                    rreq1 = self.world.receive(eps_rh, r1,
628                                               tag=201, block=False)
629                    rreq2 = self.world.receive(f_rh, r1,
630                                               tag=202, block=False)
631                    rreq3 = self.world.receive(P_rhI, r1,
632                                               tag=203, block=False)
633                    rreq4 = self.world.receive(psit_rhG, r1,
634                                               tag=204, block=False)
635                    self.rrequests += [rreq1, rreq2, rreq3, rreq4]
636
637        # Send data
638        for r2, (eps_rh, f_rh,
639                 P_rhI, psit_rhG) in enumerate(zip(eps_r2rh, f_r2rh,
640                                                   P_r2rhI, psit_r2rhG)):
641            # Check if there is any data to send
642            if r2 != rank and eps_rh is not None:
643                sreq1 = self.world.send(eps_rh, r2, tag=201, block=False)
644                sreq2 = self.world.send(f_rh, r2, tag=202, block=False)
645                sreq3 = self.world.send(P_rhI, r2, tag=203, block=False)
646                sreq4 = self.world.send(psit_rhG, r2, tag=204, block=False)
647                self.srequests += [sreq1, sreq2, sreq3, sreq4]
648
649        with self.timer('Waiting to complete mpi.receive'):
650            while self.rrequests:
651                self.world.wait(self.rrequests.pop(0))
652
653    @timer('Collecting kptdata')
654    def collect_kptdata(self, data, h_r1rh,
655                        eps_r1rh, f_r1rh, P_r1rhI, psit_r1rhG):
656        """From the extracted data, collect the KohnShamKPoint data arrays"""
657
658        # Some processes may not have to return a k-point
659        if data[0] is None:
660            return None
661        K, k_c, ik = data
662
663        # Allocate data arrays
664        wfs = self.calc.wfs
665        maxh_r1 = [max(h_rh) for h_rh in h_r1rh if h_rh]
666        if maxh_r1:
667            nh = max(maxh_r1) + 1
668        else:  # Carry around empty array
669            assert self.ta == self.tb
670            nh = 1
671        eps_h = np.empty(nh)
672        f_h = np.empty(nh)
673        Ph = wfs.kpt_u[0].projections.new(nbands=nh, bcomm=None)
674        psit_hG = np.empty((nh, self.pd0.ng_q[ik]),
675                           dtype=wfs.kpt_u[0].psit.array.dtype)
676
677        # Store extracted data in the arrays
678        for (h_rh, eps_rh,
679             f_rh, P_rhI, psit_rhG) in zip(h_r1rh, eps_r1rh,
680                                           f_r1rh, P_r1rhI, psit_r1rhG):
681            if h_rh:
682                eps_h[h_rh] = eps_rh
683                f_h[h_rh] = f_rh
684                Ph.array[h_rh] = P_rhI
685                psit_hG[h_rh] = psit_rhG
686
687        return (K, k_c, eps_h, f_h, Ph, psit_hG)
688
689    @timer('Unfolding arrays')
690    def unfold_arrays(self, eps_h, f_h, Ph, ut_hR, h_myt, myt_myt):
691        """Create transition data arrays from the composite h = (n, s) index"""
692
693        wfs = self.calc.wfs
694        # Allocate data arrays for the k-point
695        mynt = self.mynt
696        eps_myt = np.empty(mynt)
697        f_myt = np.empty(mynt)
698        P = wfs.kpt_u[0].projections.new(nbands=mynt, bcomm=None)
699        ut_mytR = wfs.gd.empty(self.mynt, wfs.dtype)
700
701        # Unfold k-point data
702        eps_myt[myt_myt] = eps_h[h_myt]
703        f_myt[myt_myt] = f_h[h_myt]
704        P.array[myt_myt] = Ph.array[h_myt]
705        ut_mytR[myt_myt] = ut_hR[h_myt]
706
707        return eps_myt, f_myt, P, ut_mytR
708
709    @timer('Extracting data from the ground state calculator object')
710    def serial_extract_kptdata(self, k_pc, n_t, s_t):
711        # All processes can access all data. Each process extracts it own data.
712        wfs = self.calc.wfs
713
714        # Do data extraction for the processes, which have data to extract
715        if self.kptblockcomm.rank in range(len(k_pc)):
716            # Find k-point indeces
717            k_c = k_pc[self.kptblockcomm.rank]
718            K = self.find_kpoint(k_c)
719            ik = wfs.kd.bz2ibz_k[K]
720            # Construct symmetry operators
721            (_, T, a_a, U_aii, shift_c,
722             time_reversal) = self.construct_symmetry_operators(K, k_c=k_c)
723
724            (myu_eu, myn_eurn, nh, h_eurn, h_myt,
725             myt_myt) = self.get_serial_extraction_protocol(ik, n_t, s_t)
726
727            # Allocate transfer arrays
728            eps_h = np.empty(nh)
729            f_h = np.empty(nh)
730            Ph = wfs.kpt_u[0].projections.new(nbands=nh, bcomm=None)
731            ut_hR = wfs.gd.empty(nh, wfs.dtype)
732
733            # Extract data from the ground state
734            for myu, myn_rn, h_rn in zip(myu_eu, myn_eurn, h_eurn):
735                kpt = wfs.kpt_u[myu]
736                with self.timer('Extracting eps, f and P_I from wfs'):
737                    eps_h[h_rn] = kpt.eps_n[myn_rn]
738                    f_h[h_rn] = kpt.f_n[myn_rn] / kpt.weight
739                    Ph.array[h_rn] = kpt.projections.array[myn_rn]
740
741                with self.timer('Extracting, fourier transforming and '
742                                'symmetrizing wave function'):
743                    for myn, h in zip(myn_rn, h_rn):
744                        ut_hR[h] = T(wfs.pd.ifft(kpt.psit_nG[myn], kpt.q))
745
746            # Symmetrize projections
747            with self.timer('Apply symmetry operations'):
748                P_ahi = []
749                for a1, U_ii in zip(a_a, U_aii):
750                    P_hi = np.ascontiguousarray(Ph[a1])
751                    # Apply symmetry operations. This will map a1 onto a2
752                    np.dot(P_hi, U_ii, out=P_hi)
753                    if time_reversal:
754                        np.conj(P_hi, out=P_hi)
755                    P_ahi.append(P_hi)
756
757                # Store symmetrized projectors
758                for a2, P_hi in enumerate(P_ahi):
759                    I1, I2 = Ph.map[a2]
760                    Ph.array[..., I1:I2] = P_hi
761
762            (eps_myt, f_myt,
763             P, ut_mytR) = self.unfold_arrays(eps_h, f_h, Ph, ut_hR,
764                                              h_myt, myt_myt)
765
766            return (K, eps_myt, f_myt, ut_mytR, P, shift_c)
767
768    @timer('Create data extraction protocol')
769    def get_serial_extraction_protocol(self, ik, n_t, s_t):
770        """Figure out how to extract data efficiently.
771        For the serial communicator, all processes can access all data,
772        and resultantly, there is no need to send any data.
773        """
774        wfs = self.calc.wfs
775
776        # Only extract the transitions handled by the process itself
777        myt_myt = np.arange(self.tb - self.ta)
778        t_myt = range(self.ta, self.tb)
779        n_myt = n_t[t_myt]
780        s_myt = s_t[t_myt]
781
782        # In the ground state, kpts are indexed by u=(s, k)
783        myu_eu = []
784        myn_eurn = []
785        nh = 0
786        h_eurn = []
787        h_myt = np.empty(self.tb - self.ta, dtype=int)
788        for s in set(s_myt):
789            thiss_myt = s_myt == s
790            n_ct = n_myt[thiss_myt]
791
792            # Find unique composite h = (n, u) indeces
793            n_rn = np.unique(n_ct)
794            nrn = len(n_rn)
795            h_eurn.append(np.arange(nrn) + nh)
796            nh += nrn
797
798            # Find mapping between h and the transition index
799            for n, h in zip(n_rn, h_eurn[-1]):
800                thisn_myt = n_myt == n
801                thish_myt = np.logical_and(thisn_myt, thiss_myt)
802                h_myt[thish_myt] = h
803
804            # Find out where data is in wfs
805            u = ik * wfs.nspins + s
806            # The process has access to all data
807            myu = u
808            myn_rn = n_rn
809
810            myu_eu.append(myu)
811            myn_eurn.append(myn_rn)
812
813        return myu_eu, myn_eurn, nh, h_eurn, h_myt, myt_myt
814
815    @timer('Identifying k-points')
816    def find_kpoint(self, k_c):
817        return self.kdtree.query(np.mod(np.mod(k_c, 1).round(6), 1))[1]
818
819    @timer('Apply symmetry operations')
820    def transform_and_symmetrize(self, K, k_c, Ph, psit_hG):
821        """Get wave function on a real space grid and symmetrize it
822        along with the corresponding PAW projections."""
823        (_, T, a_a, U_aii, shift_c,
824         time_reversal) = self.construct_symmetry_operators(K, k_c=k_c)
825
826        # Symmetrize wave functions
827        wfs = self.calc.wfs
828        ik = wfs.kd.bz2ibz_k[K]
829        ut_hR = wfs.gd.empty(len(psit_hG), wfs.dtype)
830        with self.timer('Fourier transform and symmetrize wave functions'):
831            for h, psit_G in enumerate(psit_hG):
832                ut_hR[h] = T(self.pd0.ifft(psit_G, ik))
833
834        # Symmetrize projections
835        P_ahi = []
836        for a1, U_ii in zip(a_a, U_aii):
837            P_hi = np.ascontiguousarray(Ph[a1])
838            # Apply symmetry operations. This will map a1 onto a2
839            np.dot(P_hi, U_ii, out=P_hi)
840            if time_reversal:
841                np.conj(P_hi, out=P_hi)
842            P_ahi.append(P_hi)
843
844        # Store symmetrized projectors
845        for a2, P_hi in enumerate(P_ahi):
846            I1, I2 = Ph.map[a2]
847            Ph.array[..., I1:I2] = P_hi
848
849        return Ph, ut_hR, shift_c
850
851    @timer('Construct symmetry operators')
852    def construct_symmetry_operators(self, K, k_c=None):
853        """Construct symmetry operators for wave function and PAW projections.
854
855        We want to transform a k-point in the irreducible part of the BZ to
856        the corresponding k-point with index K.
857
858        Returns U_cc, T, a_a, U_aii, shift_c and time_reversal, where:
859
860        * U_cc is a rotation matrix.
861        * T() is a function that transforms the periodic part of the wave
862          function.
863        * a_a is a list of symmetry related atom indices
864        * U_aii is a list of rotation matrices for the PAW projections
865        * shift_c is three integers: see code below.
866        * time_reversal is a flag - if True, projections should be complex
867          conjugated.
868
869        See the extract_orbitals() method for how to use these tuples.
870        """
871
872        wfs = self.calc.wfs
873        kd = wfs.kd
874
875        s = kd.sym_k[K]
876        U_cc = kd.symmetry.op_scc[s]
877        time_reversal = kd.time_reversal_k[K]
878        ik = kd.bz2ibz_k[K]
879        if k_c is None:
880            k_c = kd.bzk_kc[K]
881        ik_c = kd.ibzk_kc[ik]
882
883        sign = 1 - 2 * time_reversal
884        shift_c = np.dot(U_cc, ik_c) - k_c * sign
885
886        try:
887            assert np.allclose(shift_c.round(), shift_c)
888        except AssertionError:
889            print('shift_c ' + str(shift_c), file=self.fd)
890            print('k_c ' + str(k_c), file=self.fd)
891            print('kd.bzk_kc[K] ' + str(kd.bzk_kc[K]), file=self.fd)
892            print('ik_c ' + str(ik_c), file=self.fd)
893            print('U_cc ' + str(U_cc), file=self.fd)
894            print('sign ' + str(sign), file=self.fd)
895            raise AssertionError
896
897        shift_c = shift_c.round().astype(int)
898
899        if (U_cc == np.eye(3)).all():
900            def T(f_R):
901                return f_R
902        else:
903            N_c = self.calc.wfs.gd.N_c
904            i_cr = np.dot(U_cc.T, np.indices(N_c).reshape((3, -1)))
905            i = np.ravel_multi_index(i_cr, N_c, 'wrap')
906
907            def T(f_R):
908                return f_R.ravel()[i].reshape(N_c)
909
910        if time_reversal:
911            T0 = T
912
913            def T(f_R):
914                return T0(f_R).conj()
915
916            shift_c *= -1
917
918        a_a = []
919        U_aii = []
920        for a, id in enumerate(self.calc.wfs.setups.id_a):
921            b = kd.symmetry.a_sa[s, a]
922            S_c = np.dot(self.calc.spos_ac[a], U_cc) - self.calc.spos_ac[b]
923            x = np.exp(2j * np.pi * np.dot(ik_c, S_c))
924            U_ii = wfs.setups[a].R_sii[s].T * x
925            a_a.append(b)
926            U_aii.append(U_ii)
927
928        shift0_c = (kd.bzk_kc[K] - k_c).round().astype(int)
929        shift_c += -shift0_c
930
931        return U_cc, T, a_a, U_aii, shift_c, time_reversal
932
933
934def get_calc(gs, fd=None, timer=None):
935    """Get ground state calculation object."""
936    if isinstance(gs, GPAW):
937        return gs
938    else:
939        if timer is None:
940            def timer(*unused):
941                def __enter__(self):
942                    pass
943
944                def __exit__(self):
945                    pass
946
947        with timer('Read ground state'):
948            assert Path(gs).is_file()
949            if fd is not None:
950                print('Reading ground state calculation:\n  %s' % gs,
951                      file=fd)
952            with disable_dry_run():
953                return GPAW(gs, txt=None, communicator=mpi.serial_comm)
954
955
956class PairMatrixElement:
957    """Class for calculating matrix elements for transitions in Kohn-Sham
958    linear response functions."""
959    def __init__(self, kspair):
960        """
961        Parameters
962        ----------
963        kslrf : KohnShamLinearResponseFunction instance
964        """
965        self.calc = kspair.calc
966        self.fd = kspair.fd
967        self.timer = kspair.timer
968        self.transitionblockscomm = kspair.transitionblockscomm
969
970    def initialize(self, *args, **kwargs):
971        """Initialize e.g. PAW corrections or other operations
972        ahead in time of integration."""
973        pass
974
975    def __call__(self, kskptpair, *args, **kwargs):
976        """Calculate the matrix element for all transitions in kskptpairs."""
977        raise NotImplementedError('Define specific matrix element')
978
979
980class PlaneWavePairDensity(PairMatrixElement):
981    """Class for calculating pair densities:
982
983    n_T(q+G) = <s'n'k'| e^(i (q + G) r) |snk>
984
985    in the plane wave mode"""
986    def __init__(self, kspair):
987        PairMatrixElement.__init__(self, kspair)
988
989        # Save PAW correction for all calls with same q_c
990        self.Q_aGii = None
991        self.currentq_c = None
992
993    def initialize(self, pd):
994        """Initialize PAW corrections ahead in time of integration."""
995        self.initialize_paw_corrections(pd)
996
997    @timer('Initialize PAW corrections')
998    def initialize_paw_corrections(self, pd):
999        """Initialize PAW corrections, if not done already, for the given q"""
1000        q_c = pd.kd.bzk_kc[0]
1001        if self.Q_aGii is None or not np.allclose(q_c - self.currentq_c, 0.):
1002            self.Q_aGii = self._initialize_paw_corrections(pd)
1003            self.currentq_c = q_c
1004
1005    def _initialize_paw_corrections(self, pd):
1006        wfs = self.calc.wfs
1007        spos_ac = self.calc.spos_ac
1008        G_Gv = pd.get_reciprocal_vectors()
1009
1010        pos_av = np.dot(spos_ac, pd.gd.cell_cv)
1011
1012        # Collect integrals for all species:
1013        Q_xGii = {}
1014        for id, atomdata in wfs.setups.setups.items():
1015            Q_Gii = two_phi_planewave_integrals(G_Gv, atomdata)
1016            ni = atomdata.ni
1017            Q_Gii.shape = (-1, ni, ni)
1018
1019            Q_xGii[id] = Q_Gii
1020
1021        Q_aGii = []
1022        for a, atomdata in enumerate(wfs.setups):
1023            id = wfs.setups.id_a[a]
1024            Q_Gii = Q_xGii[id]
1025            x_G = np.exp(-1j * np.dot(G_Gv, pos_av[a]))
1026            Q_aGii.append(x_G[:, np.newaxis, np.newaxis] * Q_Gii)
1027
1028        return Q_aGii
1029
1030    @timer('Calculate pair density')
1031    def __call__(self, kskptpair, pd):
1032        """Calculate the pair densities for all transitions:
1033        n_t(q+G) = <s'n'k+q| e^(i (q + G) r) |snk>
1034                 = <snk| e^(-i (q + G) r) |s'n'k+q>
1035        """
1036        Q_aGii = self.get_paw_projectors(pd)
1037        Q_G = self.get_fft_indices(kskptpair, pd)
1038        mynt, nt, ta, tb = kskptpair.transition_distribution()
1039
1040        n_mytG = pd.empty(mynt)
1041
1042        # Calculate smooth part of the pair densities:
1043        with self.timer('Calculate smooth part'):
1044            ut1cc_mytR = kskptpair.kpt1.ut_tR.conj()
1045            n_mytR = ut1cc_mytR * kskptpair.kpt2.ut_tR
1046            # Unvectorized
1047            for myt in range(tb - ta):
1048                n_mytG[myt] = pd.fft(n_mytR[myt], 0, Q_G) * pd.gd.dv
1049
1050        # Calculate PAW corrections with numpy
1051        with self.timer('PAW corrections'):
1052            P1 = kskptpair.kpt1.projections
1053            P2 = kskptpair.kpt2.projections
1054            for (Q_Gii, (a1, P1_myti),
1055                 (a2, P2_myti)) in zip(Q_aGii, P1.items(), P2.items()):
1056                P1cc_myti = P1_myti[:tb - ta].conj()
1057                C1_Gimyt = np.tensordot(Q_Gii, P1cc_myti, axes=([1, 1]))
1058                P2_imyt = P2_myti.T[:, :tb - ta]
1059                n_mytG[:tb - ta] += np.sum(C1_Gimyt * P2_imyt[np.newaxis,
1060                                                              :, :], axis=1).T
1061
1062        # Attach the calculated pair density to the KohnShamKPointPair object
1063        kskptpair.attach('n_mytG', 'n_tG', n_mytG)
1064
1065    def get_paw_projectors(self, pd):
1066        """Make sure PAW correction has been initialized properly
1067        and return projectors"""
1068        self.initialize_paw_corrections(pd)
1069        return self.Q_aGii
1070
1071    @timer('Get G-vector indices')
1072    def get_fft_indices(self, kskptpair, pd):
1073        """Get indices for G-vectors inside cutoff sphere."""
1074        kpt1 = kskptpair.kpt1
1075        kpt2 = kskptpair.kpt2
1076        kd = self.calc.wfs.kd
1077        q_c = pd.kd.bzk_kc[0]
1078
1079        N_G = pd.Q_qG[0]
1080
1081        shift_c = kpt1.shift_c - kpt2.shift_c
1082        shift_c += (q_c - kd.bzk_kc[kpt2.K]
1083                    + kd.bzk_kc[kpt1.K]).round().astype(int)
1084        if shift_c.any():
1085            n_cG = np.unravel_index(N_G, pd.gd.N_c)
1086            n_cG = [n_G + shift for n_G, shift in zip(n_cG, shift_c)]
1087            N_G = np.ravel_multi_index(n_cG, pd.gd.N_c, 'wrap')
1088        return N_G
1089