1# coding: utf-8
2"""PSPS file with tabulated data."""
3import numpy as np
4
5from collections import OrderedDict
6from monty.bisect import find_gt
7from monty.functools import lazy_property
8from abipy.tools.plotting import add_fig_kwargs, get_ax_fig_plt, get_axarray_fig_plt
9from abipy.iotools import ETSF_Reader
10from abipy.core.mixins import AbinitNcFile
11
12import logging
13logger = logging.getLogger(__name__)
14
15
16def mklabel(fsym, der, arg):
17    """mklabel(f, 2, x) --> $f''(x)$"""
18    if der == 0:
19        return "$%s(%s)$" % (fsym, arg)
20    else:
21        fsym = fsym + "^{" + (der * r"\prime") + "}"
22        return "$%s(%s)$" % (fsym, arg)
23
24
25def rescale(arr, scale=1.0):
26    if scale is None:
27        return arr, 0.0
28
29    amax = np.abs(arr).max()
30    fact = scale / amax if amax != 0 else 1
31    return fact * arr, fact
32
33
34def dataframe_from_pseudos(pseudos, index=None):
35    """
36    Build pandas dataframe with the most important info associated to
37    a list of pseudos or a list of objects that can be converted into pseudos.
38
39    Args:
40        pseudos: List of objects that can be converted to pseudos.
41        index: Index of the dataframe.
42
43    Return: pandas Dataframe.
44    """
45    from abipy.flowtk import PseudoTable
46    pseudos = PseudoTable.as_table(pseudos)
47
48    import pandas as pd
49    attname = ["Z_val", "l_max", "l_local", "nlcc_radius", "xc", "supports_soc", "type"]
50    rows = []
51    for p in pseudos:
52        row = OrderedDict([(k, getattr(p, k, None)) for k in attname])
53        row["ecut_normal"], row["pawecutdg_normal"] = None, None
54        if p.has_hints:
55            hint = p.hint_for_accuracy(accuracy="normal")
56            row["ecut_normal"] = hint.ecut
57            if hint.pawecutdg: row["pawecutdg_normal"] = hint.pawecutdg
58        rows.append(row)
59
60    return pd.DataFrame(rows, index=index, columns=list(rows[0].keys()) if rows else None)
61
62
63class PspsFile(AbinitNcFile):
64    """
65    Netcdf file with the tables used in Abinit to apply the
66    pseudopotential part of the KS Hamiltonian.
67
68    Usage example:
69
70    .. code-block:: python
71
72        with PspsFile("foo_PSPS.nc") as psps:
73            psps.plot_tcore_rspace()
74    """
75    linestyles_der = ["-", "--", '-.', ':', ":", ":"]
76    color_der = ["black", "red", "green", "orange", "cyan"]
77
78    @classmethod
79    def from_file(cls, filepath):
80        """Initialize the object from a Netcdf file"""
81        return cls(filepath)
82
83    def __init__(self, filepath):
84        super().__init__(filepath)
85        self.reader = r = PspsReader(filepath)
86
87    def close(self):
88        """Close the file."""
89        self.reader.close()
90
91    @lazy_property
92    def params(self):
93        """:class:`OrderedDict` with parameters that might be subject to convergence studies."""
94        return {}
95
96    @add_fig_kwargs
97    def plot(self, **kwargs):
98        """
99        Driver routine to plot several quantities on the same graph.
100
101        Args:
102            ecut_ffnl: Max cutoff energy for ffnl plot (optional)
103
104        Return: |matplotlib-Figure|
105        """
106        methods = [
107            "plot_tcore_rspace",
108            "plot_tcore_qspace",
109            "plot_ffspl",
110            "plot_vlocq",
111        ]
112
113        ax_list, fig, plt = get_axarray_fig_plt(None, nrows=2, ncols=2,
114                                                sharex=False, sharey=False, squeeze=True)
115
116        ecut_ffnl = kwargs.pop("ecut_ffnl", None)
117        for m, ax in zip(methods, ax_list.ravel()):
118            getattr(self, m)(ax=ax, ecut_ffnl=ecut_ffnl, show=False)
119
120        return fig
121
122    @add_fig_kwargs
123    def plot_tcore_rspace(self, ax=None, ders=(0, 1, 2, 3), rmax=3.0,  **kwargs):
124        """
125        Plot the model core and its derivatives in real space.
126
127        Args:
128            ax: |matplotlib-Axes| or None if a new figure should be created.
129            ders: Tuple used to select the derivatives to be plotted.
130            rmax: Max radius for plot in Bohr. None is full grid is wanted.
131
132        Returns: |matplotlib-Figure|
133        """
134        ax, fig, plt = get_ax_fig_plt(ax=ax)
135
136        linewidth = kwargs.pop("linewidth", 2.0)
137        rmeshes, coresd = self.reader.read_coresd(rmax=rmax)
138
139        scale = None
140        scale = 1.0
141        for rmesh, mcores in zip(rmeshes, coresd):
142            for der, values in enumerate(mcores):
143                if der not in ders: continue
144                yvals, fact, = rescale(values, scale=scale)
145                ax.plot(rmesh, yvals, color=self.color_der[der], linewidth=linewidth,
146                        linestyle=self.linestyles_der[der],
147                        label=mklabel("\\tilde{n}_c", der, "r") + " x %.4f" % fact)
148
149        ax.grid(True)
150        ax.set_xlabel("r [Bohr]")
151        ax.set_title("Model core in r-space")
152        if kwargs.get("with_legend", False): ax.legend(loc="best")
153
154        return fig
155
156    @add_fig_kwargs
157    def plot_tcore_qspace(self, ax=None, ders=(0,), with_fact=True, with_qn=0, **kwargs):
158        """
159        Plot the model core in q space
160
161        Args:
162            ax: |matplotlib-Axes| or None if a new figure should be created.
163            ders: Tuple used to select the derivatives to be plotted.
164            with_qn:
165
166        Returns: |matplotlib-Figure|
167        """
168        ax, fig, plt = get_ax_fig_plt(ax=ax)
169
170        color = kwargs.pop("color", "black")
171        linewidth = kwargs.pop("linewidth", 2.0)
172
173        qmesh, tcore_spl = self.reader.read_tcorespl()
174        #print(qmesh, tcore_spl)
175        ecuts = 2 * (np.pi * qmesh)**2
176        lines = []
177        scale = 1.0
178        scale = None
179        for atype, tcore_atype in enumerate(tcore_spl):
180            for der, values in enumerate(tcore_atype):
181                if der == 1: der = 2
182                if der not in ders: continue
183                yvals, fact = rescale(values, scale=scale)
184
185                label = mklabel("\\tilde{n}_{c}", der, "q")
186                if with_fact: label += " x %.4f" % fact
187
188                line, = ax.plot(ecuts, yvals, color=color, linewidth=linewidth,
189                                linestyle=self.linestyles_der[der], label=label)
190                lines.append(line)
191
192                if with_qn and der == 0:
193                    yvals, fact = rescale(qmesh * values, scale=scale)
194                    line, ax.plot(ecuts, yvals, color=color, linewidth=linewidth,
195                                  label=mklabel("q f", der, "q") + " x %.4f" % fact)
196
197                    lines.append(line)
198
199        ax.grid(True)
200        ax.set_xlabel("Ecut [Hartree]")
201        ax.set_title("Model core in q-space")
202        if kwargs.get("with_legend", False): ax.legend(loc="best")
203
204        return fig
205
206    @add_fig_kwargs
207    def plot_vlocq(self, ax=None, ders=(0,), with_qn=0, with_fact=True, **kwargs):
208        """
209        Plot the local part of the pseudopotential in q space.
210
211        Args:
212            ax: |matplotlib-Axes| or None if a new figure should be created.
213            ders: Tuple used to select the derivatives to be plotted.
214            with_qn:
215
216        Returns: |matplotlib-Figure|
217        """
218        ax, fig, plt = get_ax_fig_plt(ax=ax)
219
220        color = kwargs.pop("color", "black")
221        linewidth = kwargs.pop("linewidth", 2.0)
222
223        qmesh, vlspl = self.reader.read_vlspl()
224        ecuts = 2 * (np.pi * qmesh)**2
225        scale = 1.0
226        scale = None
227        for atype, vl_atype in enumerate(vlspl):
228            for der, values in enumerate(vl_atype):
229                if der == 1: der = 2
230                if der not in ders: continue
231
232                yvals, fact = rescale(values, scale=scale)
233                label = mklabel("v_{loc}", der, "q")
234                if with_fact: label += " x %.4f" % fact
235
236                ax.plot(ecuts, yvals, color=color, linewidth=linewidth,
237                        linestyle=self.linestyles_der[der], label=label)
238
239                if with_qn and der == 0:
240                    yvals, fact = rescale(qmesh * values, scale=scale)
241                    ax.plot(ecuts, yvals, color=color, linewidth=linewidth,
242                            label="q*f(q) x %2.f" % fact)
243
244        ax.grid(True)
245        ax.set_xlabel("Ecut [Hartree]")
246        ax.set_title("Vloc(q)")
247        if kwargs.get("with_legend", False): ax.legend(loc="best")
248
249        return fig
250
251    @add_fig_kwargs
252    def plot_ffspl(self, ax=None, ecut_ffnl=None, ders=(0,), with_qn=0, with_fact=False, **kwargs):
253        """
254        Plot the nonlocal part of the pseudopotential in q-space.
255
256        Args:
257            ax: |matplotlib-Axes| or None if a new figure should be created.
258            ecut_ffnl: Max cutoff energy for ffnl plot (optional)
259            ders: Tuple used to select the derivatives to be plotted.
260            with_qn:
261
262        Returns: |matplotlib-Figure|
263        """
264        ax, fig, plt = get_ax_fig_plt(ax=ax)
265
266        color = kwargs.pop("color", "black")
267        linewidth = kwargs.pop("linewidth", 2.0)
268
269        color_l = {-1: "black", 0: "red", 1: "blue", 2: "green", 3: "orange"}
270        linestyles_n = ["solid", '-', '--', '-.', ":"]
271        scale = None
272        l_seen = set()
273
274        qmesh, vlspl = self.reader.read_vlspl()
275
276        all_projs = self.reader.read_projectors()
277        for itypat, projs_type in enumerate(all_projs):
278            # Loop over the projectors for this atom type.
279            for p in projs_type:
280                for der, values in enumerate(p.data):
281                    if der == 1: der = 2
282                    if der not in ders: continue
283                    #yvals, fact = rescale(values, scale=scale)
284                    label = None
285                    if p.l not in l_seen:
286                        l_seen.add(p.l)
287                        label = mklabel("v_{nl}", der, "q") + ", l=%d" % p.l
288
289                    stop = len(p.ecuts)
290                    if ecut_ffnl is not None:
291                        stop = find_gt(p.ecuts, ecut_ffnl)
292
293                    #values = p.ekb * p.values - vlspl[itypat, 0, :]
294                    values = vlspl[itypat, 0, :] + p.sign_sqrtekb * p.values
295
296                    #print(values.min(), values.max())
297                    ax.plot(p.ecuts[:stop], values[:stop], color=color_l[p.l], linewidth=linewidth,
298                            linestyle=linestyles_n[p.n], label=label)
299
300        ax.grid(True)
301        ax.set_xlabel("Ecut [Hartree]")
302        ax.set_title("ffnl(q)")
303        if kwargs.get("with_legend", False): ax.legend(loc="best")
304
305        ax.axhline(y=0, linewidth=linewidth, color='k', linestyle="solid")
306        fig.tight_layout()
307
308        return fig
309
310    @add_fig_kwargs
311    def compare(self, others, **kwargs):
312        """Produce matplotlib plot comparing self with another list of pseudos ``others``."""
313        if not isinstance(others, (list, tuple)):
314            others = [others]
315
316        ax_list, fig, plt = get_axarray_fig_plt(None, nrows=2, ncols=2,
317                                                sharex=False, sharey=False, squeeze=True)
318        ax_list = ax_list.ravel()
319        #fig.suptitle("%s vs %s" % (self.basename, ", ".join(o.basename for o in others)))
320
321        def mkcolor(count):
322            npseudos = 1 + len(others)
323            if npseudos <= 2:
324                return {0: "red", 1: "blue"}[count]
325            else:
326                cmap = plt.get_cmap("jet")
327                return cmap(float(count) / (1 + len(others)))
328
329        ic = 0; ax = ax_list[ic]
330        self.plot_tcore_rspace(ax=ax, color=mkcolor(0), show=False, with_legend=False)
331        for count, other in enumerate(others):
332            other.plot_tcore_rspace(ax=ax, color=mkcolor(count+1), show=False, with_legend=False)
333
334        ic += 1; ax = ax_list[ic]
335        self.plot_tcore_qspace(ax=ax, with_qn=0, color=mkcolor(0), show=False)
336        for count, other in enumerate(others):
337            other.plot_tcore_qspace(ax=ax, with_qn=0, color=mkcolor(count+1), show=False)
338
339        ic += 1; ax = ax_list[ic]
340        self.plot_vlocq(ax=ax, with_qn=0, color=mkcolor(0), show=False)
341        for count, other in enumerate(others):
342            other.plot_vlocq(ax=ax, with_qn=0, color=mkcolor(count+1), show=False)
343
344        ic += 1; ax = ax_list[ic]
345        self.plot_ffspl(ax=ax, with_qn=0, color=mkcolor(0), show=False)
346        for count, other in enumerate(others):
347            other.plot_ffspl(ax=ax, with_qn=0, color=mkcolor(count+1), show=False)
348
349        return fig
350
351
352class PspsReader(ETSF_Reader):
353    """
354    This object reads the results stored in the PSPS file produced by ABINIT.
355    It provides helper function to access the most important quantities.
356    """
357    def __init__(self, filepath):
358        super().__init__(filepath)
359
360        # Get important quantities.
361        self.usepaw, self.useylm = self.read_value("usepaw"), self.read_value("useylm")
362        assert self.usepaw == 0 and self.useylm == 0
363        self.ntypat = self.read_dimvalue("ntypat")
364        self.lmnmax = self.read_dimvalue("lmnmax")
365        self.indlmn = self.read_value("indlmn")
366
367        self.znucl_typat = self.read_value("znucltypat")
368        self.zion_typat = self.read_value("ziontypat")
369
370        # TODO
371        #self.psps_files = []
372        #for strng in r.read_value("filpsp"):
373        #    s = "".join(strng)
374        #    print(s)
375        #    self.psps_files.append(s)
376        #print(self.psps_files)
377
378    def read_coresd(self, rmax=None):
379        """
380        Read the core charges and derivatives for the different types of atoms.
381
382        Args:
383            rmax: Maximum radius in Bohr. If None, data on the full grid is returned.
384
385        Returns:
386            meshes: List of ntypat arrays. Each array contains the linear meshes in real space.
387            coresd: List with nytpat arrays of shape [6, npts].
388
389            (np.zeros. np.zeros) if core charge is not present
390
391        xccc1d[ntypat6,n1xccc*(1-usepaw)]
392
393        Norm-conserving psps only
394        The component xccc1d(n1xccc,1,ntypat) is the pseudo-core charge
395        for each type of atom, on the radial grid. The components
396        xccc1d(n1xccc,ideriv,ntypat) give the ideriv-th derivative of the
397        pseudo-core charge with respect to the radial distance.
398        """
399
400        xcccrc = self.read_value("xcccrc")
401        try:
402            all_coresd = self.read_value("xccc1d")
403        except self.Error:
404            # model core may not be present!
405            return self.ntypat * [np.linspace(0, 6, num=100)], self.ntypat * [np.zeros((2, 100))]
406
407        npts = all_coresd.shape[-1]
408        rmeshes, coresd = [], []
409        for itypat, rc in enumerate(xcccrc):
410            rvals, step = np.linspace(0, rc, num=npts, retstep=True)
411            ir_stop = -1
412            if rmax is not None:
413                # Truncate mesh
414                ir_stop = min(int(rmax / step), npts) + 1
415                #print(rmax, step, ir_stop, npts)
416
417            rmeshes.append(rvals[:ir_stop])
418            coresd.append(all_coresd[itypat, :, :ir_stop])
419
420        return rmeshes, coresd
421
422    def read_tcorespl(self):
423        """
424        Returns:
425            qmesh: Linear q-mesh in G-space
426            tcorespl:
427
428        tcorespl[ntypat, 2, mqgrid_vl]
429        Gives the pseudo core density in reciprocal space on a regular grid.
430        Only if has_tcore
431        """
432        return self.read_value("qgrid_vl"), self.read_value("nc_tcorespl")
433
434    def read_vlspl(self):
435        """
436        Returns:
437            qmesh: Linear q-mesh in G-space
438            vlspl:
439
440        vlspl[2, ntypat, mqgrid_vl]
441        Gives, on the radial grid, the local part of each type of psp.
442        """
443        return self.read_value("qgrid_vl"), self.read_value("vlspl")
444
445    def read_projectors(self):
446        """
447        ffspl(ntypat, lnmax, 2, mqgrid_ff]
448        Gives, on the radial grid, the different non-local projectors,
449        in both the norm-conserving case, and the PAW case
450        """
451        # ekb(dimekb,ntypat*(1-usepaw))
452        ekb = self.read_value("ekb")
453        qgrid_ff = self.read_value("qgrid_ff")
454        ffspl = self.read_value("ffspl")
455        #print("qgrid", qgrid_ff.min(), qgrid_ff.max())
456
457        projs = self.ntypat * [None]
458        for itypat in range(self.ntypat):
459            projs_type = []
460            ln_list = self.get_lnlist_for_type(itypat)
461            for i, ln in enumerate(ln_list):
462                #print(ffspl[itypat, i, :, :])
463                p = VnlProjector(itypat, ln, ekb[itypat, i], qgrid_ff, ffspl[itypat, i, :, :])
464                projs_type.append(p)
465
466            projs[itypat] = projs_type
467
468        return projs
469
470    def get_lnlist_for_type(self, itypat):
471        """Return a list of (l, n) indices for this atom type."""
472        # indlmn(6,lmn_size,ntypat)=array giving l,m,n,lm,ln,s for i=lmn
473        indlmn_type = self.indlmn[itypat, :, :]
474
475        iln0 = 0; ln_list = []
476        for ilmn in range(self.lmnmax):
477            iln = indlmn_type[ilmn, 4]
478            if iln > iln0:
479                iln0 = iln
480                l = indlmn_type[ilmn, 0]  # l
481                n = indlmn_type[ilmn, 2]  # n
482                ln_list.append((l, n))
483
484        return ln_list
485
486
487class VnlProjector(object):
488    """Data and parameters associated to a non-local projector."""
489    def __init__(self, itypat, ln, ekb, qmesh, data):
490        """
491        Args:
492            itypat:
493            ln: Tuple with l and n.
494            ekb: KB energy in Hartree.
495            qmesh: Mesh of q-points.
496            data: numpy array [2, nqpt]
497        """
498        self.ln = ln
499        self.l, self.n, self.ekb = ln[0], ln[1], ekb
500        self.qmesh, self.data = qmesh, data
501
502        assert len(self.qmesh) == len(self.values)
503        assert len(self.qmesh) == len(self.der2)
504
505    @property
506    def values(self):
507        """Values of the projector in q-space."""
508        return self.data[0, :]
509
510    @property
511    def der2(self):
512        """Second order derivative."""
513        return self.data[1, :]
514
515    @property
516    def ecuts(self):
517        """List of cutoff energies corresponding to self.qmesh."""
518        return 2 * (np.pi * self.qmesh)**2
519
520    @property
521    def sign_sqrtekb(self):
522        return np.sign(self.ekb) * np.sqrt(np.abs(self.ekb))
523