1# coding: utf-8
2"""Scissors operator."""
3import os
4import numpy as np
5import pickle
6
7from collections import OrderedDict
8from monty.collections import AttrDict
9from abipy.tools.plotting import add_fig_kwargs, get_ax_fig_plt
10
11
12__all__ = [
13    "Scissors",
14    "ScissorsBuilder",
15]
16
17
18class ScissorsError(Exception):
19    """Base class for the exceptions raised by :class:`Scissors`"""
20
21
22class Scissors(object):
23    """
24    This object represents an energy-dependent scissors operator.
25    The operator is defined by a list of domains (energy intervals)
26    and a list of functions defined in these domains.
27    The domains should fulfill the constraints documented in the main constructor.
28
29    .. note::
30
31        eV units are assumed.
32
33    The standard way to create this object is via the methods provided by the factory class :class:`ScissorBuilder`.
34    Once the instance has been created, one can correct the band structure by calling the `apply` method.
35    """
36    Error = ScissorsError
37
38    def __init__(self, func_list, domains, residues, bounds=None):
39        """
40        Args:
41            func_list: List of callable objects. Each function takes an eigenvalue and returns
42                the corrected value.
43            domains: Domains of each function. List of tuples [(emin1, emax1), (emin2, emax2), ...]
44            bounds: Specify how to handle energies that do not fall inside one of the domains.
45                At present, only constant boundaries are implemented.
46            residues: A list of the residues of the fitting per domain
47
48        .. note::
49
50            #. Domains should not overlap, cover e0mesh, and given in increasing order.
51
52            #. Holes are permitted but the interpolation will raise an exception if the
53               eigenvalue falls inside the hole.
54
55            #. Errors contains a list of the fitting errors per domain
56
57        """
58        # TODO Add consistency check.
59        self.func_list = func_list
60        self.domains = np.atleast_2d(domains)
61        self.residues = residues
62        assert len(self.func_list) == len(self.domains)
63
64        # Treat the out-of-boundary conditions. func_low and func_high are used to handle energies
65        # that are below or above the min/max energy given in domains.
66        blow, bhigh = "c", "c"
67        if bounds is not None:
68            blow, bhigh = bounds[0][0], bounds[0][1]
69
70        if blow.lower() == "c":
71            try:
72                self.func_low = lambda x: float(bounds[0][1])
73            except Exception:
74                x_low = self.domains[0,0]
75                fx_low = func_list[0](x_low)
76                self.func_low = lambda x: fx_low
77        else:
78            raise NotImplementedError("Only constant boundaries are implemented")
79
80        if bhigh.lower() == "c":
81            try:
82                self.func_high = lambda x: float(bounds[1][1])
83            except Exception:
84                x_high = self.domains[1, -1]
85                fx_high = func_list[-1](x_high)
86                self.func_high = lambda x: fx_high
87        else:
88            raise NotImplementedError("Only constant boundaries are implemented")
89
90        # This counter stores the number of points that are out of bounds.
91        self.out_bounds = np.zeros(3, dtype=int)
92
93    def apply(self, eig):
94        """Correct the eigenvalue eig (eV units)."""
95        # Get the list of domains.
96        domains = self.domains
97
98        if eig < domains[0,0]:
99            # Eig is below the first point of the first domain.
100            # Call func_low
101            print("left ", eig, " < ", domains[0,0])
102            self.out_bounds[0] += 1
103            return self.func_low(eig)
104
105        if eig > domains[-1,1]:
106            # Eig is above the last point of the last domain.
107            # Call func_high
108            print("right ", eig, " > ", domains[-1,1])
109            self.out_bounds[1] += 1
110            return self.func_high(eig)
111
112        # eig is inside the domains: find the domain
113        # and call the corresponding function.
114        for idx, dms in enumerate(domains):
115            if dms[1] >= eig >= dms[0]:
116                return self.func_list[idx](eig)
117
118        self.out_bounds[2] += 1
119        raise self.Error("Cannot find location of eigenvalue %s in domains:\n%s" % (eig, domains))
120
121
122class ScissorsBuilder(object):
123    """
124    This object facilitates the creation of :class:`Scissors` instances.
125
126    Usage:
127
128        builder = ScissorsBuilder.from_file("out_SIGRES.nc")
129
130        # To plot the QP results as function of the KS energy:
131        builder.plot_qpe_vs_e0()
132
133        # To select the domains esplicitly (optional but highly recommended)
134        builder.build(domains_spin=[[-10, 6.02], [6.1, 20]])
135
136        # To compare the fitted results with the ab-initio data:
137        builder.plot_fit()
138
139        # To plot the corrected bands:
140        builder.plot_qpbands(abidata.ref_file("si_nscf_WFK.nc"))
141    """
142
143    @classmethod
144    def from_file(cls, filepath):
145        """
146        Generate object from (SIGRES.nc) file. Main entry point for client code.
147        """
148        from abipy.abilab import abiopen
149        with abiopen(filepath) as ncfile:
150            return cls(qps_spin=ncfile.qplist_spin, sigres_ebands=ncfile.ebands)
151
152    @classmethod
153    def pickle_load(cls, filepath):
154        """Load the object from a pickle file."""
155        with open(filepath, "rb") as fh:
156            d = AttrDict(pickle.load(fh))
157            # Costruct the object and compute the scissors.
158            new = cls(d.qps_spin, d.sigres_ebands)
159            new.build(d.domains_spin, d.bounds_spin)
160            return new
161
162    def pickle_dump(self, filepath, protocol=-1):
163        """Save the object in Pickle format"""
164        assert all(s1 == s2 for s1, s2 in zip(self.domains_spin.keys(), self.bounds_spin.keys()))
165        assert all(s1 == s2 for s1, s2 in zip(self.domains_spin.keys(), range(self.nsppol)))
166
167        bounds_spin = None
168        if any(v is not None for v in self.bounds_spin.values()):
169            bounds_spin = [a.tolist() for a in self.bounds_spin.values()]
170
171        # This trick is needed because we cannot pickle bound methods of the scissors operator.
172        d = dict(qps_spin=self._qps_spin,
173                 sigres_ebands=self.sigres_ebands,
174                 domains_spin=[a for a in self.domains_spin.values()],
175                 bounds_spin=bounds_spin)
176
177        with open(filepath, "wb") as fh:
178            pickle.dump(d, fh, protocol=protocol)
179
180    def __init__(self, qps_spin, sigres_ebands):
181        """
182        Args:
183            qps_spin: List of :class:`QPlist`, for each spin.
184            sigres_ebands: |ElectronBands| obtained from the SIGRES file
185        """
186        # Sort quasiparticle data by e0.
187        self._qps_spin = tuple([qps.sort_by_e0() for qps in qps_spin])
188
189        # Compute the boundaries of the E0 mesh.
190        e0min, e0max = np.inf, -np.inf
191        for qps in self._qps_spin:
192            e0mesh = qps.get_e0mesh()
193            e0min = min(e0min, e0mesh[0])
194            e0max = max(e0max, e0mesh[-1])
195
196        self._e0min, self._e0max = e0min, e0max
197
198        # The KS bands stored in the sigres file (used to compute automatically the boundaries)
199        self.sigres_ebands = sigres_ebands
200
201        # Start with default values for domains.
202        self.build()
203
204    @property
205    def nsppol(self):
206        """Number of spins."""
207        return len(self._qps_spin)
208
209    @property
210    def e0min(self):
211        """Minimum KS energy in eV (takes into account spin)"""
212        return self._e0min
213
214    @property
215    def e0max(self):
216        """Maximum KS energy in eV (takes into account spin)"""
217        return self._e0max
218
219    @property
220    def scissors_spin(self):
221        """Returns a tuple of :class:`Scissors` indexed by the spin value."""
222        try:
223            return self._scissors_spin
224        except AttributeError:
225            raise AttributeError("Call self.build to create the scissors operator")
226
227    def build(self, domains_spin=None, bounds_spin=None, k=3):
228        """
229        Build the scissors operator.
230
231        Args:
232            domains_spin: list of domains in eV for each spin. If domains is None,
233                domains are computed automatically from the sigres bands
234                (two domains separated by the middle of the gap).
235            bounds_spin: Options specifying the boundary conditions (not used at present)
236            k: Parameter defining the order of the fit.
237        """
238        nsppol = self.nsppol
239
240        # The parameters defining the scissors operator
241        self.domains_spin = OrderedDict()
242        self.bounds_spin = OrderedDict()
243
244        if domains_spin is None:
245            # Use sigres_ebands and the position of the homo, lumo to compute the domains.
246            domains_spin = nsppol * [None]
247            e_bands = self.sigres_ebands
248            for spin in e_bands.spins:
249                gap_mid = (e_bands.homos[spin].eig + e_bands.lumos[spin].eig) / 2
250                domains_spin[spin] = [[self.e0min - 0.2 * abs(self.e0min), gap_mid],
251                                      [gap_mid, self.e0max + 0.2 * abs(self.e0max)]]
252                #print("domains", domains_spin[spin])
253        else:
254            if nsppol == 1:
255                domains_spin = np.reshape(domains_spin, (1, -1, 2))
256            elif nsppol == 2:
257                assert len(domains_spin) == nsppol
258                if bounds_spin is not None: assert len(bounds_spin) == nsppol
259            else:
260                raise ValueError("Wrong number of spins %d" % nsppol)
261            #if len(domains_spin) != nsppol:
262            #    raise ValueError("len(domains_spin) == %s != nsppol %s" % (len(domains_spin), nsppol))
263
264        # Construct the scissors operator for each spin.
265        scissors_spin = nsppol * [None]
266        for spin, qps in enumerate(self._qps_spin):
267            bounds = None if not bounds_spin else bounds_spin[spin]
268            scissors_spin[spin] = qps.build_scissors(domains_spin[spin], bounds=bounds, k=k, plot=False)
269
270            # Save input so that we can reconstruct Scissors.
271            self.domains_spin[spin] = domains_spin[spin]
272            self.bounds_spin[spin] = bounds
273
274        self._scissors_spin = scissors_spin
275        return domains_spin
276
277    @add_fig_kwargs
278    def plot_qpe_vs_e0(self, with_fields="all", **kwargs):
279        """Plot the quasiparticle corrections as function of the KS energy."""
280        ax_list = None
281        for spin, qps in enumerate(self._qps_spin):
282            kwargs["title"] = "spin %s" % spin
283            fig = qps.plot_qps_vs_e0(with_fields=with_fields, ax_list=ax_list, show=False, **kwargs)
284            ax_list = fig.axes
285
286        return fig
287
288    @add_fig_kwargs
289    def plot_fit(self, ax=None, fontsize=8, **kwargs):
290        """
291        Compare fit functions with input quasi-particle corrections.
292
293        Args:
294            ax: |matplotlib-Axes| or None if a new figure should be created.
295            fontsize: fontsize for titles and legend.
296
297        Return: |matplotlib-Figure|
298        """
299        ax, fig, plt = get_ax_fig_plt(ax=ax)
300
301        for spin in range(self.nsppol):
302            qps = self._qps_spin[spin]
303            e0mesh, qpcorrs = qps.get_e0mesh(), qps.get_qpeme0().real
304
305            ax.scatter(e0mesh, qpcorrs, label="Input QP corrections, spin %s" % spin)
306            scissors = self._scissors_spin[spin]
307            intp_qpc = [scissors.apply(e0) for e0 in e0mesh]
308            ax.plot(e0mesh, intp_qpc, label="Scissors operator, spin %s" % spin)
309
310        ax.grid(True)
311        ax.set_xlabel('KS energy (eV)')
312        ax.set_ylabel('QP-KS (eV)')
313        ax.legend(loc="best", fontsize=fontsize, shadow=True)
314
315        return fig
316
317    def plot_qpbands(self, bands_filepath, bands_label=None, dos_filepath=None, dos_args=None, **kwargs):
318        """
319        Correct the energies found in the netcdf file bands_filepath and plot the band energies (both the initial
320        and the corrected ones) with matplotlib. The plot contains the KS and the QP DOS if dos_filepath is not None.
321
322        Args:
323            bands_filepath: Path to the netcdf file containing the initial KS energies to be corrected.
324            bands_label String used to label the KS bands in the plot.
325            dos_filepath: Optional path to a netcdf file with the initial KS energies on a homogeneous k-mesh
326                (used to compute the KS and the QP dos)
327            dos_args: Dictionary with the arguments passed to get_dos to compute the DOS
328                Used if dos_filepath is not None.
329
330            kwargs: Options passed to the plotter.
331
332        Return: |matplotlib-Figure|
333        """
334        from abipy.abilab import abiopen, ElectronBandsPlotter
335
336        # Read the KS band energies from bands_filepath and apply the scissors operator.
337        with abiopen(bands_filepath) as ncfile:
338            ks_bands = ncfile.ebands
339            #structure = ncfile.structure
340
341        qp_bands = ks_bands.apply_scissors(self._scissors_spin)
342
343        # Read the band energies computed on the Monkhorst-Pack (MP) mesh and compute the DOS.
344        ks_dos, qp_dos = None, None
345        if dos_filepath is not None:
346            with abiopen(dos_filepath) as ncfile:
347                ks_mpbands = ncfile.ebands
348
349            dos_args = {} if not dos_args else dos_args
350            ks_dos = ks_mpbands.get_edos(**dos_args)
351            # Compute the DOS with the modified QPState energies.
352            qp_mpbands = ks_mpbands.apply_scissors(self._scissors_spin)
353            qp_dos = qp_mpbands.get_edos(**dos_args)
354
355        # Plot the LDA and the QPState band structure with matplotlib.
356        plotter = ElectronBandsPlotter()
357
358        bands_label = bands_label if bands_label is not None else os.path.basename(bands_filepath)
359        plotter.add_ebands(bands_label, ks_bands, edos=ks_dos)
360        plotter.add_ebands(bands_label + " + scissors", qp_bands, edos=qp_dos)
361
362        #qp_marker: if int > 0, markers for the ab-initio QP energies are displayed. e.g qp_marker=50
363        #qp_marker = 50
364        #if qp_marker is not None:
365        #    # Compute correspondence between the k-points in qp_list and the k-path in qp_bands.
366        #    # TODO
367        #    # WARNING: strictly speaking one should check if qp_kpoint is in the star of k-point.
368        #    # but compute_star is too slow if written in pure python.
369        #    x, y, s = [], [], []
370        #    for ik_path, kpoint in enumerate(qp_bands.kpoints):
371        #        #kstar = kpoint.compute_star(structure.fm_symmops)
372        #        for spin in range(self.nsppol):
373        #            for ik_qp, qp in enumerate(self._qps_spin[spin]):
374        #                #if qp.kpoint in kstar:
375        #                if qp.kpoint == kpoint:
376        #                    x.append(ik_path)
377        #                    y.append(np.real(qp.qpe))
378        #                    s.append(qp_marker)
379        #    plotter.set_marker("ab-initio QP", [x, y, s])
380
381        return plotter.combiplot(**kwargs)
382