1"""
2This module contains various support/utility functions.
3"""
4
5from typing import List, Tuple
6import numpy as np
7
8from ase import Atoms
9from ase.calculators.singlepoint import SinglePointCalculator
10from ase.geometry import find_mic
11from ase.geometry import get_distances
12from ase.neighborlist import neighbor_list
13from .cluster_space import ClusterSpace
14from .force_constants import ForceConstants
15from .input_output.logging_tools import logger
16
17
18logger = logger.getChild('utilities')
19
20
21def get_displacements(atoms: Atoms,
22                      atoms_ideal: Atoms,
23                      cell_tol: float = 1e-4) -> np.ndarray:
24    """Returns the the smallest possible displacements between a
25    displaced configuration relative to an ideal (reference)
26    configuration.
27
28    Notes
29    -----
30    * uses :func:`ase.geometry.find_mic`
31    * assumes periodic boundary conditions in all directions
32
33    Parameters
34    ----------
35    atoms
36        configuration with displaced atoms
37    atoms_ideal
38        ideal configuration relative to which displacements are computed
39    cell_tol
40        cell tolerance; if cell missmatch more than tol value error is raised
41    """
42    if not np.array_equal(atoms.numbers, atoms_ideal.numbers):
43        raise ValueError('Atomic numbers do not match.')
44    if np.linalg.norm(atoms.cell - atoms_ideal.cell) > cell_tol:
45        raise ValueError('Cells do not match.')
46
47    raw_position_diff = atoms.positions - atoms_ideal.positions
48    wrapped_mic_displacements = find_mic(raw_position_diff, atoms_ideal.cell, pbc=True)[0]
49    return wrapped_mic_displacements
50
51
52def prepare_structure(atoms: Atoms,
53                      atoms_ideal: Atoms,
54                      calc: SinglePointCalculator = None) -> Atoms:
55    """Prepare a structure in the format suitable for a
56    :class:`StructureContainer <hiphive.StructureContainer>`.
57
58    Parameters
59    ----------
60    atoms
61        input structure
62    atoms_ideal
63        reference structure relative to which displacements are computed
64    calc
65        ASE calculator used for computing forces
66
67    Returns
68    -------
69    ASE atoms object
70        prepared ASE atoms object with forces and displacements as arrays
71    """
72
73    # get forces
74    if 'forces' in atoms.arrays:
75        forces = atoms.get_array('forces')
76    elif calc is not None:
77        atoms_tmp = atoms.copy()
78        atoms_tmp.calc = calc
79        forces = atoms_tmp.get_forces()
80    elif isinstance(atoms.calc, SinglePointCalculator):
81        forces = atoms.get_forces()
82
83    # setup new atoms
84    perm = find_permutation(atoms, atoms_ideal)
85    atoms_new = atoms.copy()
86    atoms_new = atoms_new[perm]
87    atoms_new.arrays['forces'] = forces[perm]
88    disps = get_displacements(atoms_new, atoms_ideal)
89    atoms_new.arrays['displacements'] = disps
90    atoms_new.positions = atoms_ideal.positions
91
92    return atoms_new
93
94
95def prepare_structures(structures: List[Atoms],
96                       atoms_ideal: Atoms,
97                       calc: SinglePointCalculator = None) -> List[Atoms]:
98    """Prepares a set of structures in the format suitable for adding them to
99    a :class:`StructureContainer <hiphive.StructureContainer>`.
100
101    `structures` should represent a list of supercells with displacements
102    while `atoms_ideal` should provide the ideal reference structure (without
103    displacements) for the given structures.
104
105    The structures that are returned will have their positions reset to the
106    ideal structures. Displacements and forces will be added as arrays to the
107    atoms objects.
108
109    If no calculator is provided, then there must be an ASE
110    `SinglePointCalculator <ase.calculators.singlepoint>` object attached to
111    the structures or the forces should already be attached as
112    arrays to the structures.
113
114    If a calculator is provided then it will be used to compute the forces for
115    all structures.
116
117    Example
118    -------
119
120    The following example illustrates the use of this function::
121
122        db = connect('dft_training_structures.db')
123        training_structures = [row.toatoms() for row in db.select()]
124        training_structures = prepare_structures(training_structures, atoms_ideal)
125        for s in training_structures:
126            sc.add_structure(s)
127
128    Parameters
129    ----------
130    structures
131        list of input displaced structures
132    atoms_ideal
133        reference structure relative to which displacements are computed
134    calc
135        ASE calculator used for computing forces
136
137    Returns
138    -------
139    list of prepared structures with forces and displacements as arrays
140    """
141    return [prepare_structure(s, atoms_ideal, calc) for s in structures]
142
143
144def find_permutation(atoms: Atoms,
145                     atoms_ref: Atoms) -> List[int]:
146    """ Returns the best permutation of atoms for mapping one
147    configuration onto another.
148
149    Parameters
150    ----------
151    atoms
152        configuration to be permuted
153    atoms_ref
154        configuration onto which to map
155
156    Examples
157    --------
158    After obtaining the permutation via ``p = find_permutation(atoms1, atoms2)``
159    the reordered structure ``atoms1[p]`` will give the closest match
160    to ``atoms2``.
161    """
162    assert np.linalg.norm(atoms.cell - atoms_ref.cell) < 1e-6
163    dist_matrix = get_distances(
164        atoms.positions, atoms_ref.positions, cell=atoms_ref.cell, pbc=True)[1]
165    permutation = []
166    for i in range(len(atoms_ref)):
167        dist_row = dist_matrix[:, i]
168        permutation.append(np.argmin(dist_row))
169
170    if len(set(permutation)) != len(permutation):
171        raise Exception('Duplicates in permutation')
172    for i, p in enumerate(permutation):
173        if atoms[p].symbol != atoms_ref[i].symbol:
174            raise Exception('Matching lattice sites have different occupation')
175    return permutation
176
177
178class Shell:
179    """
180    Neighbor Shell class
181
182    Parameters
183    ----------
184    types : list or tuple
185        atomic types for neighbor shell
186    distance : float
187        interatomic distance for neighbor shell
188    count : int
189        number of pairs in the neighbor shell
190    """
191
192    def __init__(self,
193                 types: List[str],
194                 distance: float,
195                 count: int = 0):
196        self.types = types
197        self.distance = distance
198        self.count = count
199
200    def __str__(self):
201        s = '{}-{}   distance: {:10.6f}    count: {}'.format(*self.types, self.distance, self.count)
202        return s
203
204    __repr__ = __str__
205
206
207def get_neighbor_shells(atoms: Atoms,
208                        cutoff: float,
209                        dist_tol: float = 1e-5) -> List[Shell]:
210    """ Returns a list of neighbor shells.
211
212    Distances are grouped into shells via the following algorithm:
213
214    1. Find smallest atomic distance `d_min`
215
216    2. Find all pair distances in the range `d_min + 1 * dist_tol`
217
218    3. Construct a shell from these and pop them from distance list
219
220    4. Go to 1.
221
222    Parameters
223    ----------
224    atoms
225        configuration used for finding shells
226    cutoff
227        exclude neighbor shells which have a distance larger than this value
228    dist_tol
229        distance tolerance
230    """
231
232    # get distances
233    ijd = neighbor_list('ijd', atoms, cutoff)
234    ijd = list(zip(*ijd))
235    ijd.sort(key=lambda x: x[2])
236
237    # sort into shells
238    symbols = atoms.get_chemical_symbols()
239    shells = []
240    for i, j, d in ijd:
241        types = tuple(sorted([symbols[i], symbols[j]]))
242        for shell in shells:
243            if abs(d - shell.distance) < dist_tol and types == shell.types:
244                shell.count += 1
245                break
246        else:
247            shell = Shell(types, d, 1)
248            shells.append(shell)
249    shells.sort(key=lambda x: (x.distance, x.types, x.count))
250
251    # warning if two shells are within 2 * tol
252    for i, s1 in enumerate(shells):
253        for j, s2 in enumerate(shells[i+1:]):
254            if s1.types != s2.types:
255                continue
256            if not s1.distance < s2.distance - 2 * dist_tol:
257                logger.warning('Found two shells within 2 * dist_tol')
258
259    return shells
260
261
262def extract_parameters(fcs: ForceConstants,
263                       cs: ClusterSpace) -> Tuple[np.ndarray, np.ndarray, int, np.ndarray]:
264    """ Extracts parameters from force constants.
265
266    TODO: Rename this function with more explanatory name?
267
268    This function can be used to extract parameters to create a
269    ForceConstantPotential from a known set of force constants.
270    The return values come from NumPy's `lstsq function
271    <https://docs.scipy.org/doc/numpy/reference/generated/numpy.linalg.lstsq.html>`_.
272
273    Parameters
274    ----------
275    fcs
276        force constants
277    cs
278        cluster space
279
280    Returns
281    -------
282    x : {(N,), (N, K)} ndarray
283        Least-squares solution. If `b` is two-dimensional,
284        the solutions are in the `K` columns of `x`.
285    residuals : {(1,), (K,), (0,)} ndarray
286        Sums of residuals; squared Euclidean 2-norm for each column in
287        ``b - a*x``.
288        If the rank of `a` is < N or M <= N, this is an empty array.
289        If `b` is 1-dimensional, this is a (1,) shape array.
290        Otherwise the shape is (K,).
291    rank : int
292        Rank of matrix `a`.
293    s : (min(M, N),) ndarray
294        Singular values of `a`.
295    """
296    from .force_constant_model import ForceConstantModel
297    fcm = ForceConstantModel(fcs.supercell, cs)
298    return np.linalg.lstsq(*fcm.get_fcs_sensing(fcs), rcond=None)[0]
299