1# -*- coding: utf-8 -*-
2# MolMod is a collection of molecular modelling tools for python.
3# Copyright (C) 2007 - 2019 Toon Verstraelen <Toon.Verstraelen@UGent.be>, Center
4# for Molecular Modeling (CMM), Ghent University, Ghent, Belgium; all rights
5# reserved unless otherwise stated.
6#
7# This file is part of MolMod.
8#
9# MolMod is free software; you can redistribute it and/or
10# modify it under the terms of the GNU General Public License
11# as published by the Free Software Foundation; either version 3
12# of the License, or (at your option) any later version.
13#
14# MolMod is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program; if not, see <http://www.gnu.org/licenses/>
21#
22# --
23"""Data structures to handle 3D rotations and translations
24
25In addition to Translation, Rotation and Complete classes, two utility
26functions are provided: rotation_around_center and superpose. The latter is an
27implementation of the Kabsch algorithm.
28"""
29
30
31from __future__ import division
32
33from builtins import range
34import numpy as np
35
36from molmod.utils import cached, ReadOnly, ReadOnlyAttribute, compute_rmsd
37from molmod.vectors import random_unit
38from molmod.unit_cells import UnitCell
39
40
41__all__ = [
42    "Translation", "Rotation", "Complete", "superpose", "fit_rmsd"
43]
44
45
46eps = 1.0e-6
47
48def check_matrix(m):
49    """Check the sanity of the given 4x4 transformation matrix"""
50    if m.shape != (4, 4):
51        raise ValueError("The argument must be a 4x4 array.")
52    if max(abs(m[3, 0:3])) > eps:
53        raise ValueError("The given matrix does not have correct translational part")
54    if abs(m[3, 3] - 1.0) > eps:
55        raise ValueError("The lower right element of the given matrix must be 1.0.")
56
57
58class Translation(ReadOnly):
59    """Represents a translation in 3D
60
61       The attribute t contains the actual translation vector, which is a numpy
62       array with three elements.
63    """
64    t = ReadOnlyAttribute(np.ndarray, none=False, npdim=1, npshape=(3,),
65        npdtype=np.floating, doc="the translation vector")
66
67    def __init__(self, t):
68        """
69           Argument:
70            | ``t``  --  translation vector, a list-like object with three
71                         numbers
72        """
73        self.t = t
74
75    @classmethod
76    def from_matrix(cls, m):
77        """Initialize a translation from a 4x4 matrix"""
78        check_matrix(m)
79        return cls(m[0:3, 3])
80
81    @classmethod
82    def identity(cls):
83        """Return the identity transformation"""
84        return cls(np.zeros(3, float))
85
86    @cached
87    def matrix(self):
88        """The 4x4 matrix representation of this translation"""
89        result = np.identity(4, float)
90        result[0:3, 3] = self.t
91        return result
92
93    @cached
94    def inv(self):
95        """The inverse translation"""
96        result = Translation(-self.t)
97        result._cache_inv = self
98        return result
99
100    def apply_to(self, x, columns=False):
101        """Apply this translation to the given object
102
103           The argument can be several sorts of objects:
104
105           * ``np.array`` with shape (3, )
106           * ``np.array`` with shape (N, 3)
107           * ``np.array`` with shape (3, N), use ``columns=True``
108           * ``Translation``
109           * ``Rotation``
110           * ``Complete``
111           * ``UnitCell``
112
113           In case of arrays, the 3D vectors are translated. In case of trans-
114           formations, a new transformation is returned that consists of this
115           translation applied AFTER the given translation. In case of a unit
116           cell, the original object is returned.
117
118           This method is equivalent to ``self*x``.
119        """
120        if isinstance(x, np.ndarray) and len(x.shape) == 2 and x.shape[0] == 3 and columns:
121            return x + self.t.reshape((3,1))
122        if isinstance(x, np.ndarray) and (x.shape == (3, ) or (len(x.shape) == 2 and x.shape[1] == 3)) and not columns:
123            return x + self.t
124        elif isinstance(x, Complete):
125            return Complete(x.r, x.t + self.t)
126        elif isinstance(x, Translation):
127            return Translation(x.t + self.t)
128        elif isinstance(x, Rotation):
129            return Complete(x.r, self.t)
130        elif isinstance(x, UnitCell):
131            return x
132        else:
133            raise ValueError("Can not apply this translation to %s" % x)
134
135    __mul__ = apply_to
136
137    def compare(self, other, t_threshold=1e-3):
138        """Compare two translations
139
140           The RMSD of the translation vectors is computed. The return value
141           is True when the RMSD is below the threshold, i.e. when the two
142           translations are almost identical.
143        """
144        return compute_rmsd(self.t, other.t) < t_threshold
145
146
147class Rotation(ReadOnly):
148    """Represents a rotation in 3D about the origin
149
150       The attribute r contains the actual rotation matrix, which is a numpy
151       array with shape (3, 3).
152    """
153    def _check_r(self, r):
154        """the columns must orthogonal"""
155        if abs(np.dot(r[:, 0], r[:, 0]) - 1) > eps or \
156            abs(np.dot(r[:, 0], r[:, 0]) - 1) > eps or \
157            abs(np.dot(r[:, 0], r[:, 0]) - 1) > eps or \
158            np.dot(r[:, 0], r[:, 1]) > eps or \
159            np.dot(r[:, 1], r[:, 2]) > eps or \
160            np.dot(r[:, 2], r[:, 0]) > eps:
161            raise ValueError("The rotation matrix is significantly non-orthonormal.")
162
163
164    r = ReadOnlyAttribute(np.ndarray, none=False, check=_check_r, npdim=2,
165        npshape=(3,3), npdtype=np.floating, doc="the rotation matrix")
166
167    def __init__(self, r):
168        """
169           Argument:
170            | ``r``  --  rotation matrix, a 3 by 3 orthonormal array-like object
171        """
172        self.r = r
173
174    @classmethod
175    def from_matrix(cls, m):
176        """Initialize a rotation from a 4x4 matrix"""
177        check_matrix(m)
178        return cls(m[0:3, 0:3])
179
180    @classmethod
181    def identity(cls):
182        """Return the identity transformation"""
183        return cls(np.identity(3, float))
184
185    @classmethod
186    def random(cls):
187        """Return a random rotation"""
188        axis = random_unit()
189        angle = np.random.uniform(0,2*np.pi)
190        invert = bool(np.random.randint(0,2))
191        return Rotation.from_properties(angle, axis, invert)
192
193    @classmethod
194    def from_properties(cls, angle, axis, invert):
195        """Initialize a rotation based on the properties"""
196        norm = np.linalg.norm(axis)
197        if norm > 0:
198            x = axis[0] / norm
199            y = axis[1] / norm
200            z = axis[2] / norm
201            c = np.cos(angle)
202            s = np.sin(angle)
203            r = (1-2*invert) * np.array([
204                [x*x*(1-c)+c  , x*y*(1-c)-z*s, x*z*(1-c)+y*s],
205                [x*y*(1-c)+z*s, y*y*(1-c)+c  , y*z*(1-c)-x*s],
206                [x*z*(1-c)-y*s, y*z*(1-c)+x*s, z*z*(1-c)+c  ]
207            ])
208        else:
209            r = np.identity(3) * (1-2*invert)
210        return cls(r)
211
212    @cached
213    def properties(self):
214        """Rotation properties: angle, axis, invert"""
215        # determine wether an inversion rotation has been applied
216        invert = (np.linalg.det(self.r) < 0)
217        factor = {True: -1, False: 1}[invert]
218        # get the rotation data
219        # trace(r) = 1+2*cos(angle)
220        cos_angle = 0.5*(factor*np.trace(self.r) - 1)
221        if cos_angle > 1: cos_angle = 1.0
222        if cos_angle < -1: cos_angle = -1.0
223        # the antisymmetric part of the non-diagonal vector tell us something
224        # about sin(angle) and n.
225        axis = 0.5*factor*np.array([-self.r[1, 2] + self.r[2, 1], self.r[0, 2] - self.r[2, 0], -self.r[0, 1] + self.r[1, 0]])
226        sin_angle = np.linalg.norm(axis)
227        # look for the best way to normalize the
228        if (sin_angle == 0.0) and (cos_angle > 0):
229            axis[2] = 1.0
230        elif abs(sin_angle) < (1-cos_angle):
231            for index in range(3):
232                axis[index] = {True: -1, False: 1}[axis[index] < 0] * np.sqrt(abs((factor*self.r[index, index] - cos_angle) / (1 - cos_angle)))
233        else:
234            axis = axis / sin_angle
235
236        # Finally calculate the angle:
237        angle = np.arctan2(sin_angle, cos_angle)
238        return angle, axis, invert
239
240    @cached
241    def matrix(self):
242        """The 4x4 matrix representation of this rotation"""
243        result = np.identity(4, float)
244        result[0:3, 0:3] = self.r
245        return result
246
247    @cached
248    def inv(self):
249        """The inverse rotation"""
250        result = Rotation(self.r.transpose())
251        result._cache_inv = self
252        return result
253
254    def apply_to(self, x, columns=False):
255        """Apply this rotation to the given object
256
257           The argument can be several sorts of objects:
258
259           * ``np.array`` with shape (3, )
260           * ``np.array`` with shape (N, 3)
261           * ``np.array`` with shape (3, N), use ``columns=True``
262           * ``Translation``
263           * ``Rotation``
264           * ``Complete``
265           * ``UnitCell``
266
267           In case of arrays, the 3D vectors are rotated. In case of trans-
268           formations, a transformation is returned that consists of this
269           rotation applied AFTER the given translation. In case of a unit cell,
270           a unit cell with rotated cell vectors is returned.
271
272           This method is equivalent to ``self*x``.
273        """
274        if isinstance(x, np.ndarray) and len(x.shape) == 2 and x.shape[0] == 3 and columns:
275            return np.dot(self.r, x)
276        if isinstance(x, np.ndarray) and (x.shape == (3, ) or (len(x.shape) == 2 and x.shape[1] == 3)) and not columns:
277            return np.dot(x, self.r.transpose())
278        elif isinstance(x, Complete):
279            return Complete(np.dot(self.r, x.r), np.dot(self.r, x.t))
280        elif isinstance(x, Translation):
281            return Complete(self.r, np.dot(self.r, x.t))
282        elif isinstance(x, Rotation):
283            return Rotation(np.dot(self.r, x.r))
284        elif isinstance(x, UnitCell):
285            return UnitCell(np.dot(self.r, x.matrix), x.active)
286        else:
287            raise ValueError("Can not apply this rotation to %s" % x)
288
289    __mul__ = apply_to
290
291    def compare(self, other, r_threshold=1e-3):
292        """Compare two rotations
293
294           The RMSD of the rotation matrices is computed. The return value
295           is True when the RMSD is below the threshold, i.e. when the two
296           rotations are almost identical.
297        """
298        return compute_rmsd(self.r, other.r) < r_threshold
299
300
301class Complete(Translation, Rotation):
302    """Represents a rotation and translation in 3D
303
304       The attribute t contains the actual translation vector, which is a numpy
305       array with three elements. The attribute r contains the actual rotation
306       matrix, which is a numpy array with shape (3, 3).
307
308       Internally the translation part is always applied after the rotation
309       part.
310    """
311    def __init__(self, r, t):
312        """
313           Arguments:
314            | ``r``  --  rotation matrix, a 3 by 3 orthonormal array-like object
315            | ``t``  --  translation vector, a list-like object with three
316                         numbers
317        """
318        Translation.__init__(self, t)
319        Rotation.__init__(self, r)
320
321    @classmethod
322    def from_matrix(cls, m):
323        """Initialize a complete transformation from a 4x4 matrix"""
324        check_matrix(m)
325        return cls(m[0:3, 0:3], m[0:3, 3])
326
327    @classmethod
328    def identity(cls):
329        """Return the identity transformation"""
330        return cls(np.identity(3, float), np.zeros(3, float))
331
332    @classmethod
333    def from_properties(cls, angle, axis, invert, translation):
334        """Initialize a transformation based on the properties"""
335        rot = Rotation.from_properties(angle, axis, invert)
336        return Complete(rot.r, translation)
337
338    @classmethod
339    def cast(cls, c):
340        """Convert the first argument into a Complete object"""
341        if isinstance(c, Complete):
342            return c
343        elif isinstance(c, Translation):
344            return Complete(np.identity(3, float), c.t)
345        elif isinstance(c, Rotation):
346            return Complete(c.r, np.zeros(3, float))
347
348    @classmethod
349    def about_axis(cls, center, angle, axis, invert=False):
350        """Create transformation that represents a rotation about an axis
351
352           Arguments:
353            | ``center``  --  Point on the axis
354            | ``angle``  --  Rotation angle
355            | ``axis``  --  Rotation axis
356            | ``invert``  --  When True, an inversion rotation is constructed
357                              [default=False]
358        """
359        return Translation(center) * \
360               Rotation.from_properties(angle, axis, invert) * \
361               Translation(-center)
362
363    @cached
364    def matrix(self):
365        """The 4x4 matrix representation of this transformation"""
366        result = np.identity(4, float)
367        result[0:3, 3] = self.t
368        result[0:3, 0:3] = self.r
369        return result
370
371    @cached
372    def properties(self):
373        """Transformation properties: angle, axis, invert, translation"""
374        rot = Rotation(self.r)
375        angle, axis, invert = rot.properties
376        return angle, axis, invert, self.t
377
378    @cached
379    def inv(self):
380        """The inverse transformation"""
381        result = Complete(self.r.transpose(), np.dot(self.r.transpose(), -self.t))
382        result._cache_inv = self
383        return result
384
385    def apply_to(self, x, columns=False):
386        """Apply this transformation to the given object
387
388           The argument can be several sorts of objects:
389
390           * ``np.array`` with shape (3, )
391           * ``np.array`` with shape (N, 3)
392           * ``np.array`` with shape (3, N), use ``columns=True``
393           * ``Translation``
394           * ``Rotation``
395           * ``Complete``
396           * ``UnitCell``
397
398           In case of arrays, the 3D vectors are transformed. In case of trans-
399           formations, a transformation is returned that consists of this
400           transformation applied AFTER the given translation. In case of a unit
401           cell, a unit cell with rotated cell vectors is returned. (The
402           translational part does not affect the unit cell.)
403
404           This method is equivalent to self*x.
405        """
406        if isinstance(x, np.ndarray) and len(x.shape) == 2 and x.shape[0] == 3 and columns:
407            return np.dot(self.r, x) + self.t.reshape((3,1))
408        if isinstance(x, np.ndarray) and (x.shape == (3, ) or (len(x.shape) == 2 and x.shape[1] == 3)) and not columns:
409            return np.dot(x, self.r.transpose()) + self.t
410        elif isinstance(x, Complete):
411            return Complete(np.dot(self.r, x.r), np.dot(self.r, x.t) + self.t)
412        elif isinstance(x, Translation):
413            return Complete(self.r, np.dot(self.r, x.t) + self.t)
414        elif isinstance(x, Rotation):
415            return Complete(np.dot(self.r, x.r), self.t)
416        elif isinstance(x, UnitCell):
417            return UnitCell(np.dot(self.r, x.matrix), x.active)
418        else:
419            raise ValueError("Can not apply this rotation to %s" % x)
420
421    __mul__ = apply_to
422
423    def compare(self, other, t_threshold=1e-3, r_threshold=1e-3):
424        """Compare two transformations
425
426           The RMSD values of the rotation matrices and the translation vectors
427           are computed. The return value is True when the RMSD values are below
428           the thresholds, i.e. when the two transformations are almost
429           identical.
430        """
431        return compute_rmsd(self.t, other.t) < t_threshold and compute_rmsd(self.r, other.r) < r_threshold
432
433
434def superpose(ras, rbs, weights=None):
435    """Compute the transformation that minimizes the RMSD between the points ras and rbs
436
437       Arguments:
438        | ``ras``  --  a ``np.array`` with 3D coordinates of geometry A,
439                       shape=(N,3)
440        | ``rbs``  --  a ``np.array`` with 3D coordinates of geometry B,
441                       shape=(N,3)
442
443       Optional arguments:
444        | ``weights``  --  a numpy array with fitting weights for each
445                           coordinate, shape=(N,)
446
447       Return value:
448        | ``transformation``  --  the transformation that brings geometry A into
449                                  overlap with geometry B
450
451       Each row in ras and rbs represents a 3D coordinate. Corresponding rows
452       contain the points that are brought into overlap by the fitting
453       procedure. The implementation is based on the Kabsch Algorithm:
454
455       http://dx.doi.org/10.1107%2FS0567739476001873
456    """
457    if weights is None:
458        ma = ras.mean(axis=0)
459        mb = rbs.mean(axis=0)
460    else:
461        total_weight = weights.sum()
462        ma = np.dot(weights, ras)/total_weight
463        mb = np.dot(weights, rbs)/total_weight
464
465
466    # Kabsch
467    if weights is None:
468        A = np.dot((rbs-mb).transpose(), ras-ma)
469    else:
470        weights = weights.reshape((-1, 1))
471        A = np.dot(((rbs-mb)*weights).transpose(), (ras-ma)*weights)
472    v, s, wt = np.linalg.svd(A)
473    s[:] = 1
474    if np.linalg.det(np.dot(v, wt)) < 0:
475        s[2] = -1
476    r = np.dot(wt.T*s, v.T)
477    return Complete(r, np.dot(r, -mb) + ma)
478
479
480def fit_rmsd(ras, rbs, weights=None):
481    """Fit geometry rbs onto ras, returns more info than superpose
482
483       Arguments:
484        | ``ras``  --  a numpy array with 3D coordinates of geometry A,
485                       shape=(N,3)
486        | ``rbs``  --  a numpy array with 3D coordinates of geometry B,
487                       shape=(N,3)
488
489       Optional arguments:
490        | ``weights``  --  a numpy array with fitting weights for each
491                           coordinate, shape=(N,)
492
493       Return values:
494        | ``transformation``  --  the transformation that brings geometry A into
495                                  overlap with geometry B
496        | ``rbs_trans``  --  the transformed coordinates of geometry B
497        | ``rmsd``  --  the rmsd of the distances between corresponding atoms in
498                        geometry A and B
499
500       This is a utility routine based on the function superpose. It just
501       computes rbs_trans and rmsd after calling superpose with the same
502       arguments
503    """
504    transformation = superpose(ras, rbs, weights)
505    rbs_trans = transformation * rbs
506    rmsd = compute_rmsd(ras, rbs_trans)
507    return transformation, rbs_trans, rmsd
508