1import sys
2
3import numpy as np
4from scipy.optimize import leastsq
5
6from gpaw.wavefunctions.pw import PWDescriptor
7from gpaw.kpt_descriptor import KPointDescriptor
8from gpaw.response.pair import PairDensity
9
10
11def check_degenerate_bands(filename, etol):
12
13    from gpaw import GPAW
14    calc = GPAW(filename, txt=None)
15    print('Number of Electrons   :', calc.get_number_of_electrons())
16    nibzkpt = calc.get_ibz_k_points().shape[0]
17    nbands = calc.get_number_of_bands()
18    print('Number of Bands       :', nbands)
19    print('Number of ibz-kpoints :', nibzkpt)
20    e_kn = np.array([calc.get_eigenvalues(k) for k in range(nibzkpt)])
21    f_kn = np.array([calc.get_occupation_numbers(k) for k in range(nibzkpt)])
22    for k in range(nibzkpt):
23        for n in range(1, nbands):
24            if (f_kn[k, n - 1] - f_kn[k, n] > 1e-5)\
25               and (np.abs(e_kn[k, n] - e_kn[k, n - 1]) < etol):
26                print(k, n, e_kn[k, n], e_kn[k, n - 1])
27    return
28
29
30def get_orbitals(calc):
31    """Get LCAO orbitals on 3D grid by lcao_to_grid method."""
32
33    bfs_a = [setup.phit_j for setup in calc.wfs.setups]
34
35    from gpaw.lfc import BasisFunctions
36    bfs = BasisFunctions(calc.wfs.gd, bfs_a, calc.wfs.kd.comm, cut=True)
37    bfs.set_positions(calc.spos_ac)
38
39    nLCAO = calc.get_number_of_bands()
40    orb_MG = calc.wfs.gd.zeros(nLCAO)
41    C_M = np.identity(nLCAO)
42    bfs.lcao_to_grid(C_M, orb_MG, q=-1)
43
44    return orb_MG
45
46
47def get_pw_descriptor(q_c, calc, ecut, gammacentered=False):
48    """Get the planewave descriptor of q_c."""
49    qd = KPointDescriptor([q_c])
50    pd = PWDescriptor(ecut, calc.wfs.gd,
51                      complex, qd, gammacentered=gammacentered)
52    return pd
53
54
55def get_bz_transitions(filename, q_c, bzk_kc,
56                       response='density', spins='all',
57                       ecut=50, txt=sys.stdout):
58    """
59    Get transitions in the Brillouin zone from kpoints bzk_kv
60    contributing to the linear response at wave vector q_c.
61    """
62
63    pair = PairDensity(filename, ecut=ecut, response=response, txt=txt)
64    pd = get_pw_descriptor(q_c, pair.calc, pair.ecut)
65
66    bzk_kv = np.dot(bzk_kc, pd.gd.icell_cv) * 2 * np.pi
67
68    if spins == 'all':
69        spins = range(pair.calc.wfs.nspins)
70    else:
71        for spin in spins:
72            assert spin in range(pair.calc.wfs.nspins)
73
74    domain_dl = (bzk_kv, spins)
75    domainsize_d = [len(domain_l) for domain_l in domain_dl]
76    nterms = np.prod(domainsize_d)
77    domainarg_td = []
78    for t in range(nterms):
79        unravelled_d = np.unravel_index(t, domainsize_d)
80        arg = []
81        for domain_l, index in zip(domain_dl, unravelled_d):
82            arg.append(domain_l[index])
83        domainarg_td.append(tuple(arg))
84
85    return pair, pd, domainarg_td
86
87
88def get_chi0_integrand(pair, pd, n_n, m_m, k_v, s):
89    """
90    Calculates the pair densities, occupational differences
91    and energy differences of transitions from certain kpoint
92    and spin.
93    """
94
95    k_c = np.dot(pd.gd.cell_cv, k_v) / (2 * np.pi)
96
97    kptpair = pair.get_kpoint_pair(pd, s, k_c, n_n[0], n_n[-1] + 1,
98                                   m_m[0], m_m[-1] + 1)
99
100    n_nmG = pair.get_pair_density(pd, kptpair, n_n, m_m)
101    df_nm = kptpair.get_occupation_differences(n_n, m_m)
102    eps_n = kptpair.kpt1.eps_n
103    eps_m = kptpair.kpt2.eps_n
104
105    return n_nmG, df_nm, eps_n, eps_m
106
107
108def get_degeneracy_matrix(eps_n, tol=1.e-3):
109    """
110    Generate a matrix that can sum over degenerate values.
111    """
112    degmat = []
113    eps_N = []
114    nn = len(eps_n)
115    nstart = 0
116    while nstart < nn:
117        deg = [0] * nstart + [1]
118        eps_N.append(eps_n[nstart])
119        for n in range(nstart + 1, nn):
120            if abs(eps_n[nstart] - eps_n[n]) < tol:
121                deg += [1]
122                nstart += 1
123            else:
124                break
125        deg += [0] * (nn - len(deg))
126        degmat.append(deg)
127        nstart += 1
128
129    return np.array(degmat), np.array(eps_N)
130
131
132def get_individual_transition_strengths(n_nmG, df_nm, G1, G2):
133    return (df_nm * n_nmG[:, :, G1] * n_nmG[:, :, G2].conj()).real
134
135
136def find_peaks(x, y, threshold=None):
137    """ Find peaks for a certain curve.
138
139    Usage:
140    threshold = (xmin, xmax, ymin, ymax)
141
142    """
143
144    assert isinstance(x, np.ndarray) and isinstance(y, np.ndarray)
145    assert x.ndim == 1 and y.ndim == 1
146    assert x.shape[0] == y.shape[0]
147
148    if threshold is None:
149        threshold = (x.min(), x.max(), y.min(), y.max())
150
151    if not isinstance(threshold, tuple):
152        threshold = (threshold, )
153
154    if len(threshold) == 1:
155        threshold += (x.max(), y.min(), y.max())
156    elif len(threshold) == 2:
157        threshold += (y.min(), y.max())
158    elif len(threshold) == 3:
159        threshold += (y.max(),)
160    else:
161        pass
162
163    xmin = threshold[0]
164    xmax = threshold[1]
165    ymin = threshold[2]
166    ymax = threshold[3]
167
168    peak = {}
169    npeak = 0
170    for i in range(1, x.shape[0] - 1):
171        if (y[i] >= ymin and y[i] <= ymax and
172            x[i] >= xmin and x[i] <= xmax):
173            if y[i] > y[i - 1] and y[i] > y[i + 1]:
174                peak[npeak] = np.array([x[i], y[i]])
175                npeak += 1
176
177    peakarray = np.zeros([npeak, 2])
178    for i in range(npeak):
179        peakarray[i] = peak[i]
180
181    return peakarray
182
183
184def lorz_fit(x, y, npeak=1, initpara=None):
185    """ Fit curve using Lorentzian function
186
187    Note: currently only valid for one and two lorentizian
188
189    The lorentzian function is defined as::
190
191                      A w
192        lorz = --------------------- + y0
193                (x-x0)**2 + w**2
194
195    where A is the peak amplitude, w is the width, (x0,y0) the peak position
196
197    Parameters:
198
199    x, y: ndarray
200        Input data for analyze
201    p: ndarray
202        Parameters for curving fitting function. [A, x0, y0, w]
203    p0: ndarray
204        Parameters for initial guessing. similar to p
205
206    """
207
208    def residual(p, x, y):
209
210        err = y - lorz(x, p, npeak)
211        return err
212
213    def lorz(x, p, npeak):
214
215        if npeak == 1:
216            return p[0] * p[3] / ((x - p[1])**2 + p[3]**2) + p[2]
217        if npeak == 2:
218            return (p[0] * p[3] / ((x - p[1])**2 + p[3]**2) + p[2]
219                    + p[4] * p[7] / ((x - p[5])**2 + p[7]**2) + p[6])
220        else:
221            raise ValueError('Larger than 2 peaks not supported yet!')
222
223    if initpara is None:
224        if npeak == 1:
225            initpara = np.array([1., 0., 0., 0.1])
226        if npeak == 2:
227            initpara = np.array([1., 0., 0., 0.1,
228                                 3., 0., 0., 0.1])
229    p0 = initpara
230
231    result = leastsq(residual, p0, args=(x, y), maxfev=2000)
232
233    yfit = lorz(x, result[0], npeak)
234
235    return yfit, result[0]
236
237
238def linear_fit(x, y, initpara=None):
239    def residual(p, x, y):
240        err = y - linear(x, p)
241        return err
242
243    def linear(x, p):
244        return p[0] * x + p[1]
245
246    if initpara is None:
247        initpara = np.array([1.0, 1.0])
248
249    p0 = initpara
250    result = leastsq(residual, p0, args=(x, y), maxfev=2000)
251    yfit = linear(x, result[0])
252
253    return yfit, result[0]
254
255
256def plot_setfont():
257    import matplotlib.pyplot as plt
258    params = {'axes.labelsize': 18,
259              'text.fontsize': 18,
260              'legend.fontsize': 18,
261              'xtick.labelsize': 18,
262              'ytick.labelsize': 18,
263              'text.usetex': True}
264    #          'figure.figsize': fig_size}
265    plt.rcParams.update(params)
266
267
268def plot_setticks(x=True, y=True):
269    import matplotlib.pyplot as plt
270    plt.minorticks_on()
271    ax = plt.gca()
272    if x:
273        ax.xaxis.set_major_locator(plt.AutoLocator())
274        x_major = ax.xaxis.get_majorticklocs()
275        dx_minor = (x_major[-1] - x_major[0]) / (len(x_major) - 1) / 5.
276        ax.xaxis.set_minor_locator(plt.MultipleLocator(dx_minor))
277    else:
278        plt.minorticks_off()
279
280    if y:
281        ax.yaxis.set_major_locator(plt.AutoLocator())
282        y_major = ax.yaxis.get_majorticklocs()
283        dy_minor = (y_major[-1] - y_major[0]) / (len(y_major) - 1) / 5.
284        ax.yaxis.set_minor_locator(plt.MultipleLocator(dy_minor))
285    else:
286        plt.minorticks_off()
287