1from typing import Dict, List, Tuple, Union, Optional
2from numbers import Real
3from collections import namedtuple
4import re
5from string import digits
6import numpy as np
7from ase import Atoms
8from ase.units import Angstrom, Bohr, nm
9
10
11# split on newlines or semicolons
12_re_linesplit = re.compile(r'\n|;')
13# split definitions on whitespace or on "=" (possibly also with whitespace)
14_re_defs = re.compile(r'\s*=\s*|\s+')
15
16
17_ZMatrixRow = namedtuple(
18    'ZMatrixRow', 'ind1 dist ind2 a_bend ind3 a_dihedral',
19)
20
21
22class _ZMatrixToAtoms:
23    known_units = dict(
24        distance={'angstrom': Angstrom, 'bohr': Bohr, 'au': Bohr, 'nm': nm},
25        angle={'radians': 1., 'degrees': np.pi / 180},
26    )
27
28    def __init__(self, dconv: Union[str, Real], aconv: Union[str, Real],
29                 defs: Optional[Union[Dict[str, float],
30                                str, List[str]]] = None) -> None:
31        self.dconv = self.get_units('distance', dconv)  # type: float
32        self.aconv = self.get_units('angle', aconv)  # type: float
33        self.set_defs(defs)
34        self.name_to_index: Optional[Dict[str, int]] = dict()
35        self.symbols = []  # type: List[str]
36        self.positions = []  # type: List[Tuple[float, float, float]]
37
38    @property
39    def nrows(self):
40        return len(self.symbols)
41
42    def get_units(self, kind: str, value: Union[str, Real]) -> float:
43        if isinstance(value, Real):
44            return float(value)
45        out = self.known_units[kind].get(value.lower())
46        if out is None:
47            raise ValueError("Unknown {} units: {}"
48                             .format(kind, value))
49        return out
50
51    def set_defs(self, defs: Union[Dict[str, float], str,
52                                   List[str], None]) -> None:
53        self.defs = dict()  # type: Dict[str, float]
54        if defs is None:
55            return
56
57        if isinstance(defs, dict):
58            self.defs.update(**defs)
59            return
60
61        if isinstance(defs, str):
62            defs = _re_linesplit.split(defs.strip())
63
64        for row in defs:
65            key, val = _re_defs.split(row)
66            self.defs[key] = self.get_var(val)
67
68    def get_var(self, val: str) -> float:
69        try:
70            return float(val)
71        except ValueError as e:
72            val_out = self.defs.get(val.lstrip('+-'))
73            if val_out is None:
74                raise ValueError('Invalid value encountered in Z-matrix: {}'
75                                 .format(val)) from e
76        return val_out * (-1 if val.startswith('-') else 1)
77
78    def get_index(self, name: str) -> int:
79        """Find index for a given atom name"""
80        try:
81            return int(name) - 1
82        except ValueError as e:
83            if self.name_to_index is None or name not in self.name_to_index:
84                raise ValueError('Failed to determine index for name "{}"'
85                                 .format(name)) from e
86        return self.name_to_index[name]
87
88    def set_index(self, name: str) -> None:
89        """Assign index to a given atom name for name -> index lookup"""
90        if self.name_to_index is None:
91            return
92
93        if name in self.name_to_index:
94            # "name" has been encountered before, so name_to_index is no
95            # longer meaningful. Destroy the map.
96            self.name_to_index = None
97            return
98
99        self.name_to_index[name] = self.nrows
100
101    def validate_indices(self, *indices: int) -> None:
102        """Raises an error if indices in a Z-matrix row are invalid."""
103        if any(np.array(indices) >= self.nrows):
104            raise ValueError('An invalid Z-matrix was provided! Row {} refers '
105                             'to atom indices {}, at least one of which '
106                             "hasn't been defined yet!"
107                             .format(self.nrows, indices))
108
109        if len(indices) != len(set(indices)):
110            raise ValueError('An atom index has been used more than once a '
111                             'row of the Z-matrix! Row numbers {}, '
112                             'referred indices: {}'
113                             .format(self.nrows, indices))
114
115    def parse_row(self, row: str) -> Tuple[
116            str, Union[_ZMatrixRow, Tuple[float, float, float]],
117    ]:
118        tokens = row.split()
119        name = tokens[0]
120        self.set_index(name)
121        if len(tokens) == 1:
122            assert self.nrows == 0
123            return name, np.zeros(3, dtype=float)
124
125        ind1 = self.get_index(tokens[1])
126        if ind1 == -1:
127            assert len(tokens) == 5
128            return name, np.array(list(map(self.get_var, tokens[2:])),
129                                  dtype=float)
130
131        dist = self.dconv * self.get_var(tokens[2])
132
133        if len(tokens) == 3:
134            assert self.nrows == 1
135            self.validate_indices(ind1)
136            return name, np.array([dist, 0, 0], dtype=float)
137
138        ind2 = self.get_index(tokens[3])
139        a_bend = self.aconv * self.get_var(tokens[4])
140
141        if len(tokens) == 5:
142            assert self.nrows == 2
143            self.validate_indices(ind1, ind2)
144            return name, _ZMatrixRow(ind1, dist, ind2, a_bend, None, None)
145
146        ind3 = self.get_index(tokens[5])
147        a_dihedral = self.aconv * self.get_var(tokens[6])
148        self.validate_indices(ind1, ind2, ind3)
149        return name, _ZMatrixRow(ind1, dist, ind2, a_bend, ind3,
150                                 a_dihedral)
151
152    def add_atom(self, name: str, pos: Tuple[float, float, float]) -> None:
153        """Sets the symbol and position of an atom."""
154        self.symbols.append(
155            ''.join([c for c in name if c not in digits]).capitalize()
156        )
157        self.positions.append(pos)
158
159    def add_row(self, row: str) -> None:
160        name, zrow = self.parse_row(row)
161
162        if not isinstance(zrow, _ZMatrixRow):
163            self.add_atom(name, zrow)
164            return
165
166        if zrow.ind3 is None:
167            # This is the third atom, so only a bond distance and an angle
168            # have been provided.
169            pos = self.positions[zrow.ind1].copy()
170            pos[0] += zrow.dist * np.cos(zrow.a_bend) * (zrow.ind2 - zrow.ind1)
171            pos[1] += zrow.dist * np.sin(zrow.a_bend)
172            self.add_atom(name, pos)
173            return
174
175        # ax1 is the dihedral axis, which is defined by the bond vector
176        # between the two inner atoms in the dihedral, ind1 and ind2
177        ax1 = self.positions[zrow.ind2] - self.positions[zrow.ind1]
178        ax1 /= np.linalg.norm(ax1)
179
180        # ax2 lies within the 1-2-3 plane, and it is perpendicular
181        # to the dihedral axis
182        ax2 = self.positions[zrow.ind2] - self.positions[zrow.ind3]
183        ax2 -= ax1 * (ax2 @ ax1)
184        ax2 /= np.linalg.norm(ax2)
185
186        # ax3 is a vector that forms the appropriate dihedral angle, though
187        # the bending angle is 90 degrees, rather than a_bend. It is formed
188        # from a linear combination of ax2 and (ax2 x ax1)
189        ax3 = (ax2 * np.cos(zrow.a_dihedral)
190               + np.cross(ax2, ax1) * np.sin(zrow.a_dihedral))
191
192        # The final position vector is a linear combination of ax1 and ax3.
193        pos = ax1 * np.cos(zrow.a_bend) - ax3 * np.sin(zrow.a_bend)
194        pos *= zrow.dist / np.linalg.norm(pos)
195        pos += self.positions[zrow.ind1]
196        self.add_atom(name, pos)
197
198    def to_atoms(self) -> Atoms:
199        return Atoms(self.symbols, self.positions)
200
201
202def parse_zmatrix(zmat: Union[str, List[str]],
203                  distance_units: Union[str, Real] = 'angstrom',
204                  angle_units: Union[str, Real] = 'degrees',
205                  defs: Optional[Union[Dict[str, float], str,
206                                       List[str]]] = None) -> Atoms:
207    """Converts a Z-matrix into an Atoms object.
208
209    Parameters:
210
211    zmat: Iterable or str
212        The Z-matrix to be parsed. Iteration over `zmat` should yield the rows
213        of the Z-matrix. If `zmat` is a str, it will be automatically split
214        into a list at newlines.
215    distance_units: str or float, optional
216        The units of distance in the provided Z-matrix.
217        Defaults to Angstrom.
218    angle_units: str or float, optional
219        The units for angles in the provided Z-matrix.
220        Defaults to degrees.
221    defs: dict or str, optional
222        If `zmat` contains symbols for bond distances, bending angles, and/or
223        dihedral angles instead of numeric values, then the definition of
224        those symbols should be passed to this function using this keyword
225        argument.
226        Note: The symbol definitions are typically printed adjacent to the
227        Z-matrix itself, but this function will not automatically separate
228        the symbol definitions from the Z-matrix.
229
230    Returns:
231
232    atoms: Atoms object
233    """
234    zmatrix = _ZMatrixToAtoms(distance_units, angle_units, defs=defs)
235
236    # zmat should be a list containing the rows of the z-matrix.
237    # for convenience, allow block strings and split at newlines.
238    if isinstance(zmat, str):
239        zmat = _re_linesplit.split(zmat.strip())
240
241    for row in zmat:
242        zmatrix.add_row(row)
243
244    return zmatrix.to_atoms()
245