1from contextlib import contextmanager 2from typing import ( 3 TYPE_CHECKING, 4 Any, 5 Dict, 6 Hashable, 7 Iterator, 8 Mapping, 9 Sequence, 10 Set, 11 Tuple, 12 Union, 13 cast, 14) 15 16import numpy as np 17import pandas as pd 18 19from . import formatting, indexing 20from .indexes import Index, Indexes 21from .merge import merge_coordinates_without_align, merge_coords 22from .utils import Frozen, ReprObject, either_dict_or_kwargs 23from .variable import Variable 24 25if TYPE_CHECKING: 26 from .dataarray import DataArray 27 from .dataset import Dataset 28 29# Used as the key corresponding to a DataArray's variable when converting 30# arbitrary DataArray objects to datasets 31_THIS_ARRAY = ReprObject("<this-array>") 32 33 34class Coordinates(Mapping[Any, "DataArray"]): 35 __slots__ = () 36 37 def __getitem__(self, key: Hashable) -> "DataArray": 38 raise NotImplementedError() 39 40 def __setitem__(self, key: Hashable, value: Any) -> None: 41 self.update({key: value}) 42 43 @property 44 def _names(self) -> Set[Hashable]: 45 raise NotImplementedError() 46 47 @property 48 def dims(self) -> Union[Mapping[Hashable, int], Tuple[Hashable, ...]]: 49 raise NotImplementedError() 50 51 @property 52 def indexes(self) -> Indexes: 53 return self._data.indexes # type: ignore[attr-defined] 54 55 @property 56 def xindexes(self) -> Indexes: 57 return self._data.xindexes # type: ignore[attr-defined] 58 59 @property 60 def variables(self): 61 raise NotImplementedError() 62 63 def _update_coords(self, coords, indexes): 64 raise NotImplementedError() 65 66 def __iter__(self) -> Iterator["Hashable"]: 67 # needs to be in the same order as the dataset variables 68 for k in self.variables: 69 if k in self._names: 70 yield k 71 72 def __len__(self) -> int: 73 return len(self._names) 74 75 def __contains__(self, key: Hashable) -> bool: 76 return key in self._names 77 78 def __repr__(self) -> str: 79 return formatting.coords_repr(self) 80 81 def to_dataset(self) -> "Dataset": 82 raise NotImplementedError() 83 84 def to_index(self, ordered_dims: Sequence[Hashable] = None) -> pd.Index: 85 """Convert all index coordinates into a :py:class:`pandas.Index`. 86 87 Parameters 88 ---------- 89 ordered_dims : sequence of hashable, optional 90 Possibly reordered version of this object's dimensions indicating 91 the order in which dimensions should appear on the result. 92 93 Returns 94 ------- 95 pandas.Index 96 Index subclass corresponding to the outer-product of all dimension 97 coordinates. This will be a MultiIndex if this object is has more 98 than more dimension. 99 """ 100 if ordered_dims is None: 101 ordered_dims = list(self.dims) 102 elif set(ordered_dims) != set(self.dims): 103 raise ValueError( 104 "ordered_dims must match dims, but does not: " 105 "{} vs {}".format(ordered_dims, self.dims) 106 ) 107 108 if len(ordered_dims) == 0: 109 raise ValueError("no valid index for a 0-dimensional object") 110 elif len(ordered_dims) == 1: 111 (dim,) = ordered_dims 112 return self._data.get_index(dim) # type: ignore[attr-defined] 113 else: 114 indexes = [ 115 self._data.get_index(k) for k in ordered_dims # type: ignore[attr-defined] 116 ] 117 118 # compute the sizes of the repeat and tile for the cartesian product 119 # (taken from pandas.core.reshape.util) 120 index_lengths = np.fromiter( 121 (len(index) for index in indexes), dtype=np.intp 122 ) 123 cumprod_lengths = np.cumproduct(index_lengths) 124 125 if cumprod_lengths[-1] == 0: 126 # if any factor is empty, the cartesian product is empty 127 repeat_counts = np.zeros_like(cumprod_lengths) 128 129 else: 130 # sizes of the repeats 131 repeat_counts = cumprod_lengths[-1] / cumprod_lengths 132 # sizes of the tiles 133 tile_counts = np.roll(cumprod_lengths, 1) 134 tile_counts[0] = 1 135 136 # loop over the indexes 137 # for each MultiIndex or Index compute the cartesian product of the codes 138 139 code_list = [] 140 level_list = [] 141 names = [] 142 143 for i, index in enumerate(indexes): 144 if isinstance(index, pd.MultiIndex): 145 codes, levels = index.codes, index.levels 146 else: 147 code, level = pd.factorize(index) 148 codes = [code] 149 levels = [level] 150 151 # compute the cartesian product 152 code_list += [ 153 np.tile(np.repeat(code, repeat_counts[i]), tile_counts[i]) 154 for code in codes 155 ] 156 level_list += levels 157 names += index.names 158 159 return pd.MultiIndex(level_list, code_list, names=names) 160 161 def update(self, other: Mapping[Any, Any]) -> None: 162 other_vars = getattr(other, "variables", other) 163 coords, indexes = merge_coords( 164 [self.variables, other_vars], priority_arg=1, indexes=self.xindexes 165 ) 166 self._update_coords(coords, indexes) 167 168 def _merge_raw(self, other, reflexive): 169 """For use with binary arithmetic.""" 170 if other is None: 171 variables = dict(self.variables) 172 indexes = dict(self.xindexes) 173 else: 174 coord_list = [self, other] if not reflexive else [other, self] 175 variables, indexes = merge_coordinates_without_align(coord_list) 176 return variables, indexes 177 178 @contextmanager 179 def _merge_inplace(self, other): 180 """For use with in-place binary arithmetic.""" 181 if other is None: 182 yield 183 else: 184 # don't include indexes in prioritized, because we didn't align 185 # first and we want indexes to be checked 186 prioritized = { 187 k: (v, None) 188 for k, v in self.variables.items() 189 if k not in self.xindexes 190 } 191 variables, indexes = merge_coordinates_without_align( 192 [self, other], prioritized 193 ) 194 yield 195 self._update_coords(variables, indexes) 196 197 def merge(self, other: "Coordinates") -> "Dataset": 198 """Merge two sets of coordinates to create a new Dataset 199 200 The method implements the logic used for joining coordinates in the 201 result of a binary operation performed on xarray objects: 202 203 - If two index coordinates conflict (are not equal), an exception is 204 raised. You must align your data before passing it to this method. 205 - If an index coordinate and a non-index coordinate conflict, the non- 206 index coordinate is dropped. 207 - If two non-index coordinates conflict, both are dropped. 208 209 Parameters 210 ---------- 211 other : DatasetCoordinates or DataArrayCoordinates 212 The coordinates from another dataset or data array. 213 214 Returns 215 ------- 216 merged : Dataset 217 A new Dataset with merged coordinates. 218 """ 219 from .dataset import Dataset 220 221 if other is None: 222 return self.to_dataset() 223 224 if not isinstance(other, Coordinates): 225 other = Dataset(coords=other).coords 226 227 coords, indexes = merge_coordinates_without_align([self, other]) 228 coord_names = set(coords) 229 return Dataset._construct_direct( 230 variables=coords, coord_names=coord_names, indexes=indexes 231 ) 232 233 234class DatasetCoordinates(Coordinates): 235 """Dictionary like container for Dataset coordinates. 236 237 Essentially an immutable dictionary with keys given by the array's 238 dimensions and the values given by the corresponding xarray.Coordinate 239 objects. 240 """ 241 242 __slots__ = ("_data",) 243 244 def __init__(self, dataset: "Dataset"): 245 self._data = dataset 246 247 @property 248 def _names(self) -> Set[Hashable]: 249 return self._data._coord_names 250 251 @property 252 def dims(self) -> Mapping[Hashable, int]: 253 return self._data.dims 254 255 @property 256 def variables(self) -> Mapping[Hashable, Variable]: 257 return Frozen( 258 {k: v for k, v in self._data.variables.items() if k in self._names} 259 ) 260 261 def __getitem__(self, key: Hashable) -> "DataArray": 262 if key in self._data.data_vars: 263 raise KeyError(key) 264 return cast("DataArray", self._data[key]) 265 266 def to_dataset(self) -> "Dataset": 267 """Convert these coordinates into a new Dataset""" 268 269 names = [name for name in self._data._variables if name in self._names] 270 return self._data._copy_listed(names) 271 272 def _update_coords( 273 self, coords: Dict[Hashable, Variable], indexes: Mapping[Any, Index] 274 ) -> None: 275 from .dataset import calculate_dimensions 276 277 variables = self._data._variables.copy() 278 variables.update(coords) 279 280 # check for inconsistent state *before* modifying anything in-place 281 dims = calculate_dimensions(variables) 282 new_coord_names = set(coords) 283 for dim, size in dims.items(): 284 if dim in variables: 285 new_coord_names.add(dim) 286 287 self._data._variables = variables 288 self._data._coord_names.update(new_coord_names) 289 self._data._dims = dims 290 291 # TODO(shoyer): once ._indexes is always populated by a dict, modify 292 # it to update inplace instead. 293 original_indexes = dict(self._data.xindexes) 294 original_indexes.update(indexes) 295 self._data._indexes = original_indexes 296 297 def __delitem__(self, key: Hashable) -> None: 298 if key in self: 299 del self._data[key] 300 else: 301 raise KeyError(f"{key!r} is not a coordinate variable.") 302 303 def _ipython_key_completions_(self): 304 """Provide method for the key-autocompletions in IPython.""" 305 return [ 306 key 307 for key in self._data._ipython_key_completions_() 308 if key not in self._data.data_vars 309 ] 310 311 312class DataArrayCoordinates(Coordinates): 313 """Dictionary like container for DataArray coordinates. 314 315 Essentially a dict with keys given by the array's 316 dimensions and the values given by corresponding DataArray objects. 317 """ 318 319 __slots__ = ("_data",) 320 321 def __init__(self, dataarray: "DataArray"): 322 self._data = dataarray 323 324 @property 325 def dims(self) -> Tuple[Hashable, ...]: 326 return self._data.dims 327 328 @property 329 def _names(self) -> Set[Hashable]: 330 return set(self._data._coords) 331 332 def __getitem__(self, key: Hashable) -> "DataArray": 333 return self._data._getitem_coord(key) 334 335 def _update_coords( 336 self, coords: Dict[Hashable, Variable], indexes: Mapping[Any, Index] 337 ) -> None: 338 from .dataset import calculate_dimensions 339 340 coords_plus_data = coords.copy() 341 coords_plus_data[_THIS_ARRAY] = self._data.variable 342 dims = calculate_dimensions(coords_plus_data) 343 if not set(dims) <= set(self.dims): 344 raise ValueError( 345 "cannot add coordinates with new dimensions to a DataArray" 346 ) 347 self._data._coords = coords 348 349 # TODO(shoyer): once ._indexes is always populated by a dict, modify 350 # it to update inplace instead. 351 original_indexes = dict(self._data.xindexes) 352 original_indexes.update(indexes) 353 self._data._indexes = original_indexes 354 355 @property 356 def variables(self): 357 return Frozen(self._data._coords) 358 359 def to_dataset(self) -> "Dataset": 360 from .dataset import Dataset 361 362 coords = {k: v.copy(deep=False) for k, v in self._data._coords.items()} 363 return Dataset._construct_direct(coords, set(coords)) 364 365 def __delitem__(self, key: Hashable) -> None: 366 if key not in self: 367 raise KeyError(f"{key!r} is not a coordinate variable.") 368 369 del self._data._coords[key] 370 if self._data._indexes is not None and key in self._data._indexes: 371 del self._data._indexes[key] 372 373 def _ipython_key_completions_(self): 374 """Provide method for the key-autocompletions in IPython.""" 375 return self._data._ipython_key_completions_() 376 377 378def assert_coordinate_consistent( 379 obj: Union["DataArray", "Dataset"], coords: Mapping[Any, Variable] 380) -> None: 381 """Make sure the dimension coordinate of obj is consistent with coords. 382 383 obj: DataArray or Dataset 384 coords: Dict-like of variables 385 """ 386 for k in obj.dims: 387 # make sure there are no conflict in dimension coordinates 388 if k in coords and k in obj.coords and not coords[k].equals(obj[k].variable): 389 raise IndexError( 390 f"dimension coordinate {k!r} conflicts between " 391 f"indexed and indexing objects:\n{obj[k]}\nvs.\n{coords[k]}" 392 ) 393 394 395def remap_label_indexers( 396 obj: Union["DataArray", "Dataset"], 397 indexers: Mapping[Any, Any] = None, 398 method: str = None, 399 tolerance=None, 400 **indexers_kwargs: Any, 401) -> Tuple[dict, dict]: # TODO more precise return type after annotations in indexing 402 """Remap indexers from obj.coords. 403 If indexer is an instance of DataArray and it has coordinate, then this coordinate 404 will be attached to pos_indexers. 405 406 Returns 407 ------- 408 pos_indexers: Same type of indexers. 409 np.ndarray or Variable or DataArray 410 new_indexes: mapping of new dimensional-coordinate. 411 """ 412 from .dataarray import DataArray 413 414 indexers = either_dict_or_kwargs(indexers, indexers_kwargs, "remap_label_indexers") 415 416 v_indexers = { 417 k: v.variable.data if isinstance(v, DataArray) else v 418 for k, v in indexers.items() 419 } 420 421 pos_indexers, new_indexes = indexing.remap_label_indexers( 422 obj, v_indexers, method=method, tolerance=tolerance 423 ) 424 # attach indexer's coordinate to pos_indexers 425 for k, v in indexers.items(): 426 if isinstance(v, Variable): 427 pos_indexers[k] = Variable(v.dims, pos_indexers[k]) 428 elif isinstance(v, DataArray): 429 # drop coordinates found in indexers since .sel() already 430 # ensures alignments 431 coords = {k: var for k, var in v._coords.items() if k not in indexers} 432 pos_indexers[k] = DataArray(pos_indexers[k], coords=coords, dims=v.dims) 433 return pos_indexers, new_indexes 434