1import collections 2import itertools 3import time 4from typing import Union 5 6import numpy as np 7 8from ..exceptions import ValidationError 9from ..models import AlignmentMill 10from ..physical_constants import constants 11from ..testing import compare_values 12from ..util import distance_matrix, linear_sum_assignment, random_rotation_matrix, uno, which_import 13 14 15def _nre(Z, geom): 16 """Nuclear repulsion energy""" 17 18 nre = 0.0 19 for at1 in range(geom.shape[0]): 20 for at2 in range(at1): 21 dist = np.linalg.norm(geom[at1] - geom[at2]) 22 nre += Z[at1] * Z[at2] / dist 23 return nre 24 25 26def _pseudo_nre(Zhash, geom): 27 """Pseudo nuclear repulsion energy where non-physical Z contrived from `Zhash`.""" 28 29 Zidx = list(set(sorted(Zhash))) 30 pZ = [Zidx.index(z) for z in Zhash] 31 return _nre(pZ, geom) 32 33 34def B787( 35 cgeom: np.ndarray, 36 rgeom: np.ndarray, 37 cuniq: np.ndarray, 38 runiq: np.ndarray, 39 do_plot: bool = False, 40 verbose: int = 1, 41 atoms_map: bool = False, 42 run_resorting: bool = False, 43 mols_align: Union[bool, float] = False, 44 run_to_completion: bool = False, 45 algorithm: str = "hungarian_uno", 46 uno_cutoff: float = 1.0e-3, 47 run_mirror: bool = False, 48): 49 r"""Use Kabsch algorithm to find best alignment of geometry `cgeom` onto 50 `rgeom` while sampling atom mappings restricted by `runiq` and `cuniq`. 51 52 Parameters 53 ---------- 54 rgeom 55 (nat, 3) array of reference/target/unchanged geometry. Assumed [a0] 56 for RMSD purposes. 57 cgeom 58 (nat, 3) array of concern/changeable geometry. Assumed [a0] for RMSD 59 purposes. Must have same nat, units, and atom content as rgeom. 60 runiq 61 (nat,) array of str indicating which rows (atoms) in `rgeom` are shuffleable 62 without changing the molecule. Generally hashes of element symbol and 63 mass are used, but could be as simple as ['C', 'H', 'H', 'D', 'H'] for 64 monodeuterated methane. 65 cuniq 66 (nat,) array of str indicating which rows (atoms) in `cgeom` are shuffleable. 67 See `runiq` for more details. Strings and count in `cuniq` must match 68 `runiq`. That is, `sorted(cuniq) == sorted(runiq)`. 69 do_plot 70 Pops up a mpl plot showing before, after, and ref geometries. 71 verbose 72 Quantity of printing. 0 to silence. 73 atoms_map 74 Whether atom1 of rgeom already corresponds to atom1 of cgeom and so on. 75 If `True`, no resorting will be run, parameters `runiq` and `cuniq` 76 may be passed as `None`, and much time will be saved. 77 run_resorting 78 Run the resorting machinery even if unnecessary because `atoms_map=True`. 79 mols_align 80 Whether ref_mol and concern_mol have identical geometries by eye 81 (barring orientation or atom mapping) and expected final RMSD = 0. 82 If `True`, procedure is truncated when RMSD condition met, saving time. 83 If float, convcrit at which search for minimium truncates. 84 run_to_completion 85 Run reorderings to completion (past RMSD = 0) even if unnecessary because 86 `mols_align=True`. Used to test worst-case timings. 87 algorithm 88 {'hungarian_uno', 'permutative'} 89 When `atoms_map=False`, screening algorithm for plausible atom mappings. 90 `permutative` suitable only for small systems. 91 uno_cutoff 92 TODO 93 run_mirror 94 Run alternate geometries potentially allowing best match to `rgeom` 95 from mirror image of `cgeom`. Only run if system confirmed to 96 be nonsuperimposable upon mirror reflection. 97 98 Returns 99 ------- 100 float, tuple 101 First item is RMSD [A] between `rgeom` and the optimally aligned 102 geometry computed. 103 Second item is a AlignmentMill with fields 104 (shift, rotation, atommap, mirror) that prescribe the transformation 105 from `cgeom` and the optimally aligned geometry. 106 107 """ 108 # validation 109 if rgeom.shape != cgeom.shape or rgeom.shape[1] != 3: 110 raise ValidationError("""natom doesn't match: {} != {}""".format(rgeom.shape, cgeom.shape)) 111 nat = rgeom.shape[0] 112 if atoms_map and runiq is None and cuniq is None: 113 runiq = np.array([""] * nat) 114 cuniq = np.array([""] * nat) 115 if sorted(runiq) != sorted(cuniq): 116 raise ValidationError("""atom subclasses unequal:\n {}\n {}""".format(runiq, cuniq)) 117 118 if run_mirror: 119 # use aligner to check if system and its (xz-plane) mirror image are 120 # superimposible and hence whether its worth doubling the number of Kabsch 121 # runs below to check for mirror-image matches 122 mcgeom = np.copy(cgeom) 123 mcgeom[:, 1] *= -1.0 124 exact = 1.0e-6 125 mrmsd, msolution = B787( 126 mcgeom, 127 cgeom, 128 cuniq, 129 cuniq, 130 do_plot=False, 131 verbose=0, 132 atoms_map=False, 133 mols_align=exact, 134 run_mirror=False, 135 uno_cutoff=0.1, 136 ) 137 superimposable = mrmsd < exact 138 if verbose >= 1 and superimposable: 139 print( 140 "Not testing for mirror-image matches (despite `run_mirror`) since system and its mirror are superimposable" 141 ) 142 143 # initialization 144 best_rmsd = 100.0 # [A] 145 ocount = 0 146 hold_solution = None 147 run_resorting = run_resorting or not atoms_map 148 if mols_align is True: 149 a_convergence = 1.0e-3 150 elif mols_align is False: 151 a_convergence = 0.0 152 else: 153 a_convergence = mols_align 154 155 # initial presentation 156 atomfmt2 = """ {} {:16.8f} {:16.8f} {:16.8f}""" 157 158 if verbose >= 2: 159 print("<<< Reference:") 160 for at, _ in enumerate(runiq): 161 print(atomfmt2.format(runiq[at][:6], *rgeom[at])) 162 163 print("<<< Concern:") 164 for at, _ in enumerate(cuniq): 165 print(atomfmt2.format(cuniq[at][:6], *cgeom[at])) 166 167 # start_rmsd is nonsense if not atoms_map 168 start_rmsd = np.linalg.norm(cgeom - rgeom) * constants.bohr2angstroms / np.sqrt(nat) 169 if verbose >= 1: 170 print("Start RMSD = {:8.4f} [A] (naive)".format(start_rmsd)) 171 172 def _plausible_atom_orderings_wrapper( 173 runiq, cuniq, rgeom, cgeom, run_resorting, algorithm="hungarian_uno", verbose=1, uno_cutoff=1.0e-3 174 ): 175 """Wrapper to _plausible_atom_orderings that bypasses it (`run_resorting=False`) when 176 atoms of R & C known to be ordered. Easier to put logic here because _plausible is generator. 177 178 """ 179 if run_resorting: 180 return _plausible_atom_orderings( 181 runiq, cuniq, rgeom, cgeom, algorithm=algorithm, verbose=verbose, uno_cutoff=uno_cutoff 182 ) 183 else: 184 return [np.arange(rgeom.shape[0])] 185 186 t0 = time.time() 187 tc = 0.0 188 for ordering in _plausible_atom_orderings_wrapper( 189 runiq, cuniq, rgeom, cgeom, run_resorting, algorithm=algorithm, verbose=verbose, uno_cutoff=uno_cutoff 190 ): 191 t1 = time.time() 192 ocount += 1 193 npordd = np.asarray(ordering) 194 _, RR, TT = kabsch_align(rgeom, cgeom[npordd, :], weight=None) 195 196 temp_solution = AlignmentMill(shift=TT, rotation=RR, atommap=npordd, mirror=False) 197 tgeom = temp_solution.align_coordinates(cgeom, reverse=False) 198 if verbose >= 4: 199 print("temp geom diff\n", tgeom - rgeom) 200 temp_rmsd = np.linalg.norm(tgeom - rgeom) * constants.bohr2angstroms / np.sqrt(rgeom.shape[0]) 201 temp_rmsd = np.around(temp_rmsd, decimals=8) 202 t2 = time.time() 203 tc += t2 - t1 204 205 if temp_rmsd < best_rmsd: 206 best_rmsd = temp_rmsd 207 hold_solution = temp_solution 208 if verbose >= 1: 209 print("<<< trial {:8} {} yields RMSD {} >>>".format(ocount, npordd, temp_rmsd)) 210 if not run_to_completion and best_rmsd < a_convergence: 211 break 212 else: 213 if verbose >= 3: 214 print(" trial {:8} {} yields RMSD {}".format(ocount, npordd, temp_rmsd)) 215 216 if run_mirror and not superimposable: 217 t1 = time.time() 218 ocount += 1 219 icgeom = np.copy(cgeom) 220 icgeom[:, 1] *= -1.0 221 _, RR, TT = kabsch_align(rgeom, icgeom[npordd, :], weight=None) 222 223 temp_solution = AlignmentMill(shift=TT, rotation=RR, atommap=npordd, mirror=True) 224 tgeom = temp_solution.align_coordinates(cgeom, reverse=False) 225 if verbose >= 4: 226 print("temp geom diff\n", tgeom - rgeom) 227 temp_rmsd = np.linalg.norm(tgeom - rgeom) * constants.bohr2angstroms / np.sqrt(rgeom.shape[0]) 228 temp_rmsd = np.around(temp_rmsd, decimals=8) 229 t2 = time.time() 230 tc += t2 - t1 231 232 if temp_rmsd < best_rmsd: 233 best_rmsd = temp_rmsd 234 hold_solution = temp_solution 235 if verbose >= 1: 236 print("<<< trial {:8}m {} yields RMSD {} >>>".format(ocount - 1, npordd, temp_rmsd)) 237 if not run_to_completion and best_rmsd < a_convergence: 238 break 239 else: 240 if verbose >= 3: 241 print(" trial {:8}m {} yields RMSD {}".format(ocount - 1, npordd, temp_rmsd)) 242 243 t3 = time.time() 244 if verbose >= 1: 245 print("Total time [s] for {:6} iterations: {:.3}".format(ocount, t3 - t0)) 246 print("Hungarian time [s] for atom ordering: {:.3}".format(t3 - t0 - tc)) 247 print("Kabsch time [s] for mol alignment: {:.3}".format(tc)) 248 249 ageom, auniq = hold_solution.align_mini_system(cgeom, cuniq, reverse=False) 250 final_rmsd = np.linalg.norm(ageom - rgeom) * constants.bohr2angstroms / np.sqrt(nat) 251 assert abs(best_rmsd - final_rmsd) < 1.0e-3 252 253 if verbose >= 1: 254 print("Final RMSD = {:8.4f} [A]".format(final_rmsd)) 255 print("Mirror match:", hold_solution.mirror) 256 print(hold_solution) 257 258 # final presentation & plotting 259 if verbose >= 2: 260 print("<<< Aligned:") 261 for at, hsh in enumerate(auniq): 262 print(atomfmt2.format(auniq[at][:6], *ageom[at])) 263 print("<<< Aligned Diff:") 264 for at, hsh in enumerate(auniq): 265 print(atomfmt2.format(auniq[at][:6], *[ageom[at][i] - rgeom[at][i] for i in range(3)])) 266 267 if do_plot: 268 # TODO Missing import 269 plot_coord(ref=rgeom, cand=ageom, orig=cgeom, comment="Final RMSD = {:8.4f}".format(final_rmsd)) 270 271 # sanity checks 272 assert compare_values( 273 _pseudo_nre(cuniq, cgeom), 274 _pseudo_nre(auniq, ageom), 275 "D: concern_mol-->returned_mol pNRE uncorrupted", 276 atol=1.0e-4, 277 quiet=(verbose < 2), 278 ) 279 280 if mols_align is True: 281 assert compare_values( 282 _pseudo_nre(runiq, rgeom), 283 _pseudo_nre(auniq, ageom), 284 "D: concern_mol-->returned_mol pNRE matches ref_mol", 285 atol=1.0e-4, 286 quiet=(verbose < 2), 287 ) 288 assert compare_values( 289 rgeom, ageom, "D: concern_mol-->returned_mol geometry matches ref_mol", atol=1.0e-4, quiet=(verbose < 2) 290 ) 291 assert compare_values(0.0, final_rmsd, "D: null RMSD", atol=1.0e-4, quiet=(verbose < 2)) 292 293 return final_rmsd, hold_solution 294 295 296def _plausible_atom_orderings(ref, current, rgeom, cgeom, algorithm="hungarian_uno", verbose=1, uno_cutoff=1.0e-3): 297 r""" 298 299 Parameters 300 ---------- 301 ref : list 302 Hashes encoding distinguishable non-coord characteristics of reference 303 molecule. Namely, atomic symbol, mass, basis sets?. 304 current : list 305 Hashes encoding distinguishable non-coord characteristics of trial 306 molecule. Namely, atomic symbol, mass, basis sets?. 307 308 Returns 309 ------- 310 iterator of tuples 311 312 """ 313 if sorted(ref) != sorted(current): 314 raise ValidationError( 315 """ref and current can't map to each other.\n""" + "R: " + str(ref) + "\nC: " + str(current) 316 ) 317 318 where = collections.defaultdict(list) 319 for iuq, uq in enumerate(ref): 320 where[uq].append(iuq) 321 322 cwhere = collections.defaultdict(list) 323 for iuq, uq in enumerate(current): 324 cwhere[uq].append(iuq) 325 326 connect = collections.OrderedDict() 327 for k in where: 328 connect[tuple(where[k])] = tuple(cwhere[k]) 329 330 def filter_permutative(rgp, cgp): 331 """Original atom ordering generator for like subset of atoms (e.g., all carbons). 332 Relies on permutation. Filtering depends on similarity of structure (see `atol` parameter). 333 Only suitable for total system size up to about 20 atoms. 334 335 """ 336 if verbose >= 1: 337 print("""Space: {} <--> {}""".format(rgp, cgp)) 338 bnbn = [rrdistmat[first, second] for first, second in zip(rgp, rgp[1:])] 339 for pm in itertools.permutations(cgp): 340 cncn = [ccdistmat[first, second] for first, second in zip(pm, pm[1:])] 341 if np.allclose(bnbn, cncn, atol=1.0): 342 if verbose >= 1: 343 print("Candidate:", rgp, "<--", pm) 344 yield pm 345 346 def filter_hungarian_uno(rgp, cgp): 347 """Hungarian algorithm on cost matrix based off headless (all Z same w/i space anyways) NRE. 348 Having found _a_ solution and the reduced cost matrix, this still isn't likely to produce 349 atom rearrangement fit for Kabsch b/c internal coordinate cost matrix doesn't nail down 350 distance-equivalent atoms with different Cartesian coordinates like Cartesian-distance-matrix 351 cost matrix does. So, form a bipartite graph from all essentially-zero connections between 352 ref and concern and run Uno algorithm to enumerate them. 353 354 """ 355 if verbose >= 1: 356 print("""Space: {} <--> {}""".format(rgp, cgp)) 357 358 # formulate cost matrix from internal (not Cartesian) layouts of R & C 359 npcgp = np.array(cgp) 360 submatCC = ccnremat[np.ix_(cgp, cgp)] 361 submatRR = rrnremat[np.ix_(rgp, rgp)] 362 sumCC = 100.0 * np.sum(submatCC, axis=0) # cost mat small if not scaled, this way like Z=Neon 363 sumRR = 100.0 * np.sum(submatRR, axis=0) 364 cost = np.zeros((len(cgp), len(rgp))) 365 for j in range(cost.shape[1]): 366 for i in range(cost.shape[0]): 367 cost[i, j] = (sumCC[i] - sumRR[j]) ** 2 368 if verbose >= 2: 369 print("Cost:\n", cost) 370 costcopy = np.copy(cost) # other one gets manipulated by hungarian call 371 372 # find _a_ best match btwn R & C atoms through Kuhn-Munkres (Hungarian) algorithm 373 # * linear_sum_assigment call is exactly like `scipy.optimize.linear_sum_assignment(cost)` only with extra return 374 t00 = time.time() 375 (row_ind, col_ind), reducedcost = linear_sum_assignment(cost, return_cost=True) 376 ptsCR = list(zip(row_ind, col_ind)) 377 ptsCR = sorted(ptsCR, key=lambda tup: tup[1]) 378 sumCR = costcopy[row_ind, col_ind].sum() 379 t01 = time.time() 380 if verbose >= 2: 381 print("Reduced cost:\n", cost) 382 if verbose >= 1: 383 print("Hungarian time [s] for space: {:.3}".format(t01 - t00)) 384 385 # find _all_ best matches btwn R & C atoms through Uno algorithm, seeded from Hungarian sol'n 386 edges = np.argwhere(reducedcost < uno_cutoff) 387 gooduns = uno(edges, ptsCR) 388 t02 = time.time() 389 if verbose >= 1: 390 print("Uno time [s] for space: {:.3}".format(t02 - t01)) 391 392 for gu in gooduns: 393 gu2 = gu[:] 394 gu2.sort(key=lambda x: x[1]) # resorts match into (r, c) = (info, range) 395 subans = [p[0] for p in gu2] # compacted to subans/lap format 396 397 ans = tuple(npcgp[np.array(subans)]) 398 if verbose >= 3: 399 print("Best Candidate ({:6.3}):".format(sumCR), rgp, "<--", ans, " from", cgp, subans) 400 yield ans 401 402 if algorithm == "permutative": 403 ccdistmat = distance_matrix(cgeom, cgeom) 404 rrdistmat = distance_matrix(rgeom, rgeom) 405 algofn = filter_permutative 406 407 if algorithm == "hungarian_uno": 408 ccdistmat = distance_matrix(cgeom, cgeom) 409 rrdistmat = distance_matrix(rgeom, rgeom) 410 with np.errstate(divide="ignore"): 411 ccnremat = np.reciprocal(ccdistmat) 412 rrnremat = np.reciprocal(rrdistmat) 413 ccnremat[ccnremat == np.inf] = 0.0 414 rrnremat[rrnremat == np.inf] = 0.0 415 algofn = filter_hungarian_uno 416 417 # Ensure (optional dependency) networkx exists 418 if not which_import("networkx", return_bool=True): 419 raise ModuleNotFoundError( 420 """Python module networkx not found. Solve by installing it: `conda install networkx` or `pip install networkx`""" 421 ) # pragma: no cover 422 423 # collect candidate atom orderings from algofn for each of the atom classes, 424 # recombine the classes with each other in every permutation (could maybe 425 # add Hungarian here, too) as generator back to permutation_kabsch 426 for cpmut in itertools.product(*itertools.starmap(algofn, connect.items())): 427 atpat = [None] * len(ref) 428 for igp, group in enumerate(cpmut): 429 for iidx, idx in enumerate(list(connect.keys())[igp]): 430 atpat[idx] = group[iidx] 431 yield atpat 432 433 434def kabsch_align(rgeom, cgeom, weight=None): 435 r"""Finds optimal translation and rotation to align `cgeom` onto `rgeom` via 436 Kabsch algorithm by minimizing the norm of the residual, || R - U * C ||. 437 438 Parameters 439 ---------- 440 rgeom : ndarray of float 441 (nat, 3) array of reference/target/unchanged geometry. Assumed [a0] 442 for RMSD purposes. 443 cgeom : ndarray of float 444 (nat, 3) array of concern/changeable geometry. Assumed [a0] for RMSD 445 purposes. Must have same Natom, units, and 1-to-1 atom ordering as rgeom. 446 weight : ndarray of float 447 (nat,) array of weights applied to `rgeom`. Note that definitions of 448 weights (nothing to do with atom masses) are several, and I haven't 449 seen one yet that can make centroid the center-of-mass and 450 also make the RMSD match the usual mass-wtd-RMSD definition. 451 Also, only one weight vector used rather than split btwn R & C, 452 which may be invalid if not 1-to-1. Weighting is not recommended. 453 454 Returns 455 ------- 456 float, ndarray, ndarray 457 First item is RMSD [A] between `rgeom` and the optimally aligned 458 geometry computed. 459 Second item is (3, 3) rotation matrix to optimal alignment. 460 Third item is (3,) translation vector [a0] to optimal alignment. 461 462 Sources 463 ------- 464 Kabsch: Acta Cryst. (1978). A34, 827-828 http://journals.iucr.org/a/issues/1978/05/00/a15629/a15629.pdf 465 C++ affine code: https://github.com/oleg-alexandrov/projects/blob/master/eigen/Kabsch.cpp 466 weighted RMSD: http://www.amber.utah.edu/AMBER-workshop/London-2015/tutorial1/ 467 protein wRMSD code: https://pharmacy.umich.edu/sites/default/files/global_wrmsd_v8.3.py.txt 468 quaternion: https://cnx.org/contents/HV-RsdwL@23/Molecular-Distance-Measures 469 470 Author: dsirianni 471 472 """ 473 if weight is None: 474 w = np.ones((rgeom.shape[0])) 475 elif isinstance(weight, (list, np.ndarray)): 476 w = np.asarray(weight) 477 else: 478 raise ValidationError(f"""Unrecognized argument type {type(weight)} for kwarg 'weight'.""") 479 480 R = rgeom 481 C = cgeom 482 N = rgeom.shape[0] 483 if np.allclose(R, C): 484 # can hit a mixed non-identity translation/rotation, so head off 485 return 0.0, np.identity(3), np.zeros(3) 486 487 Rcentroid = R.sum(axis=0) / N 488 Ccentroid = C.sum(axis=0) / N 489 R = np.subtract(R, Rcentroid) 490 C = np.subtract(C, Ccentroid) 491 492 R *= np.sqrt(w[:, None]) 493 C *= np.sqrt(w[:, None]) 494 495 RR = kabsch_quaternion(C.T, R.T) # U 496 TT = Ccentroid - RR.dot(Rcentroid) 497 498 C = C.dot(RR) 499 rmsd = np.linalg.norm(R - C) * constants.bohr2angstroms / np.sqrt(np.sum(w)) 500 501 return rmsd, RR, TT 502 503 504def kabsch_quaternion(P, Q): 505 """Computes the optimal rotation matrix U which mapping a set of points P 506 onto the set of points Q according to the minimization of || Q - U * P ||, 507 using the unit quaternion formulation of the Kabsch algorithm. 508 509 Arguments: 510 <np.ndarray> P := MxN array. M=dimension of space, N=number of points. 511 <np.ndarray> Q := MxN array. M=dimension of space, N=number of points. 512 513 Returns: 514 <np.ndarray> U := Optimal MxM rotation matrix mapping P onto Q. 515 516 Author: dsirianni 517 518 """ 519 # Form covariance matrix 520 cov = Q.dot(P.T) 521 522 # Form the quaternion transformation matrix F 523 F = np.zeros((4, 4)) 524 # diagonal 525 F[0, 0] = cov[0, 0] + cov[1, 1] + cov[2, 2] 526 F[1, 1] = cov[0, 0] - cov[1, 1] - cov[2, 2] 527 F[2, 2] = -cov[0, 0] + cov[1, 1] - cov[2, 2] 528 F[3, 3] = -cov[0, 0] - cov[1, 1] + cov[2, 2] 529 # Upper & lower triangle 530 F[1, 0] = F[0, 1] = cov[1, 2] - cov[2, 1] 531 F[2, 0] = F[0, 2] = cov[2, 0] - cov[0, 2] 532 F[3, 0] = F[0, 3] = cov[0, 1] - cov[1, 0] 533 F[2, 1] = F[1, 2] = cov[0, 1] + cov[1, 0] 534 F[3, 1] = F[1, 3] = cov[0, 2] + cov[2, 0] 535 F[3, 2] = F[2, 3] = cov[1, 2] + cov[2, 1] 536 537 # Compute ew, ev of F 538 ew, ev = np.linalg.eigh(F) 539 540 # Construct optimal rotation matrix from leading ev 541 q = ev[:, -1] 542 U = np.zeros((3, 3)) 543 544 U[0, 0] = q[0] ** 2 + q[1] ** 2 - q[2] ** 2 - q[3] ** 2 545 U[0, 1] = 2 * (q[1] * q[2] - q[0] * q[3]) 546 U[0, 2] = 2 * (q[1] * q[3] + q[0] * q[2]) 547 U[1, 0] = 2 * (q[1] * q[2] + q[0] * q[3]) 548 U[1, 1] = q[0] ** 2 - q[1] ** 2 + q[2] ** 2 - q[3] ** 2 549 U[1, 2] = 2 * (q[2] * q[3] - q[0] * q[1]) 550 U[2, 0] = 2 * (q[1] * q[3] - q[0] * q[2]) 551 U[2, 1] = 2 * (q[2] * q[3] + q[0] * q[1]) 552 U[2, 2] = q[0] ** 2 - q[1] ** 2 - q[2] ** 2 + q[3] ** 2 553 554 return U 555 556 557def compute_scramble(nat, do_resort=True, do_shift=True, do_rotate=True, deflection=1.0, do_mirror=False): 558 r"""Generate a random or directed translation, rotation, and atom shuffling. 559 560 Parameters 561 ---------- 562 nat : int 563 Number of atoms for which to prepare an atom mapping. 564 do_resort : bool or array-like, optional 565 Whether to randomly shuffle atoms (`True`) or leave 1st atom 1st, etc. (`False`) 566 or shuffle according to specified (nat, ) indices (e.g., [2, 1, 0]) 567 do_shift : bool or array-like, optional 568 Whether to generate a random atom shift on interval [-3, 3) in each 569 dimension (`True`) or leave at current origin (`False`) or shift along 570 specified (3, ) vector (e.g., np.array([0., 1., -1.])). 571 do_rotate : bool or array-like, optional 572 Whether to generate a random 3D rotation according to algorithm of Arvo (`True`) 573 or leave at current orientation (`False`) or rotate with specified (3, 3) matrix. 574 deflection : float, optional 575 If `do_rotate`, how random a rotation: 0.0 is no change, 0.1 is small 576 perturbation, 1.0 is completely random. 577 do_mirror : bool, optional 578 Whether to set mirror reflection instruction. Changes identity of 579 molecule so off by default. 580 581 Returns 582 ------- 583 tuple 584 AlignmentMill with fields (shift, rotation, atommap, mirror) 585 as requested: identity, random, or specified. 586 587 """ 588 rand_elord = np.arange(nat) 589 if do_resort is True: 590 np.random.shuffle(rand_elord) 591 elif do_resort is False: 592 pass 593 else: 594 rand_elord = np.array(do_resort) 595 assert rand_elord.shape == (nat,) 596 597 if do_shift is True: 598 rand_shift = 6 * np.random.random_sample((3,)) - 3 599 elif do_shift is False: 600 rand_shift = np.zeros((3,)) 601 else: 602 rand_shift = np.array(do_shift) 603 assert rand_shift.shape == (3,) 604 605 if do_rotate is True: 606 rand_rot3d = random_rotation_matrix(deflection=deflection) 607 elif do_rotate is False: 608 rand_rot3d = np.identity(3) 609 else: 610 rand_rot3d = np.array(do_rotate) 611 assert rand_rot3d.shape == (3, 3) 612 613 perturbation = AlignmentMill(shift=rand_shift, rotation=rand_rot3d, atommap=rand_elord, mirror=do_mirror) 614 return perturbation 615