1from contextlib import contextmanager
2from typing import (
3    TYPE_CHECKING,
4    Any,
5    Dict,
6    Hashable,
7    Iterator,
8    Mapping,
9    Sequence,
10    Set,
11    Tuple,
12    Union,
13    cast,
14)
15
16import numpy as np
17import pandas as pd
18
19from . import formatting, indexing
20from .indexes import Index, Indexes
21from .merge import merge_coordinates_without_align, merge_coords
22from .utils import Frozen, ReprObject, either_dict_or_kwargs
23from .variable import Variable
24
25if TYPE_CHECKING:
26    from .dataarray import DataArray
27    from .dataset import Dataset
28
29# Used as the key corresponding to a DataArray's variable when converting
30# arbitrary DataArray objects to datasets
31_THIS_ARRAY = ReprObject("<this-array>")
32
33
34class Coordinates(Mapping[Any, "DataArray"]):
35    __slots__ = ()
36
37    def __getitem__(self, key: Hashable) -> "DataArray":
38        raise NotImplementedError()
39
40    def __setitem__(self, key: Hashable, value: Any) -> None:
41        self.update({key: value})
42
43    @property
44    def _names(self) -> Set[Hashable]:
45        raise NotImplementedError()
46
47    @property
48    def dims(self) -> Union[Mapping[Hashable, int], Tuple[Hashable, ...]]:
49        raise NotImplementedError()
50
51    @property
52    def indexes(self) -> Indexes:
53        return self._data.indexes  # type: ignore[attr-defined]
54
55    @property
56    def xindexes(self) -> Indexes:
57        return self._data.xindexes  # type: ignore[attr-defined]
58
59    @property
60    def variables(self):
61        raise NotImplementedError()
62
63    def _update_coords(self, coords, indexes):
64        raise NotImplementedError()
65
66    def __iter__(self) -> Iterator["Hashable"]:
67        # needs to be in the same order as the dataset variables
68        for k in self.variables:
69            if k in self._names:
70                yield k
71
72    def __len__(self) -> int:
73        return len(self._names)
74
75    def __contains__(self, key: Hashable) -> bool:
76        return key in self._names
77
78    def __repr__(self) -> str:
79        return formatting.coords_repr(self)
80
81    def to_dataset(self) -> "Dataset":
82        raise NotImplementedError()
83
84    def to_index(self, ordered_dims: Sequence[Hashable] = None) -> pd.Index:
85        """Convert all index coordinates into a :py:class:`pandas.Index`.
86
87        Parameters
88        ----------
89        ordered_dims : sequence of hashable, optional
90            Possibly reordered version of this object's dimensions indicating
91            the order in which dimensions should appear on the result.
92
93        Returns
94        -------
95        pandas.Index
96            Index subclass corresponding to the outer-product of all dimension
97            coordinates. This will be a MultiIndex if this object is has more
98            than more dimension.
99        """
100        if ordered_dims is None:
101            ordered_dims = list(self.dims)
102        elif set(ordered_dims) != set(self.dims):
103            raise ValueError(
104                "ordered_dims must match dims, but does not: "
105                "{} vs {}".format(ordered_dims, self.dims)
106            )
107
108        if len(ordered_dims) == 0:
109            raise ValueError("no valid index for a 0-dimensional object")
110        elif len(ordered_dims) == 1:
111            (dim,) = ordered_dims
112            return self._data.get_index(dim)  # type: ignore[attr-defined]
113        else:
114            indexes = [
115                self._data.get_index(k) for k in ordered_dims  # type: ignore[attr-defined]
116            ]
117
118            # compute the sizes of the repeat and tile for the cartesian product
119            # (taken from pandas.core.reshape.util)
120            index_lengths = np.fromiter(
121                (len(index) for index in indexes), dtype=np.intp
122            )
123            cumprod_lengths = np.cumproduct(index_lengths)
124
125            if cumprod_lengths[-1] == 0:
126                # if any factor is empty, the cartesian product is empty
127                repeat_counts = np.zeros_like(cumprod_lengths)
128
129            else:
130                # sizes of the repeats
131                repeat_counts = cumprod_lengths[-1] / cumprod_lengths
132            # sizes of the tiles
133            tile_counts = np.roll(cumprod_lengths, 1)
134            tile_counts[0] = 1
135
136            # loop over the indexes
137            # for each MultiIndex or Index compute the cartesian product of the codes
138
139            code_list = []
140            level_list = []
141            names = []
142
143            for i, index in enumerate(indexes):
144                if isinstance(index, pd.MultiIndex):
145                    codes, levels = index.codes, index.levels
146                else:
147                    code, level = pd.factorize(index)
148                    codes = [code]
149                    levels = [level]
150
151                # compute the cartesian product
152                code_list += [
153                    np.tile(np.repeat(code, repeat_counts[i]), tile_counts[i])
154                    for code in codes
155                ]
156                level_list += levels
157                names += index.names
158
159        return pd.MultiIndex(level_list, code_list, names=names)
160
161    def update(self, other: Mapping[Any, Any]) -> None:
162        other_vars = getattr(other, "variables", other)
163        coords, indexes = merge_coords(
164            [self.variables, other_vars], priority_arg=1, indexes=self.xindexes
165        )
166        self._update_coords(coords, indexes)
167
168    def _merge_raw(self, other, reflexive):
169        """For use with binary arithmetic."""
170        if other is None:
171            variables = dict(self.variables)
172            indexes = dict(self.xindexes)
173        else:
174            coord_list = [self, other] if not reflexive else [other, self]
175            variables, indexes = merge_coordinates_without_align(coord_list)
176        return variables, indexes
177
178    @contextmanager
179    def _merge_inplace(self, other):
180        """For use with in-place binary arithmetic."""
181        if other is None:
182            yield
183        else:
184            # don't include indexes in prioritized, because we didn't align
185            # first and we want indexes to be checked
186            prioritized = {
187                k: (v, None)
188                for k, v in self.variables.items()
189                if k not in self.xindexes
190            }
191            variables, indexes = merge_coordinates_without_align(
192                [self, other], prioritized
193            )
194            yield
195            self._update_coords(variables, indexes)
196
197    def merge(self, other: "Coordinates") -> "Dataset":
198        """Merge two sets of coordinates to create a new Dataset
199
200        The method implements the logic used for joining coordinates in the
201        result of a binary operation performed on xarray objects:
202
203        - If two index coordinates conflict (are not equal), an exception is
204          raised. You must align your data before passing it to this method.
205        - If an index coordinate and a non-index coordinate conflict, the non-
206          index coordinate is dropped.
207        - If two non-index coordinates conflict, both are dropped.
208
209        Parameters
210        ----------
211        other : DatasetCoordinates or DataArrayCoordinates
212            The coordinates from another dataset or data array.
213
214        Returns
215        -------
216        merged : Dataset
217            A new Dataset with merged coordinates.
218        """
219        from .dataset import Dataset
220
221        if other is None:
222            return self.to_dataset()
223
224        if not isinstance(other, Coordinates):
225            other = Dataset(coords=other).coords
226
227        coords, indexes = merge_coordinates_without_align([self, other])
228        coord_names = set(coords)
229        return Dataset._construct_direct(
230            variables=coords, coord_names=coord_names, indexes=indexes
231        )
232
233
234class DatasetCoordinates(Coordinates):
235    """Dictionary like container for Dataset coordinates.
236
237    Essentially an immutable dictionary with keys given by the array's
238    dimensions and the values given by the corresponding xarray.Coordinate
239    objects.
240    """
241
242    __slots__ = ("_data",)
243
244    def __init__(self, dataset: "Dataset"):
245        self._data = dataset
246
247    @property
248    def _names(self) -> Set[Hashable]:
249        return self._data._coord_names
250
251    @property
252    def dims(self) -> Mapping[Hashable, int]:
253        return self._data.dims
254
255    @property
256    def variables(self) -> Mapping[Hashable, Variable]:
257        return Frozen(
258            {k: v for k, v in self._data.variables.items() if k in self._names}
259        )
260
261    def __getitem__(self, key: Hashable) -> "DataArray":
262        if key in self._data.data_vars:
263            raise KeyError(key)
264        return cast("DataArray", self._data[key])
265
266    def to_dataset(self) -> "Dataset":
267        """Convert these coordinates into a new Dataset"""
268
269        names = [name for name in self._data._variables if name in self._names]
270        return self._data._copy_listed(names)
271
272    def _update_coords(
273        self, coords: Dict[Hashable, Variable], indexes: Mapping[Any, Index]
274    ) -> None:
275        from .dataset import calculate_dimensions
276
277        variables = self._data._variables.copy()
278        variables.update(coords)
279
280        # check for inconsistent state *before* modifying anything in-place
281        dims = calculate_dimensions(variables)
282        new_coord_names = set(coords)
283        for dim, size in dims.items():
284            if dim in variables:
285                new_coord_names.add(dim)
286
287        self._data._variables = variables
288        self._data._coord_names.update(new_coord_names)
289        self._data._dims = dims
290
291        # TODO(shoyer): once ._indexes is always populated by a dict, modify
292        # it to update inplace instead.
293        original_indexes = dict(self._data.xindexes)
294        original_indexes.update(indexes)
295        self._data._indexes = original_indexes
296
297    def __delitem__(self, key: Hashable) -> None:
298        if key in self:
299            del self._data[key]
300        else:
301            raise KeyError(f"{key!r} is not a coordinate variable.")
302
303    def _ipython_key_completions_(self):
304        """Provide method for the key-autocompletions in IPython."""
305        return [
306            key
307            for key in self._data._ipython_key_completions_()
308            if key not in self._data.data_vars
309        ]
310
311
312class DataArrayCoordinates(Coordinates):
313    """Dictionary like container for DataArray coordinates.
314
315    Essentially a dict with keys given by the array's
316    dimensions and the values given by corresponding DataArray objects.
317    """
318
319    __slots__ = ("_data",)
320
321    def __init__(self, dataarray: "DataArray"):
322        self._data = dataarray
323
324    @property
325    def dims(self) -> Tuple[Hashable, ...]:
326        return self._data.dims
327
328    @property
329    def _names(self) -> Set[Hashable]:
330        return set(self._data._coords)
331
332    def __getitem__(self, key: Hashable) -> "DataArray":
333        return self._data._getitem_coord(key)
334
335    def _update_coords(
336        self, coords: Dict[Hashable, Variable], indexes: Mapping[Any, Index]
337    ) -> None:
338        from .dataset import calculate_dimensions
339
340        coords_plus_data = coords.copy()
341        coords_plus_data[_THIS_ARRAY] = self._data.variable
342        dims = calculate_dimensions(coords_plus_data)
343        if not set(dims) <= set(self.dims):
344            raise ValueError(
345                "cannot add coordinates with new dimensions to a DataArray"
346            )
347        self._data._coords = coords
348
349        # TODO(shoyer): once ._indexes is always populated by a dict, modify
350        # it to update inplace instead.
351        original_indexes = dict(self._data.xindexes)
352        original_indexes.update(indexes)
353        self._data._indexes = original_indexes
354
355    @property
356    def variables(self):
357        return Frozen(self._data._coords)
358
359    def to_dataset(self) -> "Dataset":
360        from .dataset import Dataset
361
362        coords = {k: v.copy(deep=False) for k, v in self._data._coords.items()}
363        return Dataset._construct_direct(coords, set(coords))
364
365    def __delitem__(self, key: Hashable) -> None:
366        if key not in self:
367            raise KeyError(f"{key!r} is not a coordinate variable.")
368
369        del self._data._coords[key]
370        if self._data._indexes is not None and key in self._data._indexes:
371            del self._data._indexes[key]
372
373    def _ipython_key_completions_(self):
374        """Provide method for the key-autocompletions in IPython."""
375        return self._data._ipython_key_completions_()
376
377
378def assert_coordinate_consistent(
379    obj: Union["DataArray", "Dataset"], coords: Mapping[Any, Variable]
380) -> None:
381    """Make sure the dimension coordinate of obj is consistent with coords.
382
383    obj: DataArray or Dataset
384    coords: Dict-like of variables
385    """
386    for k in obj.dims:
387        # make sure there are no conflict in dimension coordinates
388        if k in coords and k in obj.coords and not coords[k].equals(obj[k].variable):
389            raise IndexError(
390                f"dimension coordinate {k!r} conflicts between "
391                f"indexed and indexing objects:\n{obj[k]}\nvs.\n{coords[k]}"
392            )
393
394
395def remap_label_indexers(
396    obj: Union["DataArray", "Dataset"],
397    indexers: Mapping[Any, Any] = None,
398    method: str = None,
399    tolerance=None,
400    **indexers_kwargs: Any,
401) -> Tuple[dict, dict]:  # TODO more precise return type after annotations in indexing
402    """Remap indexers from obj.coords.
403    If indexer is an instance of DataArray and it has coordinate, then this coordinate
404    will be attached to pos_indexers.
405
406    Returns
407    -------
408    pos_indexers: Same type of indexers.
409        np.ndarray or Variable or DataArray
410    new_indexes: mapping of new dimensional-coordinate.
411    """
412    from .dataarray import DataArray
413
414    indexers = either_dict_or_kwargs(indexers, indexers_kwargs, "remap_label_indexers")
415
416    v_indexers = {
417        k: v.variable.data if isinstance(v, DataArray) else v
418        for k, v in indexers.items()
419    }
420
421    pos_indexers, new_indexes = indexing.remap_label_indexers(
422        obj, v_indexers, method=method, tolerance=tolerance
423    )
424    # attach indexer's coordinate to pos_indexers
425    for k, v in indexers.items():
426        if isinstance(v, Variable):
427            pos_indexers[k] = Variable(v.dims, pos_indexers[k])
428        elif isinstance(v, DataArray):
429            # drop coordinates found in indexers since .sel() already
430            # ensures alignments
431            coords = {k: var for k, var in v._coords.items() if k not in indexers}
432            pos_indexers[k] = DataArray(pos_indexers[k], coords=coords, dims=v.dims)
433    return pos_indexers, new_indexes
434