1# coding: utf-8 2""" 3Object to analyze the results stored in the V1SYM.nc file (mainly for debugging purposes) 4""" 5import numpy as np 6 7from collections import OrderedDict 8from monty.string import marquee 9from monty.functools import lazy_property 10from abipy.tools.plotting import add_fig_kwargs, get_axarray_fig_plt 11from abipy.core.mixins import AbinitNcFile, Has_Structure, NotebookWriter 12from abipy.core.kpoints import KpointList, Kpoint 13from abipy.iotools import ETSF_Reader 14from abipy.tools import duck 15 16 17class V1symFile(AbinitNcFile, Has_Structure, NotebookWriter): 18 19 def __init__(self, filepath): 20 super().__init__(filepath) 21 self.reader = r = ETSF_Reader(filepath) 22 # Read dimensions. 23 self.nfft = r.read_dimvalue("nfft") 24 self.nspden = r.read_dimvalue("nspden") 25 self.natom3 = len(self.structure) * 3 26 self.symv1scf = r.read_value("symv1scf") 27 # Read FFT mesh. 28 #self.ngfft = r.read_value("ngfft") 29 30 @lazy_property 31 def structure(self): 32 """|Structure| object.""" 33 return self.reader.read_structure() 34 35 @lazy_property 36 def pertsy_qpt(self): 37 """ 38 Determine the symmetrical perturbations. Meaning of pertsy: 39 40 0 for non-target perturbations. 41 1 for basis perturbations. 42 -1 for perturbations that can be found from basis perturbations. 43 """ 44 # Fortran array: nctkarr_t("pertsy_qpt", "int", "three, mpert, nqpt"))) 45 return self.reader.read_value("pertsy_qpt") 46 47 def close(self): 48 self.reader.close() 49 50 @lazy_property 51 def params(self): 52 """:class:`OrderedDict` with parameters that might be subject to convergence studies.""" 53 return {} 54 55 def __str__(self): 56 return self.to_string() 57 58 def to_string(self, verbose=0): 59 """String representation.""" 60 lines = []; app = lines.append 61 app(marquee("File Info", mark="=")) 62 app(self.filestat(as_string=True)) 63 app("") 64 app(self.structure.to_string(verbose=verbose, title="Structure")) 65 app("") 66 app("symv1scf: %s" % self.symv1scf) 67 68 return "\n".join(lines) 69 70 @lazy_property 71 def qpoints(self): 72 return KpointList(self.structure.reciprocal_lattice, frac_coords=self.reader.read_value("qpts")) 73 74 def _find_iqpt_qpoint(self, qpoint): 75 if duck.is_intlike(qpoint): 76 iq = qpoint 77 qpoint = self.qpoints[iq] 78 else: 79 qpoint = Kpoint.as_kpoint(qpoint, self.structure.reciprocal_lattice) 80 iq = self.qpoints.index(qpoint) 81 82 return iq, qpoint 83 84 def read_v1_at_iq(self, key, iq, reshape_nfft_nspden=False): 85 # Fortran array ("two, nfft, nspden, natom3, nqpt") 86 v1 = self.reader.read_variable(key)[iq] 87 v1 = v1[..., 0] + 1j * v1[..., 1] 88 # reshape (nspden, nfft) dims because we are not interested in the spin dependence. 89 if reshape_nfft_nspden: v1 = np.reshape(v1, (self.natom3, self.nspden * self.nfft)) 90 return v1 91 92 @add_fig_kwargs 93 def plot_diff_at_qpoint(self, qpoint=0, fontsize=8, **kwargs): 94 """ 95 Args: 96 qpoint: 97 ax: |matplotlib-Axes| or None if a new figure should be created. 98 fontsize: fontsize for legends and titles 99 100 Return: |matplotlib-Figure| 101 """ 102 iq, qpoint = self._find_iqpt_qpoint(qpoint) 103 104 # complex arrays with shape: (natom3, nspden * nfft) 105 origin_v1 = self.read_v1_at_iq("origin_v1scf", iq, reshape_nfft_nspden=True) 106 symm_v1 = self.read_v1_at_iq("recons_v1scf", iq, reshape_nfft_nspden=True) 107 108 num_plots, ncols, nrows = self.natom3, 3, self.natom3 // 3 109 ax_list, fig, plt = get_axarray_fig_plt(None, nrows=nrows, ncols=ncols, 110 sharex=False, sharey=False, squeeze=False) 111 112 for nu, ax in enumerate(ax_list.ravel()): 113 idir = nu % 3 114 ipert = (nu - idir) // 3 115 116 # l1_rerr(f1, f2) = \int |f1 - f2| dr / (\int |f2| dr 117 abs_diff = np.abs(origin_v1[nu] - symm_v1[nu]) 118 l1_rerr = np.sum(abs_diff) / np.sum(np.abs(origin_v1[nu])) 119 120 stats = OrderedDict([ 121 ("max", abs_diff.max()), 122 ("min", abs_diff.min()), 123 ("mean", abs_diff.mean()), 124 ("std", abs_diff.std()), 125 ("L1_rerr", l1_rerr), 126 ]) 127 128 xs = np.arange(len(abs_diff)) 129 ax.hist(abs_diff, facecolor='g', alpha=0.75) 130 ax.grid(True) 131 ax.set_title("idir: %d, iat: %d, pertsy: %d" % (idir, ipert, self.pertsy_qpt[iq, ipert, idir]), 132 fontsize=fontsize) 133 134 ax.axvline(stats["mean"], color='k', linestyle='dashed', linewidth=1) 135 _, max_ = ax.get_ylim() 136 ax.text(0.7, 0.7, "\n".join("%s = %.1E" % item for item in stats.items()), 137 fontsize=fontsize, horizontalalignment='center', verticalalignment='center', 138 transform=ax.transAxes) 139 140 fig.suptitle("qpoint: %s" % repr(qpoint)) 141 return fig 142 143 @add_fig_kwargs 144 def plot_pots_at_qpoint(self, qpoint=0, fontsize=8, **kwargs): 145 """ 146 Args: 147 qpoint: 148 ax: |matplotlib-Axes| or None if a new figure should be created. 149 fontsize: fontsize for legends and titles 150 151 Return: |matplotlib-Figure| 152 """ 153 iq, qpoint = self._find_iqpt_qpoint(qpoint) 154 155 # complex arrays with shape: (natom3, nspden * nfft) 156 origin_v1 = self.read_v1_at_iq("origin_v1scf", iq, reshape_nfft_nspden=True) 157 symm_v1 = self.read_v1_at_iq("recons_v1scf", iq, reshape_nfft_nspden=True) 158 159 num_plots, ncols, nrows = self.natom3, 3, self.natom3 // 3 160 ax_list, fig, plt = get_axarray_fig_plt(None, nrows=nrows, ncols=ncols, 161 sharex=False, sharey=False, squeeze=False) 162 163 natom = len(self.structure) 164 xs = np.arange(self.nspden * self.nfft) 165 for nu, ax in enumerate(ax_list.ravel()): 166 idir = nu % 3 167 ipert = (nu - idir) // 3 168 169 # l1_rerr(f1, f2) = \int |f1 - f2| dr / (\int |f2| dr 170 abs_diff = np.abs(origin_v1[nu] - symm_v1[nu]) 171 l1_rerr = np.sum(abs_diff) / np.sum(np.abs(origin_v1[nu])) 172 173 stats = OrderedDict([ 174 ("max", abs_diff.max()), 175 ("min", abs_diff.min()), 176 ("mean", abs_diff.mean()), 177 ("std", abs_diff.std()), 178 ("L1_rerr", l1_rerr), 179 ]) 180 181 ax.grid(True) 182 ax.set_title("idir: %d, iat: %d, pertsy: %d" % (idir, ipert, self.pertsy_qpt[iq, ipert, idir]), 183 fontsize=fontsize) 184 # Plot absolute error 185 #ax.plot(xs, abs_diff, linestyle="-", color="red", alpha=1.0, label="Abs diff" if nu == 0 else None) 186 187 # Plot absolute values 188 #ax.plot(xs, np.abs(origin_v1[nu]), linestyle="--", color="red", alpha=0.4, label="Origin" if nu == 0 else None) 189 #ax.plot(xs, -np.abs(symm_v1[nu]), linestyle="--", color="blue", alpha=0.4, label="-Symm" if nu == 0 else None) 190 191 # Plot real and imag 192 #ax.plot(xs, origin_v1[nu].real, linestyle="--", color="red", alpha=0.4, label="Re Origin" if nu == 0 else None) 193 #ax.plot(xs, -symm_v1[nu].real, linestyle="--", color="blue", alpha=0.4, label="Re Symm" if nu == 0 else None) 194 195 data = np.angle(origin_v1[nu], deg=True) - np.angle(symm_v1[nu], deg=True) 196 #data = data[abs_diff > stats["mean"]] 197 data = data[np.abs(origin_v1[nu]) > 1e-5] 198 ax.plot(np.arange(len(data)), data, 199 linestyle="--", color="red", alpha=0.4, label="diff angle degrees" if nu == 0 else None) 200 201 #ax.plot(xs, origin_v1[nu].real, linestyle="--", color="red", alpha=0.4, label="Re Origin" if nu == 0 else None) 202 #ax.plot(xs, -symm_v1[nu].real, linestyle="--", color="blue", alpha=0.4, label="Re Symm" if nu == 0 else None) 203 204 #ax.plot(xs, origin_v1[nu].real - symm_v1[nu].real, linestyle="--", color="red", alpha=0.4, 205 # label="Re Origin" if nu == 0 else None) 206 207 #ax.plot(xs, origin_v1[nu].imag, linestyle=":", color="red", alpha=0.4, label="Imag Origin" if nu == 0 else None) 208 #ax.plot(xs, -symm_v1[nu].imag, linestyle=":", color="blue", alpha=0.4, label="Imag Symm" if nu == 0 else None) 209 210 #ax.plot(xs, origin_v1[nu].imag - symm_v1[nu].imag, linestyle="--", color="blue", alpha=0.4, 211 # label="Re Origin" if nu == 0 else None) 212 213 if nu == 0: 214 ax.set_ylabel(r"Abs diff") 215 ax.legend(loc="best", fontsize=fontsize, shadow=True) 216 if ipert == natom - 1: 217 ax.set_xlabel(r"FFT index") 218 219 #ax.axvline(stats["mean"], color='k', linestyle='dashed', linewidth=1) 220 _, max_ = ax.get_ylim() 221 ax.text(0.7, 0.7, "\n".join("%s = %.1E" % item for item in stats.items()), 222 fontsize=fontsize, horizontalalignment='center', verticalalignment='center', 223 transform=ax.transAxes) 224 225 #ax2 = ax.twinx() 226 #rerr = 100 * abs_diff / np.abs(origin_v1[nu]) 227 #ax2.plot(xs, rerr, linestyle="--", color="blue", alpha=0.4, 228 # label=r"|V_{\mathrm{origin}}|" if nu == 0 else None) 229 230 fig.suptitle("qpoint: %s" % repr(qpoint)) 231 return fig 232 233 def yield_figs(self, **kwargs): # pragma: no cover 234 """ 235 This function *generates* a predefined list of matplotlib figures with minimal input from the user. 236 """ 237 maxnq = 3 238 for iq, qpoint in enumerate(self.qpoints): 239 if iq > maxnq: 240 print("Only the first %d q-points are show..." % maxnq) 241 break 242 #yield self.plot_diff_at_qpoint(qpoint=iq, **kwargs, show=False) 243 yield self.plot_pots_at_qpoint(qpoint=iq, **kwargs, show=False) 244 245 def write_notebook(self, nbpath=None): 246 """ 247 Write a jupyter_ notebook to ``nbpath``. If nbpath is None, a temporay file in the current 248 working directory is created. Return path to the notebook. 249 """ 250 nbformat, nbv, nb = self.get_nbformat_nbv_nb(title=None) 251 252 nb.cells.extend([ 253 nbv.new_code_cell("ncfile = abilab.abiopen('%s')" % self.filepath), 254 nbv.new_code_cell("print(ncfile)"), 255 ]) 256 257 for iq, qpoint in enumerate(self.qpoints): 258 nb.cells.append(nbv.new_code_cell("ncfile.plot_diff_at_qpoint(qpoint=%d);" % iq)) 259 #nb.cells.append(nbv.new_code_cell("ncfile.plot_diff_at_qpoint(qpoint=%d);" % iq)) 260 261 return self._write_nb_nbpath(nb, nbpath) 262