1# coding: utf-8
2"""
3Object to analyze the results stored in the V1SYM.nc file (mainly for debugging purposes)
4"""
5import numpy as np
6
7from collections import OrderedDict
8from monty.string import marquee
9from monty.functools import lazy_property
10from abipy.tools.plotting import add_fig_kwargs, get_axarray_fig_plt
11from abipy.core.mixins import AbinitNcFile, Has_Structure, NotebookWriter
12from abipy.core.kpoints import KpointList, Kpoint
13from abipy.iotools import ETSF_Reader
14from abipy.tools import duck
15
16
17class V1symFile(AbinitNcFile, Has_Structure, NotebookWriter):
18
19    def __init__(self, filepath):
20        super().__init__(filepath)
21        self.reader = r = ETSF_Reader(filepath)
22        # Read dimensions.
23        self.nfft = r.read_dimvalue("nfft")
24        self.nspden = r.read_dimvalue("nspden")
25        self.natom3 = len(self.structure) * 3
26        self.symv1scf = r.read_value("symv1scf")
27        # Read FFT mesh.
28        #self.ngfft = r.read_value("ngfft")
29
30    @lazy_property
31    def structure(self):
32        """|Structure| object."""
33        return self.reader.read_structure()
34
35    @lazy_property
36    def pertsy_qpt(self):
37        """
38        Determine the symmetrical perturbations. Meaning of pertsy:
39
40        0 for non-target perturbations.
41        1 for basis perturbations.
42        -1 for perturbations that can be found from basis perturbations.
43        """
44        # Fortran array: nctkarr_t("pertsy_qpt", "int", "three, mpert, nqpt")))
45        return self.reader.read_value("pertsy_qpt")
46
47    def close(self):
48        self.reader.close()
49
50    @lazy_property
51    def params(self):
52        """:class:`OrderedDict` with parameters that might be subject to convergence studies."""
53        return {}
54
55    def __str__(self):
56        return self.to_string()
57
58    def to_string(self, verbose=0):
59        """String representation."""
60        lines = []; app = lines.append
61        app(marquee("File Info", mark="="))
62        app(self.filestat(as_string=True))
63        app("")
64        app(self.structure.to_string(verbose=verbose, title="Structure"))
65        app("")
66        app("symv1scf: %s" % self.symv1scf)
67
68        return "\n".join(lines)
69
70    @lazy_property
71    def qpoints(self):
72        return KpointList(self.structure.reciprocal_lattice, frac_coords=self.reader.read_value("qpts"))
73
74    def _find_iqpt_qpoint(self, qpoint):
75        if duck.is_intlike(qpoint):
76            iq = qpoint
77            qpoint = self.qpoints[iq]
78        else:
79            qpoint = Kpoint.as_kpoint(qpoint, self.structure.reciprocal_lattice)
80            iq = self.qpoints.index(qpoint)
81
82        return iq, qpoint
83
84    def read_v1_at_iq(self, key, iq, reshape_nfft_nspden=False):
85        # Fortran array ("two, nfft, nspden, natom3, nqpt")
86        v1 = self.reader.read_variable(key)[iq]
87        v1 = v1[..., 0] + 1j * v1[..., 1]
88        # reshape (nspden, nfft) dims because we are not interested in the spin dependence.
89        if reshape_nfft_nspden: v1 = np.reshape(v1, (self.natom3, self.nspden * self.nfft))
90        return v1
91
92    @add_fig_kwargs
93    def plot_diff_at_qpoint(self, qpoint=0, fontsize=8, **kwargs):
94        """
95        Args:
96            qpoint:
97            ax: |matplotlib-Axes| or None if a new figure should be created.
98            fontsize: fontsize for legends and titles
99
100        Return: |matplotlib-Figure|
101        """
102        iq, qpoint = self._find_iqpt_qpoint(qpoint)
103
104        # complex arrays with shape: (natom3, nspden * nfft)
105        origin_v1 = self.read_v1_at_iq("origin_v1scf", iq, reshape_nfft_nspden=True)
106        symm_v1 = self.read_v1_at_iq("recons_v1scf", iq, reshape_nfft_nspden=True)
107
108        num_plots, ncols, nrows = self.natom3, 3, self.natom3 // 3
109        ax_list, fig, plt = get_axarray_fig_plt(None, nrows=nrows, ncols=ncols,
110                                                sharex=False, sharey=False, squeeze=False)
111
112        for nu, ax in enumerate(ax_list.ravel()):
113            idir = nu % 3
114            ipert = (nu - idir) // 3
115
116            # l1_rerr(f1, f2) = \int |f1 - f2| dr / (\int |f2| dr
117            abs_diff = np.abs(origin_v1[nu] - symm_v1[nu])
118            l1_rerr = np.sum(abs_diff) / np.sum(np.abs(origin_v1[nu]))
119
120            stats = OrderedDict([
121                ("max", abs_diff.max()),
122                ("min", abs_diff.min()),
123                ("mean", abs_diff.mean()),
124                ("std", abs_diff.std()),
125                ("L1_rerr", l1_rerr),
126            ])
127
128            xs = np.arange(len(abs_diff))
129            ax.hist(abs_diff, facecolor='g', alpha=0.75)
130            ax.grid(True)
131            ax.set_title("idir: %d, iat: %d, pertsy: %d" % (idir, ipert, self.pertsy_qpt[iq, ipert, idir]),
132                         fontsize=fontsize)
133
134            ax.axvline(stats["mean"], color='k', linestyle='dashed', linewidth=1)
135            _, max_ = ax.get_ylim()
136            ax.text(0.7, 0.7, "\n".join("%s = %.1E" % item for item in stats.items()),
137                    fontsize=fontsize, horizontalalignment='center', verticalalignment='center',
138                    transform=ax.transAxes)
139
140        fig.suptitle("qpoint: %s" % repr(qpoint))
141        return fig
142
143    @add_fig_kwargs
144    def plot_pots_at_qpoint(self, qpoint=0, fontsize=8, **kwargs):
145        """
146        Args:
147            qpoint:
148            ax: |matplotlib-Axes| or None if a new figure should be created.
149            fontsize: fontsize for legends and titles
150
151        Return: |matplotlib-Figure|
152        """
153        iq, qpoint = self._find_iqpt_qpoint(qpoint)
154
155        # complex arrays with shape: (natom3, nspden * nfft)
156        origin_v1 = self.read_v1_at_iq("origin_v1scf", iq, reshape_nfft_nspden=True)
157        symm_v1 = self.read_v1_at_iq("recons_v1scf", iq, reshape_nfft_nspden=True)
158
159        num_plots, ncols, nrows = self.natom3, 3, self.natom3 // 3
160        ax_list, fig, plt = get_axarray_fig_plt(None, nrows=nrows, ncols=ncols,
161                                                sharex=False, sharey=False, squeeze=False)
162
163        natom = len(self.structure)
164        xs = np.arange(self.nspden * self.nfft)
165        for nu, ax in enumerate(ax_list.ravel()):
166            idir = nu % 3
167            ipert = (nu - idir) // 3
168
169            # l1_rerr(f1, f2) = \int |f1 - f2| dr / (\int |f2| dr
170            abs_diff = np.abs(origin_v1[nu] - symm_v1[nu])
171            l1_rerr = np.sum(abs_diff) / np.sum(np.abs(origin_v1[nu]))
172
173            stats = OrderedDict([
174                ("max", abs_diff.max()),
175                ("min", abs_diff.min()),
176                ("mean", abs_diff.mean()),
177                ("std", abs_diff.std()),
178                ("L1_rerr", l1_rerr),
179            ])
180
181            ax.grid(True)
182            ax.set_title("idir: %d, iat: %d, pertsy: %d" % (idir, ipert, self.pertsy_qpt[iq, ipert, idir]),
183                         fontsize=fontsize)
184            # Plot absolute error
185            #ax.plot(xs, abs_diff, linestyle="-", color="red", alpha=1.0, label="Abs diff" if nu == 0 else None)
186
187            # Plot absolute values
188            #ax.plot(xs, np.abs(origin_v1[nu]), linestyle="--", color="red", alpha=0.4, label="Origin" if nu == 0 else None)
189            #ax.plot(xs, -np.abs(symm_v1[nu]), linestyle="--", color="blue", alpha=0.4, label="-Symm" if nu == 0 else None)
190
191            # Plot real and imag
192            #ax.plot(xs, origin_v1[nu].real, linestyle="--", color="red", alpha=0.4, label="Re Origin" if nu == 0 else None)
193            #ax.plot(xs, -symm_v1[nu].real, linestyle="--", color="blue", alpha=0.4, label="Re Symm" if nu == 0 else None)
194
195            data = np.angle(origin_v1[nu], deg=True) - np.angle(symm_v1[nu], deg=True)
196            #data = data[abs_diff > stats["mean"]]
197            data = data[np.abs(origin_v1[nu]) > 1e-5]
198            ax.plot(np.arange(len(data)), data,
199                    linestyle="--", color="red", alpha=0.4, label="diff angle degrees" if nu == 0 else None)
200
201            #ax.plot(xs, origin_v1[nu].real, linestyle="--", color="red", alpha=0.4, label="Re Origin" if nu == 0 else None)
202            #ax.plot(xs, -symm_v1[nu].real, linestyle="--", color="blue", alpha=0.4, label="Re Symm" if nu == 0 else None)
203
204            #ax.plot(xs, origin_v1[nu].real - symm_v1[nu].real, linestyle="--", color="red", alpha=0.4,
205            #        label="Re Origin" if nu == 0 else None)
206
207            #ax.plot(xs, origin_v1[nu].imag, linestyle=":", color="red", alpha=0.4, label="Imag Origin" if nu == 0 else None)
208            #ax.plot(xs, -symm_v1[nu].imag, linestyle=":", color="blue", alpha=0.4, label="Imag Symm" if nu == 0 else None)
209
210            #ax.plot(xs, origin_v1[nu].imag - symm_v1[nu].imag, linestyle="--", color="blue", alpha=0.4,
211            #        label="Re Origin" if nu == 0 else None)
212
213            if nu == 0:
214                ax.set_ylabel(r"Abs diff")
215                ax.legend(loc="best", fontsize=fontsize, shadow=True)
216            if ipert == natom - 1:
217                ax.set_xlabel(r"FFT index")
218
219            #ax.axvline(stats["mean"], color='k', linestyle='dashed', linewidth=1)
220            _, max_ = ax.get_ylim()
221            ax.text(0.7, 0.7, "\n".join("%s = %.1E" % item for item in stats.items()),
222                    fontsize=fontsize, horizontalalignment='center', verticalalignment='center',
223                    transform=ax.transAxes)
224
225            #ax2 = ax.twinx()
226            #rerr = 100 * abs_diff / np.abs(origin_v1[nu])
227            #ax2.plot(xs, rerr, linestyle="--", color="blue", alpha=0.4,
228            #          label=r"|V_{\mathrm{origin}}|" if nu == 0 else None)
229
230        fig.suptitle("qpoint: %s" % repr(qpoint))
231        return fig
232
233    def yield_figs(self, **kwargs):  # pragma: no cover
234        """
235        This function *generates* a predefined list of matplotlib figures with minimal input from the user.
236        """
237        maxnq = 3
238        for iq, qpoint in enumerate(self.qpoints):
239            if iq > maxnq:
240                print("Only the first %d q-points are show..." % maxnq)
241                break
242            #yield self.plot_diff_at_qpoint(qpoint=iq, **kwargs, show=False)
243            yield self.plot_pots_at_qpoint(qpoint=iq, **kwargs, show=False)
244
245    def write_notebook(self, nbpath=None):
246        """
247        Write a jupyter_ notebook to ``nbpath``. If nbpath is None, a temporay file in the current
248        working directory is created. Return path to the notebook.
249        """
250        nbformat, nbv, nb = self.get_nbformat_nbv_nb(title=None)
251
252        nb.cells.extend([
253            nbv.new_code_cell("ncfile = abilab.abiopen('%s')" % self.filepath),
254            nbv.new_code_cell("print(ncfile)"),
255        ])
256
257        for iq, qpoint in enumerate(self.qpoints):
258            nb.cells.append(nbv.new_code_cell("ncfile.plot_diff_at_qpoint(qpoint=%d);" % iq))
259            #nb.cells.append(nbv.new_code_cell("ncfile.plot_diff_at_qpoint(qpoint=%d);" % iq))
260
261        return self._write_nb_nbpath(nb, nbpath)
262