1import collections
2import itertools
3import time
4from typing import Union
5
6import numpy as np
7
8from ..exceptions import ValidationError
9from ..models import AlignmentMill
10from ..physical_constants import constants
11from ..testing import compare_values
12from ..util import distance_matrix, linear_sum_assignment, random_rotation_matrix, uno, which_import
13
14
15def _nre(Z, geom):
16    """Nuclear repulsion energy"""
17
18    nre = 0.0
19    for at1 in range(geom.shape[0]):
20        for at2 in range(at1):
21            dist = np.linalg.norm(geom[at1] - geom[at2])
22            nre += Z[at1] * Z[at2] / dist
23    return nre
24
25
26def _pseudo_nre(Zhash, geom):
27    """Pseudo nuclear repulsion energy where non-physical Z contrived from `Zhash`."""
28
29    Zidx = list(set(sorted(Zhash)))
30    pZ = [Zidx.index(z) for z in Zhash]
31    return _nre(pZ, geom)
32
33
34def B787(
35    cgeom: np.ndarray,
36    rgeom: np.ndarray,
37    cuniq: np.ndarray,
38    runiq: np.ndarray,
39    do_plot: bool = False,
40    verbose: int = 1,
41    atoms_map: bool = False,
42    run_resorting: bool = False,
43    mols_align: Union[bool, float] = False,
44    run_to_completion: bool = False,
45    algorithm: str = "hungarian_uno",
46    uno_cutoff: float = 1.0e-3,
47    run_mirror: bool = False,
48):
49    r"""Use Kabsch algorithm to find best alignment of geometry `cgeom` onto
50    `rgeom` while sampling atom mappings restricted by `runiq` and `cuniq`.
51
52    Parameters
53    ----------
54    rgeom
55        (nat, 3) array of reference/target/unchanged geometry. Assumed [a0]
56        for RMSD purposes.
57    cgeom
58        (nat, 3) array of concern/changeable geometry. Assumed [a0] for RMSD
59        purposes. Must have same nat, units, and atom content as rgeom.
60    runiq
61        (nat,) array of str indicating which rows (atoms) in `rgeom` are shuffleable
62        without changing the molecule. Generally hashes of element symbol and
63        mass are used, but could be as simple as ['C', 'H', 'H', 'D', 'H'] for
64        monodeuterated methane.
65    cuniq
66        (nat,) array of str indicating which rows (atoms) in `cgeom` are shuffleable.
67        See `runiq` for more details. Strings and count in `cuniq` must match
68        `runiq`. That is, `sorted(cuniq) == sorted(runiq)`.
69    do_plot
70        Pops up a mpl plot showing before, after, and ref geometries.
71    verbose
72        Quantity of printing. 0 to silence.
73    atoms_map
74        Whether atom1 of rgeom already corresponds to atom1 of cgeom and so on.
75        If `True`, no resorting will be run, parameters `runiq` and `cuniq`
76        may be passed as `None`, and much time will be saved.
77    run_resorting
78        Run the resorting machinery even if unnecessary because `atoms_map=True`.
79    mols_align
80        Whether ref_mol and concern_mol have identical geometries by eye
81        (barring orientation or atom mapping) and expected final RMSD = 0.
82        If `True`, procedure is truncated when RMSD condition met, saving time.
83        If float, convcrit at which search for minimium truncates.
84    run_to_completion
85        Run reorderings to completion (past RMSD = 0) even if unnecessary because
86        `mols_align=True`. Used to test worst-case timings.
87    algorithm
88        {'hungarian_uno', 'permutative'}
89        When `atoms_map=False`, screening algorithm for plausible atom mappings.
90        `permutative` suitable only for small systems.
91    uno_cutoff
92        TODO
93    run_mirror
94        Run alternate geometries potentially allowing best match to `rgeom`
95        from mirror image of `cgeom`. Only run if system confirmed to
96        be nonsuperimposable upon mirror reflection.
97
98    Returns
99    -------
100    float, tuple
101        First item is RMSD [A] between `rgeom` and the optimally aligned
102        geometry computed.
103        Second item is a AlignmentMill with fields
104        (shift, rotation, atommap, mirror) that prescribe the transformation
105        from `cgeom` and the optimally aligned geometry.
106
107    """
108    # validation
109    if rgeom.shape != cgeom.shape or rgeom.shape[1] != 3:
110        raise ValidationError("""natom doesn't match: {} != {}""".format(rgeom.shape, cgeom.shape))
111    nat = rgeom.shape[0]
112    if atoms_map and runiq is None and cuniq is None:
113        runiq = np.array([""] * nat)
114        cuniq = np.array([""] * nat)
115    if sorted(runiq) != sorted(cuniq):
116        raise ValidationError("""atom subclasses unequal:\n  {}\n  {}""".format(runiq, cuniq))
117
118    if run_mirror:
119        # use aligner to check if system and its (xz-plane) mirror image are
120        #   superimposible and hence whether its worth doubling the number of Kabsch
121        #   runs below to check for mirror-image matches
122        mcgeom = np.copy(cgeom)
123        mcgeom[:, 1] *= -1.0
124        exact = 1.0e-6
125        mrmsd, msolution = B787(
126            mcgeom,
127            cgeom,
128            cuniq,
129            cuniq,
130            do_plot=False,
131            verbose=0,
132            atoms_map=False,
133            mols_align=exact,
134            run_mirror=False,
135            uno_cutoff=0.1,
136        )
137        superimposable = mrmsd < exact
138        if verbose >= 1 and superimposable:
139            print(
140                "Not testing for mirror-image matches (despite `run_mirror`) since system and its mirror are superimposable"
141            )
142
143    # initialization
144    best_rmsd = 100.0  # [A]
145    ocount = 0
146    hold_solution = None
147    run_resorting = run_resorting or not atoms_map
148    if mols_align is True:
149        a_convergence = 1.0e-3
150    elif mols_align is False:
151        a_convergence = 0.0
152    else:
153        a_convergence = mols_align
154
155    # initial presentation
156    atomfmt2 = """  {} {:16.8f} {:16.8f} {:16.8f}"""
157
158    if verbose >= 2:
159        print("<<<  Reference:")
160        for at, _ in enumerate(runiq):
161            print(atomfmt2.format(runiq[at][:6], *rgeom[at]))
162
163        print("<<<  Concern:")
164        for at, _ in enumerate(cuniq):
165            print(atomfmt2.format(cuniq[at][:6], *cgeom[at]))
166
167    # start_rmsd is nonsense if not atoms_map
168    start_rmsd = np.linalg.norm(cgeom - rgeom) * constants.bohr2angstroms / np.sqrt(nat)
169    if verbose >= 1:
170        print("Start RMSD = {:8.4f} [A] (naive)".format(start_rmsd))
171
172    def _plausible_atom_orderings_wrapper(
173        runiq, cuniq, rgeom, cgeom, run_resorting, algorithm="hungarian_uno", verbose=1, uno_cutoff=1.0e-3
174    ):
175        """Wrapper to _plausible_atom_orderings that bypasses it (`run_resorting=False`) when
176        atoms of R & C known to be ordered. Easier to put logic here because _plausible is generator.
177
178        """
179        if run_resorting:
180            return _plausible_atom_orderings(
181                runiq, cuniq, rgeom, cgeom, algorithm=algorithm, verbose=verbose, uno_cutoff=uno_cutoff
182            )
183        else:
184            return [np.arange(rgeom.shape[0])]
185
186    t0 = time.time()
187    tc = 0.0
188    for ordering in _plausible_atom_orderings_wrapper(
189        runiq, cuniq, rgeom, cgeom, run_resorting, algorithm=algorithm, verbose=verbose, uno_cutoff=uno_cutoff
190    ):
191        t1 = time.time()
192        ocount += 1
193        npordd = np.asarray(ordering)
194        _, RR, TT = kabsch_align(rgeom, cgeom[npordd, :], weight=None)
195
196        temp_solution = AlignmentMill(shift=TT, rotation=RR, atommap=npordd, mirror=False)
197        tgeom = temp_solution.align_coordinates(cgeom, reverse=False)
198        if verbose >= 4:
199            print("temp geom diff\n", tgeom - rgeom)
200        temp_rmsd = np.linalg.norm(tgeom - rgeom) * constants.bohr2angstroms / np.sqrt(rgeom.shape[0])
201        temp_rmsd = np.around(temp_rmsd, decimals=8)
202        t2 = time.time()
203        tc += t2 - t1
204
205        if temp_rmsd < best_rmsd:
206            best_rmsd = temp_rmsd
207            hold_solution = temp_solution
208            if verbose >= 1:
209                print("<<<  trial {:8}  {} yields RMSD {}  >>>".format(ocount, npordd, temp_rmsd))
210            if not run_to_completion and best_rmsd < a_convergence:
211                break
212        else:
213            if verbose >= 3:
214                print("     trial {:8}  {} yields RMSD {}".format(ocount, npordd, temp_rmsd))
215
216        if run_mirror and not superimposable:
217            t1 = time.time()
218            ocount += 1
219            icgeom = np.copy(cgeom)
220            icgeom[:, 1] *= -1.0
221            _, RR, TT = kabsch_align(rgeom, icgeom[npordd, :], weight=None)
222
223            temp_solution = AlignmentMill(shift=TT, rotation=RR, atommap=npordd, mirror=True)
224            tgeom = temp_solution.align_coordinates(cgeom, reverse=False)
225            if verbose >= 4:
226                print("temp geom diff\n", tgeom - rgeom)
227            temp_rmsd = np.linalg.norm(tgeom - rgeom) * constants.bohr2angstroms / np.sqrt(rgeom.shape[0])
228            temp_rmsd = np.around(temp_rmsd, decimals=8)
229            t2 = time.time()
230            tc += t2 - t1
231
232            if temp_rmsd < best_rmsd:
233                best_rmsd = temp_rmsd
234                hold_solution = temp_solution
235                if verbose >= 1:
236                    print("<<<  trial {:8}m {} yields RMSD {}  >>>".format(ocount - 1, npordd, temp_rmsd))
237                if not run_to_completion and best_rmsd < a_convergence:
238                    break
239            else:
240                if verbose >= 3:
241                    print("     trial {:8}m {} yields RMSD {}".format(ocount - 1, npordd, temp_rmsd))
242
243    t3 = time.time()
244    if verbose >= 1:
245        print("Total time [s] for {:6} iterations: {:.3}".format(ocount, t3 - t0))
246        print("Hungarian time [s] for atom ordering: {:.3}".format(t3 - t0 - tc))
247        print("Kabsch time [s] for mol alignment:    {:.3}".format(tc))
248
249    ageom, auniq = hold_solution.align_mini_system(cgeom, cuniq, reverse=False)
250    final_rmsd = np.linalg.norm(ageom - rgeom) * constants.bohr2angstroms / np.sqrt(nat)
251    assert abs(best_rmsd - final_rmsd) < 1.0e-3
252
253    if verbose >= 1:
254        print("Final RMSD = {:8.4f} [A]".format(final_rmsd))
255        print("Mirror match:", hold_solution.mirror)
256        print(hold_solution)
257
258    # final presentation & plotting
259    if verbose >= 2:
260        print("<<<  Aligned:")
261        for at, hsh in enumerate(auniq):
262            print(atomfmt2.format(auniq[at][:6], *ageom[at]))
263        print("<<<  Aligned Diff:")
264        for at, hsh in enumerate(auniq):
265            print(atomfmt2.format(auniq[at][:6], *[ageom[at][i] - rgeom[at][i] for i in range(3)]))
266
267    if do_plot:
268        # TODO Missing import
269        plot_coord(ref=rgeom, cand=ageom, orig=cgeom, comment="Final RMSD = {:8.4f}".format(final_rmsd))
270
271    # sanity checks
272    assert compare_values(
273        _pseudo_nre(cuniq, cgeom),
274        _pseudo_nre(auniq, ageom),
275        "D: concern_mol-->returned_mol pNRE uncorrupted",
276        atol=1.0e-4,
277        quiet=(verbose < 2),
278    )
279
280    if mols_align is True:
281        assert compare_values(
282            _pseudo_nre(runiq, rgeom),
283            _pseudo_nre(auniq, ageom),
284            "D: concern_mol-->returned_mol pNRE matches ref_mol",
285            atol=1.0e-4,
286            quiet=(verbose < 2),
287        )
288        assert compare_values(
289            rgeom, ageom, "D: concern_mol-->returned_mol geometry matches ref_mol", atol=1.0e-4, quiet=(verbose < 2)
290        )
291        assert compare_values(0.0, final_rmsd, "D: null RMSD", atol=1.0e-4, quiet=(verbose < 2))
292
293    return final_rmsd, hold_solution
294
295
296def _plausible_atom_orderings(ref, current, rgeom, cgeom, algorithm="hungarian_uno", verbose=1, uno_cutoff=1.0e-3):
297    r"""
298
299    Parameters
300    ----------
301    ref : list
302        Hashes encoding distinguishable non-coord characteristics of reference
303        molecule. Namely, atomic symbol, mass, basis sets?.
304    current : list
305        Hashes encoding distinguishable non-coord characteristics of trial
306        molecule. Namely, atomic symbol, mass, basis sets?.
307
308    Returns
309    -------
310    iterator of tuples
311
312    """
313    if sorted(ref) != sorted(current):
314        raise ValidationError(
315            """ref and current can't map to each other.\n""" + "R:  " + str(ref) + "\nC:  " + str(current)
316        )
317
318    where = collections.defaultdict(list)
319    for iuq, uq in enumerate(ref):
320        where[uq].append(iuq)
321
322    cwhere = collections.defaultdict(list)
323    for iuq, uq in enumerate(current):
324        cwhere[uq].append(iuq)
325
326    connect = collections.OrderedDict()
327    for k in where:
328        connect[tuple(where[k])] = tuple(cwhere[k])
329
330    def filter_permutative(rgp, cgp):
331        """Original atom ordering generator for like subset of atoms (e.g., all carbons).
332        Relies on permutation. Filtering depends on similarity of structure (see `atol` parameter).
333        Only suitable for total system size up to about 20 atoms.
334
335        """
336        if verbose >= 1:
337            print("""Space:     {} <--> {}""".format(rgp, cgp))
338        bnbn = [rrdistmat[first, second] for first, second in zip(rgp, rgp[1:])]
339        for pm in itertools.permutations(cgp):
340            cncn = [ccdistmat[first, second] for first, second in zip(pm, pm[1:])]
341            if np.allclose(bnbn, cncn, atol=1.0):
342                if verbose >= 1:
343                    print("Candidate:", rgp, "<--", pm)
344                yield pm
345
346    def filter_hungarian_uno(rgp, cgp):
347        """Hungarian algorithm on cost matrix based off headless (all Z same w/i space anyways) NRE.
348        Having found _a_ solution and the reduced cost matrix, this still isn't likely to produce
349        atom rearrangement fit for Kabsch b/c internal coordinate cost matrix doesn't nail down
350        distance-equivalent atoms with different Cartesian coordinates like Cartesian-distance-matrix
351        cost matrix does. So, form a bipartite graph from all essentially-zero connections between
352        ref and concern and run Uno algorithm to enumerate them.
353
354        """
355        if verbose >= 1:
356            print("""Space:     {} <--> {}""".format(rgp, cgp))
357
358        # formulate cost matrix from internal (not Cartesian) layouts of R & C
359        npcgp = np.array(cgp)
360        submatCC = ccnremat[np.ix_(cgp, cgp)]
361        submatRR = rrnremat[np.ix_(rgp, rgp)]
362        sumCC = 100.0 * np.sum(submatCC, axis=0)  # cost mat small if not scaled, this way like Z=Neon
363        sumRR = 100.0 * np.sum(submatRR, axis=0)
364        cost = np.zeros((len(cgp), len(rgp)))
365        for j in range(cost.shape[1]):
366            for i in range(cost.shape[0]):
367                cost[i, j] = (sumCC[i] - sumRR[j]) ** 2
368        if verbose >= 2:
369            print("Cost:\n", cost)
370        costcopy = np.copy(cost)  # other one gets manipulated by hungarian call
371
372        # find _a_ best match btwn R & C atoms through Kuhn-Munkres (Hungarian) algorithm
373        # * linear_sum_assigment call is exactly like `scipy.optimize.linear_sum_assignment(cost)` only with extra return
374        t00 = time.time()
375        (row_ind, col_ind), reducedcost = linear_sum_assignment(cost, return_cost=True)
376        ptsCR = list(zip(row_ind, col_ind))
377        ptsCR = sorted(ptsCR, key=lambda tup: tup[1])
378        sumCR = costcopy[row_ind, col_ind].sum()
379        t01 = time.time()
380        if verbose >= 2:
381            print("Reduced cost:\n", cost)
382        if verbose >= 1:
383            print("Hungarian time [s] for space:         {:.3}".format(t01 - t00))
384
385        # find _all_ best matches btwn R & C atoms through Uno algorithm, seeded from Hungarian sol'n
386        edges = np.argwhere(reducedcost < uno_cutoff)
387        gooduns = uno(edges, ptsCR)
388        t02 = time.time()
389        if verbose >= 1:
390            print("Uno time [s] for space:               {:.3}".format(t02 - t01))
391
392        for gu in gooduns:
393            gu2 = gu[:]
394            gu2.sort(key=lambda x: x[1])  # resorts match into (r, c) = (info, range)
395            subans = [p[0] for p in gu2]  # compacted to subans/lap format
396
397            ans = tuple(npcgp[np.array(subans)])
398            if verbose >= 3:
399                print("Best Candidate ({:6.3}):".format(sumCR), rgp, "<--", ans, "     from", cgp, subans)
400            yield ans
401
402    if algorithm == "permutative":
403        ccdistmat = distance_matrix(cgeom, cgeom)
404        rrdistmat = distance_matrix(rgeom, rgeom)
405        algofn = filter_permutative
406
407    if algorithm == "hungarian_uno":
408        ccdistmat = distance_matrix(cgeom, cgeom)
409        rrdistmat = distance_matrix(rgeom, rgeom)
410        with np.errstate(divide="ignore"):
411            ccnremat = np.reciprocal(ccdistmat)
412            rrnremat = np.reciprocal(rrdistmat)
413        ccnremat[ccnremat == np.inf] = 0.0
414        rrnremat[rrnremat == np.inf] = 0.0
415        algofn = filter_hungarian_uno
416
417        # Ensure (optional dependency) networkx exists
418        if not which_import("networkx", return_bool=True):
419            raise ModuleNotFoundError(
420                """Python module networkx not found. Solve by installing it: `conda install networkx` or `pip install networkx`"""
421            )  # pragma: no cover
422
423    # collect candidate atom orderings from algofn for each of the atom classes,
424    #   recombine the classes with each other in every permutation (could maybe
425    #   add Hungarian here, too) as generator back to permutation_kabsch
426    for cpmut in itertools.product(*itertools.starmap(algofn, connect.items())):
427        atpat = [None] * len(ref)
428        for igp, group in enumerate(cpmut):
429            for iidx, idx in enumerate(list(connect.keys())[igp]):
430                atpat[idx] = group[iidx]
431        yield atpat
432
433
434def kabsch_align(rgeom, cgeom, weight=None):
435    r"""Finds optimal translation and rotation to align `cgeom` onto `rgeom` via
436    Kabsch algorithm by minimizing the norm of the residual, || R - U * C ||.
437
438    Parameters
439    ----------
440    rgeom : ndarray of float
441        (nat, 3) array of reference/target/unchanged geometry. Assumed [a0]
442        for RMSD purposes.
443    cgeom : ndarray of float
444        (nat, 3) array of concern/changeable geometry. Assumed [a0] for RMSD
445        purposes. Must have same Natom, units, and 1-to-1 atom ordering as rgeom.
446    weight : ndarray of float
447        (nat,) array of weights applied to `rgeom`. Note that definitions of
448        weights (nothing to do with atom masses) are several, and I haven't
449        seen one yet that can make centroid the center-of-mass and
450        also make the RMSD match the usual mass-wtd-RMSD definition.
451        Also, only one weight vector used rather than split btwn R & C,
452        which may be invalid if not 1-to-1. Weighting is not recommended.
453
454    Returns
455    -------
456    float, ndarray, ndarray
457        First item is RMSD [A] between `rgeom` and the optimally aligned
458        geometry computed.
459        Second item is (3, 3) rotation matrix to optimal alignment.
460        Third item is (3,) translation vector [a0] to optimal alignment.
461
462    Sources
463    -------
464    Kabsch: Acta Cryst. (1978). A34, 827-828 http://journals.iucr.org/a/issues/1978/05/00/a15629/a15629.pdf
465    C++ affine code: https://github.com/oleg-alexandrov/projects/blob/master/eigen/Kabsch.cpp
466    weighted RMSD: http://www.amber.utah.edu/AMBER-workshop/London-2015/tutorial1/
467    protein wRMSD code: https://pharmacy.umich.edu/sites/default/files/global_wrmsd_v8.3.py.txt
468    quaternion: https://cnx.org/contents/HV-RsdwL@23/Molecular-Distance-Measures
469
470    Author: dsirianni
471
472    """
473    if weight is None:
474        w = np.ones((rgeom.shape[0]))
475    elif isinstance(weight, (list, np.ndarray)):
476        w = np.asarray(weight)
477    else:
478        raise ValidationError(f"""Unrecognized argument type {type(weight)} for kwarg 'weight'.""")
479
480    R = rgeom
481    C = cgeom
482    N = rgeom.shape[0]
483    if np.allclose(R, C):
484        # can hit a mixed non-identity translation/rotation, so head off
485        return 0.0, np.identity(3), np.zeros(3)
486
487    Rcentroid = R.sum(axis=0) / N
488    Ccentroid = C.sum(axis=0) / N
489    R = np.subtract(R, Rcentroid)
490    C = np.subtract(C, Ccentroid)
491
492    R *= np.sqrt(w[:, None])
493    C *= np.sqrt(w[:, None])
494
495    RR = kabsch_quaternion(C.T, R.T)  # U
496    TT = Ccentroid - RR.dot(Rcentroid)
497
498    C = C.dot(RR)
499    rmsd = np.linalg.norm(R - C) * constants.bohr2angstroms / np.sqrt(np.sum(w))
500
501    return rmsd, RR, TT
502
503
504def kabsch_quaternion(P, Q):
505    """Computes the optimal rotation matrix U which mapping a set of points P
506    onto the set of points Q according to the minimization of || Q - U * P ||,
507    using the unit quaternion formulation of the Kabsch algorithm.
508
509    Arguments:
510    <np.ndarray> P := MxN array. M=dimension of space, N=number of points.
511    <np.ndarray> Q := MxN array. M=dimension of space, N=number of points.
512
513    Returns:
514    <np.ndarray> U := Optimal MxM rotation matrix mapping P onto Q.
515
516    Author: dsirianni
517
518    """
519    # Form covariance matrix
520    cov = Q.dot(P.T)
521
522    # Form the quaternion transformation matrix F
523    F = np.zeros((4, 4))
524    # diagonal
525    F[0, 0] = cov[0, 0] + cov[1, 1] + cov[2, 2]
526    F[1, 1] = cov[0, 0] - cov[1, 1] - cov[2, 2]
527    F[2, 2] = -cov[0, 0] + cov[1, 1] - cov[2, 2]
528    F[3, 3] = -cov[0, 0] - cov[1, 1] + cov[2, 2]
529    # Upper & lower triangle
530    F[1, 0] = F[0, 1] = cov[1, 2] - cov[2, 1]
531    F[2, 0] = F[0, 2] = cov[2, 0] - cov[0, 2]
532    F[3, 0] = F[0, 3] = cov[0, 1] - cov[1, 0]
533    F[2, 1] = F[1, 2] = cov[0, 1] + cov[1, 0]
534    F[3, 1] = F[1, 3] = cov[0, 2] + cov[2, 0]
535    F[3, 2] = F[2, 3] = cov[1, 2] + cov[2, 1]
536
537    # Compute ew, ev of F
538    ew, ev = np.linalg.eigh(F)
539
540    # Construct optimal rotation matrix from leading ev
541    q = ev[:, -1]
542    U = np.zeros((3, 3))
543
544    U[0, 0] = q[0] ** 2 + q[1] ** 2 - q[2] ** 2 - q[3] ** 2
545    U[0, 1] = 2 * (q[1] * q[2] - q[0] * q[3])
546    U[0, 2] = 2 * (q[1] * q[3] + q[0] * q[2])
547    U[1, 0] = 2 * (q[1] * q[2] + q[0] * q[3])
548    U[1, 1] = q[0] ** 2 - q[1] ** 2 + q[2] ** 2 - q[3] ** 2
549    U[1, 2] = 2 * (q[2] * q[3] - q[0] * q[1])
550    U[2, 0] = 2 * (q[1] * q[3] - q[0] * q[2])
551    U[2, 1] = 2 * (q[2] * q[3] + q[0] * q[1])
552    U[2, 2] = q[0] ** 2 - q[1] ** 2 - q[2] ** 2 + q[3] ** 2
553
554    return U
555
556
557def compute_scramble(nat, do_resort=True, do_shift=True, do_rotate=True, deflection=1.0, do_mirror=False):
558    r"""Generate a random or directed translation, rotation, and atom shuffling.
559
560    Parameters
561    ----------
562    nat : int
563        Number of atoms for which to prepare an atom mapping.
564    do_resort : bool or array-like, optional
565        Whether to randomly shuffle atoms (`True`) or leave 1st atom 1st, etc. (`False`)
566        or shuffle according to specified (nat, ) indices (e.g., [2, 1, 0])
567    do_shift : bool or array-like, optional
568        Whether to generate a random atom shift on interval [-3, 3) in each
569        dimension (`True`) or leave at current origin (`False`) or shift along
570        specified (3, ) vector (e.g., np.array([0., 1., -1.])).
571    do_rotate : bool or array-like, optional
572        Whether to generate a random 3D rotation according to algorithm of Arvo (`True`)
573        or leave at current orientation (`False`) or rotate with specified (3, 3) matrix.
574    deflection : float, optional
575        If `do_rotate`, how random a rotation: 0.0 is no change, 0.1 is small
576        perturbation, 1.0 is completely random.
577    do_mirror : bool, optional
578        Whether to set mirror reflection instruction. Changes identity of
579        molecule so off by default.
580
581    Returns
582    -------
583    tuple
584        AlignmentMill with fields (shift, rotation, atommap, mirror)
585        as requested: identity, random, or specified.
586
587    """
588    rand_elord = np.arange(nat)
589    if do_resort is True:
590        np.random.shuffle(rand_elord)
591    elif do_resort is False:
592        pass
593    else:
594        rand_elord = np.array(do_resort)
595        assert rand_elord.shape == (nat,)
596
597    if do_shift is True:
598        rand_shift = 6 * np.random.random_sample((3,)) - 3
599    elif do_shift is False:
600        rand_shift = np.zeros((3,))
601    else:
602        rand_shift = np.array(do_shift)
603        assert rand_shift.shape == (3,)
604
605    if do_rotate is True:
606        rand_rot3d = random_rotation_matrix(deflection=deflection)
607    elif do_rotate is False:
608        rand_rot3d = np.identity(3)
609    else:
610        rand_rot3d = np.array(do_rotate)
611        assert rand_rot3d.shape == (3, 3)
612
613    perturbation = AlignmentMill(shift=rand_shift, rotation=rand_rot3d, atommap=rand_elord, mirror=do_mirror)
614    return perturbation
615