1import pickle
2
3import numpy as np
4from ase.utils import gcd
5from ase.units import Ha
6from scipy.interpolate import InterpolatedUnivariateSpline
7
8from gpaw import GPAW
9from gpaw.spinorbit import soc_eigenstates
10from gpaw.kpt_descriptor import to1bz
11
12
13class GWBands:
14    """This class defines the GW_bands properties"""
15
16    def __init__(self,
17                 calc=None,
18                 comm=None,
19                 gw_file=None,
20                 kpoints=None,
21                 bandrange=None):
22
23        self.calc = GPAW(calc, communicator=comm, txt=None)
24        if gw_file is not None:
25            self.gw_file = pickle.load(open(gw_file, 'rb'), encoding='bytes')
26        self.kpoints = kpoints
27        if bandrange is None:
28            self.bandrange = np.arange(self.calc.get_number_of_bands())
29        else:
30            self.bandrange = bandrange
31
32        self.gd = self.calc.wfs.gd.new_descriptor()
33        self.kd = self.calc.wfs.kd
34
35        self.acell_cv = self.gd.cell_cv
36        self.bcell_cv = 2 * np.pi * self.gd.icell_cv
37        self.vol = self.gd.volume
38        self.BZvol = (2 * np.pi)**3 / self.vol
39
40    def find_k_along_path(self, plot_BZ=True):
41        """Finds the k-points along the bandpath present in the
42           original calculation"""
43        kd = self.kd
44        acell_cv = self.acell_cv
45        bcell_cv = self.bcell_cv
46        kpoints = self.kpoints
47
48        if plot_BZ:
49            """Plotting the points in the Brillouin Zone"""
50            kp_1bz = to1bz(kd.bzk_kc, acell_cv)
51
52            bzk_kcv = np.dot(kd.bzk_kc, bcell_cv)
53            kp_1bz_v = np.dot(kp_1bz, bcell_cv)
54
55            import matplotlib.pyplot as plt
56            plt.plot(bzk_kcv[:, 0], bzk_kcv[:, 1], 'xg')
57            plt.plot(kp_1bz_v[:, 0], kp_1bz_v[:, 1], 'ob')
58            for ik in range(1, len(kpoints)):
59                kpoint1_v = np.dot(kpoints[ik], bcell_cv)
60                kpoint2_v = np.dot(kpoints[ik - 1], bcell_cv)
61                plt.plot([kpoint1_v[0], kpoint2_v[0]], [kpoint1_v[1],
62                                                        kpoint2_v[1]], '--vr')
63
64        """Finding the points along given directions"""
65        print('Finding the kpoints along the path')
66        N_c = kd.N_c
67        wpts_xc = kpoints
68
69        x_x = []
70        k_xc = []
71        k_x = []
72        x = 0.
73        X = []
74        for nwpt in range(1, len(wpts_xc)):
75            X.append(x)
76            to_c = wpts_xc[nwpt]
77            from_c = wpts_xc[nwpt - 1]
78            vec_c = to_c - from_c
79            print('From ', from_c, ' to ', to_c)
80            Nv_c = (vec_c * N_c).round().astype(int)
81            Nv = abs(gcd(gcd(Nv_c[0], Nv_c[1]), Nv_c[2]))
82            print(Nv, ' points found')
83            dv_c = vec_c / Nv
84            dv_v = np.dot(dv_c, bcell_cv)
85            dx = np.linalg.norm(dv_v)
86            if nwpt == len(wpts_xc) - 1:
87                # X.append(Nv * dx)
88                Nv += 1
89            for n in range(Nv):
90                k_c = from_c + n * dv_c
91                bzk_c = to1bz(np.array([k_c]), acell_cv)[0]
92                ikpt = kd.where_is_q(bzk_c, kd.bzk_kc)
93                x_x.append(x)
94                k_xc.append(k_c)
95                k_x.append(ikpt)
96                x += dx
97        X.append(x_x[-1])
98        if plot_BZ is True:
99            for ik in range(len(k_xc)):
100                ktemp_xcv = np.dot(k_xc[ik], bcell_cv)
101                plt.plot(ktemp_xcv[0], ktemp_xcv[1], 'xr', markersize=10)
102            plt.show()
103
104        return x_x, k_xc, k_x, X
105
106    def get_dft_eigenvalues(self):
107        Nk = len(self.calc.get_ibz_k_points())
108        bands = np.arange(self.bandrange[0], self.bandrange[-1])
109        e_kn = np.array([self.calc.get_eigenvalues(kpt=k)[bands]
110                         for k in range(Nk)])
111        return e_kn
112
113    def get_vacuum_level(self, plot_pot=False):
114        """Finds the vacuum level through Hartree potential"""
115        vHt_g = self.calc.get_electrostatic_potential()
116        vHt_z = np.mean(np.mean(vHt_g, axis=0), axis=0)
117
118        if plot_pot:
119            import matplotlib.pyplot as plt
120            plt.plot(vHt_z)
121            plt.show()
122        return vHt_z[0]
123
124    def get_spinorbit_corrections(self, return_spin=True, return_wfs=False,
125                                  bands=None, gwqeh_file=None, dft=False,
126                                  eig_file=None):
127        """Gets the spinorbit corrections to the eigenvalues"""
128        calc = self.calc
129        bandrange = self.bandrange
130
131        if not dft:
132            try:
133                e_kn = self.gw_file['qp'][0]
134            except KeyError:
135                e_kn = self.gw_file[b'qp'][0]
136        else:
137            if eig_file is not None:
138                e_kn = pickle.load(open(eig_file))[0]
139            else:
140                e_kn = self.get_dft_eigenvalues()[
141                    :, bandrange[0]:bandrange[-1] + 1]
142
143        # this will fail - please write a test!
144        soc = soc_eigenstates(
145            calc,
146            n1=bandrange[0],
147            n2=bandrange[-1] + 1,
148            eigenvalues=e_kn[np.newaxis])
149        eSO_nk = soc.eigenvalues().T
150        e_kn = eSO_nk.T
151        return e_kn
152
153    def get_gw_bands(self, nk_Int=50, interpolate=False, SO=False,
154                     gwqeh_file=None, dft=False, eig_file=None, vac=False):
155        """Getting Eigenvalues along the path"""
156        kd = self.kd
157        if SO:
158            e_kn = self.get_spinorbit_corrections(return_wfs=True,
159                                                  dft=dft,
160                                                  eig_file=eig_file)
161            if gwqeh_file is not None:
162                gwqeh_file = pickle.load(open(gwqeh_file))
163                eqeh_noSO_kn = gwqeh_file['qp_sin'][0] * Ha
164                eqeh_kn = np.zeros_like(e_kn)
165                eqeh_kn[:, ::2] = eqeh_noSO_kn
166                eqeh_kn[:, 1::2] = eqeh_noSO_kn
167
168                e_kn += eqeh_kn
169
170        elif gwqeh_file is not None:
171            gwqeh_file = pickle.load(open(gwqeh_file))
172            e_kn = gwqeh_file['Qp_sin'][0] * Ha
173        elif eig_file is not None:
174            e_kn = pickle.load(open(eig_file))[0]
175        else:
176            if not dft:
177                try:
178                    e_kn = self.gw_file['qp'][0]
179                except KeyError:
180                    e_kn = self.gw_file[b'qp'][0]
181            else:
182                e_kn = self.get_dft_eigenvalues()
183        e_kn = np.sort(e_kn, axis=1)
184
185        bandrange = self.bandrange
186        ef = self.calc.get_fermi_level()
187        if vac:
188            evac = self.get_vacuum_level()
189        else:
190            evac = 0.0
191        x_x, k_xc, k_x, X = self.find_k_along_path(plot_BZ=False)
192
193        k_ibz_x = np.zeros_like(k_x)
194        eGW_kn = np.zeros((len(k_x), e_kn.shape[1]))
195        for n in range(e_kn.shape[1]):
196            for ik in range(len(k_x)):
197                ibzkpt = kd.bz2ibz_k[k_x[ik]]
198                k_ibz_x[ik] = ibzkpt
199                eGW_kn[ik, n] = e_kn[ibzkpt, n]
200
201        N_occ = (eGW_kn[0] < ef).sum()
202        print(N_occ, bandrange[0])
203        # N_occ = int(self.calc.get_number_of_electrons()/2)
204        print(' ')
205        if SO:
206            print('The number of Occupied bands is:', N_occ + 2 * bandrange[0])
207        else:
208            print('The number of Occupied bands is:', N_occ + bandrange[0])
209        gap = (eGW_kn[:, N_occ].min() - eGW_kn[:, N_occ - 1].max())
210        print('The bandgap is: %f' % gap)
211
212        vbm = eGW_kn[:, N_occ - 1].max() - evac
213        cbm = eGW_kn[:, N_occ].min() - evac
214
215        if interpolate:
216            xfit_k = np.linspace(x_x[0], x_x[-1], nk_Int)
217            xfit_k = np.append(xfit_k, x_x)
218            xfit_k = np.sort(xfit_k)
219            nk_Int = len(xfit_k)
220            efit_kn = np.zeros((nk_Int, eGW_kn.shape[1]))
221            for n in range(eGW_kn.shape[1]):
222                fit_e = InterpolatedUnivariateSpline(x_x, eGW_kn[:, n])
223                efit_kn[:, n] = fit_e(xfit_k)
224
225            results = {'x_k': xfit_k,
226                       'X': X,
227                       'e_kn': efit_kn - evac,
228                       'ef': ef - evac,
229                       'gap': gap,
230                       'vbm': vbm,
231                       'cbm': cbm}
232
233            return results
234        else:
235            results = {'x_k': x_x,
236                       'X': X,
237                       'k_ibz_x': k_ibz_x,
238                       'e_kn': eGW_kn - evac,
239                       'ef': ef - evac,
240                       'gap': gap,
241                       'vbm': vbm,
242                       'cbm': cbm}
243            return results
244