1import numpy as np 2import pytest 3 4import rasterio 5from rasterio import ( 6 ubyte, 7 uint8, 8 uint16, 9 uint32, 10 int16, 11 int32, 12 float32, 13 float64, 14 complex_, 15 complex_int16, 16) 17from rasterio.dtypes import ( 18 _gdal_typename, 19 is_ndarray, 20 check_dtype, 21 get_minimum_dtype, 22 can_cast_dtype, 23 validate_dtype, 24 _is_complex_int, 25 _getnpdtype, 26) 27 28 29def test_is_ndarray(): 30 assert is_ndarray(np.zeros((1,))) 31 assert not is_ndarray([0]) 32 assert not is_ndarray((0,)) 33 34 35def test_np_dt_uint8(): 36 assert check_dtype(np.uint8) 37 38 39def test_dt_ubyte(): 40 assert check_dtype(ubyte) 41 42 43def test_check_dtype_invalid(): 44 assert not check_dtype('foo') 45 46 47@pytest.mark.parametrize( 48 ("dtype", "name"), 49 [ 50 (ubyte, "Byte"), 51 (np.uint8, "Byte"), 52 (np.uint16, "UInt16"), 53 ("uint8", "Byte"), 54 ("complex_int16", "CInt16"), 55 (complex_int16, "CInt16"), 56 ], 57) 58def test_gdal_name(dtype, name): 59 assert _gdal_typename(dtype) == name 60 61 62def test_get_minimum_dtype(): 63 assert get_minimum_dtype([0, 1]) == uint8 64 assert get_minimum_dtype([0, 1000]) == uint16 65 assert get_minimum_dtype([0, 100000]) == uint32 66 assert get_minimum_dtype([-1, 0, 1]) == int16 67 assert get_minimum_dtype([-1, 0, 100000]) == int32 68 assert get_minimum_dtype([-1.5, 0, 1.5]) == float32 69 assert get_minimum_dtype([-1.5e+100, 0, 1.5e+100]) == float64 70 71 assert get_minimum_dtype(np.array([0, 1], dtype=np.uint)) == uint8 72 assert get_minimum_dtype(np.array([0, 1000], dtype=np.uint)) == uint16 73 assert get_minimum_dtype(np.array([0, 100000], dtype=np.uint)) == uint32 74 assert get_minimum_dtype(np.array([-1, 0, 1], dtype=int)) == int16 75 assert get_minimum_dtype(np.array([-1, 0, 100000], dtype=int)) == int32 76 assert get_minimum_dtype(np.array([-1.5, 0, 1.5], dtype=np.float64)) == float32 77 78 79def test_can_cast_dtype(): 80 assert can_cast_dtype((1, 2, 3), np.uint8) 81 assert can_cast_dtype(np.array([1, 2, 3]), np.uint8) 82 assert can_cast_dtype(np.array([1, 2, 3], dtype=np.uint8), np.uint8) 83 assert can_cast_dtype(np.array([1, 2, 3]), np.float32) 84 assert can_cast_dtype(np.array([1.4, 2.1, 3.65]), np.float32) 85 assert not can_cast_dtype(np.array([1.4, 2.1, 3.65]), np.uint8) 86 87 88@pytest.mark.parametrize("dtype", ["float64", "float32"]) 89def test_can_cast_dtype_nan(dtype): 90 assert can_cast_dtype([np.nan], dtype) 91 92 93@pytest.mark.parametrize("dtype", ["uint8", "uint16", "uint32", "int32"]) 94def test_cant_cast_dtype_nan(dtype): 95 assert not can_cast_dtype([np.nan], dtype) 96 97 98def test_validate_dtype(): 99 assert validate_dtype([1, 2, 3], ('uint8', 'uint16')) 100 assert validate_dtype(np.array([1, 2, 3]), ('uint8', 'uint16')) 101 assert validate_dtype(np.array([1.4, 2.1, 3.65]), ('float32',)) 102 assert not validate_dtype(np.array([1.4, 2.1, 3.65]), ('uint8',)) 103 104 105def test_complex(tmpdir): 106 name = str(tmpdir.join("complex.tif")) 107 arr1 = np.ones((2, 2), dtype=complex_) 108 profile = dict(driver='GTiff', width=2, height=2, count=1, dtype=complex_) 109 110 with rasterio.open(name, 'w', **profile) as dst: 111 dst.write(arr1, 1) 112 113 with rasterio.open(name) as src: 114 arr2 = src.read(1) 115 116 assert np.array_equal(arr1, arr2) 117 118 119def test_is_complex_int(): 120 assert _is_complex_int("complex_int16") 121 122 123def test_not_is_complex_int(): 124 assert not _is_complex_int("complex") 125 126 127def test_get_npdtype(): 128 npdtype = _getnpdtype("complex_int16") 129 assert npdtype == np.complex64 130 assert npdtype.kind == "c" 131