1import os
2
3import numpy as np
4import pandas as pd
5
6import xarray as xr
7
8from . import parameterized, randint, randn, requires_dask
9
10nx = 2000
11ny = 1000
12nt = 500
13
14basic_indexes = {
15    "1slice": {"x": slice(0, 3)},
16    "1slice-1scalar": {"x": 0, "y": slice(None, None, 3)},
17    "2slicess-1scalar": {"x": slice(3, -3, 3), "y": 1, "t": slice(None, -3, 3)},
18}
19
20basic_assignment_values = {
21    "1slice": xr.DataArray(randn((3, ny), frac_nan=0.1), dims=["x", "y"]),
22    "1slice-1scalar": xr.DataArray(randn(int(ny / 3) + 1, frac_nan=0.1), dims=["y"]),
23    "2slicess-1scalar": xr.DataArray(
24        randn(np.empty(nx)[slice(3, -3, 3)].size, frac_nan=0.1), dims=["x"]
25    ),
26}
27
28outer_indexes = {
29    "1d": {"x": randint(0, nx, 400)},
30    "2d": {"x": randint(0, nx, 500), "y": randint(0, ny, 400)},
31    "2d-1scalar": {"x": randint(0, nx, 100), "y": 1, "t": randint(0, nt, 400)},
32}
33
34outer_assignment_values = {
35    "1d": xr.DataArray(randn((400, ny), frac_nan=0.1), dims=["x", "y"]),
36    "2d": xr.DataArray(randn((500, 400), frac_nan=0.1), dims=["x", "y"]),
37    "2d-1scalar": xr.DataArray(randn(100, frac_nan=0.1), dims=["x"]),
38}
39
40vectorized_indexes = {
41    "1-1d": {"x": xr.DataArray(randint(0, nx, 400), dims="a")},
42    "2-1d": {
43        "x": xr.DataArray(randint(0, nx, 400), dims="a"),
44        "y": xr.DataArray(randint(0, ny, 400), dims="a"),
45    },
46    "3-2d": {
47        "x": xr.DataArray(randint(0, nx, 400).reshape(4, 100), dims=["a", "b"]),
48        "y": xr.DataArray(randint(0, ny, 400).reshape(4, 100), dims=["a", "b"]),
49        "t": xr.DataArray(randint(0, nt, 400).reshape(4, 100), dims=["a", "b"]),
50    },
51}
52
53vectorized_assignment_values = {
54    "1-1d": xr.DataArray(randn((400, ny)), dims=["a", "y"], coords={"a": randn(400)}),
55    "2-1d": xr.DataArray(randn(400), dims=["a"], coords={"a": randn(400)}),
56    "3-2d": xr.DataArray(
57        randn((4, 100)), dims=["a", "b"], coords={"a": randn(4), "b": randn(100)}
58    ),
59}
60
61
62class Base:
63    def setup(self, key):
64        self.ds = xr.Dataset(
65            {
66                "var1": (("x", "y"), randn((nx, ny), frac_nan=0.1)),
67                "var2": (("x", "t"), randn((nx, nt))),
68                "var3": (("t",), randn(nt)),
69            },
70            coords={
71                "x": np.arange(nx),
72                "y": np.linspace(0, 1, ny),
73                "t": pd.date_range("1970-01-01", periods=nt, freq="D"),
74                "x_coords": ("x", np.linspace(1.1, 2.1, nx)),
75            },
76        )
77
78
79class Indexing(Base):
80    @parameterized(["key"], [list(basic_indexes.keys())])
81    def time_indexing_basic(self, key):
82        self.ds.isel(**basic_indexes[key]).load()
83
84    @parameterized(["key"], [list(outer_indexes.keys())])
85    def time_indexing_outer(self, key):
86        self.ds.isel(**outer_indexes[key]).load()
87
88    @parameterized(["key"], [list(vectorized_indexes.keys())])
89    def time_indexing_vectorized(self, key):
90        self.ds.isel(**vectorized_indexes[key]).load()
91
92
93class Assignment(Base):
94    @parameterized(["key"], [list(basic_indexes.keys())])
95    def time_assignment_basic(self, key):
96        ind = basic_indexes[key]
97        val = basic_assignment_values[key]
98        self.ds["var1"][ind.get("x", slice(None)), ind.get("y", slice(None))] = val
99
100    @parameterized(["key"], [list(outer_indexes.keys())])
101    def time_assignment_outer(self, key):
102        ind = outer_indexes[key]
103        val = outer_assignment_values[key]
104        self.ds["var1"][ind.get("x", slice(None)), ind.get("y", slice(None))] = val
105
106    @parameterized(["key"], [list(vectorized_indexes.keys())])
107    def time_assignment_vectorized(self, key):
108        ind = vectorized_indexes[key]
109        val = vectorized_assignment_values[key]
110        self.ds["var1"][ind.get("x", slice(None)), ind.get("y", slice(None))] = val
111
112
113class IndexingDask(Indexing):
114    def setup(self, key):
115        requires_dask()
116        super().setup(key)
117        self.ds = self.ds.chunk({"x": 100, "y": 50, "t": 50})
118
119
120class BooleanIndexing:
121    # https://github.com/pydata/xarray/issues/2227
122    def setup(self):
123        self.ds = xr.Dataset(
124            {"a": ("time", np.arange(10_000_000))},
125            coords={"time": np.arange(10_000_000)},
126        )
127        self.time_filter = self.ds.time > 50_000
128
129    def time_indexing(self):
130        self.ds.isel(time=self.time_filter)
131
132
133class HugeAxisSmallSliceIndexing:
134    # https://github.com/pydata/xarray/pull/4560
135    def setup(self):
136        self.filepath = "test_indexing_huge_axis_small_slice.nc"
137        if not os.path.isfile(self.filepath):
138            xr.Dataset(
139                {"a": ("x", np.arange(10_000_000))},
140                coords={"x": np.arange(10_000_000)},
141            ).to_netcdf(self.filepath, format="NETCDF4")
142
143        self.ds = xr.open_dataset(self.filepath)
144
145    def time_indexing(self):
146        self.ds.isel(x=slice(100))
147
148    def cleanup(self):
149        self.ds.close()
150