1
2import functools
3import numbers
4import sys
5from math import pi
6
7import numpy as np
8from scipy.spatial import Delaunay, cKDTree
9
10from ase.units import Ha
11from gpaw.utilities import convert_string_to_fd
12from ase.utils.timing import timer, Timer
13
14import gpaw.mpi as mpi
15from gpaw import GPAW, disable_dry_run
16from gpaw.fd_operators import Gradient
17from gpaw.kpt_descriptor import KPointDescriptor
18from gpaw.response.math_func import (two_phi_planewave_integrals,
19                                     two_phi_nabla_planewave_integrals)
20from gpaw.utilities.blas import gemm
21from gpaw.utilities.progressbar import ProgressBar
22from gpaw.wavefunctions.pw import PWLFC
23from gpaw.bztools import get_reduced_bz, unique_rows
24
25
26class KPoint:
27    def __init__(self, s, K, n1, n2, blocksize, na, nb,
28                 ut_nR, eps_n, f_n, P_ani, shift_c):
29        self.s = s    # spin index
30        self.K = K    # BZ k-point index
31        self.n1 = n1  # first band
32        self.n2 = n2  # first band not included
33        self.blocksize = blocksize
34        self.na = na  # first band of block
35        self.nb = nb  # first band of block not included
36        self.ut_nR = ut_nR      # periodic part of wave functions in real-space
37        self.eps_n = eps_n      # eigenvalues
38        self.f_n = f_n          # occupation numbers
39        self.P_ani = P_ani      # PAW projections
40        self.shift_c = shift_c  # long story - see the
41        # PairDensity.construct_symmetry_operators() method
42
43
44class KPointPair:
45    """This class defines the kpoint-pair container object.
46
47    Used for calculating pair quantities it contains two kpoints,
48    and an associated set of Fourier components."""
49    def __init__(self, kpt1, kpt2, Q_G):
50        self.kpt1 = kpt1
51        self.kpt2 = kpt2
52        self.Q_G = Q_G
53
54    def get_k1(self):
55        """ Return KPoint object 1."""
56        return self.kpt1
57
58    def get_k2(self):
59        """ Return KPoint object 2."""
60        return self.kpt2
61
62    def get_planewave_indices(self):
63        """ Return the planewave indices associated with this pair."""
64        return self.Q_G
65
66    def get_transition_energies(self, n_n, m_m):
67        """Return the energy difference for specified bands."""
68        n_n = np.array(n_n)
69        m_m = np.array(m_m)
70        kpt1 = self.kpt1
71        kpt2 = self.kpt2
72        deps_nm = (kpt1.eps_n[n_n - self.kpt1.n1][:, np.newaxis] -
73                   kpt2.eps_n[m_m - self.kpt2.n1])
74        return deps_nm
75
76    def get_occupation_differences(self, n_n, m_m):
77        """Get difference in occupation factor between specified bands."""
78        n_n = np.array(n_n)
79        m_m = np.array(m_m)
80        kpt1 = self.kpt1
81        kpt2 = self.kpt2
82        df_nm = (kpt1.f_n[n_n - self.kpt1.n1][:, np.newaxis] -
83                 kpt2.f_n[m_m - self.kpt2.n1])
84        return df_nm
85
86
87class PWSymmetryAnalyzer:
88    """Class for handling planewave symmetries."""
89    def __init__(self, kd, pd, txt=sys.stdout,
90                 disable_point_group=False,
91                 disable_non_symmorphic=True,
92                 disable_time_reversal=False,
93                 timer=None):
94        """Creates a PWSymmetryAnalyzer object.
95
96        Determines which of the symmetries of the atomic structure
97        that is compatible with the reciprocal lattice. Contains the
98        necessary functions for mapping quantities between kpoints,
99        and or symmetrizing arrays.
100
101        kd: KPointDescriptor
102            The kpoint descriptor containing the
103            information about symmetries and kpoints.
104        pd: PWDescriptor
105            Plane wave descriptor that contains the reciprocal
106            lattice .
107        txt: str
108            Output file.
109        disable_point_group: bool
110            Switch for disabling point group symmetries.
111        disable_non_symmorphic:
112            Switch for disabling non symmorphic symmetries.
113        disable_time_reversal:
114            Switch for disabling time reversal.
115        """
116        self.pd = pd
117        self.kd = kd
118        self.fd = txt
119
120        # Caveats
121        assert disable_non_symmorphic, \
122            print('You are not allowed to use non symmorphic syms, sorry. ',
123                  file=self.fd)
124
125        # Settings
126        self.disable_point_group = disable_point_group
127        self.disable_time_reversal = disable_time_reversal
128        self.disable_non_symmorphic = disable_non_symmorphic
129        if (kd.symmetry.has_inversion or not kd.symmetry.time_reversal) and \
130           not self.disable_time_reversal:
131            print('\nThe ground calculation does not support time-reversal ' +
132                  'symmetry possibly because it has an inversion center ' +
133                  'or that it has been manually deactivated. \n', file=self.fd)
134            self.disable_time_reversal = True
135
136        self.disable_symmetries = (self.disable_point_group and
137                                   self.disable_time_reversal and
138                                   self.disable_non_symmorphic)
139
140        # Number of symmetries
141        U_scc = kd.symmetry.op_scc
142        self.nU = len(U_scc)
143
144        self.nsym = 2 * self.nU
145        self.use_time_reversal = not self.disable_time_reversal
146
147        # Which timer to use
148        self.timer = timer or Timer()
149
150        self.KDTree = cKDTree(np.mod(np.mod(kd.bzk_kc, 1).round(6), 1))
151
152        # Initialize
153        self.initialize()
154
155    @timer('Initialize')
156    def initialize(self):
157        """Initialize relevant quantities."""
158        self.infostring = ''
159        if self.disable_point_group:
160            self.infostring += 'Point group not included. '
161        else:
162            self.infostring += 'Point group included. '
163
164        if self.disable_time_reversal:
165            self.infostring += 'Time reversal not included. '
166        else:
167            self.infostring += 'Time reversal included. '
168
169        if self.disable_non_symmorphic:
170            self.infostring += 'Disabled non symmorphic symmetries. '
171        else:
172            self.infostring += 'Time reversal included. '
173
174        if self.disable_symmetries:
175            self.infostring += 'All symmetries have been disabled. '
176
177        # Do the work
178        self.analyze_symmetries()
179        self.analyze_kpoints()
180        self.initialize_G_maps()
181
182        # Print info
183        print(self.infostring, file=self.fd)
184        self.print_symmetries()
185
186    def find_kpoint(self, k_c):
187        return self.KDTree.query(np.mod(np.mod(k_c, 1).round(6), 1))[1]
188
189    def print_symmetries(self):
190        """Handsome print function for symmetry operations."""
191
192        p = functools.partial(print, file=self.fd)
193
194        p()
195        nx = 6 if self.disable_non_symmorphic else 3
196        ns = len(self.s_s)
197        y = 0
198        for y in range((ns + nx - 1) // nx):
199            for c in range(3):
200                for x in range(nx):
201                    s = x + y * nx
202                    if s == ns:
203                        break
204                    tmp = self.get_symmetry_operator(self.s_s[s])
205                    op_cc, sign, TR, shift_c, ft_c = tmp
206                    op_c = sign * op_cc[c]
207                    p('  (%2d %2d %2d)' % tuple(op_c), end='')
208                p()
209            p()
210
211    @timer('Analyze')
212    def analyze_kpoints(self):
213        """Calculate the reduction in the number of kpoints."""
214        K_gK = self.group_kpoints()
215        ng = len(K_gK)
216        self.infostring += '{0} groups of equivalent kpoints. '.format(ng)
217        percent = (1. - (ng + 0.) / self.kd.nbzkpts) * 100
218        self.infostring += '{0}% reduction. '.format(percent)
219
220    @timer('Analyze symmetries.')
221    def analyze_symmetries(self):
222        r"""Determine allowed symmetries.
223
224        An direct symmetry U must fulfill::
225
226          U \mathbf{q} = q + \Delta
227
228        Under time-reversal (indirect) it must fulfill::
229
230          -U \mathbf{q} = q + \Delta
231
232        where :math:`\Delta` is a reciprocal lattice vector.
233        """
234        pd = self.pd
235
236        # Shortcuts
237        q_c = pd.kd.bzk_kc[0]
238        kd = self.kd
239
240        U_scc = kd.symmetry.op_scc
241        nU = self.nU
242        nsym = self.nsym
243
244        shift_sc = np.zeros((nsym, 3), int)
245        conserveq_s = np.zeros(nsym, bool)
246
247        newq_sc = np.dot(U_scc, q_c)
248
249        # Direct symmetries
250        dshift_sc = (newq_sc - q_c[np.newaxis]).round().astype(int)
251        inds_s = np.argwhere((newq_sc == q_c[np.newaxis] + dshift_sc).all(1))
252        conserveq_s[inds_s] = True
253
254        shift_sc[:nU] = dshift_sc
255
256        # Time reversal
257        trshift_sc = (-newq_sc - q_c[np.newaxis]).round().astype(int)
258        trinds_s = np.argwhere((-newq_sc == q_c[np.newaxis] +
259                                trshift_sc).all(1)) + nU
260        conserveq_s[trinds_s] = True
261        shift_sc[nU:nsym] = trshift_sc
262
263        # The indices of the allowed symmetries
264        s_s = conserveq_s.nonzero()[0]
265
266        # Filter out disabled symmetries
267        if self.disable_point_group:
268            s_s = list(filter(self.is_not_point_group, s_s))
269
270        if self.disable_time_reversal:
271            s_s = list(filter(self.is_not_time_reversal, s_s))
272
273        if self.disable_non_symmorphic:
274            s_s = list(filter(self.is_not_non_symmorphic, s_s))
275
276#        stmp_s = []
277#        for s in s_s:
278#            if self.kd.bz2bz_ks[0, s] == -1:
279#                assert (self.kd.bz2bz_ks[:, s] == -1).all()
280#            else:
281#                stmp_s.append(s)
282
283#        s_s = stmp_s
284
285        self.infostring += 'Found {} allowed symmetries. '.format(len(s_s))
286        self.s_s = s_s
287        self.shift_sc = shift_sc
288
289    def is_not_point_group(self, s):
290        U_scc = self.kd.symmetry.op_scc
291        nU = self.nU
292        return (U_scc[s % nU] == np.eye(3)).all()
293
294    def is_not_time_reversal(self, s):
295        nU = self.nU
296        return not bool(s // nU)
297
298    def is_not_non_symmorphic(self, s):
299        ft_sc = self.kd.symmetry.ft_sc
300        nU = self.nU
301        return not bool(ft_sc[s % nU].any())
302
303    def how_many_symmetries(self):
304        """Return number of symmetries."""
305        return len(self.s_s)
306
307    @timer('Group kpoints')
308    def group_kpoints(self, K_k=None):
309        """Group kpoints according to the reduced symmetries"""
310        if K_k is None:
311            K_k = np.arange(self.kd.nbzkpts)
312        s_s = self.s_s
313        bz2bz_ks = self.kd.bz2bz_ks
314        nk = len(bz2bz_ks)
315        sbz2sbz_ks = bz2bz_ks[K_k][:, s_s]  # Reduced number of symmetries
316        # Avoid -1 (see documentation in gpaw.symmetry)
317        sbz2sbz_ks[sbz2sbz_ks == -1] = nk
318
319        smallestk_k = np.sort(sbz2sbz_ks)[:, 0]
320        k2g_g = np.unique(smallestk_k, return_index=True)[1]
321
322        K_gs = sbz2sbz_ks[k2g_g]
323        K_gk = [np.unique(K_s[K_s != nk]) for K_s in K_gs]
324
325        return K_gk
326
327    def get_BZ(self):
328        # Get the little group of q
329        U_scc = []
330        for s in self.s_s:
331            U_cc, sign, _, _, _ = self.get_symmetry_operator(s)
332            U_scc.append(sign * U_cc)
333        U_scc = np.array(U_scc)
334
335        # Determine the irreducible BZ
336        bzk_kc, ibzk_kc = get_reduced_bz(self.pd.gd.cell_cv,
337                                         U_scc,
338                                         False)
339
340        return bzk_kc
341
342    def get_reduced_kd(self, pbc_c=np.ones(3, bool)):
343        # Get the little group of q
344        U_scc = []
345        for s in self.s_s:
346            U_cc, sign, _, _, _ = self.get_symmetry_operator(s)
347            U_scc.append(sign * U_cc)
348        U_scc = np.array(U_scc)
349
350        # Determine the irreducible BZ
351        bzk_kc, ibzk_kc = get_reduced_bz(self.pd.gd.cell_cv,
352                                         U_scc,
353                                         False,
354                                         pbc_c=pbc_c)
355
356        n = 3
357        N_xc = np.indices((n, n, n)).reshape((3, n**3)).T - n // 2
358
359        # Find the irreducible kpoints
360        tess = Delaunay(ibzk_kc)
361        ik_kc = []
362        for N_c in N_xc:
363            k_kc = self.kd.bzk_kc + N_c
364            k_kc = k_kc[tess.find_simplex(k_kc) >= 0]
365            if not len(ik_kc) and len(k_kc):
366                ik_kc = unique_rows(k_kc)
367            elif len(k_kc):
368                ik_kc = unique_rows(np.append(k_kc, ik_kc, axis=0))
369
370        return KPointDescriptor(ik_kc)
371
372    def unfold_kpoints(self, points_pv, tol=1e-8, mod=None):
373        points_pc = np.dot(points_pv, self.pd.gd.cell_cv.T) / (2 * np.pi)
374
375        # Get the little group of q
376        U_scc = []
377        for s in self.s_s:
378            U_cc, sign, _, _, _ = self.get_symmetry_operator(s)
379            U_scc.append(sign * U_cc)
380        U_scc = np.array(U_scc)
381
382        points = np.concatenate(np.dot(points_pc, U_scc.transpose(0, 2, 1)))
383        points = unique_rows(points, tol=tol, mod=mod)
384        points = np.dot(points, self.pd.gd.icell_cv) * (2 * np.pi)
385        return points
386
387    def get_kpoint_weight(self, k_c):
388        K = self.find_kpoint(k_c)
389        iK = self.kd.bz2ibz_k[K]
390        K_k = self.unfold_ibz_kpoint(iK)
391        K_gK = self.group_kpoints(K_k)
392
393        for K_k in K_gK:
394            if K in K_k:
395                if self.kd.refine_info is not None:
396                    weight = sum(self.kd.refine_info.weight_k[K_k])
397                    return weight
398                else:
399                    return len(K_k)
400
401    def get_kpoint_mapping(self, K1, K2):
402        """Get index of symmetry for mapping between K1 and K2"""
403        s_s = self.s_s
404        bz2bz_ks = self.kd.bz2bz_ks
405        bzk2rbz_s = bz2bz_ks[K1][s_s]
406        try:
407            s = np.argwhere(bzk2rbz_s == K2)[0][0]
408        except IndexError:
409            print('K = {0} cannot be mapped into K = {1}'.format(K1, K2),
410                  file=self.fd)
411            raise
412        return s_s[s]
413
414    def get_shift(self, K1, K2, U_cc, sign):
415        """Get shift for mapping between K1 and K2."""
416        kd = self.kd
417        k1_c = kd.bzk_kc[K1]
418        k2_c = kd.bzk_kc[K2]
419
420        shift_c = np.dot(U_cc, k1_c) - k2_c * sign
421        assert np.allclose(shift_c.round(), shift_c)
422        shift_c = shift_c.round().astype(int)
423
424        return shift_c
425
426    @timer('map_G')
427    def map_G(self, K1, K2, a_MG):
428        """Map a function of G from K1 to K2. """
429        if len(a_MG) == 0:
430            return []
431
432        if K1 == K2:
433            return a_MG
434
435        G_G, sign = self.map_G_vectors(K1, K2)
436
437        s = self.get_kpoint_mapping(K1, K2)
438        U_cc, _, TR, shift_c, ft_c = self.get_symmetry_operator(s)
439
440        return TR(a_MG[..., G_G])
441
442    def symmetrize_wGG(self, A_wGG):
443        """Symmetrize an array in GG'."""
444
445        for A_GG in A_wGG:
446            tmp_GG = np.zeros_like(A_GG)
447
448            for s in self.s_s:
449                G_G, sign, _ = self.G_sG[s]
450                if sign == 1:
451                    tmp_GG += A_GG[G_G, :][:, G_G]
452                if sign == -1:
453                    tmp_GG += A_GG[G_G, :][:, G_G].T
454
455            A_GG[:] = tmp_GG / self.how_many_symmetries()
456
457    def symmetrize_wxx(self, A_wxx, optical_limit=False):
458        """Symmetrize an array in xx'."""
459        tmp_wxx = np.zeros_like(A_wxx)
460
461        A_cv = self.pd.gd.cell_cv
462        iA_cv = self.pd.gd.icell_cv
463
464        if self.use_time_reversal:
465            AT_wxx = np.transpose(A_wxx, (0, 2, 1))
466
467        for s in self.s_s:
468            G_G, sign, shift_c = self.G_sG[s]
469            if optical_limit:
470                G_G = np.array(G_G) + 2
471                G_G = np.insert(G_G, 0, [0, 1])
472                U_cc, _, TR, shift_c, ft_c = self.get_symmetry_operator(s)
473                M_vv = np.dot(np.dot(A_cv.T, U_cc.T), iA_cv)
474
475            if sign == 1:
476                tmp = A_wxx[:, G_G, :][:, :, G_G]
477                if optical_limit:
478                    tmp[:, 0:3, :] = np.transpose(np.dot(M_vv.T,
479                                                         tmp[:, 0:3, :]),
480                                                  (1, 0, 2))
481                    tmp[:, :, 0:3] = np.dot(tmp[..., 0:3], M_vv)
482                tmp_wxx += tmp
483            elif sign == -1:
484                tmp = AT_wxx[:, G_G, :][:, :, G_G]
485                if optical_limit:
486                    tmp[:, 0:3, :] = np.transpose(np.dot(M_vv.T,
487                                                         tmp[:, 0:3, :]),
488                                                  (1, 0, 2)) * sign
489                    tmp[:, :, 0:3] = np.dot(tmp[:, :, 0:3], M_vv) * sign
490                tmp_wxx += tmp
491
492        # Inplace overwriting
493        A_wxx[:] = tmp_wxx / self.how_many_symmetries()
494
495    def symmetrize_wxvG(self, A_wxvG):
496        """Symmetrize chi0_wxvG"""
497        A_cv = self.pd.gd.cell_cv
498        iA_cv = self.pd.gd.icell_cv
499
500        if self.use_time_reversal:
501            # ::-1 corresponds to transpose in wing indices
502            AT_wxvG = A_wxvG[:, ::-1]
503
504        tmp_wxvG = np.zeros_like(A_wxvG)
505        for s in self.s_s:
506            G_G, sign, shift_c = self.G_sG[s]
507            U_cc, _, TR, shift_c, ft_c = self.get_symmetry_operator(s)
508            M_vv = np.dot(np.dot(A_cv.T, U_cc.T), iA_cv)
509            if sign == 1:
510                tmp = sign * np.dot(M_vv.T, A_wxvG[..., G_G])
511            elif sign == -1:
512                tmp = sign * np.dot(M_vv.T, AT_wxvG[..., G_G])
513            tmp_wxvG += np.transpose(tmp, (1, 2, 0, 3))
514
515        # Overwrite the input
516        A_wxvG[:] = tmp_wxvG / self.how_many_symmetries()
517
518    def symmetrize_wvv(self, A_wvv):
519        """Symmetrize chi_wvv."""
520        A_cv = self.pd.gd.cell_cv
521        iA_cv = self.pd.gd.icell_cv
522        tmp_wvv = np.zeros_like(A_wvv)
523        if self.use_time_reversal:
524            AT_wvv = np.transpose(A_wvv, (0, 2, 1))
525
526        for s in self.s_s:
527            G_G, sign, shift_c = self.G_sG[s]
528            U_cc, _, TR, shift_c, ft_c = self.get_symmetry_operator(s)
529            M_vv = np.dot(np.dot(A_cv.T, U_cc.T), iA_cv)
530            if sign == 1:
531                tmp = np.dot(np.dot(M_vv.T, A_wvv), M_vv)
532            elif sign == -1:
533                tmp = np.dot(np.dot(M_vv.T, AT_wvv), M_vv)
534            tmp_wvv += np.transpose(tmp, (1, 0, 2))
535
536        # Overwrite the input
537        A_wvv[:] = tmp_wvv / self.how_many_symmetries()
538
539    @timer('map_v')
540    def map_v(self, K1, K2, a_Mv):
541        """Map a function of v (cartesian component) from K1 to K2."""
542
543        if len(a_Mv) == 0:
544            return []
545
546        if K1 == K2:
547            return a_Mv
548
549        A_cv = self.pd.gd.cell_cv
550        iA_cv = self.pd.gd.icell_cv
551
552        # Get symmetry
553        s = self.get_kpoint_mapping(K1, K2)
554        U_cc, sign, TR, _, ft_c = self.get_symmetry_operator(s)
555
556        # Create cartesian operator
557        M_vv = np.dot(np.dot(A_cv.T, U_cc.T), iA_cv)
558        return sign * np.dot(TR(a_Mv), M_vv)
559
560    def timereversal(self, s):
561        """Is this a time-reversal symmetry?"""
562        tr = bool(s // self.nU)
563        return tr
564
565    def get_symmetry_operator(self, s):
566        """Return symmetry operator s."""
567        U_scc = self.kd.symmetry.op_scc
568        ft_sc = self.kd.symmetry.op_scc
569
570        reds = s % self.nU
571        if self.timereversal(s):
572            TR = np.conj
573            sign = -1
574        else:
575            sign = 1
576
577            def TR(x):
578                return x
579
580        return U_scc[reds], sign, TR, self.shift_sc[s], ft_sc[reds]
581
582    @timer('map_G_vectors')
583    def map_G_vectors(self, K1, K2):
584        """Return G vector mapping."""
585        s = self.get_kpoint_mapping(K1, K2)
586        G_G, sign, shift_c = self.G_sG[s]
587
588        return G_G, sign
589
590    def initialize_G_maps(self):
591        """Calculate the Gvector mappings."""
592        pd = self.pd
593        B_cv = 2.0 * np.pi * pd.gd.icell_cv
594        G_Gv = pd.get_reciprocal_vectors(add_q=False)
595        G_Gc = np.dot(G_Gv, np.linalg.inv(B_cv))
596        Q_G = pd.Q_qG[0]
597
598        G_sG = [None] * self.nsym
599        UG_sGc = [None] * self.nsym
600        Q_sG = [None] * self.nsym
601        for s in self.s_s:
602            U_cc, sign, TR, shift_c, ft_c = self.get_symmetry_operator(s)
603            iU_cc = np.linalg.inv(U_cc).T
604            UG_Gc = np.dot(G_Gc - shift_c, sign * iU_cc)
605
606            assert np.allclose(UG_Gc.round(), UG_Gc)
607            UQ_G = np.ravel_multi_index(UG_Gc.round().astype(int).T,
608                                        pd.gd.N_c, 'wrap')
609
610            G_G = len(Q_G) * [None]
611            for G, UQ in enumerate(UQ_G):
612                try:
613                    G_G[G] = np.argwhere(Q_G == UQ)[0][0]
614                except IndexError:
615                    print('This should not be possible but' +
616                          'a G-vector was mapped outside the sphere')
617                    raise IndexError
618            UG_sGc[s] = UG_Gc
619            Q_sG[s] = UQ_G
620            G_sG[s] = [G_G, sign, shift_c]
621        self.G_Gc = G_Gc
622        self.UG_sGc = UG_sGc
623        self.Q_sG = Q_sG
624        self.G_sG = G_sG
625
626    def unfold_ibz_kpoint(self, ik):
627        """Return kpoints related to irreducible kpoint."""
628        kd = self.kd
629        K_k = np.unique(kd.bz2bz_ks[kd.ibz2bz_k[ik]])
630        K_k = K_k[K_k != -1]
631        return K_k
632
633
634class PairDensity:
635    def __init__(self, gs, ecut=50, response='density',
636                 ftol=1e-6, threshold=1,
637                 real_space_derivatives=False,
638                 world=mpi.world, txt='-', timer=None,
639                 nblocks=1, gate_voltage=None,
640                 paw_correction='brute-force', **unused):
641        """Density matrix elements
642
643        Parameters
644        ----------
645        ftol : float
646            Threshold determining whether a band is completely filled
647            (f > 1 - ftol) or completely empty (f < ftol).
648        threshold : float
649            Numerical threshold for the optical limit k dot p perturbation
650            theory expansion.
651        real_space_derivatives : bool
652            Calculate nabla matrix elements (in the optical limit)
653            using a real space finite difference approximation.
654        gate_voltage : float
655            Shift the fermi level by gate_voltage [Hartree].
656        """
657        self.world = world
658        self.fd = convert_string_to_fd(txt, world)
659        self.timer = timer or Timer()
660
661        with self.timer('Read ground state'):
662            if not isinstance(gs, GPAW):
663                print('Reading ground state calculation:\n  %s' % gs,
664                      file=self.fd)
665                with disable_dry_run():
666                    calc = GPAW(gs, communicator=mpi.serial_comm)
667            else:
668                calc = gs
669                assert calc.wfs.world.size == 1
670
671        assert calc.wfs.kd.symmetry.symmorphic
672        self.calc = calc
673
674        if ecut is not None:
675            ecut /= Ha
676
677        if gate_voltage is not None:
678            gate_voltage = gate_voltage / Ha
679
680        self.response = response
681        self.ecut = ecut
682        self.ftol = ftol
683        self.threshold = threshold
684        self.real_space_derivatives = real_space_derivatives
685        self.gate_voltage = gate_voltage
686
687        if nblocks == 1:
688            self.blockcomm = world.new_communicator([world.rank])
689            self.kncomm = world
690        else:
691            assert world.size % nblocks == 0, world.size
692            rank1 = world.rank // nblocks * nblocks
693            rank2 = rank1 + nblocks
694            self.blockcomm = self.world.new_communicator(range(rank1, rank2))
695            ranks = range(world.rank % nblocks, world.size, nblocks)
696            self.kncomm = self.world.new_communicator(ranks)
697
698        self.fermi_level = self.calc.wfs.fermi_level
699
700        if gate_voltage is not None:
701            self.add_gate_voltage(gate_voltage)
702
703        self.spos_ac = calc.spos_ac
704
705        self.nocc1 = None  # number of completely filled bands
706        self.nocc2 = None  # number of non-empty bands
707        self.count_occupied_bands()
708
709        self.ut_sKnvR = None  # gradient of wave functions for optical limit
710
711        self.vol = abs(np.linalg.det(calc.wfs.gd.cell_cv))
712
713        kd = self.calc.wfs.kd
714        self.KDTree = cKDTree(np.mod(np.mod(kd.bzk_kc, 1).round(6), 1))
715        print('Number of blocks:', nblocks, file=self.fd)
716
717        self.paw_correction = paw_correction
718
719    def find_kpoint(self, k_c):
720        return self.KDTree.query(np.mod(np.mod(k_c, 1).round(6), 1))[1]
721
722    def add_gate_voltage(self, gate_voltage=0):
723        """Shifts the Fermi-level by e * Vg. By definition e = 1."""
724        assert self.calc.wfs.occupations.name in {'fermi-dirac', 'zero-width'}
725        print('Shifting Fermi-level by %.2f eV' % (gate_voltage * Ha),
726              file=self.fd)
727        self.fermi_level += gate_voltage
728        for kpt in self.calc.wfs.kpt_u:
729            kpt.f_n = (self.shift_occupations(kpt.eps_n, gate_voltage) *
730                       kpt.weight)
731
732    def shift_occupations(self, eps_n, gate_voltage):
733        """Shift fermilevel."""
734        fermi = self.fermi_level
735        width = getattr(self.calc.wfs.occupations, '_width', 0.0) / Ha
736        if width < 1e-9:
737            return (eps_n < fermi).astype(float)
738        else:
739            tmp = (eps_n - fermi) / width
740        f_n = np.zeros_like(eps_n)
741        f_n[tmp <= 100] = 1 / (1 + np.exp(tmp[tmp <= 100]))
742        f_n[tmp > 100] = 0.0
743        return f_n
744
745    def count_occupied_bands(self):
746        self.nocc1 = 9999999
747        self.nocc2 = 0
748        for kpt in self.calc.wfs.kpt_u:
749            f_n = kpt.f_n / kpt.weight
750            self.nocc1 = min((f_n > 1 - self.ftol).sum(), self.nocc1)
751            self.nocc2 = max((f_n > self.ftol).sum(), self.nocc2)
752        print('Number of completely filled bands:', self.nocc1, file=self.fd)
753        print('Number of partially filled bands:', self.nocc2, file=self.fd)
754        print('Total number of bands:', self.calc.wfs.bd.nbands,
755              file=self.fd)
756
757    def distribute_k_points_and_bands(self, band1, band2, kpts=None):
758        """Distribute spins, k-points and bands.
759
760        nbands: int
761            Number of bands for each spin/k-point combination.
762
763        The attribute self.mysKn1n2 will be set to a list of (s, K, n1, n2)
764        tuples that this process handles.
765        """
766
767        wfs = self.calc.wfs
768
769        if kpts is None:
770            kpts = np.arange(wfs.kd.nbzkpts)
771
772        nbands = band2 - band1
773        size = self.kncomm.size
774        rank = self.kncomm.rank
775        ns = wfs.nspins
776        nk = len(kpts)
777        n = (ns * nk * nbands + size - 1) // size
778        i1 = rank * n
779        i2 = min(i1 + n, ns * nk * nbands)
780
781        self.mysKn1n2 = []
782        i = 0
783        for s in range(ns):
784            for K in kpts:
785                n1 = min(max(0, i1 - i), nbands)
786                n2 = min(max(0, i2 - i), nbands)
787                if n1 != n2:
788                    self.mysKn1n2.append((s, K, n1 + band1, n2 + band1))
789                i += nbands
790
791        print('BZ k-points:', self.calc.wfs.kd, file=self.fd)
792        print('Distributing spins, k-points and bands (%d x %d x %d)' %
793              (ns, nk, nbands),
794              'over %d process%s' %
795              (self.kncomm.size, ['es', ''][self.kncomm.size == 1]),
796              file=self.fd)
797        print('Number of blocks:', self.blockcomm.size, file=self.fd)
798
799    @timer('Get a k-point')
800    def get_k_point(self, s, k_c, n1, n2, load_wfs=True, block=False):
801        """Return wave functions for a specific k-point and spin.
802
803        s: int
804            Spin index (0 or 1).
805        K: int
806            BZ k-point index.
807        n1, n2: int
808            Range of bands to include.
809        """
810
811        wfs = self.calc.wfs
812        kd = wfs.kd
813
814        # Parse kpoint: is k_c an index or a vector
815        if not isinstance(k_c, numbers.Integral):
816            K = self.find_kpoint(k_c)
817            shift0_c = (kd.bzk_kc[K] - k_c).round().astype(int)
818        else:
819            # Fall back to index
820            K = k_c
821            shift0_c = np.array([0, 0, 0])
822            k_c = None
823
824        if block:
825            nblocks = self.blockcomm.size
826            rank = self.blockcomm.rank
827        else:
828            nblocks = 1
829            rank = 0
830
831        blocksize = (n2 - n1 + nblocks - 1) // nblocks
832        na = min(n1 + rank * blocksize, n2)
833        nb = min(na + blocksize, n2)
834
835        U_cc, T, a_a, U_aii, shift_c, time_reversal = \
836            self.construct_symmetry_operators(K, k_c=k_c)
837
838        shift_c += -shift0_c
839        ik = wfs.kd.bz2ibz_k[K]
840        assert wfs.kd.comm.size == 1
841        kpt = wfs.kpt_qs[ik][s]
842
843        assert n2 <= len(kpt.eps_n), \
844            'Increase GS-nbands or decrease chi0-nbands!'
845        eps_n = kpt.eps_n[n1:n2]
846        f_n = kpt.f_n[n1:n2] / kpt.weight
847
848        if not load_wfs:
849            return KPoint(s, K, n1, n2, blocksize, na, nb,
850                          None, eps_n, f_n, None, shift_c)
851
852        with self.timer('load wfs'):
853            psit_nG = kpt.psit_nG
854            ut_nR = wfs.gd.empty(nb - na, wfs.dtype)
855            for n in range(na, nb):
856                ut_nR[n - na] = T(wfs.pd.ifft(psit_nG[n], ik))
857
858        with self.timer('Load projections'):
859            P_ani = []
860            for b, U_ii in zip(a_a, U_aii):
861                P_ni = np.dot(kpt.P_ani[b][na:nb], U_ii)
862                if time_reversal:
863                    P_ni = P_ni.conj()
864                P_ani.append(P_ni)
865
866        return KPoint(s, K, n1, n2, blocksize, na, nb,
867                      ut_nR, eps_n, f_n, P_ani, shift_c)
868
869    def generate_pair_densities(self, pd, m1, m2, spins, intraband=True,
870                                PWSA=None, disable_optical_limit=False,
871                                unsymmetrized=False, use_more_memory=1):
872        """Generator for returning pair densities.
873
874        Returns the pair densities between the occupied and
875        the states in range(m1, m2).
876
877        pd: PWDescriptor
878            Plane-wave descriptor for a single q-point.
879        m1: int
880            Index of first unoccupied band.
881        m2: int
882            Index of last unoccupied band.
883        spins: list
884            List of spin indices included.
885        intraband: bool
886            Include intraband transitions in optical limit.
887        PWSA: PlanewaveSymmetryAnalyzer
888            If supplied uses this object to determine the symmetries
889            of the pair-densities.
890        disable_optical_limit: bool
891            Disable optical limit.
892        unsymmetrized: bool
893            Only return pair-densities from one kpoint in each
894            group of equivalent kpoints.
895        use_more_memory: float
896            Group more pair densities for several occupied bands
897            together before returning. Here 0 <= use_more_memory <= 1,
898            where zero is the minimal amount of memory, and 1 is the maximal.
899        """
900        assert 0 <= use_more_memory <= 1
901
902        q_c = pd.kd.bzk_kc[0]
903        optical_limit = np.allclose(q_c, 0.0) and self.response == 'density'
904        optical_limit = not disable_optical_limit and optical_limit
905
906        Q_aGii = self.initialize_paw_corrections(pd)
907        self.Q_aGii = Q_aGii  # This is used in g0w0
908
909        if PWSA is None:
910            with self.timer('Symmetry analyzer'):
911                PWSA = PWSymmetryAnalyzer  # Line too long otherwise
912                PWSA = PWSA(self.calc.wfs.kd, pd,
913                            timer=self.timer, txt=self.fd)
914
915        pb = ProgressBar(self.fd)
916        for kn, (s, ik, n1, n2) in pb.enumerate(self.mysKn1n2):
917            Kstar_k = PWSA.unfold_ibz_kpoint(ik)
918            for K_k in PWSA.group_kpoints(Kstar_k):
919                # Let the first kpoint of the group represent
920                # the rest of the kpoints
921                K1 = K_k[0]
922                # In this way wavefunctions are only loaded into
923                # memory for this particular set of kpoints
924                kptpair = self.get_kpoint_pair(pd, s, K1, n1, n2, m1, m2)
925                kpt1 = kptpair.get_k1()  # kpt1 = k
926
927                if kpt1.s not in spins:
928                    continue
929                kpt2 = kptpair.get_k2()  # kpt2 = k + q
930
931                if unsymmetrized:
932                    # Number of times kpoints are mapped into themselves
933                    weight = np.sqrt(PWSA.how_many_symmetries() / len(K_k))
934
935                # Use kpt2 to compute intraband transitions
936                # These conditions are sufficient to make sure
937                # that it still works in parallel
938                if kpt1.n1 == 0 and self.blockcomm.rank == 0 and \
939                   optical_limit and intraband:
940                    assert self.nocc2 <= kpt2.nb, \
941                        print('Error: Too few unoccupied bands')
942                    vel0_mv = self.intraband_pair_density(kpt2)
943                    f_m = kpt2.f_n[kpt2.na - kpt2.n1:kpt2.nb - kpt2.n1]
944                    with self.timer('intraband'):
945                        if vel0_mv is not None:
946                            if unsymmetrized:
947                                yield (f_m, None, None,
948                                       None, None, vel0_mv / weight)
949                            else:
950                                for K2 in K_k:
951                                    vel_mv = PWSA.map_v(K1, K2, vel0_mv)
952                                    yield (f_m, None, None,
953                                           None, None, vel_mv)
954
955                # Divide the occupied bands into chunks
956                n_n = np.arange(n2 - n1)
957                if use_more_memory == 0:
958                    chunksize = 1
959                else:
960                    chunksize = np.ceil(len(n_n) *
961                                        use_more_memory).astype(int)
962
963                no_n = []
964                for i in range(len(n_n) // chunksize):
965                    i1 = i * chunksize
966                    i2 = min((i + 1) * chunksize, len(n_n))
967                    no_n.append(n_n[i1:i2])
968
969                # n runs over occupied bands
970                for n_n in no_n:  # n_n is a list of occupied band indices
971                    # m over unoccupied bands
972                    m_m = np.arange(0, kpt2.n2 - kpt2.n1)
973                    deps_nm = kptpair.get_transition_energies(n_n, m_m)
974                    df_nm = kptpair.get_occupation_differences(n_n, m_m)
975
976                    # This is not quite right for
977                    # degenerate partially occupied
978                    # bands, but good enough for now:
979                    df_nm[df_nm <= 1e-20] = 0.0
980
981                    # Get pair density for representative kpoint
982                    ol = optical_limit
983                    n0_nmG, n0_nmv, _ = self.get_pair_density(pd, kptpair,
984                                                              n_n, m_m,
985                                                              optical_limit=ol,
986                                                              intraband=False,
987                                                              Q_aGii=Q_aGii)
988
989                    n0_nmG[deps_nm >= 0.0] = 0.0
990                    if optical_limit:
991                        n0_nmv[deps_nm >= 0.0] = 0.0
992
993                    # Reshape nm -> m
994                    nG = pd.ngmax
995                    deps_m = deps_nm.reshape(-1)
996                    df_m = df_nm.reshape(-1)
997                    n0_mG = n0_nmG.reshape((-1, nG))
998                    if optical_limit:
999                        n0_mv = n0_nmv.reshape((-1, 3))
1000
1001                    if unsymmetrized:
1002                        if optical_limit:
1003                            yield (None, df_m, deps_m,
1004                                   n0_mG / weight, n0_mv / weight, None)
1005                        else:
1006                            yield (None, df_m, deps_m,
1007                                   n0_mG / weight, None, None)
1008                        continue
1009
1010                    # Collect pair densities in a single array
1011                    # and return them
1012                    nm = n0_mG.shape[0]
1013                    nG = n0_mG.shape[1]
1014                    nk = len(K_k)
1015
1016                    n_MG = np.empty((nm * nk, nG), complex)
1017                    if optical_limit:
1018                        n_Mv = np.empty((nm * nk, 3), complex)
1019                    deps_M = np.tile(deps_m, nk)
1020                    df_M = np.tile(df_m, nk)
1021
1022                    for i, K2 in enumerate(K_k):
1023                        i1 = i * nm
1024                        i2 = (i + 1) * nm
1025                        n_mG = PWSA.map_G(K1, K2, n0_mG)
1026
1027                        if optical_limit:
1028                            n_mv = PWSA.map_v(K1, K2, n0_mv)
1029                            n_mG[:, 0] = n_mv[:, 0]
1030                            n_Mv[i1:i2, :] = n_mv
1031
1032                        n_MG[i1:i2, :] = n_mG
1033
1034                    if optical_limit:
1035                        yield (None, df_M, deps_M, n_MG, n_Mv, None)
1036                    else:
1037                        yield (None, df_M, deps_M, n_MG, None, None)
1038
1039        pb.finish()
1040
1041    @timer('Get kpoint pair')
1042    def get_kpoint_pair(self, pd, s, Kork_c, n1, n2, m1, m2,
1043                        load_wfs=True, block=False):
1044        # wfs = self.calc.wfs
1045        # bzk_kc = wfs.kd.bzk_kc
1046
1047        if isinstance(Kork_c, int):
1048            # If k_c is an integer then it refers to
1049            # the index of the kpoint in the BZ
1050            k_c = self.calc.wfs.kd.bzk_kc[Kork_c]
1051        else:
1052            k_c = Kork_c
1053
1054        q_c = pd.kd.bzk_kc[0]
1055        with self.timer('get k-points'):
1056            kpt1 = self.get_k_point(s, k_c, n1, n2, load_wfs=load_wfs)
1057            # K2 = wfs.kd.find_k_plus_q(q_c, [kpt1.K])[0]
1058            if self.response in ['+-', '-+']:
1059                s2 = 1 - s
1060            else:
1061                s2 = s
1062            kpt2 = self.get_k_point(s2, k_c + q_c, m1, m2,
1063                                    load_wfs=load_wfs, block=block)
1064
1065        with self.timer('fft indices'):
1066            Q_G = self.get_fft_indices(kpt1.K, kpt2.K, q_c, pd,
1067                                       kpt1.shift_c - kpt2.shift_c)
1068
1069        return KPointPair(kpt1, kpt2, Q_G)
1070
1071    @timer('get_pair_density')
1072    def get_pair_density(self, pd, kptpair, n_n, m_m,
1073                         optical_limit=False, intraband=False,
1074                         Q_aGii=None, block=False, direction=2,
1075                         extend_head=True):
1076        """Get pair density for a kpoint pair."""
1077        ol = optical_limit = np.allclose(pd.kd.bzk_kc[0], 0.0) and \
1078            self.response == 'density'
1079        eh = extend_head
1080        cpd = self.calculate_pair_densities  # General pair densities
1081        opd = self.optical_pair_density  # Interband pair densities / q
1082
1083        if Q_aGii is None:
1084            Q_aGii = self.initialize_paw_corrections(pd)
1085
1086        kpt1 = kptpair.kpt1
1087        kpt2 = kptpair.kpt2
1088        Q_G = kptpair.Q_G  # Fourier components of kpoint pair
1089        nG = len(Q_G)
1090
1091        if extend_head:
1092            n_nmG = np.zeros((len(n_n), len(m_m), nG + 2 * ol), pd.dtype)
1093        else:
1094            n_nmG = np.zeros((len(n_n), len(m_m), nG), pd.dtype)
1095
1096        for j, n in enumerate(n_n):
1097            Q_G = kptpair.Q_G
1098            with self.timer('conj'):
1099                ut1cc_R = kpt1.ut_nR[n - kpt1.na].conj()
1100            with self.timer('paw'):
1101                C1_aGi = [np.dot(Q_Gii, P1_ni[n - kpt1.na].conj())
1102                          for Q_Gii, P1_ni in zip(Q_aGii, kpt1.P_ani)]
1103                n_nmG[j, :, 2 * ol * eh:] = cpd(ut1cc_R, C1_aGi, kpt2, pd, Q_G,
1104                                                block=block)
1105            if optical_limit:
1106                if extend_head:
1107                    n_nmG[j, :, 0:3] = opd(n, m_m, kpt1, kpt2,
1108                                           block=block)
1109                else:
1110                    n_nmG[j, :, 0] = opd(n, m_m, kpt1, kpt2,
1111                                         block=block)[:, direction]
1112        return n_nmG
1113
1114    @timer('get_pair_momentum')
1115    def get_pair_momentum(self, pd, kptpair, n_n, m_m, Q_avGii=None):
1116        r"""Calculate matrix elements of the momentum operator.
1117
1118        Calculates::
1119
1120          n_{nm\mathrm{k}}\int_{\Omega_{\mathrm{cell}}}\mathrm{d}\mathbf{r}
1121          \psi_{n\mathrm{k}}^*(\mathbf{r})
1122          e^{-i\,(\mathrm{q} + \mathrm{G})\cdot\mathbf{r}}
1123          \nabla\psi_{m\mathrm{k} + \mathrm{q}}(\mathbf{r})
1124
1125        pd: PlaneWaveDescriptor
1126            Plane wave descriptor of a single q_c.
1127        kptpair: KPointPair
1128            KpointPair object containing the two kpoints.
1129        n_n: list
1130            List of left-band indices (n).
1131        m_m:
1132            List of right-band indices (m).
1133        """
1134        wfs = self.calc.wfs
1135
1136        kpt1 = kptpair.kpt1
1137        kpt2 = kptpair.kpt2
1138        Q_G = kptpair.Q_G  # Fourier components of kpoint pair
1139
1140        # For the same band we
1141        kd = wfs.kd
1142        gd = wfs.gd
1143        k_c = kd.bzk_kc[kpt1.K] + kpt1.shift_c
1144        k_v = 2 * np.pi * np.dot(k_c, np.linalg.inv(gd.cell_cv).T)
1145
1146        # Calculate k + G
1147        G_Gv = pd.get_reciprocal_vectors(add_q=True)
1148        kqG_Gv = k_v[np.newaxis] + G_Gv
1149
1150        # Pair velocities
1151        n_nmvG = pd.zeros((len(n_n), len(m_m), 3))
1152
1153        # Calculate derivatives of left-wavefunction
1154        # (there will typically be fewer of these)
1155        ut_nvR = self.make_derivative(kpt1.s, kpt1.K, kpt1.n1, kpt1.n2)
1156
1157        # PAW-corrections
1158        if Q_avGii is None:
1159            Q_avGii = self.initialize_paw_nabla_corrections(pd)
1160
1161        # Iterate over occupied bands
1162        for j, n in enumerate(n_n):
1163            ut1cc_R = kpt1.ut_nR[n].conj()
1164
1165            n_mG = self.calculate_pair_densities(ut1cc_R,
1166                                                 [], kpt2,
1167                                                 pd, Q_G)
1168
1169            n_nmvG[j] = 1j * kqG_Gv.T[np.newaxis] * n_mG[:, np.newaxis]
1170
1171            # Treat each cartesian component at a time
1172            for v in range(3):
1173                # Minus from integration by parts
1174                utvcc_R = -ut_nvR[n, v].conj()
1175                Cv1_aGi = [np.dot(P1_ni[n].conj(), Q_vGii[v])
1176                           for Q_vGii, P1_ni in zip(Q_avGii, kpt1.P_ani)]
1177
1178                nv_mG = self.calculate_pair_densities(utvcc_R,
1179                                                      Cv1_aGi, kpt2,
1180                                                      pd, Q_G)
1181
1182                n_nmvG[j, :, v] += nv_mG
1183
1184        # We want the momentum operator
1185        n_nmvG *= -1j
1186
1187        return n_nmvG
1188
1189    @timer('Calculate pair-densities')
1190    def calculate_pair_densities(self, ut1cc_R, C1_aGi, kpt2, pd, Q_G,
1191                                 block=True):
1192        """Calculate FFT of pair-densities and add PAW corrections.
1193
1194        ut1cc_R: 3-d complex ndarray
1195            Complex conjugate of the periodic part of the left hand side
1196            wave function.
1197        C1_aGi: list of ndarrays
1198            PAW corrections for all atoms.
1199        kpt2: KPoint object
1200            Right hand side k-point object.
1201        pd: PWDescriptor
1202            Plane-wave descriptor for for q=k2-k1.
1203        Q_G: 1-d int ndarray
1204            Mapping from flattened 3-d FFT grid to 0.5(G+q)^2<ecut sphere.
1205        """
1206
1207        dv = pd.gd.dv
1208        n_mG = pd.empty(kpt2.blocksize)
1209        myblocksize = kpt2.nb - kpt2.na
1210
1211        for ut_R, n_G in zip(kpt2.ut_nR, n_mG):
1212            n_R = ut1cc_R * ut_R
1213            with self.timer('fft'):
1214                n_G[:] = pd.fft(n_R, 0, Q_G) * dv
1215        # PAW corrections:
1216        with self.timer('gemm'):
1217            for C1_Gi, P2_mi in zip(C1_aGi, kpt2.P_ani):
1218                gemm(1.0, C1_Gi, P2_mi, 1.0, n_mG[:myblocksize], 't')
1219
1220        if not block or self.blockcomm.size == 1:
1221            return n_mG
1222        else:
1223            n_MG = pd.empty(kpt2.blocksize * self.blockcomm.size)
1224            self.blockcomm.all_gather(n_mG, n_MG)
1225            return n_MG[:kpt2.n2 - kpt2.n1]
1226
1227    @timer('Optical limit')
1228    def optical_pair_velocity(self, n, m_m, kpt1, kpt2, block=False):
1229        if self.ut_sKnvR is None or kpt1.K not in self.ut_sKnvR[kpt1.s]:
1230            self.ut_sKnvR = self.calculate_derivatives(kpt1)
1231
1232        kd = self.calc.wfs.kd
1233        gd = self.calc.wfs.gd
1234        k_c = kd.bzk_kc[kpt1.K] + kpt1.shift_c
1235        k_v = 2 * np.pi * np.dot(k_c, np.linalg.inv(gd.cell_cv).T)
1236
1237        ut_vR = self.ut_sKnvR[kpt1.s][kpt1.K][n - kpt1.n1]
1238        atomdata_a = self.calc.wfs.setups
1239        if self.paw_correction == 'brute-force':
1240            C_avi = [np.dot(atomdata.nabla_iiv.T, P_ni[n - kpt1.na])
1241                     for atomdata, P_ni in zip(atomdata_a, kpt1.P_ani)]
1242        elif self.paw_correction == 'skip':
1243            C_avi = [np.zeros((3, P_ni.shape[1]), complex)
1244                     for atomdata, P_ni in zip(atomdata_a, kpt1.P_ani)]
1245        else:
1246            1 / 0
1247
1248        blockbands = kpt2.nb - kpt2.na
1249        n0_mv = np.empty((kpt2.blocksize, 3), dtype=complex)
1250        nt_m = np.empty(kpt2.blocksize, dtype=complex)
1251        n0_mv[:blockbands] = -self.calc.wfs.gd.integrate(ut_vR,
1252                                                         kpt2.ut_nR).T
1253        nt_m[:blockbands] = self.calc.wfs.gd.integrate(kpt1.ut_nR[n - kpt1.na],
1254                                                       kpt2.ut_nR)
1255
1256        n0_mv[:blockbands] += (1j * nt_m[:blockbands, np.newaxis] *
1257                               k_v[np.newaxis, :])
1258
1259        for C_vi, P_mi in zip(C_avi, kpt2.P_ani):
1260            gemm(1.0, C_vi, P_mi, 1.0, n0_mv[:blockbands], 'c')
1261
1262        if block and self.blockcomm.size > 1:
1263            n0_Mv = np.empty((kpt2.blocksize * self.blockcomm.size, 3),
1264                             dtype=complex)
1265            self.blockcomm.all_gather(n0_mv, n0_Mv)
1266            n0_mv = n0_Mv[:kpt2.n2 - kpt2.n1]
1267
1268        return -1j * n0_mv
1269
1270    def optical_pair_density(self, n, m_m, kpt1, kpt2,
1271                             block=False):
1272        # Relative threshold for perturbation theory
1273        threshold = self.threshold
1274
1275        eps1 = kpt1.eps_n[n - kpt1.n1]
1276        deps_m = (eps1 - kpt2.eps_n)[m_m - kpt2.n1]
1277        n0_mv = self.optical_pair_velocity(n, m_m, kpt1, kpt2,
1278                                           block=block)
1279
1280        deps_m = deps_m.copy()
1281        deps_m[deps_m == 0.0] = np.inf
1282
1283        smallness_mv = np.abs(-1e-3 * n0_mv / deps_m[:, np.newaxis])
1284        inds_mv = (np.logical_and(np.inf > smallness_mv,
1285                                  smallness_mv > threshold))
1286        n0_mv *= - 1 / deps_m[:, np.newaxis]
1287        n0_mv[inds_mv] = 0
1288
1289        return n0_mv
1290
1291    @timer('Intraband')
1292    def intraband_pair_density(self, kpt, n_n=None,
1293                               only_partially_occupied=False):
1294        """Calculate intraband matrix elements of nabla"""
1295        # Bands and check for block parallelization
1296        na, nb, n1 = kpt.na, kpt.nb, kpt.n1
1297        vel_nv = np.zeros((nb - na, 3), dtype=complex)
1298        if n_n is None:
1299            n_n = np.arange(na, nb)
1300        assert np.max(n_n) < nb, 'This is too many bands'
1301        assert np.min(n_n) >= na, 'This is too few bands'
1302
1303        # Load kpoints
1304        kd = self.calc.wfs.kd
1305        gd = self.calc.wfs.gd
1306        k_c = kd.bzk_kc[kpt.K] + kpt.shift_c
1307        k_v = 2 * np.pi * np.dot(k_c, np.linalg.inv(gd.cell_cv).T)
1308        atomdata_a = self.calc.wfs.setups
1309        f_n = kpt.f_n
1310
1311        # Only works with Fermi-Dirac distribution
1312        assert self.calc.wfs.occupations.name in {'fermi-dirac', 'zero-width'}
1313
1314        # No carriers when T=0
1315        width = getattr(self.calc.wfs.occupations, '_width', 0.0) / Ha
1316
1317        if width > 1e-15:
1318            dfde_n = -1 / width * (f_n - f_n**2.0)  # Analytical derivative
1319            partocc_n = np.abs(dfde_n) > 1e-5  # Is part. occupied?
1320        else:
1321            # Just include all bands to be sure
1322            partocc_n = np.ones(len(f_n), dtype=bool)
1323
1324        if only_partially_occupied and not partocc_n.any():
1325            return None
1326
1327        if only_partially_occupied:
1328            # Check for block par. consistency
1329            assert (partocc_n < nb).all(), \
1330                print('Include more unoccupied bands ', +
1331                      'or less block parr.', file=self.fd)
1332
1333        # Break bands into degenerate chunks
1334        degchunks_cn = []  # indexing c as chunk number
1335        for n in n_n:
1336            inds_n = np.nonzero(np.abs(kpt.eps_n[n - n1] -
1337                                       kpt.eps_n) < 1e-5)[0] + n1
1338
1339            # Has this chunk already been computed?
1340            oldchunk = any([n in chunk for chunk in degchunks_cn])
1341            if not oldchunk and \
1342               (partocc_n[n - n1] or not only_partially_occupied):
1343                assert all([ind in n_n for ind in inds_n]), \
1344                    print('\nYou are cutting over a degenerate band ' +
1345                          'using block parallelization.',
1346                          inds_n, n_n, file=self.fd)
1347                degchunks_cn.append((inds_n))
1348
1349        # Calculate matrix elements by diagonalizing each block
1350        for ind_n in degchunks_cn:
1351            deg = len(ind_n)
1352            ut_nvR = self.calc.wfs.gd.zeros((deg, 3), complex)
1353            vel_nnv = np.zeros((deg, deg, 3), dtype=complex)
1354            # States are included starting from kpt.na
1355            ut_nR = kpt.ut_nR[ind_n - na]
1356
1357            # Get derivatives
1358            for ind, ut_vR in zip(ind_n, ut_nvR):
1359                ut_vR[:] = self.make_derivative(kpt.s, kpt.K,
1360                                                ind, ind + 1)[0]
1361
1362            # Treat the whole degenerate chunk
1363            for n in range(deg):
1364                ut_vR = ut_nvR[n]
1365                C_avi = [np.dot(atomdata.nabla_iiv.T, P_ni[ind_n[n] - na])
1366                         for atomdata, P_ni in zip(atomdata_a, kpt.P_ani)]
1367
1368                nabla0_nv = -self.calc.wfs.gd.integrate(ut_vR, ut_nR).T
1369                nt_n = self.calc.wfs.gd.integrate(ut_nR[n], ut_nR)
1370                nabla0_nv += 1j * nt_n[:, np.newaxis] * k_v[np.newaxis, :]
1371
1372                for C_vi, P_ni in zip(C_avi, kpt.P_ani):
1373                    gemm(1.0, C_vi, P_ni[ind_n - na], 1.0, nabla0_nv, 'c')
1374
1375                vel_nnv[n] = -1j * nabla0_nv
1376
1377            for iv in range(3):
1378                vel, _ = np.linalg.eig(vel_nnv[..., iv])
1379                vel_nv[ind_n - na, iv] = vel  # Use eigenvalues
1380
1381        return vel_nv[n_n - na]
1382
1383    def get_fft_indices(self, K1, K2, q_c, pd, shift0_c):
1384        """Get indices for G-vectors inside cutoff sphere."""
1385        kd = self.calc.wfs.kd
1386        N_G = pd.Q_qG[0]
1387        shift_c = (shift0_c +
1388                   (q_c - kd.bzk_kc[K2] + kd.bzk_kc[K1]).round().astype(int))
1389        if shift_c.any():
1390            n_cG = np.unravel_index(N_G, pd.gd.N_c)
1391            n_cG = [n_G + shift for n_G, shift in zip(n_cG, shift_c)]
1392            N_G = np.ravel_multi_index(n_cG, pd.gd.N_c, 'wrap')
1393        return N_G
1394
1395    def construct_symmetry_operators(self, K, k_c=None):
1396        """Construct symmetry operators for wave function and PAW projections.
1397
1398        We want to transform a k-point in the irreducible part of the BZ to
1399        the corresponding k-point with index K.
1400
1401        Returns U_cc, T, a_a, U_aii, shift_c and time_reversal, where:
1402
1403        * U_cc is a rotation matrix.
1404        * T() is a function that transforms the periodic part of the wave
1405          function.
1406        * a_a is a list of symmetry related atom indices
1407        * U_aii is a list of rotation matrices for the PAW projections
1408        * shift_c is three integers: see code below.
1409        * time_reversal is a flag - if True, projections should be complex
1410          conjugated.
1411
1412        See the get_k_point() method for how to use these tuples.
1413        """
1414
1415        wfs = self.calc.wfs
1416        kd = wfs.kd
1417
1418        s = kd.sym_k[K]
1419        U_cc = kd.symmetry.op_scc[s]
1420        time_reversal = kd.time_reversal_k[K]
1421        ik = kd.bz2ibz_k[K]
1422        if k_c is None:
1423            k_c = kd.bzk_kc[K]
1424        ik_c = kd.ibzk_kc[ik]
1425
1426        sign = 1 - 2 * time_reversal
1427        shift_c = np.dot(U_cc, ik_c) - k_c * sign
1428
1429        try:
1430            assert np.allclose(shift_c.round(), shift_c)
1431        except AssertionError:
1432            print('shift_c ' + str(shift_c), file=self.fd)
1433            print('k_c ' + str(k_c), file=self.fd)
1434            print('kd.bzk_kc[K] ' + str(kd.bzk_kc[K]), file=self.fd)
1435            print('ik_c ' + str(ik_c), file=self.fd)
1436            print('U_cc ' + str(U_cc), file=self.fd)
1437            print('sign ' + str(sign), file=self.fd)
1438            raise AssertionError
1439
1440        shift_c = shift_c.round().astype(int)
1441
1442        if (U_cc == np.eye(3)).all():
1443            def T(f_R):
1444                return f_R
1445        else:
1446            N_c = self.calc.wfs.gd.N_c
1447            i_cr = np.dot(U_cc.T, np.indices(N_c).reshape((3, -1)))
1448            i = np.ravel_multi_index(i_cr, N_c, 'wrap')
1449
1450            def T(f_R):
1451                return f_R.ravel()[i].reshape(N_c)
1452
1453        if time_reversal:
1454            T0 = T
1455
1456            def T(f_R):
1457                return T0(f_R).conj()
1458            shift_c *= -1
1459
1460        a_a = []
1461        U_aii = []
1462        for a, id in enumerate(self.calc.wfs.setups.id_a):
1463            b = kd.symmetry.a_sa[s, a]
1464            S_c = np.dot(self.spos_ac[a], U_cc) - self.spos_ac[b]
1465            x = np.exp(2j * pi * np.dot(ik_c, S_c))
1466            U_ii = wfs.setups[a].R_sii[s].T * x
1467            a_a.append(b)
1468            U_aii.append(U_ii)
1469
1470        return U_cc, T, a_a, U_aii, shift_c, time_reversal
1471
1472    @timer('Initialize PAW corrections')
1473    def initialize_paw_corrections(self, pd, soft=False):
1474        wfs = self.calc.wfs
1475        q_v = pd.K_qv[0]
1476        optical_limit = np.allclose(q_v, 0) and self.response == 'density'
1477
1478        G_Gv = pd.get_reciprocal_vectors()
1479        if optical_limit:
1480            G_Gv[0] = 1
1481
1482        pos_av = np.dot(self.spos_ac, pd.gd.cell_cv)
1483
1484        # Collect integrals for all species:
1485        Q_xGii = {}
1486        for id, atomdata in wfs.setups.setups.items():
1487            if soft:
1488                ghat = PWLFC([atomdata.ghat_l], pd)
1489                ghat.set_positions(np.zeros((1, 3)))
1490                Q_LG = ghat.expand().T
1491                if atomdata.Delta_iiL is None:
1492                    ni = atomdata.ni
1493                    Q_Gii = np.zeros((Q_LG.shape[1], ni, ni))
1494                else:
1495                    Q_Gii = np.dot(atomdata.Delta_iiL, Q_LG).T
1496            else:
1497                ni = atomdata.ni
1498                if self.paw_correction == 'brute-force':
1499                    Q_Gii = two_phi_planewave_integrals(G_Gv, atomdata)
1500                    Q_Gii.shape = (-1, ni, ni)
1501                elif self.paw_correction == 'skip':
1502                    Q_Gii = np.zeros((len(G_Gv), ni, ni), complex)
1503                else:
1504                    1 / 0
1505
1506            Q_xGii[id] = Q_Gii
1507
1508        Q_aGii = []
1509        for a, atomdata in enumerate(wfs.setups):
1510            id = wfs.setups.id_a[a]
1511            Q_Gii = Q_xGii[id]
1512            x_G = np.exp(-1j * np.dot(G_Gv, pos_av[a]))
1513            Q_aGii.append(x_G[:, np.newaxis, np.newaxis] * Q_Gii)
1514            if optical_limit:
1515                Q_aGii[a][0] = atomdata.dO_ii
1516
1517        return Q_aGii
1518
1519    @timer('Initialize PAW corrections')
1520    def initialize_paw_nabla_corrections(self, pd, soft=False):
1521        print('Initializing nabla PAW Corrections', file=self.fd)
1522        wfs = self.calc.wfs
1523        G_Gv = pd.get_reciprocal_vectors()
1524        pos_av = np.dot(self.spos_ac, pd.gd.cell_cv)
1525
1526        # Collect integrals for all species:
1527        Q_xvGii = {}
1528        for id, atomdata in wfs.setups.setups.items():
1529            if soft:
1530                raise NotImplementedError
1531            else:
1532                Q_vGii = two_phi_nabla_planewave_integrals(G_Gv, atomdata)
1533                ni = atomdata.ni
1534                Q_vGii.shape = (3, -1, ni, ni)
1535
1536            Q_xvGii[id] = Q_vGii
1537
1538        Q_avGii = []
1539        for a, atomdata in enumerate(wfs.setups):
1540            id = wfs.setups.id_a[a]
1541            Q_vGii = Q_xvGii[id]
1542            x_G = np.exp(-1j * np.dot(G_Gv, pos_av[a]))
1543            Q_avGii.append(x_G[np.newaxis, :, np.newaxis, np.newaxis] * Q_vGii)
1544
1545        return Q_avGii
1546
1547    def calculate_derivatives(self, kpt):
1548        ut_sKnvR = [{}, {}]
1549        ut_nvR = self.make_derivative(kpt.s, kpt.K, kpt.n1, kpt.n2)
1550        ut_sKnvR[kpt.s][kpt.K] = ut_nvR
1551
1552        return ut_sKnvR
1553
1554    @timer('Derivatives')
1555    def make_derivative(self, s, K, n1, n2):
1556        wfs = self.calc.wfs
1557        if self.real_space_derivatives:
1558            grad_v = [Gradient(wfs.gd, v, 1.0, 4, complex).apply
1559                      for v in range(3)]
1560
1561        U_cc, T, a_a, U_aii, shift_c, time_reversal = \
1562            self.construct_symmetry_operators(K)
1563        A_cv = wfs.gd.cell_cv
1564        M_vv = np.dot(np.dot(A_cv.T, U_cc.T), np.linalg.inv(A_cv).T)
1565        ik = wfs.kd.bz2ibz_k[K]
1566        assert wfs.kd.comm.size == 1
1567        kpt = wfs.kpt_qs[ik][s]
1568        psit_nG = kpt.psit_nG
1569        iG_Gv = 1j * wfs.pd.get_reciprocal_vectors(q=ik, add_q=False)
1570        ut_nvR = wfs.gd.zeros((n2 - n1, 3), complex)
1571        for n in range(n1, n2):
1572            for v in range(3):
1573                if self.real_space_derivatives:
1574                    ut_R = T(wfs.pd.ifft(psit_nG[n], ik))
1575                    grad_v[v](ut_R, ut_nvR[n - n1, v],
1576                              np.ones((3, 2), complex))
1577                else:
1578                    ut_R = T(wfs.pd.ifft(iG_Gv[:, v] * psit_nG[n], ik))
1579                    for v2 in range(3):
1580                        ut_nvR[n - n1, v2] += ut_R * M_vv[v, v2]
1581
1582        return ut_nvR
1583