1# -*- coding: utf-8 -*-
2# MolMod is a collection of molecular modelling tools for python.
3# Copyright (C) 2007 - 2019 Toon Verstraelen <Toon.Verstraelen@UGent.be>, Center
4# for Molecular Modeling (CMM), Ghent University, Ghent, Belgium; all rights
5# reserved unless otherwise stated.
6#
7# This file is part of MolMod.
8#
9# MolMod is free software; you can redistribute it and/or
10# modify it under the terms of the GNU General Public License
11# as published by the Free Software Foundation; either version 3
12# of the License, or (at your option) any later version.
13#
14# MolMod is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program; if not, see <http://www.gnu.org/licenses/>
21#
22# --
23
24
25import unittest
26
27import numpy as np
28import pkg_resources
29import pytest
30
31from molmod import *
32from molmod.io import *
33from molmod.periodic import periodic
34from molmod.test.test_unit_cells import get_random_uc
35
36
37__all__ = ["BinningTestCase"]
38
39
40class BinningTestCase(unittest.TestCase):
41    def verify_distances(self, iter_pairs, cutoff, distances, idcls, unit_cell=None):
42        # a few sanity checks first
43        for key, value in distances:
44            self.assertEqual(len(key), 2, "Singletons encountered: %s" % str(key))
45        count = len(distances)
46        distances = dict(distances)
47        self.assertEqual(len(distances), count, "Duplicate distances: %i > %i" % (count, len(distances)))
48
49        # real check of distances
50        missing_pairs = []
51        wrong_distances = []
52        num_total = 0
53        num_correct = 0
54
55        for (id0, coord0), (id1, coord1) in iter_pairs():
56            delta = coord1 - coord0
57            if unit_cell is not None:
58                delta = unit_cell.shortest_vector(delta)
59            distance = np.linalg.norm(delta)
60            if distance < cutoff:
61                num_total += 1
62                identifier = idcls([id0, id1])
63                fast_distance = distances.get(identifier)
64                if fast_distance is None:
65                    missing_pairs.append(tuple(identifier) + (distance,))
66                elif fast_distance != distance:
67                    wrong_distances.append(tuple(identifier) + (fast_distance, distance))
68                else:
69                    num_correct += 1
70                    del distances[identifier]
71
72        message  = "-"*50+"\n"
73        message += "CUTOFF %s\n" % cutoff
74        message += "MISSING PAIRS: %i\n" % len(missing_pairs)
75        for missing_pair in missing_pairs:
76            message += "%10s %10s: \t % 10.7f\n" % missing_pair
77        message += "WRONG DISTANCES: %i\n" % len(wrong_distances)
78        for wrong_distance in wrong_distances:
79            message += "%10s %10s: \t % 10.7f != % 10.7f\n" % wrong_distance
80        message += "UNWANTED PAIRS: %i\n" % len(distances)
81        for identifier, fast_distance in distances.items():
82            message += "%10s %10s: \t % 10.7f\n" % (identifier, fast_distance)
83        message += "TOTAL PAIRS: %i\n" % num_total
84        message += "CORRECT PAIRS: %i\n" % num_correct
85        message += "-"*50+"\n"
86
87        self.assertEqual(len(missing_pairs), 0, message)
88        self.assertEqual(len(wrong_distances), 0, message)
89        self.assertEqual(len(distances), 0, message)
90
91    @pytest.mark.xfail
92    def verify_bins_intra_periodic(self, bins):
93        neighbor_set = set([tuple(index) for index in bins.neighbor_indexes])
94        self.assertEqual(len(neighbor_set), len(bins.neighbor_indexes))
95
96        for key0, bin0 in bins:
97            encountered = set([])
98            for key1, bin1 in bins.iter_surrounding(key0):
99                frac_key1 = bins.integer_cell.to_fractional(key1)
100                assert frac_key1.max() < 0.5, str(key1) + str(frac_key1)
101                assert key1 not in encountered, str(key0) + str(key1)
102                encountered.add(key1)
103
104    @pytest.mark.xfail
105    def verify_bins_inter_periodic(self, bins0, bins1):
106        for bins in bins0, bins1:
107            neighbor_set = set([tuple(index) for index in bins.neighbor_indexes])
108            self.assertEqual(len(neighbor_set), len(bins.neighbor_indexes))
109
110        for key0, bin0 in bins0:
111            encountered = set([])
112            for key1, bin1 in bins1.iter_surrounding(key0):
113                frac_key1 = bins0.integer_cell.to_fractional(key1)
114                assert frac_key1.max() < 0.5, str(key1) + str(frac_key1)
115                assert key1 not in encountered, str(key0) + str(key1)
116                encountered.add(key1)
117
118    def verify_distances_intra(self, coordinates, cutoff, distances, unit_cell=None):
119        def iter_pairs():
120            for index0, coord0 in enumerate(coordinates):
121                for index1, coord1 in enumerate(coordinates[:index0]):
122                    yield (index0, coord0), (index1, coord1)
123        self.verify_distances(iter_pairs, cutoff, distances, frozenset, unit_cell)
124
125    def verify_distances_inter(self, coordinates0, coordinates1, cutoff, distances, unit_cell=None):
126        def iter_pairs():
127            for index0, coord0 in enumerate(coordinates0):
128                for index1, coord1 in enumerate(coordinates1):
129                    yield (index0, coord0), (index1, coord1)
130        self.verify_distances(iter_pairs, cutoff, distances, tuple, unit_cell)
131
132    def test_distances_intra_lau(self):
133        coordinates = XYZFile(pkg_resources.resource_filename("molmod", "data/test/lau.xyz")).geometries[0]
134        cutoff = periodic.max_radius*2
135        distances = [
136            (frozenset([i0, i1]), distance)
137            for i0, i1, delta, distance
138            in PairSearchIntra(coordinates, cutoff)
139        ]
140        self.verify_distances_intra(coordinates, cutoff, distances)
141
142    def test_distances_intra_lau_periodic(self):
143        coordinates = XYZFile(pkg_resources.resource_filename("molmod", "data/test/lau.xyz")).geometries[0]
144        cutoff = periodic.max_radius*2
145        unit_cell = UnitCell.from_parameters3(
146            np.array([14.59, 12.88, 7.61])*angstrom,
147            np.array([ 90.0, 111.0, 90.0])*deg,
148        )
149
150        pair_search = PairSearchIntra(coordinates, cutoff, unit_cell)
151        self.verify_bins_intra_periodic(pair_search.bins)
152
153        distances = [
154            (frozenset([i0, i1]), distance)
155            for i0, i1, delta, distance
156            in pair_search
157        ]
158
159        self.verify_distances_intra(coordinates, cutoff, distances, unit_cell)
160
161    def test_distances_intra_random(self):
162        for i in range(10):
163            coordinates = np.random.uniform(0,5,(20,3))
164            cutoff = np.random.uniform(1, 6)
165            distances = [
166                (frozenset([i0, i1]), distance)
167                for i0, i1, delta, distance
168                in PairSearchIntra(coordinates, cutoff)
169            ]
170            self.verify_distances_intra(coordinates, cutoff, distances)
171
172    @pytest.mark.xfail
173    def test_distances_intra_random_periodic(self):
174        for i in range(10):
175            coordinates = np.random.uniform(0,1,(20,3))
176            unit_cell = get_random_uc(5.0, np.random.randint(0, 4), 0.5)
177            coordinates = unit_cell.to_cartesian(coordinates)*3-unit_cell.matrix.sum(axis=1)
178            cutoff = np.random.uniform(1, 6)
179
180            pair_search = PairSearchIntra(coordinates, cutoff, unit_cell)
181            self.verify_bins_intra_periodic(pair_search.bins)
182
183            distances = [
184                (frozenset([i0, i1]), distance)
185                for i0, i1, delta, distance
186                in pair_search
187            ]
188            self.verify_distances_intra(coordinates, cutoff, distances, unit_cell)
189
190    def test_distances_inter_random(self):
191        for i in range(10):
192            coordinates0 = np.random.uniform(0,5,(20,3))
193            coordinates1 = np.random.uniform(0,5,(20,3))
194            cutoff = np.random.uniform(1, 6)
195            distances = [
196                ((i0, i1), distance)
197                for i0, i1, delta, distance
198                in PairSearchInter(coordinates0, coordinates1, cutoff)
199            ]
200            self.verify_distances_inter(coordinates0, coordinates1, cutoff, distances)
201
202    @pytest.mark.xfail
203    def test_distances_inter_random_periodic(self):
204        for i in range(10):
205            fractional0 = np.random.uniform(0,1,(20,3))
206            fractional1 = np.random.uniform(0,1,(20,3))
207            unit_cell = get_random_uc(5.0, np.random.randint(0, 4), 0.5)
208            coordinates0 = unit_cell.to_cartesian(fractional0)*3-unit_cell.matrix.sum(axis=1)
209            coordinates1 = unit_cell.to_cartesian(fractional1)*3-unit_cell.matrix.sum(axis=1)
210            cutoff = np.random.uniform(1, 6)
211
212            pair_search = PairSearchInter(coordinates0, coordinates1, cutoff, unit_cell)
213            self.verify_bins_inter_periodic(pair_search.bins0, pair_search.bins1)
214
215            distances = [
216                ((i0, i1), distance)
217                for i0, i1, delta, distance
218                in pair_search
219            ]
220            self.verify_distances_inter(coordinates0, coordinates1, cutoff, distances, unit_cell)
221