1# -*- Mode: python; tab-width: 4; indent-tabs-mode:nil; coding:utf-8 -*-
2# vim: tabstop=4 expandtab shiftwidth=4 softtabstop=4 fileencoding=utf-8
3#
4# MDAnalysis --- https://www.mdanalysis.org
5# Copyright (c) 2006-2017 The MDAnalysis Development Team and contributors
6# (see the file AUTHORS for the full list of names)
7#
8# Released under the GNU Public Licence, v2 or any higher version
9#
10# Please cite your use of MDAnalysis in published work:
11#
12# R. J. Gowers, M. Linke, J. Barnoud, T. J. E. Reddy, M. N. Melo, S. L. Seyler,
13# D. L. Dotson, J. Domanski, S. Buchoux, I. M. Kenney, and O. Beckstein.
14# MDAnalysis: A Python package for the rapid analysis of molecular dynamics
15# simulations. In S. Benthall and S. Rostrup editors, Proceedings of the 15th
16# Python in Science Conference, pages 102-109, Austin, TX, 2016. SciPy.
17#
18# N. Michaud-Agrawal, E. J. Denning, T. B. Woolf, and O. Beckstein.
19# MDAnalysis: A Toolkit for the Analysis of Molecular Dynamics Simulations.
20# J. Comput. Chem. 32 (2011), 2319--2327, doi:10.1002/jcc.21787
21#
22from __future__ import absolute_import, division
23
24
25from six.moves import range, StringIO
26import pytest
27import os
28import warnings
29import re
30import textwrap
31
32import numpy as np
33from numpy.testing import (assert_equal, assert_almost_equal,
34                           assert_array_almost_equal, assert_array_equal)
35
36import MDAnalysis as mda
37import MDAnalysis.lib.util as util
38import MDAnalysis.lib.mdamath as mdamath
39from MDAnalysis.lib.util import (cached, static_variables, warn_if_not_unique,
40                                 check_coords)
41from MDAnalysis.core.topologyattrs import Bonds
42from MDAnalysis.exceptions import NoDataError, DuplicateWarning
43
44
45from MDAnalysisTests.datafiles import (
46    Make_Whole, TPR, GRO, fullerene, two_water_gro,
47)
48
49
50def convert_aa_code_long_data():
51    aa = [
52        ('H',
53         ('HIS', 'HISA', 'HISB', 'HSE', 'HSD', 'HIS1', 'HIS2', 'HIE', 'HID')),
54        ('K', ('LYS', 'LYSH', 'LYN')),
55        ('A', ('ALA',)),
56        ('D', ('ASP', 'ASPH', 'ASH')),
57        ('E', ('GLU', 'GLUH', 'GLH')),
58        ('N', ('ASN',)),
59        ('Q', ('GLN',)),
60        ('C', ('CYS', 'CYSH', 'CYS1', 'CYS2')),
61    ]
62    for resname1, strings in aa:
63        for resname3 in strings:
64            yield (resname3, resname1)
65
66
67class TestStringFunctions(object):
68    # (1-letter, (canonical 3 letter, other 3/4 letter, ....))
69    aa = [
70        ('H',
71         ('HIS', 'HISA', 'HISB', 'HSE', 'HSD', 'HIS1', 'HIS2', 'HIE', 'HID')),
72        ('K', ('LYS', 'LYSH', 'LYN')),
73        ('A', ('ALA',)),
74        ('D', ('ASP', 'ASPH', 'ASH')),
75        ('E', ('GLU', 'GLUH', 'GLH')),
76        ('N', ('ASN',)),
77        ('Q', ('GLN',)),
78        ('C', ('CYS', 'CYSH', 'CYS1', 'CYS2')),
79    ]
80
81    residues = [
82        ("LYS300:HZ1", ("LYS", 300, "HZ1")),
83        ("K300:HZ1", ("LYS", 300, "HZ1")),
84        ("K300", ("LYS", 300, None)),
85        ("LYS 300:HZ1", ("LYS", 300, "HZ1")),
86        ("M1:CA", ("MET", 1, "CA")),
87    ]
88
89    @pytest.mark.parametrize('rstring, residue', residues)
90    def test_parse_residue(self, rstring, residue):
91        assert util.parse_residue(rstring) == residue
92
93    def test_parse_residue_ValueError(self):
94        with pytest.raises(ValueError):
95            util.parse_residue('ZZZ')
96
97    @pytest.mark.parametrize('resname3, resname1', convert_aa_code_long_data())
98    def test_convert_aa_3to1(self, resname3, resname1):
99        assert util.convert_aa_code(resname3) == resname1
100
101    @pytest.mark.parametrize('resname1, strings', aa)
102    def test_convert_aa_1to3(self, resname1, strings):
103        assert util.convert_aa_code(resname1) == strings[0]
104
105    @pytest.mark.parametrize('x', (
106            'XYZXYZ',
107            '£'
108    ))
109    def test_ValueError(self, x):
110        with pytest.raises(ValueError):
111            util.convert_aa_code(x)
112
113
114def test_greedy_splitext(inp="foo/bar/boing.2.pdb.bz2",
115                         ref=["foo/bar/boing", ".2.pdb.bz2"]):
116    inp = os.path.normpath(inp)
117    ref[0] = os.path.normpath(ref[0])
118    ref[1] = os.path.normpath(ref[1])
119    root, ext = util.greedy_splitext(inp)
120    assert root == ref[0], "root incorrect"
121    assert ext == ref[1], "extension incorrect"
122
123
124@pytest.mark.parametrize('iterable, value', [
125    ([1, 2, 3], True),
126    ([], True),
127    ((1, 2, 3), True),
128    ((), True),
129    (range(3), True),
130    (np.array([1, 2, 3]), True),
131    (123, False),
132    ("byte string", False),
133    (u"unicode string", False)
134])
135def test_iterable(iterable, value):
136    assert util.iterable(iterable) == value
137
138
139class TestFilename(object):
140    root = "foo"
141    filename = "foo.psf"
142    ext = "pdb"
143    filename2 = "foo.pdb"
144
145    @pytest.mark.parametrize('name, ext, keep, actual_name', [
146        (filename, None, False, filename),
147        (filename, ext, False, filename2),
148        (filename, ext, True, filename),
149        (root, ext, False, filename2),
150        (root, ext, True, filename2)
151    ])
152    def test_string(self, name, ext, keep, actual_name):
153        file_name = util.filename(name, ext, keep)
154        assert file_name == actual_name
155
156    def test_named_stream(self):
157        ns = util.NamedStream(StringIO(), self.filename)
158        fn = util.filename(ns, ext=self.ext)
159        # assert_equal replace by this if loop to avoid segfault on some systems
160        if fn != ns:
161            pytest.fail("fn and ns are different")
162        assert str(fn) == self.filename2
163        assert ns.name == self.filename2
164
165
166class TestGeometryFunctions(object):
167    e1, e2, e3 = np.eye(3)
168    a = np.array([np.cos(np.pi / 3), np.sin(np.pi / 3), 0])
169    null = np.zeros(3)
170
171    @pytest.mark.parametrize('x_axis, y_axis, value', [
172        # Unit vectors
173        (e1, e2, np.pi / 2),
174        (e1, a, np.pi / 3),
175        # Angle vectors
176        (2 * e1, e2, np.pi / 2),
177        (-2 * e1, e2, np.pi - np.pi / 2),
178        (23.3 * e1, a, np.pi / 3),
179        # Null vector
180        (e1, null, np.nan),
181        # Coleniar
182        (a, a, 0.0)
183    ])
184    def test_vectors(self, x_axis, y_axis, value):
185        assert_equal(mdamath.angle(x_axis, y_axis), value)
186
187    @pytest.mark.parametrize('x_axis, y_axis, value', [
188        (-2.3456e7 * e1, 3.4567e-6 * e1, np.pi),
189        (2.3456e7 * e1, 3.4567e-6 * e1, 0.0)
190    ])
191    def test_angle_pi(self, x_axis, y_axis, value):
192        assert_almost_equal(mdamath.angle(x_axis, y_axis), value)
193
194    @pytest.mark.parametrize('x', np.linspace(0, np.pi, 20))
195    def test_angle_range(self, x):
196        r = 1000.
197        v = r * np.array([np.cos(x), np.sin(x), 0])
198        assert_almost_equal(mdamath.angle(self.e1, v), x, 6)
199
200    @pytest.mark.parametrize('vector, value', [
201        (e3, 1),
202        (a, np.linalg.norm(a)),
203        (null, 0.0)
204    ])
205    def test_norm(self, vector, value):
206        assert mdamath.norm(vector) == value
207
208    @pytest.mark.parametrize('x', np.linspace(0, np.pi, 20))
209    def test_norm_range(self, x):
210        r = 1000.
211        v = r * np.array([np.cos(x), np.sin(x), 0])
212        assert_almost_equal(mdamath.norm(v), r, 6)
213
214    @pytest.mark.parametrize('vec1, vec2, value', [
215        (e1, e2, e3),
216        (e1, null, 0.0)
217    ])
218    def test_normal(self, vec1, vec2, value):
219        assert_equal(mdamath.normal(vec1, vec2), value)
220        # add more non-trivial tests
221
222    def test_stp(self):
223        assert mdamath.stp(self.e1, self.e2, self.e3) == 1.0
224        # add more non-trivial tests
225
226    def test_dihedral(self):
227        ab = self.e1
228        bc = ab + self.e2
229        cd = bc + self.e3
230        assert_almost_equal(mdamath.dihedral(ab, bc, cd), -np.pi / 2)
231
232
233class TestMakeWhole(object):
234    """Set up a simple system:
235
236    +-----------+
237    |           |
238    | 6       3 | 6
239    | !       ! | !
240    |-5-8   1-2-|-5-8
241    | !       ! | !
242    | 7       4 | 7
243    |           |
244    +-----------+
245    """
246
247    prec = 5
248
249    @pytest.fixture()
250    def universe(self):
251        universe = mda.Universe(Make_Whole)
252        bondlist = [(0, 1), (1, 2), (1, 3), (1, 4), (4, 5), (4, 6), (4, 7)]
253        universe.add_TopologyAttr(Bonds(bondlist))
254        return universe
255
256    def test_single_atom_no_bonds(self):
257        # Call make_whole on single atom with no bonds, shouldn't move
258        u = mda.Universe(Make_Whole)
259        # Atom0 is isolated
260        bondlist = [(1, 2), (1, 3), (1, 4), (4, 5), (4, 6), (4, 7)]
261        u.add_TopologyAttr(Bonds(bondlist))
262
263        ag = u.atoms[[0]]
264        refpos = ag.positions.copy()
265        mdamath.make_whole(ag)
266
267        assert_array_almost_equal(ag.positions, refpos)
268
269    def test_scrambled_ag(self, universe):
270        # if order of atomgroup is mixed
271        ag = universe.atoms[[1, 3, 2, 4, 0, 6, 5, 7]]
272
273        mdamath.make_whole(ag)
274
275        # artificial system which uses 1nm bonds, so
276        # largest bond should be 20A
277        assert ag.bonds.values().max() < 20.1
278
279    @staticmethod
280    @pytest.fixture()
281    def ag(universe):
282        return universe.residues[0].atoms
283
284    def test_no_bonds(self):
285        # NoData caused by no bonds
286        universe = mda.Universe(Make_Whole)
287        ag = universe.residues[0].atoms
288        with pytest.raises(NoDataError):
289            mdamath.make_whole(ag)
290
291    def test_zero_box_size(self, universe, ag):
292        universe.dimensions = [0., 0., 0., 90., 90., 90.]
293        with pytest.raises(ValueError):
294            mdamath.make_whole(ag)
295
296    def test_wrong_reference_atom(self, universe, ag):
297        # Reference atom not in atomgroup
298        with pytest.raises(ValueError):
299            mdamath.make_whole(ag, reference_atom=universe.atoms[-1])
300
301    def test_impossible_solve(self, universe):
302        # check that the algorithm sees the bad walk
303        with pytest.raises(ValueError):
304            mdamath.make_whole(universe.atoms)
305
306    def test_solve_1(self, universe, ag):
307        # regular usage of function
308
309        refpos = universe.atoms[:4].positions.copy()
310
311        mdamath.make_whole(ag)
312
313        assert_array_almost_equal(universe.atoms[:4].positions, refpos)
314        assert_array_almost_equal(universe.atoms[4].position,
315                                  np.array([110.0, 50.0, 0.0]), decimal=self.prec)
316        assert_array_almost_equal(universe.atoms[5].position,
317                                  np.array([110.0, 60.0, 0.0]), decimal=self.prec)
318        assert_array_almost_equal(universe.atoms[6].position,
319                                  np.array([110.0, 40.0, 0.0]), decimal=self.prec)
320        assert_array_almost_equal(universe.atoms[7].position,
321                                  np.array([120.0, 50.0, 0.0]), decimal=self.prec)
322
323    def test_solve_2(self, universe, ag):
324        # use but specify the center atom
325
326        refpos = universe.atoms[4:8].positions.copy()
327
328        mdamath.make_whole(ag, reference_atom=universe.residues[0].atoms[4])
329
330        assert_array_almost_equal(universe.atoms[4:8].positions, refpos)
331        assert_array_almost_equal(universe.atoms[0].position,
332                                  np.array([-20.0, 50.0, 0.0]), decimal=self.prec)
333        assert_array_almost_equal(universe.atoms[1].position,
334                                  np.array([-10.0, 50.0, 0.0]), decimal=self.prec)
335        assert_array_almost_equal(universe.atoms[2].position,
336                                  np.array([-10.0, 60.0, 0.0]), decimal=self.prec)
337        assert_array_almost_equal(universe.atoms[3].position,
338                                  np.array([-10.0, 40.0, 0.0]), decimal=self.prec)
339
340    def test_solve_3(self, universe):
341        # put in a chunk that doesn't need any work
342
343        refpos = universe.atoms[:1].positions.copy()
344
345        mdamath.make_whole(universe.atoms[:1])
346
347        assert_array_almost_equal(universe.atoms[:1].positions, refpos)
348
349    def test_solve_4(self, universe):
350        # Put in only some of a fragment,
351        # check that not everything gets moved
352
353        chunk = universe.atoms[:7]
354        refpos = universe.atoms[7].position.copy()
355
356        mdamath.make_whole(chunk)
357
358        assert_array_almost_equal(universe.atoms[7].position, refpos)
359        assert_array_almost_equal(universe.atoms[4].position,
360                                  np.array([110.0, 50.0, 0.0]))
361        assert_array_almost_equal(universe.atoms[5].position,
362                                  np.array([110.0, 60.0, 0.0]))
363        assert_array_almost_equal(universe.atoms[6].position,
364                                  np.array([110.0, 40.0, 0.0]))
365
366    def test_double_frag_short_bonds(self, universe, ag):
367        # previous bug where if two fragments are given
368        # but all bonds were short, the algorithm didn't
369        # complain
370        mdamath.make_whole(ag)
371        with pytest.raises(ValueError):
372            mdamath.make_whole(universe.atoms)
373
374    def test_make_whole_triclinic(self):
375        u = mda.Universe(TPR, GRO)
376        thing = u.select_atoms('not resname SOL NA+')
377        mdamath.make_whole(thing)
378
379        blengths = thing.bonds.values()
380
381        assert blengths.max() < 2.0
382
383    def test_make_whole_fullerene(self):
384        # lots of circular bonds as a nice pathological case
385        u = mda.Universe(fullerene)
386
387        bbox = u.atoms.bbox()
388        u.dimensions[:3] = bbox[1] - bbox[0]
389        u.dimensions[3:] = 90.0
390
391        blengths = u.atoms.bonds.values()
392        # kaboom
393        u.atoms[::2].translate([u.dimensions[0], -2 * u.dimensions[1], 0.0])
394        u.atoms[1::2].translate([0.0, 7 * u.dimensions[1], -5 * u.dimensions[2]])
395
396        mdamath.make_whole(u.atoms)
397
398        assert_array_almost_equal(u.atoms.bonds.values(), blengths, decimal=self.prec)
399
400    def test_make_whole_multiple_molecules(self):
401        u = mda.Universe(two_water_gro, guess_bonds=True)
402
403        for f in u.atoms.fragments:
404            mdamath.make_whole(f)
405
406        assert u.atoms.bonds.values().max() < 2.0
407
408class Class_with_Caches(object):
409    def __init__(self):
410        self._cache = dict()
411        self.ref1 = 1.0
412        self.ref2 = 2.0
413        self.ref3 = 3.0
414        self.ref4 = 4.0
415        self.ref5 = 5.0
416
417    @cached('val1')
418    def val1(self):
419        return self.ref1
420
421    # Do one with property decorator as these are used together often
422    @property
423    @cached('val2')
424    def val2(self):
425        return self.ref2
426
427    # Check use of property setters
428    @property
429    @cached('val3')
430    def val3(self):
431        return self.ref3
432
433    @val3.setter
434    def val3(self, new):
435        self._clear_caches('val3')
436        self._fill_cache('val3', new)
437
438    @val3.deleter
439    def val3(self):
440        self._clear_caches('val3')
441
442    # Check that args are passed through to underlying functions
443    @cached('val4')
444    def val4(self, n1, n2):
445        return self._init_val_4(n1, n2)
446
447    def _init_val_4(self, m1, m2):
448        return self.ref4 + m1 + m2
449
450    # Args and Kwargs
451    @cached('val5')
452    def val5(self, n, s=None):
453        return self._init_val_5(n, s=s)
454
455    def _init_val_5(self, n, s=None):
456        return n * s
457
458    # These are designed to mimic the AG and Universe cache methods
459    def _clear_caches(self, *args):
460        if len(args) == 0:
461            self._cache = dict()
462        else:
463            for name in args:
464                try:
465                    del self._cache[name]
466                except KeyError:
467                    pass
468
469    def _fill_cache(self, name, value):
470        self._cache[name] = value
471
472
473class TestCachedDecorator(object):
474    @pytest.fixture()
475    def obj(self):
476        return Class_with_Caches()
477
478    def test_val1_lookup(self, obj):
479        obj._clear_caches()
480        assert 'val1' not in obj._cache
481        assert obj.val1() == obj.ref1
482        ret = obj.val1()
483        assert 'val1' in obj._cache
484        assert obj._cache['val1'] == ret
485        assert obj.val1() is obj._cache['val1']
486
487    def test_val1_inject(self, obj):
488        # Put something else into the cache and check it gets returned
489        # this tests that the cache is blindly being used
490        obj._clear_caches()
491        ret = obj.val1()
492        assert 'val1' in obj._cache
493        assert ret == obj.ref1
494        new = 77.0
495        obj._fill_cache('val1', new)
496        assert obj.val1() == new
497
498    # Managed property
499    def test_val2_lookup(self, obj):
500        obj._clear_caches()
501        assert 'val2' not in obj._cache
502        assert obj.val2 == obj.ref2
503        ret = obj.val2
504        assert 'val2' in obj._cache
505        assert obj._cache['val2'] == ret
506
507    def test_val2_inject(self, obj):
508        obj._clear_caches()
509        ret = obj.val2
510        assert 'val2' in obj._cache
511        assert ret == obj.ref2
512        new = 77.0
513        obj._fill_cache('val2', new)
514        assert obj.val2 == new
515
516        # Setter on cached attribute
517
518    def test_val3_set(self, obj):
519        obj._clear_caches()
520        assert obj.val3 == obj.ref3
521        new = 99.0
522        obj.val3 = new
523        assert obj.val3 == new
524        assert obj._cache['val3'] == new
525
526    def test_val3_del(self, obj):
527        # Check that deleting the property removes it from cache,
528        obj._clear_caches()
529        assert obj.val3 == obj.ref3
530        assert 'val3' in obj._cache
531        del obj.val3
532        assert 'val3' not in obj._cache
533        # But allows it to work as usual afterwards
534        assert obj.val3 == obj.ref3
535        assert 'val3' in obj._cache
536
537    # Pass args
538    def test_val4_args(self, obj):
539        obj._clear_caches()
540        assert obj.val4(1, 2) == 1 + 2 + obj.ref4
541        # Further calls should yield the old result
542        # this arguably shouldn't be cached...
543        assert obj.val4(3, 4) == 1 + 2 + obj.ref4
544
545    # Pass args and kwargs
546    def test_val5_kwargs(self, obj):
547        obj._clear_caches()
548        assert obj.val5(5, s='abc') == 5 * 'abc'
549
550        assert obj.val5(5, s='!!!') == 5 * 'abc'
551
552
553class TestConvFloat(object):
554    @pytest.mark.parametrize('s, output', [
555        ('0.45', 0.45),
556        ('.45', 0.45),
557        ('a.b', 'a.b')
558    ])
559    def test_float(self, s, output):
560        assert util.conv_float(s) == output
561
562    @pytest.mark.parametrize('input, output', [
563        (('0.45', '0.56', '6.7'), [0.45, 0.56, 6.7]),
564        (('0.45', 'a.b', '!!'), [0.45, 'a.b', '!!'])
565    ])
566    def test_map(self, input, output):
567        ret = [util.conv_float(el) for el in input]
568        assert ret == output
569
570
571class TestFixedwidthBins(object):
572    def test_keys(self):
573        ret = util.fixedwidth_bins(0.5, 1.0, 2.0)
574        for k in ['Nbins', 'delta', 'min', 'max']:
575            assert k in ret
576
577    def test_ValueError(self):
578        with pytest.raises(ValueError):
579            util.fixedwidth_bins(0.1, 5.0, 4.0)
580
581    @pytest.mark.parametrize(
582        'delta, xmin, xmax, output_Nbins, output_delta, output_min, output_max',
583        [
584            (0.1, 4.0, 5.0, 10, 0.1, 4.0, 5.0),
585            (0.4, 4.0, 5.0, 3, 0.4, 3.9, 5.1)
586        ])
587    def test_usage(self, delta, xmin, xmax, output_Nbins, output_delta,
588                   output_min, output_max):
589        ret = util.fixedwidth_bins(delta, xmin, xmax)
590        assert ret['Nbins'] == output_Nbins
591        assert ret['delta'] == output_delta
592        assert ret['min'], output_min
593        assert ret['max'], output_max
594
595@pytest.fixture
596def atoms():
597    from MDAnalysisTests import make_Universe
598    u = make_Universe(extras=("masses",), size=(3,1,1))
599    return u.atoms
600
601@pytest.mark.parametrize('weights,result',
602                          [
603                              (None, None),
604                              ("mass", np.array([5.1, 4.2, 3.3])),
605                              (np.array([12.0, 1.0, 12.0]), np.array([12.0, 1.0, 12.0])),
606                              ([12.0, 1.0, 12.0], np.array([12.0, 1.0, 12.0])),
607                              (range(3), np.arange(3, dtype=int)),
608                          ])
609def test_check_weights_ok(atoms, weights, result):
610    assert_array_equal(util.get_weights(atoms, weights), result)
611
612@pytest.mark.parametrize('weights',
613                          [42,
614                           "geometry",
615                           np.array(1.0),
616                          ])
617def test_check_weights_raises_TypeError(atoms, weights):
618    with pytest.raises(TypeError):
619        util.get_weights(atoms, weights)
620
621@pytest.mark.parametrize('weights',
622                          [
623                           np.array([12.0, 1.0, 12.0, 1.0]),
624                           [12.0, 1.0],
625                           np.array([[12.0, 1.0, 12.0]]),
626                           np.array([[12.0, 1.0, 12.0], [12.0, 1.0, 12.0]]),
627                          ])
628def test_check_weights_raises_ValueError(atoms, weights):
629    with pytest.raises(ValueError):
630        util.get_weights(atoms, weights)
631
632
633class TestGuessFormat(object):
634    """Test guessing of format from filenames
635
636    Tests also getting the appropriate Parser and Reader from a
637    given filename
638    """
639    # list of known formats, followed by the desired Parser and Reader
640    # None indicates that there isn't a Reader for this format
641    # All formats call fallback to the MinimalParser
642    formats = [
643        ('CHAIN', mda.topology.MinimalParser.MinimalParser, mda.coordinates.chain.ChainReader),
644        ('CONFIG', mda.topology.DLPolyParser.ConfigParser, mda.coordinates.DLPoly.ConfigReader),
645        ('CRD', mda.topology.CRDParser.CRDParser, mda.coordinates.CRD.CRDReader),
646        ('DATA', mda.topology.LAMMPSParser.DATAParser, mda.coordinates.LAMMPS.DATAReader),
647        ('DCD', mda.topology.MinimalParser.MinimalParser, mda.coordinates.DCD.DCDReader),
648        ('DMS', mda.topology.DMSParser.DMSParser, mda.coordinates.DMS.DMSReader),
649        ('GMS', mda.topology.GMSParser.GMSParser, mda.coordinates.GMS.GMSReader),
650        ('GRO', mda.topology.GROParser.GROParser, mda.coordinates.GRO.GROReader),
651        ('HISTORY', mda.topology.DLPolyParser.HistoryParser, mda.coordinates.DLPoly.HistoryReader),
652        ('INPCRD', mda.topology.MinimalParser.MinimalParser, mda.coordinates.INPCRD.INPReader),
653        ('LAMMPS', mda.topology.MinimalParser.MinimalParser, mda.coordinates.LAMMPS.DCDReader),
654        ('MDCRD', mda.topology.MinimalParser.MinimalParser, mda.coordinates.TRJ.TRJReader),
655        ('MMTF', mda.topology.MMTFParser.MMTFParser, mda.coordinates.MMTF.MMTFReader),
656        ('MOL2', mda.topology.MOL2Parser.MOL2Parser, mda.coordinates.MOL2.MOL2Reader),
657        ('NC', mda.topology.MinimalParser.MinimalParser, mda.coordinates.TRJ.NCDFReader),
658        ('NCDF', mda.topology.MinimalParser.MinimalParser, mda.coordinates.TRJ.NCDFReader),
659        ('PDB', mda.topology.PDBParser.PDBParser, mda.coordinates.PDB.PDBReader),
660        ('PDBQT', mda.topology.PDBQTParser.PDBQTParser, mda.coordinates.PDBQT.PDBQTReader),
661        ('PRMTOP', mda.topology.TOPParser.TOPParser, None),
662        ('PQR', mda.topology.PQRParser.PQRParser, mda.coordinates.PQR.PQRReader),
663        ('PSF', mda.topology.PSFParser.PSFParser, None),
664        ('RESTRT', mda.topology.MinimalParser.MinimalParser, mda.coordinates.INPCRD.INPReader),
665        ('TOP', mda.topology.TOPParser.TOPParser, None),
666        ('TPR', mda.topology.TPRParser.TPRParser, None),
667        ('TRJ', mda.topology.MinimalParser.MinimalParser, mda.coordinates.TRJ.TRJReader),
668        ('TRR', mda.topology.MinimalParser.MinimalParser, mda.coordinates.TRR.TRRReader),
669        ('XML', mda.topology.HoomdXMLParser.HoomdXMLParser, None),
670        ('XPDB', mda.topology.ExtendedPDBParser.ExtendedPDBParser, mda.coordinates.PDB.ExtendedPDBReader),
671        ('XTC', mda.topology.MinimalParser.MinimalParser, mda.coordinates.XTC.XTCReader),
672        ('XYZ', mda.topology.XYZParser.XYZParser, mda.coordinates.XYZ.XYZReader),
673        ('TRZ', mda.topology.MinimalParser.MinimalParser, mda.coordinates.TRZ.TRZReader),
674    ]
675    # list of possible compressed extensions
676    # include no extension too!
677    compressed_extensions = ['.bz2', '.gz']
678
679    @pytest.mark.parametrize('extention',
680                             [format_tuple[0].upper() for format_tuple in
681                              formats] +
682                             [format_tuple[0].lower() for format_tuple in
683                              formats])
684    def test_get_extention(self, extention):
685        """Check that get_ext works"""
686        file_name = 'file.{0}'.format(extention)
687        a, b = util.get_ext(file_name)
688
689        assert a == 'file'
690        assert b == extention.lower()
691
692    @pytest.mark.parametrize('extention',
693                             [format_tuple[0].upper() for format_tuple in
694                              formats] +
695                             [format_tuple[0].lower() for format_tuple in
696                              formats])
697    def test_compressed_without_compression_extention(self, extention):
698        """Check that format suffixed by compressed extension works"""
699        file_name = 'file.{0}'.format(extention)
700        a = util.format_from_filename_extension(file_name)
701        # expect answer to always be uppercase
702        assert a == extention.upper()
703
704    @pytest.mark.parametrize('extention',
705                             [format_tuple[0].upper() for format_tuple in
706                              formats] +
707                             [format_tuple[0].lower() for format_tuple in
708                              formats])
709    @pytest.mark.parametrize('compression_extention', compressed_extensions)
710    def test_compressed(self, extention, compression_extention):
711        """Check that format suffixed by compressed extension works"""
712        file_name = 'file.{0}{1}'.format(extention, compression_extention)
713        a = util.format_from_filename_extension(file_name)
714        # expect answer to always be uppercase
715        assert a == extention.upper()
716
717    @pytest.mark.parametrize('extention',
718                             [format_tuple[0].upper() for format_tuple in
719                              formats] + [format_tuple[0].lower() for
720                                          format_tuple in formats])
721    def test_guess_format(self, extention):
722        file_name = 'file.{0}'.format(extention)
723        a = util.guess_format(file_name)
724        # expect answer to always be uppercase
725        assert a == extention.upper()
726
727    @pytest.mark.parametrize('extention',
728                             [format_tuple[0].upper() for format_tuple in
729                              formats] + [format_tuple[0].lower() for
730                                          format_tuple in formats])
731    @pytest.mark.parametrize('compression_extention', compressed_extensions)
732    def test_guess_format_compressed(self, extention, compression_extention):
733        file_name = 'file.{0}{1}'.format(extention, compression_extention)
734        a = util.guess_format(file_name)
735        # expect answer to always be uppercase
736        assert a == extention.upper()
737
738    @pytest.mark.parametrize('extention, parser',
739                             [(format_tuple[0], format_tuple[1]) for
740                              format_tuple in formats if
741                              format_tuple[1] is not None]
742                             )
743    def test_get_parser(self, extention, parser):
744        file_name = 'file.{0}'.format(extention)
745        a = mda.topology.core.get_parser_for(file_name)
746
747        assert a == parser
748
749    @pytest.mark.parametrize('extention, parser',
750                             [(format_tuple[0], format_tuple[1]) for
751                              format_tuple in formats if
752                              format_tuple[1] is not None]
753                             )
754    @pytest.mark.parametrize('compression_extention', compressed_extensions)
755    def test_get_parser_compressed(self, extention, parser,
756                                   compression_extention):
757        file_name = 'file.{0}{1}'.format(extention, compression_extention)
758        a = mda.topology.core.get_parser_for(file_name)
759
760        assert a == parser
761
762    @pytest.mark.parametrize('extention',
763                             [(format_tuple[0], format_tuple[1]) for
764                              format_tuple in formats if
765                              format_tuple[1] is None]
766                             )
767    def test_get_parser_invalid(self, extention):
768        file_name = 'file.{0}'.format(extention)
769        with pytest.raises(ValueError):
770            mda.topology.core.get_parser_for(file_name)
771
772    @pytest.mark.parametrize('extention, reader',
773                             [(format_tuple[0], format_tuple[2]) for
774                              format_tuple in formats if
775                              format_tuple[2] is not None]
776                             )
777    def test_get_reader(self, extention, reader):
778        file_name = 'file.{0}'.format(extention)
779        a = mda.coordinates.core.get_reader_for(file_name)
780
781        assert a == reader
782
783    @pytest.mark.parametrize('extention, reader',
784                             [(format_tuple[0], format_tuple[2]) for
785                              format_tuple in formats if
786                              format_tuple[2] is not None]
787                             )
788    @pytest.mark.parametrize('compression_extention', compressed_extensions)
789    def test_get_reader_compressed(self, extention, reader,
790                                   compression_extention):
791        file_name = 'file.{0}{1}'.format(extention, compression_extention)
792        a = mda.coordinates.core.get_reader_for(file_name)
793
794        assert a == reader
795
796    @pytest.mark.parametrize('extention',
797                             [(format_tuple[0], format_tuple[2]) for
798                              format_tuple in formats if
799                              format_tuple[2] is None]
800                             )
801    def test_get_reader_invalid(self, extention):
802        file_name = 'file.{0}'.format(extention)
803        with pytest.raises(ValueError):
804            mda.coordinates.core.get_reader_for(file_name)
805
806    def test_check_compressed_format_TypeError(self):
807        with pytest.raises(TypeError):
808            util.check_compressed_format(1234, 'bz2')
809
810    def test_format_from_filename_TypeError(self):
811        with pytest.raises(TypeError):
812            util.format_from_filename_extension(1234)
813
814    def test_guess_format_stream_ValueError(self):
815        # This stream has no name, so can't guess format
816        s = StringIO('this is a very fun file')
817        with pytest.raises(ValueError):
818            util.guess_format(s)
819
820    def test_from_ndarray(self):
821        fn = np.zeros((3, 3))
822        rd = mda.coordinates.core.get_reader_for(fn)
823        assert rd == mda.coordinates.memory.MemoryReader
824
825
826class TestUniqueRows(object):
827    def test_unique_rows_2(self):
828        a = np.array([[0, 1], [1, 2], [2, 1], [0, 1], [0, 1], [2, 1]])
829
830        assert_array_equal(util.unique_rows(a),
831                           np.array([[0, 1], [1, 2], [2, 1]]))
832
833    def test_unique_rows_3(self):
834        a = np.array([[0, 1, 2], [0, 1, 2], [2, 3, 4], [0, 1, 2]])
835
836        assert_array_equal(util.unique_rows(a),
837                           np.array([[0, 1, 2], [2, 3, 4]]))
838
839    def test_unique_rows_with_view(self):
840        # unique_rows doesn't work when flags['OWNDATA'] is False,
841        # happens when second dimension is created through broadcast
842        a = np.array([1, 2])
843
844        assert_array_equal(util.unique_rows(a[None, :]),
845                           np.array([[1, 2]]))
846
847
848class TestGetWriterFor(object):
849    def test_no_filename_argument(self):
850        # Does ``get_writer_for`` fails as expected when provided no
851        # filename arguments
852        with pytest.raises(TypeError):
853            mda.coordinates.core.get_writer_for()
854
855    def test_precedence(self):
856        writer = mda.coordinates.core.get_writer_for('test.pdb', 'GRO')
857        assert writer == mda.coordinates.GRO.GROWriter
858        # Make sure ``get_writer_for`` uses *format* if provided
859
860    def test_missing_extension(self):
861        # Make sure ``get_writer_for`` behave as expected if *filename*
862        # has no extension
863        with pytest.raises(TypeError):
864            mda.coordinates.core.get_writer_for(filename='test', format=None)
865
866    def test_wrong_format(self):
867        # Make sure ``get_writer_for`` fails if the format is unknown
868        with pytest.raises(TypeError):
869            mda.coordinates.core.get_writer_for(filename="fail_me",
870                                                format='UNK')
871
872    def test_compressed_extension(self):
873        for ext in ('.gz', '.bz2'):
874            fn = 'test.gro' + ext
875            writer = mda.coordinates.core.get_writer_for(filename=fn)
876            assert writer == mda.coordinates.GRO.GROWriter
877            # Make sure ``get_writer_for`` works with compressed file file names
878
879    def test_compressed_extension_fail(self):
880        for ext in ('.gz', '.bz2'):
881            fn = 'test.unk' + ext
882            # Make sure ``get_writer_for`` fails if an unknown format is compressed
883            with pytest.raises(TypeError):
884                mda.coordinates.core.get_writer_for(filename=fn)
885
886    def test_non_string_filename(self):
887        # Does ``get_writer_for`` fails with non string filename, no format
888        with pytest.raises(ValueError):
889            mda.coordinates.core.get_writer_for(filename=StringIO(),
890                                                format=None)
891
892    def test_multiframe_failure(self):
893        # does ``get_writer_for`` fail with invalid format and multiframe not None
894        with pytest.raises(TypeError):
895            mda.coordinates.core.get_writer_for(filename="fail_me",
896                                                format='UNK', multiframe=True)
897            mda.coordinates.core.get_writer_for(filename="fail_me",
898                                                format='UNK', multiframe=False)
899
900    def test_multiframe_nonsense(self):
901        with pytest.raises(ValueError):
902            mda.coordinates.core.get_writer_for(filename='this.gro',
903                                                multiframe='sandwich')
904
905    formats = [
906        # format name, related class, singleframe, multiframe
907        ('CRD', mda.coordinates.CRD.CRDWriter, True, False),
908        ('DATA', mda.coordinates.LAMMPS.DATAWriter, True, False),
909        ('DCD', mda.coordinates.DCD.DCDWriter, True, True),
910        # ('ENT', mda.coordinates.PDB.PDBWriter, True, False),
911        ('GRO', mda.coordinates.GRO.GROWriter, True, False),
912        ('LAMMPS', mda.coordinates.LAMMPS.DCDWriter, True, True),
913        ('MOL2', mda.coordinates.MOL2.MOL2Writer, True, True),
914        ('NCDF', mda.coordinates.TRJ.NCDFWriter, True, True),
915        ('NULL', mda.coordinates.null.NullWriter, True, True),
916        # ('PDB', mda.coordinates.PDB.PDBWriter, True, True), special case, done separately
917        ('PDBQT', mda.coordinates.PDBQT.PDBQTWriter, True, False),
918        ('PQR', mda.coordinates.PQR.PQRWriter, True, False),
919        ('TRR', mda.coordinates.TRR.TRRWriter, True, True),
920        ('XTC', mda.coordinates.XTC.XTCWriter, True, True),
921        ('XYZ', mda.coordinates.XYZ.XYZWriter, True, True),
922        ('TRZ', mda.coordinates.TRZ.TRZWriter, True, True),
923    ]
924
925    @pytest.mark.parametrize('format, writer',
926                             [(format_tuple[0], format_tuple[1]) for
927                              format_tuple in formats if
928                              format_tuple[2] is True])
929    def test_singleframe(self, format, writer):
930        assert mda.coordinates.core.get_writer_for('this', format=format,
931                                                   multiframe=False) == writer
932
933    @pytest.mark.parametrize('format', [(format_tuple[0], format_tuple[1]) for
934                                        format_tuple in formats if
935                                        format_tuple[2] is False])
936    def test_singleframe_fails(self, format):
937        with pytest.raises(TypeError):
938            mda.coordinates.core.get_writer_for('this', format=format,
939                                                multiframe=False)
940
941    @pytest.mark.parametrize('format, writer',
942                             [(format_tuple[0], format_tuple[1]) for
943                              format_tuple in formats if
944                              format_tuple[3] is True])
945    def test_multiframe(self, format, writer):
946        assert mda.coordinates.core.get_writer_for('this', format=format,
947                                                   multiframe=True) == writer
948
949    @pytest.mark.parametrize('format',
950                             [format_tuple[0] for format_tuple in formats if
951                              format_tuple[3] is False])
952    def test_multiframe_fails(self, format):
953        with pytest.raises(TypeError):
954            mda.coordinates.core.get_writer_for('this', format=format,
955                                                multiframe=True)
956
957    def test_get_writer_for_pdb(self):
958        assert mda.coordinates.core.get_writer_for('this', format='PDB',
959                                                   multiframe=False) == mda.coordinates.PDB.PDBWriter
960        assert mda.coordinates.core.get_writer_for('this', format='PDB',
961                                                   multiframe=True) == mda.coordinates.PDB.MultiPDBWriter
962        assert mda.coordinates.core.get_writer_for('this', format='ENT',
963                                                   multiframe=False) == mda.coordinates.PDB.PDBWriter
964        assert mda.coordinates.core.get_writer_for('this', format='ENT',
965                                                   multiframe=True) == mda.coordinates.PDB.MultiPDBWriter
966
967
968class TestBlocksOf(object):
969    def test_blocks_of_1(self):
970        arr = np.arange(16).reshape(4, 4)
971
972        view = util.blocks_of(arr, 1, 1)
973
974        assert view.shape == (4, 1, 1)
975        assert_array_almost_equal(view,
976                                  np.array([[[0]], [[5]], [[10]], [[15]]]))
977
978        # Change my view, check changes are reflected in arr
979        view[:] = 1001
980
981        assert_array_almost_equal(arr,
982                                  np.array([[1001, 1, 2, 3],
983                                            [4, 1001, 6, 7],
984                                            [8, 9, 1001, 11],
985                                            [12, 13, 14, 1001]]))
986
987    def test_blocks_of_2(self):
988        arr = np.arange(16).reshape(4, 4)
989
990        view = util.blocks_of(arr, 2, 2)
991
992        assert view.shape == (2, 2, 2)
993        assert_array_almost_equal(view, np.array([[[0, 1], [4, 5]],
994                                                  [[10, 11], [14, 15]]]))
995
996        view[0] = 100
997        view[1] = 200
998
999        assert_array_almost_equal(arr,
1000                                  np.array([[100, 100, 2, 3],
1001                                            [100, 100, 6, 7],
1002                                            [8, 9, 200, 200],
1003                                            [12, 13, 200, 200]]))
1004
1005    def test_blocks_of_3(self):
1006        # testing non square array
1007        arr = np.arange(32).reshape(8, 4)
1008
1009        view = util.blocks_of(arr, 2, 1)
1010
1011        assert view.shape == (4, 2, 1)
1012
1013    def test_blocks_of_4(self):
1014        # testing block exceeding array size results in empty view
1015        arr = np.arange(4).reshape(2, 2)
1016        view = util.blocks_of(arr, 3, 3)
1017        assert view.shape == (0, 3, 3)
1018        view[:] = 100
1019        assert_array_equal(arr, np.arange(4).reshape(2, 2))
1020
1021    def test_blocks_of_ValueError(self):
1022        arr = np.arange(16).reshape(4, 4)
1023        with pytest.raises(ValueError):
1024            util.blocks_of(arr, 2, 1)  # blocks don't fit
1025        with pytest.raises(ValueError):
1026            util.blocks_of(arr[:, ::2], 2, 1)  # non-contiguous input
1027
1028
1029class TestNamespace(object):
1030    @staticmethod
1031    @pytest.fixture()
1032    def ns():
1033        return util.Namespace()
1034
1035    def test_getitem(self, ns):
1036        ns.this = 42
1037        assert ns['this'] == 42
1038
1039    def test_getitem_KeyError(self, ns):
1040        with pytest.raises(KeyError):
1041            dict.__getitem__(ns, 'this')
1042
1043    def test_setitem(self, ns):
1044        ns['this'] = 42
1045
1046        assert ns['this'] == 42
1047
1048    def test_delitem(self, ns):
1049        ns['this'] = 42
1050        assert 'this' in ns
1051        del ns['this']
1052        assert 'this' not in ns
1053
1054    def test_delitem_AttributeError(self, ns):
1055        with pytest.raises(AttributeError):
1056            del ns.this
1057
1058    def test_setattr(self, ns):
1059        ns.this = 42
1060
1061        assert ns.this == 42
1062
1063    def test_getattr(self, ns):
1064        ns['this'] = 42
1065
1066        assert ns.this == 42
1067
1068    def test_getattr_AttributeError(self, ns):
1069        with pytest.raises(AttributeError):
1070            getattr(ns, 'this')
1071
1072    def test_delattr(self, ns):
1073        ns['this'] = 42
1074
1075        assert 'this' in ns
1076        del ns.this
1077        assert 'this' not in ns
1078
1079    def test_eq(self, ns):
1080        ns['this'] = 42
1081
1082        ns2 = util.Namespace()
1083        ns2['this'] = 42
1084
1085        assert ns == ns2
1086
1087    def test_len(self, ns):
1088        assert len(ns) == 0
1089        ns['this'] = 1
1090        ns['that'] = 2
1091        assert len(ns) == 2
1092
1093    def test_iter(self, ns):
1094        ns['this'] = 12
1095        ns['that'] = 24
1096        ns['other'] = 48
1097
1098        seen = []
1099        for val in ns:
1100            seen.append(val)
1101        for val in ['this', 'that', 'other']:
1102            assert val in seen
1103
1104
1105class TestTruncateInteger(object):
1106    @pytest.mark.parametrize('a, b', [
1107        ((1234, 1), 4),
1108        ((1234, 2), 34),
1109        ((1234, 3), 234),
1110        ((1234, 4), 1234),
1111        ((1234, 5), 1234),
1112    ])
1113    def test_ltruncate_int(self, a, b):
1114        assert util.ltruncate_int(*a) == b
1115
1116class TestFlattenDict(object):
1117    def test_flatten_dict(self):
1118        d = {
1119            'A' : { 1 : ('a', 'b', 'c')},
1120            'B' : { 2 : ('c', 'd', 'e')},
1121            'C' : { 3 : ('f', 'g', 'h')}
1122        }
1123        result = util.flatten_dict(d)
1124
1125        for k in result:
1126            assert type(k) == tuple
1127            assert len(k) == 2
1128            assert k[0] in d
1129            assert k[1] in d[k[0]]
1130            assert result[k] in d[k[0]].values()
1131
1132class TestStaticVariables(object):
1133    """Tests concerning the decorator @static_variables
1134    """
1135
1136    def test_static_variables(self):
1137        x = [0]
1138
1139        @static_variables(foo=0, bar={'test': x})
1140        def myfunc():
1141            assert myfunc.foo is 0
1142            assert type(myfunc.bar) is type(dict())
1143            if 'test2' not in myfunc.bar:
1144                myfunc.bar['test2'] = "a"
1145            else:
1146                myfunc.bar['test2'] += "a"
1147            myfunc.bar['test'][0] += 1
1148            return myfunc.bar['test']
1149
1150        assert hasattr(myfunc, 'foo')
1151        assert hasattr(myfunc, 'bar')
1152
1153        y = myfunc()
1154        assert y is x
1155        assert x[0] is 1
1156        assert myfunc.bar['test'][0] is 1
1157        assert myfunc.bar['test2'] == "a"
1158
1159        x = [0]
1160        y = myfunc()
1161        assert y is not x
1162        assert myfunc.bar['test'][0] is 2
1163        assert myfunc.bar['test2'] == "aa"
1164
1165class TestWarnIfNotUnique(object):
1166    """Tests concerning the decorator @warn_if_not_uniue
1167    """
1168
1169    @pytest.fixture()
1170    def warn_msg(self, func, group, group_name):
1171        msg = ("{}.{}(): {} {} contains duplicates. Results might be "
1172               "biased!".format(group.__class__.__name__, func.__name__,
1173                                group_name, group.__repr__()))
1174        return msg
1175
1176    def test_warn_if_not_unique(self, atoms):
1177        # Check that the warn_if_not_unique decorator has a "static variable"
1178        # warn_if_not_unique.warned:
1179        assert hasattr(warn_if_not_unique, 'warned')
1180        assert warn_if_not_unique.warned is False
1181
1182    def test_warn_if_not_unique_once_outer(self, atoms):
1183
1184        # Construct a scenario with two nested functions, each one decorated
1185        # with @warn_if_not_unique:
1186
1187        @warn_if_not_unique
1188        def inner(group):
1189            if not group.isunique:
1190                # The inner function should not trigger a warning, and the state
1191                # of warn_if_not_unique.warned should reflect that:
1192                assert warn_if_not_unique.warned is True
1193            return 0
1194
1195        @warn_if_not_unique
1196        def outer(group):
1197            return inner(group)
1198
1199        # Check that no warning is raised for a unique group:
1200        assert atoms.isunique
1201        with pytest.warns(None) as w:
1202            x = outer(atoms)
1203            assert x is 0
1204            assert not w.list
1205
1206        # Check that a warning is raised for a group with duplicates:
1207        ag = atoms + atoms[0]
1208        msg = self.warn_msg(outer, ag, "'ag'")
1209        with pytest.warns(DuplicateWarning) as w:
1210            assert warn_if_not_unique.warned is False
1211            x = outer(ag)
1212            # Assert that the "warned" state is restored:
1213            assert warn_if_not_unique.warned is False
1214            # Check correct function execution:
1215            assert x is 0
1216            # Only one warning must have been raised:
1217            assert len(w) == 1
1218            # For whatever reason pytest.warns(DuplicateWarning, match=msg)
1219            # doesn't work, so we compare the recorded warning message instead:
1220            assert w[0].message.args[0] == msg
1221            # Make sure the warning uses the correct stacklevel and references
1222            # this file instead of MDAnalysis/lib/util.py:
1223            assert w[0].filename == __file__
1224
1225    def test_warned_state_restored_on_failure(self, atoms):
1226
1227        # A decorated function raising an exception:
1228        @warn_if_not_unique
1229        def thisfails(group):
1230            raise ValueError()
1231
1232        ag = atoms + atoms[0]
1233        msg = self.warn_msg(thisfails, ag, "'ag'")
1234        with pytest.warns(DuplicateWarning) as w:
1235            assert warn_if_not_unique.warned is False
1236            with pytest.raises(ValueError):
1237                thisfails(ag)
1238            # Assert that the "warned" state is restored despite `thisfails`
1239            # raising an exception:
1240            assert warn_if_not_unique.warned is False
1241            assert len(w) == 1
1242            assert w[0].message.args[0] == msg
1243            assert w[0].filename == __file__
1244
1245    def test_warn_if_not_unique_once_inner(self, atoms):
1246
1247        # Construct a scenario with two nested functions, each one decorated
1248        # with @warn_if_not_unique, but the outer function adds a duplicate
1249        # to the group:
1250
1251        @warn_if_not_unique
1252        def inner(group):
1253            return 0
1254
1255        @warn_if_not_unique
1256        def outer(group):
1257            dupgroup = group + group[0]
1258            return inner(dupgroup)
1259
1260        # Check that even though outer() is called the warning is raised for
1261        # inner():
1262        msg = self.warn_msg(inner, atoms + atoms[0], "'dupgroup'")
1263        with pytest.warns(DuplicateWarning) as w:
1264            assert warn_if_not_unique.warned is False
1265            x = outer(atoms)
1266            # Assert that the "warned" state is restored:
1267            assert warn_if_not_unique.warned is False
1268            # Check correct function execution:
1269            assert x is 0
1270            # Only one warning must have been raised:
1271            assert len(w) == 1
1272            assert w[0].message.args[0] == msg
1273            assert w[0].filename == __file__
1274
1275    def test_warn_if_not_unique_multiple_references(self, atoms):
1276        ag = atoms + atoms[0]
1277        aag = ag
1278        aaag = aag
1279
1280        @warn_if_not_unique
1281        def func(group):
1282            return group.isunique
1283
1284        # Check that the warning message contains the names of all references to
1285        # the group in alphabetic order:
1286        msg = self.warn_msg(func, ag, "'aaag' a.k.a. 'aag' a.k.a. 'ag'")
1287        with pytest.warns(DuplicateWarning) as w:
1288            x = func(ag)
1289            # Assert that the "warned" state is restored:
1290            assert warn_if_not_unique.warned is False
1291            # Check correct function execution:
1292            assert x is False
1293            # Check warning message:
1294            assert w[0].message.args[0] == msg
1295            # Check correct file referenced:
1296            assert w[0].filename == __file__
1297
1298    def test_warn_if_not_unique_unnamed(self, atoms):
1299
1300        @warn_if_not_unique
1301        def func(group):
1302            pass
1303
1304        msg = self.warn_msg(func, atoms + atoms[0],
1305                            "'unnamed {}'".format(atoms.__class__.__name__))
1306        with pytest.warns(DuplicateWarning) as w:
1307            func(atoms + atoms[0])
1308            # Check warning message:
1309            assert w[0].message.args[0] == msg
1310
1311    def test_warn_if_not_unique_fails_for_non_groupmethods(self):
1312
1313        @warn_if_not_unique
1314        def func(group):
1315            pass
1316
1317        class dummy(object):
1318            pass
1319
1320        with pytest.raises(AttributeError):
1321            func(dummy())
1322
1323    def test_filter_duplicate_with_userwarning(self, atoms):
1324
1325        @warn_if_not_unique
1326        def func(group):
1327            pass
1328
1329        with warnings.catch_warnings(record=True) as record:
1330            warnings.resetwarnings()
1331            warnings.filterwarnings("ignore", category=UserWarning)
1332            with pytest.warns(None) as w:
1333                func(atoms)
1334                assert not w.list
1335            assert len(record) == 0
1336
1337class TestCheckCoords(object):
1338    """Tests concerning the decorator @check_coords
1339    """
1340
1341    prec = 6
1342
1343    def test_default_options(self):
1344        a_in = np.zeros(3, dtype=np.float32)
1345        b_in = np.ones(3, dtype=np.float32)
1346        b_in2 = np.ones((2, 3), dtype=np.float32)
1347
1348        @check_coords('a','b')
1349        def func(a, b):
1350            # check that enforce_copy is True by default:
1351            assert a is not a_in
1352            assert b is not b_in
1353            # check that convert_single is True by default:
1354            assert a.shape == (1, 3)
1355            assert b.shape == (1, 3)
1356            return a + b
1357
1358        # check that allow_single is True by default:
1359        res = func(a_in, b_in)
1360        # check that reduce_result_if_single is True by default:
1361        assert res.shape == (3,)
1362        # check correct function execution:
1363        assert_array_equal(res, b_in)
1364
1365        # check that check_lenghts_match is True by default:
1366        with pytest.raises(ValueError):
1367            res = func(a_in, b_in2)
1368
1369    def test_enforce_copy(self):
1370
1371        a_2d = np.ones((1, 3), dtype=np.float32)
1372        b_1d = np.zeros(3, dtype=np.float32)
1373        c_2d = np.zeros((1, 6), dtype=np.float32)[:, ::2]
1374        d_2d = np.zeros((1, 3), dtype=np.int64)
1375
1376        @check_coords('a', 'b', 'c', 'd', enforce_copy=False)
1377        def func(a, b, c, d):
1378            # Assert that if enforce_copy is False:
1379            # no copy is made if input shape, order, and dtype are correct:
1380            assert a is a_2d
1381            # a copy is made if input shape has to be changed:
1382            assert b is not b_1d
1383            # a copy is made if input order has to be changed:
1384            assert c is not c_2d
1385            # a copy is made if input dtype has to be changed:
1386            assert d is not d_2d
1387            # Assert correct dtype conversion:
1388            assert d.dtype == np.float32
1389            assert_almost_equal(d, d_2d, self.prec)
1390            # Assert all shapes are converted to (1, 3):
1391            assert a.shape == b.shape == c.shape == d.shape == (1, 3)
1392            return a + b + c + d
1393
1394        # Call func() to:
1395        # - test the above assertions
1396        # - ensure that input of single coordinates is simultaneously possible
1397        #   with different shapes (3,) and (1, 3)
1398        res = func(a_2d, b_1d, c_2d, d_2d)
1399        # Since some inputs are not 1d, even though reduce_result_if_single is
1400        # True, the result must have shape (1, 3):
1401        assert res.shape == (1, 3)
1402        # check correct function execution:
1403        assert_array_equal(res, a_2d)
1404
1405    def test_no_allow_single(self):
1406
1407        @check_coords('a', allow_single=False)
1408        def func(a):
1409            pass
1410
1411        with pytest.raises(ValueError) as err:
1412            func(np.zeros(3, dtype=np.float32))
1413            assert err.msg == ("func(): a.shape must be (n, 3), got (3,).")
1414
1415    def test_no_convert_single(self):
1416
1417        a_1d = np.arange(-3, 0, dtype=np.float32)
1418
1419        @check_coords('a', enforce_copy=False, convert_single=False)
1420        def func(a):
1421            # assert no conversion and no copy were performed:
1422            assert a is a_1d
1423            return a
1424
1425        res = func(a_1d)
1426        # Assert result has been reduced:
1427        assert res == a_1d[0]
1428        assert type(res) is np.float32
1429
1430    def test_no_reduce_result_if_single(self):
1431
1432        a_1d = np.zeros(3, dtype=np.float32)
1433
1434        # Test without shape conversion:
1435        @check_coords('a', enforce_copy=False, convert_single=False,
1436                      reduce_result_if_single=False)
1437        def func(a):
1438            return a
1439
1440        res = func(a_1d)
1441        # make sure the input array is just passed through:
1442        assert res is a_1d
1443
1444        # Test with shape conversion:
1445        @check_coords('a', enforce_copy=False, reduce_result_if_single=False)
1446        def func(a):
1447            return a
1448
1449        res = func(a_1d)
1450        assert res.shape == (1, 3)
1451        assert_array_equal(res[0], a_1d)
1452
1453    def test_no_check_lengths_match(self):
1454
1455        a_2d = np.zeros((1, 3), dtype=np.float32)
1456        b_2d = np.zeros((3, 3), dtype=np.float32)
1457
1458        @check_coords('a', 'b', enforce_copy=False, check_lengths_match=False)
1459        def func(a, b):
1460            return a, b
1461
1462        res_a, res_b = func(a_2d, b_2d)
1463        # Assert arrays are just passed through:
1464        assert res_a is a_2d
1465        assert res_b is b_2d
1466
1467    def test_invalid_input(self):
1468
1469        a_inv_dtype = np.array([['hello', 'world', '!']])
1470        a_inv_type = [[0., 0., 0.]]
1471        a_inv_shape_1d = np.zeros(6, dtype=np.float32)
1472        a_inv_shape_2d = np.zeros((3, 2), dtype=np.float32)
1473
1474        @check_coords('a')
1475        def func(a):
1476            pass
1477
1478        with pytest.raises(TypeError) as err:
1479            func(a_inv_dtype)
1480            assert err.msg.startswith("func(): a.dtype must be convertible to "
1481                                      "float32, got ")
1482
1483        with pytest.raises(TypeError) as err:
1484            func(a_inv_type)
1485            assert err.msg == ("func(): Parameter 'a' must be a numpy.ndarray, "
1486                               "got <class 'list'>.")
1487
1488        with pytest.raises(ValueError) as err:
1489            func(a_inv_shape_1d)
1490            assert err.msg == ("func(): a.shape must be (3,) or (n, 3), got "
1491                               "(6,).")
1492
1493        with pytest.raises(ValueError) as err:
1494            func(a_inv_shape_2d)
1495            assert err.msg == ("func(): a.shape must be (3,) or (n, 3), got "
1496                               "(3, 2).")
1497
1498    def test_usage_with_kwargs(self):
1499
1500        a_2d = np.zeros((1, 3), dtype=np.float32)
1501
1502        @check_coords('a', enforce_copy=False)
1503        def func(a, b, c=0):
1504            return a, b, c
1505
1506        # check correct functionality if passed as keyword argument:
1507        a, b, c = func(a=a_2d, b=0, c=1)
1508        assert a is a_2d
1509        assert b == 0
1510        assert c == 1
1511
1512    def test_wrong_func_call(self):
1513
1514        @check_coords('a', enforce_copy=False)
1515        def func(a, b, c=0):
1516            pass
1517
1518        # Make sure invalid call marker is present:
1519        func._invalid_call = False
1520
1521        # usage with posarg doubly defined:
1522        assert not func._invalid_call
1523        with pytest.raises(TypeError):
1524            func(0, a=0)  # pylint: disable=redundant-keyword-arg
1525        assert func._invalid_call
1526        func._invalid_call = False
1527
1528        # usage with missing posargs:
1529        assert not func._invalid_call
1530        with pytest.raises(TypeError):
1531            func(0)
1532        assert func._invalid_call
1533        func._invalid_call = False
1534
1535        # usage with missing posargs (supplied as kwargs):
1536        assert not func._invalid_call
1537        with pytest.raises(TypeError):
1538            func(a=0, c=1)
1539        assert func._invalid_call
1540        func._invalid_call = False
1541
1542        # usage with too many posargs:
1543        assert not func._invalid_call
1544        with pytest.raises(TypeError):
1545            func(0, 0, 0, 0)
1546        assert func._invalid_call
1547        func._invalid_call = False
1548
1549        # usage with unexpected kwarg:
1550        assert not func._invalid_call
1551        with pytest.raises(TypeError):
1552            func(a=0, b=0, c=1, d=1)  # pylint: disable=unexpected-keyword-arg
1553        assert func._invalid_call
1554        func._invalid_call = False
1555
1556    def test_wrong_decorator_usage(self):
1557
1558        # usage without parantheses:
1559        @check_coords
1560        def func():
1561            pass
1562
1563        with pytest.raises(TypeError):
1564            func()
1565
1566        # usage without arguments:
1567        with pytest.raises(ValueError) as err:
1568            @check_coords()
1569            def func():
1570                pass
1571
1572            assert err.msg == ("Decorator check_coords() cannot be used "
1573                               "without positional arguments.")
1574
1575        # usage with defaultarg:
1576        with pytest.raises(ValueError) as err:
1577            @check_coords('a')
1578            def func(a=1):
1579                pass
1580
1581            assert err.msg == ("In decorator check_coords(): Name 'a' doesn't "
1582                               "correspond to any positional argument of the "
1583                               "decorated function func().")
1584
1585        # usage with invalid parameter name:
1586        with pytest.raises(ValueError) as err:
1587            @check_coords('b')
1588            def func(a):
1589                pass
1590
1591            assert err.msg == ("In decorator check_coords(): Name 'b' doesn't "
1592                               "correspond to any positional argument of the "
1593                               "decorated function func().")
1594
1595
1596@pytest.mark.parametrize("old_name", (None, "MDAnalysis.Universe"))
1597@pytest.mark.parametrize("new_name", (None, "Multiverse"))
1598@pytest.mark.parametrize("remove", (None, "99.0.0", 2099))
1599@pytest.mark.parametrize("message", (None, "use the new stuff"))
1600def test_deprecate(old_name, new_name, remove, message, release="2.7.1"):
1601    def AlternateUniverse(anything):
1602        # important: first line needs to be """\ so that textwrap.dedent()
1603        # works
1604        """\
1605        AlternateUniverse provides a true view of the Universe.
1606
1607        Parameters
1608        ----------
1609        anything : object
1610
1611        Returns
1612        -------
1613        truth
1614
1615        """
1616        return True
1617
1618    oldfunc = util.deprecate(AlternateUniverse, old_name=old_name,
1619                             new_name=new_name,
1620                             release=release, remove=remove,
1621                             message=message)
1622    with pytest.warns(DeprecationWarning, match_expr="`.+` is deprecated"):
1623        oldfunc(42)
1624
1625    doc = oldfunc.__doc__
1626    name = old_name if old_name else AlternateUniverse.__name__
1627
1628    deprecation_line_1 = ".. deprecated:: {0}".format(release)
1629    assert re.search(deprecation_line_1, doc)
1630
1631    if message:
1632        deprecation_line_2 = message
1633    else:
1634        if new_name is None:
1635            default_message = "`{0}` is deprecated!".format(name)
1636        else:
1637            default_message = "`{0}` is deprecated, use `{1}` instead!".format(
1638                name, new_name)
1639        deprecation_line_2 = default_message
1640    assert re.search(deprecation_line_2, doc)
1641
1642    if remove:
1643        deprecation_line_3 = "`{0}` will be removed in release {1}".format(
1644            name,  remove)
1645        assert re.search(deprecation_line_3, doc)
1646
1647    # check that the old docs are still present
1648    assert re.search(textwrap.dedent(AlternateUniverse.__doc__), doc)
1649
1650
1651def test_deprecate_missing_release_ValueError():
1652    with pytest.raises(ValueError):
1653        util.deprecate(mda.Universe)
1654
1655def test_set_function_name(name="bar"):
1656    def foo():
1657        pass
1658    util._set_function_name(foo, name)
1659    assert foo.__name__ == name
1660
1661@pytest.mark.parametrize("text",
1662                         ("",
1663                          "one line text",
1664                          "  one line with leading space",
1665                          "multiline\n\n   with some\n   leading space",
1666                          "   multiline\n\n   with all\n   leading space"))
1667def test_dedent_docstring(text):
1668    doc = util.dedent_docstring(text)
1669    for line in doc.splitlines():
1670        assert line == line.lstrip()
1671
1672
1673class TestCheckBox(object):
1674
1675    prec = 6
1676    ref_ortho = np.ones(3, dtype=np.float32)
1677    ref_tri_vecs = np.array([[1, 0, 0], [0, 1, 0], [0, 2 ** 0.5, 2 ** 0.5]],
1678                            dtype=np.float32)
1679
1680    @pytest.mark.parametrize('box',
1681        ([1, 1, 1, 90, 90, 90],
1682         (1, 1, 1, 90, 90, 90),
1683         ['1', '1', 1, 90, '90', '90'],
1684         ('1', '1', 1, 90, '90', '90'),
1685         np.array(['1', '1', 1, 90, '90', '90']),
1686         np.array([1, 1, 1, 90, 90, 90], dtype=np.float32),
1687         np.array([1, 1, 1, 90, 90, 90], dtype=np.float64),
1688         np.array([1, 1, 1, 1, 1, 1, 90, 90, 90, 90, 90, 90],
1689                  dtype=np.float32)[::2]))
1690    def test_ckeck_box_ortho(self, box):
1691        boxtype, checked_box = util.check_box(box)
1692        assert boxtype == 'ortho'
1693        assert_equal(checked_box, self.ref_ortho)
1694        assert checked_box.dtype == np.float32
1695        assert checked_box.flags['C_CONTIGUOUS']
1696
1697    @pytest.mark.parametrize('box',
1698         ([1, 1, 2, 45, 90, 90],
1699          (1, 1, 2, 45, 90, 90),
1700          ['1', '1', 2, 45, '90', '90'],
1701          ('1', '1', 2, 45, '90', '90'),
1702          np.array(['1', '1', 2, 45, '90', '90']),
1703          np.array([1, 1, 2, 45, 90, 90], dtype=np.float32),
1704          np.array([1, 1, 2, 45, 90, 90], dtype=np.float64),
1705          np.array([1, 1, 1, 1, 2, 2, 45, 45, 90, 90, 90, 90],
1706                   dtype=np.float32)[::2]))
1707    def test_check_box_tri_vecs(self, box):
1708        boxtype, checked_box = util.check_box(box)
1709        assert boxtype == 'tri_vecs'
1710        assert_almost_equal(checked_box, self.ref_tri_vecs, self.prec)
1711        assert checked_box.dtype == np.float32
1712        assert checked_box.flags['C_CONTIGUOUS']
1713
1714    def test_check_box_wrong_data(self):
1715        with pytest.raises(ValueError):
1716            wrongbox = ['invalid', 1, 1, 90, 90, 90]
1717            boxtype, checked_box = util.check_box(wrongbox)
1718
1719    def test_check_box_wrong_shape(self):
1720        with pytest.raises(ValueError):
1721            wrongbox = np.ones((3, 3), dtype=np.float32)
1722            boxtype, checked_box = util.check_box(wrongbox)
1723