1# This file is part of h5py, a Python interface to the HDF5 library.
2#
3# http://www.h5py.org
4#
5# Copyright 2008-2013 Andrew Collette and contributors
6#
7# License:  Standard 3-clause BSD; see "license.txt" for full license terms
8#           and contributor agreement.
9
10"""
11    Dataset slicing test module.
12
13    Tests all supported slicing operations, including read/write and
14    broadcasting operations.  Does not test type conversion except for
15    corner cases overlapping with slicing; for example, when selecting
16    specific fields of a compound type.
17"""
18
19import numpy as np
20
21from .common import ut, TestCase
22
23import h5py
24from h5py import h5s, h5t, h5d
25from h5py import File, MultiBlockSlice
26
27class BaseSlicing(TestCase):
28
29    def setUp(self):
30        self.f = File(self.mktemp(), 'w')
31
32    def tearDown(self):
33        if self.f:
34            self.f.close()
35
36class TestSingleElement(BaseSlicing):
37
38    """
39        Feature: Retrieving a single element works with NumPy semantics
40    """
41
42    def test_single_index(self):
43        """ Single-element selection with [index] yields array scalar """
44        dset = self.f.create_dataset('x', (1,), dtype='i1')
45        out = dset[0]
46        self.assertIsInstance(out, np.int8)
47
48    def test_single_null(self):
49        """ Single-element selection with [()] yields ndarray """
50        dset = self.f.create_dataset('x', (1,), dtype='i1')
51        out = dset[()]
52        self.assertIsInstance(out, np.ndarray)
53        self.assertEqual(out.shape, (1,))
54
55    def test_scalar_index(self):
56        """ Slicing with [...] yields scalar ndarray """
57        dset = self.f.create_dataset('x', shape=(), dtype='f')
58        out = dset[...]
59        self.assertIsInstance(out, np.ndarray)
60        self.assertEqual(out.shape, ())
61
62    def test_scalar_null(self):
63        """ Slicing with [()] yields array scalar """
64        dset = self.f.create_dataset('x', shape=(), dtype='i1')
65        out = dset[()]
66        self.assertIsInstance(out, np.int8)
67
68    def test_compound(self):
69        """ Compound scalar is numpy.void, not tuple (issue 135) """
70        dt = np.dtype([('a','i4'),('b','f8')])
71        v = np.ones((4,), dtype=dt)
72        dset = self.f.create_dataset('foo', (4,), data=v)
73        self.assertEqual(dset[0], v[0])
74        self.assertIsInstance(dset[0], np.void)
75
76class TestObjectIndex(BaseSlicing):
77
78    """
79        Feature: numpy.object_ subtypes map to real Python objects
80    """
81
82    def test_reference(self):
83        """ Indexing a reference dataset returns a h5py.Reference instance """
84        dset = self.f.create_dataset('x', (1,), dtype=h5py.ref_dtype)
85        dset[0] = self.f.ref
86        self.assertEqual(type(dset[0]), h5py.Reference)
87
88    def test_regref(self):
89        """ Indexing a region reference dataset returns a h5py.RegionReference
90        """
91        dset1 = self.f.create_dataset('x', (10,10))
92        regref = dset1.regionref[...]
93        dset2 = self.f.create_dataset('y', (1,), dtype=h5py.regionref_dtype)
94        dset2[0] = regref
95        self.assertEqual(type(dset2[0]), h5py.RegionReference)
96
97    def test_reference_field(self):
98        """ Compound types of which a reference is an element work right """
99        dt = np.dtype([('a', 'i'),('b', h5py.ref_dtype)])
100
101        dset = self.f.create_dataset('x', (1,), dtype=dt)
102        dset[0] = (42, self.f['/'].ref)
103
104        out = dset[0]
105        self.assertEqual(type(out[1]), h5py.Reference)  # isinstance does NOT work
106
107    def test_scalar(self):
108        """ Indexing returns a real Python object on scalar datasets """
109        dset = self.f.create_dataset('x', (), dtype=h5py.ref_dtype)
110        dset[()] = self.f.ref
111        self.assertEqual(type(dset[()]), h5py.Reference)
112
113    def test_bytestr(self):
114        """ Indexing a byte string dataset returns a real python byte string
115        """
116        dset = self.f.create_dataset('x', (1,), dtype=h5py.string_dtype(encoding='ascii'))
117        dset[0] = b"Hello there!"
118        self.assertEqual(type(dset[0]), bytes)
119
120class TestSimpleSlicing(TestCase):
121
122    """
123        Feature: Simple NumPy-style slices (start:stop:step) are supported.
124    """
125
126    def setUp(self):
127        self.f = File(self.mktemp(), 'w')
128        self.arr = np.arange(10)
129        self.dset = self.f.create_dataset('x', data=self.arr)
130
131    def tearDown(self):
132        if self.f:
133            self.f.close()
134
135    def test_negative_stop(self):
136        """ Negative stop indexes work as they do in NumPy """
137        self.assertArrayEqual(self.dset[2:-2], self.arr[2:-2])
138
139    def test_write(self):
140        """Assigning to a 1D slice of a 2D dataset
141        """
142        dset = self.f.create_dataset('x2', (10, 2))
143
144        x = np.zeros((10, 1))
145        dset[:, 0] = x[:, 0]
146        with self.assertRaises(TypeError):
147            dset[:, 1] = x
148
149class TestArraySlicing(BaseSlicing):
150
151    """
152        Feature: Array types are handled appropriately
153    """
154
155    def test_read(self):
156        """ Read arrays tack array dimensions onto end of shape tuple """
157        dt = np.dtype('(3,)f8')
158        dset = self.f.create_dataset('x',(10,),dtype=dt)
159        self.assertEqual(dset.shape, (10,))
160        self.assertEqual(dset.dtype, dt)
161
162        # Full read
163        out = dset[...]
164        self.assertEqual(out.dtype, np.dtype('f8'))
165        self.assertEqual(out.shape, (10,3))
166
167        # Single element
168        out = dset[0]
169        self.assertEqual(out.dtype, np.dtype('f8'))
170        self.assertEqual(out.shape, (3,))
171
172        # Range
173        out = dset[2:8:2]
174        self.assertEqual(out.dtype, np.dtype('f8'))
175        self.assertEqual(out.shape, (3,3))
176
177    def test_write_broadcast(self):
178        """ Array fill from constant is not supported (issue 211).
179        """
180        dt = np.dtype('(3,)i')
181
182        dset = self.f.create_dataset('x', (10,), dtype=dt)
183
184        with self.assertRaises(TypeError):
185            dset[...] = 42
186
187    def test_write_element(self):
188        """ Write a single element to the array
189
190        Issue 211.
191        """
192        dt = np.dtype('(3,)f8')
193        dset = self.f.create_dataset('x', (10,), dtype=dt)
194
195        data = np.array([1,2,3.0])
196        dset[4] = data
197
198        out = dset[4]
199        self.assertTrue(np.all(out == data))
200
201    def test_write_slices(self):
202        """ Write slices to array type """
203        dt = np.dtype('(3,)i')
204
205        data1 = np.ones((2,), dtype=dt)
206        data2 = np.ones((4,5), dtype=dt)
207
208        dset = self.f.create_dataset('x', (10,9,11), dtype=dt)
209
210        dset[0,0,2:4] = data1
211        self.assertArrayEqual(dset[0,0,2:4], data1)
212
213        dset[3, 1:5, 6:11] = data2
214        self.assertArrayEqual(dset[3, 1:5, 6:11], data2)
215
216
217    def test_roundtrip(self):
218        """ Read the contents of an array and write them back
219
220        Issue 211.
221        """
222        dt = np.dtype('(3,)f8')
223        dset = self.f.create_dataset('x', (10,), dtype=dt)
224
225        out = dset[...]
226        dset[...] = out
227
228        self.assertTrue(np.all(dset[...] == out))
229
230
231class TestZeroLengthSlicing(BaseSlicing):
232
233    """
234        Slices resulting in empty arrays
235    """
236
237    def test_slice_zero_length_dimension(self):
238        """ Slice a dataset with a zero in its shape vector
239            along the zero-length dimension """
240        for i, shape in enumerate([(0,), (0, 3), (0, 2, 1)]):
241            dset = self.f.create_dataset('x%d'%i, shape, dtype=int, maxshape=(None,)*len(shape))
242            self.assertEqual(dset.shape, shape)
243            out = dset[...]
244            self.assertIsInstance(out, np.ndarray)
245            self.assertEqual(out.shape, shape)
246            out = dset[:]
247            self.assertIsInstance(out, np.ndarray)
248            self.assertEqual(out.shape, shape)
249            if len(shape) > 1:
250                out = dset[:, :1]
251                self.assertIsInstance(out, np.ndarray)
252                self.assertEqual(out.shape[:2], (0, 1))
253
254    def test_slice_other_dimension(self):
255        """ Slice a dataset with a zero in its shape vector
256            along a non-zero-length dimension """
257        for i, shape in enumerate([(3, 0), (1, 2, 0), (2, 0, 1)]):
258            dset = self.f.create_dataset('x%d'%i, shape, dtype=int, maxshape=(None,)*len(shape))
259            self.assertEqual(dset.shape, shape)
260            out = dset[:1]
261            self.assertIsInstance(out, np.ndarray)
262            self.assertEqual(out.shape, (1,)+shape[1:])
263
264    def test_slice_of_length_zero(self):
265        """ Get a slice of length zero from a non-empty dataset """
266        for i, shape in enumerate([(3,), (2, 2,), (2,  1, 5)]):
267            dset = self.f.create_dataset('x%d'%i, data=np.zeros(shape, int), maxshape=(None,)*len(shape))
268            self.assertEqual(dset.shape, shape)
269            out = dset[1:1]
270            self.assertIsInstance(out, np.ndarray)
271            self.assertEqual(out.shape, (0,)+shape[1:])
272
273class TestFieldNames(BaseSlicing):
274
275    """
276        Field names for read & write
277    """
278
279    dt = np.dtype([('a', 'f'), ('b', 'i'), ('c', 'f4')])
280    data = np.ones((100,), dtype=dt)
281
282    def setUp(self):
283        BaseSlicing.setUp(self)
284        self.dset = self.f.create_dataset('x', (100,), dtype=self.dt)
285        self.dset[...] = self.data
286
287    def test_read(self):
288        """ Test read with field selections """
289        self.assertArrayEqual(self.dset['a'], self.data['a'])
290
291    def test_unicode_names(self):
292        """ Unicode field names for for read and write """
293        self.assertArrayEqual(self.dset['a'], self.data['a'])
294        self.dset['a'] = 42
295        data = self.data.copy()
296        data['a'] = 42
297        self.assertArrayEqual(self.dset['a'], data['a'])
298
299    def test_write(self):
300        """ Test write with field selections """
301        data2 = self.data.copy()
302        data2['a'] *= 2
303        self.dset['a'] = data2
304        self.assertTrue(np.all(self.dset[...] == data2))
305        data2['b'] *= 4
306        self.dset['b'] = data2
307        self.assertTrue(np.all(self.dset[...] == data2))
308        data2['a'] *= 3
309        data2['c'] *= 3
310        self.dset['a','c'] = data2
311        self.assertTrue(np.all(self.dset[...] == data2))
312
313    def test_write_noncompound(self):
314        """ Test write with non-compound source (single-field) """
315        data2 = self.data.copy()
316        data2['b'] = 1.0
317        self.dset['b'] = 1.0
318        self.assertTrue(np.all(self.dset[...] == data2))
319
320
321class TestMultiBlockSlice(BaseSlicing):
322
323    def setUp(self):
324        super(TestMultiBlockSlice, self).setUp()
325        self.arr = np.arange(10)
326        self.dset = self.f.create_dataset('x', data=self.arr)
327
328    def test_default(self):
329        # Default selects entire dataset as one block
330        mbslice = MultiBlockSlice()
331
332        self.assertEqual(mbslice.indices(10), (0, 1, 10, 1))
333        np.testing.assert_array_equal(self.dset[mbslice], self.arr)
334
335    def test_default_explicit(self):
336        mbslice = MultiBlockSlice(start=0, count=10, stride=1, block=1)
337
338        self.assertEqual(mbslice.indices(10), (0, 1, 10, 1))
339        np.testing.assert_array_equal(self.dset[mbslice], self.arr)
340
341    def test_start(self):
342        mbslice = MultiBlockSlice(start=4)
343
344        self.assertEqual(mbslice.indices(10), (4, 1, 6, 1))
345        np.testing.assert_array_equal(self.dset[mbslice], np.array([4, 5, 6, 7, 8, 9]))
346
347    def test_count(self):
348        mbslice = MultiBlockSlice(count=7)
349
350        self.assertEqual(mbslice.indices(10), (0, 1, 7, 1))
351        np.testing.assert_array_equal(
352            self.dset[mbslice], np.array([0, 1, 2, 3, 4, 5, 6])
353        )
354
355    def test_count_more_than_length_error(self):
356        mbslice = MultiBlockSlice(count=11)
357        with self.assertRaises(ValueError):
358            mbslice.indices(10)
359
360    def test_stride(self):
361        mbslice = MultiBlockSlice(stride=2)
362
363        self.assertEqual(mbslice.indices(10), (0, 2, 5, 1))
364        np.testing.assert_array_equal(self.dset[mbslice], np.array([0, 2, 4, 6, 8]))
365
366    def test_stride_zero_error(self):
367        with self.assertRaises(ValueError):
368            # This would cause a ZeroDivisionError if not caught
369            MultiBlockSlice(stride=0, block=0).indices(10)
370
371    def test_stride_block_equal(self):
372        mbslice = MultiBlockSlice(stride=2, block=2)
373
374        self.assertEqual(mbslice.indices(10), (0, 2, 5, 2))
375        np.testing.assert_array_equal(self.dset[mbslice], self.arr)
376
377    def test_block_more_than_stride_error(self):
378        with self.assertRaises(ValueError):
379            MultiBlockSlice(block=3)
380
381        with self.assertRaises(ValueError):
382            MultiBlockSlice(stride=2, block=3)
383
384    def test_stride_more_than_block(self):
385        mbslice = MultiBlockSlice(stride=3, block=2)
386
387        self.assertEqual(mbslice.indices(10), (0, 3, 3, 2))
388        np.testing.assert_array_equal(self.dset[mbslice], np.array([0, 1, 3, 4, 6, 7]))
389
390    def test_block_overruns_extent_error(self):
391        # If fully described then must fit within extent
392        mbslice = MultiBlockSlice(start=2, count=2, stride=5, block=4)
393        with self.assertRaises(ValueError):
394            mbslice.indices(10)
395
396    def test_fully_described(self):
397        mbslice = MultiBlockSlice(start=1, count=2, stride=5, block=4)
398
399        self.assertEqual(mbslice.indices(10), (1, 5, 2, 4))
400        np.testing.assert_array_equal(
401            self.dset[mbslice], np.array([1, 2, 3, 4, 6, 7, 8, 9])
402        )
403
404    def test_count_calculated(self):
405        # If not given, count should be calculated to select as many full blocks as possible
406        mbslice = MultiBlockSlice(start=1, stride=3, block=2)
407
408        self.assertEqual(mbslice.indices(10), (1, 3, 3, 2))
409        np.testing.assert_array_equal(self.dset[mbslice], np.array([1, 2, 4, 5, 7, 8]))
410
411    def test_zero_count_calculated_error(self):
412        # In this case, there is no possible count to select even one block, so error
413        mbslice = MultiBlockSlice(start=8, stride=4, block=3)
414
415        with self.assertRaises(ValueError):
416            mbslice.indices(10)
417