1import os 2 3import numpy as np 4import pandas as pd 5 6import xarray as xr 7 8from . import parameterized, randint, randn, requires_dask 9 10nx = 2000 11ny = 1000 12nt = 500 13 14basic_indexes = { 15 "1slice": {"x": slice(0, 3)}, 16 "1slice-1scalar": {"x": 0, "y": slice(None, None, 3)}, 17 "2slicess-1scalar": {"x": slice(3, -3, 3), "y": 1, "t": slice(None, -3, 3)}, 18} 19 20basic_assignment_values = { 21 "1slice": xr.DataArray(randn((3, ny), frac_nan=0.1), dims=["x", "y"]), 22 "1slice-1scalar": xr.DataArray(randn(int(ny / 3) + 1, frac_nan=0.1), dims=["y"]), 23 "2slicess-1scalar": xr.DataArray( 24 randn(np.empty(nx)[slice(3, -3, 3)].size, frac_nan=0.1), dims=["x"] 25 ), 26} 27 28outer_indexes = { 29 "1d": {"x": randint(0, nx, 400)}, 30 "2d": {"x": randint(0, nx, 500), "y": randint(0, ny, 400)}, 31 "2d-1scalar": {"x": randint(0, nx, 100), "y": 1, "t": randint(0, nt, 400)}, 32} 33 34outer_assignment_values = { 35 "1d": xr.DataArray(randn((400, ny), frac_nan=0.1), dims=["x", "y"]), 36 "2d": xr.DataArray(randn((500, 400), frac_nan=0.1), dims=["x", "y"]), 37 "2d-1scalar": xr.DataArray(randn(100, frac_nan=0.1), dims=["x"]), 38} 39 40vectorized_indexes = { 41 "1-1d": {"x": xr.DataArray(randint(0, nx, 400), dims="a")}, 42 "2-1d": { 43 "x": xr.DataArray(randint(0, nx, 400), dims="a"), 44 "y": xr.DataArray(randint(0, ny, 400), dims="a"), 45 }, 46 "3-2d": { 47 "x": xr.DataArray(randint(0, nx, 400).reshape(4, 100), dims=["a", "b"]), 48 "y": xr.DataArray(randint(0, ny, 400).reshape(4, 100), dims=["a", "b"]), 49 "t": xr.DataArray(randint(0, nt, 400).reshape(4, 100), dims=["a", "b"]), 50 }, 51} 52 53vectorized_assignment_values = { 54 "1-1d": xr.DataArray(randn((400, ny)), dims=["a", "y"], coords={"a": randn(400)}), 55 "2-1d": xr.DataArray(randn(400), dims=["a"], coords={"a": randn(400)}), 56 "3-2d": xr.DataArray( 57 randn((4, 100)), dims=["a", "b"], coords={"a": randn(4), "b": randn(100)} 58 ), 59} 60 61 62class Base: 63 def setup(self, key): 64 self.ds = xr.Dataset( 65 { 66 "var1": (("x", "y"), randn((nx, ny), frac_nan=0.1)), 67 "var2": (("x", "t"), randn((nx, nt))), 68 "var3": (("t",), randn(nt)), 69 }, 70 coords={ 71 "x": np.arange(nx), 72 "y": np.linspace(0, 1, ny), 73 "t": pd.date_range("1970-01-01", periods=nt, freq="D"), 74 "x_coords": ("x", np.linspace(1.1, 2.1, nx)), 75 }, 76 ) 77 78 79class Indexing(Base): 80 @parameterized(["key"], [list(basic_indexes.keys())]) 81 def time_indexing_basic(self, key): 82 self.ds.isel(**basic_indexes[key]).load() 83 84 @parameterized(["key"], [list(outer_indexes.keys())]) 85 def time_indexing_outer(self, key): 86 self.ds.isel(**outer_indexes[key]).load() 87 88 @parameterized(["key"], [list(vectorized_indexes.keys())]) 89 def time_indexing_vectorized(self, key): 90 self.ds.isel(**vectorized_indexes[key]).load() 91 92 93class Assignment(Base): 94 @parameterized(["key"], [list(basic_indexes.keys())]) 95 def time_assignment_basic(self, key): 96 ind = basic_indexes[key] 97 val = basic_assignment_values[key] 98 self.ds["var1"][ind.get("x", slice(None)), ind.get("y", slice(None))] = val 99 100 @parameterized(["key"], [list(outer_indexes.keys())]) 101 def time_assignment_outer(self, key): 102 ind = outer_indexes[key] 103 val = outer_assignment_values[key] 104 self.ds["var1"][ind.get("x", slice(None)), ind.get("y", slice(None))] = val 105 106 @parameterized(["key"], [list(vectorized_indexes.keys())]) 107 def time_assignment_vectorized(self, key): 108 ind = vectorized_indexes[key] 109 val = vectorized_assignment_values[key] 110 self.ds["var1"][ind.get("x", slice(None)), ind.get("y", slice(None))] = val 111 112 113class IndexingDask(Indexing): 114 def setup(self, key): 115 requires_dask() 116 super().setup(key) 117 self.ds = self.ds.chunk({"x": 100, "y": 50, "t": 50}) 118 119 120class BooleanIndexing: 121 # https://github.com/pydata/xarray/issues/2227 122 def setup(self): 123 self.ds = xr.Dataset( 124 {"a": ("time", np.arange(10_000_000))}, 125 coords={"time": np.arange(10_000_000)}, 126 ) 127 self.time_filter = self.ds.time > 50_000 128 129 def time_indexing(self): 130 self.ds.isel(time=self.time_filter) 131 132 133class HugeAxisSmallSliceIndexing: 134 # https://github.com/pydata/xarray/pull/4560 135 def setup(self): 136 self.filepath = "test_indexing_huge_axis_small_slice.nc" 137 if not os.path.isfile(self.filepath): 138 xr.Dataset( 139 {"a": ("x", np.arange(10_000_000))}, 140 coords={"x": np.arange(10_000_000)}, 141 ).to_netcdf(self.filepath, format="NETCDF4") 142 143 self.ds = xr.open_dataset(self.filepath) 144 145 def time_indexing(self): 146 self.ds.isel(x=slice(100)) 147 148 def cleanup(self): 149 self.ds.close() 150