1import numpy as np 2import pytest 3import warnings 4 5from ase import Atom, Atoms 6from ase.io import read 7from ase.io import NetCDFTrajectory 8 9 10@pytest.fixture(scope='module') 11def netCDF4(): 12 return pytest.importorskip('netCDF4') 13 14 15@pytest.fixture(autouse=True) 16def catch_netcdf4_warning(): 17 with warnings.catch_warnings(): 18 # XXX Ignore deprecation warning from numpy over how netCDF4 19 # uses numpy. We can't really do anything about that. 20 warnings.simplefilter('ignore', DeprecationWarning) 21 yield 22 23 24@pytest.fixture 25def co(netCDF4): 26 return Atoms([Atom('C', (0, 0, 0)), 27 Atom('O', (0, 0, 1.2))], 28 cell=[3, 3, 3], 29 pbc=True) 30 31 32def test_netcdftrajectory(co): 33 rng = np.random.RandomState(17) 34 traj = NetCDFTrajectory('1.nc', 'w', co) 35 for i in range(5): 36 co.positions[:, 2] += 0.1 37 traj.write() 38 del traj 39 traj = NetCDFTrajectory('1.nc', 'a') 40 co = traj[-1] 41 print(co.positions) 42 co.positions[:] += 1 43 traj.write(co) 44 del traj 45 t = NetCDFTrajectory('1.nc', 'a') 46 47 print(t[-1].positions) 48 print('.--------') 49 for i, a in enumerate(t): 50 if i < 4: 51 print(1, a.positions[-1, 2], 1.3 + i * 0.1) 52 assert abs(a.positions[-1, 2] - 1.3 - i * 0.1) < 1e-6 53 else: 54 print(1, a.positions[-1, 2], 1.7 + i - 4) 55 assert abs(a.positions[-1, 2] - 1.7 - i + 4) < 1e-6 56 assert a.pbc.all() 57 co.positions[:] += 1 58 t.write(co) 59 for i, a in enumerate(t): 60 if i < 4: 61 print(2, a.positions[-1, 2], 1.3 + i * 0.1) 62 assert abs(a.positions[-1, 2] - 1.3 - i * 0.1) < 1e-6 63 else: 64 print(2, a.positions[-1, 2], 1.7 + i - 4) 65 assert abs(a.positions[-1, 2] - 1.7 - i + 4) < 1e-6 66 assert len(t) == 7 67 68 # Change atom type and append 69 co[0].number = 1 70 t.write(co) 71 t2 = NetCDFTrajectory('1.nc', 'r') 72 co2 = t2[-1] 73 assert (co2.numbers == co.numbers).all() 74 del t2 75 76 co[0].number = 6 77 t.write(co) 78 79 co.pbc = False 80 o = co.pop(1) 81 try: 82 t.write(co) 83 except ValueError: 84 pass 85 else: 86 assert False 87 88 co.append(o) 89 co.pbc = True 90 t.write(co) 91 del t 92 93 # append to a nonexisting file 94 fname = '2.nc' 95 t = NetCDFTrajectory(fname, 'a', co) 96 del t 97 98 fname = '3.nc' 99 t = NetCDFTrajectory(fname, 'w', co) 100 # File is not created before first write 101 co.set_pbc([True, False, False]) 102 d = co.get_distance(0, 1) 103 with pytest.warns(None): 104 t.write(co) 105 del t 106 # Check pbc 107 for c in [1, 1000]: 108 t = NetCDFTrajectory(fname, chunk_size=c) 109 a = t[-1] 110 assert a.pbc[0] and not a.pbc[1] and not a.pbc[2] 111 assert abs(a.get_distance(0, 1) - d) < 1e-6 112 del t 113 # Append something in Voigt notation 114 t = NetCDFTrajectory(fname, 'a') 115 for frame, a in enumerate(t): 116 test = rng.random([len(a), 6]) 117 a.set_array('test', test) 118 t.write_arrays(a, frame, ['test']) 119 del t 120 121 # Check cell origin 122 co.set_pbc(True) 123 co.set_celldisp([1, 2, 3]) 124 traj = NetCDFTrajectory('4.nc', 'w', co) 125 traj.write(co) 126 traj.close() 127 128 traj = NetCDFTrajectory('4.nc', 'r') 129 a = traj[0] 130 assert np.all(abs(a.get_celldisp() - np.array([1, 2, 3])) < 1e-12) 131 traj.close() 132 133 # Add 'id' field and check if it is read correctly 134 co.set_array('id', np.array([2, 1])) 135 traj = NetCDFTrajectory('5.nc', 'w', co) 136 traj.write(co, arrays=['id']) 137 traj.close() 138 139 traj = NetCDFTrajectory('5.nc', 'r') 140 assert np.all(traj[0].numbers == [8, 6]) 141 assert np.all(np.abs(traj[0].positions - np.array([[2, 2, 3.7], 142 [2., 2., 2.5]])) < 1e-6) 143 traj.close() 144 145 a = read('5.nc') 146 assert(len(a) == 2) 147 148 149def test_netcdf_with_variable_atomic_numbers(netCDF4): 150 # Create a NetCDF file with a per-file definition of atomic numbers. ASE 151 # NetCDFTrajectory can read but not write these types of files. 152 nc = netCDF4.Dataset('6.nc', 'w') 153 nc.createDimension('frame', None) 154 nc.createDimension('atom', 2) 155 nc.createDimension('spatial', 3) 156 nc.createDimension('cell_spatial', 3) 157 nc.createDimension('cell_angular', 3) 158 159 nc.createVariable('atom_types', 'i', ('atom',)) 160 nc.createVariable('coordinates', 'f4', ('frame', 'atom', 'spatial',)) 161 nc.createVariable('cell_lengths', 'f4', ('frame', 'cell_spatial',)) 162 nc.createVariable('cell_angles', 'f4', ('frame', 'cell_angular',)) 163 164 r0 = np.array([[1, 2, 3], [4, 5, 6]], dtype=float) 165 r1 = 2 * r0 166 167 nc.variables['atom_types'][:] = [1, 2] 168 nc.variables['coordinates'][0] = r0 169 nc.variables['coordinates'][1] = r1 170 nc.variables['cell_lengths'][:] = 0 171 nc.variables['cell_angles'][:] = 90 172 173 nc.close() 174 175 traj = NetCDFTrajectory('6.nc', 'r') 176 assert np.allclose(traj[0].positions, r0) 177 assert np.allclose(traj[1].positions, r1) 178 traj.close() 179 180 181def test_netcdf_with_nonconsecutive_index(netCDF4): 182 nc = netCDF4.Dataset('7.nc', 'w') 183 nc.createDimension('frame', None) 184 nc.createDimension('atom', 3) 185 nc.createDimension('spatial', 3) 186 nc.createDimension('cell_spatial', 3) 187 nc.createDimension('cell_angular', 3) 188 189 nc.createVariable('atom_types', 'i', ('atom',)) 190 nc.createVariable('coordinates', 'f4', ('frame', 'atom', 'spatial',)) 191 nc.createVariable('cell_lengths', 'f4', ('frame', 'cell_spatial',)) 192 nc.createVariable('cell_angles', 'f4', ('frame', 'cell_angular',)) 193 nc.createVariable('id', 'i', ('frame', 'atom',)) 194 195 r0 = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=float) 196 r1 = 2 * r0 197 198 nc.variables['atom_types'][:] = [1, 2, 3] 199 nc.variables['coordinates'][0] = r0 200 nc.variables['coordinates'][1] = r1 201 nc.variables['cell_lengths'][:] = 0 202 nc.variables['cell_angles'][:] = 90 203 nc.variables['id'][0] = [13, 3, 5] 204 nc.variables['id'][1] = [-1, 0, -5] 205 206 nc.close() 207 208 traj = NetCDFTrajectory('7.nc', 'r') 209 assert (traj[0].numbers == [2, 3, 1]).all() 210 assert (traj[1].numbers == [3, 1, 2]).all() 211 traj.close() 212 213 214def test_types_to_numbers_argument(co): 215 traj = NetCDFTrajectory('8.nc', 'w', co) 216 traj.write() 217 traj.close() 218 d = {6: 15, 8: 15} 219 traj = NetCDFTrajectory('8.nc', mode="r", types_to_numbers=d) 220 assert np.allclose(traj[-1].get_masses(), 30.974) 221 assert (traj[-1].numbers == [15, 15]).all() 222 d = {3: 14} 223 traj = NetCDFTrajectory('8.nc', mode="r", types_to_numbers=d) 224 assert (traj[-1].numbers == [6, 8]).all() 225 traj = NetCDFTrajectory('8.nc', 'r', 226 types_to_numbers=[0, 0, 0, 0, 0, 0, 15]) 227 assert (traj[-1].numbers == [15, 8]).all() 228 229 traj.close() 230