1import numpy as np
2import pytest
3
4import rasterio
5from rasterio import (
6    ubyte,
7    uint8,
8    uint16,
9    uint32,
10    int16,
11    int32,
12    float32,
13    float64,
14    complex_,
15    complex_int16,
16)
17from rasterio.dtypes import (
18    _gdal_typename,
19    is_ndarray,
20    check_dtype,
21    get_minimum_dtype,
22    can_cast_dtype,
23    validate_dtype,
24    _is_complex_int,
25    _getnpdtype,
26)
27
28
29def test_is_ndarray():
30    assert is_ndarray(np.zeros((1,)))
31    assert not is_ndarray([0])
32    assert not is_ndarray((0,))
33
34
35def test_np_dt_uint8():
36    assert check_dtype(np.uint8)
37
38
39def test_dt_ubyte():
40    assert check_dtype(ubyte)
41
42
43def test_check_dtype_invalid():
44    assert not check_dtype('foo')
45
46
47@pytest.mark.parametrize(
48    ("dtype", "name"),
49    [
50        (ubyte, "Byte"),
51        (np.uint8, "Byte"),
52        (np.uint16, "UInt16"),
53        ("uint8", "Byte"),
54        ("complex_int16", "CInt16"),
55        (complex_int16, "CInt16"),
56    ],
57)
58def test_gdal_name(dtype, name):
59    assert _gdal_typename(dtype) == name
60
61
62def test_get_minimum_dtype():
63    assert get_minimum_dtype([0, 1]) == uint8
64    assert get_minimum_dtype([0, 1000]) == uint16
65    assert get_minimum_dtype([0, 100000]) == uint32
66    assert get_minimum_dtype([-1, 0, 1]) == int16
67    assert get_minimum_dtype([-1, 0, 100000]) == int32
68    assert get_minimum_dtype([-1.5, 0, 1.5]) == float32
69    assert get_minimum_dtype([-1.5e+100, 0, 1.5e+100]) == float64
70
71    assert get_minimum_dtype(np.array([0, 1], dtype=np.uint)) == uint8
72    assert get_minimum_dtype(np.array([0, 1000], dtype=np.uint)) == uint16
73    assert get_minimum_dtype(np.array([0, 100000], dtype=np.uint)) == uint32
74    assert get_minimum_dtype(np.array([-1, 0, 1], dtype=int)) == int16
75    assert get_minimum_dtype(np.array([-1, 0, 100000], dtype=int)) == int32
76    assert get_minimum_dtype(np.array([-1.5, 0, 1.5], dtype=np.float64)) == float32
77
78
79def test_can_cast_dtype():
80    assert can_cast_dtype((1, 2, 3), np.uint8)
81    assert can_cast_dtype(np.array([1, 2, 3]), np.uint8)
82    assert can_cast_dtype(np.array([1, 2, 3], dtype=np.uint8), np.uint8)
83    assert can_cast_dtype(np.array([1, 2, 3]), np.float32)
84    assert can_cast_dtype(np.array([1.4, 2.1, 3.65]), np.float32)
85    assert not can_cast_dtype(np.array([1.4, 2.1, 3.65]), np.uint8)
86
87
88@pytest.mark.parametrize("dtype", ["float64", "float32"])
89def test_can_cast_dtype_nan(dtype):
90    assert can_cast_dtype([np.nan], dtype)
91
92
93@pytest.mark.parametrize("dtype", ["uint8", "uint16", "uint32", "int32"])
94def test_cant_cast_dtype_nan(dtype):
95    assert not can_cast_dtype([np.nan], dtype)
96
97
98def test_validate_dtype():
99    assert validate_dtype([1, 2, 3], ('uint8', 'uint16'))
100    assert validate_dtype(np.array([1, 2, 3]), ('uint8', 'uint16'))
101    assert validate_dtype(np.array([1.4, 2.1, 3.65]), ('float32',))
102    assert not validate_dtype(np.array([1.4, 2.1, 3.65]), ('uint8',))
103
104
105def test_complex(tmpdir):
106    name = str(tmpdir.join("complex.tif"))
107    arr1 = np.ones((2, 2), dtype=complex_)
108    profile = dict(driver='GTiff', width=2, height=2, count=1, dtype=complex_)
109
110    with rasterio.open(name, 'w', **profile) as dst:
111        dst.write(arr1, 1)
112
113    with rasterio.open(name) as src:
114        arr2 = src.read(1)
115
116    assert np.array_equal(arr1, arr2)
117
118
119def test_is_complex_int():
120    assert _is_complex_int("complex_int16")
121
122
123def test_not_is_complex_int():
124    assert not _is_complex_int("complex")
125
126
127def test_get_npdtype():
128    npdtype = _getnpdtype("complex_int16")
129    assert npdtype == np.complex64
130    assert npdtype.kind == "c"
131