1from typing import Optional
2
3import numpy as np
4from ase.units import Ha
5
6from gpaw.occupations import OccupationNumberCalculator
7from gpaw.projections import Projections
8from gpaw.utilities import pack, unpack2
9from gpaw.utilities.blas import gemm, axpy
10from gpaw.utilities.partition import AtomPartition
11
12
13class WaveFunctions:
14    """...
15
16    setups:
17        List of setup objects.
18    symmetry:
19        Symmetry object.
20    kpt_u:
21        List of **k**-point objects.
22    nbands: int
23        Number of bands.
24    nspins: int
25        Number of spins.
26    dtype: dtype
27        Data type of wave functions (float or complex).
28    bzk_kc: ndarray
29        Scaled **k**-points used for sampling the whole
30        Brillouin zone - values scaled to [-0.5, 0.5).
31    ibzk_kc: ndarray
32        Scaled **k**-points in the irreducible part of the
33        Brillouin zone.
34    weight_k: ndarray
35        Weights of the **k**-points in the irreducible part
36        of the Brillouin zone (summing up to 1).
37    kpt_comm:
38        MPI-communicator for parallelization over **k**-points.
39    """
40
41    def __init__(self, gd, nvalence, setups, bd, dtype, collinear,
42                 world, kd, kptband_comm, timer):
43        self.gd = gd
44        self.nspins = kd.nspins
45        self.collinear = collinear
46        self.nvalence = nvalence
47        self.bd = bd
48        self.dtype = dtype
49        assert dtype == float or dtype == complex
50        self.world = world
51        self.kd = kd
52        self.kptband_comm = kptband_comm
53        self.timer = timer
54        self.atom_partition = None
55
56        self.kpt_qs = kd.create_k_points(self.gd.sdisp_cd, collinear)
57        self.kpt_u = [kpt for kpt_s in self.kpt_qs for kpt in kpt_s]
58
59        self.occupations: Optional[OccupationNumberCalculator] = None
60        self.fermi_levels: Optional[np.ndarray] = None
61
62        self.eigensolver = None
63        self.positions_set = False
64        self.spos_ac = None
65
66        self.set_setups(setups)
67
68    @property
69    def fermi_level(self):
70        assert len(self.fermi_levels) == 1
71        return self.fermi_levels[0]
72
73    def summary(self, log):
74        log(eigenvalue_string(self))
75
76        if self.fermi_levels is None:
77            return
78
79        if len(self.fermi_levels) == 1:
80            log(f'Fermi level: {self.fermi_levels[0] * Ha:.5f}\n')
81        else:
82            f1, f2 = (f * Ha for f in self.fermi_levels)
83            log(f'Fermi levels: {f1:.5f}, {f2:.5f}\n')
84
85    def set_setups(self, setups):
86        self.setups = setups
87
88    def set_eigensolver(self, eigensolver):
89        self.eigensolver = eigensolver
90
91    def add_realspace_orbital_to_density(self, nt_G, psit_G):
92        if psit_G.dtype == float:
93            axpy(1.0, psit_G**2, nt_G)
94        else:
95            assert psit_G.dtype == complex
96            axpy(1.0, psit_G.real**2, nt_G)
97            axpy(1.0, psit_G.imag**2, nt_G)
98
99    def add_orbital_density(self, nt_G, kpt, n):
100        self.add_realspace_orbital_to_density(nt_G, kpt.psit_nG[n])
101
102    def calculate_band_energy(self):
103        e_band = 0.0
104        for kpt in self.kpt_u:
105            e_band += np.dot(kpt.f_n, kpt.eps_n)
106
107        try:  # DCSF needs this ...
108            e_band += self.occupations.calculate_band_energy(self)
109        except AttributeError:
110            pass
111
112        return self.kptband_comm.sum(e_band)
113
114    def calculate_density_contribution(self, nt_sG):
115        """Calculate contribution to pseudo density from wave functions.
116
117        Array entries are written to (not added to)."""
118        nt_sG.fill(0.0)
119        for kpt in self.kpt_u:
120            self.add_to_density_from_k_point(nt_sG, kpt)
121        self.kptband_comm.sum(nt_sG)
122
123        self.timer.start('Symmetrize density')
124        for nt_G in nt_sG:
125            self.kd.symmetry.symmetrize(nt_G, self.gd)
126        self.timer.stop('Symmetrize density')
127
128    def add_to_density_from_k_point(self, nt_sG, kpt):
129        self.add_to_density_from_k_point_with_occupation(nt_sG, kpt, kpt.f_n)
130
131    def get_orbital_density_matrix(self, a, kpt, n):
132        """Add the nth band density from kpt to density matrix D_sp"""
133        ni = self.setups[a].ni
134        D_sii = np.zeros((self.nspins, ni, ni))
135        P_i = kpt.P_ani[a][n]
136        D_sii[kpt.s] += np.outer(P_i.conj(), P_i).real
137        D_sp = [pack(D_ii) for D_ii in D_sii]
138        return D_sp
139
140    def calculate_atomic_density_matrices_k_point(self, D_sii, kpt, a, f_n):
141        if kpt.rho_MM is not None:
142            P_Mi = self.P_aqMi[a][kpt.q]
143            rhoP_Mi = np.zeros_like(P_Mi)
144            D_ii = np.zeros(D_sii[kpt.s].shape, kpt.rho_MM.dtype)
145            gemm(1.0, P_Mi, kpt.rho_MM, 0.0, rhoP_Mi)
146            gemm(1.0, rhoP_Mi, P_Mi.T.conj().copy(), 0.0, D_ii)
147            D_sii[kpt.s] += D_ii.real
148        else:
149            if self.collinear:
150                P_ni = kpt.projections[a]
151                D_sii[kpt.s] += np.dot(P_ni.T.conj() * f_n, P_ni).real
152            else:
153                P_nsi = kpt.projections[a]
154                D_ssii = np.einsum('nsi,n,nzj->szij',
155                                   P_nsi.conj(), f_n, P_nsi)
156                D_sii[0] += (D_ssii[0, 0] + D_ssii[1, 1]).real
157                D_sii[1] += 2 * D_ssii[0, 1].real
158                D_sii[2] += 2 * D_ssii[0, 1].imag
159                D_sii[3] += (D_ssii[0, 0] - D_ssii[1, 1]).real
160
161        if hasattr(kpt, 'c_on'):
162            for ne, c_n in zip(kpt.ne_o, kpt.c_on):
163                d_nn = ne * np.outer(c_n.conj(), c_n)
164                D_sii[kpt.s] += np.dot(P_ni.T.conj(), np.dot(d_nn, P_ni)).real
165
166    def calculate_atomic_density_matrices(self, D_asp):
167        """Calculate atomic density matrices from projections."""
168        f_un = [kpt.f_n for kpt in self.kpt_u]
169        self.calculate_atomic_density_matrices_with_occupation(D_asp, f_un)
170
171    def calculate_atomic_density_matrices_with_occupation(self, D_asp, f_un):
172        """Calculate atomic density matrices from projections with
173        custom occupation f_un."""
174
175        # Parameter check (if user accidentally passes f_n instead of f_un)
176        if f_un[0] is not None:  # special case for transport calculations...
177            assert isinstance(f_un[0], np.ndarray)
178        # Varying f_n used in calculation of response part of GLLB-potential
179        for a, D_sp in D_asp.items():
180            ni = self.setups[a].ni
181            D_sii = np.zeros((len(D_sp), ni, ni))
182            for f_n, kpt in zip(f_un, self.kpt_u):
183                self.calculate_atomic_density_matrices_k_point(D_sii, kpt, a,
184                                                               f_n)
185            D_sp[:] = [pack(D_ii) for D_ii in D_sii]
186            self.kptband_comm.sum(D_sp)
187
188        self.symmetrize_atomic_density_matrices(D_asp)
189
190    def symmetrize_atomic_density_matrices(self, D_asp):
191        if len(self.kd.symmetry.op_scc) == 0:
192            return
193
194        a_sa = self.kd.symmetry.a_sa
195        D_asp.redistribute(self.atom_partition.as_serial())
196        for s in range(self.nspins):
197            D_aii = [unpack2(D_asp[a][s])
198                     for a in range(len(D_asp))]
199            for a, D_ii in enumerate(D_aii):
200                setup = self.setups[a]
201                D_asp[a][s] = pack(setup.symmetrize(a, D_aii, a_sa))
202        D_asp.redistribute(self.atom_partition)
203
204    def calculate_occupation_numbers(self, fixed_fermi_level=False):
205        if self.collinear and self.nspins == 1:
206            degeneracy = 2
207        else:
208            degeneracy = 1
209
210        f_qn, fermi_levels, e_entropy = self.occupations.calculate(
211            nelectrons=self.nvalence / degeneracy,
212            eigenvalues=[kpt.eps_n * Ha for kpt in self.kpt_u],
213            weights=[kpt.weightk for kpt in self.kpt_u],
214            fermi_levels_guess=self.fermi_levels * Ha
215            if self.fermi_levels is not None else None)
216
217        if not fixed_fermi_level or self.fermi_levels is None:
218            self.fermi_levels = np.array(fermi_levels) / Ha
219
220        for f_n, kpt in zip(f_qn, self.kpt_u):
221            kpt.f_n = f_n * (kpt.weightk * degeneracy)
222
223        return e_entropy * degeneracy / Ha
224
225    def set_positions(self, spos_ac, atom_partition=None):
226        self.positions_set = False
227        # rank_a = self.gd.get_ranks_from_positions(spos_ac)
228        # atom_partition = AtomPartition(self.gd.comm, rank_a)
229        # XXX pass AtomPartition around instead of spos_ac?
230        # All the classes passing around spos_ac end up needing the ranks
231        # anyway.
232
233        if atom_partition is None:
234            rank_a = self.gd.get_ranks_from_positions(spos_ac)
235            atom_partition = AtomPartition(self.gd.comm, rank_a)
236
237        if self.atom_partition is not None and self.kpt_u[0].P_ani is not None:
238            with self.timer('Redistribute'):
239                for kpt in self.kpt_u:
240                    P = kpt.projections
241                    assert self.atom_partition == P.atom_partition
242                    kpt.projections = P.redist(atom_partition)
243                    assert atom_partition == kpt.projections.atom_partition
244
245        self.atom_partition = atom_partition
246        self.kd.symmetry.check(spos_ac)
247        self.spos_ac = spos_ac
248
249    def allocate_arrays_for_projections(self, my_atom_indices):  # XXX unused
250        if not self.positions_set and self.kpt_u[0]._projections is not None:
251            # Projections have been read from file - don't delete them!
252            pass
253        else:
254            nproj_a = [setup.ni for setup in self.setups]
255            for kpt in self.kpt_u:
256                kpt.projections = Projections(
257                    self.bd.nbands, nproj_a,
258                    self.atom_partition,
259                    self.bd.comm,
260                    collinear=self.collinear, spin=kpt.s, dtype=self.dtype)
261
262    def collect_eigenvalues(self, k, s):
263        return self.collect_array('eps_n', k, s)
264
265    def collect_occupations(self, k, s):
266        return self.collect_array('f_n', k, s)
267
268    def collect_array(self, name, k, s, subset=None):
269        """Helper method for collect_eigenvalues and collect_occupations.
270
271        For the parallel case find the rank in kpt_comm that contains
272        the (k,s) pair, for this rank, collect on the corresponding
273        domain a full array on the domain master and send this to the
274        global master."""
275
276        kpt_qs = self.kpt_qs
277        kpt_rank, q = self.kd.get_rank_and_index(k)
278        if self.kd.comm.rank == kpt_rank:
279            a_nx = getattr(kpt_qs[q][s], name)
280
281            if subset is not None:
282                a_nx = a_nx[subset]
283
284            # Domain master send this to the global master
285            if self.gd.comm.rank == 0:
286                if self.bd.comm.size == 1:
287                    if kpt_rank == 0:
288                        return a_nx
289                    else:
290                        self.kd.comm.ssend(a_nx, 0, 1301)
291                else:
292                    b_nx = self.bd.collect(a_nx)
293                    if self.bd.comm.rank == 0:
294                        if kpt_rank == 0:
295                            return b_nx
296                        else:
297                            self.kd.comm.ssend(b_nx, 0, 1301)
298
299        elif self.world.rank == 0 and kpt_rank != 0:
300            # Only used to determine shape and dtype of receiving buffer:
301            a_nx = getattr(kpt_qs[0][0], name)
302
303            if subset is not None:
304                a_nx = a_nx[subset]
305
306            b_nx = np.zeros((self.bd.nbands,) + a_nx.shape[1:],
307                            dtype=a_nx.dtype)
308            self.kd.comm.receive(b_nx, kpt_rank, 1301)
309            return b_nx
310
311        return np.zeros(0)  # see comment in get_wave_function_array() method
312
313    def collect_auxiliary(self, value, k, s, shape=1, dtype=float):
314        """Helper method for collecting band-independent scalars/arrays.
315
316        For the parallel case find the rank in kpt_comm that contains
317        the (k,s) pair, for this rank, collect on the corresponding
318        domain a full array on the domain master and send this to the
319        global master."""
320
321        kpt_rank, q = self.kd.get_rank_and_index(k)
322
323        if self.kd.comm.rank == kpt_rank:
324            if isinstance(value, str):
325                a_o = getattr(self.kpt_qs[q][s], value)
326            else:
327                u = q * self.nspins + s
328                a_o = value[u]  # assumed list
329
330            # Make sure data is a mutable object
331            a_o = np.asarray(a_o)
332
333            if a_o.dtype is not dtype:
334                a_o = a_o.astype(dtype)
335
336            # Domain master send this to the global master
337            if self.gd.comm.rank == 0:
338                if kpt_rank == 0:
339                    return a_o
340                else:
341                    self.kd.comm.send(a_o, 0, 1302)
342
343        elif self.world.rank == 0 and kpt_rank != 0:
344            b_o = np.zeros(shape, dtype=dtype)
345            self.kd.comm.receive(b_o, kpt_rank, 1302)
346            return b_o
347
348    def collect_projections(self, k, s):
349        """Helper method for collecting projector overlaps across domains.
350
351        For the parallel case find the rank in kpt_comm that contains
352        the (k,s) pair, for this rank, send to the global master."""
353
354        kpt_rank, q = self.kd.get_rank_and_index(k)
355
356        if self.kd.comm.rank == kpt_rank:
357            kpt = self.kpt_qs[q][s]
358            P_nI = kpt.projections.collect()
359            if self.world.rank == 0:
360                return P_nI
361            if P_nI is not None:
362                self.kd.comm.send(np.ascontiguousarray(P_nI), 0, tag=117)
363        if self.world.rank == 0:
364            nproj = sum(setup.ni for setup in self.setups)
365            if not self.collinear:
366                nproj *= 2
367            P_nI = np.empty((self.bd.nbands, nproj), self.dtype)
368            self.kd.comm.receive(P_nI, kpt_rank, tag=117)
369            return P_nI
370
371    def get_wave_function_array(self, n, k, s, realspace=True, periodic=False):
372        """Return pseudo-wave-function array on master.
373
374        n: int
375            Global band index.
376        k: int
377            Global IBZ k-point index.
378        s: int
379            Spin index (0 or 1).
380        realspace: bool
381            Transform plane wave or LCAO expansion coefficients to real-space.
382
383        For the parallel case find the ranks in kd.comm and bd.comm
384        that contains to (n, k, s), and collect on the corresponding
385        domain a full array on the domain master and send this to the
386        global master."""
387
388        kpt_rank, q = self.kd.get_rank_and_index(k)
389        band_rank, myn = self.bd.who_has(n)
390
391        rank = self.world.rank
392
393        if (self.kd.comm.rank == kpt_rank and
394            self.bd.comm.rank == band_rank):
395            u = q * self.nspins + s
396            psit_G = self._get_wave_function_array(u, myn,
397                                                   realspace, periodic)
398
399            if realspace:
400                psit_G = self.gd.collect(psit_G)
401
402            if rank == 0:
403                return psit_G
404
405            # Domain master send this to the global master
406            if self.gd.comm.rank == 0:
407                psit_G = np.ascontiguousarray(psit_G)
408                self.world.ssend(psit_G, 0, 1398)
409
410        if rank == 0:
411            # allocate full wave function and receive
412            shape = () if self.collinear else (2,)
413            psit_G = self.empty(shape, global_array=True,
414                                realspace=realspace)
415            # XXX this will fail when using non-standard nesting
416            # of communicators.
417            world_rank = (kpt_rank * self.gd.comm.size *
418                          self.bd.comm.size +
419                          band_rank * self.gd.comm.size)
420            self.world.receive(psit_G, world_rank, 1398)
421            return psit_G
422
423        # We return a number instead of None on all the slaves.  Most of
424        # the time the return value will be ignored on the slaves, but
425        # in some cases it will be multiplied by some other number and
426        # then ignored.  Allowing for this will simplify some code here
427        # and there.
428        return np.nan
429
430    def get_homo_lumo(self, spin=None):
431        """Return HOMO and LUMO eigenvalues."""
432        if spin is None:
433            if self.nspins == 1:
434                return self.get_homo_lumo(0)
435            h0, l0 = self.get_homo_lumo(0)
436            h1, l1 = self.get_homo_lumo(1)
437            return np.array([max(h0, h1), min(l0, l1)])
438
439        n = self.nvalence // 2
440        band_rank, myn = self.bd.who_has(n - 1)
441        homo = -np.inf
442        if self.bd.comm.rank == band_rank:
443            for kpt in self.kpt_u:
444                if kpt.s == spin:
445                    homo = max(kpt.eps_n[myn], homo)
446        homo = self.world.max(homo)
447
448        lumo = np.inf
449        if n < self.bd.nbands:  # there are not enough bands for LUMO
450            band_rank, myn = self.bd.who_has(n)
451            if self.bd.comm.rank == band_rank:
452                for kpt in self.kpt_u:
453                    if kpt.s == spin:
454                        lumo = min(kpt.eps_n[myn], lumo)
455            lumo = self.world.min(lumo)
456
457        return np.array([homo, lumo])
458
459    def write(self, writer):
460        writer.write(version=2, ha=Ha)
461        writer.write(fermi_levels=self.fermi_levels * Ha)
462        writer.write(kpts=self.kd)
463        self.write_projections(writer)
464        self.write_eigenvalues(writer)
465        self.write_occupations(writer)
466
467    def write_projections(self, writer):
468        nproj = sum(setup.ni for setup in self.setups)
469
470        if self.collinear:
471            shape = (self.nspins, self.kd.nibzkpts, self.bd.nbands, nproj)
472        else:
473            shape = (self.kd.nibzkpts, self.bd.nbands, 2, nproj)
474
475        writer.add_array('projections', shape, self.dtype)
476
477        for s in range(self.nspins):
478            for k in range(self.kd.nibzkpts):
479                P_nI = self.collect_projections(k, s)
480                if not self.collinear and P_nI is not None:
481                    P_nI.shape = (self.bd.nbands, 2, nproj)
482                writer.fill(P_nI)
483
484    def write_eigenvalues(self, writer):
485        if self.collinear:
486            shape = (self.nspins, self.kd.nibzkpts, self.bd.nbands)
487        else:
488            shape = (self.kd.nibzkpts, self.bd.nbands)
489
490        writer.add_array('eigenvalues', shape)
491        for s in range(self.nspins):
492            for k in range(self.kd.nibzkpts):
493                writer.fill(self.collect_eigenvalues(k, s) * Ha)
494
495    def write_occupations(self, writer):
496
497        if self.collinear:
498            shape = (self.nspins, self.kd.nibzkpts, self.bd.nbands)
499        else:
500            shape = (self.kd.nibzkpts, self.bd.nbands)
501
502        writer.add_array('occupations', shape)
503        for s in range(self.nspins):
504            for k in range(self.kd.nibzkpts):
505                # Scale occupation numbers when writing:
506                # XXX fix this in the code also ...
507                weight = self.kd.weight_k[k] * 2 / self.nspins
508                writer.fill(self.collect_occupations(k, s) / weight)
509
510    def read(self, reader):
511        r = reader.wave_functions
512        # Backward compatibility:
513        # Take parameters from main reader
514        if 'ha' not in r:
515            r.ha = reader.ha
516        if 'version' not in r:
517            r.version = reader.version
518
519        if reader.version >= 3:
520            self.fermi_levels = r.fermi_levels / r.ha
521        else:
522            o = reader.occupations
523            self.fermi_levels = np.array(
524                [o.fermilevel + o.split / 2,
525                 o.fermilevel - o.split / 2]) / r.ha
526            if self.occupations.name != 'fixmagmom':
527                assert o.split == 0.0
528                self.fermi_levels = self.fermi_levels[:1]
529
530        if reader.version >= 2:
531            kpts = r.kpts
532            assert np.allclose(kpts.ibzkpts, self.kd.ibzk_kc)
533            assert np.allclose(kpts.bzkpts, self.kd.bzk_kc)
534            assert (kpts.bz2ibz == self.kd.bz2ibz_k).all()
535            assert np.allclose(kpts.weights, self.kd.weight_k)
536
537        self.read_projections(r)
538        self.read_eigenvalues(r, r.version <= 0)
539        self.read_occupations(r, r.version <= 0)
540
541    def read_projections(self, reader):
542        nslice = self.bd.get_slice()
543        nproj_a = [setup.ni for setup in self.setups]
544        atom_partition = AtomPartition(self.gd.comm,
545                                       np.zeros(len(nproj_a), int))
546        for u, kpt in enumerate(self.kpt_u):
547            if self.collinear:
548                index = (kpt.s, kpt.k)
549            else:
550                index = (kpt.k,)
551            kpt.projections = Projections(
552                self.bd.nbands, nproj_a,
553                atom_partition, self.bd.comm,
554                collinear=self.collinear, spin=kpt.s, dtype=self.dtype)
555            if self.gd.comm.rank == 0:
556                P_nI = reader.proxy('projections', *index)[nslice]
557                if not self.collinear:
558                    P_nI.shape = (self.bd.mynbands, -1)
559                kpt.projections.matrix.array[:] = P_nI
560
561    def read_eigenvalues(self, reader, old=False):
562        nslice = self.bd.get_slice()
563        for u, kpt in enumerate(self.kpt_u):
564            if self.collinear:
565                index = (kpt.s, kpt.k)
566            else:
567                index = (kpt.k,)
568            eps_n = reader.proxy('eigenvalues', *index)[nslice]
569            x = self.bd.mynbands - len(eps_n)  # missing bands?
570            if x > 0:
571                # Working on a real fix to this parallelization problem ...
572                eps_n = np.pad(eps_n, (0, x), 'constant')
573            if not old:  # skip for old tar-files gpw's
574                eps_n /= reader.ha
575            kpt.eps_n = eps_n
576
577    def read_occupations(self, reader, old=False):
578        nslice = self.bd.get_slice()
579        for u, kpt in enumerate(self.kpt_u):
580            if self.collinear:
581                index = (kpt.s, kpt.k)
582            else:
583                index = (kpt.k,)
584            f_n = reader.proxy('occupations', *index)[nslice]
585            x = self.bd.mynbands - len(f_n)  # missing bands?
586            if x > 0:
587                # Working on a real fix to this parallelization problem ...
588                f_n = np.pad(f_n, (0, x), 'constant')
589            if not old:  # skip for old tar-files gpw's
590                f_n *= kpt.weight
591            kpt.f_n = f_n
592
593
594def eigenvalue_string(wfs, comment=' '):
595    """Write eigenvalues and occupation numbers into a string.
596
597    The parameter comment can be used to comment out non-numers,
598    for example to escape it for gnuplot.
599    """
600    tokens = []
601
602    def add(*line):
603        for token in line:
604            tokens.append(token)
605        tokens.append('\n')
606
607    def eigs(k, s):
608        eps_n = wfs.collect_eigenvalues(k, s)
609        return eps_n * Ha
610
611    def occs(k, s):
612        occ_n = wfs.collect_occupations(k, s)
613        return occ_n / wfs.kd.weight_k[k]
614
615    if len(wfs.kd.ibzk_kc) == 1:
616        if wfs.nspins == 1:
617            add(comment, 'Band  Eigenvalues  Occupancy')
618            eps_n = eigs(0, 0)
619            f_n = occs(0, 0)
620            if wfs.world.rank == 0:
621                for n in range(wfs.bd.nbands):
622                    add('%5d  %11.5f  %9.5f' % (n, eps_n[n], f_n[n]))
623        else:
624            add(comment, '                  Up                     Down')
625            add(comment, 'Band  Eigenvalues  Occupancy  Eigenvalues  '
626                'Occupancy')
627            epsa_n = eigs(0, 0)
628            epsb_n = eigs(0, 1)
629            fa_n = occs(0, 0)
630            fb_n = occs(0, 1)
631            if wfs.world.rank == 0:
632                for n in range(wfs.bd.nbands):
633                    add('%5d  %11.5f  %9.5f  %11.5f  %9.5f' %
634                        (n, epsa_n[n], fa_n[n], epsb_n[n], fb_n[n]))
635        return ''.join(tokens)
636
637    if len(wfs.kd.ibzk_kc) > 2:
638        add('Showing only first 2 kpts')
639        print_range = 2
640    else:
641        add('Showing all kpts')
642        print_range = len(wfs.kd.ibzk_kc)
643
644    if wfs.nvalence / 2. > 2:
645        m = int(wfs.nvalence / 2. - 2)
646    else:
647        m = 0
648    if wfs.bd.nbands - wfs.nvalence / 2. > 2:
649        j = int(wfs.nvalence / 2. + 2)
650    else:
651        j = int(wfs.bd.nbands)
652
653    if wfs.nspins == 1:
654        add(comment, 'Kpt  Band  Eigenvalues  Occupancy')
655        for i in range(print_range):
656            eps_n = eigs(i, 0)
657            f_n = occs(i, 0)
658            if wfs.world.rank == 0:
659                for n in range(m, j):
660                    add('%3i %5d  %11.5f  %9.5f' % (i, n, eps_n[n], f_n[n]))
661                add()
662    else:
663        add(comment, '                     Up                     Down')
664        add(comment, 'Kpt  Band  Eigenvalues  Occupancy  Eigenvalues  '
665            'Occupancy')
666
667        for i in range(print_range):
668            epsa_n = eigs(i, 0)
669            epsb_n = eigs(i, 1)
670            fa_n = occs(i, 0)
671            fb_n = occs(i, 1)
672            if wfs.world.rank == 0:
673                for n in range(m, j):
674                    add('%3i %5d  %11.5f  %9.5f  %11.5f  %9.5f' %
675                        (i, n, epsa_n[n], fa_n[n], epsb_n[n], fb_n[n]))
676                add()
677    return ''.join(tokens)
678