1import numpy as np
2import pytest
3
4pytestmark = pytest.mark.gpu
5
6import dask.array as da
7from dask.array.numpy_compat import _numpy_120
8from dask.array.utils import assert_eq
9
10cupy = pytest.importorskip("cupy")
11
12
13@pytest.mark.parametrize("idx_chunks", [None, 3, 2, 1])
14@pytest.mark.parametrize("x_chunks", [(3, 5), (2, 3), (1, 2), (1, 1)])
15def test_index_with_int_dask_array(x_chunks, idx_chunks):
16    # test data is crafted to stress use cases:
17    # - pick from different chunks of x out of order
18    # - a chunk of x contains no matches
19    # - only one chunk of x
20    x = cupy.array(
21        [[10, 20, 30, 40, 50], [60, 70, 80, 90, 100], [110, 120, 130, 140, 150]]
22    )
23    idx = cupy.array([3, 0, 1])
24    expect = cupy.array([[40, 10, 20], [90, 60, 70], [140, 110, 120]])
25
26    x = da.from_array(x, chunks=x_chunks)
27    if idx_chunks is not None:
28        idx = da.from_array(idx, chunks=idx_chunks)
29
30    assert_eq(x[:, idx], expect)
31    assert_eq(x.T[idx, :], expect.T)
32
33
34@pytest.mark.skipif(not _numpy_120, reason="NEP-35 is not available")
35@pytest.mark.parametrize("idx_chunks", [None, 3, 2, 1])
36@pytest.mark.parametrize("x_chunks", [(3, 5), (2, 3), (1, 2), (1, 1)])
37def test_index_with_int_dask_array(x_chunks, idx_chunks):
38    # test data is crafted to stress use cases:
39    # - pick from different chunks of x out of order
40    # - a chunk of x contains no matches
41    # - only one chunk of x
42    x = cupy.array(
43        [[10, 20, 30, 40, 50], [60, 70, 80, 90, 100], [110, 120, 130, 140, 150]]
44    )
45    orig_idx = np.array([3, 0, 1])
46    expect = cupy.array([[40, 10, 20], [90, 60, 70], [140, 110, 120]])
47
48    if x_chunks is not None:
49        x = da.from_array(x, chunks=x_chunks)
50    if idx_chunks is not None:
51        idx = da.from_array(orig_idx, chunks=idx_chunks)
52    else:
53        idx = orig_idx
54
55    assert_eq(x[:, idx], expect)
56    assert_eq(x.T[idx, :], expect.T)
57
58    # CuPy index
59    orig_idx = cupy.array(orig_idx)
60    if idx_chunks is not None:
61        idx = da.from_array(orig_idx, chunks=idx_chunks)
62    else:
63        idx = orig_idx
64
65    assert_eq(x[:, idx], expect)
66    assert_eq(x.T[idx, :], expect.T)
67
68
69@pytest.mark.skipif(not _numpy_120, reason="NEP-35 is not available")
70@pytest.mark.parametrize("chunks", [1, 2, 3])
71def test_index_with_int_dask_array_0d(chunks):
72    # Slice by 0-dimensional array
73    x = da.from_array(cupy.array([[10, 20, 30], [40, 50, 60]]), chunks=chunks)
74    idx0 = da.from_array(1, chunks=1)
75    assert_eq(x[idx0, :], x[1, :])
76    assert_eq(x[:, idx0], x[:, 1])
77
78    # CuPy index
79    idx0 = da.from_array(cupy.array(1), chunks=1)
80    assert_eq(x[idx0, :], x[1, :])
81    assert_eq(x[:, idx0], x[:, 1])
82
83
84@pytest.mark.skipif(not _numpy_120, reason="NEP-35 is not available")
85@pytest.mark.skip("dask.Array.nonzero() doesn't support non-NumPy arrays yet")
86@pytest.mark.parametrize("chunks", [1, 2, 3, 4, 5])
87def test_index_with_int_dask_array_nanchunks(chunks):
88    # Slice by array with nan-sized chunks
89    a = da.from_array(cupy.arange(-2, 3), chunks=chunks)
90    assert_eq(a[a.nonzero()], cupy.array([-2, -1, 1, 2]))
91    # Edge case: the nan-sized chunks resolve to size 0
92    a = da.zeros_like(cupy.array(()), shape=5, chunks=chunks)
93    assert_eq(a[a.nonzero()], cupy.array([]))
94
95
96@pytest.mark.skipif(not _numpy_120, reason="NEP-35 is not available")
97@pytest.mark.parametrize("chunks", [2, 4])
98def test_index_with_int_dask_array_negindex(chunks):
99    a = da.arange(4, chunks=chunks, like=cupy.array(()))
100    idx = da.from_array([-1, -4], chunks=1)
101    assert_eq(a[idx], cupy.array([3, 0]))
102
103    # CuPy index
104    idx = da.from_array(cupy.array([-1, -4]), chunks=1)
105    assert_eq(a[idx], cupy.array([3, 0]))
106
107
108@pytest.mark.skipif(not _numpy_120, reason="NEP-35 is not available")
109@pytest.mark.parametrize("chunks", [2, 4])
110def test_index_with_int_dask_array_indexerror(chunks):
111    a = da.arange(4, chunks=chunks, like=cupy.array(()))
112    idx = da.from_array([4], chunks=1)
113    with pytest.raises(IndexError):
114        a[idx].compute()
115    idx = da.from_array([-5], chunks=1)
116    with pytest.raises(IndexError):
117        a[idx].compute()
118
119    # CuPy indices
120    idx = da.from_array(cupy.array([4]), chunks=1)
121    with pytest.raises(IndexError):
122        a[idx].compute()
123    idx = da.from_array(cupy.array([-5]), chunks=1)
124    with pytest.raises(IndexError):
125        a[idx].compute()
126
127
128@pytest.mark.skipif(not _numpy_120, reason="NEP-35 is not available")
129@pytest.mark.parametrize(
130    "dtype", ["int8", "int16", "int32", "int64", "uint8", "uint16", "uint32", "uint64"]
131)
132def test_index_with_int_dask_array_dtypes(dtype):
133    a = da.from_array(cupy.array([10, 20, 30, 40]), chunks=-1)
134    idx = da.from_array(np.array([1, 2]).astype(dtype), chunks=1)
135    assert_eq(a[idx], cupy.array([20, 30]))
136
137    # CuPy index
138    idx = da.from_array(cupy.array([1, 2]).astype(dtype), chunks=1)
139    assert_eq(a[idx], cupy.array([20, 30]))
140
141
142def test_index_with_int_dask_array_nocompute():
143    """Test that when the indices are a dask array
144    they are not accidentally computed
145    """
146
147    def crash():
148        raise NotImplementedError()
149
150    x = da.arange(5, chunks=-1, like=cupy.array(()))
151    idx = da.Array({("x", 0): (crash,)}, name="x", chunks=((2,),), dtype=np.int64)
152    result = x[idx]
153    with pytest.raises(NotImplementedError):
154        result.compute()
155