1# -*- coding: utf-8 -*- 2# MolMod is a collection of molecular modelling tools for python. 3# Copyright (C) 2007 - 2019 Toon Verstraelen <Toon.Verstraelen@UGent.be>, Center 4# for Molecular Modeling (CMM), Ghent University, Ghent, Belgium; all rights 5# reserved unless otherwise stated. 6# 7# This file is part of MolMod. 8# 9# MolMod is free software; you can redistribute it and/or 10# modify it under the terms of the GNU General Public License 11# as published by the Free Software Foundation; either version 3 12# of the License, or (at your option) any later version. 13# 14# MolMod is distributed in the hope that it will be useful, 15# but WITHOUT ANY WARRANTY; without even the implied warranty of 16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 17# GNU General Public License for more details. 18# 19# You should have received a copy of the GNU General Public License 20# along with this program; if not, see <http://www.gnu.org/licenses/> 21# 22# -- 23 24 25import unittest 26 27import numpy as np 28import pkg_resources 29import pytest 30 31from molmod import * 32from molmod.io import * 33from molmod.periodic import periodic 34from molmod.test.test_unit_cells import get_random_uc 35 36 37__all__ = ["BinningTestCase"] 38 39 40class BinningTestCase(unittest.TestCase): 41 def verify_distances(self, iter_pairs, cutoff, distances, idcls, unit_cell=None): 42 # a few sanity checks first 43 for key, value in distances: 44 self.assertEqual(len(key), 2, "Singletons encountered: %s" % str(key)) 45 count = len(distances) 46 distances = dict(distances) 47 self.assertEqual(len(distances), count, "Duplicate distances: %i > %i" % (count, len(distances))) 48 49 # real check of distances 50 missing_pairs = [] 51 wrong_distances = [] 52 num_total = 0 53 num_correct = 0 54 55 for (id0, coord0), (id1, coord1) in iter_pairs(): 56 delta = coord1 - coord0 57 if unit_cell is not None: 58 delta = unit_cell.shortest_vector(delta) 59 distance = np.linalg.norm(delta) 60 if distance < cutoff: 61 num_total += 1 62 identifier = idcls([id0, id1]) 63 fast_distance = distances.get(identifier) 64 if fast_distance is None: 65 missing_pairs.append(tuple(identifier) + (distance,)) 66 elif fast_distance != distance: 67 wrong_distances.append(tuple(identifier) + (fast_distance, distance)) 68 else: 69 num_correct += 1 70 del distances[identifier] 71 72 message = "-"*50+"\n" 73 message += "CUTOFF %s\n" % cutoff 74 message += "MISSING PAIRS: %i\n" % len(missing_pairs) 75 for missing_pair in missing_pairs: 76 message += "%10s %10s: \t % 10.7f\n" % missing_pair 77 message += "WRONG DISTANCES: %i\n" % len(wrong_distances) 78 for wrong_distance in wrong_distances: 79 message += "%10s %10s: \t % 10.7f != % 10.7f\n" % wrong_distance 80 message += "UNWANTED PAIRS: %i\n" % len(distances) 81 for identifier, fast_distance in distances.items(): 82 message += "%10s %10s: \t % 10.7f\n" % (identifier, fast_distance) 83 message += "TOTAL PAIRS: %i\n" % num_total 84 message += "CORRECT PAIRS: %i\n" % num_correct 85 message += "-"*50+"\n" 86 87 self.assertEqual(len(missing_pairs), 0, message) 88 self.assertEqual(len(wrong_distances), 0, message) 89 self.assertEqual(len(distances), 0, message) 90 91 @pytest.mark.xfail 92 def verify_bins_intra_periodic(self, bins): 93 neighbor_set = set([tuple(index) for index in bins.neighbor_indexes]) 94 self.assertEqual(len(neighbor_set), len(bins.neighbor_indexes)) 95 96 for key0, bin0 in bins: 97 encountered = set([]) 98 for key1, bin1 in bins.iter_surrounding(key0): 99 frac_key1 = bins.integer_cell.to_fractional(key1) 100 assert frac_key1.max() < 0.5, str(key1) + str(frac_key1) 101 assert key1 not in encountered, str(key0) + str(key1) 102 encountered.add(key1) 103 104 @pytest.mark.xfail 105 def verify_bins_inter_periodic(self, bins0, bins1): 106 for bins in bins0, bins1: 107 neighbor_set = set([tuple(index) for index in bins.neighbor_indexes]) 108 self.assertEqual(len(neighbor_set), len(bins.neighbor_indexes)) 109 110 for key0, bin0 in bins0: 111 encountered = set([]) 112 for key1, bin1 in bins1.iter_surrounding(key0): 113 frac_key1 = bins0.integer_cell.to_fractional(key1) 114 assert frac_key1.max() < 0.5, str(key1) + str(frac_key1) 115 assert key1 not in encountered, str(key0) + str(key1) 116 encountered.add(key1) 117 118 def verify_distances_intra(self, coordinates, cutoff, distances, unit_cell=None): 119 def iter_pairs(): 120 for index0, coord0 in enumerate(coordinates): 121 for index1, coord1 in enumerate(coordinates[:index0]): 122 yield (index0, coord0), (index1, coord1) 123 self.verify_distances(iter_pairs, cutoff, distances, frozenset, unit_cell) 124 125 def verify_distances_inter(self, coordinates0, coordinates1, cutoff, distances, unit_cell=None): 126 def iter_pairs(): 127 for index0, coord0 in enumerate(coordinates0): 128 for index1, coord1 in enumerate(coordinates1): 129 yield (index0, coord0), (index1, coord1) 130 self.verify_distances(iter_pairs, cutoff, distances, tuple, unit_cell) 131 132 def test_distances_intra_lau(self): 133 coordinates = XYZFile(pkg_resources.resource_filename("molmod", "data/test/lau.xyz")).geometries[0] 134 cutoff = periodic.max_radius*2 135 distances = [ 136 (frozenset([i0, i1]), distance) 137 for i0, i1, delta, distance 138 in PairSearchIntra(coordinates, cutoff) 139 ] 140 self.verify_distances_intra(coordinates, cutoff, distances) 141 142 def test_distances_intra_lau_periodic(self): 143 coordinates = XYZFile(pkg_resources.resource_filename("molmod", "data/test/lau.xyz")).geometries[0] 144 cutoff = periodic.max_radius*2 145 unit_cell = UnitCell.from_parameters3( 146 np.array([14.59, 12.88, 7.61])*angstrom, 147 np.array([ 90.0, 111.0, 90.0])*deg, 148 ) 149 150 pair_search = PairSearchIntra(coordinates, cutoff, unit_cell) 151 self.verify_bins_intra_periodic(pair_search.bins) 152 153 distances = [ 154 (frozenset([i0, i1]), distance) 155 for i0, i1, delta, distance 156 in pair_search 157 ] 158 159 self.verify_distances_intra(coordinates, cutoff, distances, unit_cell) 160 161 def test_distances_intra_random(self): 162 for i in range(10): 163 coordinates = np.random.uniform(0,5,(20,3)) 164 cutoff = np.random.uniform(1, 6) 165 distances = [ 166 (frozenset([i0, i1]), distance) 167 for i0, i1, delta, distance 168 in PairSearchIntra(coordinates, cutoff) 169 ] 170 self.verify_distances_intra(coordinates, cutoff, distances) 171 172 @pytest.mark.xfail 173 def test_distances_intra_random_periodic(self): 174 for i in range(10): 175 coordinates = np.random.uniform(0,1,(20,3)) 176 unit_cell = get_random_uc(5.0, np.random.randint(0, 4), 0.5) 177 coordinates = unit_cell.to_cartesian(coordinates)*3-unit_cell.matrix.sum(axis=1) 178 cutoff = np.random.uniform(1, 6) 179 180 pair_search = PairSearchIntra(coordinates, cutoff, unit_cell) 181 self.verify_bins_intra_periodic(pair_search.bins) 182 183 distances = [ 184 (frozenset([i0, i1]), distance) 185 for i0, i1, delta, distance 186 in pair_search 187 ] 188 self.verify_distances_intra(coordinates, cutoff, distances, unit_cell) 189 190 def test_distances_inter_random(self): 191 for i in range(10): 192 coordinates0 = np.random.uniform(0,5,(20,3)) 193 coordinates1 = np.random.uniform(0,5,(20,3)) 194 cutoff = np.random.uniform(1, 6) 195 distances = [ 196 ((i0, i1), distance) 197 for i0, i1, delta, distance 198 in PairSearchInter(coordinates0, coordinates1, cutoff) 199 ] 200 self.verify_distances_inter(coordinates0, coordinates1, cutoff, distances) 201 202 @pytest.mark.xfail 203 def test_distances_inter_random_periodic(self): 204 for i in range(10): 205 fractional0 = np.random.uniform(0,1,(20,3)) 206 fractional1 = np.random.uniform(0,1,(20,3)) 207 unit_cell = get_random_uc(5.0, np.random.randint(0, 4), 0.5) 208 coordinates0 = unit_cell.to_cartesian(fractional0)*3-unit_cell.matrix.sum(axis=1) 209 coordinates1 = unit_cell.to_cartesian(fractional1)*3-unit_cell.matrix.sum(axis=1) 210 cutoff = np.random.uniform(1, 6) 211 212 pair_search = PairSearchInter(coordinates0, coordinates1, cutoff, unit_cell) 213 self.verify_bins_inter_periodic(pair_search.bins0, pair_search.bins1) 214 215 distances = [ 216 ((i0, i1), distance) 217 for i0, i1, delta, distance 218 in pair_search 219 ] 220 self.verify_distances_inter(coordinates0, coordinates1, cutoff, distances, unit_cell) 221