1# -*- Mode: python; tab-width: 4; indent-tabs-mode:nil; coding:utf-8 -*- 2# vim: tabstop=4 expandtab shiftwidth=4 softtabstop=4 fileencoding=utf-8 3# 4# MDAnalysis --- https://www.mdanalysis.org 5# Copyright (c) 2006-2017 The MDAnalysis Development Team and contributors 6# (see the file AUTHORS for the full list of names) 7# 8# Released under the GNU Public Licence, v2 or any higher version 9# 10# Please cite your use of MDAnalysis in published work: 11# 12# R. J. Gowers, M. Linke, J. Barnoud, T. J. E. Reddy, M. N. Melo, S. L. Seyler, 13# D. L. Dotson, J. Domanski, S. Buchoux, I. M. Kenney, and O. Beckstein. 14# MDAnalysis: A Python package for the rapid analysis of molecular dynamics 15# simulations. In S. Benthall and S. Rostrup editors, Proceedings of the 15th 16# Python in Science Conference, pages 102-109, Austin, TX, 2016. SciPy. 17# 18# N. Michaud-Agrawal, E. J. Denning, T. B. Woolf, and O. Beckstein. 19# MDAnalysis: A Toolkit for the Analysis of Molecular Dynamics Simulations. 20# J. Comput. Chem. 32 (2011), 2319--2327, doi:10.1002/jcc.21787 21# 22from __future__ import absolute_import, division 23 24 25from six.moves import range, StringIO 26import pytest 27import os 28import warnings 29import re 30import textwrap 31 32import numpy as np 33from numpy.testing import (assert_equal, assert_almost_equal, 34 assert_array_almost_equal, assert_array_equal) 35 36import MDAnalysis as mda 37import MDAnalysis.lib.util as util 38import MDAnalysis.lib.mdamath as mdamath 39from MDAnalysis.lib.util import (cached, static_variables, warn_if_not_unique, 40 check_coords) 41from MDAnalysis.core.topologyattrs import Bonds 42from MDAnalysis.exceptions import NoDataError, DuplicateWarning 43 44 45from MDAnalysisTests.datafiles import ( 46 Make_Whole, TPR, GRO, fullerene, two_water_gro, 47) 48 49 50def convert_aa_code_long_data(): 51 aa = [ 52 ('H', 53 ('HIS', 'HISA', 'HISB', 'HSE', 'HSD', 'HIS1', 'HIS2', 'HIE', 'HID')), 54 ('K', ('LYS', 'LYSH', 'LYN')), 55 ('A', ('ALA',)), 56 ('D', ('ASP', 'ASPH', 'ASH')), 57 ('E', ('GLU', 'GLUH', 'GLH')), 58 ('N', ('ASN',)), 59 ('Q', ('GLN',)), 60 ('C', ('CYS', 'CYSH', 'CYS1', 'CYS2')), 61 ] 62 for resname1, strings in aa: 63 for resname3 in strings: 64 yield (resname3, resname1) 65 66 67class TestStringFunctions(object): 68 # (1-letter, (canonical 3 letter, other 3/4 letter, ....)) 69 aa = [ 70 ('H', 71 ('HIS', 'HISA', 'HISB', 'HSE', 'HSD', 'HIS1', 'HIS2', 'HIE', 'HID')), 72 ('K', ('LYS', 'LYSH', 'LYN')), 73 ('A', ('ALA',)), 74 ('D', ('ASP', 'ASPH', 'ASH')), 75 ('E', ('GLU', 'GLUH', 'GLH')), 76 ('N', ('ASN',)), 77 ('Q', ('GLN',)), 78 ('C', ('CYS', 'CYSH', 'CYS1', 'CYS2')), 79 ] 80 81 residues = [ 82 ("LYS300:HZ1", ("LYS", 300, "HZ1")), 83 ("K300:HZ1", ("LYS", 300, "HZ1")), 84 ("K300", ("LYS", 300, None)), 85 ("LYS 300:HZ1", ("LYS", 300, "HZ1")), 86 ("M1:CA", ("MET", 1, "CA")), 87 ] 88 89 @pytest.mark.parametrize('rstring, residue', residues) 90 def test_parse_residue(self, rstring, residue): 91 assert util.parse_residue(rstring) == residue 92 93 def test_parse_residue_ValueError(self): 94 with pytest.raises(ValueError): 95 util.parse_residue('ZZZ') 96 97 @pytest.mark.parametrize('resname3, resname1', convert_aa_code_long_data()) 98 def test_convert_aa_3to1(self, resname3, resname1): 99 assert util.convert_aa_code(resname3) == resname1 100 101 @pytest.mark.parametrize('resname1, strings', aa) 102 def test_convert_aa_1to3(self, resname1, strings): 103 assert util.convert_aa_code(resname1) == strings[0] 104 105 @pytest.mark.parametrize('x', ( 106 'XYZXYZ', 107 '£' 108 )) 109 def test_ValueError(self, x): 110 with pytest.raises(ValueError): 111 util.convert_aa_code(x) 112 113 114def test_greedy_splitext(inp="foo/bar/boing.2.pdb.bz2", 115 ref=["foo/bar/boing", ".2.pdb.bz2"]): 116 inp = os.path.normpath(inp) 117 ref[0] = os.path.normpath(ref[0]) 118 ref[1] = os.path.normpath(ref[1]) 119 root, ext = util.greedy_splitext(inp) 120 assert root == ref[0], "root incorrect" 121 assert ext == ref[1], "extension incorrect" 122 123 124@pytest.mark.parametrize('iterable, value', [ 125 ([1, 2, 3], True), 126 ([], True), 127 ((1, 2, 3), True), 128 ((), True), 129 (range(3), True), 130 (np.array([1, 2, 3]), True), 131 (123, False), 132 ("byte string", False), 133 (u"unicode string", False) 134]) 135def test_iterable(iterable, value): 136 assert util.iterable(iterable) == value 137 138 139class TestFilename(object): 140 root = "foo" 141 filename = "foo.psf" 142 ext = "pdb" 143 filename2 = "foo.pdb" 144 145 @pytest.mark.parametrize('name, ext, keep, actual_name', [ 146 (filename, None, False, filename), 147 (filename, ext, False, filename2), 148 (filename, ext, True, filename), 149 (root, ext, False, filename2), 150 (root, ext, True, filename2) 151 ]) 152 def test_string(self, name, ext, keep, actual_name): 153 file_name = util.filename(name, ext, keep) 154 assert file_name == actual_name 155 156 def test_named_stream(self): 157 ns = util.NamedStream(StringIO(), self.filename) 158 fn = util.filename(ns, ext=self.ext) 159 # assert_equal replace by this if loop to avoid segfault on some systems 160 if fn != ns: 161 pytest.fail("fn and ns are different") 162 assert str(fn) == self.filename2 163 assert ns.name == self.filename2 164 165 166class TestGeometryFunctions(object): 167 e1, e2, e3 = np.eye(3) 168 a = np.array([np.cos(np.pi / 3), np.sin(np.pi / 3), 0]) 169 null = np.zeros(3) 170 171 @pytest.mark.parametrize('x_axis, y_axis, value', [ 172 # Unit vectors 173 (e1, e2, np.pi / 2), 174 (e1, a, np.pi / 3), 175 # Angle vectors 176 (2 * e1, e2, np.pi / 2), 177 (-2 * e1, e2, np.pi - np.pi / 2), 178 (23.3 * e1, a, np.pi / 3), 179 # Null vector 180 (e1, null, np.nan), 181 # Coleniar 182 (a, a, 0.0) 183 ]) 184 def test_vectors(self, x_axis, y_axis, value): 185 assert_equal(mdamath.angle(x_axis, y_axis), value) 186 187 @pytest.mark.parametrize('x_axis, y_axis, value', [ 188 (-2.3456e7 * e1, 3.4567e-6 * e1, np.pi), 189 (2.3456e7 * e1, 3.4567e-6 * e1, 0.0) 190 ]) 191 def test_angle_pi(self, x_axis, y_axis, value): 192 assert_almost_equal(mdamath.angle(x_axis, y_axis), value) 193 194 @pytest.mark.parametrize('x', np.linspace(0, np.pi, 20)) 195 def test_angle_range(self, x): 196 r = 1000. 197 v = r * np.array([np.cos(x), np.sin(x), 0]) 198 assert_almost_equal(mdamath.angle(self.e1, v), x, 6) 199 200 @pytest.mark.parametrize('vector, value', [ 201 (e3, 1), 202 (a, np.linalg.norm(a)), 203 (null, 0.0) 204 ]) 205 def test_norm(self, vector, value): 206 assert mdamath.norm(vector) == value 207 208 @pytest.mark.parametrize('x', np.linspace(0, np.pi, 20)) 209 def test_norm_range(self, x): 210 r = 1000. 211 v = r * np.array([np.cos(x), np.sin(x), 0]) 212 assert_almost_equal(mdamath.norm(v), r, 6) 213 214 @pytest.mark.parametrize('vec1, vec2, value', [ 215 (e1, e2, e3), 216 (e1, null, 0.0) 217 ]) 218 def test_normal(self, vec1, vec2, value): 219 assert_equal(mdamath.normal(vec1, vec2), value) 220 # add more non-trivial tests 221 222 def test_stp(self): 223 assert mdamath.stp(self.e1, self.e2, self.e3) == 1.0 224 # add more non-trivial tests 225 226 def test_dihedral(self): 227 ab = self.e1 228 bc = ab + self.e2 229 cd = bc + self.e3 230 assert_almost_equal(mdamath.dihedral(ab, bc, cd), -np.pi / 2) 231 232 233class TestMakeWhole(object): 234 """Set up a simple system: 235 236 +-----------+ 237 | | 238 | 6 3 | 6 239 | ! ! | ! 240 |-5-8 1-2-|-5-8 241 | ! ! | ! 242 | 7 4 | 7 243 | | 244 +-----------+ 245 """ 246 247 prec = 5 248 249 @pytest.fixture() 250 def universe(self): 251 universe = mda.Universe(Make_Whole) 252 bondlist = [(0, 1), (1, 2), (1, 3), (1, 4), (4, 5), (4, 6), (4, 7)] 253 universe.add_TopologyAttr(Bonds(bondlist)) 254 return universe 255 256 def test_single_atom_no_bonds(self): 257 # Call make_whole on single atom with no bonds, shouldn't move 258 u = mda.Universe(Make_Whole) 259 # Atom0 is isolated 260 bondlist = [(1, 2), (1, 3), (1, 4), (4, 5), (4, 6), (4, 7)] 261 u.add_TopologyAttr(Bonds(bondlist)) 262 263 ag = u.atoms[[0]] 264 refpos = ag.positions.copy() 265 mdamath.make_whole(ag) 266 267 assert_array_almost_equal(ag.positions, refpos) 268 269 def test_scrambled_ag(self, universe): 270 # if order of atomgroup is mixed 271 ag = universe.atoms[[1, 3, 2, 4, 0, 6, 5, 7]] 272 273 mdamath.make_whole(ag) 274 275 # artificial system which uses 1nm bonds, so 276 # largest bond should be 20A 277 assert ag.bonds.values().max() < 20.1 278 279 @staticmethod 280 @pytest.fixture() 281 def ag(universe): 282 return universe.residues[0].atoms 283 284 def test_no_bonds(self): 285 # NoData caused by no bonds 286 universe = mda.Universe(Make_Whole) 287 ag = universe.residues[0].atoms 288 with pytest.raises(NoDataError): 289 mdamath.make_whole(ag) 290 291 def test_zero_box_size(self, universe, ag): 292 universe.dimensions = [0., 0., 0., 90., 90., 90.] 293 with pytest.raises(ValueError): 294 mdamath.make_whole(ag) 295 296 def test_wrong_reference_atom(self, universe, ag): 297 # Reference atom not in atomgroup 298 with pytest.raises(ValueError): 299 mdamath.make_whole(ag, reference_atom=universe.atoms[-1]) 300 301 def test_impossible_solve(self, universe): 302 # check that the algorithm sees the bad walk 303 with pytest.raises(ValueError): 304 mdamath.make_whole(universe.atoms) 305 306 def test_solve_1(self, universe, ag): 307 # regular usage of function 308 309 refpos = universe.atoms[:4].positions.copy() 310 311 mdamath.make_whole(ag) 312 313 assert_array_almost_equal(universe.atoms[:4].positions, refpos) 314 assert_array_almost_equal(universe.atoms[4].position, 315 np.array([110.0, 50.0, 0.0]), decimal=self.prec) 316 assert_array_almost_equal(universe.atoms[5].position, 317 np.array([110.0, 60.0, 0.0]), decimal=self.prec) 318 assert_array_almost_equal(universe.atoms[6].position, 319 np.array([110.0, 40.0, 0.0]), decimal=self.prec) 320 assert_array_almost_equal(universe.atoms[7].position, 321 np.array([120.0, 50.0, 0.0]), decimal=self.prec) 322 323 def test_solve_2(self, universe, ag): 324 # use but specify the center atom 325 326 refpos = universe.atoms[4:8].positions.copy() 327 328 mdamath.make_whole(ag, reference_atom=universe.residues[0].atoms[4]) 329 330 assert_array_almost_equal(universe.atoms[4:8].positions, refpos) 331 assert_array_almost_equal(universe.atoms[0].position, 332 np.array([-20.0, 50.0, 0.0]), decimal=self.prec) 333 assert_array_almost_equal(universe.atoms[1].position, 334 np.array([-10.0, 50.0, 0.0]), decimal=self.prec) 335 assert_array_almost_equal(universe.atoms[2].position, 336 np.array([-10.0, 60.0, 0.0]), decimal=self.prec) 337 assert_array_almost_equal(universe.atoms[3].position, 338 np.array([-10.0, 40.0, 0.0]), decimal=self.prec) 339 340 def test_solve_3(self, universe): 341 # put in a chunk that doesn't need any work 342 343 refpos = universe.atoms[:1].positions.copy() 344 345 mdamath.make_whole(universe.atoms[:1]) 346 347 assert_array_almost_equal(universe.atoms[:1].positions, refpos) 348 349 def test_solve_4(self, universe): 350 # Put in only some of a fragment, 351 # check that not everything gets moved 352 353 chunk = universe.atoms[:7] 354 refpos = universe.atoms[7].position.copy() 355 356 mdamath.make_whole(chunk) 357 358 assert_array_almost_equal(universe.atoms[7].position, refpos) 359 assert_array_almost_equal(universe.atoms[4].position, 360 np.array([110.0, 50.0, 0.0])) 361 assert_array_almost_equal(universe.atoms[5].position, 362 np.array([110.0, 60.0, 0.0])) 363 assert_array_almost_equal(universe.atoms[6].position, 364 np.array([110.0, 40.0, 0.0])) 365 366 def test_double_frag_short_bonds(self, universe, ag): 367 # previous bug where if two fragments are given 368 # but all bonds were short, the algorithm didn't 369 # complain 370 mdamath.make_whole(ag) 371 with pytest.raises(ValueError): 372 mdamath.make_whole(universe.atoms) 373 374 def test_make_whole_triclinic(self): 375 u = mda.Universe(TPR, GRO) 376 thing = u.select_atoms('not resname SOL NA+') 377 mdamath.make_whole(thing) 378 379 blengths = thing.bonds.values() 380 381 assert blengths.max() < 2.0 382 383 def test_make_whole_fullerene(self): 384 # lots of circular bonds as a nice pathological case 385 u = mda.Universe(fullerene) 386 387 bbox = u.atoms.bbox() 388 u.dimensions[:3] = bbox[1] - bbox[0] 389 u.dimensions[3:] = 90.0 390 391 blengths = u.atoms.bonds.values() 392 # kaboom 393 u.atoms[::2].translate([u.dimensions[0], -2 * u.dimensions[1], 0.0]) 394 u.atoms[1::2].translate([0.0, 7 * u.dimensions[1], -5 * u.dimensions[2]]) 395 396 mdamath.make_whole(u.atoms) 397 398 assert_array_almost_equal(u.atoms.bonds.values(), blengths, decimal=self.prec) 399 400 def test_make_whole_multiple_molecules(self): 401 u = mda.Universe(two_water_gro, guess_bonds=True) 402 403 for f in u.atoms.fragments: 404 mdamath.make_whole(f) 405 406 assert u.atoms.bonds.values().max() < 2.0 407 408class Class_with_Caches(object): 409 def __init__(self): 410 self._cache = dict() 411 self.ref1 = 1.0 412 self.ref2 = 2.0 413 self.ref3 = 3.0 414 self.ref4 = 4.0 415 self.ref5 = 5.0 416 417 @cached('val1') 418 def val1(self): 419 return self.ref1 420 421 # Do one with property decorator as these are used together often 422 @property 423 @cached('val2') 424 def val2(self): 425 return self.ref2 426 427 # Check use of property setters 428 @property 429 @cached('val3') 430 def val3(self): 431 return self.ref3 432 433 @val3.setter 434 def val3(self, new): 435 self._clear_caches('val3') 436 self._fill_cache('val3', new) 437 438 @val3.deleter 439 def val3(self): 440 self._clear_caches('val3') 441 442 # Check that args are passed through to underlying functions 443 @cached('val4') 444 def val4(self, n1, n2): 445 return self._init_val_4(n1, n2) 446 447 def _init_val_4(self, m1, m2): 448 return self.ref4 + m1 + m2 449 450 # Args and Kwargs 451 @cached('val5') 452 def val5(self, n, s=None): 453 return self._init_val_5(n, s=s) 454 455 def _init_val_5(self, n, s=None): 456 return n * s 457 458 # These are designed to mimic the AG and Universe cache methods 459 def _clear_caches(self, *args): 460 if len(args) == 0: 461 self._cache = dict() 462 else: 463 for name in args: 464 try: 465 del self._cache[name] 466 except KeyError: 467 pass 468 469 def _fill_cache(self, name, value): 470 self._cache[name] = value 471 472 473class TestCachedDecorator(object): 474 @pytest.fixture() 475 def obj(self): 476 return Class_with_Caches() 477 478 def test_val1_lookup(self, obj): 479 obj._clear_caches() 480 assert 'val1' not in obj._cache 481 assert obj.val1() == obj.ref1 482 ret = obj.val1() 483 assert 'val1' in obj._cache 484 assert obj._cache['val1'] == ret 485 assert obj.val1() is obj._cache['val1'] 486 487 def test_val1_inject(self, obj): 488 # Put something else into the cache and check it gets returned 489 # this tests that the cache is blindly being used 490 obj._clear_caches() 491 ret = obj.val1() 492 assert 'val1' in obj._cache 493 assert ret == obj.ref1 494 new = 77.0 495 obj._fill_cache('val1', new) 496 assert obj.val1() == new 497 498 # Managed property 499 def test_val2_lookup(self, obj): 500 obj._clear_caches() 501 assert 'val2' not in obj._cache 502 assert obj.val2 == obj.ref2 503 ret = obj.val2 504 assert 'val2' in obj._cache 505 assert obj._cache['val2'] == ret 506 507 def test_val2_inject(self, obj): 508 obj._clear_caches() 509 ret = obj.val2 510 assert 'val2' in obj._cache 511 assert ret == obj.ref2 512 new = 77.0 513 obj._fill_cache('val2', new) 514 assert obj.val2 == new 515 516 # Setter on cached attribute 517 518 def test_val3_set(self, obj): 519 obj._clear_caches() 520 assert obj.val3 == obj.ref3 521 new = 99.0 522 obj.val3 = new 523 assert obj.val3 == new 524 assert obj._cache['val3'] == new 525 526 def test_val3_del(self, obj): 527 # Check that deleting the property removes it from cache, 528 obj._clear_caches() 529 assert obj.val3 == obj.ref3 530 assert 'val3' in obj._cache 531 del obj.val3 532 assert 'val3' not in obj._cache 533 # But allows it to work as usual afterwards 534 assert obj.val3 == obj.ref3 535 assert 'val3' in obj._cache 536 537 # Pass args 538 def test_val4_args(self, obj): 539 obj._clear_caches() 540 assert obj.val4(1, 2) == 1 + 2 + obj.ref4 541 # Further calls should yield the old result 542 # this arguably shouldn't be cached... 543 assert obj.val4(3, 4) == 1 + 2 + obj.ref4 544 545 # Pass args and kwargs 546 def test_val5_kwargs(self, obj): 547 obj._clear_caches() 548 assert obj.val5(5, s='abc') == 5 * 'abc' 549 550 assert obj.val5(5, s='!!!') == 5 * 'abc' 551 552 553class TestConvFloat(object): 554 @pytest.mark.parametrize('s, output', [ 555 ('0.45', 0.45), 556 ('.45', 0.45), 557 ('a.b', 'a.b') 558 ]) 559 def test_float(self, s, output): 560 assert util.conv_float(s) == output 561 562 @pytest.mark.parametrize('input, output', [ 563 (('0.45', '0.56', '6.7'), [0.45, 0.56, 6.7]), 564 (('0.45', 'a.b', '!!'), [0.45, 'a.b', '!!']) 565 ]) 566 def test_map(self, input, output): 567 ret = [util.conv_float(el) for el in input] 568 assert ret == output 569 570 571class TestFixedwidthBins(object): 572 def test_keys(self): 573 ret = util.fixedwidth_bins(0.5, 1.0, 2.0) 574 for k in ['Nbins', 'delta', 'min', 'max']: 575 assert k in ret 576 577 def test_ValueError(self): 578 with pytest.raises(ValueError): 579 util.fixedwidth_bins(0.1, 5.0, 4.0) 580 581 @pytest.mark.parametrize( 582 'delta, xmin, xmax, output_Nbins, output_delta, output_min, output_max', 583 [ 584 (0.1, 4.0, 5.0, 10, 0.1, 4.0, 5.0), 585 (0.4, 4.0, 5.0, 3, 0.4, 3.9, 5.1) 586 ]) 587 def test_usage(self, delta, xmin, xmax, output_Nbins, output_delta, 588 output_min, output_max): 589 ret = util.fixedwidth_bins(delta, xmin, xmax) 590 assert ret['Nbins'] == output_Nbins 591 assert ret['delta'] == output_delta 592 assert ret['min'], output_min 593 assert ret['max'], output_max 594 595@pytest.fixture 596def atoms(): 597 from MDAnalysisTests import make_Universe 598 u = make_Universe(extras=("masses",), size=(3,1,1)) 599 return u.atoms 600 601@pytest.mark.parametrize('weights,result', 602 [ 603 (None, None), 604 ("mass", np.array([5.1, 4.2, 3.3])), 605 (np.array([12.0, 1.0, 12.0]), np.array([12.0, 1.0, 12.0])), 606 ([12.0, 1.0, 12.0], np.array([12.0, 1.0, 12.0])), 607 (range(3), np.arange(3, dtype=int)), 608 ]) 609def test_check_weights_ok(atoms, weights, result): 610 assert_array_equal(util.get_weights(atoms, weights), result) 611 612@pytest.mark.parametrize('weights', 613 [42, 614 "geometry", 615 np.array(1.0), 616 ]) 617def test_check_weights_raises_TypeError(atoms, weights): 618 with pytest.raises(TypeError): 619 util.get_weights(atoms, weights) 620 621@pytest.mark.parametrize('weights', 622 [ 623 np.array([12.0, 1.0, 12.0, 1.0]), 624 [12.0, 1.0], 625 np.array([[12.0, 1.0, 12.0]]), 626 np.array([[12.0, 1.0, 12.0], [12.0, 1.0, 12.0]]), 627 ]) 628def test_check_weights_raises_ValueError(atoms, weights): 629 with pytest.raises(ValueError): 630 util.get_weights(atoms, weights) 631 632 633class TestGuessFormat(object): 634 """Test guessing of format from filenames 635 636 Tests also getting the appropriate Parser and Reader from a 637 given filename 638 """ 639 # list of known formats, followed by the desired Parser and Reader 640 # None indicates that there isn't a Reader for this format 641 # All formats call fallback to the MinimalParser 642 formats = [ 643 ('CHAIN', mda.topology.MinimalParser.MinimalParser, mda.coordinates.chain.ChainReader), 644 ('CONFIG', mda.topology.DLPolyParser.ConfigParser, mda.coordinates.DLPoly.ConfigReader), 645 ('CRD', mda.topology.CRDParser.CRDParser, mda.coordinates.CRD.CRDReader), 646 ('DATA', mda.topology.LAMMPSParser.DATAParser, mda.coordinates.LAMMPS.DATAReader), 647 ('DCD', mda.topology.MinimalParser.MinimalParser, mda.coordinates.DCD.DCDReader), 648 ('DMS', mda.topology.DMSParser.DMSParser, mda.coordinates.DMS.DMSReader), 649 ('GMS', mda.topology.GMSParser.GMSParser, mda.coordinates.GMS.GMSReader), 650 ('GRO', mda.topology.GROParser.GROParser, mda.coordinates.GRO.GROReader), 651 ('HISTORY', mda.topology.DLPolyParser.HistoryParser, mda.coordinates.DLPoly.HistoryReader), 652 ('INPCRD', mda.topology.MinimalParser.MinimalParser, mda.coordinates.INPCRD.INPReader), 653 ('LAMMPS', mda.topology.MinimalParser.MinimalParser, mda.coordinates.LAMMPS.DCDReader), 654 ('MDCRD', mda.topology.MinimalParser.MinimalParser, mda.coordinates.TRJ.TRJReader), 655 ('MMTF', mda.topology.MMTFParser.MMTFParser, mda.coordinates.MMTF.MMTFReader), 656 ('MOL2', mda.topology.MOL2Parser.MOL2Parser, mda.coordinates.MOL2.MOL2Reader), 657 ('NC', mda.topology.MinimalParser.MinimalParser, mda.coordinates.TRJ.NCDFReader), 658 ('NCDF', mda.topology.MinimalParser.MinimalParser, mda.coordinates.TRJ.NCDFReader), 659 ('PDB', mda.topology.PDBParser.PDBParser, mda.coordinates.PDB.PDBReader), 660 ('PDBQT', mda.topology.PDBQTParser.PDBQTParser, mda.coordinates.PDBQT.PDBQTReader), 661 ('PRMTOP', mda.topology.TOPParser.TOPParser, None), 662 ('PQR', mda.topology.PQRParser.PQRParser, mda.coordinates.PQR.PQRReader), 663 ('PSF', mda.topology.PSFParser.PSFParser, None), 664 ('RESTRT', mda.topology.MinimalParser.MinimalParser, mda.coordinates.INPCRD.INPReader), 665 ('TOP', mda.topology.TOPParser.TOPParser, None), 666 ('TPR', mda.topology.TPRParser.TPRParser, None), 667 ('TRJ', mda.topology.MinimalParser.MinimalParser, mda.coordinates.TRJ.TRJReader), 668 ('TRR', mda.topology.MinimalParser.MinimalParser, mda.coordinates.TRR.TRRReader), 669 ('XML', mda.topology.HoomdXMLParser.HoomdXMLParser, None), 670 ('XPDB', mda.topology.ExtendedPDBParser.ExtendedPDBParser, mda.coordinates.PDB.ExtendedPDBReader), 671 ('XTC', mda.topology.MinimalParser.MinimalParser, mda.coordinates.XTC.XTCReader), 672 ('XYZ', mda.topology.XYZParser.XYZParser, mda.coordinates.XYZ.XYZReader), 673 ('TRZ', mda.topology.MinimalParser.MinimalParser, mda.coordinates.TRZ.TRZReader), 674 ] 675 # list of possible compressed extensions 676 # include no extension too! 677 compressed_extensions = ['.bz2', '.gz'] 678 679 @pytest.mark.parametrize('extention', 680 [format_tuple[0].upper() for format_tuple in 681 formats] + 682 [format_tuple[0].lower() for format_tuple in 683 formats]) 684 def test_get_extention(self, extention): 685 """Check that get_ext works""" 686 file_name = 'file.{0}'.format(extention) 687 a, b = util.get_ext(file_name) 688 689 assert a == 'file' 690 assert b == extention.lower() 691 692 @pytest.mark.parametrize('extention', 693 [format_tuple[0].upper() for format_tuple in 694 formats] + 695 [format_tuple[0].lower() for format_tuple in 696 formats]) 697 def test_compressed_without_compression_extention(self, extention): 698 """Check that format suffixed by compressed extension works""" 699 file_name = 'file.{0}'.format(extention) 700 a = util.format_from_filename_extension(file_name) 701 # expect answer to always be uppercase 702 assert a == extention.upper() 703 704 @pytest.mark.parametrize('extention', 705 [format_tuple[0].upper() for format_tuple in 706 formats] + 707 [format_tuple[0].lower() for format_tuple in 708 formats]) 709 @pytest.mark.parametrize('compression_extention', compressed_extensions) 710 def test_compressed(self, extention, compression_extention): 711 """Check that format suffixed by compressed extension works""" 712 file_name = 'file.{0}{1}'.format(extention, compression_extention) 713 a = util.format_from_filename_extension(file_name) 714 # expect answer to always be uppercase 715 assert a == extention.upper() 716 717 @pytest.mark.parametrize('extention', 718 [format_tuple[0].upper() for format_tuple in 719 formats] + [format_tuple[0].lower() for 720 format_tuple in formats]) 721 def test_guess_format(self, extention): 722 file_name = 'file.{0}'.format(extention) 723 a = util.guess_format(file_name) 724 # expect answer to always be uppercase 725 assert a == extention.upper() 726 727 @pytest.mark.parametrize('extention', 728 [format_tuple[0].upper() for format_tuple in 729 formats] + [format_tuple[0].lower() for 730 format_tuple in formats]) 731 @pytest.mark.parametrize('compression_extention', compressed_extensions) 732 def test_guess_format_compressed(self, extention, compression_extention): 733 file_name = 'file.{0}{1}'.format(extention, compression_extention) 734 a = util.guess_format(file_name) 735 # expect answer to always be uppercase 736 assert a == extention.upper() 737 738 @pytest.mark.parametrize('extention, parser', 739 [(format_tuple[0], format_tuple[1]) for 740 format_tuple in formats if 741 format_tuple[1] is not None] 742 ) 743 def test_get_parser(self, extention, parser): 744 file_name = 'file.{0}'.format(extention) 745 a = mda.topology.core.get_parser_for(file_name) 746 747 assert a == parser 748 749 @pytest.mark.parametrize('extention, parser', 750 [(format_tuple[0], format_tuple[1]) for 751 format_tuple in formats if 752 format_tuple[1] is not None] 753 ) 754 @pytest.mark.parametrize('compression_extention', compressed_extensions) 755 def test_get_parser_compressed(self, extention, parser, 756 compression_extention): 757 file_name = 'file.{0}{1}'.format(extention, compression_extention) 758 a = mda.topology.core.get_parser_for(file_name) 759 760 assert a == parser 761 762 @pytest.mark.parametrize('extention', 763 [(format_tuple[0], format_tuple[1]) for 764 format_tuple in formats if 765 format_tuple[1] is None] 766 ) 767 def test_get_parser_invalid(self, extention): 768 file_name = 'file.{0}'.format(extention) 769 with pytest.raises(ValueError): 770 mda.topology.core.get_parser_for(file_name) 771 772 @pytest.mark.parametrize('extention, reader', 773 [(format_tuple[0], format_tuple[2]) for 774 format_tuple in formats if 775 format_tuple[2] is not None] 776 ) 777 def test_get_reader(self, extention, reader): 778 file_name = 'file.{0}'.format(extention) 779 a = mda.coordinates.core.get_reader_for(file_name) 780 781 assert a == reader 782 783 @pytest.mark.parametrize('extention, reader', 784 [(format_tuple[0], format_tuple[2]) for 785 format_tuple in formats if 786 format_tuple[2] is not None] 787 ) 788 @pytest.mark.parametrize('compression_extention', compressed_extensions) 789 def test_get_reader_compressed(self, extention, reader, 790 compression_extention): 791 file_name = 'file.{0}{1}'.format(extention, compression_extention) 792 a = mda.coordinates.core.get_reader_for(file_name) 793 794 assert a == reader 795 796 @pytest.mark.parametrize('extention', 797 [(format_tuple[0], format_tuple[2]) for 798 format_tuple in formats if 799 format_tuple[2] is None] 800 ) 801 def test_get_reader_invalid(self, extention): 802 file_name = 'file.{0}'.format(extention) 803 with pytest.raises(ValueError): 804 mda.coordinates.core.get_reader_for(file_name) 805 806 def test_check_compressed_format_TypeError(self): 807 with pytest.raises(TypeError): 808 util.check_compressed_format(1234, 'bz2') 809 810 def test_format_from_filename_TypeError(self): 811 with pytest.raises(TypeError): 812 util.format_from_filename_extension(1234) 813 814 def test_guess_format_stream_ValueError(self): 815 # This stream has no name, so can't guess format 816 s = StringIO('this is a very fun file') 817 with pytest.raises(ValueError): 818 util.guess_format(s) 819 820 def test_from_ndarray(self): 821 fn = np.zeros((3, 3)) 822 rd = mda.coordinates.core.get_reader_for(fn) 823 assert rd == mda.coordinates.memory.MemoryReader 824 825 826class TestUniqueRows(object): 827 def test_unique_rows_2(self): 828 a = np.array([[0, 1], [1, 2], [2, 1], [0, 1], [0, 1], [2, 1]]) 829 830 assert_array_equal(util.unique_rows(a), 831 np.array([[0, 1], [1, 2], [2, 1]])) 832 833 def test_unique_rows_3(self): 834 a = np.array([[0, 1, 2], [0, 1, 2], [2, 3, 4], [0, 1, 2]]) 835 836 assert_array_equal(util.unique_rows(a), 837 np.array([[0, 1, 2], [2, 3, 4]])) 838 839 def test_unique_rows_with_view(self): 840 # unique_rows doesn't work when flags['OWNDATA'] is False, 841 # happens when second dimension is created through broadcast 842 a = np.array([1, 2]) 843 844 assert_array_equal(util.unique_rows(a[None, :]), 845 np.array([[1, 2]])) 846 847 848class TestGetWriterFor(object): 849 def test_no_filename_argument(self): 850 # Does ``get_writer_for`` fails as expected when provided no 851 # filename arguments 852 with pytest.raises(TypeError): 853 mda.coordinates.core.get_writer_for() 854 855 def test_precedence(self): 856 writer = mda.coordinates.core.get_writer_for('test.pdb', 'GRO') 857 assert writer == mda.coordinates.GRO.GROWriter 858 # Make sure ``get_writer_for`` uses *format* if provided 859 860 def test_missing_extension(self): 861 # Make sure ``get_writer_for`` behave as expected if *filename* 862 # has no extension 863 with pytest.raises(TypeError): 864 mda.coordinates.core.get_writer_for(filename='test', format=None) 865 866 def test_wrong_format(self): 867 # Make sure ``get_writer_for`` fails if the format is unknown 868 with pytest.raises(TypeError): 869 mda.coordinates.core.get_writer_for(filename="fail_me", 870 format='UNK') 871 872 def test_compressed_extension(self): 873 for ext in ('.gz', '.bz2'): 874 fn = 'test.gro' + ext 875 writer = mda.coordinates.core.get_writer_for(filename=fn) 876 assert writer == mda.coordinates.GRO.GROWriter 877 # Make sure ``get_writer_for`` works with compressed file file names 878 879 def test_compressed_extension_fail(self): 880 for ext in ('.gz', '.bz2'): 881 fn = 'test.unk' + ext 882 # Make sure ``get_writer_for`` fails if an unknown format is compressed 883 with pytest.raises(TypeError): 884 mda.coordinates.core.get_writer_for(filename=fn) 885 886 def test_non_string_filename(self): 887 # Does ``get_writer_for`` fails with non string filename, no format 888 with pytest.raises(ValueError): 889 mda.coordinates.core.get_writer_for(filename=StringIO(), 890 format=None) 891 892 def test_multiframe_failure(self): 893 # does ``get_writer_for`` fail with invalid format and multiframe not None 894 with pytest.raises(TypeError): 895 mda.coordinates.core.get_writer_for(filename="fail_me", 896 format='UNK', multiframe=True) 897 mda.coordinates.core.get_writer_for(filename="fail_me", 898 format='UNK', multiframe=False) 899 900 def test_multiframe_nonsense(self): 901 with pytest.raises(ValueError): 902 mda.coordinates.core.get_writer_for(filename='this.gro', 903 multiframe='sandwich') 904 905 formats = [ 906 # format name, related class, singleframe, multiframe 907 ('CRD', mda.coordinates.CRD.CRDWriter, True, False), 908 ('DATA', mda.coordinates.LAMMPS.DATAWriter, True, False), 909 ('DCD', mda.coordinates.DCD.DCDWriter, True, True), 910 # ('ENT', mda.coordinates.PDB.PDBWriter, True, False), 911 ('GRO', mda.coordinates.GRO.GROWriter, True, False), 912 ('LAMMPS', mda.coordinates.LAMMPS.DCDWriter, True, True), 913 ('MOL2', mda.coordinates.MOL2.MOL2Writer, True, True), 914 ('NCDF', mda.coordinates.TRJ.NCDFWriter, True, True), 915 ('NULL', mda.coordinates.null.NullWriter, True, True), 916 # ('PDB', mda.coordinates.PDB.PDBWriter, True, True), special case, done separately 917 ('PDBQT', mda.coordinates.PDBQT.PDBQTWriter, True, False), 918 ('PQR', mda.coordinates.PQR.PQRWriter, True, False), 919 ('TRR', mda.coordinates.TRR.TRRWriter, True, True), 920 ('XTC', mda.coordinates.XTC.XTCWriter, True, True), 921 ('XYZ', mda.coordinates.XYZ.XYZWriter, True, True), 922 ('TRZ', mda.coordinates.TRZ.TRZWriter, True, True), 923 ] 924 925 @pytest.mark.parametrize('format, writer', 926 [(format_tuple[0], format_tuple[1]) for 927 format_tuple in formats if 928 format_tuple[2] is True]) 929 def test_singleframe(self, format, writer): 930 assert mda.coordinates.core.get_writer_for('this', format=format, 931 multiframe=False) == writer 932 933 @pytest.mark.parametrize('format', [(format_tuple[0], format_tuple[1]) for 934 format_tuple in formats if 935 format_tuple[2] is False]) 936 def test_singleframe_fails(self, format): 937 with pytest.raises(TypeError): 938 mda.coordinates.core.get_writer_for('this', format=format, 939 multiframe=False) 940 941 @pytest.mark.parametrize('format, writer', 942 [(format_tuple[0], format_tuple[1]) for 943 format_tuple in formats if 944 format_tuple[3] is True]) 945 def test_multiframe(self, format, writer): 946 assert mda.coordinates.core.get_writer_for('this', format=format, 947 multiframe=True) == writer 948 949 @pytest.mark.parametrize('format', 950 [format_tuple[0] for format_tuple in formats if 951 format_tuple[3] is False]) 952 def test_multiframe_fails(self, format): 953 with pytest.raises(TypeError): 954 mda.coordinates.core.get_writer_for('this', format=format, 955 multiframe=True) 956 957 def test_get_writer_for_pdb(self): 958 assert mda.coordinates.core.get_writer_for('this', format='PDB', 959 multiframe=False) == mda.coordinates.PDB.PDBWriter 960 assert mda.coordinates.core.get_writer_for('this', format='PDB', 961 multiframe=True) == mda.coordinates.PDB.MultiPDBWriter 962 assert mda.coordinates.core.get_writer_for('this', format='ENT', 963 multiframe=False) == mda.coordinates.PDB.PDBWriter 964 assert mda.coordinates.core.get_writer_for('this', format='ENT', 965 multiframe=True) == mda.coordinates.PDB.MultiPDBWriter 966 967 968class TestBlocksOf(object): 969 def test_blocks_of_1(self): 970 arr = np.arange(16).reshape(4, 4) 971 972 view = util.blocks_of(arr, 1, 1) 973 974 assert view.shape == (4, 1, 1) 975 assert_array_almost_equal(view, 976 np.array([[[0]], [[5]], [[10]], [[15]]])) 977 978 # Change my view, check changes are reflected in arr 979 view[:] = 1001 980 981 assert_array_almost_equal(arr, 982 np.array([[1001, 1, 2, 3], 983 [4, 1001, 6, 7], 984 [8, 9, 1001, 11], 985 [12, 13, 14, 1001]])) 986 987 def test_blocks_of_2(self): 988 arr = np.arange(16).reshape(4, 4) 989 990 view = util.blocks_of(arr, 2, 2) 991 992 assert view.shape == (2, 2, 2) 993 assert_array_almost_equal(view, np.array([[[0, 1], [4, 5]], 994 [[10, 11], [14, 15]]])) 995 996 view[0] = 100 997 view[1] = 200 998 999 assert_array_almost_equal(arr, 1000 np.array([[100, 100, 2, 3], 1001 [100, 100, 6, 7], 1002 [8, 9, 200, 200], 1003 [12, 13, 200, 200]])) 1004 1005 def test_blocks_of_3(self): 1006 # testing non square array 1007 arr = np.arange(32).reshape(8, 4) 1008 1009 view = util.blocks_of(arr, 2, 1) 1010 1011 assert view.shape == (4, 2, 1) 1012 1013 def test_blocks_of_4(self): 1014 # testing block exceeding array size results in empty view 1015 arr = np.arange(4).reshape(2, 2) 1016 view = util.blocks_of(arr, 3, 3) 1017 assert view.shape == (0, 3, 3) 1018 view[:] = 100 1019 assert_array_equal(arr, np.arange(4).reshape(2, 2)) 1020 1021 def test_blocks_of_ValueError(self): 1022 arr = np.arange(16).reshape(4, 4) 1023 with pytest.raises(ValueError): 1024 util.blocks_of(arr, 2, 1) # blocks don't fit 1025 with pytest.raises(ValueError): 1026 util.blocks_of(arr[:, ::2], 2, 1) # non-contiguous input 1027 1028 1029class TestNamespace(object): 1030 @staticmethod 1031 @pytest.fixture() 1032 def ns(): 1033 return util.Namespace() 1034 1035 def test_getitem(self, ns): 1036 ns.this = 42 1037 assert ns['this'] == 42 1038 1039 def test_getitem_KeyError(self, ns): 1040 with pytest.raises(KeyError): 1041 dict.__getitem__(ns, 'this') 1042 1043 def test_setitem(self, ns): 1044 ns['this'] = 42 1045 1046 assert ns['this'] == 42 1047 1048 def test_delitem(self, ns): 1049 ns['this'] = 42 1050 assert 'this' in ns 1051 del ns['this'] 1052 assert 'this' not in ns 1053 1054 def test_delitem_AttributeError(self, ns): 1055 with pytest.raises(AttributeError): 1056 del ns.this 1057 1058 def test_setattr(self, ns): 1059 ns.this = 42 1060 1061 assert ns.this == 42 1062 1063 def test_getattr(self, ns): 1064 ns['this'] = 42 1065 1066 assert ns.this == 42 1067 1068 def test_getattr_AttributeError(self, ns): 1069 with pytest.raises(AttributeError): 1070 getattr(ns, 'this') 1071 1072 def test_delattr(self, ns): 1073 ns['this'] = 42 1074 1075 assert 'this' in ns 1076 del ns.this 1077 assert 'this' not in ns 1078 1079 def test_eq(self, ns): 1080 ns['this'] = 42 1081 1082 ns2 = util.Namespace() 1083 ns2['this'] = 42 1084 1085 assert ns == ns2 1086 1087 def test_len(self, ns): 1088 assert len(ns) == 0 1089 ns['this'] = 1 1090 ns['that'] = 2 1091 assert len(ns) == 2 1092 1093 def test_iter(self, ns): 1094 ns['this'] = 12 1095 ns['that'] = 24 1096 ns['other'] = 48 1097 1098 seen = [] 1099 for val in ns: 1100 seen.append(val) 1101 for val in ['this', 'that', 'other']: 1102 assert val in seen 1103 1104 1105class TestTruncateInteger(object): 1106 @pytest.mark.parametrize('a, b', [ 1107 ((1234, 1), 4), 1108 ((1234, 2), 34), 1109 ((1234, 3), 234), 1110 ((1234, 4), 1234), 1111 ((1234, 5), 1234), 1112 ]) 1113 def test_ltruncate_int(self, a, b): 1114 assert util.ltruncate_int(*a) == b 1115 1116class TestFlattenDict(object): 1117 def test_flatten_dict(self): 1118 d = { 1119 'A' : { 1 : ('a', 'b', 'c')}, 1120 'B' : { 2 : ('c', 'd', 'e')}, 1121 'C' : { 3 : ('f', 'g', 'h')} 1122 } 1123 result = util.flatten_dict(d) 1124 1125 for k in result: 1126 assert type(k) == tuple 1127 assert len(k) == 2 1128 assert k[0] in d 1129 assert k[1] in d[k[0]] 1130 assert result[k] in d[k[0]].values() 1131 1132class TestStaticVariables(object): 1133 """Tests concerning the decorator @static_variables 1134 """ 1135 1136 def test_static_variables(self): 1137 x = [0] 1138 1139 @static_variables(foo=0, bar={'test': x}) 1140 def myfunc(): 1141 assert myfunc.foo is 0 1142 assert type(myfunc.bar) is type(dict()) 1143 if 'test2' not in myfunc.bar: 1144 myfunc.bar['test2'] = "a" 1145 else: 1146 myfunc.bar['test2'] += "a" 1147 myfunc.bar['test'][0] += 1 1148 return myfunc.bar['test'] 1149 1150 assert hasattr(myfunc, 'foo') 1151 assert hasattr(myfunc, 'bar') 1152 1153 y = myfunc() 1154 assert y is x 1155 assert x[0] is 1 1156 assert myfunc.bar['test'][0] is 1 1157 assert myfunc.bar['test2'] == "a" 1158 1159 x = [0] 1160 y = myfunc() 1161 assert y is not x 1162 assert myfunc.bar['test'][0] is 2 1163 assert myfunc.bar['test2'] == "aa" 1164 1165class TestWarnIfNotUnique(object): 1166 """Tests concerning the decorator @warn_if_not_uniue 1167 """ 1168 1169 @pytest.fixture() 1170 def warn_msg(self, func, group, group_name): 1171 msg = ("{}.{}(): {} {} contains duplicates. Results might be " 1172 "biased!".format(group.__class__.__name__, func.__name__, 1173 group_name, group.__repr__())) 1174 return msg 1175 1176 def test_warn_if_not_unique(self, atoms): 1177 # Check that the warn_if_not_unique decorator has a "static variable" 1178 # warn_if_not_unique.warned: 1179 assert hasattr(warn_if_not_unique, 'warned') 1180 assert warn_if_not_unique.warned is False 1181 1182 def test_warn_if_not_unique_once_outer(self, atoms): 1183 1184 # Construct a scenario with two nested functions, each one decorated 1185 # with @warn_if_not_unique: 1186 1187 @warn_if_not_unique 1188 def inner(group): 1189 if not group.isunique: 1190 # The inner function should not trigger a warning, and the state 1191 # of warn_if_not_unique.warned should reflect that: 1192 assert warn_if_not_unique.warned is True 1193 return 0 1194 1195 @warn_if_not_unique 1196 def outer(group): 1197 return inner(group) 1198 1199 # Check that no warning is raised for a unique group: 1200 assert atoms.isunique 1201 with pytest.warns(None) as w: 1202 x = outer(atoms) 1203 assert x is 0 1204 assert not w.list 1205 1206 # Check that a warning is raised for a group with duplicates: 1207 ag = atoms + atoms[0] 1208 msg = self.warn_msg(outer, ag, "'ag'") 1209 with pytest.warns(DuplicateWarning) as w: 1210 assert warn_if_not_unique.warned is False 1211 x = outer(ag) 1212 # Assert that the "warned" state is restored: 1213 assert warn_if_not_unique.warned is False 1214 # Check correct function execution: 1215 assert x is 0 1216 # Only one warning must have been raised: 1217 assert len(w) == 1 1218 # For whatever reason pytest.warns(DuplicateWarning, match=msg) 1219 # doesn't work, so we compare the recorded warning message instead: 1220 assert w[0].message.args[0] == msg 1221 # Make sure the warning uses the correct stacklevel and references 1222 # this file instead of MDAnalysis/lib/util.py: 1223 assert w[0].filename == __file__ 1224 1225 def test_warned_state_restored_on_failure(self, atoms): 1226 1227 # A decorated function raising an exception: 1228 @warn_if_not_unique 1229 def thisfails(group): 1230 raise ValueError() 1231 1232 ag = atoms + atoms[0] 1233 msg = self.warn_msg(thisfails, ag, "'ag'") 1234 with pytest.warns(DuplicateWarning) as w: 1235 assert warn_if_not_unique.warned is False 1236 with pytest.raises(ValueError): 1237 thisfails(ag) 1238 # Assert that the "warned" state is restored despite `thisfails` 1239 # raising an exception: 1240 assert warn_if_not_unique.warned is False 1241 assert len(w) == 1 1242 assert w[0].message.args[0] == msg 1243 assert w[0].filename == __file__ 1244 1245 def test_warn_if_not_unique_once_inner(self, atoms): 1246 1247 # Construct a scenario with two nested functions, each one decorated 1248 # with @warn_if_not_unique, but the outer function adds a duplicate 1249 # to the group: 1250 1251 @warn_if_not_unique 1252 def inner(group): 1253 return 0 1254 1255 @warn_if_not_unique 1256 def outer(group): 1257 dupgroup = group + group[0] 1258 return inner(dupgroup) 1259 1260 # Check that even though outer() is called the warning is raised for 1261 # inner(): 1262 msg = self.warn_msg(inner, atoms + atoms[0], "'dupgroup'") 1263 with pytest.warns(DuplicateWarning) as w: 1264 assert warn_if_not_unique.warned is False 1265 x = outer(atoms) 1266 # Assert that the "warned" state is restored: 1267 assert warn_if_not_unique.warned is False 1268 # Check correct function execution: 1269 assert x is 0 1270 # Only one warning must have been raised: 1271 assert len(w) == 1 1272 assert w[0].message.args[0] == msg 1273 assert w[0].filename == __file__ 1274 1275 def test_warn_if_not_unique_multiple_references(self, atoms): 1276 ag = atoms + atoms[0] 1277 aag = ag 1278 aaag = aag 1279 1280 @warn_if_not_unique 1281 def func(group): 1282 return group.isunique 1283 1284 # Check that the warning message contains the names of all references to 1285 # the group in alphabetic order: 1286 msg = self.warn_msg(func, ag, "'aaag' a.k.a. 'aag' a.k.a. 'ag'") 1287 with pytest.warns(DuplicateWarning) as w: 1288 x = func(ag) 1289 # Assert that the "warned" state is restored: 1290 assert warn_if_not_unique.warned is False 1291 # Check correct function execution: 1292 assert x is False 1293 # Check warning message: 1294 assert w[0].message.args[0] == msg 1295 # Check correct file referenced: 1296 assert w[0].filename == __file__ 1297 1298 def test_warn_if_not_unique_unnamed(self, atoms): 1299 1300 @warn_if_not_unique 1301 def func(group): 1302 pass 1303 1304 msg = self.warn_msg(func, atoms + atoms[0], 1305 "'unnamed {}'".format(atoms.__class__.__name__)) 1306 with pytest.warns(DuplicateWarning) as w: 1307 func(atoms + atoms[0]) 1308 # Check warning message: 1309 assert w[0].message.args[0] == msg 1310 1311 def test_warn_if_not_unique_fails_for_non_groupmethods(self): 1312 1313 @warn_if_not_unique 1314 def func(group): 1315 pass 1316 1317 class dummy(object): 1318 pass 1319 1320 with pytest.raises(AttributeError): 1321 func(dummy()) 1322 1323 def test_filter_duplicate_with_userwarning(self, atoms): 1324 1325 @warn_if_not_unique 1326 def func(group): 1327 pass 1328 1329 with warnings.catch_warnings(record=True) as record: 1330 warnings.resetwarnings() 1331 warnings.filterwarnings("ignore", category=UserWarning) 1332 with pytest.warns(None) as w: 1333 func(atoms) 1334 assert not w.list 1335 assert len(record) == 0 1336 1337class TestCheckCoords(object): 1338 """Tests concerning the decorator @check_coords 1339 """ 1340 1341 prec = 6 1342 1343 def test_default_options(self): 1344 a_in = np.zeros(3, dtype=np.float32) 1345 b_in = np.ones(3, dtype=np.float32) 1346 b_in2 = np.ones((2, 3), dtype=np.float32) 1347 1348 @check_coords('a','b') 1349 def func(a, b): 1350 # check that enforce_copy is True by default: 1351 assert a is not a_in 1352 assert b is not b_in 1353 # check that convert_single is True by default: 1354 assert a.shape == (1, 3) 1355 assert b.shape == (1, 3) 1356 return a + b 1357 1358 # check that allow_single is True by default: 1359 res = func(a_in, b_in) 1360 # check that reduce_result_if_single is True by default: 1361 assert res.shape == (3,) 1362 # check correct function execution: 1363 assert_array_equal(res, b_in) 1364 1365 # check that check_lenghts_match is True by default: 1366 with pytest.raises(ValueError): 1367 res = func(a_in, b_in2) 1368 1369 def test_enforce_copy(self): 1370 1371 a_2d = np.ones((1, 3), dtype=np.float32) 1372 b_1d = np.zeros(3, dtype=np.float32) 1373 c_2d = np.zeros((1, 6), dtype=np.float32)[:, ::2] 1374 d_2d = np.zeros((1, 3), dtype=np.int64) 1375 1376 @check_coords('a', 'b', 'c', 'd', enforce_copy=False) 1377 def func(a, b, c, d): 1378 # Assert that if enforce_copy is False: 1379 # no copy is made if input shape, order, and dtype are correct: 1380 assert a is a_2d 1381 # a copy is made if input shape has to be changed: 1382 assert b is not b_1d 1383 # a copy is made if input order has to be changed: 1384 assert c is not c_2d 1385 # a copy is made if input dtype has to be changed: 1386 assert d is not d_2d 1387 # Assert correct dtype conversion: 1388 assert d.dtype == np.float32 1389 assert_almost_equal(d, d_2d, self.prec) 1390 # Assert all shapes are converted to (1, 3): 1391 assert a.shape == b.shape == c.shape == d.shape == (1, 3) 1392 return a + b + c + d 1393 1394 # Call func() to: 1395 # - test the above assertions 1396 # - ensure that input of single coordinates is simultaneously possible 1397 # with different shapes (3,) and (1, 3) 1398 res = func(a_2d, b_1d, c_2d, d_2d) 1399 # Since some inputs are not 1d, even though reduce_result_if_single is 1400 # True, the result must have shape (1, 3): 1401 assert res.shape == (1, 3) 1402 # check correct function execution: 1403 assert_array_equal(res, a_2d) 1404 1405 def test_no_allow_single(self): 1406 1407 @check_coords('a', allow_single=False) 1408 def func(a): 1409 pass 1410 1411 with pytest.raises(ValueError) as err: 1412 func(np.zeros(3, dtype=np.float32)) 1413 assert err.msg == ("func(): a.shape must be (n, 3), got (3,).") 1414 1415 def test_no_convert_single(self): 1416 1417 a_1d = np.arange(-3, 0, dtype=np.float32) 1418 1419 @check_coords('a', enforce_copy=False, convert_single=False) 1420 def func(a): 1421 # assert no conversion and no copy were performed: 1422 assert a is a_1d 1423 return a 1424 1425 res = func(a_1d) 1426 # Assert result has been reduced: 1427 assert res == a_1d[0] 1428 assert type(res) is np.float32 1429 1430 def test_no_reduce_result_if_single(self): 1431 1432 a_1d = np.zeros(3, dtype=np.float32) 1433 1434 # Test without shape conversion: 1435 @check_coords('a', enforce_copy=False, convert_single=False, 1436 reduce_result_if_single=False) 1437 def func(a): 1438 return a 1439 1440 res = func(a_1d) 1441 # make sure the input array is just passed through: 1442 assert res is a_1d 1443 1444 # Test with shape conversion: 1445 @check_coords('a', enforce_copy=False, reduce_result_if_single=False) 1446 def func(a): 1447 return a 1448 1449 res = func(a_1d) 1450 assert res.shape == (1, 3) 1451 assert_array_equal(res[0], a_1d) 1452 1453 def test_no_check_lengths_match(self): 1454 1455 a_2d = np.zeros((1, 3), dtype=np.float32) 1456 b_2d = np.zeros((3, 3), dtype=np.float32) 1457 1458 @check_coords('a', 'b', enforce_copy=False, check_lengths_match=False) 1459 def func(a, b): 1460 return a, b 1461 1462 res_a, res_b = func(a_2d, b_2d) 1463 # Assert arrays are just passed through: 1464 assert res_a is a_2d 1465 assert res_b is b_2d 1466 1467 def test_invalid_input(self): 1468 1469 a_inv_dtype = np.array([['hello', 'world', '!']]) 1470 a_inv_type = [[0., 0., 0.]] 1471 a_inv_shape_1d = np.zeros(6, dtype=np.float32) 1472 a_inv_shape_2d = np.zeros((3, 2), dtype=np.float32) 1473 1474 @check_coords('a') 1475 def func(a): 1476 pass 1477 1478 with pytest.raises(TypeError) as err: 1479 func(a_inv_dtype) 1480 assert err.msg.startswith("func(): a.dtype must be convertible to " 1481 "float32, got ") 1482 1483 with pytest.raises(TypeError) as err: 1484 func(a_inv_type) 1485 assert err.msg == ("func(): Parameter 'a' must be a numpy.ndarray, " 1486 "got <class 'list'>.") 1487 1488 with pytest.raises(ValueError) as err: 1489 func(a_inv_shape_1d) 1490 assert err.msg == ("func(): a.shape must be (3,) or (n, 3), got " 1491 "(6,).") 1492 1493 with pytest.raises(ValueError) as err: 1494 func(a_inv_shape_2d) 1495 assert err.msg == ("func(): a.shape must be (3,) or (n, 3), got " 1496 "(3, 2).") 1497 1498 def test_usage_with_kwargs(self): 1499 1500 a_2d = np.zeros((1, 3), dtype=np.float32) 1501 1502 @check_coords('a', enforce_copy=False) 1503 def func(a, b, c=0): 1504 return a, b, c 1505 1506 # check correct functionality if passed as keyword argument: 1507 a, b, c = func(a=a_2d, b=0, c=1) 1508 assert a is a_2d 1509 assert b == 0 1510 assert c == 1 1511 1512 def test_wrong_func_call(self): 1513 1514 @check_coords('a', enforce_copy=False) 1515 def func(a, b, c=0): 1516 pass 1517 1518 # Make sure invalid call marker is present: 1519 func._invalid_call = False 1520 1521 # usage with posarg doubly defined: 1522 assert not func._invalid_call 1523 with pytest.raises(TypeError): 1524 func(0, a=0) # pylint: disable=redundant-keyword-arg 1525 assert func._invalid_call 1526 func._invalid_call = False 1527 1528 # usage with missing posargs: 1529 assert not func._invalid_call 1530 with pytest.raises(TypeError): 1531 func(0) 1532 assert func._invalid_call 1533 func._invalid_call = False 1534 1535 # usage with missing posargs (supplied as kwargs): 1536 assert not func._invalid_call 1537 with pytest.raises(TypeError): 1538 func(a=0, c=1) 1539 assert func._invalid_call 1540 func._invalid_call = False 1541 1542 # usage with too many posargs: 1543 assert not func._invalid_call 1544 with pytest.raises(TypeError): 1545 func(0, 0, 0, 0) 1546 assert func._invalid_call 1547 func._invalid_call = False 1548 1549 # usage with unexpected kwarg: 1550 assert not func._invalid_call 1551 with pytest.raises(TypeError): 1552 func(a=0, b=0, c=1, d=1) # pylint: disable=unexpected-keyword-arg 1553 assert func._invalid_call 1554 func._invalid_call = False 1555 1556 def test_wrong_decorator_usage(self): 1557 1558 # usage without parantheses: 1559 @check_coords 1560 def func(): 1561 pass 1562 1563 with pytest.raises(TypeError): 1564 func() 1565 1566 # usage without arguments: 1567 with pytest.raises(ValueError) as err: 1568 @check_coords() 1569 def func(): 1570 pass 1571 1572 assert err.msg == ("Decorator check_coords() cannot be used " 1573 "without positional arguments.") 1574 1575 # usage with defaultarg: 1576 with pytest.raises(ValueError) as err: 1577 @check_coords('a') 1578 def func(a=1): 1579 pass 1580 1581 assert err.msg == ("In decorator check_coords(): Name 'a' doesn't " 1582 "correspond to any positional argument of the " 1583 "decorated function func().") 1584 1585 # usage with invalid parameter name: 1586 with pytest.raises(ValueError) as err: 1587 @check_coords('b') 1588 def func(a): 1589 pass 1590 1591 assert err.msg == ("In decorator check_coords(): Name 'b' doesn't " 1592 "correspond to any positional argument of the " 1593 "decorated function func().") 1594 1595 1596@pytest.mark.parametrize("old_name", (None, "MDAnalysis.Universe")) 1597@pytest.mark.parametrize("new_name", (None, "Multiverse")) 1598@pytest.mark.parametrize("remove", (None, "99.0.0", 2099)) 1599@pytest.mark.parametrize("message", (None, "use the new stuff")) 1600def test_deprecate(old_name, new_name, remove, message, release="2.7.1"): 1601 def AlternateUniverse(anything): 1602 # important: first line needs to be """\ so that textwrap.dedent() 1603 # works 1604 """\ 1605 AlternateUniverse provides a true view of the Universe. 1606 1607 Parameters 1608 ---------- 1609 anything : object 1610 1611 Returns 1612 ------- 1613 truth 1614 1615 """ 1616 return True 1617 1618 oldfunc = util.deprecate(AlternateUniverse, old_name=old_name, 1619 new_name=new_name, 1620 release=release, remove=remove, 1621 message=message) 1622 with pytest.warns(DeprecationWarning, match_expr="`.+` is deprecated"): 1623 oldfunc(42) 1624 1625 doc = oldfunc.__doc__ 1626 name = old_name if old_name else AlternateUniverse.__name__ 1627 1628 deprecation_line_1 = ".. deprecated:: {0}".format(release) 1629 assert re.search(deprecation_line_1, doc) 1630 1631 if message: 1632 deprecation_line_2 = message 1633 else: 1634 if new_name is None: 1635 default_message = "`{0}` is deprecated!".format(name) 1636 else: 1637 default_message = "`{0}` is deprecated, use `{1}` instead!".format( 1638 name, new_name) 1639 deprecation_line_2 = default_message 1640 assert re.search(deprecation_line_2, doc) 1641 1642 if remove: 1643 deprecation_line_3 = "`{0}` will be removed in release {1}".format( 1644 name, remove) 1645 assert re.search(deprecation_line_3, doc) 1646 1647 # check that the old docs are still present 1648 assert re.search(textwrap.dedent(AlternateUniverse.__doc__), doc) 1649 1650 1651def test_deprecate_missing_release_ValueError(): 1652 with pytest.raises(ValueError): 1653 util.deprecate(mda.Universe) 1654 1655def test_set_function_name(name="bar"): 1656 def foo(): 1657 pass 1658 util._set_function_name(foo, name) 1659 assert foo.__name__ == name 1660 1661@pytest.mark.parametrize("text", 1662 ("", 1663 "one line text", 1664 " one line with leading space", 1665 "multiline\n\n with some\n leading space", 1666 " multiline\n\n with all\n leading space")) 1667def test_dedent_docstring(text): 1668 doc = util.dedent_docstring(text) 1669 for line in doc.splitlines(): 1670 assert line == line.lstrip() 1671 1672 1673class TestCheckBox(object): 1674 1675 prec = 6 1676 ref_ortho = np.ones(3, dtype=np.float32) 1677 ref_tri_vecs = np.array([[1, 0, 0], [0, 1, 0], [0, 2 ** 0.5, 2 ** 0.5]], 1678 dtype=np.float32) 1679 1680 @pytest.mark.parametrize('box', 1681 ([1, 1, 1, 90, 90, 90], 1682 (1, 1, 1, 90, 90, 90), 1683 ['1', '1', 1, 90, '90', '90'], 1684 ('1', '1', 1, 90, '90', '90'), 1685 np.array(['1', '1', 1, 90, '90', '90']), 1686 np.array([1, 1, 1, 90, 90, 90], dtype=np.float32), 1687 np.array([1, 1, 1, 90, 90, 90], dtype=np.float64), 1688 np.array([1, 1, 1, 1, 1, 1, 90, 90, 90, 90, 90, 90], 1689 dtype=np.float32)[::2])) 1690 def test_ckeck_box_ortho(self, box): 1691 boxtype, checked_box = util.check_box(box) 1692 assert boxtype == 'ortho' 1693 assert_equal(checked_box, self.ref_ortho) 1694 assert checked_box.dtype == np.float32 1695 assert checked_box.flags['C_CONTIGUOUS'] 1696 1697 @pytest.mark.parametrize('box', 1698 ([1, 1, 2, 45, 90, 90], 1699 (1, 1, 2, 45, 90, 90), 1700 ['1', '1', 2, 45, '90', '90'], 1701 ('1', '1', 2, 45, '90', '90'), 1702 np.array(['1', '1', 2, 45, '90', '90']), 1703 np.array([1, 1, 2, 45, 90, 90], dtype=np.float32), 1704 np.array([1, 1, 2, 45, 90, 90], dtype=np.float64), 1705 np.array([1, 1, 1, 1, 2, 2, 45, 45, 90, 90, 90, 90], 1706 dtype=np.float32)[::2])) 1707 def test_check_box_tri_vecs(self, box): 1708 boxtype, checked_box = util.check_box(box) 1709 assert boxtype == 'tri_vecs' 1710 assert_almost_equal(checked_box, self.ref_tri_vecs, self.prec) 1711 assert checked_box.dtype == np.float32 1712 assert checked_box.flags['C_CONTIGUOUS'] 1713 1714 def test_check_box_wrong_data(self): 1715 with pytest.raises(ValueError): 1716 wrongbox = ['invalid', 1, 1, 90, 90, 90] 1717 boxtype, checked_box = util.check_box(wrongbox) 1718 1719 def test_check_box_wrong_shape(self): 1720 with pytest.raises(ValueError): 1721 wrongbox = np.ones((3, 3), dtype=np.float32) 1722 boxtype, checked_box = util.check_box(wrongbox) 1723