1# -*- coding: utf-8 -*-
2import numpy as np
3
4from ase.units import Hartree, Bohr
5
6from ase.io.ulm import Reader
7from gpaw.io import Writer
8from gpaw.external import ConstantElectricField
9from gpaw.lcaotddft.hamiltonian import KickHamiltonian
10from gpaw.lcaotddft.utilities import collect_MM
11from gpaw.lcaotddft.utilities import distribute_nM
12from gpaw.lcaotddft.utilities import read_uMM
13from gpaw.lcaotddft.utilities import write_uMM
14from gpaw.lcaotddft.utilities import read_uX, write_uX
15from gpaw.utilities.scalapack import \
16    pblas_simple_gemm, pblas_simple_hemm, scalapack_tri2full
17from gpaw.utilities.tools import tri2full
18
19
20def gauss_ij(energy_i, energy_j, sigma):
21    denergy_ij = energy_i[:, np.newaxis] - energy_j[np.newaxis, :]
22    norm = 1.0 / (sigma * np.sqrt(2 * np.pi))
23    return norm * np.exp(-0.5 * denergy_ij**2 / sigma**2)
24
25
26def get_bfs_maps(calc):
27    # Construct maps
28    # a_M: M -> atom index a
29    # l_M: M -> angular momentum l
30    a_M = []
31    l_M = []
32    M = 0
33    for a, sphere in enumerate(calc.wfs.basis_functions.sphere_a):
34        for j, spline in enumerate(sphere.spline_j):
35            l = spline.get_angular_momentum_number()
36            for _ in range(2 * l + 1):
37                a_M.append(a)
38                l_M.append(l)
39                M += 1
40    a_M = np.array(a_M)
41    l_M = np.array(l_M)
42    return a_M, l_M
43
44
45class KohnShamDecomposition(object):
46    version = 1
47    ulmtag = 'KSD'
48    readwrite_attrs = ['fermilevel', 'only_ia', 'w_p', 'f_p', 'ia_p',
49                       'P_p', 'dm_vp', 'a_M', 'l_M']
50
51    def __init__(self, paw=None, filename=None):
52        self.filename = filename
53        self.has_initialized = False
54        self.reader = None
55        if paw is not None:
56            self.world = paw.world
57            self.log = paw.log
58            self.ksl = paw.wfs.ksl
59            self.kd = paw.wfs.kd
60            self.bd = paw.wfs.bd
61            self.kpt_u = paw.wfs.kpt_u
62            self.density = paw.density
63            self.comm = paw.comms['K']
64
65            if len(paw.wfs.kpt_u) > 1:
66                raise RuntimeError('K-points are not fully supported')
67
68        if filename is not None:
69            self.read(filename)
70            return
71
72    def initialize(self, paw, min_occdiff=1e-3, only_ia=True):
73        if self.has_initialized:
74            return
75        paw.initialize_positions()
76        # paw.set_positions()
77
78        assert self.bd.nbands == self.ksl.nao
79        self.only_ia = only_ia
80
81        if not self.ksl.using_blacs and self.bd.comm.size > 1:
82            raise RuntimeError('Band parallelization without scalapack '
83                               'is not supported')
84
85        if self.kd.gamma:
86            self.C0_dtype = float
87        else:
88            self.C0_dtype = complex
89
90        # Take quantities
91        self.fermilevel = paw.wfs.fermi_level
92        self.S_uMM = []
93        self.C0_unM = []
94        self.eig_un = []
95        self.occ_un = []
96        for kpt in paw.wfs.kpt_u:
97            S_MM = kpt.S_MM
98            assert np.max(np.absolute(S_MM.imag)) == 0.0
99            S_MM = np.ascontiguousarray(S_MM.real)
100            if self.ksl.using_blacs:
101                scalapack_tri2full(self.ksl.mmdescriptor, S_MM)
102            self.S_uMM.append(S_MM)
103
104            C_nM = kpt.C_nM
105            if self.C0_dtype == float:
106                assert np.max(np.absolute(C_nM.imag)) == 0.0
107                C_nM = np.ascontiguousarray(C_nM.real)
108            C_nM = distribute_nM(self.ksl, C_nM)
109            self.C0_unM.append(C_nM)
110
111            eig_n = paw.wfs.collect_eigenvalues(kpt.k, kpt.s)
112            occ_n = paw.wfs.collect_occupations(kpt.k, kpt.s)
113            self.eig_un.append(eig_n)
114            self.occ_un.append(occ_n)
115
116        self.a_M, self.l_M = get_bfs_maps(paw)
117        self.atoms = paw.atoms
118
119        # TODO: do the rest of the function with K-points
120
121        # Construct p = (i, a) pairs
122        u = 0
123        eig_n = self.eig_un[u]
124        occ_n = self.occ_un[u]
125        C0_nM = self.C0_unM[u]
126
127        if self.comm.rank == 0:
128            Nn = self.bd.nbands
129
130            f_p = []
131            w_p = []
132            i_p = []
133            a_p = []
134            ia_p = []
135            i0 = 0
136            for i in range(i0, Nn):
137                if only_ia:
138                    a0 = i + 1
139                else:
140                    a0 = 0
141                for a in range(a0, Nn):
142                    f = occ_n[i] - occ_n[a]
143                    if only_ia and f < min_occdiff:
144                        continue
145                    w = eig_n[a] - eig_n[i]
146                    f_p.append(f)
147                    w_p.append(w)
148                    i_p.append(i)
149                    a_p.append(a)
150                    ia_p.append((i, a))
151            f_p = np.array(f_p)
152            w_p = np.array(w_p)
153            i_p = np.array(i_p, dtype=int)
154            a_p = np.array(a_p, dtype=int)
155            ia_p = np.array(ia_p, dtype=int)
156
157            # Sort according to energy difference
158            p_s = np.argsort(w_p)
159            f_p = f_p[p_s]
160            w_p = w_p[p_s]
161            i_p = i_p[p_s]
162            a_p = a_p[p_s]
163            ia_p = ia_p[p_s]
164
165            Np = len(f_p)
166            P_p = []
167            for p in range(Np):
168                P = np.ravel_multi_index(ia_p[p], (Nn, Nn))
169                P_p.append(P)
170            P_p = np.array(P_p)
171
172            dm_vp = np.empty((3, Np), dtype=float)
173
174        for v in range(3):
175            direction = np.zeros(3, dtype=float)
176            direction[v] = 1.0
177            cef = ConstantElectricField(Hartree / Bohr, direction)
178            kick_hamiltonian = KickHamiltonian(paw.hamiltonian, paw.density,
179                                               cef)
180            dm_MM = paw.wfs.eigensolver.calculate_hamiltonian_matrix(
181                kick_hamiltonian, paw.wfs, paw.wfs.kpt_u[u],
182                add_kinetic=False, root=-1)
183
184            if self.ksl.using_blacs:
185                tmp_nM = self.ksl.mmdescriptor.zeros(dtype=C0_nM.dtype)
186                pblas_simple_hemm(self.ksl.mmdescriptor,
187                                  self.ksl.mmdescriptor,
188                                  self.ksl.mmdescriptor,
189                                  dm_MM, C0_nM.conj(), tmp_nM,
190                                  side='R', uplo='L')
191                dm_nn = self.ksl.mmdescriptor.zeros(dtype=C0_nM.dtype)
192                pblas_simple_gemm(self.ksl.mmdescriptor,
193                                  self.ksl.mmdescriptor,
194                                  self.ksl.mmdescriptor,
195                                  tmp_nM, C0_nM, dm_nn, transb='T')
196            else:
197                tri2full(dm_MM)
198                dm_nn = np.dot(C0_nM.conj(), np.dot(dm_MM, C0_nM.T))
199
200            dm_nn = collect_MM(self.ksl, dm_nn)
201            if self.comm.rank == 0:
202                dm_P = dm_nn.ravel()
203                dm_p = dm_P[P_p]
204                dm_vp[v] = dm_p
205
206        if self.comm.rank == 0:
207            self.w_p = w_p
208            self.f_p = f_p
209            self.ia_p = ia_p
210            self.P_p = P_p
211            self.dm_vp = dm_vp
212
213        self.has_initialized = True
214
215    def write(self, filename):
216        from ase.io.trajectory import write_atoms
217
218        self.log('%s: Writing to %s' % (self.__class__.__name__, filename))
219        writer = Writer(filename, self.world, mode='w',
220                        tag=self.__class__.ulmtag)
221        writer.write(version=self.__class__.version)
222
223        write_atoms(writer.child('atoms'), self.atoms)
224
225        writer.write(ha=Hartree)
226        write_uMM(self.kd, self.ksl, writer, 'S_uMM', self.S_uMM)
227        write_uMM(self.kd, self.ksl, writer, 'C0_unM', self.C0_unM)
228        write_uX(self.kd, self.ksl.block_comm, writer, 'eig_un', self.eig_un)
229        write_uX(self.kd, self.ksl.block_comm, writer, 'occ_un', self.occ_un)
230
231        if self.comm.rank == 0:
232            for arg in self.readwrite_attrs:
233                writer.write(arg, getattr(self, arg))
234
235        writer.close()
236
237    def read(self, filename):
238        self.reader = Reader(filename)
239        tag = self.reader.get_tag()
240        if tag != self.__class__.ulmtag:
241            raise RuntimeError('Unknown tag %s' % tag)
242        self.version = self.reader.version
243
244        # Do lazy reading in __getattr__ only if/when
245        # the variables are required
246        self.has_initialized = True
247
248    def __getattr__(self, attr):
249        if attr in ['S_uMM', 'C0_unM']:
250            val = read_uMM(self.kpt_u, self.ksl, self.reader, attr)
251            setattr(self, attr, val)
252            return val
253        if attr in ['eig_un', 'occ_un']:
254            val = read_uX(self.kpt_u, self.reader, attr)
255            setattr(self, attr, val)
256            return val
257        if attr in ['C0S_unM']:
258            C0S_unM = []
259            for u, kpt in enumerate(self.kpt_u):
260                C0_nM = self.C0_unM[u]
261                S_MM = self.S_uMM[u]
262                if self.ksl.using_blacs:
263                    C0S_nM = self.ksl.mmdescriptor.zeros(dtype=C0_nM.dtype)
264                    pblas_simple_hemm(self.ksl.mmdescriptor,
265                                      self.ksl.mmdescriptor,
266                                      self.ksl.mmdescriptor,
267                                      S_MM, C0_nM, C0S_nM,
268                                      side='R', uplo='L')
269                else:
270                    C0S_nM = np.dot(C0_nM, S_MM)
271                C0S_unM.append(C0S_nM)
272            setattr(self, attr, C0S_unM)
273            return C0S_unM
274        if attr in ['weight_Mn']:
275            assert self.world.size == 1
276            C2_nM = np.absolute(self.C0_unM[0])**2
277            val = C2_nM.T / np.sum(C2_nM, axis=1)
278            setattr(self, attr, val)
279            return val
280
281        try:
282            val = getattr(self.reader, attr)
283            if attr == 'atoms':
284                from ase.io.trajectory import read_atoms
285                val = read_atoms(val)
286            setattr(self, attr, val)
287            return val
288        except (KeyError, AttributeError):
289            pass
290
291        raise AttributeError('Attribute %s not defined in version %s' %
292                             (repr(attr), repr(self.version)))
293
294    def distribute(self, comm):
295        self.comm = comm
296        N = comm.size
297        self.Np = len(self.P_p)
298        self.Nq = int(np.ceil(self.Np / float(N)))
299        self.NQ = self.Nq * N
300        self.w_q = self.distribute_p(self.w_p)
301        self.f_q = self.distribute_p(self.f_p)
302        self.dm_vq = self.distribute_xp(self.dm_vp)
303
304    def distribute_p(self, a_p, a_q=None, root=0):
305        if a_q is None:
306            a_q = np.zeros(self.Nq, dtype=a_p.dtype)
307        if self.comm.rank == root:
308            a_Q = np.append(a_p, np.zeros(self.NQ - self.Np, dtype=a_p.dtype))
309        else:
310            a_Q = None
311        self.comm.scatter(a_Q, a_q, root)
312        return a_q
313
314    def collect_q(self, a_q, root=0):
315        if self.comm.rank == root:
316            a_Q = np.zeros(self.NQ, dtype=a_q.dtype)
317        else:
318            a_Q = None
319        self.comm.gather(a_q, root, a_Q)
320        if self.comm.rank == root:
321            a_p = a_Q[:self.Np]
322        else:
323            a_p = None
324        return a_p
325
326    def distribute_xp(self, a_xp):
327        Nx = a_xp.shape[0]
328        a_xq = np.zeros((Nx, self.Nq), dtype=a_xp.dtype)
329        for x in range(Nx):
330            self.distribute_p(a_xp[x], a_xq[x])
331        return a_xq
332
333    def transform(self, rho_uMM, broadcast=False):
334        assert len(rho_uMM) == 1, 'K-points not implemented'
335        u = 0
336        rho_MM = np.ascontiguousarray(rho_uMM[u])
337        C0S_nM = self.C0S_unM[u].astype(rho_MM.dtype, copy=True)
338        # KS decomposition
339        if self.ksl.using_blacs:
340            tmp_nM = self.ksl.mmdescriptor.zeros(dtype=rho_MM.dtype)
341            pblas_simple_gemm(self.ksl.mmdescriptor,
342                              self.ksl.mmdescriptor,
343                              self.ksl.mmdescriptor,
344                              C0S_nM, rho_MM, tmp_nM)
345            rho_nn = self.ksl.mmdescriptor.zeros(dtype=rho_MM.dtype)
346            pblas_simple_gemm(self.ksl.mmdescriptor,
347                              self.ksl.mmdescriptor,
348                              self.ksl.mmdescriptor,
349                              tmp_nM, C0S_nM, rho_nn, transb='C')
350        else:
351            rho_nn = np.dot(np.dot(C0S_nM, rho_MM), C0S_nM.T.conj())
352
353        rho_nn = collect_MM(self.ksl, rho_nn)
354        if self.comm.rank == 0:
355            rho_P = rho_nn.ravel()
356            # Remove de-excitation terms
357            rho_p = rho_P[self.P_p]
358            if self.only_ia:
359                rho_p *= 2
360        else:
361            rho_p = None
362
363        if broadcast:
364            if self.comm.rank != 0:
365                rho_p = np.zeros_like(self.P_p, dtype=rho_MM.dtype)
366            self.comm.broadcast(rho_p, 0)
367        rho_up = [rho_p]
368        return rho_up
369
370    def ialims(self):
371        i_p = self.ia_p[:, 0]
372        a_p = self.ia_p[:, 1]
373        imin = np.min(i_p)
374        imax = np.max(i_p)
375        amin = np.min(a_p)
376        amax = np.max(a_p)
377        return imin, imax, amin, amax
378
379    def M_p_to_M_ia(self, M_p):
380        return self.M_ia_from_M_p(M_p)
381
382    def M_ia_from_M_p(self, M_p):
383        imin, imax, amin, amax = self.ialims()
384        M_ia = np.zeros((imax - imin + 1, amax - amin + 1), dtype=M_p.dtype)
385        for M, (i, a) in zip(M_p, self.ia_p):
386            M_ia[i - imin, a - amin] = M
387        return M_ia
388
389    def plot_matrix(self, M_p):
390        import matplotlib.pyplot as plt
391        M_ia = self.M_ia_from_M_p(M_p)
392        plt.imshow(M_ia, interpolation='none')
393        plt.xlabel('a')
394        plt.ylabel('i')
395
396    def get_dipole_moment_contributions(self, rho_up):
397        assert len(rho_up) == 1, 'K-points not implemented'
398        u = 0
399        rho_p = rho_up[u]
400        dmrho_vp = - self.dm_vp * rho_p
401        return dmrho_vp
402
403    def get_dipole_moment(self, rho_up):
404        assert len(rho_up) == 1, 'K-points not implemented'
405        u = 0
406        rho_p = rho_up[u]
407        dm_v = - np.dot(self.dm_vp, rho_p)
408        return dm_v
409
410    def get_density(self, wfs, rho_up, density='comp'):
411        from gpaw.lcaotddft.densitymatrix import get_density
412
413        if self.ksl.using_blacs:
414            raise NotImplementedError('Scalapack is not supported')
415
416        density_type = density
417        assert len(rho_up) == 1, 'K-points not implemented'
418        u = 0
419        rho_p = rho_up[u]
420        C0_nM = self.C0_unM[u]
421
422        rho_ia = self.M_ia_from_M_p(rho_p)
423        imin, imax, amin, amax = self.ialims()
424        C0_iM = C0_nM[imin:(imax + 1)]
425        C0_aM = C0_nM[amin:(amax + 1)]
426
427        rho_MM = np.dot(C0_iM.T, np.dot(rho_ia, C0_aM.conj()))
428        rho_MM = 0.5 * (rho_MM + rho_MM.T)
429
430        return get_density(rho_MM, wfs, self.density, density_type, u)
431
432    def get_contributions_table(self, weight_p, minweight=0.01,
433                                zero_fermilevel=True):
434        assert weight_p.dtype == float
435        u = 0  # TODO
436
437        absweight_p = np.absolute(weight_p)
438        tot_weight = weight_p.sum()
439        propweight_p = weight_p / tot_weight * 100
440        tot_propweight = propweight_p.sum()
441        rest_weight = tot_weight
442        rest_propweight = tot_propweight
443        eig_n = self.eig_un[u].copy()
444        if zero_fermilevel:
445            eig_n -= self.fermilevel
446
447        txt = ''
448        txt += ('# %6s %4s(%8s)    %4s(%8s)  %12s %14s %8s\n' %
449                ('p', 'i', 'eV', 'a', 'eV', 'Ediff (eV)', 'weight', '%'))
450        p_s = np.argsort(absweight_p)[::-1]
451        for s, p in enumerate(p_s):
452            i, a = self.ia_p[p]
453            if absweight_p[p] < minweight:
454                break
455            txt += ('  %6s %4d(%8.3f) -> %4d(%8.3f): %12.4f %14.4f %8.1f\n' %
456                    (p, i, eig_n[i] * Hartree, a, eig_n[a] * Hartree,
457                     self.w_p[p] * Hartree, weight_p[p], propweight_p[p]))
458            rest_weight -= weight_p[p]
459            rest_propweight -= propweight_p[p]
460        txt += ('  %39s: %12s %+14.4f %8.1f\n' %
461                ('rest', '', rest_weight, rest_propweight))
462        txt += ('  %39s: %12s %+14.4f %8.1f\n' %
463                ('total', '', tot_weight, tot_propweight))
464        return txt
465
466    def plot_TCM(self, weight_p, energy_o, energy_u, sigma,
467                 zero_fermilevel=True, vmax='80%'):
468        from gpaw.lcaotddft.tcm import TCMPlotter
469        plotter = TCMPlotter(self, energy_o, energy_u, sigma, zero_fermilevel)
470        ax_tcm = plotter.plot_TCM(weight_p, vmax)
471        ax_occ_dos, ax_unocc_dos = plotter.plot_DOS()
472        return ax_tcm, ax_occ_dos, ax_unocc_dos
473
474    def get_TCM(self, weight_p, eig_n, energy_o, energy_u, sigma):
475        flt_p = self.filter_by_x_ia(eig_n, energy_o, energy_u, 8 * sigma)
476        weight_f = weight_p[flt_p]
477        G_fo = gauss_ij(eig_n[self.ia_p[flt_p, 0]], energy_o, sigma)
478        G_fu = gauss_ij(eig_n[self.ia_p[flt_p, 1]], energy_u, sigma)
479        tcm_ou = np.dot(G_fo.T * weight_f, G_fu)
480        return tcm_ou
481
482    def get_DOS(self, eig_n, energy_o, energy_u, sigma):
483        return self.get_weighted_DOS(1, eig_n, energy_o, energy_u, sigma)
484
485    def get_weighted_DOS(self, weight_n, eig_n, energy_o, energy_u, sigma):
486        if not isinstance(weight_n, np.ndarray):
487            # Assume float
488            weight_n = weight_n * np.ones_like(eig_n)
489        G_on = gauss_ij(energy_o, eig_n, sigma)
490        G_un = gauss_ij(energy_u, eig_n, sigma)
491        dos_o = np.dot(G_on, weight_n)
492        dos_u = np.dot(G_un, weight_n)
493        return dos_o, dos_u
494
495    def get_weight_n_by_l(self, l):
496        if isinstance(l, int):
497            weight_n = np.sum(self.weight_Mn[self.l_M == l], axis=0)
498        else:
499            weight_n = np.sum([self.get_weight_n_by_l(l_) for l_ in l],
500                              axis=0)
501        return weight_n
502
503    def get_weight_n_by_a(self, a):
504        if isinstance(a, int):
505            weight_n = np.sum(self.weight_Mn[self.a_M == a], axis=0)
506        else:
507            weight_n = np.sum([self.get_weight_n_by_a(a_) for a_ in a],
508                              axis=0)
509        return weight_n
510
511    def get_distribution_i(self, weight_p, energy_e, sigma,
512                           zero_fermilevel=True):
513        eig_n, fermilevel = self.get_eig_n(zero_fermilevel)
514        flt_p = self.filter_by_x_i(eig_n, energy_e, 8 * sigma)
515        weight_f = weight_p[flt_p]
516        G_fe = gauss_ij(eig_n[self.ia_p[flt_p, 0]], energy_e, sigma)
517        dist_e = np.dot(G_fe.T, weight_f)
518        return dist_e
519
520    def get_distribution_a(self, weight_p, energy_e, sigma,
521                           zero_fermilevel=True):
522        eig_n, fermilevel = self.get_eig_n(zero_fermilevel)
523        flt_p = self.filter_by_x_a(eig_n, energy_e, 8 * sigma)
524        weight_f = weight_p[flt_p]
525        G_fe = gauss_ij(eig_n[self.ia_p[flt_p, 1]], energy_e, sigma)
526        dist_e = np.dot(G_fe.T, weight_f)
527        return dist_e
528
529    def get_distribution_ia(self, weight_p, energy_o, energy_u, sigma,
530                            zero_fermilevel=True):
531        """
532        Filter both i and a spaces as in TCM.
533
534        """
535        eig_n, fermilevel = self.get_eig_n(zero_fermilevel)
536        flt_p = self.filter_by_x_ia(eig_n, energy_o, energy_u, 8 * sigma)
537        weight_f = weight_p[flt_p]
538        G_fo = gauss_ij(eig_n[self.ia_p[flt_p, 0]], energy_o, sigma)
539        dist_o = np.dot(G_fo.T, weight_f)
540        G_fu = gauss_ij(eig_n[self.ia_p[flt_p, 1]], energy_u, sigma)
541        dist_u = np.dot(G_fu.T, weight_f)
542        return dist_o, dist_u
543
544    def get_distribution(self, weight_p, energy_e, sigma):
545        w_p = self.w_p * Hartree
546        flt_p = self.filter_by_x_p(w_p, energy_e, 8 * sigma)
547        weight_f = weight_p[flt_p]
548        G_fe = gauss_ij(w_p[flt_p], energy_e, sigma)
549        dist_e = np.dot(G_fe.T, weight_f)
550        return dist_e
551
552    def get_eig_n(self, zero_fermilevel=True):
553        u = 0  # TODO
554        eig_n = self.eig_un[u].copy()
555        if zero_fermilevel:
556            eig_n -= self.fermilevel
557            fermilevel = 0.0
558        else:
559            fermilevel = self.fermilevel
560        eig_n *= Hartree
561        fermilevel *= Hartree
562        return eig_n, fermilevel
563
564    def filter_by_x_p(self, x_p, energy_e, buf):
565        flt_p = np.logical_and((energy_e[0] - buf) <= x_p,
566                               x_p <= (energy_e[-1] + buf))
567        return flt_p
568
569    def filter_by_x_i(self, x_n, energy_e, buf):
570        return self.filter_by_x_p(x_n[self.ia_p[:, 0]], energy_e, buf)
571
572    def filter_by_x_a(self, x_n, energy_e, buf):
573        return self.filter_by_x_p(x_n[self.ia_p[:, 1]], energy_e, buf)
574
575    def filter_by_x_ia(self, x_n, energy_o, energy_u, buf):
576        flti_p = self.filter_by_x_i(x_n, energy_o, buf)
577        flta_p = self.filter_by_x_a(x_n, energy_u, buf)
578        flt_p = np.logical_and(flti_p, flta_p)
579        return flt_p
580