1"""Tetrahedron method for Brillouin-zone integrations.
2
3See::
4
5    Improved tetrahedron method for Brillouin-zone integrations.
6
7    Peter E. Blöchl, O. Jepsen, and O. K. Andersen
8    Phys. Rev. B 49, 16223 – Published 15 June 1994
9
10    DOI:https://doi.org/10.1103/PhysRevB.49.16223
11"""
12
13from math import nan
14from typing import List, Tuple, cast
15import numpy as np
16from scipy.spatial import Delaunay
17
18from gpaw.occupations import (ZeroWidth, findroot, collect_eigelvalues,
19                              distribute_occupation_numbers,
20                              OccupationNumberCalculator, ParallelLayout)
21from gpaw.mpi import broadcast_float
22from gpaw.typing import Array1D, Array2D, Array3D, ArrayLike1D, ArrayLike2D
23
24
25def bja1(e1: Array1D, e2: Array1D, e3: Array1D, e4: Array1D
26         ) -> Tuple[float, Array1D]:
27    """Eq. (A2) and (C2) from Blöchl, Jepsen and Andersen."""
28    x = 1.0 / ((e2 - e1) * (e3 - e1) * (e4 - e1))
29    return (-(e1**3).dot(x),
30            3 * e1**2 * x)
31
32
33def bja2(e1: Array1D, e2: Array1D, e3: Array1D, e4: Array1D
34         ) -> Tuple[float, Array1D]:
35    """Eq. (A3) and (C3) from Blöchl, Jepsen and Andersen."""
36    x = 1.0 / ((e3 - e1) * (e4 - e1))
37    y = (e3 - e1 + e4 - e2) / ((e3 - e2) * (e4 - e2))
38    return (x.dot((e2 - e1)**2
39                  - 3 * (e2 - e1) * e2
40                  + 3 * e2**2
41                  + y * e2**3),
42            x * (3 * (e2 - e1)
43                 - 6 * e2
44                 - 3 * y * e2**2))
45
46
47def bja3(e1: Array1D, e2: Array1D, e3: Array1D, e4: Array1D
48         ) -> Tuple[float, Array1D]:
49    """Eq. (A4) and (C4) from Blöchl, Jepsen and Andersen."""
50    x = 1.0 / ((e4 - e1) * (e4 - e2) * (e4 - e3))
51    return (len(e1) - x.dot(e4**3),
52            3 * x * e4**2)
53
54
55def bja1b(e1: Array1D, e2: Array1D, e3: Array1D, e4: Array1D) -> Array2D:
56    """Eq. (B2)-(B5) from Blöchl, Jepsen and Andersen."""
57    C = -0.25 * e1**3 / ((e2 - e1) * (e3 - e1) * (e4 - e1))
58    w2 = -C * e1 / (e2 - e1)
59    w3 = -C * e1 / (e3 - e1)
60    w4 = -C * e1 / (e4 - e1)
61    w1 = 4 * C - w2 - w3 - w4
62    return np.array([w1, w2, w3, w4])
63
64
65def bja2b(e1: Array1D, e2: Array1D, e3: Array1D, e4: Array1D) -> Array2D:
66    """Eq. (B7)-(B10) from Blöchl, Jepsen and Andersen."""
67    C1 = 0.25 * e1**2 / ((e4 - e1) * (e3 - e1))
68    C2 = 0.25 * e1 * e2 * e3 / ((e4 - e1) * (e3 - e2) * (e3 - e1))
69    C3 = 0.25 * e2**2 * e4 / ((e4 - e2) * (e3 - e2) * (e4 - e1))
70    w1 = C1 + (C1 + C2) * e3 / (e3 - e1) + (C1 + C2 + C3) * e4 / (e4 - e1)
71    w2 = C1 + C2 + C3 + (C2 + C3) * e3 / (e3 - e2) + C3 * e4 / (e4 - e2)
72    w3 = (C1 + C2) * e1 / (e1 - e3) - (C2 + C3) * e2 / (e3 - e2)
73    w4 = (C1 + C2 + C3) * e1 / (e1 - e4) + C3 * e2 / (e2 - e4)
74    return np.array([w1, w2, w3, w4])
75
76
77def bja3b(e1: Array1D, e2: Array1D, e3: Array1D, e4: Array1D) -> Array2D:
78    """Eq. (B14)-(B17) from Blöchl, Jepsen and Andersen."""
79    C = 0.25 * e4**3 / ((e4 - e1) * (e4 - e2) * (e4 - e3))
80    w1 = 0.25 - C * e4 / (e4 - e1)
81    w2 = 0.25 - C * e4 / (e4 - e2)
82    w3 = 0.25 - C * e4 / (e4 - e3)
83    w4 = 1.0 - 4 * C - w1 - w2 - w3
84    return np.array([w1, w2, w3, w4])
85
86
87def triangulate_submesh(rcell_cv: Array2D) -> Array3D:
88    """Find the 6 tetrahedra."""
89    ABC_sc = np.array([[A, B, C]
90                       for A in [0, 1] for B in [0, 1] for C in [0, 1]])
91    dt = Delaunay(ABC_sc.dot(rcell_cv))
92    s_tq = dt.simplices
93    ABC_tqc = ABC_sc[s_tq]
94
95    # Remove zero-volume slivers:
96    ABC_tqc = ABC_tqc[np.linalg.det(ABC_tqc[:, 1:] - ABC_tqc[:, :1]) != 0]
97
98    assert ABC_tqc.shape == (6, 4, 3)
99    return ABC_tqc
100
101
102def triangulate_everything(size_c: Array1D,
103                           ABC_tqc: Array3D,
104                           i_k: Array1D) -> Array3D:
105    """Triangulate the whole BZ.
106
107    Returns i_ktq ndarray mapping:
108
109    * k: BZ k-point index (0, 1, ...,  nbzk - 1)
110    * t: tetrahedron index (0, 1, ..., 5)
111    * q: tetrahedron corner index (0, 1, 2, 3)
112
113    to i: IBZ k-point index (0, 1, ...,  nibzk - 1).
114    """
115    nbzk = cast(int, size_c.prod())
116    ABC_ck = np.unravel_index(np.arange(nbzk), size_c)
117    ABC_tqck = ABC_tqc[..., np.newaxis] + ABC_ck
118    ABC_cktq = np.transpose(ABC_tqck, (2, 3, 0, 1))
119    k_ktq = np.ravel_multi_index(
120        ABC_cktq.reshape((3, nbzk * 6 * 4)),
121        size_c,
122        mode='wrap').reshape((nbzk, 6, 4))  # type: ignore
123    i_ktq = i_k[k_ktq]
124    return i_ktq
125
126
127class TetrahedronMethod(OccupationNumberCalculator):
128    name = 'tetrahedron-method'
129    extrapolate_factor = 0.0
130
131    def __init__(self,
132                 rcell: ArrayLike2D,
133                 size: ArrayLike1D,
134                 improved=False,
135                 bz2ibzmap: ArrayLike1D = None,
136                 parallel_layout: ParallelLayout = None):
137        """Tetrahedron method for calculating occupation numbers.
138
139        The reciprocal cell, *rcell*, can be given in arbitrary units
140        (only the shape matters) and *size* is the size of the
141        Monkhorst-Pack grid.  If k-points have been symmetry-reduced
142        the *bz2ibzmap* parameter  mapping BZ k-point indizes to
143        IBZ k-point indices must be given.
144        """
145
146        OccupationNumberCalculator.__init__(self, parallel_layout)
147
148        self.rcell_cv = np.asarray(rcell)
149        self.size_c = np.asarray(size)
150        self.improved = improved
151
152        nbzk = self.size_c.prod()
153
154        if bz2ibzmap is None:
155            bz2ibzmap = np.arange(nbzk)
156
157        self.i_k = np.asarray(bz2ibzmap)
158
159        assert self.size_c.shape == (3,)
160        assert self.rcell_cv.shape == (3, 3)
161        assert self.i_k.shape == (nbzk,)
162
163        ABC_tqc = triangulate_submesh(
164            self.rcell_cv / self.size_c[:, np.newaxis])
165
166        self.i_ktq = triangulate_everything(self.size_c, ABC_tqc, self.i_k)
167
168        self.nibzkpts = self.i_k.max() + 1
169
170    def __repr__(self):
171        return (
172            'TetrahedronMethod('
173            f'rcell={self.rcell_cv.tolist()}, '
174            f'size={self.size_c.tolist()}, '
175            f'bz2ibzmap={self.i_k.tolist()}, '
176            'parallel_layout=<'
177            f'{self.bd.comm.size}x{self.kpt_comm.size}x{self.domain_comm.size}'
178            '>)')
179
180    def copy(self,
181             parallel_layout: ParallelLayout = None,
182             bz2ibzmap: List[int] = None
183             ) -> OccupationNumberCalculator:
184        return TetrahedronMethod(
185            self.rcell_cv,
186            self.size_c,
187            self.improved,
188            self.i_k if bz2ibzmap is None else bz2ibzmap,
189            parallel_layout or self.parallel_layout)
190
191    def _calculate(self,
192                   nelectrons,
193                   eig_qn,
194                   weight_q,
195                   f_qn,
196                   fermi_level_guess=nan) -> Tuple[float, float]:
197        if np.isnan(fermi_level_guess):
198            zero = ZeroWidth(self.parallel_layout)
199            fermi_level_guess, _ = zero._calculate(
200                nelectrons, eig_qn, weight_q, f_qn)
201            if np.isinf(fermi_level_guess):
202                return fermi_level_guess, 0.0
203
204        x = fermi_level_guess
205
206        eig_in, weight_i, nkpts_r = collect_eigelvalues(eig_qn, weight_q,
207                                                        self.bd, self.kpt_comm)
208
209        if eig_in.size != 0:
210            if len(eig_in) == self.nibzkpts:
211                nspins = 1
212            else:
213                nspins = 2
214                assert len(eig_in) == 2 * self.nibzkpts
215
216            def func(x, eig_in=eig_in):
217                """Return excess electrons and derivative."""
218                if nspins == 1:
219                    n, dn = count(x, eig_in, self.i_ktq)
220                else:
221                    n1, dn1 = count(x, eig_in[::2], self.i_ktq)
222                    n2, dn2 = count(x, eig_in[1::2], self.i_ktq)
223                    n = n1 + n2
224                    dn = dn1 + dn2
225                return n - nelectrons, dn
226
227            fermi_level, niter = findroot(func, x)
228
229            def w(de_in):
230                return weights(de_in, self.i_ktq, self.improved)
231
232            if nspins == 1:
233                f_in = w(eig_in - fermi_level)
234            else:
235                f_in = np.zeros_like(eig_in)
236                f_in[::2] = w(eig_in[::2] - fermi_level)
237                f_in[1::2] = w(eig_in[1::2] - fermi_level)
238
239            f_in *= 1 / (weight_i[:, np.newaxis] * len(self.i_k))
240        else:
241            f_in = None
242            fermi_level = nan
243
244        distribute_occupation_numbers(f_in, f_qn, nkpts_r,
245                                      self.bd, self.kpt_comm)
246
247        if self.kpt_comm.rank == 0:
248            fermi_level = broadcast_float(fermi_level, self.bd.comm)
249        fermi_level = broadcast_float(fermi_level, self.kpt_comm)
250
251        return fermi_level, 0.0
252
253
254def count(fermi_level: float,
255          eig_in: Array2D,
256          i_ktq: Array3D) -> Tuple[float, float]:
257    """Count electrons.
258
259    Return number of electrons and derivative with respect to fermi level.
260    """
261    eig_in = eig_in - fermi_level
262    nocc_i = (eig_in < 0.0).sum(axis=1)
263    n1 = nocc_i.min()
264    n2 = nocc_i.max()
265
266    ne = n1
267    dnedef = 0.0
268
269    if n1 == n2:
270        return ne, dnedef
271
272    ntetra = 6 * i_ktq.shape[0]
273    eig_Tq = eig_in[i_ktq, n1:n2].transpose((0, 1, 3, 2)).reshape(
274        (ntetra * (n2 - n1), 4))
275    eig_Tq.sort(axis=1)
276
277    eig_Tq = eig_Tq[eig_Tq[:, 0] < 0.0]
278
279    mask1_T = eig_Tq[:, 1] > 0.0
280    mask2_T = ~mask1_T & (eig_Tq[:, 2] > 0.0)
281    mask3_T = ~mask1_T & ~mask2_T & (eig_Tq[:, 3] > 0.0)
282
283    for mask_T, bjaa in [(mask1_T, bja1), (mask2_T, bja2), (mask3_T, bja3)]:
284        n, dn_T = bjaa(*eig_Tq[mask_T].T)
285        ne += n / ntetra
286        dnedef += dn_T.sum() / ntetra  # type: ignore
287
288    mask4_T = ~mask1_T & ~mask2_T & ~mask3_T
289    ne += mask4_T.sum() / ntetra
290
291    return ne, dnedef
292
293
294def weights(eig_in: Array2D, i_ktq: Array3D, improved=False) -> Array2D:
295    """Calculate occupation numbers."""
296    nocc_i = (eig_in < 0.0).sum(axis=1)
297    n1 = nocc_i.min()
298    n2 = nocc_i.max()
299
300    f_in = np.zeros_like(eig_in)
301
302    for i in i_ktq[:, 0, 0]:
303        f_in[i, :n1] += 6.0
304
305    if n1 == n2:
306        return f_in / 6
307
308    ntetra = 6 * i_ktq.shape[0]
309    eig_Tq = eig_in[i_ktq, n1:n2].transpose((0, 1, 3, 2)).reshape(
310        (ntetra * (n2 - n1), 4))
311    q_Tq = eig_Tq.argsort(axis=1)
312    eig_Tq = np.take_along_axis(eig_Tq, q_Tq, 1)
313    f_Tq = np.zeros_like(eig_Tq)
314
315    mask0_T = eig_Tq[:, 0] > 0.0
316    mask1_T = ~mask0_T & (eig_Tq[:, 1] > 0.0)
317    mask2_T = ~mask0_T & ~mask1_T & (eig_Tq[:, 2] > 0.0)
318    mask3_T = ~mask0_T & ~mask1_T & ~mask2_T & (eig_Tq[:, 3] > 0.0)
319
320    for mask_T, bjab in [(mask1_T, bja1b), (mask2_T, bja2b), (mask3_T, bja3b)]:
321        w_qT = bjab(*eig_Tq[mask_T].T)
322        f_Tq[mask_T] += w_qT.T
323
324    if improved:
325        for mask_T, bja in [(mask1_T, bja1),
326                            (mask2_T, bja2),
327                            (mask3_T, bja3)]:
328            e_Tq = eig_Tq[mask_T]
329            _, d_T = bja(*e_Tq.T)
330            f_Tq[mask_T] += (d_T * (e_Tq.sum(1) - 4 * e_Tq.T)).T / 40
331
332    mask4_T = ~mask0_T & ~mask1_T & ~mask2_T & ~mask3_T
333    f_Tq[mask4_T] += 0.25
334
335    ktn_T = np.array(np.unravel_index(np.arange(len(eig_Tq)),
336                                      (len(i_ktq), 6, n2 - n1))).T
337    for f_q, q_q, (k, t, n) in zip(f_Tq, q_Tq, ktn_T):
338        for q, f in zip(q_q, f_q):
339            f_in[i_ktq[k, t, q], n1 + n] += f
340
341    f_in *= 1 / 6
342
343    return f_in
344