1# Copyright (C) 2013 Atsushi Togo
2# All rights reserved.
3#
4# This file is part of phonopy.
5#
6# Redistribution and use in source and binary forms, with or without
7# modification, are permitted provided that the following conditions
8# are met:
9#
10# * Redistributions of source code must retain the above copyright
11#   notice, this list of conditions and the following disclaimer.
12#
13# * Redistributions in binary form must reproduce the above copyright
14#   notice, this list of conditions and the following disclaimer in
15#   the documentation and/or other materials provided with the
16#   distribution.
17#
18# * Neither the name of the phonopy project nor the names of its
19#   contributors may be used to endorse or promote products derived
20#   from this software without specific prior written permission.
21#
22# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
23# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
24# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
25# FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE
26# COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
27# INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
28# BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
29# LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
30# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
31# LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN
32# ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
33# POSSIBILITY OF SUCH DAMAGE.
34
35import numpy as np
36try:
37    import phonopy._phonopy as phonoc
38except ImportError:
39    import sys
40    print("Phonopy C-extension has to be built properly.")
41    sys.exit(1)
42
43parallelepiped_vertices = np.array([[0, 0, 0],
44                                    [1, 0, 0],
45                                    [0, 1, 0],
46                                    [1, 1, 0],
47                                    [0, 0, 1],
48                                    [1, 0, 1],
49                                    [0, 1, 1],
50                                    [1, 1, 1]], dtype='int_', order='C')
51
52
53def get_tetrahedra_relative_grid_address(microzone_lattice):
54    """Returns relative (differences of) grid addresses from the central
55
56    Parameter
57    ---------
58    microzone_lattice : ndarray or list of list
59        column vectors of parallel piped microzone lattice, i.e.,
60        microzone_lattice = np.linalg.inv(cell.get_cell()) / mesh
61
62    """
63
64    relative_grid_address = np.zeros((24, 4, 3), dtype='int_', order='C')
65    phonoc.tetrahedra_relative_grid_address(
66        relative_grid_address,
67        np.array(microzone_lattice, dtype='double', order='C'))
68
69    return relative_grid_address
70
71
72def get_all_tetrahedra_relative_grid_address():
73    """Returns relative grid addresses dataset
74
75    This exists only for the test.
76
77    """
78    relative_grid_address = np.zeros((4, 24, 4, 3), dtype='int_')
79    phonoc.all_tetrahedra_relative_grid_address(relative_grid_address)
80
81    return relative_grid_address
82
83
84def get_tetrahedra_integration_weight(omegas,
85                                      tetrahedra_omegas,
86                                      function='I'):
87    """Returns integration weights
88
89    Parameters
90    ----------
91    omegas : float or list of float values
92        Energy(s) at which the integration weight(s) are computed.
93    tetrahedra_omegas : ndarray of list of list
94        Energies at vertices of 24 tetrahedra
95        shape=(24, 4)
96        dytpe='double'
97    function : str, 'I' or 'J'
98        'J' is for intetration and 'I' is for its derivative.
99
100    """
101
102    if isinstance(omegas, float):
103        return phonoc.tetrahedra_integration_weight(
104            omegas,
105            np.array(tetrahedra_omegas, dtype='double', order='C'),
106            function)
107    else:
108        integration_weights = np.zeros(len(omegas), dtype='double')
109        phonoc.tetrahedra_integration_weight_at_omegas(
110            integration_weights,
111            np.array(omegas, dtype='double'),
112            np.array(tetrahedra_omegas, dtype='double', order='C'),
113            function)
114        return integration_weights
115
116
117class TetrahedronMethod(object):
118    def __init__(self,
119                 primitive_vectors=None, # column vectors
120                 mesh=None,
121                 lang='C'):
122        if mesh is None:
123            mesh = [1, 1, 1]
124        if primitive_vectors is None:
125            self._primitive_vectors = None
126        else:
127            self._primitive_vectors = np.array(
128                primitive_vectors, dtype='double', order='C') / mesh
129        self._lang = lang
130
131        self._vertices = None
132        self._relative_grid_addresses = None
133        self._central_indices = None
134        self._tetrahedra_omegas = None
135        self._sort_indices = None
136        self._omegas = None
137        self._set_relative_grid_addresses()
138        self._integration_weight = None
139
140    def run(self, omegas, value='I'):
141        if self._lang == 'C':
142            self._run_c(omegas, value=value)
143        else:
144            self._run_py(omegas, value=value)
145
146    def get_tetrahedra(self):
147        """
148        Returns relative grid addresses at vertices of tetrahedra
149        """
150        return self._relative_grid_addresses
151
152    def get_unique_tetrahedra_vertices(self):
153        unique_vertices = []
154        for adrs in self._relative_grid_addresses.reshape(-1, 3):
155            found = False
156            for uadrs in unique_vertices:
157                if (uadrs == adrs).all():
158                    found = True
159                    break
160            if not found:
161                unique_vertices.append(adrs)
162        return np.array(unique_vertices, dtype='int_', order='C')
163
164    def set_tetrahedra_omegas(self, tetrahedra_omegas):
165        """
166        tetrahedra_omegas: (24, 4) omegas at self._relative_grid_addresses
167        """
168        self._tetrahedra_omegas = tetrahedra_omegas
169
170    def get_integration_weight(self):
171        return self._integration_weight
172
173    def _run_c(self, omegas, value='I'):
174        self._integration_weight = get_tetrahedra_integration_weight(
175            omegas,
176            self._tetrahedra_omegas,
177            function=value)
178
179    def _run_py(self, omegas, value='I'):
180        if isinstance(omegas, float) or isinstance(omegas, int):
181            iw = self._get_integration_weight_py(omegas, value=value)
182        else:
183            iw = np.zeros(len(omegas), dtype='double')
184            for i, omega in enumerate(omegas):
185                iw[i] = self._get_integration_weight_py(omega, value=value)
186        self._integration_weight = iw
187
188    def _get_integration_weight_py(self, omega, value='I'):
189        if value == 'I':
190            IJ = self._I
191            gn = self._g
192        else:
193            IJ = self._J
194            gn = self._n
195
196        self._sort_indices = np.argsort(self._tetrahedra_omegas, axis=1)
197        sum_value = 0.0
198        self._omega = omega
199        for omegas, indices, ci in zip(self._tetrahedra_omegas,
200                                       self._sort_indices,
201                                       self._central_indices):
202            self._vertices_omegas = omegas[indices]
203            # i_where = np.where(omega < self._vertices_omegas)[0]
204            # if len(i_where):
205            #     i = i_where[0]
206            # else:
207            #     i = 4
208            v = self._vertices_omegas
209            if (omega < v[0]):
210                sum_value += IJ(0, np.where(indices==ci)[0][0]) * gn(0)
211            elif (v[0] < omega and omega < v[1]):
212                sum_value += IJ(1, np.where(indices==ci)[0][0]) * gn(1)
213            elif (v[1] < omega and omega < v[2]):
214                sum_value += IJ(2, np.where(indices==ci)[0][0]) * gn(2)
215            elif (v[2] < omega and omega < v[3]):
216                sum_value += IJ(3, np.where(indices==ci)[0][0]) * gn(3)
217            elif (v[3] < omega):
218                sum_value += IJ(4, np.where(indices==ci)[0][0]) * gn(4)
219
220        return sum_value / 6
221
222    def _create_tetrahedra(self):
223        #
224        #     6-------7
225        #    /|      /|
226        #   / |     / |
227        #  4-------5  |
228        #  |  2----|--3
229        #  | /     | /
230        #  |/      |/
231        #  0-------1
232        #
233        # i: vec        neighbours
234        # 0: O          1, 2, 4
235        # 1: a          0, 3, 5
236        # 2: b          0, 3, 6
237        # 3: a + b      1, 2, 7
238        # 4: c          0, 5, 6
239        # 5: c + a      1, 4, 7
240        # 6: c + b      2, 4, 7
241        # 7: c + a + b  3, 5, 6
242        a, b, c = self._primitive_vectors.T
243        diag_vecs = np.array([ a + b + c,  # 0-7
244                              -a + b + c,  # 1-6
245                               a - b + c,  # 2-5
246                               a + b - c]) # 3-4
247        shortest_index = np.argmin(np.sum(diag_vecs ** 2, axis=1))
248        # vertices = [np.zeros(3), a, b, a + b, c, c + a, c + b, c + a + b]
249        if shortest_index == 0:
250            pairs = ((1, 3), (1, 5), (2, 3), (2, 6), (4, 5), (4, 6))
251            tetras = np.sort([[0, 7] + list(x) for x in pairs])
252        elif shortest_index == 1:
253            pairs = ((0, 2), (0, 4), (2, 3), (3, 7), (4, 5), (5, 7))
254            tetras = np.sort([[1, 6] + list(x) for x in pairs])
255        elif shortest_index == 2:
256            pairs = ((0, 1), (0, 4), (1, 3), (3, 7), (4, 6), (6, 7))
257            tetras = np.sort([[2, 5] + list(x) for x in pairs])
258        elif shortest_index == 3:
259            pairs = ((0, 1), (0, 2), (1, 5), (2, 6), (5, 7), (6, 7))
260            tetras = np.sort([[3, 4] + list(x) for x in pairs])
261        else:
262            assert False
263
264        self._vertices = tetras
265
266    def _set_relative_grid_addresses(self):
267        if self._lang == 'C':
268            rga = get_tetrahedra_relative_grid_address(
269                self._primitive_vectors)
270            self._relative_grid_addresses = rga
271        else:
272            self._create_tetrahedra()
273            relative_grid_addresses = np.zeros((24, 4, 3), dtype='int_')
274            central_indices = np.zeros(24, dtype='int_')
275            pos = 0
276            for i in range(8):
277                ppd_shifted = (parallelepiped_vertices -
278                               parallelepiped_vertices[i])
279                for tetra in self._vertices:
280                    if i in tetra:
281                        central_indices[pos] = np.where(tetra == i)[0][0]
282                        relative_grid_addresses[pos, :, :] = ppd_shifted[tetra]
283                        pos += 1
284            self._relative_grid_addresses = relative_grid_addresses
285            self._central_indices = central_indices
286
287    def _f(self, n, m):
288        return ((self._omega - self._vertices_omegas[m]) /
289                (self._vertices_omegas[n] - self._vertices_omegas[m]))
290
291    def _J(self, i, ci):
292        if i == 0:
293            return self._J_0()
294        elif i == 1:
295            if ci == 0:
296                return self._J_10()
297            elif ci == 1:
298                return self._J_11()
299            elif ci == 2:
300                return self._J_12()
301            elif ci == 3:
302                return self._J_13()
303            else:
304                assert False
305        elif i == 2:
306            if ci == 0:
307                return self._J_20()
308            elif ci == 1:
309                return self._J_21()
310            elif ci == 2:
311                return self._J_22()
312            elif ci == 3:
313                return self._J_23()
314            else:
315                assert False
316        elif i == 3:
317            if ci == 0:
318                return self._J_30()
319            elif ci == 1:
320                return self._J_31()
321            elif ci == 2:
322                return self._J_32()
323            elif ci == 3:
324                return self._J_33()
325            else:
326                assert False
327        elif i == 4:
328            return self._J_4()
329        else:
330            assert False
331
332    def _I(self, i, ci):
333        if i == 0:
334            return self._I_0()
335        elif i == 1:
336            if ci == 0:
337                return self._I_10()
338            elif ci == 1:
339                return self._I_11()
340            elif ci == 2:
341                return self._I_12()
342            elif ci == 3:
343                return self._I_13()
344            else:
345                assert False
346        elif i == 2:
347            if ci == 0:
348                return self._I_20()
349            elif ci == 1:
350                return self._I_21()
351            elif ci == 2:
352                return self._I_22()
353            elif ci == 3:
354                return self._I_23()
355            else:
356                assert False
357        elif i == 3:
358            if ci == 0:
359                return self._I_30()
360            elif ci == 1:
361                return self._I_31()
362            elif ci == 2:
363                return self._I_32()
364            elif ci == 3:
365                return self._I_33()
366            else:
367                assert False
368        elif i == 4:
369            return self._I_4()
370        else:
371            assert False
372
373    def _n(self, i):
374        if i == 0:
375            return self._n_0()
376        elif i == 1:
377            return self._n_1()
378        elif i == 2:
379            return self._n_2()
380        elif i == 3:
381            return self._n_3()
382        elif i == 4:
383            return self._n_4()
384        else:
385            assert False
386
387    def _g(self, i):
388        if i == 0:
389            return self._g_0()
390        elif i == 1:
391            return self._g_1()
392        elif i == 2:
393            return self._g_2()
394        elif i == 3:
395            return self._g_3()
396        elif i == 4:
397            return self._g_4()
398        else:
399            assert False
400
401    def _n_0(self):
402        """omega < omega1"""
403        return 0.0
404
405    def _n_1(self):
406        """omega1 < omega < omega2"""
407        return self._f(1, 0) * self._f(2, 0) * self._f(3, 0)
408
409    def _n_2(self):
410        """omega2 < omega < omega3"""
411        return (self._f(3, 1) * self._f(2, 1) +
412                self._f(3, 0) * self._f(1, 3) * self._f(2, 1) +
413                self._f(3, 0) * self._f(2, 0) * self._f(1, 2))
414
415    def _n_3(self):
416        """omega2 < omega < omega3"""
417        return (1.0 - self._f(0, 3) * self._f(1, 3) * self._f(2, 3))
418
419    def _n_4(self):
420        """omega4 < omega"""
421        return 1.0
422
423    def _g_0(self):
424        """omega < omega1"""
425        return 0.0
426
427    def _g_1(self):
428        """omega1 < omega < omega2"""
429        # return 3 * self._n_1() / (self._omega - self._vertices_omegas[0])
430        return (3 * self._f(1, 0) * self._f(2, 0) /
431                (self._vertices_omegas[3] - self._vertices_omegas[0]))
432
433    def _g_2(self):
434        """omega2 < omega < omega3"""
435        return 3 / (self._vertices_omegas[3] - self._vertices_omegas[0]) * (
436            self._f(1, 2) * self._f(2, 0) +
437            self._f(2, 1) * self._f(1, 3))
438
439    def _g_3(self):
440        """omega3 < omega < omega4"""
441        # return 3 * (1.0 - self._n_3()) / (self._vertices_omegas[3] - self._omega)
442        return (3 * self._f(1, 3) * self._f(2, 3) /
443                (self._vertices_omegas[3] - self._vertices_omegas[0]))
444
445    def _g_4(self):
446        """omega4 < omega"""
447        return 0.0
448
449    def _J_0(self):
450        return 0.0
451
452    def _J_10(self):
453        return (1.0 + self._f(0, 1) + self._f(0, 2) + self._f(0, 3)) / 4
454
455    def _J_11(self):
456        return self._f(1, 0) / 4
457
458    def _J_12(self):
459        return self._f(2, 0) / 4
460
461    def _J_13(self):
462        return self._f(3, 0) / 4
463
464    def _J_20(self):
465        return (self._f(3, 1) * self._f(2, 1) +
466                self._f(3, 0) * self._f(1, 3) * self._f(2, 1) *
467                (1.0 + self._f(0, 3)) +
468                self._f(3, 0) * self._f(2, 0) * self._f(1, 2) *
469                (1.0 + self._f(0, 3) + self._f(0, 2))) / 4 / self._n_2()
470
471    def _J_21(self):
472        return (self._f(3, 1) * self._f(2, 1) *
473                (1.0 + self._f(1, 3) + self._f(1, 2)) +
474                self._f(3, 0) * self._f(1, 3) * self._f(2, 1) *
475                (self._f(1, 3) + self._f(1, 2)) +
476                self._f(3, 0) * self._f(2, 0) * self._f(1, 2) *
477                self._f(1, 2)) / 4 / self._n_2()
478
479    def _J_22(self):
480        return (self._f(3, 1) * self._f(2, 1) *
481                self._f(2, 1) +
482                self._f(3, 0) * self._f(1, 3) * self._f(2, 1) *
483                self._f(2, 1) +
484                self._f(3, 0) * self._f(2, 0) * self._f(1, 2) *
485                (self._f(2, 1) + self._f(2, 0))) / 4 / self._n_2()
486
487    def _J_23(self):
488        return (self._f(3, 1) * self._f(2, 1) *
489                self._f(3, 1) +
490                self._f(3, 0) * self._f(1, 3) * self._f(2, 1) *
491                (self._f(3, 1) + self._f(3, 0)) +
492                self._f(3, 0) * self._f(2, 0) * self._f(1, 2) *
493                self._f(3, 0)) / 4 / self._n_2()
494
495    def _J_30(self):
496        return ((1.0 - self._f(0, 3) ** 2 * self._f(1, 3) * self._f(2, 3)) /
497                4 / self._n_3())
498
499    def _J_31(self):
500        return ((1.0 - self._f(0, 3) * self._f(1, 3) ** 2 * self._f(2, 3)) /
501                4 / self._n_3())
502
503    def _J_32(self):
504        return ((1.0 - self._f(0, 3) * self._f(1, 3) * self._f(2, 3) ** 2) /
505                4 / self._n_3())
506
507    def _J_33(self):
508        return ((1.0 - self._f(0, 3) * self._f(1, 3) * self._f(2, 3) *
509                 (1.0 + self._f(3, 0) + self._f(3, 1) + self._f(3, 2))) /
510                4 / self._n_3())
511
512    def _J_4(self):
513        return 0.25
514
515    def _I_0(self):
516        return 0.0
517
518    def _I_10(self):
519        return (self._f(0, 1) + self._f(0, 2) + self._f(0, 3)) / 3
520
521    def _I_11(self):
522        return self._f(1, 0) / 3
523
524    def _I_12(self):
525        return self._f(2, 0) / 3
526
527    def _I_13(self):
528        return self._f(3, 0) / 3
529
530    def _I_20(self):
531        return (self._f(0, 3) +
532                self._f(0, 2) * self._f(2, 0) * self._f(1, 2) /
533                (self._f(1, 2) * self._f(2, 0) + self._f(2, 1) * self._f(1, 3))
534                ) / 3
535
536    def _I_21(self):
537        return (self._f(1, 2) +
538                self._f(1, 3) ** 2 * self._f(2, 1) /
539                (self._f(1, 2) * self._f(2, 0) + self._f(2, 1) * self._f(1, 3))
540                ) / 3
541
542    def _I_22(self):
543        return (self._f(2, 1) +
544                self._f(2, 0) ** 2 * self._f(1, 2) /
545                (self._f(1, 2) * self._f(2, 0) + self._f(2, 1) * self._f(1, 3))
546                ) / 3
547
548    def _I_23(self):
549        return (self._f(3, 0) +
550                self._f(3, 1) * self._f(1, 3) * self._f(2, 1) /
551                (self._f(1, 2) * self._f(2, 0) + self._f(2, 1) * self._f(1, 3))
552                ) / 3
553
554    def _I_30(self):
555        return self._f(0, 3) / 3
556
557    def _I_31(self):
558        return self._f(1, 3) / 3
559
560    def _I_32(self):
561        return self._f(2, 3) / 3
562
563    def _I_33(self):
564        return (self._f(3, 0) + self._f(3, 1) + self._f(3, 2)) / 3
565
566    def _I_4(self):
567        return 0.0
568