1# coding: utf-8
2"""
3Object to analyze the results stored in the WR.nc file
4"""
5import numpy as np
6
7from monty.string import marquee
8from monty.functools import lazy_property
9from abipy.tools.plotting import add_fig_kwargs, get_ax_fig_plt #, get_axarray_fig_plt
10from abipy.core.mixins import AbinitNcFile, Has_Structure, NotebookWriter
11from abipy.iotools import ETSF_Reader
12
13
14class WrNcFile(AbinitNcFile, Has_Structure, NotebookWriter):
15
16    def __init__(self, filepath):
17        super().__init__(filepath)
18        self.reader = r = ETSF_Reader(filepath)
19
20        # Read dimensions.
21        self.nfft = r.read_dimvalue("nfft")
22        self.nspden = r.read_dimvalue("nspden")
23        self.natom3 = len(self.structure) * 3
24        self.method = r.read_value("method")
25        assert self.method == 0
26        self.ngqpt = r.read_value("ngqpt")
27        self.rpt = r.read_value("rpt")
28        self.nrpt = len(self.rpt)
29        # FFT mesh.
30        self.ngfft = r.read_value("ngfft")
31
32    def create_xsf(self, iatom=0, red_dir=(1, 0, 0), u=1.0, ispden=0):
33
34        nfft, nrpt = self.nfft, self.nrpt
35
36        nx, ny, nz = self.ngfft
37        nqx, nqy, nqz = self.ngqpt
38        box_shape = self.ngqpt * self.ngfft
39        box_size = np.product(box_shape)
40        print("ngqpt:", self.ngqpt)
41        print("nrpt:", self.nrpt)
42        print("Unit cell FFT shape:", self.ngfft)
43        print("Big box shape:", box_shape)
44        print("rpt:\n", self.rpt)
45
46        # Get FFT points in reduced coordinates of the microcell.
47        # ix is the fastest index here because we are gonna access
48        # FFT values produced by Fortran via ifft
49        fft_inds = np.empty((nfft, 3), dtype=np.int64)
50        ifft = -1
51        for iz in range(nz):
52            for iy in range(ny):
53                for ix in range(nx):
54                    ifft += 1
55                    fft_inds[ifft, :] = [ix, iy, iz]
56
57        def ig2gfft(ig, ng):
58            # Use the following indexing (N means ngfft of the adequate direction)
59            # 0 1 2 3 ... N/2    -(N-1)/2 ... -1    <= gc
60            # 1 2 3 4 ....N/2+1  N/2+2    ...  N    <= index ig
61
62            #if ( ig <= 0 or ig > ng):
63            #  # Wrong ig, returns huge. Parent code will likely crash with SIGSEV.
64            #  gc = huge(1)
65            #  return
66
67            #if (ig  > ng/2 + 1):
68            #  gc = ig - ng -1
69            #else
70            #  gc = ig -1
71            #return gc
72            raise NotImplementedError("Foo")
73
74        # Find index of (r, R) in the bigbox
75        # Use C notation (z runs faster)
76        #box2fr = np.full((box_size, 2), np.inf, dtype=np.int64)
77
78        #iffr2box = np.empty((nfft, nrpt), dtype=np.int64)
79        print("Building r-R dictionary")
80        d = {}
81        for ir, rpt in enumerate(self.rpt):
82            for ifft, fft_ijk in enumerate(fft_inds):
83                ijk = (fft_ijk - rpt * self.ngfft) % box_shape
84                key = tuple(map(int, ijk))
85                d[key] = (ifft, ir)
86                #i, j, k = ijk
87                #box_iloc = k + j * box_shape[2] + i * (box_shape[2] * box_shape[1])
88                #print(box_iloc, "(i, j, k): ", fft_ijk, "rpt:", rpt)
89                #box2fr[int(box_iloc)] = [ifft, ir]
90                #iffr2box[ifft, ir] = box_iloc
91        print("Done")
92
93        #ip = 0
94        #idir = ip % 3
95        #iatom = (ip - idir) // 3 # + 1
96
97        # nctkarr_t("v1scf_rpt_sr", "dp", "two, nrpt, nfft, nspden, natom3")
98        # use iorder = "f" to transpose the last 3 dimensions since ETSF
99        # stores data in Fortran order while AbiPy uses C-ordering.
100        # (z,y,x) --> (x,y,z)
101        # datar = transpose_last3dims(datar)
102        wsr_var = self.reader.read_variable("v1scf_rpt_sr")
103        wlr_var = self.reader.read_variable("v1scf_rpt_lr")
104
105        wsr = np.zeros(nfft, nrpt)
106        wlr = np.zeros(nfft, nrpt)
107        for idir, red_comp in enumerate(red_dir):
108            ip = idir + 3 * iatom
109            wsr += u * red_comp * wsr_var[ip, ispden, :, :, 0]
110            wlr += u * red_comp * wlr_var[ip, ispden, :, :, 0]
111
112        print("wsr.shape:", wsr.shape)
113        print("Max |Re Wsr|:", np.max(np.abs(wsr.real)), "Max |Im Wsr|:", np.max(np.abs(wsr.imag)))
114        #print("Max |Re Wlr|:", np.max(np.abs(wlr.real)), "Max |Im Wlr|:", np.max(np.abs(wlr.imag)))
115
116        r0 = np.array([0, 0, 0], dtype=np.int)
117        qgrid = np.where(self.ngqpt > 2, self.ngqpt, 0)
118        r0 = - (self.ngqpt - 1) // 2
119        print("Origin of datagrid set at R0:", r0)
120
121        # Build datagrid in the supercell using C indexing
122        # This is what xsf_write_data expects.
123        miss = []
124        data_lr = np.empty(box_shape)
125        data_sr = np.empty(box_shape)
126
127        print("Filling data array")
128        for ix in range(box_shape[0]):
129            for iy in range(box_shape[1]):
130                for iz in range(box_shape[2]):
131                    y = np.array((ix, iy, iz), dtype=np.int)
132                    x = (y + r0) % box_shape
133                    key = tuple(map(int, x))
134                    try:
135                        ifft, ir = d[key]
136                    except KeyError:
137                        #print("Cannot find r - R with key:", key)
138                        miss.append(key)
139                        continue
140
141                    data_lr[ix, iy, iz] = wlr[ifft, ir]
142                    #data_sr[ix, iy, iz] = wsr[ifft, ir]
143
144        if miss:
145            #print(d.keys())
146            raise RuntimeError("Cannot find r-R points! nmiss:", len(miss))
147
148        super_structure = self.structure * self.ngqpt
149
150        def dump_xsf(filename, data):
151            from abipy.iotools import xsf
152            xsf.xsf_write_structure_and_data_to_path(filename, super_structure, data, cplx_mode="abs")
153
154        dump_xsf("foo_lr.xsf", data_lr)
155        dump_xsf("foo_sr.xsf", data_sr)
156
157    @lazy_property
158    def structure(self):
159        """|Structure| object."""
160        return self.reader.read_structure()
161
162    def close(self):
163        self.reader.close()
164
165    @lazy_property
166    def params(self):
167        """:class:`OrderedDict` with parameters that might be subject to convergence studies."""
168        return {}
169
170    def __str__(self):
171        return self.to_string()
172
173    def to_string(self, verbose=0):
174        """String representation."""
175        lines = []; app = lines.append
176        app(marquee("File Info", mark="="))
177        app(self.filestat(as_string=True))
178        app("")
179        app(self.structure.to_string(verbose=verbose, title="Structure"))
180        app("")
181
182        return "\n".join(lines)
183
184    @add_fig_kwargs
185    def plot_maxw(self, scale="semilogy", ax=None, fontsize=8, **kwargs):
186        """
187        Plot the decay of max_{r,idir,ipert} |W(R,r,idir,ipert)|
188        for the long-range and the short-range part.
189
190        Args:
191            scale: "semilogy", "loglog" or "plot".
192            ax: |matplotlib-Axes| or None if a new figure should be created.
193            fontsize: fontsize for legends and titles
194
195        Return: |matplotlib-Figure|
196        """
197        ax, fig, plt = get_ax_fig_plt(ax=ax)
198        f = {"plot": ax.plot, "semilogy": ax.semilogy, "loglog": ax.loglog}[scale]
199
200        rmod = self.reader.read_value("rmod")
201
202        # Plot short-range part.
203        # normalize wrt the R=0 value
204        # Fortran array: nctkarr_t("maxw_sr", "dp", "nrpt, natom3")
205        maxw_sr = self.reader.read_value("maxw_sr")
206        data = np.max(maxw_sr, axis=0)
207        #data = data / data[0]
208        f(rmod, data, marker="o", ls=":", lw=0, label="SR", **kwargs)
209
210        # Plot long-range part.
211        maxw_lr = self.reader.read_value("maxw_lr")
212        data = np.max(maxw_lr, axis=0)
213        #data = data / data[0]
214        f(rmod, data, marker="x", ls="-", lw=0, label="LR", **kwargs)
215
216        # Plot the ratio
217        data = np.max(maxw_lr, axis=0) / np.max(maxw_sr, axis=0)
218        #f(rmod, data, marker="x", ls="-", lw=0, label="LR/SR", **kwargs)
219
220        #rmod = self.reader.read_value("rmod_lrmodel")
221        #maxw = self.reader.read_value("maxw_lrmodel")
222        #data = np.max(maxw, axis=0)
223        #data = data / data[0]
224        #f(rmod, data, marker="x", ls="-", lw=0, label="W_LR_only", **kwargs)
225
226        ax.grid(True)
227        ax.set_ylabel(r"$Max_{({\bf{r}}, idir, ipert)} \| W({\bf{r}}, {\bf{R}}, idir, ipert) \|$")
228        ax.set_xlabel(r"$\|{\bf{R}}\|$ (Bohr)")
229        ax.legend(loc="best", fontsize=fontsize, shadow=True)
230
231        #if kwargs.pop("with_title", True):
232        #    ax.set_title("dvdb_add_lr %d, qdamp: %s, symv1scf: %d" % (self.dvdb_add_lr, self.qdamp, self.symv1scf),
233        #                 fontsize=fontsize)
234        return fig
235
236    def yield_figs(self, **kwargs):  # pragma: no cover
237        """
238        This function *generates* a predefined list of matplotlib figures with minimal input from the user.
239        """
240        yield self.plot_maxw(scale="semilogy")
241
242    def write_notebook(self, nbpath=None):
243        """
244        Write a jupyter_ notebook to ``nbpath``. If nbpath is None, a temporay file in the current
245        working directory is created. Return path to the notebook.
246        """
247        nbformat, nbv, nb = self.get_nbformat_nbv_nb(title=None)
248
249        nb.cells.extend([
250            nbv.new_code_cell("ncfile = abilab.abiopen('%s')" % self.filepath),
251            nbv.new_code_cell("print(ncfile)"),
252        ])
253
254        #nb.cells.append(nbv.new_code_cell("ncfile.plot_diff_at_qpoint(qpoint=%d);" % iq))
255
256        return self._write_nb_nbpath(nb, nbpath)
257
258
259if __name__ == "__main__":
260    import sys
261    ncfile = WrNcFile.from_file(sys.argv[1])
262
263    #print(ncfile)
264    ncfile.plot_maxw(scale="semilogy", ax=None, fontsize=8)
265    #ncfile.create_xsf(iatom=0, red_dict=(-1, +1, +1), u=0.1)
266