1# -*- coding: utf-8 -*- 2import numpy as np 3import matplotlib as mpl 4import matplotlib.pyplot as plt 5from matplotlib.colors import Normalize, LogNorm 6 7 8def generate_gridspec(**kwargs): 9 from matplotlib.gridspec import GridSpec 10 width = 0.84 11 bottom = 0.12 12 left = 0.12 13 return GridSpec(2, 2, width_ratios=[3, 1], height_ratios=[1, 3], 14 bottom=bottom, top=bottom + width, 15 left=left, right=left + width, 16 **kwargs) 17 18 19def plot_DOS(ax, energy_e, dos_e, base_e, dos_min, dos_max, 20 flip=False, fill=None, line=None): 21 ax.xaxis.set_ticklabels([]) 22 ax.yaxis.set_ticklabels([]) 23 ax.spines['right'].set_visible(False) 24 ax.spines['top'].set_visible(False) 25 ax.yaxis.set_ticks_position('left') 26 ax.xaxis.set_ticks_position('bottom') 27 if flip: 28 set_label = ax.set_xlabel 29 fill_between = ax.fill_betweenx 30 set_energy_lim = ax.set_ylim 31 set_dos_lim = ax.set_xlim 32 33 def plot(x, y, *args, **kwargs): 34 return ax.plot(y, x, *args, **kwargs) 35 else: 36 set_label = ax.set_ylabel 37 fill_between = ax.fill_between 38 set_energy_lim = ax.set_xlim 39 set_dos_lim = ax.set_ylim 40 41 def plot(x, y, *args, **kwargs): 42 return ax.plot(x, y, *args, **kwargs) 43 if fill: 44 fill_between(energy_e, base_e, dos_e + base_e, **fill) 45 if line: 46 plot(energy_e, dos_e, **line) 47 set_label('DOS', labelpad=0) 48 set_energy_lim(np.take(energy_e, (0, -1))) 49 set_dos_lim(dos_min, dos_max) 50 51 52class TCM(object): 53 54 def __init__(self, energy_o, energy_u, fermilevel): 55 self.energy_o = energy_o 56 self.energy_u = energy_u 57 self.fermilevel = fermilevel 58 59 self.base_o = np.zeros_like(energy_o) 60 self.base_u = np.zeros_like(energy_u) 61 62 def __getattr__(self, attr): 63 # Generate axis only when needed 64 if attr in ['ax_occ_dos', 'ax_unocc_dos', 'ax_tcm']: 65 gs = generate_gridspec(hspace=0.05, wspace=0.05) 66 self.ax_occ_dos = plt.subplot(gs[0]) 67 self.ax_unocc_dos = plt.subplot(gs[3]) 68 self.ax_tcm = plt.subplot(gs[2]) 69 return getattr(self, attr) 70 if attr in ['ax_spec']: 71 gs = generate_gridspec(hspace=0.8, wspace=0.8) 72 self.ax_spec = plt.subplot(gs[1]) 73 return getattr(self, attr) 74 if attr in ['ax_cbar']: 75 self.ax_cbar = plt.axes((0.15, 0.6, 0.02, 0.1)) 76 return getattr(self, attr) 77 raise AttributeError('%s object has no attribute %s' % 78 (repr(self.__class__.__name__), repr(attr))) 79 80 def plot_TCM(self, tcm_ou, vmax='80%', vmin='symmetrize', cmap='seismic', 81 log=False, colorbar=True, lw=None): 82 if lw is None: 83 lw = mpl.rcParams['lines.linewidth'] 84 energy_o = self.energy_o 85 energy_u = self.energy_u 86 fermilevel = self.fermilevel 87 88 tcmmax = np.max(np.absolute(tcm_ou)) 89 print('tcmmax', tcmmax) 90 91 # Plot TCM 92 ax = self.ax_tcm 93 plt.sca(ax) 94 plt.cla() 95 if isinstance(vmax, str): 96 assert vmax[-1] == '%' 97 tcmmax = np.max(np.absolute(tcm_ou)) 98 vmax = tcmmax * float(vmax[:-1]) / 100.0 99 if vmin == 'symmetrize': 100 vmin = -vmax 101 if tcm_ou.dtype == complex: 102 linecolor = 'w' 103 from matplotlib.colors import hsv_to_rgb 104 105 def transform_to_hsv(z, rmin, rmax, hue_start=90): 106 amp = np.absolute(z) # **2 107 amp = np.where(amp < rmin, rmin, amp) 108 amp = np.where(amp > rmax, rmax, amp) 109 ph = np.angle(z, deg=1) + hue_start 110 h = (ph % 360) / 360 111 s = 1.85 * np.ones_like(h) 112 v = (amp - rmin) / (rmax - rmin) 113 return hsv_to_rgb(np.dstack((h, s, v))) 114 115 img = transform_to_hsv(tcm_ou.T, 0, vmax) 116 plt.imshow(img, origin='lower', 117 extent=[energy_o[0], energy_o[-1], 118 energy_u[0], energy_u[-1]], 119 interpolation='bilinear', 120 ) 121 else: 122 linecolor = 'k' 123 if cmap == 'magma': 124 linecolor = 'w' 125 if log: 126 norm = LogNorm(vmin=vmin, vmax=vmax) 127 else: 128 norm = Normalize(vmin=vmin, vmax=vmax) 129 plt.pcolormesh(energy_o, energy_u, tcm_ou.T, 130 cmap=cmap, rasterized=True, norm=norm, 131 ) 132 if colorbar: 133 ax = self.ax_cbar 134 ax.clear() 135 cb = plt.colorbar(cax=ax) 136 cb.outline.set_edgecolor(linecolor) 137 ax.tick_params(axis='both', colors=linecolor) 138 # ax.yaxis.label.set_color(linecolor) 139 # ax.xaxis.label.set_color(linecolor) 140 ax = self.ax_tcm 141 plt.sca(ax) 142 plt.axhline(fermilevel, c=linecolor, lw=lw) 143 plt.axvline(fermilevel, c=linecolor, lw=lw) 144 145 ax.tick_params(axis='both', which='major', pad=2) 146 plt.xlabel(r'Occ. energy $\varepsilon_{o}$ (eV)', labelpad=0) 147 plt.ylabel(r'Unocc. energy $\varepsilon_{u}$ (eV)', labelpad=0) 148 plt.xlim(np.take(energy_o, (0, -1))) 149 plt.ylim(np.take(energy_u, (0, -1))) 150 151 def plot_DOS(self, dos_o, dos_u, stack=False, 152 fill={'color': '0.8'}, line={'color': 'k'}): 153 # Plot DOSes 154 if stack: 155 base_o = self.base_o 156 base_u = self.base_u 157 else: 158 base_o = np.zeros_like(self.energy_o) 159 base_u = np.zeros_like(self.energy_u) 160 dos_min = 0.0 161 dos_max = 1.01 * max(np.max(dos_o), np.max(dos_u)) 162 plot_DOS(self.ax_occ_dos, self.energy_o, dos_o, base_o, 163 dos_min, dos_max, flip=False, fill=fill, line=line) 164 plot_DOS(self.ax_unocc_dos, self.energy_u, dos_u, base_u, 165 dos_min, dos_max, flip=True, fill=fill, line=line) 166 if stack: 167 self.base_o += dos_o 168 self.base_u += dos_u 169 170 def plot_spectrum(self): 171 raise NotImplementedError() 172 173 def plot_TCM_diagonal(self, energy, **kwargs): 174 x_o = np.take(self.energy_o, (0, -1)) 175 self.ax_tcm.plot(x_o, x_o + energy, **kwargs) 176 177 def set_title(self, *args, **kwargs): 178 self.ax_occ_dos.set_title(*args, **kwargs) 179 180 181class TCMPlotter(TCM): 182 183 def __init__(self, ksd, energy_o, energy_u, sigma, 184 zero_fermilevel=True): 185 eig_n, fermilevel = ksd.get_eig_n(zero_fermilevel) 186 TCM.__init__(self, energy_o, energy_u, fermilevel) 187 self.ksd = ksd 188 self.sigma = sigma 189 self.eig_n = eig_n 190 191 def plot_TCM(self, weight_p, **kwargs): 192 # Calculate TCM 193 tcm_ou = self.ksd.get_TCM(weight_p, self.eig_n, self.energy_o, 194 self.energy_u, self.sigma) 195 TCM.plot_TCM(self, tcm_ou, **kwargs) 196 197 def plot_DOS(self, weight_n=1.0, **kwargs): 198 # Calculate DOS 199 dos_o, dos_u = self.ksd.get_weighted_DOS(weight_n, self.eig_n, 200 self.energy_o, 201 self.energy_u, 202 self.sigma) 203 TCM.plot_DOS(self, dos_o, dos_u, **kwargs) 204