1import functools
2from io import StringIO
3from math import pi, sqrt
4from typing import List, Tuple
5
6import ase.units as units
7import numpy as np
8from ase.data import chemical_symbols
9
10from gpaw import debug
11from gpaw.basis_data import Basis
12from gpaw.gaunt import gaunt, nabla
13from gpaw.overlap import OverlapCorrections
14from gpaw.rotation import rotation
15from gpaw.setup_data import SetupData, search_for_file
16from gpaw.utilities import pack, unpack
17from gpaw.xc import XC
18
19
20def parse_hubbard_string(type: str) -> Tuple[str,
21                                             List[int],
22                                             List[float],
23                                             List[bool]]:
24    # Parse DFT+U parameters from type-string:
25    # Examples: "type:l,U" or "type:l,U,scale"
26    type, lus = type.split(':')
27    if type == '':
28        type = 'paw'
29
30    l = []
31    U = []
32    scale = []
33
34    for lu in lus.split(';'):  # Multiple U corrections
35        l_, u_, scale_ = (lu + ',,').split(',')[:3]
36        l.append('spdf'.find(l_))
37        U.append(float(u_) / units.Hartree)
38        if scale_:
39            scale.append(bool(int(scale_)))
40        else:
41            scale.append(True)
42    return type, l, U, scale
43
44
45def create_setup(symbol, xc='LDA', lmax=0,
46                 type='paw', basis=None, setupdata=None,
47                 filter=None, world=None):
48    if isinstance(xc, str):
49        xc = XC(xc)
50
51    if isinstance(type, str) and ':' in type:
52        type, l, U, scale = parse_hubbard_string(type)
53    else:
54        U = None
55
56    if setupdata is None:
57        if type == 'hgh' or type == 'hgh.sc':
58            lmax = 0
59            from gpaw.hgh import HGHSetupData, sc_setups, setups
60            if type == 'hgh.sc':
61                table = sc_setups
62            else:
63                table = setups
64            parameters = table[symbol]
65            setupdata = HGHSetupData(parameters)
66        elif type == 'ah':
67            from gpaw.ah import AppelbaumHamann
68            ah = AppelbaumHamann()
69            ah.build(basis)
70            return ah
71        elif type == 'ae':
72            from gpaw.ae import HydrogenAllElectronSetup
73            assert symbol == 'H'
74            ae = HydrogenAllElectronSetup()
75            ae.build(basis)
76            return ae
77        elif type == 'ghost':
78            from gpaw.lcao.bsse import GhostSetupData
79            setupdata = GhostSetupData(symbol)
80        elif type == 'sg15':
81            from gpaw.upf import read_sg15
82            upfname = f'{symbol}_ONCV_PBE-*.upf'
83            try:
84                upfpath, source = search_for_file(upfname, world=world)
85            except RuntimeError:
86                raise IOError('Could not find pseudopotential file %s '
87                              'in any GPAW search path.  '
88                              'Please install the SG15 setups using, '
89                              'e.g., "gpaw install-data".' % upfname)
90            setupdata = read_sg15(upfpath)
91            if xc.get_setup_name() != 'PBE':
92                raise ValueError('SG15 pseudopotentials support only the PBE '
93                                 'functional.  This calculation would use '
94                                 'the %s functional.' % xc.get_setup_name())
95        else:
96            setupdata = SetupData(symbol, xc.get_setup_name(),
97                                  type, True,
98                                  world=world)
99    if hasattr(setupdata, 'build'):
100        setup = LeanSetup(setupdata.build(xc, lmax, basis, filter))
101        if U is not None:
102            setup.set_hubbard_u(U, l, scale)
103        return setup
104    else:
105        return setupdata
106
107
108def correct_occ_numbers(f_j,
109                        degeneracy_j,
110                        jsorted,
111                        correction: float,
112                        eps=1e-12) -> None:
113    """Correct f_j ndarray in-place."""
114
115    if correction > 0:
116        # Add electrons to the lowest eigenstates:
117        for j in jsorted:
118            c = min(correction, degeneracy_j[j] - f_j[j])
119            f_j[j] += c
120            correction -= c
121            if correction < eps:
122                break
123    elif correction < 0:
124        # Add electrons to the highest eigenstates:
125        for j in jsorted[::-1]:
126            c = min(-correction, f_j[j])
127            f_j[j] -= c
128            correction += c
129            if correction > -eps:
130                break
131
132
133class LocalCorrectionVar:
134    """Class holding data for local the calculation of local corr."""
135    def __init__(self, s=None):
136        """Initialize our data."""
137        for work_key in ('nq', 'lcut', 'n_qg', 'nt_qg', 'nc_g', 'nct_g',
138                         'rgd2', 'Delta_lq', 'T_Lqp'):
139            if s is None or not hasattr(s, work_key):
140                setattr(self, work_key, None)
141            else:
142                setattr(self, work_key, getattr(s, work_key))
143
144
145class BaseSetup:
146    """Mixin-class for setups.
147
148    This makes it possible to inherit the most important methods without
149    the cumbersome constructor of the ordinary Setup class.
150
151    Maybe this class will be removed in the future, or it could be
152    made a proper base class with attributes and so on."""
153
154    orbital_free = False
155
156    def print_info(self, text):
157        self.data.print_info(text, self)
158
159    def get_basis_description(self):
160        return self.basis.get_description()
161
162    def get_partial_waves_for_atomic_orbitals(self):
163        """Get those states phit that represent a real atomic state.
164
165        This typically corresponds to the (truncated) partial waves (PAW) or
166        a single-zeta basis."""
167
168        # XXX ugly hack for pseudopotentials:
169        if not hasattr(self, 'pseudo_partial_waves_j'):
170            return []
171
172        # The zip may cut off part of phit_j if there are more states than
173        # projectors.  This should be the correct behaviour for all the
174        # currently supported PAW/pseudopotentials.
175        phit_j = []
176        for n, phit in zip(self.n_j, self.pseudo_partial_waves_j):
177            if n > 0:
178                phit_j.append(phit)
179        return phit_j
180
181    def calculate_initial_occupation_numbers(self, magmom, hund, charge,
182                                             nspins, f_j=None):
183        """If f_j is specified, custom occupation numbers will be used.
184
185        Hund rules disabled if so."""
186
187        nao = self.nao
188        f_si = np.zeros((nspins, nao))
189
190        assert (not hund) or f_j is None
191        if f_j is None:
192            f_j = self.f_j
193        f_j = np.array(f_j, float)
194        l_j = np.array(self.l_j)
195
196        if hasattr(self, 'data') and hasattr(self.data, 'eps_j'):
197            eps_j = np.array(self.data.eps_j)
198        else:
199            eps_j = np.ones(len(self.n_j))
200            # Bound states:
201            for j, n in enumerate(self.n_j):
202                if n > 0:
203                    eps_j[j] = -1.0
204
205        deg_j = 2 * (2 * l_j + 1)
206
207        # Sort after:
208        #
209        # 1) empty state (f == 0)
210        # 2) open shells (d - f)
211        # 3) eigenvalues (e)
212
213        states = []
214        for j, (f, d, e) in enumerate(zip(f_j, deg_j, eps_j)):
215            if e < 0.0:
216                states.append((f == 0, d - f, e, j))
217        states.sort()
218        jsorted = [j for _, _, _, j in states]
219
220        # if len(l_j) == 0:
221        #     l_j = np.ones(1)
222
223        # distribute the charge to the radial orbitals
224        if nspins == 1:
225            assert magmom == 0.0
226            f_sj = np.array([f_j])
227            if not self.orbital_free:
228                correct_occ_numbers(f_sj[0], deg_j, jsorted, -charge)
229            else:
230                # ofdft degeneracy of one orbital is infinite
231                f_sj[0] += -charge
232        else:
233            nval = f_j.sum() - charge
234            if np.abs(magmom) > nval:
235                raise RuntimeError(
236                    'Magnetic moment larger than number ' +
237                    f'of valence electrons (|{magmom:g}| > {nval:g})')
238            f_sj = 0.5 * np.array([f_j, f_j])
239            nup = 0.5 * (nval + magmom)
240            ndn = 0.5 * (nval - magmom)
241            deg_j //= 2
242            correct_occ_numbers(f_sj[0], deg_j, jsorted, nup - f_sj[0].sum())
243            correct_occ_numbers(f_sj[1], deg_j, jsorted, ndn - f_sj[1].sum())
244
245        # Projector function indices:
246        nj = len(self.n_j)  # or l_j?  Seriously.
247
248        # distribute to the atomic wave functions
249        i = 0
250        j = 0
251        for phit in self.phit_j:
252            l = phit.get_angular_momentum_number()
253
254            # Skip functions not in basis set:
255            while j < nj and self.l_orb_j[j] != l:
256                j += 1
257            if j < len(f_j):  # lengths of f_j and l_j may differ
258                f = f_j[j]
259                f_s = f_sj[:, j]
260            else:
261                f = 0
262                f_s = np.array([0, 0])
263
264            degeneracy = 2 * l + 1
265
266            if hund:
267                # Use Hunds rules:
268                # assert f == int(f)
269                f = int(f)
270                f_si[0, i:i + min(f, degeneracy)] = 1.0      # spin up
271                f_si[1, i:i + max(f - degeneracy, 0)] = 1.0  # spin down
272                if f < degeneracy:
273                    magmom -= f
274                else:
275                    magmom -= 2 * degeneracy - f
276            else:
277                for s in range(nspins):
278                    f_si[s, i:i + degeneracy] = f_s[s] / degeneracy
279
280            i += degeneracy
281            j += 1
282
283        if hund and magmom != 0:
284            raise ValueError(
285                f'Bad magnetic moment {magmom:g} for {self.symbol} atom!')
286        assert i == nao
287
288        # print('fsi=', f_si)
289        return f_si
290
291    def get_hunds_rule_moment(self, charge=0):
292        for M in range(10):
293            try:
294                self.calculate_initial_occupation_numbers(M, True, charge, 2)
295            except ValueError:
296                pass
297            else:
298                return M
299        raise RuntimeError
300
301    def initialize_density_matrix(self, f_si):
302        nspins, nao = f_si.shape
303        ni = self.ni
304
305        D_sii = np.zeros((nspins, ni, ni))
306        D_sp = np.zeros((nspins, ni * (ni + 1) // 2))
307        nj = len(self.pt_j)
308        j = 0
309        i = 0
310        ib = 0
311        for phit in self.phit_j:
312            l = phit.get_angular_momentum_number()
313            # Skip functions not in basis set:
314            while j < nj and self.l_j[j] != l:
315                i += 2 * self.l_j[j] + 1
316                j += 1
317            if j == nj:
318                break
319
320            for m in range(2 * l + 1):
321                D_sii[:, i + m, i + m] = f_si[:, ib + m]
322            j += 1
323            i += 2 * l + 1
324            ib += 2 * l + 1
325        for s in range(nspins):
326            D_sp[s] = pack(D_sii[s])
327        return D_sp
328
329    def symmetrize(self, a, D_aii, map_sa):
330        D_ii = np.zeros((self.ni, self.ni))
331        for s, R_ii in enumerate(self.R_sii):
332            D_ii += np.dot(R_ii, np.dot(D_aii[map_sa[s][a]],
333                                        np.transpose(R_ii)))
334        return D_ii / len(map_sa)
335
336    def calculate_rotations(self, R_slmm):
337        nsym = len(R_slmm)
338        self.R_sii = np.zeros((nsym, self.ni, self.ni))
339        i1 = 0
340        for l in self.l_j:
341            i2 = i1 + 2 * l + 1
342            for s, R_lmm in enumerate(R_slmm):
343                self.R_sii[s, i1:i2, i1:i2] = R_lmm[l]
344            i1 = i2
345
346    def get_partial_waves(self):
347        """Return spline representation of partial waves and densities."""
348
349        l_j = self.l_j
350
351        # cutoffs
352        rcut2 = 2 * max(self.rcut_j)
353        gcut2 = self.rgd.ceil(rcut2)
354
355        data = self.data
356
357        # Construct splines:
358        nc_g = data.nc_g.copy()
359        nct_g = data.nct_g.copy()
360        tauc_g = data.tauc_g
361        tauct_g = data.tauct_g
362        nc = self.rgd.spline(nc_g, rcut2, points=1000)
363        nct = self.rgd.spline(nct_g, rcut2, points=1000)
364        if tauc_g is None:
365            tauc_g = np.zeros(nct_g.shape)
366            tauct_g = tauc_g
367        tauc = self.rgd.spline(tauc_g, rcut2, points=1000)
368        tauct = self.rgd.spline(tauct_g, rcut2, points=1000)
369        phi_j = []
370        phit_j = []
371        for j, (phi_g, phit_g) in enumerate(zip(data.phi_jg, data.phit_jg)):
372            l = l_j[j]
373            phi_g = phi_g.copy()
374            phit_g = phit_g.copy()
375            phi_g[gcut2:] = phit_g[gcut2:] = 0.0
376            phi_j.append(self.rgd.spline(phi_g, rcut2, l, points=100))
377            phit_j.append(self.rgd.spline(phit_g, rcut2, l, points=100))
378        return phi_j, phit_j, nc, nct, tauc, tauct
379
380    def set_hubbard_u(self, U, l, scale=1, store=0, LinRes=0):
381        """Set Hubbard parameter.
382        U in atomic units, l is the orbital to which we whish to
383        add a hubbard potential and scale enables or desables the
384        scaling of the overlap between the l orbitals, if true we enforce
385        <p|p>=1
386        Note U is in atomic units
387        """
388
389        self.HubLinRes = LinRes
390        self.Hubs = scale
391        self.HubStore = store
392        self.HubOcc = []
393        self.HubU = U
394        self.Hubl = l
395        self.Hubi = 0
396        for ll in self.l_j:
397            if ll == self.Hubl:
398                break
399            self.Hubi = self.Hubi + 2 * ll + 1
400
401    def four_phi_integrals(self):
402        """Calculate four-phi integral.
403
404        Calculate the integral over the product of four all electron
405        functions in the augmentation sphere, i.e.::
406
407          /
408          | d vr  ( phi_i1 phi_i2 phi_i3 phi_i4
409          /         - phit_i1 phit_i2 phit_i3 phit_i4 ),
410
411        where phi_i1 is an all electron function and phit_i1 is its
412        smooth partner.
413        """
414        if hasattr(self, 'I4_pp'):
415            return self.I4_pp
416
417        # radial grid
418        r2dr_g = self.rgd.r_g**2 * self.rgd.dr_g
419
420        phi_jg = self.data.phi_jg
421        phit_jg = self.data.phit_jg
422
423        # compute radial parts
424        nj = len(self.l_j)
425        R_jjjj = np.empty((nj, nj, nj, nj))
426        for j1 in range(nj):
427            for j2 in range(nj):
428                for j3 in range(nj):
429                    for j4 in range(nj):
430                        R_jjjj[j1, j2, j3, j4] = np.dot(
431                            r2dr_g,
432                            phi_jg[j1] * phi_jg[j2] *
433                            phi_jg[j3] * phi_jg[j4] -
434                            phit_jg[j1] * phit_jg[j2] *
435                            phit_jg[j3] * phit_jg[j4])
436
437        # prepare for angular parts
438        L_i = []
439        j_i = []
440        for j, l in enumerate(self.l_j):
441            for m in range(2 * l + 1):
442                L_i.append(l**2 + m)
443                j_i.append(j)
444        ni = len(L_i)
445        # j_i is the list of j values
446        # L_i is the list of L (=l**2+m for 0<=m<2*l+1) values
447        # https://wiki.fysik.dtu.dk/gpaw/devel/overview.html
448
449        G_LLL = gaunt(max(self.l_j))
450
451        # calculate the integrals
452        _np = ni * (ni + 1) // 2  # length for packing
453        self.I4_pp = np.empty((_np, _np))
454        p1 = 0
455        for i1 in range(ni):
456            L1 = L_i[i1]
457            j1 = j_i[i1]
458            for i2 in range(i1, ni):
459                L2 = L_i[i2]
460                j2 = j_i[i2]
461                p2 = 0
462                for i3 in range(ni):
463                    L3 = L_i[i3]
464                    j3 = j_i[i3]
465                    for i4 in range(i3, ni):
466                        L4 = L_i[i4]
467                        j4 = j_i[i4]
468                        self.I4_pp[p1, p2] = (np.dot(G_LLL[L1, L2],
469                                                     G_LLL[L3, L4]) *
470                                              R_jjjj[j1, j2, j3, j4])
471                        p2 += 1
472                p1 += 1
473
474        # To unpack into I4_iip do:
475        # from gpaw.utilities import unpack
476        # I4_iip = np.empty((ni, ni, _np)):
477        # for p in range(_np):
478        #     I4_iip[..., p] = unpack(I4_pp[:, p])
479
480        return self.I4_pp
481
482    def get_default_nbands(self):
483        assert len(self.l_orb_j) == len(self.n_j), (self.l_orb_j, self.n_j)
484        return sum([2 * l + 1 for (l, n) in zip(self.l_orb_j, self.n_j)
485                    if n > 0])
486
487    def calculate_coulomb_corrections(self, wn_lqg, wnt_lqg, wg_lg, wnc_g,
488                                      wmct_g):
489        """Calculate "Coulomb" energies."""
490        # Can we reduce the excessive parameter passing?
491        # Seems so ....
492        # Added instance variables
493        # T_Lqp = self.local_corr.T_Lqp
494        # n_qg = self.local_corr.n_qg
495        # Delta_lq = self.local_corr.Delta_lq
496        # nt_qg = self.local_corr.nt_qg
497        # Local variables derived from instance variables
498        _np = self.ni * (self.ni + 1) // 2  # change to inst. att.?
499        mct_g = self.local_corr.nct_g + self.Delta0 * self.g_lg[0]  # s.a.
500        rdr_g = self.local_corr.rgd2.r_g * \
501            self.local_corr.rgd2.dr_g  # change to inst. att.?
502
503        A_q = 0.5 * (np.dot(wn_lqg[0], self.local_corr.nc_g) + np.dot(
504            self.local_corr.n_qg, wnc_g))
505        A_q -= sqrt(4 * pi) * self.Z * np.dot(self.local_corr.n_qg, rdr_g)
506        A_q -= 0.5 * (np.dot(wnt_lqg[0], mct_g) +
507                      np.dot(self.local_corr.nt_qg, wmct_g))
508        A_q -= 0.5 * (np.dot(mct_g, wg_lg[0]) +
509                      np.dot(self.g_lg[0], wmct_g)) * \
510            self.local_corr.Delta_lq[0]
511        M_p = np.dot(A_q, self.local_corr.T_Lqp[0])
512
513        A_lqq = []
514        for l in range(2 * self.local_corr.lcut + 1):
515            A_qq = 0.5 * np.dot(self.local_corr.n_qg, np.transpose(wn_lqg[l]))
516            A_qq -= 0.5 * np.dot(self.local_corr.nt_qg,
517                                 np.transpose(wnt_lqg[l]))
518            if l <= self.lmax:
519                A_qq -= 0.5 * np.outer(self.local_corr.Delta_lq[l],
520                                       np.dot(wnt_lqg[l], self.g_lg[l]))
521                A_qq -= 0.5 * np.outer(np.dot(self.local_corr.nt_qg,
522                                              wg_lg[l]),
523                                       self.local_corr.Delta_lq[l])
524                A_qq -= 0.5 * np.dot(self.g_lg[l], wg_lg[l]) * \
525                    np.outer(self.local_corr.Delta_lq[l],
526                             self.local_corr.Delta_lq[l])
527            A_lqq.append(A_qq)
528
529        M_pp = np.zeros((_np, _np))
530        L = 0
531        for l in range(2 * self.local_corr.lcut + 1):
532            for m in range(2 * l + 1):  # m?
533                M_pp += np.dot(np.transpose(self.local_corr.T_Lqp[L]),
534                               np.dot(A_lqq[l], self.local_corr.T_Lqp[L]))
535                L += 1
536
537        return M_p, M_pp
538
539    def calculate_integral_potentials(self, func):
540        """Calculates a set of potentials using func."""
541        wg_lg = [func(self, self.g_lg[l], l)
542                 for l in range(self.lmax + 1)]
543        wn_lqg = [np.array([func(self, self.local_corr.n_qg[q], l)
544                            for q in range(self.local_corr.nq)])
545                  for l in range(2 * self.local_corr.lcut + 1)]
546        wnt_lqg = [np.array([func(self, self.local_corr.nt_qg[q], l)
547                             for q in range(self.local_corr.nq)])
548                   for l in range(2 * self.local_corr.lcut + 1)]
549        wnc_g = func(self, self.local_corr.nc_g, l=0)
550        wnct_g = func(self, self.local_corr.nct_g, l=0)
551        wmct_g = wnct_g + self.Delta0 * wg_lg[0]
552        return wg_lg, wn_lqg, wnt_lqg, wnc_g, wnct_g, wmct_g
553
554    def calculate_yukawa_interaction(self, gamma):
555        """Calculate and return the Yukawa based interaction."""
556        if self._Mg_pp is not None and gamma == self._gamma:
557            return self._Mg_pp  # Cached
558
559        # Solves the radial screened poisson equation for density n_g
560        def Yuk(self, n_g, l):
561            """Solve radial screened poisson for density n_g."""
562            gamma = self._gamma
563            return self.local_corr.rgd2.yukawa(n_g, l, gamma) * \
564                self.local_corr.rgd2.r_g * self.local_corr.rgd2.dr_g
565
566        self._gamma = gamma
567        (wg_lg, wn_lqg, wnt_lqg, wnc_g, wnct_g, wmct_g) = \
568            self.calculate_integral_potentials(Yuk)
569        self._Mg_pp = self.calculate_coulomb_corrections(
570            wn_lqg, wnt_lqg, wg_lg, wnc_g, wmct_g)[1]
571        return self._Mg_pp
572
573
574class LeanSetup(BaseSetup):
575    """Setup class with minimal attribute set.
576
577    A setup-like class must define at least the attributes of this
578    class in order to function in a calculation."""
579    def __init__(self, s):
580        """Copies precisely the necessary attributes of the Setup s."""
581        # R_sii and HubU can be changed dynamically (which is ugly)
582        self.R_sii = None  # rotations, initialized when doing sym. reductions
583        self.HubU = s.HubU  # XXX probably None
584        self.lq = s.lq  # Required for LDA+U I think.
585        self.type = s.type  # required for writing to file
586        self.fingerprint = s.fingerprint  # also req. for writing
587        self.filename = s.filename
588
589        self.symbol = s.symbol
590        self.Z = s.Z
591        self.Nv = s.Nv
592        self.Nc = s.Nc
593
594        self.ni = s.ni
595        self.nao = s.nao
596
597        self.pt_j = s.pt_j
598        self.phit_j = s.phit_j  # basis functions
599
600        self.Nct = s.Nct
601        self.nct = s.nct
602
603        self.lmax = s.lmax
604        self.ghat_l = s.ghat_l
605        self.vbar = s.vbar
606
607        self.Delta_pL = s.Delta_pL
608        self.Delta0 = s.Delta0
609
610        self.E = s.E
611        self.Kc = s.Kc
612
613        self.M = s.M
614        self.M_p = s.M_p
615        self.M_pp = s.M_pp
616        self.K_p = s.K_p
617        self.MB = s.MB
618        self.MB_p = s.MB_p
619
620        self.dO_ii = s.dO_ii
621
622        self.xc_correction = s.xc_correction
623
624        # Required to calculate initial occupations
625        self.f_j = s.f_j
626        self.n_j = s.n_j
627        self.l_j = s.l_j
628        self.l_orb_j = s.l_orb_j
629        self.nj = len(s.l_j)
630
631        self.data = s.data
632
633        # Below are things which are not really used all that much,
634        # i.e. shouldn't generally be necessary.  Maybe we can make a system
635        # involving dictionaries for these "optional" parameters
636
637        # Required by print_info
638        self.rcutfilter = s.rcutfilter
639        self.rcore = s.rcore
640        self.basis = s.basis  # we don't need nao if we use this instead
641
642        # XXX figure out better way to store these.
643        # Refactoring: We should delete this and use psit_j.  However
644        # the code depends on psit_j being the *basis* functions sometimes.
645        if hasattr(s, 'pseudo_partial_waves_j'):
646            self.pseudo_partial_waves_j = s.pseudo_partial_waves_j
647        # Can also get rid of the phit_j splines if need be
648
649        self.N0_p = s.N0_p  # req. by estimate_magnetic_moments
650        self.nabla_iiv = s.nabla_iiv  # req. by lrtddft
651        self.rxnabla_iiv = s.rxnabla_iiv  # req. by lrtddft and lrtddft2
652
653        # XAS stuff
654        self.phicorehole_g = s.phicorehole_g  # should be optional
655        if s.phicorehole_g is not None:
656            self.A_ci = s.A_ci  # oscillator strengths
657
658        # Required to get all electron density
659        self.rgd = s.rgd
660        self.rcut_j = s.rcut_j
661
662        self.tauct = s.tauct  # required by TPSS, MGGA
663
664        self.Delta_iiL = s.Delta_iiL  # required with external potential
665
666        self.B_ii = s.B_ii  # required for exact inverse overlap operator
667        self.dC_ii = s.dC_ii  # required by time-prop tddft with apply_inverse
668
669        # Required by exx
670        self.X_p = s.X_p
671        self.ExxC = s.ExxC
672
673        # Required by yukawa rsf
674        self.X_pg = s.X_pg
675        self.X_gamma = s.X_gamma
676
677        # Required by electrostatic correction
678        self.dEH0 = s.dEH0
679        self.dEH_p = s.dEH_p
680
681        # Required by utilities/kspot.py (AllElectronPotential)
682        self.g_lg = s.g_lg
683
684        # Probably empty dictionary, required by GLLB
685        self.extra_xc_data = s.extra_xc_data
686
687        self.orbital_free = s.orbital_free
688
689        # Stuff required by Yukawa RSF to calculate Mg_pp at runtime
690        # the calcualtion of Mg_pp at rt is needed for dscf
691        if hasattr(s, 'local_corr'):
692            self.local_corr = s.local_corr
693        else:
694            self.local_corr = LocalCorrectionVar(s)
695        self._Mg_pp = None
696        self._gamma = 0
697
698
699class Setup(BaseSetup):
700    """Attributes:
701
702    ========== =====================================================
703    Name       Description
704    ========== =====================================================
705    ``Z``      Charge
706    ``type``   Type-name of setup (eg. 'paw')
707    ``symbol`` Chemical element label (eg. 'Mg')
708    ``xcname`` Name of xc
709    ``data``   Container class for information on the the atom, eg.
710               Nc, Nv, n_j, l_j, f_j, eps_j, rcut_j.
711               It defines the radial grid by ng and beta, from which
712               r_g = beta * arange(ng) / (ng - arange(ng)).
713               It stores pt_jg, phit_jg, phi_jg, vbar_g
714    ========== =====================================================
715
716
717    Attributes for making PAW corrections
718
719    ============= ==========================================================
720    Name          Description
721    ============= ==========================================================
722    ``Delta0``    Constant in compensation charge expansion coeff.
723    ``Delta_iiL`` Linear term in compensation charge expansion coeff.
724    ``Delta_pL``  Packed version of ``Delta_iiL``.
725    ``dO_ii``     Overlap coefficients
726    ``B_ii``      Projector function overlaps B_ii = <pt_i | pt_i>
727    ``dC_ii``     Inverse overlap coefficients
728    ``E``         Reference total energy of atom
729    ``M``         Constant correction to Coulomb energy
730    ``M_p``       Linear correction to Coulomb energy
731    ``M_pp``      2nd order correction to Coulomb energy and Exx energy
732    ``Kc``        Core kinetic energy
733    ``K_p``       Linear correction to kinetic energy
734    ``ExxC``      Core Exx energy
735    ``X_p``       Linear correction to Exx energy
736    ``MB``        Constant correction due to vbar potential
737    ``MB_p``      Linear correction due to vbar potential
738    ``dEH0``      Constant correction due to average electrostatic potential
739    ``dEH_p``     Linear correction due to average electrostatic potential
740    ``I4_iip``    Correction to integrals over 4 all electron wave functions
741    ``Nct``       Analytical integral of the pseudo core density ``nct``
742    ============= ==========================================================
743
744    It also has the attribute ``xc_correction`` which is an XCCorrection class
745    instance capable of calculating the corrections due to the xc functional.
746
747
748    Splines:
749
750    ========== ============================================
751    Name       Description
752    ========== ============================================
753    ``pt_j``   Projector functions
754    ``phit_j`` Pseudo partial waves
755    ``vbar``   vbar potential
756    ``nct``    Pseudo core density
757    ``ghat_l`` Compensation charge expansion functions
758    ``tauct``  Pseudo core kinetic energy density
759    ========== ============================================
760    """
761    def __init__(self, data, xc, lmax=0, basis=None, filter=None):
762        self.type = data.name
763
764        self.HubU = None
765
766        if not data.is_compatible(xc):
767            raise ValueError('Cannot use %s setup with %s functional' %
768                             (data.setupname, xc.get_setup_name()))
769
770        self.symbol = data.symbol
771        self.data = data
772
773        self.Nc = data.Nc
774        self.Nv = data.Nv
775        self.Z = data.Z
776        l_j = self.l_j = data.l_j
777        self.l_orb_j = data.l_orb_j
778        n_j = self.n_j = data.n_j
779        self.f_j = data.f_j
780        self.eps_j = data.eps_j
781        nj = self.nj = len(l_j)
782        rcut_j = self.rcut_j = data.rcut_j
783
784        self.ExxC = data.ExxC
785        self.X_p = data.X_p
786
787        self.X_gamma = data.X_gamma
788        self.X_pg = data.X_pg
789
790        self.orbital_free = data.orbital_free
791
792        pt_jg = data.pt_jg
793        phit_jg = data.phit_jg
794        phi_jg = data.phi_jg
795
796        self.fingerprint = data.fingerprint
797        self.filename = data.filename
798
799        rgd = self.rgd = data.rgd
800        r_g = rgd.r_g
801        dr_g = rgd.dr_g
802
803        self.lmax = lmax
804
805        self._Mg_pp = None  # Yukawa based corrections
806        self._gamma = 0
807        # Attributes for run-time calculation of _Mg_pp
808        self.local_corr = LocalCorrectionVar(data)
809
810        rcutmax = max(rcut_j)
811        rcut2 = 2 * rcutmax
812        gcut2 = rgd.ceil(rcut2)
813        self.gcut2 = gcut2
814
815        self.gcutmin = rgd.ceil(min(rcut_j))
816
817        vbar_g = data.vbar_g
818
819        if float(data.version) < 0.7 and data.generator_version < 2:
820            # Old-style Fourier-filtered datatsets.
821            # Find Fourier-filter cutoff radius:
822            gcutfilter = rgd.get_cutoff(pt_jg[0])
823
824        elif filter:
825            rc = rcutmax
826            vbar_g = vbar_g.copy()
827            filter(rgd, rc, vbar_g)
828
829            pt_jg = [pt_g.copy() for pt_g in pt_jg]
830            for l, pt_g in zip(l_j, pt_jg):
831                filter(rgd, rc, pt_g, l)
832
833            for l in range(max(l_j) + 1):
834                J = [j for j, lj in enumerate(l_j) if lj == l]
835                A_nn = [[rgd.integrate(phit_jg[j1] * pt_jg[j2]) / 4 / pi
836                         for j1 in J] for j2 in J]
837                B_nn = np.linalg.inv(A_nn)
838                pt_ng = np.dot(B_nn, [pt_jg[j] for j in J])
839                for n, j in enumerate(J):
840                    pt_jg[j] = pt_ng[n]
841            gcutfilter = rgd.get_cutoff(pt_jg[0])
842        else:
843            gcutfilter = rgd.ceil(max(rcut_j))
844
845        if (vbar_g[gcutfilter:] != 0.0).any():
846            gcutfilter = rgd.get_cutoff(vbar_g)
847            assert r_g[gcutfilter] < 2.0 * max(rcut_j)
848
849        self.rcutfilter = rcutfilter = r_g[gcutfilter]
850
851        ni = 0
852        i = 0
853        j = 0
854        jlL_i = []
855        for l, n in zip(l_j, n_j):
856            for m in range(2 * l + 1):
857                jlL_i.append((j, l, l**2 + m))
858                i += 1
859            j += 1
860        ni = i
861        self.ni = ni
862
863        _np = ni * (ni + 1) // 2
864        self.local_corr.nq = nj * (nj + 1) // 2
865
866        lcut = max(l_j)
867        if 2 * lcut < lmax:
868            lcut = (lmax + 1) // 2
869        self.local_corr.lcut = lcut
870
871        self.B_ii = self.calculate_projector_overlaps(pt_jg)
872
873        self.fcorehole = data.fcorehole
874        self.lcorehole = data.lcorehole
875        if data.phicorehole_g is not None:
876            if self.lcorehole == 0:
877                self.calculate_oscillator_strengths(phi_jg)
878            else:
879                self.A_ci = None
880
881        # Construct splines:
882        self.vbar = rgd.spline(vbar_g, rcutfilter)
883
884        rcore, nc_g, nct_g, nct = self.construct_core_densities(data)
885        self.rcore = rcore
886        self.nct = nct
887
888        # Construct splines for core kinetic energy density:
889        tauct_g = data.tauct_g
890        if tauct_g is not None:
891            self.tauct = rgd.spline(tauct_g, self.rcore)
892        else:
893            self.tauct = None
894
895        self.pt_j = self.create_projectors(pt_jg, rcutfilter)
896
897        partial_waves = self.create_basis_functions(phit_jg, rcut2, gcut2)
898        self.pseudo_partial_waves_j = partial_waves.tosplines()
899
900        if basis is None:
901            phit_j = self.pseudo_partial_waves_j
902            basis = partial_waves
903        else:
904            phit_j = basis.tosplines()
905        self.phit_j = phit_j
906        self.basis = basis
907
908        self.nao = 0
909        for phit in self.phit_j:
910            l = phit.get_angular_momentum_number()
911            self.nao += 2 * l + 1
912
913        rgd2 = self.local_corr.rgd2 = rgd.new(gcut2)
914        r_g = rgd2.r_g
915        dr_g = rgd2.dr_g
916        phi_jg = np.array([phi_g[:gcut2].copy() for phi_g in phi_jg])
917        phit_jg = np.array([phit_g[:gcut2].copy() for phit_g in phit_jg])
918        self.local_corr.nc_g = nc_g = nc_g[:gcut2].copy()
919        self.local_corr.nct_g = nct_g = nct_g[:gcut2].copy()
920        vbar_g = vbar_g[:gcut2].copy()
921
922        extra_xc_data = dict(data.extra_xc_data)
923        # Cut down the GLLB related extra data
924        for key, item in extra_xc_data.items():
925            if len(item) == rgd.N:
926                extra_xc_data[key] = item[:gcut2].copy()
927        self.extra_xc_data = extra_xc_data
928
929        self.phicorehole_g = data.phicorehole_g
930        if self.phicorehole_g is not None:
931            self.phicorehole_g = self.phicorehole_g[:gcut2].copy()
932
933        self.local_corr.T_Lqp = self.calculate_T_Lqp(lcut, _np, nj, jlL_i)
934        #  set the attributes directly?
935        (self.g_lg, self.local_corr.n_qg, self.local_corr.nt_qg,
936         self.local_corr.Delta_lq, self.Lmax, self.Delta_pL, self.Delta0,
937         self.N0_p) = self.get_compensation_charges(phi_jg, phit_jg, _np,
938                                                    self.local_corr.T_Lqp)
939
940        # Solves the radial poisson equation for density n_g
941        def H(self, n_g, l):
942            return rgd2.poisson(n_g, l) * r_g * dr_g
943
944        (wg_lg, wn_lqg, wnt_lqg, wnc_g, wnct_g, wmct_g) = \
945            self.calculate_integral_potentials(H)
946        self.wg_lg = wg_lg
947
948        rdr_g = r_g * dr_g
949        dv_g = r_g * rdr_g
950        A = 0.5 * np.dot(nc_g, wnc_g)
951        A -= sqrt(4 * pi) * self.Z * np.dot(rdr_g, nc_g)
952        mct_g = nct_g + self.Delta0 * self.g_lg[0]
953        # wmct_g = wnct_g + self.Delta0 * wg_lg[0]
954        A -= 0.5 * np.dot(mct_g, wmct_g)
955        self.M = A
956        self.MB = -np.dot(dv_g * nct_g, vbar_g)
957
958        AB_q = -np.dot(self.local_corr.nt_qg, dv_g * vbar_g)
959        self.MB_p = np.dot(AB_q, self.local_corr.T_Lqp[0])
960
961        # Correction for average electrostatic potential:
962        #
963        #   dEH = dEH0 + dot(D_p, dEH_p)
964        #
965        self.dEH0 = sqrt(4 * pi) * (wnc_g - wmct_g -
966                                    sqrt(4 * pi) * self.Z * r_g * dr_g).sum()
967        dEh_q = (wn_lqg[0].sum(1) - wnt_lqg[0].sum(1) -
968                 self.local_corr.Delta_lq[0] * wg_lg[0].sum())
969        self.dEH_p = np.dot(dEh_q, self.local_corr.T_Lqp[0]) * sqrt(4 * pi)
970
971        M_p, M_pp = self.calculate_coulomb_corrections(wn_lqg, wnt_lqg,
972                                                       wg_lg, wnc_g, wmct_g)
973        self.M_p = M_p
974        self.M_pp = M_pp
975
976        if xc.type == 'GLLB':
977            if 'core_f' in self.extra_xc_data:
978                self.wnt_lqg = wnt_lqg
979                self.wn_lqg = wn_lqg
980                self.fc_j = self.extra_xc_data['core_f']
981                self.lc_j = self.extra_xc_data['core_l']
982                self.njcore = len(self.lc_j)
983                if self.njcore > 0:
984                    self.uc_jg = self.extra_xc_data['core_states'].reshape(
985                        (self.njcore, -1))
986                    self.uc_jg = self.uc_jg[:, :gcut2]
987                self.phi_jg = phi_jg
988
989        self.Kc = data.e_kinetic_core - data.e_kinetic
990        self.M -= data.e_electrostatic
991        self.E = data.e_total
992
993        Delta0_ii = unpack(self.Delta_pL[:, 0].copy())
994        self.dO_ii = data.get_overlap_correction(Delta0_ii)
995        self.dC_ii = self.get_inverse_overlap_coefficients(self.B_ii,
996                                                           self.dO_ii)
997
998        self.Delta_iiL = np.zeros((ni, ni, self.Lmax))
999        for L in range(self.Lmax):
1000            self.Delta_iiL[:, :, L] = unpack(self.Delta_pL[:, L].copy())
1001
1002        self.Nct = data.get_smooth_core_density_integral(self.Delta0)
1003        self.K_p = data.get_linear_kinetic_correction(self.local_corr.T_Lqp[0])
1004
1005        self.ghat_l = [rgd2.spline(g_g, rcut2, l, 50)
1006                       for l, g_g in enumerate(self.g_lg)]
1007
1008        self.xc_correction = data.get_xc_correction(rgd2, xc, gcut2, lcut)
1009        self.nabla_iiv = self.get_derivative_integrals(rgd2, phi_jg, phit_jg)
1010        self.rxnabla_iiv = self.get_magnetic_integrals(rgd2, phi_jg, phit_jg)
1011
1012    def create_projectors(self, pt_jg, rcut):
1013        pt_j = []
1014        for j, pt_g in enumerate(pt_jg):
1015            l = self.l_j[j]
1016            pt_j.append(self.rgd.spline(pt_g, rcut, l))
1017        return pt_j
1018
1019    def get_inverse_overlap_coefficients(self, B_ii, dO_ii):
1020        ni = len(B_ii)
1021        xO_ii = np.dot(B_ii, dO_ii)
1022        return -np.dot(dO_ii, np.linalg.inv(np.identity(ni) + xO_ii))
1023
1024    def calculate_T_Lqp(self, lcut, _np, nj, jlL_i):
1025        Lcut = (2 * lcut + 1)**2
1026        G_LLL = gaunt(max(self.l_j))[:, :, :Lcut]
1027        LGcut = G_LLL.shape[2]
1028        T_Lqp = np.zeros((Lcut, self.local_corr.nq, _np))
1029        p = 0
1030        i1 = 0
1031        for j1, l1, L1 in jlL_i:
1032            for j2, l2, L2 in jlL_i[i1:]:
1033                if j1 < j2:
1034                    q = j2 + j1 * nj - j1 * (j1 + 1) // 2
1035                else:
1036                    q = j1 + j2 * nj - j2 * (j2 + 1) // 2
1037                T_Lqp[:LGcut, q, p] = G_LLL[L1, L2]
1038                p += 1
1039            i1 += 1
1040        return T_Lqp
1041
1042    def calculate_projector_overlaps(self, pt_jg):
1043        """Compute projector function overlaps B_ii = <pt_i | pt_i>."""
1044        nj = len(pt_jg)
1045        B_jj = np.zeros((nj, nj))
1046        for j1, pt1_g in enumerate(pt_jg):
1047            for j2, pt2_g in enumerate(pt_jg):
1048                B_jj[j1, j2] = self.rgd.integrate(pt1_g * pt2_g) / (4 * pi)
1049        B_ii = np.zeros((self.ni, self.ni))
1050        i1 = 0
1051        for j1, l1 in enumerate(self.l_j):
1052            for m1 in range(2 * l1 + 1):
1053                i2 = 0
1054                for j2, l2 in enumerate(self.l_j):
1055                    for m2 in range(2 * l2 + 1):
1056                        if l1 == l2 and m1 == m2:
1057                            B_ii[i1, i2] = B_jj[j1, j2]
1058                        i2 += 1
1059                i1 += 1
1060        return B_ii
1061
1062    def get_compensation_charges(self, phi_jg, phit_jg, _np, T_Lqp):
1063        lmax = self.lmax
1064        gcut2 = self.gcut2
1065        nq = self.local_corr.nq
1066
1067        g_lg = self.data.create_compensation_charge_functions(lmax)
1068
1069        n_qg = np.zeros((nq, gcut2))
1070        nt_qg = np.zeros((nq, gcut2))
1071        q = 0  # q: common index for j1, j2
1072        for j1 in range(self.nj):
1073            for j2 in range(j1, self.nj):
1074                n_qg[q] = phi_jg[j1] * phi_jg[j2]
1075                nt_qg[q] = phit_jg[j1] * phit_jg[j2]
1076                q += 1
1077
1078        gcutmin = self.gcutmin
1079        r_g = self.local_corr.rgd2.r_g
1080        dr_g = self.local_corr.rgd2.dr_g
1081        self.lq = np.dot(n_qg[:, :gcutmin], r_g[:gcutmin]**2 * dr_g[:gcutmin])
1082
1083        Delta_lq = np.zeros((lmax + 1, nq))
1084        for l in range(lmax + 1):
1085            Delta_lq[l] = np.dot(n_qg - nt_qg, r_g**(2 + l) * dr_g)
1086
1087        Lmax = (lmax + 1)**2
1088        Delta_pL = np.zeros((_np, Lmax))
1089        for l in range(lmax + 1):
1090            L = l**2
1091            for m in range(2 * l + 1):
1092                delta_p = np.dot(Delta_lq[l], T_Lqp[L + m])
1093                Delta_pL[:, L + m] = delta_p
1094
1095        Delta0 = np.dot(self.local_corr.nc_g - self.local_corr.nct_g,
1096                        r_g**2 * dr_g) - self.Z / sqrt(4 * pi)
1097
1098        # Electron density inside augmentation sphere.  Used for estimating
1099        # atomic magnetic moment:
1100        rcutmax = max(self.rcut_j)
1101        gcutmax = self.rgd.round(rcutmax)
1102        N0_q = np.dot(n_qg[:, :gcutmax], (r_g**2 * dr_g)[:gcutmax])
1103        N0_p = np.dot(N0_q, T_Lqp[0]) * sqrt(4 * pi)
1104
1105        return (g_lg[:, :gcut2].copy(), n_qg, nt_qg,
1106                Delta_lq, Lmax, Delta_pL, Delta0, N0_p)
1107
1108    def get_derivative_integrals(self, rgd, phi_jg, phit_jg):
1109        """Calculate PAW-correction matrix elements of nabla.
1110
1111        ::
1112
1113          /  _       _  d       _     ~   _  d   ~   _
1114          | dr [phi (r) -- phi (r) - phi (r) -- phi (r)]
1115          /        1    dx    2         1    dx    2
1116
1117        and similar for y and z."""
1118        # lmax needs to be at least 1 for evaluating
1119        # the Gaunt coefficients from derivatives
1120        lmax = max(1, max(self.l_j))
1121        G_LLL = gaunt(lmax)
1122        Y_LLv = nabla(lmax)
1123
1124        r_g = rgd.r_g
1125        dr_g = rgd.dr_g
1126        nabla_iiv = np.empty((self.ni, self.ni, 3))
1127        if debug:
1128            nabla_iiv[:] = np.nan
1129        i1 = 0
1130        for j1 in range(self.nj):
1131            l1 = self.l_j[j1]
1132            nm1 = 2 * l1 + 1
1133            i2 = 0
1134            for j2 in range(self.nj):
1135                l2 = self.l_j[j2]
1136                nm2 = 2 * l2 + 1
1137                f1f2or = np.dot(phi_jg[j1] * phi_jg[j2] -
1138                                phit_jg[j1] * phit_jg[j2], r_g * dr_g)
1139                dphidr_g = np.empty_like(phi_jg[j2])
1140                rgd.derivative_spline(phi_jg[j2], dphidr_g)
1141                dphitdr_g = np.empty_like(phit_jg[j2])
1142                rgd.derivative_spline(phit_jg[j2], dphitdr_g)
1143                f1df2dr = np.dot(phi_jg[j1] * dphidr_g -
1144                                 phit_jg[j1] * dphitdr_g, r_g**2 * dr_g)
1145                for v in range(3):
1146                    Lv = 1 + (v + 2) % 3
1147                    G_12 = G_LLL[Lv, l1**2:l1**2 + nm1, l2**2:l2**2 + nm2]
1148                    Y_12 = Y_LLv[l1**2:l1**2 + nm1, l2**2:l2**2 + nm2, v]
1149                    nabla_iiv[i1:i1 + nm1, i2:i2 + nm2, v] = (
1150                        sqrt(4 * pi / 3) * (f1df2dr - l2 * f1f2or) * G_12
1151                        + f1f2or * Y_12)
1152                i2 += nm2
1153            i1 += nm1
1154        if debug:
1155            assert not np.any(np.isnan(nabla_iiv))
1156        return nabla_iiv
1157
1158    def get_magnetic_integrals(self, rgd, phi_jg, phit_jg):
1159        """Calculate PAW-correction matrix elements of r x nabla.
1160
1161        ::
1162
1163          /  _       _          _     ~   _      ~   _
1164          | dr [phi (r) O  phi (r) - phi (r) O  phi (r)]
1165          /        1     x    2         1     x    2
1166
1167                       d      d
1168          where O  = y -- - z --
1169                 x     dz     dy
1170
1171        and similar for y and z."""
1172        # lmax needs to be at least 1 for evaluating
1173        # the Gaunt coefficients from derivatives
1174        lmax = max(1, max(self.l_j))
1175        G_LLL = gaunt(lmax)
1176        Y_LLv = nabla(2 * lmax)
1177
1178        r_g = rgd.r_g
1179        dr_g = rgd.dr_g
1180        rxnabla_iiv = np.empty((self.ni, self.ni, 3))
1181        if debug:
1182            rxnabla_iiv[:] = np.nan
1183        i1 = 0
1184        for j1 in range(self.nj):
1185            l1 = self.l_j[j1]
1186            nm1 = 2 * l1 + 1
1187            i2 = 0
1188            for j2 in range(self.nj):
1189                l2 = self.l_j[j2]
1190                nm2 = 2 * l2 + 1
1191                f1f2or = np.dot(phi_jg[j1] * phi_jg[j2] -
1192                                phit_jg[j1] * phit_jg[j2], r_g**2 * dr_g)
1193                for v in range(3):
1194                    v1 = (v + 1) % 3
1195                    v2 = (v + 2) % 3
1196                    Lv1 = 1 + (v1 + 2) % 3
1197                    Lv2 = 1 + (v2 + 2) % 3
1198                    # term from radial wfs does not contribute
1199                    # term from spherical harmonics derivatives
1200                    G_12 = np.zeros((nm1, nm2))
1201                    G_12 += np.dot(G_LLL[Lv1, l1**2:l1**2 + nm1, :],
1202                                   Y_LLv[:, l2**2:l2**2 + nm2, v2])
1203                    G_12 -= np.dot(G_LLL[Lv2, l1**2:l1**2 + nm1, :],
1204                                   Y_LLv[:, l2**2:l2**2 + nm2, v1])
1205                    rxnabla_iiv[i1:i1 + nm1, i2:i2 + nm2, v] = (
1206                        sqrt(4 * pi / 3) * f1f2or * G_12)
1207                i2 += nm2
1208            i1 += nm1
1209        if debug:
1210            assert not np.any(np.isnan(rxnabla_iiv))
1211        return rxnabla_iiv
1212
1213    def construct_core_densities(self, setupdata):
1214        rcore = self.data.find_core_density_cutoff(setupdata.nc_g)
1215        nct = self.rgd.spline(setupdata.nct_g, rcore)
1216        return rcore, setupdata.nc_g, setupdata.nct_g, nct
1217
1218    def create_basis_functions(self, phit_jg, rcut2, gcut2):
1219        # Cutoff for atomic orbitals used for initial guess:
1220        rcut3 = 8.0  # XXXXX Should depend on the size of the atom!
1221        gcut3 = self.rgd.ceil(rcut3)
1222
1223        # We cut off the wave functions smoothly at rcut3 by the
1224        # following replacement:
1225        #
1226        #            /
1227        #           | f(r),                                   r < rcut2
1228        #  f(r) <- <  f(r) - a(r) f(rcut3) - b(r) f'(rcut3),  rcut2 < r < rcut3
1229        #           | 0,                                      r > rcut3
1230        #            \
1231        #
1232        # where a(r) and b(r) are 4. order polynomials:
1233        #
1234        #  a(rcut2) = 0,  a'(rcut2) = 0,  a''(rcut2) = 0,
1235        #  a(rcut3) = 1, a'(rcut3) = 0
1236        #  b(rcut2) = 0, b'(rcut2) = 0, b''(rcut2) = 0,
1237        #  b(rcut3) = 0, b'(rcut3) = 1
1238        #
1239        r_g = self.rgd.r_g
1240        x = (r_g[gcut2:gcut3] - rcut2) / (rcut3 - rcut2)
1241        a_g = 4 * x**3 * (1 - 0.75 * x)
1242        b_g = x**3 * (x - 1) * (rcut3 - rcut2)
1243
1244        class PartialWaveBasis(Basis):  # yuckkk
1245            def __init__(self, symbol, phit_j):
1246                Basis.__init__(self, symbol, 'partial-waves', readxml=False)
1247                self.phit_j = phit_j
1248
1249            def tosplines(self):
1250                return self.phit_j
1251
1252            def get_description(self):
1253                template = 'Using partial waves for %s as LCAO basis'
1254                string = template % self.symbol
1255                return string
1256
1257        phit_j = []
1258        for j, phit_g in enumerate(phit_jg):
1259            if self.n_j[j] > 0:
1260                l = self.l_j[j]
1261                phit = phit_g[gcut3]
1262                dphitdr = ((phit - phit_g[gcut3 - 1]) /
1263                           (r_g[gcut3] - r_g[gcut3 - 1]))
1264                phit_g[gcut2:gcut3] -= phit * a_g + dphitdr * b_g
1265                phit_g[gcut3:] = 0.0
1266                phit_j.append(self.rgd.spline(phit_g, rcut3, l, points=100))
1267        basis = PartialWaveBasis(self.symbol, phit_j)
1268        return basis
1269
1270    def calculate_oscillator_strengths(self, phi_jg):
1271        # XXX implement oscillator strengths for lcorehole != 0
1272        assert(self.lcorehole == 0)
1273        self.A_ci = np.zeros((3, self.ni))
1274        nj = len(phi_jg)
1275        i = 0
1276        for j in range(nj):
1277            l = self.l_j[j]
1278            if l == 1:
1279                a = self.rgd.integrate(phi_jg[j] * self.data.phicorehole_g,
1280                                       n=1) / (4 * pi)
1281
1282                for m in range(3):
1283                    c = (m + 1) % 3
1284                    self.A_ci[c, i] = a
1285                    i += 1
1286            else:
1287                i += 2 * l + 1
1288        assert i == self.ni
1289
1290
1291class Setups(list):
1292    """Collection of Setup objects. One for each distinct atom.
1293
1294    Non-distinct atoms are those with the same atomic number, setup, and basis.
1295
1296    Class attributes:
1297
1298    ``nvalence``    Number of valence electrons.
1299    ``nao``         Number of atomic orbitals.
1300    ``Eref``        Reference energy.
1301    ``core_charge`` Core hole charge.
1302    """
1303
1304    def __init__(self, Z_a, setup_types, basis_sets, xc,
1305                 filter=None, world=None):
1306        list.__init__(self)
1307        symbols = [chemical_symbols[Z] for Z in Z_a]
1308        type_a = types2atomtypes(symbols, setup_types, default='paw')
1309        basis_a = types2atomtypes(symbols, basis_sets, default=None)
1310
1311        for a, _type in enumerate(type_a):
1312            # Make basis files correspond to setup files.
1313            #
1314            # If the setup has a name (i.e. non-default _type), then
1315            # prepend that name to the basis name.
1316            #
1317            # Typically people might specify '11' as the setup but just
1318            # 'dzp' for the basis set.  Here we adjust to
1319            # obtain, say, '11.dzp' which loads the correct basis set.
1320            #
1321            # There will be no way to obtain the original 'dzp' with
1322            # a custom-named setup except by loading directly from
1323            # BasisData.
1324            #
1325            # Due to the "szp(dzp)" syntax this is complicated!
1326            # The name has to go as "szp(name.dzp)".
1327            basis = basis_a[a]
1328            if isinstance(basis, str):
1329                if isinstance(_type, str):
1330                    setupname = _type
1331                else:
1332                    setupname = _type.name  # _type is an object like SetupData
1333                # Drop DFT+U specification from type string if it is there:
1334                if hasattr(setupname, 'swapcase'):
1335                    setupname = setupname.split(':')[0]
1336
1337                # Basis names inherit setup names except default setups
1338                # and ghost atoms.
1339                if setupname != 'paw' and setupname != 'ghost':
1340                    if setupname:
1341                        if '(' in basis:
1342                            reduced, name = basis.split('(')
1343                            assert name.endswith(')')
1344                            name = name[:-1]
1345                            fullname = f'{reduced}({setupname}.{name})'
1346                        else:
1347                            fullname = f'{setupname}.{basis_a[a]}'
1348                        basis_a[a] = fullname
1349
1350        # Construct necessary PAW-setup objects:
1351        self.setups = {}
1352        natoms = {}
1353        Mcumulative = 0
1354        self.M_a = []
1355        self.id_a = list(zip(Z_a, type_a, basis_a))
1356        for id in self.id_a:
1357            setup = self.setups.get(id)
1358            if setup is None:
1359                Z, type, basis = id
1360                symbol = chemical_symbols[Z]
1361                setupdata = None
1362                if not isinstance(type, str):
1363                    setupdata = type
1364                # Basis may be None (meaning that the setup decides), a string
1365                # (meaning we load the basis set now from a file) or an actual
1366                # pre-created Basis object (meaning we just pass it along)
1367                if isinstance(basis, str):
1368                    basis = Basis(symbol, basis, world=world)
1369                setup = create_setup(symbol, xc, 2, type,
1370                                     basis, setupdata=setupdata,
1371                                     filter=filter, world=world)
1372                self.setups[id] = setup
1373                natoms[id] = 0
1374            natoms[id] += 1
1375            self.append(setup)
1376            self.M_a.append(Mcumulative)
1377            Mcumulative += setup.nao
1378
1379        # Sum up ...
1380        self.nvalence = 0       # number of valence electrons
1381        self.nao = 0            # number of atomic orbitals
1382        self.Eref = 0.0         # reference energy
1383        self.core_charge = 0.0  # core hole charge
1384        for id, setup in self.setups.items():
1385            n = natoms[id]
1386            self.Eref += n * setup.E
1387            self.core_charge += n * (setup.Z - setup.Nv - setup.Nc)
1388            self.nvalence += n * setup.Nv
1389            self.nao += n * setup.nao
1390
1391        self.dS = OverlapCorrections(self)
1392
1393    def __str__(self):
1394        # Write PAW setup information in order of appearance:
1395        ids = set()
1396        s = ''
1397        for id in self.id_a:
1398            if id in ids:
1399                continue
1400            ids.add(id)
1401            setup = self.setups[id]
1402            output = StringIO()
1403            setup.print_info(functools.partial(print, file=output))
1404            txt = output.getvalue()
1405            basis_descr = setup.get_basis_description()
1406            basis_descr = basis_descr.replace('\n  ', '\n    ')
1407            s += txt + '  ' + basis_descr + '\n\n'
1408
1409        s += f'Reference energy: {self.Eref * units.Hartree:.6f}\n'
1410        return s
1411
1412    def set_symmetry(self, symmetry):
1413        """Find rotation matrices for spherical harmonics."""
1414        R_slmm = []
1415        for op_cc in symmetry.op_scc:
1416            op_vv = np.dot(np.linalg.inv(symmetry.cell_cv),
1417                           np.dot(op_cc, symmetry.cell_cv))
1418            R_slmm.append([rotation(l, op_vv) for l in range(4)])
1419
1420        for setup in self.setups.values():
1421            setup.calculate_rotations(R_slmm)
1422
1423    def empty_atomic_matrix(self, ns, atom_partition, dtype=float):
1424        Dshapes_a = [(ns, setup.ni * (setup.ni + 1) // 2)
1425                     for setup in self]
1426        return atom_partition.arraydict(Dshapes_a, dtype)
1427
1428    def estimate_dedecut(self, ecut):
1429        from gpaw.utilities.ekin import dekindecut, ekin
1430        dedecut = 0.0
1431        e = {}
1432        for id in self.id_a:
1433            if id not in e:
1434                G, de, e0 = ekin(self.setups[id])
1435                e[id] = -dekindecut(G, de, ecut)
1436            dedecut += e[id]
1437        return dedecut
1438
1439    def basis_indices(self):
1440        return FunctionIndices([setup.phit_j for setup in self])
1441
1442    def projector_indices(self):
1443        return FunctionIndices([setup.pt_j for setup in self])
1444
1445
1446class FunctionIndices:
1447    def __init__(self, f_aj):
1448        nm_a = [0]
1449        for f_j in f_aj:
1450            nm = sum([2 * f.get_angular_momentum_number() + 1 for f in f_j])
1451            nm_a.append(nm)
1452        self.M_a = np.cumsum(nm_a)
1453        self.nm_a = np.array(nm_a[1:])
1454        self.max = self.M_a[-1]
1455
1456    def __getitem__(self, a):
1457        return self.M_a[a], self.M_a[a + 1]
1458
1459
1460def types2atomtypes(symbols, types, default):
1461    """Map a types identifier to a list with a type id for each atom.
1462
1463    types can be a single str, or a dictionary mapping chemical
1464    symbols and/or atom numbers to a type identifier.
1465    If both a symbol key and atomnumber key relates to the same atom, then
1466    the atomnumber key is dominant.
1467
1468    If types is a dictionary and contains the string 'default', this will
1469    be used as default type, otherwize input arg ``default`` is used as
1470    default.
1471    """
1472    natoms = len(symbols)
1473    if isinstance(types, str):
1474        return [types] * natoms
1475
1476    # If present, None will map to the default type,
1477    # else use the input default
1478    type_a = [types.get('default', default)] * natoms
1479
1480    # First symbols ...
1481    for symbol, type in types.items():
1482        # Types are given either by strings or they are objects that
1483        # have a 'symbol' attribute (SetupData, Pseudopotential, Basis, etc.).
1484        assert isinstance(type, str) or hasattr(type, 'symbol')
1485        if isinstance(symbol, str):
1486            for a, symbol2 in enumerate(symbols):
1487                if symbol == symbol2:
1488                    type_a[a] = type
1489
1490    # and then atom indices
1491    for a, type in types.items():
1492        if isinstance(a, int):
1493            type_a[a] = type
1494
1495    return type_a
1496
1497
1498if __name__ == '__main__':
1499    print("""\
1500This is not the setup.py you are looking for!  This setup.py defines a
1501Setup class used to hold the atomic data needed for a specific atom.
1502For building the GPAW code you must use the setup.py distutils script
1503at the root of the code tree.  Just do "cd .." and you will be at the
1504right place.""")
1505    raise SystemExit
1506