1# Copyright 2014-2021 The PySCF Developers. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15'''
16There are two options to call wannier90 in PySCF.  One is the pyWannier90.py
17interface as implemented in this file.
18
19(1)
20pyWannier90: Wannier90 for PySCF (https://github.com/hungpham2017/pyWannier90)
21Hung Q. Pham
22email: pqh3.14@gmail.com
23
24(2)
25Another wannier90 python interface is available on the repo:
26    https://github.com/zhcui/wannier90
27Contact its author "Zhihao Cui" <zcui@caltech.edu> for more details of
28installation and implementations.
29'''
30
31import os, time
32import numpy as np
33from scipy.io import FortranFile
34import pyscf.data.nist as param
35from pyscf import lib
36from pyscf.pbc import df
37from pyscf.pbc.dft import gen_grid, numint
38
39try:
40    import libwannier90
41except ImportError:
42    print('WARNING: Check the installation of libwannier90 and its path in pyscf/pbc/tools/pywannier90.py')
43    print('libwannier90 can be found at: https://github.com/hungpham2017/pyWannier90')
44    raise
45
46
47def save_kmf(kmf, chkfile):
48    ''' Save a wavefunction'''
49    from pyscf.lib.chkfile import save
50    kpts = kmf.kpts
51    mo_energy_kpts = kmf.mo_energy_kpts
52    mo_coeff_kpts = kmf.mo_coeff_kpts
53
54    scf_dic = {'kpts'          : kpts,
55               'mo_energy_kpts': mo_energy_kpts,
56               'mo_coeff_kpts' : mo_coeff_kpts}
57    save(chkfile, 'scf', scf_dic)
58
59def load_kmf(chkfile):
60    ''' Load a wavefunction'''
61    from pyscf.lib.chkfile import load
62    kmf = load(chkfile, 'scf')
63    class fake_kmf:
64        def __init__(self, kmf):
65            self.kpts = kmf['kpts']
66            self.mo_energy_kpts = kmf['mo_energy_kpts']
67            self.mo_coeff_kpts = kmf['mo_coeff_kpts']
68    kmf = fake_kmf(kmf)
69    return kmf
70
71def angle(v1, v2):
72    '''
73    Return the angle (in radiant between v1 and v2)
74    '''
75
76    v1 = np.asarray(v1)
77    v2 = np.asarray(v2)
78    cosa = v1.dot(v2)/ np.linalg.norm(v1) / np.linalg.norm(v2)
79    return np.arccos(cosa)
80
81def transform(x_vec, z_vec):
82    '''
83    Construct a transformation matrix to transform r_vec to the new coordinate system defined by x_vec and z_vec
84    '''
85
86    x_vec = x_vec/np.linalg.norm(np.asarray(x_vec))
87    z_vec = z_vec/np.linalg.norm(np.asarray(z_vec))
88    assert x_vec.dot(z_vec) == 0    # x and z have to be orthogonal to one another
89    y_vec = -np.cross(x_vec,z_vec)
90    new = np.asarray([x_vec, y_vec, z_vec])
91    original = np.asarray([[1,0,0],[0,1,0],[0,0,1]])
92
93    tran_matrix = np.empty([3,3])
94    for row in range(3):
95        for col in range(3):
96            tran_matrix[row,col] = np.cos(angle(original[row],new[col]))
97
98    return tran_matrix.T
99
100def cartesian_prod(arrays, out=None, order='C'):
101    '''
102    This function is similar to lib.cartesian_prod of PySCF, except the output can be in Fortran or in C order
103    '''
104    arrays = [np.asarray(x) for x in arrays]
105    dtype = np.result_type(*arrays)
106    nd = len(arrays)
107    dims = [nd] + [len(x) for x in arrays]
108
109    if out is None:
110        out = np.empty(dims, dtype)
111    else:
112        out = np.ndarray(dims, dtype, buffer=out)
113    tout = out.reshape(dims)
114
115    shape = [-1] + [1] * nd
116    for i, arr in enumerate(arrays):
117        tout[i] = arr.reshape(shape[:nd-i])
118
119    return tout.reshape((nd,-1),order=order).T
120
121def periodic_grid(cell, grid=[50,50,50], supercell=[1,1,1], order='C'):
122    '''
123    Generate a periodic grid for the unit/computational cell in F/C order
124    '''
125    ngrid = np.asarray(grid)
126    qv = cartesian_prod([np.arange(-ngrid[i]*(supercell[i]//2),ngrid[i]*((supercell[i]+1)//2)) for i in range(3)], order=order)
127    a_frac = lib.einsum('i,ij->ij', 1./ngrid, cell.lattice_vectors())
128    coords = np.dot(qv, a_frac)
129
130    # Compute weight
131    ngrids = np.prod(grid)
132    ncells = np.prod(supercell)
133    weights = np.empty(ngrids*ncells)
134    weights[:] = cell.vol / ngrids / ncells
135    return coords, weights
136
137def R_r(r_norm, r=1, zona=1):
138    r'''
139    Radial functions used to compute \Theta_{l,m_r}(\theta,\phi)
140    '''
141
142    if r == 1:
143        R_r = 2 * zona**(3/2) * np.exp(-zona*r_norm)
144    elif r == 2:
145        R_r = 1 / 2 / np.sqrt(2) * zona**(3/2) * (2 - zona*r_norm) * np.exp(-zona*r_norm/2)
146    else:
147        R_r = np.sqrt(4/27) * zona**(3/2) * (1 - 2*zona*r_norm/3 + 2*(zona**2)*(r_norm**2)/27) * np.exp(-zona*r_norm/3)
148
149    return R_r
150
151def theta(func, cost, phi):
152    r'''
153    Basic angular functions (s,p,d,f) used to compute \Theta_{l,m_r}(\theta,\phi)
154    ref: Table 3.1 of the Wannier90 User guide
155        Link: https://github.com/wannier-developers/wannier90/raw/v3.1.0/doc/compiled_docs/user_guide.pdf
156    '''
157    sint = np.sqrt(1 - cost**2)
158    if func == 's':
159        theta = 1 / np.sqrt(4 * np.pi) * np.ones([cost.shape[0]])
160    elif func == 'pz':
161        theta = np.sqrt(3 / 4 / np.pi) * cost
162    elif func == 'px':
163        theta = np.sqrt(3 / 4 / np.pi) * sint * np.cos(phi)
164    elif func == 'py':
165        theta = np.sqrt(3 / 4 / np.pi) * sint * np.sin(phi)
166    elif func == 'dz2':
167        theta = np.sqrt(5 / 16 / np.pi) * (3*cost**2 - 1)
168    elif func == 'dxz':
169        theta = np.sqrt(15 / 4 / np.pi) * sint * cost * np.cos(phi)
170    elif func == 'dyz':
171        theta = np.sqrt(15 / 4 / np.pi) * sint * cost * np.sin(phi)
172    elif func == 'dx2-y2':
173        theta = np.sqrt(15 / 16 / np.pi) * (sint**2) * np.cos(2*phi)
174    elif func == 'dxy':
175        theta = np.sqrt(15 / 16 / np.pi) * (sint**2) * np.sin(2*phi)
176    elif func == 'fz3':
177        theta = np.sqrt(7) / 4 / np.sqrt(np.pi) * (5*cost**3 - 3*cost)
178    elif func == 'fxz2':
179        theta = np.sqrt(21) / 4 / np.sqrt(2*np.pi) * (5*cost**2 - 1) * sint * np.cos(phi)
180    elif func == 'fyz2':
181        theta = np.sqrt(21) / 4 / np.sqrt(2*np.pi) * (5*cost**2 - 1) * sint * np.sin(phi)
182    elif func == 'fz(x2-y2)':
183        theta = np.sqrt(105) / 4 / np.sqrt(np.pi) * sint**2 * cost * np.cos(2*phi)
184    elif func == 'fxyz':
185        theta = np.sqrt(105) / 4 / np.sqrt(np.pi) * sint**2 * cost * np.sin(2*phi)
186    elif func == 'fx(x2-3y2)':
187        theta = np.sqrt(35) / 4 / np.sqrt(2*np.pi) * sint**3 * (np.cos(phi)**2 - 3*np.sin(phi)**2) * np.cos(phi)
188    elif func == 'fy(3x2-y2)':
189        theta = np.sqrt(35) / 4 / np.sqrt(2*np.pi) * sint**3 * (3*np.cos(phi)**2 - np.sin(phi)**2) * np.sin(phi)
190
191    return theta
192
193def theta_lmr(l, mr, cost, phi):
194    r'''
195    Compute the value of \Theta_{l,m_r}(\theta,\phi)
196    ref: Table 3.1 and 3.2 of the Wannier90 User guide
197        Link: https://github.com/wannier-developers/wannier90/raw/v3.1.0/doc/compiled_docs/user_guide.pdf
198    '''
199    assert l in [0,1,2,3,-1,-2,-3,-4,-5]
200    assert mr in [1,2,3,4,5,6,7]
201
202    if l == 0:                           # s
203        theta_lmr = theta('s', cost, phi)
204    elif (l == 1) and (mr == 1):         # pz
205        theta_lmr = theta('pz', cost, phi)
206    elif (l == 1) and (mr == 2):         # px
207        theta_lmr = theta('px', cost, phi)
208    elif (l == 1) and (mr == 3):         # py
209        theta_lmr = theta('py', cost, phi)
210    elif (l == 2) and (mr == 1):         # dz2
211        theta_lmr = theta('dz2', cost, phi)
212    elif (l == 2) and (mr == 2):         # dxz
213        theta_lmr = theta('dxz', cost, phi)
214    elif (l == 2) and (mr == 3):         # dyz
215        theta_lmr = theta('dyz', cost, phi)
216    elif (l == 2) and (mr == 4):         # dx2-y2
217        theta_lmr = theta('dx2-y2', cost, phi)
218    elif (l == 2) and (mr == 5):         # dxy
219        theta_lmr = theta('dxy', cost, phi)
220    elif (l == 3) and (mr == 1):         # fz3
221        theta_lmr = theta('fz3', cost, phi)
222    elif (l == 3) and (mr == 2):         # fxz2
223        theta_lmr = theta('fxz2', cost, phi)
224    elif (l == 3) and (mr == 3):         # fyz2
225        theta_lmr = theta('fyz2', cost, phi)
226    elif (l == 3) and (mr == 4):         # fz(x2-y2)
227        theta_lmr = theta('fz(x2-y2)', cost, phi)
228    elif (l == 3) and (mr == 5):         # fxyz
229        theta_lmr = theta('fxyz', cost, phi)
230    elif (l == 3) and (mr == 6):         # fx(x2-3y2)
231        theta_lmr = theta('fx(x2-3y2)', cost, phi)
232    elif (l == 3) and (mr == 7):         # fy(3x2-y2)
233        theta_lmr = theta('fy(3x2-y2)', cost, phi)
234    elif (l == -1) and (mr == 1):         # sp-1
235        theta_lmr = 1/np.sqrt(2) * (theta('s', cost, phi) + theta('px', cost, phi))
236    elif (l == -1) and (mr == 2):         # sp-2
237        theta_lmr = 1/np.sqrt(2) * (theta('s', cost, phi) - theta('px', cost, phi))
238    elif (l == -2) and (mr == 1):         # sp2-1
239        theta_lmr = 1/np.sqrt(3) * theta('s', cost, phi) - 1/np.sqrt(6) *theta('px', cost, phi) + 1/np.sqrt(2) * theta('py', cost, phi)
240    elif (l == -2) and (mr == 2):         # sp2-2
241        theta_lmr = 1/np.sqrt(3) * theta('s', cost, phi) - 1/np.sqrt(6) *theta('px', cost, phi) - 1/np.sqrt(2) * theta('py', cost, phi)
242    elif (l == -2) and (mr == 3):         # sp2-3
243        theta_lmr = 1/np.sqrt(3) * theta('s', cost, phi) + 2/np.sqrt(6) *theta('px', cost, phi)
244    elif (l == -3) and (mr == 1):         # sp3-1
245        theta_lmr = 1/2 * (theta('s', cost, phi) + theta('px', cost, phi) + theta('py', cost, phi) + theta('pz', cost, phi))
246    elif (l == -3) and (mr == 2):         # sp3-2
247        theta_lmr = 1/2 * (theta('s', cost, phi) + theta('px', cost, phi) - theta('py', cost, phi) - theta('pz', cost, phi))
248    elif (l == -3) and (mr == 3):         # sp3-3
249        theta_lmr = 1/2 * (theta('s', cost, phi) - theta('px', cost, phi) + theta('py', cost, phi) - theta('pz', cost, phi))
250    elif (l == -3) and (mr == 4):         # sp3-4
251        theta_lmr = 1/2 * (theta('s', cost, phi) - theta('px', cost, phi) - theta('py', cost, phi) + theta('pz', cost, phi))
252    elif (l == -4) and (mr == 1):         # sp3d-1
253        theta_lmr = 1/np.sqrt(3) * theta('s', cost, phi) - 1/np.sqrt(6) *theta('px', cost, phi) + 1/np.sqrt(2) * theta('py', cost, phi)
254    elif (l == -4) and (mr == 2):         # sp3d-2
255        theta_lmr = 1/np.sqrt(3) * theta('s', cost, phi) - 1/np.sqrt(6) *theta('px', cost, phi) - 1/np.sqrt(2) * theta('py', cost, phi)
256    elif (l == -4) and (mr == 3):         # sp3d-3
257        theta_lmr = 1/np.sqrt(3) * theta('s', cost, phi) + 2/np.sqrt(6) * theta('px', cost, phi)
258    elif (l == -4) and (mr == 4):         # sp3d-4
259        theta_lmr = 1/np.sqrt(2) * (theta('pz', cost, phi) + theta('dz2', cost, phi))
260    elif (l == -4) and (mr == 5):         # sp3d-5
261        theta_lmr = 1/np.sqrt(2) * (-theta('pz', cost, phi) + theta('dz2', cost, phi))
262    elif (l == -5) and (mr == 1):         # sp3d2-1
263        theta_lmr = 1/np.sqrt(6) * theta('s', cost, phi) - 1/np.sqrt(2) *theta('px', cost, phi) - 1/np.sqrt(12) *theta('dz2', cost, phi) \
264                    + 1/2 *theta('dx2-y2', cost, phi)
265    elif (l == -5) and (mr == 2):         # sp3d2-2
266        theta_lmr = 1/np.sqrt(6) * theta('s', cost, phi) + 1/np.sqrt(2) *theta('px', cost, phi) - 1/np.sqrt(12) *theta('dz2', cost, phi) \
267                    + 1/2 *theta('dx2-y2', cost, phi)
268    elif (l == -5) and (mr == 3):         # sp3d2-3
269        theta_lmr = 1/np.sqrt(6) * theta('s', cost, phi) - 1/np.sqrt(2) *theta('py', cost, phi) - 1/np.sqrt(12) *theta('dz2', cost, phi) \
270                    - 1/2 *theta('dx2-y2', cost, phi)
271    elif (l == -5) and (mr == 4):         # sp3d2-4
272        theta_lmr = 1/np.sqrt(6) * theta('s', cost, phi) + 1/np.sqrt(2) *theta('py', cost, phi) - 1/np.sqrt(12) *theta('dz2', cost, phi) \
273                    - 1/2 *theta('dx2-y2', cost, phi)
274    elif (l == -5) and (mr == 5):         # sp3d2-5
275        theta_lmr = 1/np.sqrt(6) * theta('s', cost, phi) - 1/np.sqrt(2) *theta('pz', cost, phi) + 1/np.sqrt(3) *theta('dz2', cost, phi)
276    elif (l == -5) and (mr == 6):         # sp3d2-6
277        theta_lmr = 1/np.sqrt(6) * theta('s', cost, phi) + 1/np.sqrt(2) *theta('pz', cost, phi) + 1/np.sqrt(3) *theta('dz2', cost, phi)
278
279    return theta_lmr
280
281def g_r(grids_coor, site, l, mr, r, zona, x_axis=[1,0,0], z_axis=[0,0,1], unit='B'):
282    r'''
283    Evaluate the projection function g(r) or \Theta_{l,m_r}(\theta,\phi) on a grid
284    ref: Chapter 3, wannier90 User Guide
285    Attributes:
286        grids_coor : a grids for the cell of interest
287        site       : absolute coordinate (in Borh/Angstrom) of the g(r) in the cell
288        l, mr      : l and mr value in the Table 3.1 and 3.2 of the ref
289    Return:
290        theta_lmr  : an array (ngrid, value) of g(r)
291
292    '''
293
294    unit_conv = 1
295    if unit == 'A': unit_conv = param.BOHR
296
297    r_vec = (grids_coor - site)
298    r_vec = lib.einsum('iv,uv ->iu', r_vec, transform(x_axis, z_axis))
299    r_norm = np.linalg.norm(r_vec,axis=1)
300    if (r_norm < 1e-8).any():
301        r_vec = (grids_coor - site - 1e-5)
302        r_vec = lib.einsum('iv,uv ->iu', r_vec, transform(x_axis, z_axis))
303        r_norm = np.linalg.norm(r_vec,axis=1)
304    cost = r_vec[:,2]/r_norm
305
306    phi = np.empty_like(r_norm)
307    larger_idx = r_vec[:, 0] > 1e-8
308    smaller_idx = r_vec[:, 0] < -1e-8
309    neither_idx = not larger_idx and not smaller_idx
310    phi[larger_idx] = np.arctan(r_vec[larger_idx,1]/r_vec[larger_idx,0])
311    phi[smaller_idx] = np.arctan(r_vec[smaller_idx,1]/r_vec[smaller_idx,0])  + np.pi
312    phi[neither_idx] = np.sign(r_vec[neither_idx,1]) * 0.5 * np.pi
313
314    return theta_lmr(l, mr, cost, phi) * R_r(r_norm * unit_conv, r = r, zona = zona)
315
316def get_wigner_seitz_supercell(w90, ws_search_size=[2,2,2], ws_distance_tol=1e-6):
317    '''
318    Return a grid that contains all the lattice within the Wigner-Seitz supercell
319    Ref: the hamiltonian_wigner_seitz(count_pts) in wannier90/src/hamittonian.F90
320    '''
321
322    real_metric = w90.real_lattice_loc.T.dot(w90.real_lattice_loc)
323    dist_dim = np.prod(2 * (np.asarray(ws_search_size) + 1) + 1)
324    ndegen = []
325    irvec = []
326    mp_grid = np.asarray(w90.mp_grid_loc)
327    n1_range =  np.arange(-ws_search_size[0] * mp_grid[0], ws_search_size[0]*mp_grid[0] + 1)
328    n2_range =  np.arange(-ws_search_size[1] * mp_grid[1], ws_search_size[1]*mp_grid[1] + 1)
329    n3_range =  np.arange(-ws_search_size[2] * mp_grid[2], ws_search_size[2]*mp_grid[2] + 1)
330    x, y, z = np.meshgrid(n1_range, n2_range, n3_range)
331    n_list = np.vstack([z.flatten('F'), x.flatten('F'), y.flatten('F')]).T
332    i1 = np.arange(- ws_search_size[0] - 1, ws_search_size[0] + 2)
333    i2 = np.arange(- ws_search_size[1] - 1, ws_search_size[1] + 2)
334    i3 = np.arange(- ws_search_size[2] - 1, ws_search_size[2] + 2)
335    x, y, z = np.meshgrid(i1, i2, i3)
336    i_list = np.vstack([z.flatten('F'), x.flatten('F'), y.flatten('F')]).T
337
338    nrpts = 0
339    for n in n_list:
340        # Calculate |r-R|^2
341        ndiff = n - i_list * mp_grid
342        dist = (ndiff.dot(real_metric).dot(ndiff.T)).diagonal()
343
344        dist_min = dist.min()
345        if abs(dist[(dist_dim + 1)//2 -1] - dist_min) < ws_distance_tol**2:
346            temp = 0
347            for i in range(0, dist_dim):
348                if (abs(dist[i] - dist_min) < ws_distance_tol**2):
349                    temp = temp + 1
350            ndegen.append(temp)
351            irvec.append(n.tolist())
352            if (n**2).sum() < 1.e-10: rpt_origin = nrpts
353            nrpts = nrpts + 1
354
355    irvec = np.asarray(irvec)
356    ndegen = np.asarray(ndegen)
357
358    # Check the "sum rule"
359    tot = np.sum(1/np.asarray(ndegen))
360    assert tot - np.prod(mp_grid) < 1e-8, "Error in finding Wigner-Seitz points!!!"
361
362    return ndegen, irvec, rpt_origin
363
364def R_wz_sc(w90, R_in, R0, ws_search_size=[2,2,2], ws_distance_tol=1e-6):
365    '''
366    TODO: document it
367    Ref: This is the replication of the R_wz_sc function of ws_distance.F90
368    '''
369    ndegenx = 8 #max number of unit cells that can touch in a single point (i.e.  vertex of cube)
370    R_bz = np.asarray(R_in).reshape(-1, 3)
371    nR = R_bz.shape[0]
372    R0 = np.asarray(R0)
373    ndeg = np.zeros([nR], dtype=np.int32)
374    ndeg_ = np.zeros([nR, ndegenx])
375    shifts = np.zeros([nR, ndegenx, 3])
376    R_out = np.zeros([nR, ndegenx, 3])
377
378    mod2_R_bz = np.sum((R_bz - R0)**2, axis=1)
379    R_in_f = R_bz.dot(w90.recip_lattice_loc.T / 2 / np.pi)
380    n1_range =  np.arange(-ws_search_size[0] - 1, ws_search_size[0] + 2)
381    n2_range =  np.arange(-ws_search_size[1] - 1, ws_search_size[1] + 2)
382    n3_range =  np.arange(-ws_search_size[2] - 1, ws_search_size[2] + 2)
383    x, y, z = np.meshgrid(n1_range, n2_range, n3_range)
384    n_list = np.vstack([z.flatten('F'), x.flatten('F'), y.flatten('F')]).T
385    trans_vecs = n_list * w90.mp_grid_loc
386
387    # First loop:
388    R_f = np.repeat(R_in_f[:,np.newaxis,:], trans_vecs.shape[0], axis=1) + trans_vecs
389    R = R_f.dot(w90.real_lattice_loc)
390    mod2_R = np.sum((R - R0)**2, axis=2)
391    mod2_R_min = mod2_R.min(axis=1)
392    mod2_R_min_idx = np.argmin(mod2_R, axis=1)
393    idx = mod2_R_min < mod2_R_bz
394    R_bz[idx] = R[idx, mod2_R_min_idx[idx]]
395    mod2_R_bz[idx] = mod2_R_min[idx]
396    shifts_data = np.repeat(trans_vecs[np.newaxis,:,:], nR, axis=0)[idx, mod2_R_min_idx[idx]]
397    shifts[idx] = np.repeat(shifts_data[:,np.newaxis,:], ndegenx, axis=1)
398
399    idx = mod2_R_bz < ws_distance_tol**2
400    ndeg[idx] = 1
401    R_out[idx, 0] = R0
402
403    # Second loop:
404    R_in_f = R_bz.dot(w90.recip_lattice_loc.T / 2 / np.pi)
405    R_f = np.repeat(R_in_f[:,np.newaxis,:], trans_vecs.shape[0], axis=1) + trans_vecs
406    R = R_f.dot(w90.real_lattice_loc)
407    mod2_R = np.sum((R - R0)**2, axis=2)
408    mod2_R_bz = np.repeat(mod2_R_bz[:,np.newaxis], trans_vecs.shape[0], axis=1)
409    abs_diff = abs(np.sqrt(mod2_R) - np.sqrt(mod2_R_bz))
410    idx = abs_diff < ws_distance_tol
411    ndeg = idx.sum(axis=1)
412    assert (ndeg <= 8).all(), "The degeneracy cannot be larger than 8"
413    for i in range(nR):
414        R_out[i, :ndeg[i]] = R[i, idx[i]]
415        shifts[i, :ndeg[i]] = shifts[i, :ndeg[i]] + trans_vecs[idx[i]]
416        ndeg_[i, :ndeg[i]] = 1.0
417
418    return ndeg_, ndeg, R_out, shifts
419
420def ws_translate_dist(w90, irvec, ws_search_size=[2,2,2], ws_distance_tol=1e-6):
421    '''
422    TODO: document it
423    Ref: This is the replication of the ws_translate_dist function of ws_distance.F90
424    '''
425    nrpts = irvec.shape[0]
426    ndegenx = 8 #max number of unit cells that can touch in a single point (i.e.  vertex of cube)
427    num_wann = w90.num_wann
428    assert ndegenx*num_wann*nrpts > 0, "Unexpected dimensions in ws_translate_dist"
429
430    irvec_ = []
431    wann_centres_i = []
432    wann_centres_j = []
433    for i in range(3):
434        x, y, z = np.meshgrid(irvec[:,i], np.zeros(num_wann), np.zeros(num_wann), indexing='ij')
435        irvec_.append(x.flatten())
436        x, y, z = np.meshgrid(np.zeros(nrpts), np.zeros(num_wann), w90.wann_centres[:,i], indexing='ij')
437        wann_centres_i.append(z.flatten())
438        x, y, z = np.meshgrid(np.zeros(nrpts), w90.wann_centres[:,i], np.zeros(num_wann), indexing='ij')
439        wann_centres_j.append(y.flatten())
440
441
442    irvec_list = np.vstack(irvec_).T
443    irvec_cart_list = irvec_list.dot(w90.real_lattice_loc)
444    wann_centres_i_list = np.vstack(wann_centres_i).T
445    wann_centres_j_list = np.vstack(wann_centres_j).T
446    R_in = irvec_cart_list - wann_centres_i_list + wann_centres_j_list
447    wdist_ndeg_, wdist_ndeg, R_out, shifts = w90.R_wz_sc(R_in, [0,0,0], ws_search_size, ws_distance_tol)
448    ndegenx = wdist_ndeg_.shape[1]
449    irdist_ws = np.repeat(irvec_list[:,np.newaxis,:], ndegenx, axis=1) + shifts
450    crdist_ws = irdist_ws.dot(w90.real_lattice_loc)
451
452    # Reformat the matrices for the computational convenience in lib.einsum
453    wdist_ndeg = wdist_ndeg.reshape(nrpts, num_wann, num_wann)
454    wdist_ndeg_ = wdist_ndeg_.reshape(nrpts, num_wann, num_wann, ndegenx).transpose(3,0,1,2)
455    irdist_ws = irdist_ws.reshape(nrpts, num_wann, num_wann, ndegenx, 3).transpose(3,0,1,2,4)
456    crdist_ws = crdist_ws.reshape(nrpts, num_wann, num_wann, ndegenx, 3).transpose(3,0,1,2,4)
457
458    return wdist_ndeg, wdist_ndeg_, irdist_ws, crdist_ws
459
460
461'''Main class of pyWannier90'''
462class W90:
463    def __init__(self, kmf, cell, mp_grid, num_wann, gamma=False, spinors=False, spin_up=None, other_keywords=None):
464
465        if isinstance(kmf, str):
466            self.kmf = load_kmf(kmf)
467        else:
468            self.kmf = kmf
469        self.cell = cell
470        self.num_wann = num_wann
471        self.keywords = other_keywords
472
473        # Collect the pyscf calculation info
474        nao_kpts = []
475        for mo_energy in kmf.mo_energy_kpts:
476            nao_kpts.append(mo_energy.shape[0])
477
478        self.num_bands_tot = np.min(nao_kpts)
479        if self.num_bands_tot < cell.nao_nr():
480            print(('The number of bands at different k-point are not the same. '
481                   'The first %d bands are used.') % (self.num_bands_tot) )
482
483        self.num_kpts_loc = self.kmf.kpts.shape[0]
484        self.mp_grid_loc = mp_grid
485        assert self.num_kpts_loc == np.asarray(self.mp_grid_loc).prod()
486        self.real_lattice_loc = self.cell.lattice_vectors() * param.BOHR
487        self.recip_lattice_loc = self.cell.reciprocal_vectors() / param.BOHR
488        self.kpt_latt_loc = self.cell.get_scaled_kpts(self.kmf.kpts)
489        self.num_atoms_loc = self.cell.natm
490        self.atom_symbols_loc = [atom[0] for atom in self.cell._atom]
491        self.atom_atomic_loc = [int(self.cell._atm[atom][0] + self.cell.atom_nelec_core(atom))
492                                for atom in range(self.num_atoms_loc)]
493        self.atoms_cart_loc = np.asarray([(np.asarray(atom[1])* param.BOHR).tolist()
494                                          for atom in self.cell._atom])
495        self.gamma_only, self.spinors = (0 , 0)
496        if gamma: self.gamma_only = 1
497        if spinors: self.spinors = 1
498
499        # Wannier90_setup outputs
500        self.num_bands_loc = None
501        self.num_wann_loc = None
502        self.nntot_loc = None
503        self.nn_list = None
504        self.proj_site = None
505        self.proj_l = None
506        self.proj_m = None
507        self.proj_radial = None
508        self.proj_z = None
509        self.proj_x = None
510        self.proj_zona = None
511        self.exclude_bands = None
512        self.proj_s = None
513        self.proj_s_qaxis = None
514
515        # Input for Wannier90_run
516        self.band_included_list = None
517        self.A_matrix_loc = None
518        self.M_matrix_loc = None
519        self.eigenvalues_loc = None
520
521        # Wannier90_run outputs
522        self.U_matrix = None
523        self.U_matrix_opt = None
524        self.lwindow = None
525        self.wann_centres = None
526        self.wann_spreads = None
527        self.spread = None
528
529        # Others
530        self.use_bloch_phases = False
531        self.spin_up = spin_up
532        self.mo_energy_kpts = []
533        self.mo_coeff_kpts = []
534        if np.mod(self.cell.nelectron,2) !=0:
535            if spin_up:
536                for kpt in range(self.num_kpts_loc):
537                    self.mo_energy_kpts.append(self.kmf.mo_energy_kpts[0][kpt][:self.num_bands_tot])
538                    self.mo_coeff_kpts.append(self.kmf.mo_coeff_kpts[0][kpt][:,:self.num_bands_tot])
539            else:
540                for kpt in range(self.num_kpts_loc):
541                    self.mo_energy_kpts.append(self.kmf.mo_energy_kpts[1][kpt][:self.num_bands_tot])
542                    self.mo_coeff_kpts.append(self.kmf.mo_coeff_kpts[1][kpt][:,:self.num_bands_tot])
543        else:
544
545            for kpt in range(self.num_kpts_loc):
546                self.mo_energy_kpts.append(self.kmf.mo_energy_kpts[kpt][:self.num_bands_tot])
547                self.mo_coeff_kpts.append(self.kmf.mo_coeff_kpts[kpt][:,:self.num_bands_tot])
548
549    def kernel(self, external_AME=None):
550        '''
551        Main kernel for pyWannier90
552        '''
553        self.make_win()
554        self.setup()
555        if external_AME is not None:
556            self.M_matrix_loc = self.read_M_mat(external_AME + '.mmn')
557            self.A_matrix_loc = self.read_A_mat(external_AME + '.amn')
558            self.eigenvalues_loc = self.read_epsilon_mat(external_AME + '.eig')
559        else:
560            self.M_matrix_loc = self.get_M_mat()
561            self.A_matrix_loc = self.get_A_mat()
562            self.eigenvalues_loc = self.get_epsilon_mat()
563        self.run()
564
565    def make_win(self):
566        '''
567        Make a basic *.win file for wannier90
568        '''
569
570        win_file = open('wannier90.win', "w")
571        win_file.write('! Basic input generated by the pyWannier90. Date: %s\n' % (time.ctime()))
572        win_file.write('\n')
573        win_file.write('num_bands       = %d\n' % (self.num_bands_tot))
574        win_file.write('num_wann       = %d\n' % (self.num_wann))
575        win_file.write('\n')
576        win_file.write('Begin Unit_Cell_Cart\n')
577        for row in range(3):
578            win_file.write('%10.7f  %10.7f  %10.7f\n' %
579                           (self.real_lattice_loc[0, row], self.real_lattice_loc[1, row],
580                            self.real_lattice_loc[2, row]))
581        win_file.write('End Unit_Cell_Cart\n')
582        win_file.write('\n')
583        win_file.write('Begin atoms_cart\n')
584        for atom in range(len(self.atom_symbols_loc)):
585            win_file.write('%s  %7.7f  %7.7f  %7.7f\n' %
586                           (self.atom_symbols_loc[atom], self.atoms_cart_loc[atom,0],
587                            self.atoms_cart_loc[atom,1], self.atoms_cart_loc[atom,2]))
588        win_file.write('End atoms_cart\n')
589        win_file.write('\n')
590        if self.use_bloch_phases: win_file.write('use_bloch_phases = T\n\n')
591        if self.keywords is not None:
592            win_file.write('!Additional keywords\n')
593            win_file.write(self.keywords)
594        win_file.write('\n\n\n')
595        win_file.write('mp_grid        = %d %d %d\n' %
596                       (self.mp_grid_loc[0], self.mp_grid_loc[1], self.mp_grid_loc[2]))
597        if self.gamma_only == 1: win_file.write('gamma_only : true\n')
598        win_file.write('begin kpoints\n')
599        for kpt in range(self.num_kpts_loc):
600            win_file.write('%7.7f  %7.7f  %7.7f\n' %
601                           (self.kpt_latt_loc[kpt][0], self.kpt_latt_loc[kpt][1], self.kpt_latt_loc[kpt][2]))
602        win_file.write('End Kpoints\n')
603        win_file.close()
604
605    def get_M_mat(self):
606        r'''
607        Construct the ovelap matrix: M_{m,n}^{(\mathbf{k,b})}
608        Equation (25) in MV, Phys. Rev. B 56, 12847
609        '''
610
611        M_matrix_loc = np.empty([self.num_kpts_loc, self.nntot_loc,
612                                 self.num_bands_loc, self.num_bands_loc],
613                                dtype = np.complex128)
614
615        for k_id in range(self.num_kpts_loc):
616            for nn in range(self.nntot_loc):
617                k1 = self.cell.get_abs_kpts(self.kpt_latt_loc[k_id])
618                k_id2 = self.nn_list[nn, k_id, 0] - 1
619                k2_ = self.kpt_latt_loc[k_id2]
620                k2_scaled = k2_ + self.nn_list[nn, k_id, 1:4]
621                k2 = self.cell.get_abs_kpts(k2_scaled)
622                s_AO = df.ft_ao.ft_aopair(self.cell, -k2+k1, kpti_kptj=[k2,k1], q = np.zeros(3))[0]
623                Cm = self.mo_coeff_kpts[k_id][:,self.band_included_list]
624                Cn = self.mo_coeff_kpts[k_id2][:,self.band_included_list]
625                M_matrix_loc[k_id, nn,:,:] = lib.einsum('nu,vm,uv->nm', Cn.T.conj(), Cm, s_AO).conj()
626
627        return M_matrix_loc
628
629    def read_M_mat(self, filename=None):
630        r'''
631        Read the ovelap matrix: M_{m,n}^{(\mathbf{k,b})} from seedname.mnn
632        '''
633        if filename is None: filename = 'wannier90.mmn'
634        assert os.path.exists(filename), "Cannot find " + filename
635
636        with open(filename, 'r') as f:
637            data = f.readlines()
638            num_bands_loc, num_kpts_loc, nntot_loc = np.int64(data[1].split())
639            data = data[2:]
640            nn_list = []
641            nline = num_bands_loc**2 + 1
642            M_matrix_loc = np.empty([num_kpts_loc, nntot_loc, num_bands_loc, num_bands_loc], dtype = np.complex128)
643            jump = 0
644            for kpt in range(num_kpts_loc):
645                for nn in range(nntot_loc):
646                    temp = data[jump : jump + nline]
647                    nn_list.append(np.int64(temp[0].split()))
648                    val_in_float = np.float64(" ".join(temp[1:]).split()).reshape(-1,2)
649                    val_in_complex = val_in_float[:,0] + 1j * val_in_float[:,1]
650                    M_matrix_loc[kpt, nn] = val_in_complex.reshape(num_bands_loc, num_bands_loc)
651                    jump += nline
652
653        return M_matrix_loc
654
655    def get_A_mat(self):
656        r'''
657        Construct the projection matrix: A_{m,n}^{\mathbf{k}}
658        Equation (62) in MV, Phys. Rev. B 56, 12847 or equation (22) in SMV, Phys. Rev. B 65, 035109
659        '''
660
661        A_matrix_loc = np.empty([self.num_kpts_loc, self.num_wann_loc, self.num_bands_loc], dtype = np.complex128)
662
663        if self.use_bloch_phases:
664            Amn = np.zeros([self.num_wann_loc, self.num_bands_loc])
665            np.fill_diagonal(Amn, 1)
666            A_matrix_loc[:,:,:] = Amn
667        else:
668            from pyscf.dft import numint as mol_numint
669            from pyscf.dft import gen_grid as mol_gen_grid
670            grids = mol_gen_grid.Grids(self.cell).build()
671            coords = grids.coords
672            weights = grids.weights
673            for ith_wann in range(self.num_wann_loc):
674                frac_site = self.proj_site[ith_wann]
675                abs_site = frac_site.dot(self.real_lattice_loc) / param.BOHR
676                l = self.proj_l[ith_wann]
677                mr = self.proj_m[ith_wann]
678                r = self.proj_radial[ith_wann]
679                zona = self.proj_zona[ith_wann]
680                x_axis = self.proj_x[ith_wann]
681                z_axis = self.proj_z[ith_wann]
682                gr = g_r(coords, abs_site, l, mr, r, zona, x_axis, z_axis, unit = 'B')
683                ao_L0 = mol_numint.eval_ao(self.cell, coords)
684                s_aoL0_g = lib.einsum('i,i,iv->v', weights, gr, ao_L0)
685                for k_id in range(self.num_kpts_loc):
686                    kpt = self.cell.get_abs_kpts(self.kpt_latt_loc[k_id])
687                    mo_included = self.mo_coeff_kpts[k_id][:,self.band_included_list]
688                    s_kpt = self.cell.pbc_intor('int1e_ovlp', hermi=1, kpts=kpt, pbcopt=lib.c_null_ptr())
689                    A_matrix_loc[k_id,ith_wann,:] = lib.einsum('v,vu,um->m', s_aoL0_g, s_kpt, mo_included,
690                                                               optimize=True).conj()
691
692        return A_matrix_loc
693
694    def read_A_mat(self, filename=None):
695        r'''
696        Read the ovelap matrix: M_{m,n}^{(\mathbf{k,b})} from seedname.mnn
697        '''
698        if filename is None: filename = 'wannier90.amn'
699        assert os.path.exists(filename), "Cannot find " + filename
700
701        with open(filename, 'r') as f:
702            data = f.readlines()
703            num_bands_loc, num_kpts_loc, num_wann_loc = np.int64(data[1].split())
704            data = data[2:]
705            A_matrix_loc = np.empty([num_kpts_loc, num_wann_loc, num_bands_loc], dtype = np.complex128)
706            val_in_float = np.float64(" ".join(data).split()).reshape(-1,5)
707            A_matrix_loc = (val_in_float[:,3] + 1j * val_in_float[:,4]).reshape(num_kpts_loc, num_wann_loc, num_bands_loc)
708
709        return A_matrix_loc
710
711    def get_epsilon_mat(self):
712        r'''
713        Construct the eigenvalues matrix: \epsilon_{n}^(\mathbf{k})
714        '''
715
716        return np.asarray(self.mo_energy_kpts, dtype=np.float64)[:,self.band_included_list] * param.HARTREE2EV
717
718    def read_epsilon_mat(self, filename=None):
719        r'''
720        Read the eigenvalues matrix: \epsilon_{n}^(\mathbf{k})
721        '''
722        if filename is None: filename = 'wannier90.eig'
723        assert os.path.exists(filename), "Cannot find " + filename
724        with open(filename, 'r') as f:
725            data = f.read()
726            temp = np.float64(data.split()).reshape(-1, 3)
727            nbands = int(temp[:,0].max())
728            nkpts = int(temp[:,1].max())
729            eigenvals = temp[:,2].reshape(nkpts, nbands)
730
731        return eigenvals
732
733    def setup(self):
734        '''
735        Execute the Wannier90_setup
736        '''
737
738        real_lattice_loc = self.real_lattice_loc.T.flatten()
739        recip_lattice_loc = self.recip_lattice_loc.T.flatten()
740        kpt_latt_loc = self.kpt_latt_loc.flatten()
741        atoms_cart_loc = self.atoms_cart_loc.flatten()
742
743        (bands_wann_nntot, nn_list, proj_site, proj_l, proj_m, proj_radial,
744         proj_z, proj_x, proj_zona, exclude_bands, proj_s, proj_s_qaxis) = \
745                libwannier90.setup(self.mp_grid_loc, self.num_kpts_loc, real_lattice_loc,
746                                   recip_lattice_loc, kpt_latt_loc,
747                                   self.num_bands_tot, self.num_atoms_loc,
748                                   self.atom_atomic_loc, atoms_cart_loc, self.gamma_only, self.spinors)
749
750        # Convert outputs to the correct data type
751        self.num_bands_loc, self.num_wann_loc, self.nntot_loc = np.int32(bands_wann_nntot)
752        self.nn_list = np.int32(nn_list)
753        self.proj_site = proj_site
754        self.proj_l = np.int32(proj_l)
755        self.proj_m = np.int32(proj_m)
756        self.proj_radial = np.int32(proj_radial)
757        self.proj_z = proj_z
758        self.proj_x = proj_x
759        self.proj_zona = proj_zona
760        self.exclude_bands = np.int32(exclude_bands)
761        self.band_included_list = [i for i in range(self.num_bands_tot) if (i + 1) not in self.exclude_bands]
762        self.proj_s = np.int32(proj_s)
763        self.proj_s_qaxis = proj_s_qaxis
764
765    def run(self):
766        '''
767        Execute the Wannier90_run
768        '''
769
770        assert self.num_wann_loc is not None
771        assert isinstance(self.M_matrix_loc, np.ndarray)
772        assert isinstance(self.A_matrix_loc, np.ndarray)
773        assert isinstance(self.eigenvalues_loc, np.ndarray)
774
775        real_lattice_loc = self.real_lattice_loc.T.flatten()
776        recip_lattice_loc = self.recip_lattice_loc.T.flatten()
777        kpt_latt_loc = self.kpt_latt_loc.flatten()
778        atoms_cart_loc = self.atoms_cart_loc.flatten()
779        M_matrix_loc = self.M_matrix_loc.flatten()
780        A_matrix_loc = self.A_matrix_loc.flatten()
781        eigenvalues_loc = self.eigenvalues_loc.flatten()
782
783        U_matrix, U_matrix_opt, lwindow, wann_centres, wann_spreads, spread = \
784                libwannier90.run(self.mp_grid_loc, self.num_kpts_loc, real_lattice_loc,
785                                 recip_lattice_loc, kpt_latt_loc, self.num_bands_loc,
786                                 self.num_wann_loc, self.nntot_loc, self.num_atoms_loc,
787                                 self.atom_atomic_loc, atoms_cart_loc, self.gamma_only,
788                                 M_matrix_loc, A_matrix_loc, eigenvalues_loc)
789
790        # Convert outputs to the correct data typ
791        self.U_matrix = U_matrix
792        self.U_matrix_opt = U_matrix_opt
793        lwindow = np.int32(np.abs(lwindow.real))
794        self.lwindow = (lwindow == 1)
795        self.wann_centres = wann_centres.real
796        self.wann_spreads = wann_spreads.real
797        self.spread = spread.real
798
799    get_wigner_seitz_supercell = get_wigner_seitz_supercell
800    R_wz_sc = R_wz_sc
801    ws_translate_dist = ws_translate_dist
802
803    def get_hamiltonian_kpts(self):
804        '''Get the Hamiltonian in k-space, this should be identical to Fock matrix from PySCF'''
805
806        assert self.U_matrix is not None, "You must wannierize first, then you can run this function"
807        eigenvals_in_window = []
808        for k_id in range(self.num_kpts_loc):
809            mo_included = self.mo_energy_kpts[k_id][self.band_included_list]
810            orbs_in_win = self.lwindow[k_id]
811            mo_in_window = mo_included[orbs_in_win]
812            U_matrix_opt = self.U_matrix_opt[k_id][ :, orbs_in_win].T
813            eigenvals = lib.einsum('m,mo,mo->o', mo_in_window, U_matrix_opt.conj(), U_matrix_opt)
814            eigenvals_in_window.append(eigenvals)
815
816        hamiltonian_kpts = lib.einsum('kso,ko,kto->kst', self.U_matrix.conj(), eigenvals_in_window, self.U_matrix)
817        return hamiltonian_kpts
818
819    def get_hamiltonian_Rs(self, Rs, ham_kpts=None):
820        '''Get the R-space Hamiltonian H(R0, R) centered at R0 or the first R in Rs list
821        '''
822
823        assert self.U_matrix is not None, "You must wannierize first, then you can run this function"
824        nkpts = self.kpt_latt_loc.shape[0]
825        if ham_kpts is not None:
826            hamiltonian_kpts = ham_kpts
827        else:
828            hamiltonian_kpts = self.get_hamiltonian_kpts()
829
830        # Find the center either R(0,0,0) or the first R in the Rs list
831        ngrid = len(Rs)
832        center = np.arange(ngrid)[(np.asarray(Rs)**2).sum(axis=1) < 1e-10]
833        if center.shape[0] == 1:
834            center = center[0]
835        else:
836            center = 0
837
838        # The phase factor is computed using the exp(1j*R.dot(k)) rather than exp(-1j*R.dot(k)) in wannier90
839        phase = 1/np.sqrt(nkpts) * np.exp(1j* 2*np.pi * np.dot(Rs, self.kpt_latt_loc.T))
840        hamiltonian_R0 = lib.einsum('k,kst,Rk->Rst', phase[center], hamiltonian_kpts, phase.conj())
841
842        return hamiltonian_R0
843
844    def interpolate_ham_kpts(self, frac_kpts, ham_kpts=None,
845                             use_ws_distance=True, ws_search_size=[2,2,2],
846                             ws_distance_tol=1e-6):
847        ''' Interpolate the band structure using the Slater-Koster scheme
848            Return:
849                eigenvalues and eigenvectors at the desired kpts
850        '''
851
852        assert self.U_matrix is not None, "You must wannierize first, then you can run this function"
853        ndegen, Rs, center = self.get_wigner_seitz_supercell(ws_search_size, ws_distance_tol)
854        hamiltonian_R0 = self.get_hamiltonian_Rs(Rs, ham_kpts)
855
856        # Interpolate H(kpts) at the desired k-pts
857        if use_ws_distance:
858            wdist_ndeg, wdist_ndeg_, irdist_ws, crdist_ws = self.ws_translate_dist(Rs)
859            temp = lib.einsum('iRstx,kx->iRstk', irdist_ws, frac_kpts)
860            phase = lib.einsum('iRstk,iRst->Rstk', np.exp(1j* 2*np.pi * temp), wdist_ndeg_)
861            inter_hamiltonian_kpts = \
862                    lib.einsum('R,Rst,Rts,Rstk->kst', 1/ndegen, 1/wdist_ndeg, hamiltonian_R0, phase)
863        else:
864            phase = np.exp(1j* 2*np.pi * np.dot(Rs, frac_kpts.T))
865            inter_hamiltonian_kpts = \
866                    lib.einsum('R,Rst,Rk->kst', 1/ndegen, hamiltonian_R0, phase)
867
868        return inter_hamiltonian_kpts
869
870    def interpolate_band(self, frac_kpts, ham_kpts=None, use_ws_distance=True,
871                         ws_search_size=[2,2,2], ws_distance_tol=1e-6):
872        ''' Interpolate the band structure using the Slater-Koster scheme
873            Return:
874                eigenvalues and eigenvectors at the desired kpts
875        '''
876
877        assert self.U_matrix is not None, (
878            "You must wannierize first, then you can run this function")
879        inter_hamiltonian_kpts = self.interpolate_ham_kpts(
880            frac_kpts, ham_kpts, use_ws_distance, ws_search_size, ws_distance_tol)
881        # Diagonalize H(kpts) to get eigenvalues and eigenvector
882        nkpts = frac_kpts.shape[0]
883        eigvals, eigvecs = np.linalg.eigh(inter_hamiltonian_kpts)
884        idx_kpts = eigvals.argsort()
885        eigvals = np.asarray([eigvals[kpt][idx_kpts[kpt]] for kpt in range(nkpts)])
886        eigvecs = np.asarray([eigvecs[kpt][:,idx_kpts[kpt]] for kpt in range(nkpts)])
887
888        return eigvals, eigvecs
889
890    def is_real(self, threshold=1.e-6):
891        '''
892        Fourier transform the mo coefficients to real space and check if it is real
893        '''
894
895        assert self.U_matrix is not None, "You must wannierize first, then you can run this function"
896        eigenvecs_in_window = []
897        for k_id in range(self.num_kpts_loc):
898            mo_included = self.mo_coeff_kpts[k_id][:,self.band_included_list]
899            orbs_in_win = self.lwindow[k_id]
900            mo_in_window = mo_included[:, orbs_in_win].dot(self.U_matrix_opt[k_id][ :, orbs_in_win].T)
901            eigenvecs_in_window.append(mo_in_window)
902
903        # Rotate the mo(kpts) into localized basis
904        rotated_mo_coeff_kpts = lib.einsum('kum,ksm->kus', eigenvecs_in_window, self.U_matrix)
905
906        # Fourier transform the localized mo
907        nkx, nky, nkz = self.mp_grid_loc
908        Ts = lib.cartesian_prod((np.arange(nkx), np.arange(nky), np.arange(nkz)))
909        nkpts = self.kpt_latt_loc.shape[0]
910        phase = 1/np.sqrt(nkpts) * np.exp(1j* 2*np.pi * np.dot(Ts, self.kpt_latt_loc.T))
911        mo_coeff_Rs = lib.einsum('k,kus,Rk->Rus', phase[0], rotated_mo_coeff_kpts, phase.conj())
912
913        return mo_coeff_Rs.imag.max() < threshold
914
915    def export_unk(self, grid=[50,50,50]):
916        '''
917        Export the periodic part of BF in a real space grid for plotting with wannier90
918        '''
919
920        grids_coor, weights = periodic_grid(self.cell, grid, order = 'F')
921
922        for k_id in range(self.num_kpts_loc):
923            if self.spin_up:
924                spin = '.1'
925            else:
926                spin = '.2'
927            kpt = self.cell.get_abs_kpts(self.kpt_latt_loc[k_id])
928            ao = numint.eval_ao(self.cell, grids_coor, kpt = kpt)
929            u_ao = lib.einsum('x,xi->xi', np.exp(-1j*np.dot(grids_coor, kpt)), ao)
930            unk_file = FortranFile('UNK' + "%05d" % (k_id + 1) + spin, 'w')
931            unk_file.write_record(np.asarray([grid[0], grid[1], grid[2], k_id + 1, self.num_bands_loc], dtype = np.int32))
932            mo_included = self.mo_coeff_kpts[k_id][:,self.band_included_list]
933            u_mo = lib.einsum('xi,in->xn', u_ao, mo_included)
934            for band in range(len(self.band_included_list)):
935                unk_file.write_record(np.asarray(u_mo[:,band], dtype = np.complex128))
936            unk_file.close()
937
938    def export_AME(self, grid=[50,50,50]):
939        r'''
940        Export A_{m,n}^{\mathbf{k}} and M_{m,n}^{(\mathbf{k,b})} and \epsilon_{n}^(\mathbf{k})
941        '''
942
943        if self.A_matrix_loc is None:
944            self.make_win()
945            self.setup()
946            self.M_matrix_loc = self.get_M_mat()
947            self.A_matrix_loc = self.get_A_mat()
948            self.eigenvalues_loc = self.get_epsilon_mat()
949            self.export_unk(self, grid = grid)
950
951        with open('wannier90.mmn', 'w') as f:
952            f.write('Generated by the pyWannier90. Date: %s\n' % (time.ctime()))
953            f.write('    %d    %d    %d\n' % (self.num_bands_loc, self.num_kpts_loc, self.nntot_loc))
954
955            for k_id in range(self.num_kpts_loc):
956                for nn in range(self.nntot_loc):
957                    k_id1 = k_id + 1
958                    k_id2 = self.nn_list[nn, k_id, 0]
959                    nnn, nnm, nnl = self.nn_list[nn, k_id, 1:4]
960                    f.write('    %d  %d    %d  %d  %d\n' % (k_id1, k_id2, nnn, nnm, nnl))
961                    for m in range(self.num_bands_loc):
962                        for n in range(self.num_bands_loc):
963                            f.write('    %22.18e  %22.18e\n' % (self.M_matrix_loc[k_id, nn,m,n].real,
964                                                                self.M_matrix_loc[k_id, nn,m,n].imag))
965
966        with open('wannier90.amn', 'w') as f:
967            f.write('Generated by the pyWannier90. Date: %s\n' % (time.ctime()))
968            f.write('    %d    %d    %d\n' % (self.num_bands_loc, self.num_kpts_loc, self.num_wann_loc))
969
970            for k_id in range(self.num_kpts_loc):
971                for ith_wann in range(self.num_wann_loc):
972                    for band in range(self.num_bands_loc):
973                        f.write('    %d    %d    %d    %22.18e    %22.18e\n' %
974                                (band+1, ith_wann+1, k_id+1,
975                                 self.A_matrix_loc[k_id,ith_wann,band].real,
976                                 self.A_matrix_loc[k_id,ith_wann,band].imag))
977
978        with open('wannier90.eig', 'w') as f:
979            for k_id in range(self.num_kpts_loc):
980                for band in range(self.num_bands_loc):
981                    f.write('    %d    %d    %22.18e\n' % (band+1, k_id+1, self.eigenvalues_loc[k_id,band]))
982
983    def get_wannier(self, supercell=[1,1,1], grid=[50,50,50]):
984        '''
985        Evaluate the MLWF using a periodic grid
986        '''
987
988        grids_coor, weights = periodic_grid(self.cell, grid, supercell = [1,1,1], order = 'C')
989        kpts = self.cell.get_abs_kpts(self.kpt_latt_loc)
990
991        u_mo  = []
992        for k_id in range(self.num_kpts_loc):
993            mo_included = self.mo_coeff_kpts[k_id][:,self.band_included_list]
994            mo_in_window = self.lwindow[k_id]
995            C_opt = mo_included[:,mo_in_window].dot(self.U_matrix_opt[k_id][ :, mo_in_window].T)
996            C_tildle = C_opt.dot(self.U_matrix[k_id].T)
997            kpt = kpts[k_id]
998            ao = numint.eval_ao(self.cell, grids_coor, kpt = kpt)
999            u_ao = lib.einsum('x,xi->xi', np.exp(-1j*np.dot(grids_coor, kpt)), ao)
1000            u_mo.append(lib.einsum('xi,in->xn', u_ao, C_tildle))
1001
1002        u_mo = np.asarray(u_mo)
1003        WF0 = libwannier90.get_WF0s(self.kpt_latt_loc.shape[0],self.kpt_latt_loc, supercell, grid, u_mo)
1004
1005        # Fix the global phase following the pw2wannier90 procedure
1006        max_index = (WF0*WF0.conj()).real.argmax(axis=0)
1007        norm_wfs = np.diag(WF0[max_index,:])
1008        norm_wfs = norm_wfs/np.absolute(norm_wfs)
1009        WF0 = WF0/norm_wfs/self.num_kpts_loc
1010
1011        # Check the 'reality' following the pw2wannier90 procedure
1012        for WF_id in range(self.num_wann_loc):
1013            ratio_max = np.abs(WF0[np.abs(WF0[:,WF_id].real) >= 0.01,WF_id].imag /
1014                               WF0[np.abs(WF0[:,WF_id].real) >= 0.01,WF_id].real).max(axis=0)
1015            print('The maximum imag/real for wannier function ', WF_id,' : ', ratio_max)
1016        return WF0
1017
1018    def get_guess_orb(self, frac_site=[0,0,0], l=0, mr=1, r=1,
1019                      zona=1.0, x_axis=[1,0,0], z_axis=[0,0,1],
1020                      supercell=[1,1,1], grid=[50,50,50]):
1021        '''
1022        Evaluate a guess orbital using a periodic uniform grid
1023        '''
1024        grids_coor, weights = periodic_grid(self.cell, grid, supercell = supercell, order = 'C')
1025        frac_site = np.asarray(frac_site)
1026        abs_site = frac_site.dot(self.real_lattice_loc) / param.BOHR
1027        gr = g_r(grids_coor, abs_site, l, mr, r, zona, x_axis, z_axis, unit = 'B')
1028        return gr
1029
1030    def plot_wf(self, outfile='MLWF', wf_list=None, supercell=[1,1,1], grid=[50,50,50]):
1031        '''
1032        Export Wannier function at cell R
1033        xsf format: http://web.mit.edu/xcrysden_v1.5.60/www/XCRYSDEN/doc/XSF.html
1034        Attributes:
1035            wf_list        : a list of MLWFs to plot
1036            supercell    : a supercell used for plotting
1037        '''
1038
1039        if wf_list is None:
1040            wf_list = list(range(self.num_wann_loc))
1041
1042        grid = np.asarray(grid)
1043        origin = np.asarray([-(grid[i]*(supercell[i]//2) + 1)/grid[i] for i in range(3)]).dot(
1044            self.cell.lattice_vectors().T)* param.BOHR
1045        real_lattice_loc = (grid*supercell-1)/grid * self.cell.lattice_vectors() * param.BOHR
1046        nx, ny, nz = grid*supercell
1047        WF0 = self.get_wannier(supercell = supercell, grid = grid)
1048
1049
1050        for wf_id in wf_list:
1051            assert wf_id in list(range(self.num_wann_loc))
1052            WF = WF0[:,wf_id].reshape(nx,ny,nz).real
1053
1054            with open(outfile + '-' + str(wf_id) + '.xsf', 'w') as f:
1055                f.write('Generated by the pyWannier90. Date: %s\n\n' % (time.ctime()))
1056                f.write('CRYSTAL\n')
1057                f.write('PRIMVEC\n')
1058                for row in range(3):
1059                    f.write('%10.7f  %10.7f  %10.7f\n' %
1060                            (self.real_lattice_loc[row,0], self.real_lattice_loc[row,1],
1061                             self.real_lattice_loc[row,2]))
1062                f.write('CONVVEC\n')
1063                for row in range(3):
1064                    f.write('%10.7f  %10.7f  %10.7f\n' %
1065                            (self.real_lattice_loc[row,0], self.real_lattice_loc[row,1],
1066                             self.real_lattice_loc[row,2]))
1067                f.write('PRIMCOORD\n')
1068                f.write('%3d %3d\n' % (self.num_atoms_loc, 1))
1069                for atom in range(len(self.atom_symbols_loc)):
1070                    f.write('%s  %7.7f  %7.7f  %7.7f\n' %
1071                            (self.atom_symbols_loc[atom], self.atoms_cart_loc[atom][0],
1072                             self.atoms_cart_loc[atom][1], self.atoms_cart_loc[atom][2]))
1073                f.write('\n\n')
1074                f.write('BEGIN_BLOCK_DATAGRID_3D\n3D_field\nBEGIN_DATAGRID_3D_UNKNOWN\n')
1075                f.write('   %5d     %5d  %5d\n' % (nx, ny, nz))
1076                f.write('   %10.7f  %10.7f  %10.7f\n' % (origin[0],origin[1],origin[2]))
1077                for row in range(3):
1078                    f.write('   %10.7f  %10.7f  %10.7f\n' %
1079                            (real_lattice_loc[row,0], real_lattice_loc[row,1], real_lattice_loc[row,2]))
1080
1081                fmt = ' %13.5e' * nx + '\n'
1082                for iz in range(nz):
1083                    for iy in range(ny):
1084                        f.write(fmt % tuple(WF[:,iy,iz].tolist()))
1085                f.write('END_DATAGRID_3D\nEND_BLOCK_DATAGRID_3D')
1086
1087    def plot_guess_orbs(self, outfile='guess_orb', frac_site=[0,0,0], l=0, mr=1, r=1,
1088                        zona=1.0, x_axis=[1,0,0], z_axis=[0,0,1],
1089                        supercell=[1,1,1], grid=[50,50,50]):
1090        '''
1091        Export Wannier function at cell R
1092        xsf format: http://web.mit.edu/xcrysden_v1.5.60/www/XCRYSDEN/doc/XSF.html
1093        Attributes:
1094            wf_list        : a list of MLWFs to plot
1095            supercell    : a supercell used for plotting
1096        '''
1097
1098        grid = np.asarray(grid)
1099        origin = np.asarray([-grid[i]*(supercell[i]//2)/grid[i] for i in range(3)]).dot(
1100            self.cell.lattice_vectors().T) * param.BOHR
1101        real_lattice_loc = (grid*supercell-1)/grid * self.cell.lattice_vectors() * param.BOHR
1102        nx, ny, nz = grid*supercell
1103        guess_orb = self.get_guess_orb(frac_site=frac_site, l=l, mr=mr, r=r,
1104                                       zona=zona, x_axis=x_axis, z_axis=z_axis,
1105                                       supercell=supercell, grid=grid)
1106        guess_orb = guess_orb.reshape(nx,ny,nz).real
1107
1108        with open(outfile + '.xsf', 'w') as f:
1109            f.write('Generated by the pyWannier90\n\n')
1110            f.write('CRYSTAL\n')
1111            f.write('PRIMVEC\n')
1112            for row in range(3):
1113                f.write('%10.7f  %10.7f  %10.7f\n' % (self.real_lattice_loc[row,0], self.real_lattice_loc[row,1],
1114                                                      self.real_lattice_loc[row,2]))
1115            f.write('CONVVEC\n')
1116            for row in range(3):
1117                f.write('%10.7f  %10.7f  %10.7f\n' % (self.real_lattice_loc[row,0], self.real_lattice_loc[row,1],
1118                                                      self.real_lattice_loc[row,2]))
1119            f.write('PRIMCOORD\n')
1120            f.write('%3d %3d\n' % (self.num_atoms_loc, 1))
1121            for atom in range(len(self.atom_symbols_loc)):
1122                f.write('%s  %7.7f  %7.7f  %7.7f\n' % (self.atom_symbols_loc[atom], self.atoms_cart_loc[atom][0],
1123                                                       self.atoms_cart_loc[atom][1], self.atoms_cart_loc[atom][2]))
1124            f.write('\n\n')
1125            f.write('BEGIN_BLOCK_DATAGRID_3D\n3D_field\nBEGIN_DATAGRID_3D_UNKNOWN\n')
1126            f.write('   %5d     %5d  %5d\n' % (nx, ny, nz))
1127            f.write('   %10.7f  %10.7f  %10.7f\n' % (origin[0],origin[1],origin[2]))
1128            for row in range(3):
1129                f.write('   %10.7f  %10.7f  %10.7f\n' % (real_lattice_loc[row,0], real_lattice_loc[row,1],
1130                                                         real_lattice_loc[row,2]))
1131
1132            fmt = ' %13.5e' * nx + '\n'
1133            for iz in range(nz):
1134                for iy in range(ny):
1135                    f.write(fmt % tuple(guess_orb[:,iy,iz].tolist()))
1136            f.write('END_DATAGRID_3D\nEND_BLOCK_DATAGRID_3D')
1137
1138if __name__ == '__main__':
1139    from pyscf.pbc import gto as pgto
1140    from pyscf.pbc import scf as pscf
1141    import pywannier90
1142
1143    # build cell object
1144    cell = pgto.Cell()
1145    cell.a = [[0.0, 2.7155, 2.7155], [2.7155, 0.0, 2.7155], [2.7155, 2.7155, 0.0]]
1146    cell.atom = [['Si',[0.0,0.0,0.0]], ['Si',[1.35775, 1.35775, 1.35775]]]
1147    cell.basis = 'gth-dzv'
1148    cell.pseudo = 'gth-pade'
1149    cell.exp_to_discard = 0.1
1150    cell.build()
1151
1152    # build and run scf object
1153    kmesh = [3, 1, 1]
1154    kpts = cell.make_kpts(kmesh)
1155    kmf = pscf.KKS(cell, kpts)
1156    kmf.xc = 'pbe'
1157    kmf.run()
1158
1159    # build and run w90 object
1160    num_wann = 8
1161    keywords = '''
1162    begin projections
1163    Si:sp3
1164    end projections
1165    '''
1166    w90 = pywannier90.W90(kmf, cell, kmesh, num_wann, other_keywords=keywords)
1167    w90.kernel()
1168