1# This Source Code Form is subject to the terms of the Mozilla Public
2# License, v. 2.0. If a copy of the MPL was not distributed with this
3# file, You can obtain one at http://mozilla.org/MPL/2.0/.
4from __future__ import division
5
6from collections import OrderedDict
7from itertools import combinations, product
8
9import numpy as np
10from numpy import dot, pi
11from numpy.linalg import norm
12
13from . import Math
14from .species_data import get_property
15
16__all__ = ()
17
18angstrom = 1 / 0.52917721092  #:
19
20
21class InternalCoord(object):
22    def __init__(self, C=None):
23        if C is not None:
24            self.weak = sum(
25                not C[self.idx[i], self.idx[i + 1]] for i in range(len(self.idx) - 1)
26            )
27
28    def __eq__(self, other):
29        self.idx == other.idx  # noqa B015
30
31    def __hash__(self):
32        return hash(self.idx)
33
34    def __repr__(self):
35        args = list(map(str, self.idx))
36        if self.weak is not None:
37            args.append('weak=' + str(self.weak))
38        return '{}({})'.format(self.__class__.__name__, ', '.join(args))
39
40
41class Bond(InternalCoord):
42    def __init__(self, i, j, **kwargs):
43        if i > j:
44            i, j = j, i
45        self.i = i
46        self.j = j
47        self.idx = i, j
48        InternalCoord.__init__(self, **kwargs)
49
50    def hessian(self, rho):
51        return 0.45 * rho[self.i, self.j]
52
53    def weight(self, rho, coords):
54        return rho[self.i, self.j]
55
56    def center(self, ijk):
57        return np.round(ijk[[self.i, self.j]].sum(0))
58
59    def eval(self, coords, grad=False):
60        v = (coords[self.i] - coords[self.j]) * angstrom
61        r = norm(v)
62        if not grad:
63            return r
64        return r, [v / r, -v / r]
65
66
67class Angle(InternalCoord):
68    def __init__(self, i, j, k, **kwargs):
69        if i > k:
70            i, j, k = k, j, i
71        self.i = i
72        self.j = j
73        self.k = k
74        self.idx = i, j, k
75        InternalCoord.__init__(self, **kwargs)
76
77    def hessian(self, rho):
78        return 0.15 * (rho[self.i, self.j] * rho[self.j, self.k])
79
80    def weight(self, rho, coords):
81        f = 0.12
82        return np.sqrt(rho[self.i, self.j] * rho[self.j, self.k]) * (
83            f + (1 - f) * np.sin(self.eval(coords))
84        )
85
86    def center(self, ijk):
87        return np.round(2 * ijk[self.j])
88
89    def eval(self, coords, grad=False):
90        v1 = (coords[self.i] - coords[self.j]) * angstrom
91        v2 = (coords[self.k] - coords[self.j]) * angstrom
92        dot_product = np.dot(v1, v2) / (norm(v1) * norm(v2))
93        if dot_product < -1:
94            dot_product = -1
95        elif dot_product > 1:
96            dot_product = 1
97        phi = np.arccos(dot_product)
98        if not grad:
99            return phi
100        if abs(phi) > pi - 1e-6:
101            grad = [
102                (pi - phi) / (2 * norm(v1) ** 2) * v1,
103                (1 / norm(v1) - 1 / norm(v2)) * (pi - phi) / (2 * norm(v1)) * v1,
104                (pi - phi) / (2 * norm(v2) ** 2) * v2,
105            ]
106        else:
107            grad = [
108                1 / np.tan(phi) * v1 / norm(v1) ** 2
109                - v2 / (norm(v1) * norm(v2) * np.sin(phi)),
110                (v1 + v2) / (norm(v1) * norm(v2) * np.sin(phi))
111                - 1 / np.tan(phi) * (v1 / norm(v1) ** 2 + v2 / norm(v2) ** 2),
112                1 / np.tan(phi) * v2 / norm(v2) ** 2
113                - v1 / (norm(v1) * norm(v2) * np.sin(phi)),
114            ]
115        return phi, grad
116
117
118class Dihedral(InternalCoord):
119    def __init__(self, i, j, k, l, weak=None, angles=None, C=None, **kwargs):
120        if j > k:
121            i, j, k, l = l, k, j, i
122        self.i = i
123        self.j = j
124        self.k = k
125        self.l = l
126        self.idx = (i, j, k, l)
127        self.weak = weak
128        self.angles = angles
129        InternalCoord.__init__(self, **kwargs)
130
131    def hessian(self, rho):
132        return 0.005 * rho[self.i, self.j] * rho[self.j, self.k] * rho[self.k, self.l]
133
134    def weight(self, rho, coords):
135        f = 0.12
136        th1 = Angle(self.i, self.j, self.k).eval(coords)
137        th2 = Angle(self.j, self.k, self.l).eval(coords)
138        return (
139            (rho[self.i, self.j] * rho[self.j, self.k] * rho[self.k, self.l]) ** (1 / 3)
140            * (f + (1 - f) * np.sin(th1))
141            * (f + (1 - f) * np.sin(th2))
142        )
143
144    def center(self, ijk):
145        return np.round(ijk[[self.j, self.k]].sum(0))
146
147    def eval(self, coords, grad=False):
148        v1 = (coords[self.i] - coords[self.j]) * angstrom
149        v2 = (coords[self.l] - coords[self.k]) * angstrom
150        w = (coords[self.k] - coords[self.j]) * angstrom
151        ew = w / norm(w)
152        a1 = v1 - dot(v1, ew) * ew
153        a2 = v2 - dot(v2, ew) * ew
154        sgn = np.sign(np.linalg.det(np.array([v2, v1, w])))
155        sgn = sgn or 1
156        dot_product = dot(a1, a2) / (norm(a1) * norm(a2))
157        if dot_product < -1:
158            dot_product = -1
159        elif dot_product > 1:
160            dot_product = 1
161        phi = np.arccos(dot_product) * sgn
162        if not grad:
163            return phi
164        if abs(phi) > pi - 1e-6:
165            g = Math.cross(w, a1)
166            g = g / norm(g)
167            A = dot(v1, ew) / norm(w)
168            B = dot(v2, ew) / norm(w)
169            grad = [
170                g / (norm(g) * norm(a1)),
171                -((1 - A) / norm(a1) - B / norm(a2)) * g,
172                -((1 + B) / norm(a2) + A / norm(a1)) * g,
173                g / (norm(g) * norm(a2)),
174            ]
175        elif abs(phi) < 1e-6:
176            g = Math.cross(w, a1)
177            g = g / norm(g)
178            A = dot(v1, ew) / norm(w)
179            B = dot(v2, ew) / norm(w)
180            grad = [
181                g / (norm(g) * norm(a1)),
182                -((1 - A) / norm(a1) + B / norm(a2)) * g,
183                ((1 + B) / norm(a2) - A / norm(a1)) * g,
184                -g / (norm(g) * norm(a2)),
185            ]
186        else:
187            A = dot(v1, ew) / norm(w)
188            B = dot(v2, ew) / norm(w)
189            grad = [
190                1 / np.tan(phi) * a1 / norm(a1) ** 2
191                - a2 / (norm(a1) * norm(a2) * np.sin(phi)),
192                ((1 - A) * a2 - B * a1) / (norm(a1) * norm(a2) * np.sin(phi))
193                - 1
194                / np.tan(phi)
195                * ((1 - A) * a1 / norm(a1) ** 2 - B * a2 / norm(a2) ** 2),
196                ((1 + B) * a1 + A * a2) / (norm(a1) * norm(a2) * np.sin(phi))
197                - 1
198                / np.tan(phi)
199                * ((1 + B) * a2 / norm(a2) ** 2 + A * a1 / norm(a1) ** 2),
200                1 / np.tan(phi) * a2 / norm(a2) ** 2
201                - a1 / (norm(a1) * norm(a2) * np.sin(phi)),
202            ]
203        return phi, grad
204
205
206def get_clusters(C):
207    nonassigned = list(range(len(C)))
208    clusters = []
209    while nonassigned:
210        queue = {nonassigned[0]}
211        clusters.append([])
212        while queue:
213            node = queue.pop()
214            clusters[-1].append(node)
215            nonassigned.remove(node)
216            queue.update(n for n in np.flatnonzero(C[node]) if n in nonassigned)
217    C = np.zeros_like(C)
218    for cluster in clusters:
219        for i in cluster:
220            C[i, cluster] = True
221    return clusters, C
222
223
224class InternalCoords(object):
225    def __init__(self, geom, allowed=None, dihedral=True, superweakdih=False):
226        self._coords = []
227        n = len(geom)
228        geom = geom.supercell()
229        dist = geom.dist(geom)
230        radii = np.array([get_property(sp, 'covalent_radius') for sp in geom.species])
231        bondmatrix = dist < 1.3 * (radii[None, :] + radii[:, None])
232        self.fragments, C = get_clusters(bondmatrix)
233        radii = np.array([get_property(sp, 'vdw_radius') for sp in geom.species])
234        shift = 0.0
235        C_total = C.copy()
236        while not C_total.all():
237            bondmatrix |= ~C_total & (dist < radii[None, :] + radii[:, None] + shift)
238            C_total = get_clusters(bondmatrix)[1]
239            shift += 1.0
240        for i, j in combinations(range(len(geom)), 2):
241            if bondmatrix[i, j]:
242                bond = Bond(i, j, C=C)
243                self.append(bond)
244        for j in range(len(geom)):
245            for i, k in combinations(np.flatnonzero(bondmatrix[j, :]), 2):
246                ang = Angle(i, j, k, C=C)
247                if ang.eval(geom.coords) > pi / 4:
248                    self.append(ang)
249        if dihedral:
250            for bond in self.bonds:
251                self.extend(
252                    get_dihedrals(
253                        [bond.i, bond.j],
254                        geom.coords,
255                        bondmatrix,
256                        C,
257                        superweak=superweakdih,
258                    )
259                )
260        if geom.lattice is not None:
261            self._reduce(n)
262
263    def append(self, coord):
264        self._coords.append(coord)
265
266    def extend(self, coords):
267        self._coords.extend(coords)
268
269    def __iter__(self):
270        return self._coords.__iter__()
271
272    def __len__(self):
273        return len(self._coords)
274
275    @property
276    def bonds(self):
277        return [c for c in self if isinstance(c, Bond)]
278
279    @property
280    def angles(self):
281        return [c for c in self if isinstance(c, Angle)]
282
283    @property
284    def dihedrals(self):
285        return [c for c in self if isinstance(c, Dihedral)]
286
287    @property
288    def dict(self):
289        return OrderedDict(
290            [
291                ('bonds', self.bonds),
292                ('angles', self.angles),
293                ('dihedrals', self.dihedrals),
294            ]
295        )
296
297    def __repr__(self):
298        return '<InternalCoords "{}">'.format(
299            ', '.join(
300                '{}: {}'.format(name, len(coords)) for name, coords in self.dict.items()
301            )
302        )
303
304    def __str__(self):
305        ncoords = sum(len(coords) for coords in self.dict.values())
306        s = 'Internal coordinates:\n'
307        s += '* Number of fragments: {}\n'.format(len(self.fragments))
308        s += '* Number of internal coordinates: {}\n'.format(ncoords)
309        for name, coords in self.dict.items():
310            for degree, adjective in [(0, 'strong'), (1, 'weak'), (2, 'superweak')]:
311                n = len([None for c in coords if min(2, c.weak) == degree])
312                if n > 0:
313                    s += '* Number of {} {}: {}\n'.format(adjective, name, n)
314        return s.rstrip()
315
316    def eval_geom(self, geom, template=None):
317        geom = geom.supercell()
318        q = np.array([coord.eval(geom.coords) for coord in self])
319        if template is None:
320            return q
321        swapped = []  # dihedrals swapped by pi
322        candidates = set()  # potentially swapped angles
323        for i, dih in enumerate(self):
324            if not isinstance(dih, Dihedral):
325                continue
326            diff = q[i] - template[i]
327            if abs(abs(diff) - 2 * pi) < pi / 2:
328                q[i] -= 2 * pi * np.sign(diff)
329            elif abs(abs(diff) - pi) < pi / 2:
330                q[i] -= pi * np.sign(diff)
331                swapped.append(dih)
332                candidates.update(dih.angles)
333        for i, ang in enumerate(self):
334            if not isinstance(ang, Angle) or ang not in candidates:
335                continue
336            # candidate angle was swapped if each dihedral that contains it was
337            # either swapped or all its angles are candidates
338            if all(
339                dih in swapped or all(a in candidates for a in dih.angles)
340                for dih in self.dihedrals
341                if ang in dih.angles
342            ):
343                q[i] = 2 * pi - q[i]
344        return q
345
346    def _reduce(self, n):
347        idxs = np.int64(np.floor(np.array(range(3 ** 3 * n)) / n))
348        idxs, i = np.divmod(idxs, 3)
349        idxs, j = np.divmod(idxs, 3)
350        k = idxs % 3
351        ijk = np.vstack((i, j, k)).T - 1
352        self._coords = [
353            coord
354            for coord in self._coords
355            if np.all(np.isin(coord.center(ijk), [0, -1]))
356        ]
357        idxs = {i for coord in self._coords for i in coord.idx}
358        self.fragments = [frag for frag in self.fragments if set(frag) & idxs]
359
360    def hessian_guess(self, geom):
361        geom = geom.supercell()
362        rho = geom.rho()
363        return np.diag([coord.hessian(rho) for coord in self])
364
365    def weights(self, geom):
366        geom = geom.supercell()
367        rho = geom.rho()
368        return np.array([coord.weight(rho, geom.coords) for coord in self])
369
370    def B_matrix(self, geom):
371        geom = geom.supercell()
372        B = np.zeros((len(self), len(geom), 3))
373        for i, coord in enumerate(self):
374            _, grads = coord.eval(geom.coords, grad=True)
375            idx = [k % len(geom) for k in coord.idx]
376            for j, grad in zip(idx, grads):
377                B[i, j] += grad
378        return B.reshape(len(self), 3 * len(geom))
379
380    def update_geom(self, geom, q, dq, B_inv, log=lambda _: None):
381        geom = geom.copy()
382        thre = 1e-6
383        # target = CartIter(q=q+dq)
384        # prev = CartIter(geom.coords, q, dq)
385        for i in range(20):
386            coords_new = geom.coords + B_inv.dot(dq).reshape(-1, 3) / angstrom
387            dcart_rms = Math.rms(coords_new - geom.coords)
388            geom.coords = coords_new
389            q_new = self.eval_geom(geom, template=q)
390            dq_rms = Math.rms(q_new - q)
391            q, dq = q_new, dq - (q_new - q)
392            if dcart_rms < thre:
393                msg = 'Perfect transformation to cartesians in {} iterations'
394                break
395            if i == 0:
396                keep_first = geom.copy(), q, dcart_rms, dq_rms
397        else:
398            msg = 'Transformation did not converge in {} iterations'
399            geom, q, dcart_rms, dq_rms = keep_first
400        log(msg.format(i + 1))
401        log('* RMS(dcart): {:.3}, RMS(dq): {:.3}'.format(dcart_rms, dq_rms))
402        return q, geom
403
404
405def get_dihedrals(center, coords, bondmatrix, C, superweak=False):
406    lin_thre = 5 * pi / 180
407    neigh_l = [n for n in np.flatnonzero(bondmatrix[center[0], :]) if n not in center]
408    neigh_r = [n for n in np.flatnonzero(bondmatrix[center[-1], :]) if n not in center]
409    angles_l = [Angle(i, center[0], center[1]).eval(coords) for i in neigh_l]
410    angles_r = [Angle(center[-2], center[-1], j).eval(coords) for j in neigh_r]
411    nonlinear_l = [
412        n
413        for n, ang in zip(neigh_l, angles_l)
414        if ang < pi - lin_thre and ang >= lin_thre
415    ]
416    nonlinear_r = [
417        n
418        for n, ang in zip(neigh_r, angles_r)
419        if ang < pi - lin_thre and ang >= lin_thre
420    ]
421    linear_l = [
422        n for n, ang in zip(neigh_l, angles_l) if ang >= pi - lin_thre or ang < lin_thre
423    ]
424    linear_r = [
425        n for n, ang in zip(neigh_r, angles_r) if ang >= pi - lin_thre or ang < lin_thre
426    ]
427    assert len(linear_l) <= 1
428    assert len(linear_r) <= 1
429    if center[0] < center[-1]:
430        nweak = len(
431            [None for i in range(len(center) - 1) if not C[center[i], center[i + 1]]]
432        )
433        dihedrals = []
434        for nl, nr in product(nonlinear_l, nonlinear_r):
435            if nl == nr:
436                continue
437            weak = (
438                nweak + (0 if C[nl, center[0]] else 1) + (0 if C[center[0], nr] else 1)
439            )
440            if not superweak and weak > 1:
441                continue
442            dihedrals.append(
443                Dihedral(
444                    nl,
445                    center[0],
446                    center[-1],
447                    nr,
448                    weak=weak,
449                    angles=(
450                        Angle(nl, center[0], center[1], C=C),
451                        Angle(nl, center[-2], center[-1], C=C),
452                    ),
453                )
454            )
455    else:
456        dihedrals = []
457    if len(center) > 3:
458        pass
459    elif linear_l and not linear_r:
460        dihedrals.extend(get_dihedrals(linear_l + center, coords, bondmatrix, C))
461    elif linear_r and not linear_l:
462        dihedrals.extend(get_dihedrals(center + linear_r, coords, bondmatrix, C))
463    return dihedrals
464