1from __future__ import annotations 2 3import collections 4import itertools 5import operator 6from typing import ( 7 TYPE_CHECKING, 8 Any, 9 Callable, 10 DefaultDict, 11 Dict, 12 Hashable, 13 Iterable, 14 List, 15 Mapping, 16 Sequence, 17 Tuple, 18 Union, 19) 20 21import numpy as np 22 23from .alignment import align 24from .dataarray import DataArray 25from .dataset import Dataset 26 27try: 28 import dask 29 import dask.array 30 from dask.array.utils import meta_from_array 31 from dask.highlevelgraph import HighLevelGraph 32 33except ImportError: 34 pass 35 36 37if TYPE_CHECKING: 38 from .types import T_Xarray 39 40 41def unzip(iterable): 42 return zip(*iterable) 43 44 45def assert_chunks_compatible(a: Dataset, b: Dataset): 46 a = a.unify_chunks() 47 b = b.unify_chunks() 48 49 for dim in set(a.chunks).intersection(set(b.chunks)): 50 if a.chunks[dim] != b.chunks[dim]: 51 raise ValueError(f"Chunk sizes along dimension {dim!r} are not equal.") 52 53 54def check_result_variables( 55 result: Union[DataArray, Dataset], expected: Mapping[str, Any], kind: str 56): 57 58 if kind == "coords": 59 nice_str = "coordinate" 60 elif kind == "data_vars": 61 nice_str = "data" 62 63 # check that coords and data variables are as expected 64 missing = expected[kind] - set(getattr(result, kind)) 65 if missing: 66 raise ValueError( 67 "Result from applying user function does not contain " 68 f"{nice_str} variables {missing}." 69 ) 70 extra = set(getattr(result, kind)) - expected[kind] 71 if extra: 72 raise ValueError( 73 "Result from applying user function has unexpected " 74 f"{nice_str} variables {extra}." 75 ) 76 77 78def dataset_to_dataarray(obj: Dataset) -> DataArray: 79 if not isinstance(obj, Dataset): 80 raise TypeError(f"Expected Dataset, got {type(obj)}") 81 82 if len(obj.data_vars) > 1: 83 raise TypeError( 84 "Trying to convert Dataset with more than one data variable to DataArray" 85 ) 86 87 return next(iter(obj.data_vars.values())) 88 89 90def dataarray_to_dataset(obj: DataArray) -> Dataset: 91 # only using _to_temp_dataset would break 92 # func = lambda x: x.to_dataset() 93 # since that relies on preserving name. 94 if obj.name is None: 95 dataset = obj._to_temp_dataset() 96 else: 97 dataset = obj.to_dataset() 98 return dataset 99 100 101def make_meta(obj): 102 """If obj is a DataArray or Dataset, return a new object of the same type and with 103 the same variables and dtypes, but where all variables have size 0 and numpy 104 backend. 105 If obj is neither a DataArray nor Dataset, return it unaltered. 106 """ 107 if isinstance(obj, DataArray): 108 obj_array = obj 109 obj = dataarray_to_dataset(obj) 110 elif isinstance(obj, Dataset): 111 obj_array = None 112 else: 113 return obj 114 115 meta = Dataset() 116 for name, variable in obj.variables.items(): 117 meta_obj = meta_from_array(variable.data, ndim=variable.ndim) 118 meta[name] = (variable.dims, meta_obj, variable.attrs) 119 meta.attrs = obj.attrs 120 meta = meta.set_coords(obj.coords) 121 122 if obj_array is not None: 123 return dataset_to_dataarray(meta) 124 return meta 125 126 127def infer_template( 128 func: Callable[..., T_Xarray], obj: Union[DataArray, Dataset], *args, **kwargs 129) -> T_Xarray: 130 """Infer return object by running the function on meta objects.""" 131 meta_args = [make_meta(arg) for arg in (obj,) + args] 132 133 try: 134 template = func(*meta_args, **kwargs) 135 except Exception as e: 136 raise Exception( 137 "Cannot infer object returned from running user provided function. " 138 "Please supply the 'template' kwarg to map_blocks." 139 ) from e 140 141 if not isinstance(template, (Dataset, DataArray)): 142 raise TypeError( 143 "Function must return an xarray DataArray or Dataset. Instead it returned " 144 f"{type(template)}" 145 ) 146 147 return template 148 149 150def make_dict(x: Union[DataArray, Dataset]) -> Dict[Hashable, Any]: 151 """Map variable name to numpy(-like) data 152 (Dataset.to_dict() is too complicated). 153 """ 154 if isinstance(x, DataArray): 155 x = x._to_temp_dataset() 156 157 return {k: v.data for k, v in x.variables.items()} 158 159 160def _get_chunk_slicer(dim: Hashable, chunk_index: Mapping, chunk_bounds: Mapping): 161 if dim in chunk_index: 162 which_chunk = chunk_index[dim] 163 return slice(chunk_bounds[dim][which_chunk], chunk_bounds[dim][which_chunk + 1]) 164 return slice(None) 165 166 167def map_blocks( 168 func: Callable[..., T_Xarray], 169 obj: Union[DataArray, Dataset], 170 args: Sequence[Any] = (), 171 kwargs: Mapping[str, Any] = None, 172 template: Union[DataArray, Dataset] = None, 173) -> T_Xarray: 174 """Apply a function to each block of a DataArray or Dataset. 175 176 .. warning:: 177 This function is experimental and its signature may change. 178 179 Parameters 180 ---------- 181 func : callable 182 User-provided function that accepts a DataArray or Dataset as its first 183 parameter ``obj``. The function will receive a subset or 'block' of ``obj`` (see below), 184 corresponding to one chunk along each chunked dimension. ``func`` will be 185 executed as ``func(subset_obj, *subset_args, **kwargs)``. 186 187 This function must return either a single DataArray or a single Dataset. 188 189 This function cannot add a new chunked dimension. 190 obj : DataArray, Dataset 191 Passed to the function as its first argument, one block at a time. 192 args : sequence 193 Passed to func after unpacking and subsetting any xarray objects by blocks. 194 xarray objects in args must be aligned with obj, otherwise an error is raised. 195 kwargs : mapping 196 Passed verbatim to func after unpacking. xarray objects, if any, will not be 197 subset to blocks. Passing dask collections in kwargs is not allowed. 198 template : DataArray or Dataset, optional 199 xarray object representing the final result after compute is called. If not provided, 200 the function will be first run on mocked-up data, that looks like ``obj`` but 201 has sizes 0, to determine properties of the returned object such as dtype, 202 variable names, attributes, new dimensions and new indexes (if any). 203 ``template`` must be provided if the function changes the size of existing dimensions. 204 When provided, ``attrs`` on variables in `template` are copied over to the result. Any 205 ``attrs`` set by ``func`` will be ignored. 206 207 Returns 208 ------- 209 A single DataArray or Dataset with dask backend, reassembled from the outputs of the 210 function. 211 212 Notes 213 ----- 214 This function is designed for when ``func`` needs to manipulate a whole xarray object 215 subset to each block. Each block is loaded into memory. In the more common case where 216 ``func`` can work on numpy arrays, it is recommended to use ``apply_ufunc``. 217 218 If none of the variables in ``obj`` is backed by dask arrays, calling this function is 219 equivalent to calling ``func(obj, *args, **kwargs)``. 220 221 See Also 222 -------- 223 dask.array.map_blocks, xarray.apply_ufunc, xarray.Dataset.map_blocks 224 xarray.DataArray.map_blocks 225 226 Examples 227 -------- 228 Calculate an anomaly from climatology using ``.groupby()``. Using 229 ``xr.map_blocks()`` allows for parallel operations with knowledge of ``xarray``, 230 its indices, and its methods like ``.groupby()``. 231 232 >>> def calculate_anomaly(da, groupby_type="time.month"): 233 ... gb = da.groupby(groupby_type) 234 ... clim = gb.mean(dim="time") 235 ... return gb - clim 236 ... 237 >>> time = xr.cftime_range("1990-01", "1992-01", freq="M") 238 >>> month = xr.DataArray(time.month, coords={"time": time}, dims=["time"]) 239 >>> np.random.seed(123) 240 >>> array = xr.DataArray( 241 ... np.random.rand(len(time)), 242 ... dims=["time"], 243 ... coords={"time": time, "month": month}, 244 ... ).chunk() 245 >>> array.map_blocks(calculate_anomaly, template=array).compute() 246 <xarray.DataArray (time: 24)> 247 array([ 0.12894847, 0.11323072, -0.0855964 , -0.09334032, 0.26848862, 248 0.12382735, 0.22460641, 0.07650108, -0.07673453, -0.22865714, 249 -0.19063865, 0.0590131 , -0.12894847, -0.11323072, 0.0855964 , 250 0.09334032, -0.26848862, -0.12382735, -0.22460641, -0.07650108, 251 0.07673453, 0.22865714, 0.19063865, -0.0590131 ]) 252 Coordinates: 253 * time (time) object 1990-01-31 00:00:00 ... 1991-12-31 00:00:00 254 month (time) int64 1 2 3 4 5 6 7 8 9 10 11 12 1 2 3 4 5 6 7 8 9 10 11 12 255 256 Note that one must explicitly use ``args=[]`` and ``kwargs={}`` to pass arguments 257 to the function being applied in ``xr.map_blocks()``: 258 259 >>> array.map_blocks( 260 ... calculate_anomaly, 261 ... kwargs={"groupby_type": "time.year"}, 262 ... template=array, 263 ... ) # doctest: +ELLIPSIS 264 <xarray.DataArray (time: 24)> 265 dask.array<<this-array>-calculate_anomaly, shape=(24,), dtype=float64, chunksize=(24,), chunktype=numpy.ndarray> 266 Coordinates: 267 * time (time) object 1990-01-31 00:00:00 ... 1991-12-31 00:00:00 268 month (time) int64 dask.array<chunksize=(24,), meta=np.ndarray> 269 """ 270 271 def _wrapper( 272 func: Callable, 273 args: List, 274 kwargs: dict, 275 arg_is_array: Iterable[bool], 276 expected: dict, 277 ): 278 """ 279 Wrapper function that receives datasets in args; converts to dataarrays when necessary; 280 passes these to the user function `func` and checks returned objects for expected shapes/sizes/etc. 281 """ 282 283 converted_args = [ 284 dataset_to_dataarray(arg) if is_array else arg 285 for is_array, arg in zip(arg_is_array, args) 286 ] 287 288 result = func(*converted_args, **kwargs) 289 290 # check all dims are present 291 missing_dimensions = set(expected["shapes"]) - set(result.sizes) 292 if missing_dimensions: 293 raise ValueError( 294 f"Dimensions {missing_dimensions} missing on returned object." 295 ) 296 297 # check that index lengths and values are as expected 298 for name, index in result.xindexes.items(): 299 if name in expected["shapes"]: 300 if result.sizes[name] != expected["shapes"][name]: 301 raise ValueError( 302 f"Received dimension {name!r} of length {result.sizes[name]}. " 303 f"Expected length {expected['shapes'][name]}." 304 ) 305 if name in expected["indexes"]: 306 expected_index = expected["indexes"][name] 307 if not index.equals(expected_index): 308 raise ValueError( 309 f"Expected index {name!r} to be {expected_index!r}. Received {index!r} instead." 310 ) 311 312 # check that all expected variables were returned 313 check_result_variables(result, expected, "coords") 314 if isinstance(result, Dataset): 315 check_result_variables(result, expected, "data_vars") 316 317 return make_dict(result) 318 319 if template is not None and not isinstance(template, (DataArray, Dataset)): 320 raise TypeError( 321 f"template must be a DataArray or Dataset. Received {type(template).__name__} instead." 322 ) 323 if not isinstance(args, Sequence): 324 raise TypeError("args must be a sequence (for example, a list or tuple).") 325 if kwargs is None: 326 kwargs = {} 327 elif not isinstance(kwargs, Mapping): 328 raise TypeError("kwargs must be a mapping (for example, a dict)") 329 330 for value in kwargs.values(): 331 if dask.is_dask_collection(value): 332 raise TypeError( 333 "Cannot pass dask collections in kwargs yet. Please compute or " 334 "load values before passing to map_blocks." 335 ) 336 337 if not dask.is_dask_collection(obj): 338 return func(obj, *args, **kwargs) 339 340 all_args = [obj] + list(args) 341 is_xarray = [isinstance(arg, (Dataset, DataArray)) for arg in all_args] 342 is_array = [isinstance(arg, DataArray) for arg in all_args] 343 344 # there should be a better way to group this. partition? 345 xarray_indices, xarray_objs = unzip( 346 (index, arg) for index, arg in enumerate(all_args) if is_xarray[index] 347 ) 348 others = [ 349 (index, arg) for index, arg in enumerate(all_args) if not is_xarray[index] 350 ] 351 352 # all xarray objects must be aligned. This is consistent with apply_ufunc. 353 aligned = align(*xarray_objs, join="exact") 354 xarray_objs = tuple( 355 dataarray_to_dataset(arg) if is_da else arg 356 for is_da, arg in zip(is_array, aligned) 357 ) 358 359 _, npargs = unzip( 360 sorted(list(zip(xarray_indices, xarray_objs)) + others, key=lambda x: x[0]) 361 ) 362 363 # check that chunk sizes are compatible 364 input_chunks = dict(npargs[0].chunks) 365 input_indexes = dict(npargs[0].xindexes) 366 for arg in xarray_objs[1:]: 367 assert_chunks_compatible(npargs[0], arg) 368 input_chunks.update(arg.chunks) 369 input_indexes.update(arg.xindexes) 370 371 if template is None: 372 # infer template by providing zero-shaped arrays 373 template = infer_template(func, aligned[0], *args, **kwargs) 374 template_indexes = set(template.xindexes) 375 preserved_indexes = template_indexes & set(input_indexes) 376 new_indexes = template_indexes - set(input_indexes) 377 indexes = {dim: input_indexes[dim] for dim in preserved_indexes} 378 indexes.update({k: template.xindexes[k] for k in new_indexes}) 379 output_chunks = { 380 dim: input_chunks[dim] for dim in template.dims if dim in input_chunks 381 } 382 383 else: 384 # template xarray object has been provided with proper sizes and chunk shapes 385 indexes = dict(template.xindexes) 386 if isinstance(template, DataArray): 387 output_chunks = dict( 388 zip(template.dims, template.chunks) # type: ignore[arg-type] 389 ) 390 else: 391 output_chunks = dict(template.chunks) 392 393 for dim in output_chunks: 394 if dim in input_chunks and len(input_chunks[dim]) != len(output_chunks[dim]): 395 raise ValueError( 396 "map_blocks requires that one block of the input maps to one block of output. " 397 f"Expected number of output chunks along dimension {dim!r} to be {len(input_chunks[dim])}. " 398 f"Received {len(output_chunks[dim])} instead. Please provide template if not provided, or " 399 "fix the provided template." 400 ) 401 402 if isinstance(template, DataArray): 403 result_is_array = True 404 template_name = template.name 405 template = template._to_temp_dataset() 406 elif isinstance(template, Dataset): 407 result_is_array = False 408 else: 409 raise TypeError( 410 f"func output must be DataArray or Dataset; got {type(template)}" 411 ) 412 413 # We're building a new HighLevelGraph hlg. We'll have one new layer 414 # for each variable in the dataset, which is the result of the 415 # func applied to the values. 416 417 graph: Dict[Any, Any] = {} 418 new_layers: DefaultDict[str, Dict[Any, Any]] = collections.defaultdict(dict) 419 gname = "{}-{}".format( 420 dask.utils.funcname(func), dask.base.tokenize(npargs[0], args, kwargs) 421 ) 422 423 # map dims to list of chunk indexes 424 ichunk = {dim: range(len(chunks_v)) for dim, chunks_v in input_chunks.items()} 425 # mapping from chunk index to slice bounds 426 input_chunk_bounds = { 427 dim: np.cumsum((0,) + chunks_v) for dim, chunks_v in input_chunks.items() 428 } 429 output_chunk_bounds = { 430 dim: np.cumsum((0,) + chunks_v) for dim, chunks_v in output_chunks.items() 431 } 432 433 def subset_dataset_to_block( 434 graph: dict, gname: str, dataset: Dataset, input_chunk_bounds, chunk_index 435 ): 436 """ 437 Creates a task that subsets an xarray dataset to a block determined by chunk_index. 438 Block extents are determined by input_chunk_bounds. 439 Also subtasks that subset the constituent variables of a dataset. 440 """ 441 442 # this will become [[name1, variable1], 443 # [name2, variable2], 444 # ...] 445 # which is passed to dict and then to Dataset 446 data_vars = [] 447 coords = [] 448 449 chunk_tuple = tuple(chunk_index.values()) 450 for name, variable in dataset.variables.items(): 451 # make a task that creates tuple of (dims, chunk) 452 if dask.is_dask_collection(variable.data): 453 # recursively index into dask_keys nested list to get chunk 454 chunk = variable.__dask_keys__() 455 for dim in variable.dims: 456 chunk = chunk[chunk_index[dim]] 457 458 chunk_variable_task = (f"{name}-{gname}-{chunk[0]}",) + chunk_tuple 459 graph[chunk_variable_task] = ( 460 tuple, 461 [variable.dims, chunk, variable.attrs], 462 ) 463 else: 464 # non-dask array possibly with dimensions chunked on other variables 465 # index into variable appropriately 466 subsetter = { 467 dim: _get_chunk_slicer(dim, chunk_index, input_chunk_bounds) 468 for dim in variable.dims 469 } 470 subset = variable.isel(subsetter) 471 chunk_variable_task = ( 472 f"{name}-{gname}-{dask.base.tokenize(subset)}", 473 ) + chunk_tuple 474 graph[chunk_variable_task] = ( 475 tuple, 476 [subset.dims, subset, subset.attrs], 477 ) 478 479 # this task creates dict mapping variable name to above tuple 480 if name in dataset._coord_names: 481 coords.append([name, chunk_variable_task]) 482 else: 483 data_vars.append([name, chunk_variable_task]) 484 485 return (Dataset, (dict, data_vars), (dict, coords), dataset.attrs) 486 487 # iterate over all possible chunk combinations 488 for chunk_tuple in itertools.product(*ichunk.values()): 489 # mapping from dimension name to chunk index 490 chunk_index = dict(zip(ichunk.keys(), chunk_tuple)) 491 492 blocked_args = [ 493 subset_dataset_to_block(graph, gname, arg, input_chunk_bounds, chunk_index) 494 if isxr 495 else arg 496 for isxr, arg in zip(is_xarray, npargs) 497 ] 498 499 # expected["shapes", "coords", "data_vars", "indexes"] are used to 500 # raise nice error messages in _wrapper 501 expected = {} 502 # input chunk 0 along a dimension maps to output chunk 0 along the same dimension 503 # even if length of dimension is changed by the applied function 504 expected["shapes"] = { 505 k: output_chunks[k][v] for k, v in chunk_index.items() if k in output_chunks 506 } 507 expected["data_vars"] = set(template.data_vars.keys()) # type: ignore[assignment] 508 expected["coords"] = set(template.coords.keys()) # type: ignore[assignment] 509 expected["indexes"] = { 510 dim: indexes[dim][_get_chunk_slicer(dim, chunk_index, output_chunk_bounds)] 511 for dim in indexes 512 } 513 514 from_wrapper = (gname,) + chunk_tuple 515 graph[from_wrapper] = (_wrapper, func, blocked_args, kwargs, is_array, expected) 516 517 # mapping from variable name to dask graph key 518 var_key_map: Dict[Hashable, str] = {} 519 for name, variable in template.variables.items(): 520 if name in indexes: 521 continue 522 gname_l = f"{name}-{gname}" 523 var_key_map[name] = gname_l 524 525 key: Tuple[Any, ...] = (gname_l,) 526 for dim in variable.dims: 527 if dim in chunk_index: 528 key += (chunk_index[dim],) 529 else: 530 # unchunked dimensions in the input have one chunk in the result 531 # output can have new dimensions with exactly one chunk 532 key += (0,) 533 534 # We're adding multiple new layers to the graph: 535 # The first new layer is the result of the computation on 536 # the array. 537 # Then we add one layer per variable, which extracts the 538 # result for that variable, and depends on just the first new 539 # layer. 540 new_layers[gname_l][key] = (operator.getitem, from_wrapper, name) 541 542 hlg = HighLevelGraph.from_collections( 543 gname, 544 graph, 545 dependencies=[arg for arg in npargs if dask.is_dask_collection(arg)], 546 ) 547 548 # This adds in the getitems for each variable in the dataset. 549 hlg = HighLevelGraph( 550 {**hlg.layers, **new_layers}, 551 dependencies={ 552 **hlg.dependencies, 553 **{name: {gname} for name in new_layers.keys()}, 554 }, 555 ) 556 557 # TODO: benbovy - flexible indexes: make it work with custom indexes 558 # this will need to pass both indexes and coords to the Dataset constructor 559 result = Dataset( 560 coords={k: idx.to_pandas_index() for k, idx in indexes.items()}, 561 attrs=template.attrs, 562 ) 563 564 for index in result.xindexes: 565 result[index].attrs = template[index].attrs 566 result[index].encoding = template[index].encoding 567 568 for name, gname_l in var_key_map.items(): 569 dims = template[name].dims 570 var_chunks = [] 571 for dim in dims: 572 if dim in output_chunks: 573 var_chunks.append(output_chunks[dim]) 574 elif dim in result.xindexes: 575 var_chunks.append((result.sizes[dim],)) 576 elif dim in template.dims: 577 # new unindexed dimension 578 var_chunks.append((template.sizes[dim],)) 579 580 data = dask.array.Array( 581 hlg, name=gname_l, chunks=var_chunks, dtype=template[name].dtype 582 ) 583 result[name] = (dims, data, template[name].attrs) 584 result[name].encoding = template[name].encoding 585 586 result = result.set_coords(template._coord_names) 587 588 if result_is_array: 589 da = dataset_to_dataarray(result) 590 da.name = template_name 591 return da # type: ignore[return-value] 592 return result # type: ignore[return-value] 593