1import numpy as np
2import pytest
3import warnings
4
5from ase import Atom, Atoms
6from ase.io import read
7from ase.io import NetCDFTrajectory
8
9
10@pytest.fixture(scope='module')
11def netCDF4():
12    return pytest.importorskip('netCDF4')
13
14
15@pytest.fixture(autouse=True)
16def catch_netcdf4_warning():
17    with warnings.catch_warnings():
18        # XXX Ignore deprecation warning from numpy over how netCDF4
19        # uses numpy.  We can't really do anything about that.
20        warnings.simplefilter('ignore', DeprecationWarning)
21        yield
22
23
24@pytest.fixture
25def co(netCDF4):
26    return Atoms([Atom('C', (0, 0, 0)),
27                  Atom('O', (0, 0, 1.2))],
28                 cell=[3, 3, 3],
29                 pbc=True)
30
31
32def test_netcdftrajectory(co):
33    rng = np.random.RandomState(17)
34    traj = NetCDFTrajectory('1.nc', 'w', co)
35    for i in range(5):
36        co.positions[:, 2] += 0.1
37        traj.write()
38    del traj
39    traj = NetCDFTrajectory('1.nc', 'a')
40    co = traj[-1]
41    print(co.positions)
42    co.positions[:] += 1
43    traj.write(co)
44    del traj
45    t = NetCDFTrajectory('1.nc', 'a')
46
47    print(t[-1].positions)
48    print('.--------')
49    for i, a in enumerate(t):
50        if i < 4:
51            print(1, a.positions[-1, 2], 1.3 + i * 0.1)
52            assert abs(a.positions[-1, 2] - 1.3 - i * 0.1) < 1e-6
53        else:
54            print(1, a.positions[-1, 2], 1.7 + i - 4)
55            assert abs(a.positions[-1, 2] - 1.7 - i + 4) < 1e-6
56        assert a.pbc.all()
57    co.positions[:] += 1
58    t.write(co)
59    for i, a in enumerate(t):
60        if i < 4:
61            print(2, a.positions[-1, 2], 1.3 + i * 0.1)
62            assert abs(a.positions[-1, 2] - 1.3 - i * 0.1) < 1e-6
63        else:
64            print(2, a.positions[-1, 2], 1.7 + i - 4)
65            assert abs(a.positions[-1, 2] - 1.7 - i + 4) < 1e-6
66    assert len(t) == 7
67
68    # Change atom type and append
69    co[0].number = 1
70    t.write(co)
71    t2 = NetCDFTrajectory('1.nc', 'r')
72    co2 = t2[-1]
73    assert (co2.numbers == co.numbers).all()
74    del t2
75
76    co[0].number = 6
77    t.write(co)
78
79    co.pbc = False
80    o = co.pop(1)
81    try:
82        t.write(co)
83    except ValueError:
84        pass
85    else:
86        assert False
87
88    co.append(o)
89    co.pbc = True
90    t.write(co)
91    del t
92
93    # append to a nonexisting file
94    fname = '2.nc'
95    t = NetCDFTrajectory(fname, 'a', co)
96    del t
97
98    fname = '3.nc'
99    t = NetCDFTrajectory(fname, 'w', co)
100    # File is not created before first write
101    co.set_pbc([True, False, False])
102    d = co.get_distance(0, 1)
103    with pytest.warns(None):
104        t.write(co)
105    del t
106    # Check pbc
107    for c in [1, 1000]:
108        t = NetCDFTrajectory(fname, chunk_size=c)
109        a = t[-1]
110        assert a.pbc[0] and not a.pbc[1] and not a.pbc[2]
111        assert abs(a.get_distance(0, 1) - d) < 1e-6
112        del t
113    # Append something in Voigt notation
114    t = NetCDFTrajectory(fname, 'a')
115    for frame, a in enumerate(t):
116        test = rng.random([len(a), 6])
117        a.set_array('test', test)
118        t.write_arrays(a, frame, ['test'])
119    del t
120
121    # Check cell origin
122    co.set_pbc(True)
123    co.set_celldisp([1, 2, 3])
124    traj = NetCDFTrajectory('4.nc', 'w', co)
125    traj.write(co)
126    traj.close()
127
128    traj = NetCDFTrajectory('4.nc', 'r')
129    a = traj[0]
130    assert np.all(abs(a.get_celldisp() - np.array([1, 2, 3])) < 1e-12)
131    traj.close()
132
133    # Add 'id' field and check if it is read correctly
134    co.set_array('id', np.array([2, 1]))
135    traj = NetCDFTrajectory('5.nc', 'w', co)
136    traj.write(co, arrays=['id'])
137    traj.close()
138
139    traj = NetCDFTrajectory('5.nc', 'r')
140    assert np.all(traj[0].numbers == [8, 6])
141    assert np.all(np.abs(traj[0].positions - np.array([[2, 2, 3.7],
142                                                       [2., 2., 2.5]])) < 1e-6)
143    traj.close()
144
145    a = read('5.nc')
146    assert(len(a) == 2)
147
148
149def test_netcdf_with_variable_atomic_numbers(netCDF4):
150    # Create a NetCDF file with a per-file definition of atomic numbers. ASE
151    # NetCDFTrajectory can read but not write these types of files.
152    nc = netCDF4.Dataset('6.nc', 'w')
153    nc.createDimension('frame', None)
154    nc.createDimension('atom', 2)
155    nc.createDimension('spatial', 3)
156    nc.createDimension('cell_spatial', 3)
157    nc.createDimension('cell_angular', 3)
158
159    nc.createVariable('atom_types', 'i', ('atom',))
160    nc.createVariable('coordinates', 'f4', ('frame', 'atom', 'spatial',))
161    nc.createVariable('cell_lengths', 'f4', ('frame', 'cell_spatial',))
162    nc.createVariable('cell_angles', 'f4', ('frame', 'cell_angular',))
163
164    r0 = np.array([[1, 2, 3], [4, 5, 6]], dtype=float)
165    r1 = 2 * r0
166
167    nc.variables['atom_types'][:] = [1, 2]
168    nc.variables['coordinates'][0] = r0
169    nc.variables['coordinates'][1] = r1
170    nc.variables['cell_lengths'][:] = 0
171    nc.variables['cell_angles'][:] = 90
172
173    nc.close()
174
175    traj = NetCDFTrajectory('6.nc', 'r')
176    assert np.allclose(traj[0].positions, r0)
177    assert np.allclose(traj[1].positions, r1)
178    traj.close()
179
180
181def test_netcdf_with_nonconsecutive_index(netCDF4):
182    nc = netCDF4.Dataset('7.nc', 'w')
183    nc.createDimension('frame', None)
184    nc.createDimension('atom', 3)
185    nc.createDimension('spatial', 3)
186    nc.createDimension('cell_spatial', 3)
187    nc.createDimension('cell_angular', 3)
188
189    nc.createVariable('atom_types', 'i', ('atom',))
190    nc.createVariable('coordinates', 'f4', ('frame', 'atom', 'spatial',))
191    nc.createVariable('cell_lengths', 'f4', ('frame', 'cell_spatial',))
192    nc.createVariable('cell_angles', 'f4', ('frame', 'cell_angular',))
193    nc.createVariable('id', 'i', ('frame', 'atom',))
194
195    r0 = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=float)
196    r1 = 2 * r0
197
198    nc.variables['atom_types'][:] = [1, 2, 3]
199    nc.variables['coordinates'][0] = r0
200    nc.variables['coordinates'][1] = r1
201    nc.variables['cell_lengths'][:] = 0
202    nc.variables['cell_angles'][:] = 90
203    nc.variables['id'][0] = [13, 3, 5]
204    nc.variables['id'][1] = [-1, 0, -5]
205
206    nc.close()
207
208    traj = NetCDFTrajectory('7.nc', 'r')
209    assert (traj[0].numbers == [2, 3, 1]).all()
210    assert (traj[1].numbers == [3, 1, 2]).all()
211    traj.close()
212
213
214def test_types_to_numbers_argument(co):
215    traj = NetCDFTrajectory('8.nc', 'w', co)
216    traj.write()
217    traj.close()
218    d = {6: 15, 8: 15}
219    traj = NetCDFTrajectory('8.nc', mode="r", types_to_numbers=d)
220    assert np.allclose(traj[-1].get_masses(), 30.974)
221    assert (traj[-1].numbers == [15, 15]).all()
222    d = {3: 14}
223    traj = NetCDFTrajectory('8.nc', mode="r", types_to_numbers=d)
224    assert (traj[-1].numbers == [6, 8]).all()
225    traj = NetCDFTrajectory('8.nc', 'r',
226                            types_to_numbers=[0, 0, 0, 0, 0, 0, 15])
227    assert (traj[-1].numbers == [15, 8]).all()
228
229    traj.close()
230