1# Copyright (c) 2018,2019 MetPy Developers.
2# Distributed under the terms of the BSD 3-Clause License.
3# SPDX-License-Identifier: BSD-3-Clause
4"""Provide accessors to enhance interoperability between xarray and MetPy.
5
6MetPy relies upon the `CF Conventions <http://cfconventions.org/>`_. to provide helpful
7attributes and methods on xarray DataArrays and Dataset for working with
8coordinate-related metadata. Also included are several attributes and methods for unit
9operations.
10
11These accessors will be activated with any import of MetPy. Do not use the
12``MetPyDataArrayAccessor`` or ``MetPyDatasetAccessor`` classes directly, instead, utilize the
13applicable properties and methods via the ``.metpy`` attribute on an xarray DataArray or
14Dataset.
15
16See Also: :doc:`xarray with MetPy Tutorial </tutorials/xarray_tutorial>`.
17"""
18import contextlib
19import functools
20from inspect import signature
21from itertools import chain
22import logging
23import re
24import warnings
25
26import numpy as np
27from pyproj import CRS, Proj
28import xarray as xr
29
30from ._vendor.xarray import either_dict_or_kwargs, expanded_indexer, is_dict_like
31from .units import DimensionalityError, UndefinedUnitError, units
32
33__all__ = ('MetPyDataArrayAccessor', 'MetPyDatasetAccessor', 'grid_deltas_from_dataarray')
34metpy_axes = ['time', 'vertical', 'y', 'latitude', 'x', 'longitude']
35
36# Define the criteria for coordinate matches
37coordinate_criteria = {
38    'standard_name': {
39        'time': 'time',
40        'vertical': {'air_pressure', 'height', 'geopotential_height', 'altitude',
41                     'model_level_number', 'atmosphere_ln_pressure_coordinate',
42                     'atmosphere_sigma_coordinate',
43                     'atmosphere_hybrid_sigma_pressure_coordinate',
44                     'atmosphere_hybrid_height_coordinate', 'atmosphere_sleve_coordinate',
45                     'height_above_geopotential_datum', 'height_above_reference_ellipsoid',
46                     'height_above_mean_sea_level'},
47        'y': 'projection_y_coordinate',
48        'latitude': 'latitude',
49        'x': 'projection_x_coordinate',
50        'longitude': 'longitude'
51    },
52    '_CoordinateAxisType': {
53        'time': 'Time',
54        'vertical': {'GeoZ', 'Height', 'Pressure'},
55        'y': 'GeoY',
56        'latitude': 'Lat',
57        'x': 'GeoX',
58        'longitude': 'Lon'
59    },
60    'axis': {
61        'time': 'T',
62        'vertical': 'Z',
63        'y': 'Y',
64        'x': 'X'
65    },
66    'positive': {
67        'vertical': {'up', 'down'}
68    },
69    'units': {
70        'vertical': {
71            'match': 'dimensionality',
72            'units': 'Pa'
73        },
74        'latitude': {
75            'match': 'name',
76            'units': {'degree_north', 'degree_N', 'degreeN', 'degrees_north', 'degrees_N',
77                      'degreesN'}
78        },
79        'longitude': {
80            'match': 'name',
81            'units': {'degree_east', 'degree_E', 'degreeE', 'degrees_east', 'degrees_E',
82                      'degreesE'}
83        },
84    },
85    'regular_expression': {
86        'time': r'time[0-9]*',
87        'vertical': (r'(lv_|bottom_top|sigma|h(ei)?ght|altitude|depth|isobaric|pres|'
88                     r'isotherm)[a-z_]*[0-9]*'),
89        'y': r'y',
90        'latitude': r'x?lat[a-z0-9]*',
91        'x': r'x',
92        'longitude': r'x?lon[a-z0-9]*'
93    }
94}
95
96log = logging.getLogger(__name__)
97
98_axis_identifier_error = ('Given axis is not valid. Must be an axis number, a dimension '
99                          'coordinate name, or a standard axis type.')
100
101
102@xr.register_dataarray_accessor('metpy')
103class MetPyDataArrayAccessor:
104    r"""Provide custom attributes and methods on xarray DataArrays for MetPy functionality.
105
106    This accessor provides several convenient attributes and methods through the `.metpy`
107    attribute on a DataArray. For example, MetPy can identify the coordinate corresponding
108    to a particular axis (given sufficient metadata):
109
110        >>> import xarray as xr
111        >>> from metpy.units import units
112        >>> temperature = xr.DataArray([[0, 1], [2, 3]] * units.degC, dims=('lat', 'lon'),
113        ...                            coords={'lat': [40, 41], 'lon': [-105, -104]})
114        >>> temperature.metpy.x
115        <xarray.DataArray 'lon' (lon: 2)>
116        array([-105, -104])
117        Coordinates:
118          * lon      (lon) int64 -105 -104
119        Attributes:
120            _metpy_axis:  x,longitude
121
122    """
123
124    def __init__(self, data_array):  # noqa: D107
125        # Initialize accessor with a DataArray. (Do not use directly).
126        self._data_array = data_array
127
128    @property
129    def units(self):
130        """Return the units of this DataArray as a `pint.Unit`."""
131        if isinstance(self._data_array.variable._data, units.Quantity):
132            return self._data_array.variable._data.units
133        else:
134            return units.parse_units(self._data_array.attrs.get('units', 'dimensionless'))
135
136    @property
137    def magnitude(self):
138        """Return the magnitude of the data values of this DataArray (i.e., without units)."""
139        if isinstance(self._data_array.data, units.Quantity):
140            return self._data_array.data.magnitude
141        else:
142            return self._data_array.data
143
144    @property
145    def unit_array(self):
146        """Return the data values of this DataArray as a `pint.Quantity`.
147
148        Notes
149        -----
150        If not already existing as a `pint.Quantity` or Dask array, the data of this DataArray
151        will be loaded into memory by this operation. Do not utilize on moderate- to
152        large-sized remote datasets before subsetting!
153        """
154        if isinstance(self._data_array.data, units.Quantity):
155            return self._data_array.data
156        else:
157            return units.Quantity(self._data_array.data, self.units)
158
159    def convert_units(self, units):
160        """Return new DataArray with values converted to different units.
161
162        Notes
163        -----
164        Any cached/lazy-loaded data (except that in a Dask array) will be loaded into memory
165        by this operation. Do not utilize on moderate- to large-sized remote datasets before
166        subsetting!
167
168        See Also
169        --------
170        convert_coordinate_units
171        """
172        return self.quantify().copy(data=self.unit_array.to(units))
173
174    def convert_coordinate_units(self, coord, units):
175        """Return new DataArray with specified coordinate converted to different units.
176
177        This operation differs from ``.convert_units`` since xarray coordinate indexes do not
178        yet support unit-aware arrays (even though unit-aware *data* arrays are).
179
180        Notes
181        -----
182        Any cached/lazy-loaded coordinate data (except that in a Dask array) will be loaded
183        into memory by this operation.
184
185        See Also
186        --------
187        convert_units
188        """
189        new_coord_var = self._data_array[coord].copy(
190            data=self._data_array[coord].metpy.unit_array.m_as(units)
191        )
192        new_coord_var.attrs['units'] = str(units)
193        return self._data_array.assign_coords(coords={coord: new_coord_var})
194
195    def quantify(self):
196        """Return a new DataArray with the data converted to a `pint.Quantity`.
197
198        Notes
199        -----
200        Any cached/lazy-loaded data (except that in a Dask array) will be loaded into memory
201        by this operation. Do not utilize on moderate- to large-sized remote datasets before
202        subsetting!
203        """
204        if (
205            not isinstance(self._data_array.data, units.Quantity)
206            and np.issubdtype(self._data_array.data.dtype, np.number)
207        ):
208            # Only quantify if not already quantified and is quantifiable
209            quantified_dataarray = self._data_array.copy(data=self.unit_array)
210            if 'units' in quantified_dataarray.attrs:
211                del quantified_dataarray.attrs['units']
212        else:
213            quantified_dataarray = self._data_array
214        return quantified_dataarray
215
216    def dequantify(self):
217        """Return a new DataArray with the data as magnitude and the units as an attribute."""
218        if isinstance(self._data_array.data, units.Quantity):
219            # Only dequantify if quantified
220            dequantified_dataarray = self._data_array.copy(
221                data=self._data_array.data.magnitude
222            )
223            dequantified_dataarray.attrs['units'] = str(self.units)
224        else:
225            dequantified_dataarray = self._data_array
226        return dequantified_dataarray
227
228    @property
229    def crs(self):
230        """Return the coordinate reference system (CRS) as a CFProjection object."""
231        if 'metpy_crs' in self._data_array.coords:
232            return self._data_array.coords['metpy_crs'].item()
233        raise AttributeError('crs attribute is not available.')
234
235    @property
236    def cartopy_crs(self):
237        """Return the coordinate reference system (CRS) as a cartopy object."""
238        return self.crs.to_cartopy()
239
240    @property
241    def cartopy_globe(self):
242        """Return the globe belonging to the coordinate reference system (CRS)."""
243        return self.crs.cartopy_globe
244
245    @property
246    def cartopy_geodetic(self):
247        """Return the cartopy Geodetic CRS associated with the native CRS globe."""
248        return self.crs.cartopy_geodetic
249
250    @property
251    def pyproj_crs(self):
252        """Return the coordinate reference system (CRS) as a pyproj object."""
253        return self.crs.to_pyproj()
254
255    def _fixup_coordinate_map(self, coord_map):
256        """Ensure sure we have coordinate variables in map, not coordinate names."""
257        new_coord_map = {}
258        for axis in coord_map:
259            if coord_map[axis] is not None and not isinstance(coord_map[axis], xr.DataArray):
260                new_coord_map[axis] = self._data_array[coord_map[axis]]
261            else:
262                new_coord_map[axis] = coord_map[axis]
263
264        return new_coord_map
265
266    def assign_coordinates(self, coordinates):
267        """Return new DataArray with given coordinates assigned to the given MetPy axis types.
268
269        Parameters
270        ----------
271        coordinates : dict or None
272            Mapping from axis types ('time', 'vertical', 'y', 'latitude', 'x', 'longitude') to
273            coordinates of this DataArray. Coordinates can either be specified directly or by
274            their name. If ``None``, clears the `_metpy_axis` attribute on all coordinates,
275            which will trigger reparsing of all coordinates on next access.
276
277        """
278        coord_updates = {}
279        if coordinates:
280            # Assign the _metpy_axis attributes according to supplied mapping
281            coordinates = self._fixup_coordinate_map(coordinates)
282            for axis in coordinates:
283                if coordinates[axis] is not None:
284                    coord_updates[coordinates[axis].name] = (
285                        coordinates[axis].assign_attrs(
286                            _assign_axis(coordinates[axis].attrs.copy(), axis)
287                        )
288                    )
289        else:
290            # Clear _metpy_axis attribute on all coordinates
291            for coord_name, coord_var in self._data_array.coords.items():
292                coord_updates[coord_name] = coord_var.copy(deep=False)
293
294                # Some coordinates remained linked in old form under other coordinates. We
295                # need to remove from these.
296                sub_coords = coord_updates[coord_name].coords
297                for sub_coord in sub_coords:
298                    coord_updates[coord_name].coords[sub_coord].attrs.pop('_metpy_axis', None)
299
300                # Now we can remove the _metpy_axis attr from the coordinate itself
301                coord_updates[coord_name].attrs.pop('_metpy_axis', None)
302
303        return self._data_array.assign_coords(coord_updates)
304
305    def _generate_coordinate_map(self):
306        """Generate a coordinate map via CF conventions and other methods."""
307        coords = self._data_array.coords.values()
308        # Parse all the coordinates, attempting to identify x, longitude, y, latitude,
309        # vertical, time
310        coord_lists = {'time': [], 'vertical': [], 'y': [], 'latitude': [], 'x': [],
311                       'longitude': []}
312        for coord_var in coords:
313            # Identify the coordinate type using check_axis helper
314            for axis in coord_lists:
315                if check_axis(coord_var, axis):
316                    coord_lists[axis].append(coord_var)
317
318        # Fill in x/y with longitude/latitude if x/y not otherwise present
319        for geometric, graticule in (('y', 'latitude'), ('x', 'longitude')):
320            if len(coord_lists[geometric]) == 0 and len(coord_lists[graticule]) > 0:
321                coord_lists[geometric] = coord_lists[graticule]
322
323        # Filter out multidimensional coordinates where not allowed
324        require_1d_coord = ['time', 'vertical', 'y', 'x']
325        for axis in require_1d_coord:
326            coord_lists[axis] = [coord for coord in coord_lists[axis] if coord.ndim <= 1]
327
328        # Resolve any coordinate type duplication
329        axis_duplicates = [axis for axis in coord_lists if len(coord_lists[axis]) > 1]
330        for axis in axis_duplicates:
331            self._resolve_axis_duplicates(axis, coord_lists)
332
333        # Collapse the coord_lists to a coord_map
334        return {axis: (coord_lists[axis][0] if len(coord_lists[axis]) > 0 else None)
335                for axis in coord_lists}
336
337    def _resolve_axis_duplicates(self, axis, coord_lists):
338        """Handle coordinate duplication for an axis type if it arises."""
339        # If one and only one of the possible axes is a dimension, use it
340        dimension_coords = [coord_var for coord_var in coord_lists[axis] if
341                            coord_var.name in coord_var.dims]
342        if len(dimension_coords) == 1:
343            coord_lists[axis] = dimension_coords
344            return
345
346        # Ambiguous axis, raise warning and do not parse
347        varname = (' "' + self._data_array.name + '"'
348                   if self._data_array.name is not None else '')
349        warnings.warn('More than one ' + axis + ' coordinate present for variable'
350                      + varname + '.')
351        coord_lists[axis] = []
352
353    def _metpy_axis_search(self, metpy_axis):
354        """Search for cached _metpy_axis attribute on the coordinates, otherwise parse."""
355        # Search for coord with proper _metpy_axis
356        coords = self._data_array.coords.values()
357        for coord_var in coords:
358            if metpy_axis in coord_var.attrs.get('_metpy_axis', '').split(','):
359                return coord_var
360
361        # Opportunistically parse all coordinates, and assign if not already assigned
362        # Note: since this is generally called by way of the coordinate properties, to cache
363        # the coordinate parsing results in coord_map on the coordinates means modifying the
364        # DataArray in-place (an exception to the usual behavior of MetPy's accessor). This is
365        # considered safe because it only effects the "_metpy_axis" attribute on the
366        # coordinates, and nothing else.
367        coord_map = self._generate_coordinate_map()
368        for axis, coord_var in coord_map.items():
369            if (coord_var is not None
370                and not any(axis in coord.attrs.get('_metpy_axis', '').split(',')
371                            for coord in coords)):
372
373                _assign_axis(coord_var.attrs, axis)
374
375        # Return parsed result (can be None if none found)
376        return coord_map[metpy_axis]
377
378    def _axis(self, axis):
379        """Return the coordinate variable corresponding to the given individual axis type."""
380        if axis in metpy_axes:
381            coord_var = self._metpy_axis_search(axis)
382            if coord_var is not None:
383                return coord_var
384            else:
385                raise AttributeError(axis + ' attribute is not available.')
386        else:
387            raise AttributeError("'" + axis + "' is not an interpretable axis.")
388
389    def coordinates(self, *args):
390        """Return the coordinate variables corresponding to the given axes types.
391
392        Parameters
393        ----------
394        args : str
395            Strings describing the axes type(s) to obtain. Currently understood types are
396            'time', 'vertical', 'y', 'latitude', 'x', and 'longitude'.
397
398        Notes
399        -----
400        This method is designed for use with multiple coordinates; it returns a generator. To
401        access a single coordinate, use the appropriate attribute on the accessor, or use tuple
402        unpacking.
403
404        """
405        for arg in args:
406            yield self._axis(arg)
407
408    @property
409    def time(self):
410        """Return the time coordinate."""
411        return self._axis('time')
412
413    @property
414    def vertical(self):
415        """Return the vertical coordinate."""
416        return self._axis('vertical')
417
418    @property
419    def y(self):
420        """Return the y coordinate."""
421        return self._axis('y')
422
423    @property
424    def latitude(self):
425        """Return the latitude coordinate (if it exists)."""
426        return self._axis('latitude')
427
428    @property
429    def x(self):
430        """Return the x coordinate."""
431        return self._axis('x')
432
433    @property
434    def longitude(self):
435        """Return the longitude coordinate (if it exists)."""
436        return self._axis('longitude')
437
438    def coordinates_identical(self, other):
439        """Return whether or not the coordinates of other match this DataArray's."""
440        return (len(self._data_array.coords) == len(other.coords)
441                and all(coord_name in other.coords and other[coord_name].identical(coord_var)
442                        for coord_name, coord_var in self._data_array.coords.items()))
443
444    @property
445    def time_deltas(self):
446        """Return the time difference of the data in seconds (to microsecond precision)."""
447        us_diffs = np.diff(self._data_array.values).astype('timedelta64[us]').astype('int64')
448        return units.Quantity(us_diffs / 1e6, 's')
449
450    def find_axis_name(self, axis):
451        """Return the name of the axis corresponding to the given identifier.
452
453        Parameters
454        ----------
455        axis : str or int
456            Identifier for an axis. Can be an axis number (integer), dimension coordinate
457            name (string) or a standard axis type (string).
458
459        """
460        if isinstance(axis, int):
461            # If an integer, use the corresponding dimension
462            return self._data_array.dims[axis]
463        elif axis not in self._data_array.dims and axis in metpy_axes:
464            # If not a dimension name itself, but a valid axis type, get the name of the
465            # coordinate corresponding to that axis type
466            return self._axis(axis).name
467        elif axis in self._data_array.dims and axis in self._data_array.coords:
468            # If this is a dimension coordinate name, use it directly
469            return axis
470        else:
471            # Otherwise, not valid
472            raise ValueError(_axis_identifier_error)
473
474    def find_axis_number(self, axis):
475        """Return the dimension number of the axis corresponding to the given identifier.
476
477        Parameters
478        ----------
479        axis : str or int
480            Identifier for an axis. Can be an axis number (integer), dimension coordinate
481            name (string) or a standard axis type (string).
482
483        """
484        if isinstance(axis, int):
485            # If an integer, use it directly
486            return axis
487        elif axis in self._data_array.dims:
488            # Simply index into dims
489            return self._data_array.dims.index(axis)
490        elif axis in metpy_axes:
491            # If not a dimension name itself, but a valid axis type, first determine if this
492            # standard axis type is present as a dimension coordinate
493            try:
494                name = self._axis(axis).name
495                return self._data_array.dims.index(name)
496            except AttributeError as exc:
497                # If x, y, or vertical requested, but not available, attempt to interpret dim
498                # names using regular expressions from coordinate parsing to allow for
499                # multidimensional lat/lon without y/x dimension coordinates, and basic
500                # vertical dim recognition
501                if axis in ('vertical', 'y', 'x'):
502                    for i, dim in enumerate(self._data_array.dims):
503                        if re.match(coordinate_criteria['regular_expression'][axis],
504                                    dim.lower()):
505                            return i
506                raise exc
507            except ValueError:
508                # Intercept ValueError when axis type found but not dimension coordinate
509                raise AttributeError(f'Requested {axis} dimension coordinate but {axis} '
510                                     f'coordinate {name} is not a dimension')
511        else:
512            # Otherwise, not valid
513            raise ValueError(_axis_identifier_error)
514
515    class _LocIndexer:
516        """Provide the unit-wrapped .loc indexer for data arrays."""
517
518        def __init__(self, data_array):
519            self.data_array = data_array
520
521        def expand(self, key):
522            """Parse key using xarray utils to ensure we have dimension names."""
523            if not is_dict_like(key):
524                labels = expanded_indexer(key, self.data_array.ndim)
525                key = dict(zip(self.data_array.dims, labels))
526            return key
527
528        def __getitem__(self, key):
529            key = _reassign_quantity_indexer(self.data_array, self.expand(key))
530            return self.data_array.loc[key]
531
532        def __setitem__(self, key, value):
533            key = _reassign_quantity_indexer(self.data_array, self.expand(key))
534            self.data_array.loc[key] = value
535
536    @property
537    def loc(self):
538        """Wrap DataArray.loc with an indexer to handle units and coordinate types."""
539        return self._LocIndexer(self._data_array)
540
541    def sel(self, indexers=None, method=None, tolerance=None, drop=False, **indexers_kwargs):
542        """Wrap DataArray.sel to handle units and coordinate types."""
543        indexers = either_dict_or_kwargs(indexers, indexers_kwargs, 'sel')
544        indexers = _reassign_quantity_indexer(self._data_array, indexers)
545        return self._data_array.sel(indexers, method=method, tolerance=tolerance, drop=drop)
546
547    def assign_crs(self, cf_attributes=None, **kwargs):
548        """Assign a CRS to this DataArray based on CF projection attributes.
549
550        Specify a coordinate reference system/grid mapping following the Climate and
551        Forecasting (CF) conventions (see `Appendix F: Grid Mappings
552        <http://cfconventions.org/Data/cf-conventions/cf-conventions-1.7/cf-conventions.html#appendix-grid-mappings>`_
553        ) and store in the ``metpy_crs`` coordinate.
554
555        This method is only required if your data do not come from a dataset that follows CF
556        conventions with respect to grid mappings (in which case the ``.parse_cf`` method will
557        parse for the CRS metadata automatically).
558
559        Parameters
560        ----------
561        cf_attributes : dict, optional
562            Dictionary of CF projection attributes
563        kwargs : optional
564            CF projection attributes specified as keyword arguments
565
566        Returns
567        -------
568        `xarray.DataArray`
569            New xarray DataArray with CRS coordinate assigned
570
571        Notes
572        -----
573        CF projection arguments should be supplied as a dictionary or collection of kwargs,
574        but not both.
575
576        """
577        return _assign_crs(self._data_array, cf_attributes, kwargs)
578
579    def assign_latitude_longitude(self, force=False):
580        """Assign 2D latitude and longitude coordinates derived from 1D y and x coordinates.
581
582        Parameters
583        ----------
584        force : bool, optional
585            If force is true, overwrite latitude and longitude coordinates if they exist,
586            otherwise, raise a RuntimeError if such coordinates exist.
587
588        Returns
589        -------
590        `xarray.DataArray`
591            New xarray DataArray with latitude and longtiude auxiliary coordinates assigned.
592
593        Notes
594        -----
595        A valid CRS coordinate must be present (as assigned by ``.parse_cf`` or
596        ``.assign_crs``). PyProj is used for the coordinate transformations.
597
598        """
599        # Check for existing latitude and longitude coords
600        if (not force and (self._metpy_axis_search('latitude') is not None
601                           or self._metpy_axis_search('longitude'))):
602            raise RuntimeError('Latitude/longitude coordinate(s) are present. If you wish to '
603                               'overwrite these, specify force=True.')
604
605        # Build new latitude and longitude DataArrays
606        latitude, longitude = _build_latitude_longitude(self._data_array)
607
608        # Assign new coordinates, refresh MetPy's parsed axis attribute, and return result
609        new_dataarray = self._data_array.assign_coords(latitude=latitude, longitude=longitude)
610        return new_dataarray.metpy.assign_coordinates(None)
611
612    def assign_y_x(self, force=False, tolerance=None):
613        """Assign 1D y and x dimension coordinates derived from 2D latitude and longitude.
614
615        Parameters
616        ----------
617        force : bool, optional
618            If force is true, overwrite y and x coordinates if they exist, otherwise, raise a
619            RuntimeError if such coordinates exist.
620        tolerance : `pint.Quantity`
621            Maximum range tolerated when collapsing projected y and x coordinates from 2D to
622            1D. Defaults to 1 meter.
623
624        Returns
625        -------
626        `xarray.DataArray`
627            New xarray DataArray with y and x dimension coordinates assigned.
628
629        Notes
630        -----
631        A valid CRS coordinate must be present (as assigned by ``.parse_cf`` or
632        ``.assign_crs``) for the y/x projection space. PyProj is used for the coordinate
633        transformations.
634
635        """
636        # Check for existing latitude and longitude coords
637        if (not force and (self._metpy_axis_search('y') is not None
638                           or self._metpy_axis_search('x'))):
639            raise RuntimeError('y/x coordinate(s) are present. If you wish to overwrite '
640                               'these, specify force=True.')
641
642        # Build new y and x DataArrays
643        y, x = _build_y_x(self._data_array, tolerance)
644
645        # Assign new coordinates, refresh MetPy's parsed axis attribute, and return result
646        new_dataarray = self._data_array.assign_coords(**{y.name: y, x.name: x})
647        return new_dataarray.metpy.assign_coordinates(None)
648
649
650@xr.register_dataset_accessor('metpy')
651class MetPyDatasetAccessor:
652    """Provide custom attributes and methods on XArray Datasets for MetPy functionality.
653
654    This accessor provides parsing of CF grid mapping metadata, generating missing coordinate
655    types, and unit-/coordinate-type-aware operations.
656
657        >>> import xarray as xr
658        >>> from metpy.cbook import get_test_data
659        >>> ds = xr.open_dataset(get_test_data('narr_example.nc', False)).metpy.parse_cf()
660        >>> print(ds['metpy_crs'].item())
661        Projection: lambert_conformal_conic
662
663    """
664
665    def __init__(self, dataset):  # noqa: D107
666        # Initialize accessor with a Dataset. (Do not use directly).
667        self._dataset = dataset
668
669    def parse_cf(self, varname=None, coordinates=None):
670        """Parse dataset for coordinate system metadata according to CF conventions.
671
672        Interpret the grid mapping metadata in the dataset according to the Climate and
673        Forecasting (CF) conventions (see `Appendix F: Grid Mappings
674        <http://cfconventions.org/Data/cf-conventions/cf-conventions-1.7/cf-conventions.html#appendix-grid-mappings>`_
675        ) and store in the ``metpy_crs`` coordinate. Also, gives option to manually specify
676        coordinate types with the ``coordinates`` keyword argument.
677
678        If your dataset does not follow the CF conventions, you can manually supply the grid
679        mapping metadata with the ``.assign_crs`` method.
680
681        This method operates on individual data variables within the dataset, so do not be
682        surprised if information not associated with individual data variables is not
683        preserved.
684
685        Parameters
686        ----------
687        varname : str or iterable of str, optional
688            Name of the variable(s) to extract from the dataset while parsing for CF metadata.
689            Defaults to all variables.
690        coordinates : dict, optional
691            Dictionary mapping CF axis types to coordinates of the variable(s). Only specify
692            if you wish to override MetPy's automatic parsing of some axis type(s).
693
694        Returns
695        -------
696        `xarray.DataArray` or `xarray.Dataset`
697            Parsed DataArray (if varname is a string) or Dataset
698
699        See Also
700        --------
701        assign_crs
702
703        """
704        from .plots.mapping import CFProjection
705
706        if varname is None:
707            # If no varname is given, parse all variables in the dataset
708            varname = list(self._dataset.data_vars)
709
710        if np.iterable(varname) and not isinstance(varname, str):
711            # If non-string iterable is given, apply recursively across the varnames
712            subset = xr.merge([self.parse_cf(single_varname, coordinates=coordinates)
713                               for single_varname in varname])
714            subset.attrs = self._dataset.attrs
715            return subset
716
717        var = self._dataset[varname]
718
719        # Check for crs conflict
720        if varname == 'metpy_crs':
721            warnings.warn(
722                'Attempting to parse metpy_crs as a data variable. Unexpected merge conflicts '
723                'may occur.'
724            )
725        elif 'metpy_crs' in var.coords and (var.coords['metpy_crs'].size > 1 or not isinstance(
726                var.coords['metpy_crs'].item(), CFProjection)):
727            warnings.warn(
728                'metpy_crs already present as a non-CFProjection coordinate. Unexpected '
729                'merge conflicts may occur.'
730            )
731
732        # Assign coordinates if the coordinates argument is given
733        if coordinates is not None:
734            var = var.metpy.assign_coordinates(coordinates)
735
736        # Attempt to build the crs coordinate
737        crs = None
738        if 'grid_mapping' in var.attrs:
739            # Use given CF grid_mapping
740            proj_name = var.attrs['grid_mapping']
741            try:
742                proj_var = self._dataset.variables[proj_name]
743            except KeyError:
744                log.warning(
745                    'Could not find variable corresponding to the value of '
746                    f'grid_mapping: {proj_name}')
747            else:
748                crs = CFProjection(proj_var.attrs)
749
750        if crs is None:
751            # This isn't a lat or lon coordinate itself, so determine if we need to fall back
752            # to creating a latitude_longitude CRS. We do so if there exists valid *at most
753            # 1D* coordinates for latitude and longitude (usually dimension coordinates, but
754            # that is not strictly required, for example, for DSG's). What is required is that
755            # x == latitude and y == latitude (so that all assumptions about grid coordinates
756            # and CRS line up).
757            try:
758                latitude, y, longitude, x = var.metpy.coordinates(
759                    'latitude',
760                    'y',
761                    'longitude',
762                    'x'
763                )
764            except AttributeError:
765                # This means that we don't even have sufficient coordinates, so skip
766                pass
767            else:
768                if latitude.identical(y) and longitude.identical(x):
769                    crs = CFProjection({'grid_mapping_name': 'latitude_longitude'})
770                    log.warning('Found valid latitude/longitude coordinates, assuming '
771                                'latitude_longitude for projection grid_mapping variable')
772
773        # Rebuild the coordinates of the dataarray, and return quantified DataArray
774        var = self._rebuild_coords(var, crs)
775        if crs is not None:
776            var = var.assign_coords(coords={'metpy_crs': crs})
777        return var
778
779    def _rebuild_coords(self, var, crs):
780        """Clean up the units on the coordinate variables."""
781        for coord_name, coord_var in var.coords.items():
782            if (check_axis(coord_var, 'x', 'y')
783                    and not check_axis(coord_var, 'longitude', 'latitude')):
784                try:
785                    var = var.metpy.convert_coordinate_units(coord_name, 'meters')
786                except DimensionalityError:
787                    # Radians! Attempt to use perspective point height conversion
788                    if crs is not None:
789                        height = crs['perspective_point_height']
790                        new_coord_var = coord_var.copy(
791                            data=(
792                                coord_var.metpy.unit_array
793                                * units.Quantity(height, 'meter')
794                            ).m_as('meter')
795                        )
796                        new_coord_var.attrs['units'] = 'meter'
797                        var = var.assign_coords(coords={coord_name: new_coord_var})
798
799        return var
800
801    class _LocIndexer:
802        """Provide the unit-wrapped .loc indexer for datasets."""
803
804        def __init__(self, dataset):
805            self.dataset = dataset
806
807        def __getitem__(self, key):
808            parsed_key = _reassign_quantity_indexer(self.dataset, key)
809            return self.dataset.loc[parsed_key]
810
811    @property
812    def loc(self):
813        """Wrap Dataset.loc with an indexer to handle units and coordinate types."""
814        return self._LocIndexer(self._dataset)
815
816    def sel(self, indexers=None, method=None, tolerance=None, drop=False, **indexers_kwargs):
817        """Wrap Dataset.sel to handle units."""
818        indexers = either_dict_or_kwargs(indexers, indexers_kwargs, 'sel')
819        indexers = _reassign_quantity_indexer(self._dataset, indexers)
820        return self._dataset.sel(indexers, method=method, tolerance=tolerance, drop=drop)
821
822    def assign_crs(self, cf_attributes=None, **kwargs):
823        """Assign a CRS to this Dataset based on CF projection attributes.
824
825        Specify a coordinate reference system/grid mapping following the Climate and
826        Forecasting (CF) conventions (see `Appendix F: Grid Mappings
827        <http://cfconventions.org/Data/cf-conventions/cf-conventions-1.7/cf-conventions.html#appendix-grid-mappings>`_
828        ) and store in the ``metpy_crs`` coordinate.
829
830        This method is only required if your dataset does not already follow CF conventions
831        with respect to grid mappings (in which case the ``.parse_cf`` method will parse for
832        the CRS metadata automatically).
833
834        Parameters
835        ----------
836        cf_attributes : dict, optional
837            Dictionary of CF projection attributes
838        kwargs : optional
839            CF projection attributes specified as keyword arguments
840
841        Returns
842        -------
843        `xarray.Dataset`
844            New xarray Dataset with CRS coordinate assigned
845
846        Notes
847        -----
848        CF projection arguments should be supplied as a dictionary or collection of kwargs,
849        but not both.
850
851        See Also
852        --------
853        parse_cf
854
855        """
856        return _assign_crs(self._dataset, cf_attributes, kwargs)
857
858    def assign_latitude_longitude(self, force=False):
859        """Assign latitude and longitude coordinates derived from y and x coordinates.
860
861        Parameters
862        ----------
863        force : bool, optional
864            If force is true, overwrite latitude and longitude coordinates if they exist,
865            otherwise, raise a RuntimeError if such coordinates exist.
866
867        Returns
868        -------
869        `xarray.Dataset`
870            New xarray Dataset with latitude and longitude coordinates assigned to all
871            variables with y and x coordinates.
872
873        Notes
874        -----
875        A valid CRS coordinate must be present (as assigned by ``.parse_cf`` or
876        ``.assign_crs``). PyProj is used for the coordinate transformations.
877
878        """
879        # Determine if there is a valid grid prototype from which to compute the coordinates,
880        # while also checking for existing lat/lon coords
881        grid_prototype = None
882        for data_var in self._dataset.data_vars.values():
883            if hasattr(data_var.metpy, 'y') and hasattr(data_var.metpy, 'x'):
884                if grid_prototype is None:
885                    grid_prototype = data_var
886                if (not force and (hasattr(data_var.metpy, 'latitude')
887                                   or hasattr(data_var.metpy, 'longitude'))):
888                    raise RuntimeError('Latitude/longitude coordinate(s) are present. If you '
889                                       'wish to overwrite these, specify force=True.')
890
891        # Calculate latitude and longitude from grid_prototype, if it exists, and assign
892        if grid_prototype is None:
893            warnings.warn('No latitude and longitude assigned since horizontal coordinates '
894                          'were not found')
895            return self._dataset
896        else:
897            latitude, longitude = _build_latitude_longitude(grid_prototype)
898            return self._dataset.assign_coords(latitude=latitude, longitude=longitude)
899
900    def assign_y_x(self, force=False, tolerance=None):
901        """Assign y and x dimension coordinates derived from 2D latitude and longitude.
902
903        Parameters
904        ----------
905        force : bool, optional
906            If force is true, overwrite y and x coordinates if they exist, otherwise, raise a
907            RuntimeError if such coordinates exist.
908        tolerance : `pint.Quantity`
909            Maximum range tolerated when collapsing projected y and x coordinates from 2D to
910            1D. Defaults to 1 meter.
911
912        Returns
913        -------
914        `xarray.Dataset`
915            New xarray Dataset with y and x dimension coordinates assigned to all variables
916            with valid latitude and longitude coordinates.
917
918        Notes
919        -----
920        A valid CRS coordinate must be present (as assigned by ``.parse_cf`` or
921        ``.assign_crs``). PyProj is used for the coordinate transformations.
922
923        """
924        # Determine if there is a valid grid prototype from which to compute the coordinates,
925        # while also checking for existing y and x coords
926        grid_prototype = None
927        for data_var in self._dataset.data_vars.values():
928            if hasattr(data_var.metpy, 'latitude') and hasattr(data_var.metpy, 'longitude'):
929                if grid_prototype is None:
930                    grid_prototype = data_var
931                if (not force and (hasattr(data_var.metpy, 'y')
932                                   or hasattr(data_var.metpy, 'x'))):
933                    raise RuntimeError('y/x coordinate(s) are present. If you wish to '
934                                       'overwrite these, specify force=True.')
935
936        # Calculate y and x from grid_prototype, if it exists, and assign
937        if grid_prototype is None:
938            warnings.warn('No y and x coordinates assigned since horizontal coordinates '
939                          'were not found')
940            return self._dataset
941        else:
942            y, x = _build_y_x(grid_prototype, tolerance)
943            return self._dataset.assign_coords(**{y.name: y, x.name: x})
944
945    def update_attribute(self, attribute, mapping):
946        """Return new Dataset with specified attribute updated on all Dataset variables.
947
948        Parameters
949        ----------
950        attribute : str,
951            Name of attribute to update
952        mapping : dict or callable
953            Either a dict, with keys as variable names and values as attribute values to set,
954            or a callable, which must accept one positional argument (variable name) and
955            arbitrary keyword arguments (all existing variable attributes). If a variable name
956            is not present/the callable returns None, the attribute will not be updated.
957
958        Returns
959        -------
960        `xarray.Dataset`
961            New Dataset with attribute updated
962
963        """
964        # Make mapping uniform
965        if not callable(mapping):
966            old_mapping = mapping
967
968            def mapping(varname, **kwargs):
969                return old_mapping.get(varname, None)
970
971        # Define mapping function for Dataset.map
972        def mapping_func(da):
973            new_value = mapping(da.name, **da.attrs)
974            if new_value is None:
975                return da
976            else:
977                return da.assign_attrs(**{attribute: new_value})
978
979        # Apply across all variables and coordinates
980        return (
981            self._dataset
982            .map(mapping_func)
983            .assign_coords({
984                coord_name: mapping_func(coord_var)
985                for coord_name, coord_var in self._dataset.coords.items()
986            })
987        )
988
989    def quantify(self):
990        """Return new dataset with all numeric variables quantified and cached data loaded.
991
992        Notes
993        -----
994        Any cached/lazy-loaded data (except that in a Dask array) will be loaded into memory
995        by this operation. Do not utilize on moderate- to large-sized remote datasets before
996        subsetting!
997        """
998        return self._dataset.map(lambda da: da.metpy.quantify())
999
1000    def dequantify(self):
1001        """Return new dataset with variables cast to magnitude and units on attribute."""
1002        return self._dataset.map(lambda da: da.metpy.dequantify())
1003
1004
1005def _assign_axis(attributes, axis):
1006    """Assign the given axis to the _metpy_axis attribute."""
1007    existing_axes = attributes.get('_metpy_axis', '').split(',')
1008    if ((axis == 'y' and 'latitude' in existing_axes)
1009            or (axis == 'latitude' and 'y' in existing_axes)):
1010        # Special case for combined y/latitude handling
1011        attributes['_metpy_axis'] = 'y,latitude'
1012    elif ((axis == 'x' and 'longitude' in existing_axes)
1013            or (axis == 'longitude' and 'x' in existing_axes)):
1014        # Special case for combined x/longitude handling
1015        attributes['_metpy_axis'] = 'x,longitude'
1016    else:
1017        # Simply add it/overwrite past value
1018        attributes['_metpy_axis'] = axis
1019    return attributes
1020
1021
1022def check_axis(var, *axes):
1023    """Check if the criteria for any of the given axes are satisfied.
1024
1025    Parameters
1026    ----------
1027    var : `xarray.DataArray`
1028        DataArray belonging to the coordinate to be checked
1029    axes : str
1030        Axis type(s) to check for. Currently can check for 'time', 'vertical', 'y', 'latitude',
1031        'x', and 'longitude'.
1032
1033    """
1034    for axis in axes:
1035        # Check for
1036        #   - standard name (CF option)
1037        #   - _CoordinateAxisType (from THREDDS)
1038        #   - axis (CF option)
1039        #   - positive (CF standard for non-pressure vertical coordinate)
1040        if any(var.attrs.get(criterion, 'absent')
1041               in coordinate_criteria[criterion].get(axis, set())
1042               for criterion in ('standard_name', '_CoordinateAxisType', 'axis', 'positive')):
1043            return True
1044
1045        # Check for units, either by dimensionality or name
1046        with contextlib.suppress(UndefinedUnitError):
1047            if (axis in coordinate_criteria['units'] and (
1048                    (
1049                        coordinate_criteria['units'][axis]['match'] == 'dimensionality'
1050                        and (units.get_dimensionality(var.metpy.units)
1051                             == units.get_dimensionality(
1052                                 coordinate_criteria['units'][axis]['units']))
1053                    ) or (
1054                        coordinate_criteria['units'][axis]['match'] == 'name'
1055                        and str(var.metpy.units)
1056                        in coordinate_criteria['units'][axis]['units']
1057                    ))):
1058                return True
1059
1060        # Check if name matches regular expression (non-CF failsafe)
1061        if re.match(coordinate_criteria['regular_expression'][axis], var.name.lower()):
1062            return True
1063
1064    # If no match has been made, return False (rather than None)
1065    return False
1066
1067
1068def _assign_crs(xarray_object, cf_attributes, cf_kwargs):
1069    from .plots.mapping import CFProjection
1070
1071    # Handle argument options
1072    if cf_attributes is not None and len(cf_kwargs) > 0:
1073        raise ValueError('Cannot specify both attribute dictionary and kwargs.')
1074    elif cf_attributes is None and len(cf_kwargs) == 0:
1075        raise ValueError('Must specify either attribute dictionary or kwargs.')
1076    attrs = cf_attributes if cf_attributes is not None else cf_kwargs
1077
1078    # Assign crs coordinate to xarray object
1079    return xarray_object.assign_coords(metpy_crs=CFProjection(attrs))
1080
1081
1082def _build_latitude_longitude(da):
1083    """Build latitude/longitude coordinates from DataArray's y/x coordinates."""
1084    y, x = da.metpy.coordinates('y', 'x')
1085    xx, yy = np.meshgrid(x.values, y.values)
1086    lonlats = np.stack(Proj(da.metpy.pyproj_crs)(xx, yy, inverse=True, radians=False), axis=-1)
1087    longitude = xr.DataArray(lonlats[..., 0], dims=(y.name, x.name),
1088                             coords={y.name: y, x.name: x},
1089                             attrs={'units': 'degrees_east', 'standard_name': 'longitude'})
1090    latitude = xr.DataArray(lonlats[..., 1], dims=(y.name, x.name),
1091                            coords={y.name: y, x.name: x},
1092                            attrs={'units': 'degrees_north', 'standard_name': 'latitude'})
1093    return latitude, longitude
1094
1095
1096def _build_y_x(da, tolerance):
1097    """Build y/x coordinates from DataArray's latitude/longitude coordinates."""
1098    # Initial sanity checks
1099    latitude, longitude = da.metpy.coordinates('latitude', 'longitude')
1100    if latitude.dims != longitude.dims:
1101        raise ValueError('Latitude and longitude must have same dimensionality')
1102    elif latitude.ndim != 2:
1103        raise ValueError('To build 1D y/x coordinates via assign_y_x, latitude/longitude '
1104                         'must be 2D')
1105
1106    # Convert to projected y/x
1107    xxyy = np.stack(Proj(da.metpy.pyproj_crs)(
1108        longitude.values,
1109        latitude.values,
1110        inverse=False,
1111        radians=False
1112    ), axis=-1)
1113
1114    # Handle tolerance
1115    tolerance = 1 if tolerance is None else tolerance.m_as('m')
1116
1117    # If within tolerance, take median to collapse to 1D
1118    try:
1119        y_dim = latitude.metpy.find_axis_number('y')
1120        x_dim = latitude.metpy.find_axis_number('x')
1121    except AttributeError:
1122        warnings.warn('y and x dimensions unable to be identified. Assuming [..., y, x] '
1123                      'dimension order.')
1124        y_dim, x_dim = 0, 1
1125    if (np.all(np.ptp(xxyy[..., 0], axis=y_dim) < tolerance)
1126            and np.all(np.ptp(xxyy[..., 1], axis=x_dim) < tolerance)):
1127        x = np.median(xxyy[..., 0], axis=y_dim)
1128        y = np.median(xxyy[..., 1], axis=x_dim)
1129        x = xr.DataArray(x, name=latitude.dims[x_dim], dims=(latitude.dims[x_dim],),
1130                         coords={latitude.dims[x_dim]: x},
1131                         attrs={'units': 'meter', 'standard_name': 'projection_x_coordinate'})
1132        y = xr.DataArray(y, name=latitude.dims[y_dim], dims=(latitude.dims[y_dim],),
1133                         coords={latitude.dims[y_dim]: y},
1134                         attrs={'units': 'meter', 'standard_name': 'projection_y_coordinate'})
1135        return y, x
1136    else:
1137        raise ValueError('Projected y and x coordinates cannot be collapsed to 1D within '
1138                         'tolerance. Verify that your latitude and longitude coordinates '
1139                         'correspond to your CRS coordinate.')
1140
1141
1142def preprocess_and_wrap(broadcast=None, wrap_like=None, match_unit=False, to_magnitude=False):
1143    """Return decorator to wrap array calculations for type flexibility.
1144
1145    Assuming you have a calculation that works internally with `pint.Quantity` or
1146    `numpy.ndarray`, this will wrap the function to be able to handle `xarray.DataArray` and
1147    `pint.Quantity` as well (assuming appropriate match to one of the input arguments).
1148
1149    Parameters
1150    ----------
1151    broadcast : iterable of str or None
1152        Iterable of string labels for arguments to broadcast against each other using xarray,
1153        assuming they are supplied as `xarray.DataArray`. No automatic broadcasting will occur
1154        with default of None.
1155    wrap_like : str or array-like or tuple of str or tuple of array-like or None
1156        Wrap the calculation output following a particular input argument (if str) or data
1157        object (if array-like). If tuple, will assume output is in the form of a tuple,
1158        and wrap iteratively according to the str or array-like contained within. If None,
1159        will not wrap output.
1160    match_unit : bool
1161        If true, force the unit of the final output to be that of wrapping object (as
1162        determined by wrap_like), no matter the original calculation output. Defaults to
1163        False.
1164    to_magnitude : bool
1165        If true, downcast xarray and Pint arguments to their magnitude. If false, downcast
1166        xarray arguments to Quantity, and do not change other array-like arguments.
1167    """
1168    def decorator(func):
1169        @functools.wraps(func)
1170        def wrapper(*args, **kwargs):
1171            bound_args = signature(func).bind(*args, **kwargs)
1172
1173            # Auto-broadcast select xarray arguments, and update bound_args
1174            if broadcast is not None:
1175                arg_names_to_broadcast = tuple(
1176                    arg_name for arg_name in broadcast
1177                    if arg_name in bound_args.arguments
1178                    and isinstance(
1179                        bound_args.arguments[arg_name],
1180                        (xr.DataArray, xr.Variable)
1181                    )
1182                )
1183                broadcasted_args = xr.broadcast(
1184                    *(bound_args.arguments[arg_name] for arg_name in arg_names_to_broadcast)
1185                )
1186                for i, arg_name in enumerate(arg_names_to_broadcast):
1187                    bound_args.arguments[arg_name] = broadcasted_args[i]
1188
1189            # Cast all Variables to their data and warn
1190            # (need to do before match finding, since we don't want to rewrap as Variable)
1191            def cast_variables(arg, arg_name):
1192                warnings.warn(
1193                    f'Argument {arg_name} given as xarray Variable...casting to its data. '
1194                    'xarray DataArrays are recommended instead.'
1195                )
1196                return arg.data
1197            _mutate_arguments(bound_args, xr.Variable, cast_variables)
1198
1199            # Obtain proper match if referencing an input
1200            match = list(wrap_like) if isinstance(wrap_like, tuple) else wrap_like
1201            if isinstance(wrap_like, str):
1202                match = bound_args.arguments[wrap_like]
1203            elif isinstance(wrap_like, tuple):
1204                for i, arg in enumerate(wrap_like):
1205                    if isinstance(arg, str):
1206                        match[i] = bound_args.arguments[arg]
1207
1208            # Cast all DataArrays to Pint Quantities
1209            _mutate_arguments(bound_args, xr.DataArray, lambda arg, _: arg.metpy.unit_array)
1210
1211            # Optionally cast all Quantities to their magnitudes
1212            if to_magnitude:
1213                _mutate_arguments(bound_args, units.Quantity, lambda arg, _: arg.m)
1214
1215            # Evaluate inner calculation
1216            result = func(*bound_args.args, **bound_args.kwargs)
1217
1218            # Wrap output based on match and match_unit
1219            if match is None:
1220                return result
1221            else:
1222                if match_unit:
1223                    wrapping = _wrap_output_like_matching_units
1224                else:
1225                    wrapping = _wrap_output_like_not_matching_units
1226
1227                if isinstance(match, list):
1228                    return tuple(wrapping(*args) for args in zip(result, match))
1229                else:
1230                    return wrapping(result, match)
1231        return wrapper
1232    return decorator
1233
1234
1235def _mutate_arguments(bound_args, check_type, mutate_arg):
1236    """Handle adjusting bound arguments.
1237
1238    Calls ``mutate_arg`` on every argument, including those passed as ``*args``, if they are
1239    of type ``check_type``.
1240    """
1241    for arg_name, arg_val in bound_args.arguments.items():
1242        if isinstance(arg_val, check_type):
1243            bound_args.arguments[arg_name] = mutate_arg(arg_val, arg_name)
1244
1245    if isinstance(bound_args.arguments.get('args'), tuple):
1246        bound_args.arguments['args'] = tuple(
1247            mutate_arg(arg_val, '(unnamed)') if isinstance(arg_val, check_type) else arg_val
1248            for arg_val in bound_args.arguments['args'])
1249
1250
1251def _wrap_output_like_matching_units(result, match):
1252    """Convert result to be like match with matching units for output wrapper."""
1253    output_xarray = isinstance(match, xr.DataArray)
1254    match_units = str(match.metpy.units if output_xarray else getattr(match, 'units', ''))
1255
1256    if isinstance(result, xr.DataArray):
1257        result = result.metpy.convert_units(match_units)
1258        return result if output_xarray else result.metpy.unit_array
1259    else:
1260        result = (
1261            result.to(match_units) if isinstance(result, units.Quantity)
1262            else units.Quantity(result, match_units)
1263        )
1264        return (
1265            xr.DataArray(result, coords=match.coords, dims=match.dims) if output_xarray
1266            else result
1267        )
1268
1269
1270def _wrap_output_like_not_matching_units(result, match):
1271    """Convert result to be like match without matching units for output wrapper."""
1272    output_xarray = isinstance(match, xr.DataArray)
1273    if isinstance(result, xr.DataArray):
1274        return result if output_xarray else result.metpy.unit_array
1275    else:
1276        # Determine if need to upcast to Quantity
1277        if (
1278            not isinstance(result, units.Quantity)
1279            and (
1280                isinstance(match, units.Quantity)
1281                or (output_xarray and isinstance(match.data, units.Quantity))
1282            )
1283        ):
1284            result = units.Quantity(result)
1285        return (
1286            xr.DataArray(result, coords=match.coords, dims=match.dims) if output_xarray
1287            else result
1288        )
1289
1290
1291def check_matching_coordinates(func):
1292    """Decorate a function to make sure all given DataArrays have matching coordinates."""
1293    @functools.wraps(func)
1294    def wrapper(*args, **kwargs):
1295        data_arrays = ([a for a in args if isinstance(a, xr.DataArray)]
1296                       + [a for a in kwargs.values() if isinstance(a, xr.DataArray)])
1297        if len(data_arrays) > 1:
1298            first = data_arrays[0]
1299            for other in data_arrays[1:]:
1300                if not first.metpy.coordinates_identical(other):
1301                    raise ValueError('Input DataArray arguments must be on same coordinates.')
1302        return func(*args, **kwargs)
1303    return wrapper
1304
1305
1306def _reassign_quantity_indexer(data, indexers):
1307    """Reassign a units.Quantity indexer to units of relevant coordinate."""
1308    def _to_magnitude(val, unit):
1309        try:
1310            return val.m_as(unit)
1311        except AttributeError:
1312            return val
1313
1314    # Update indexers keys for axis type -> coord name replacement
1315    indexers = {(key if not isinstance(data, xr.DataArray) or key in data.dims
1316                 or key not in metpy_axes else
1317                 next(data.metpy.coordinates(key)).name): indexers[key]
1318                for key in indexers}
1319
1320    # Update indexers to handle quantities and slices of quantities
1321    reassigned_indexers = {}
1322    for coord_name in indexers:
1323        coord_units = data[coord_name].metpy.units
1324        if isinstance(indexers[coord_name], slice):
1325            # Handle slices of quantities
1326            start = _to_magnitude(indexers[coord_name].start, coord_units)
1327            stop = _to_magnitude(indexers[coord_name].stop, coord_units)
1328            step = _to_magnitude(indexers[coord_name].step, coord_units)
1329            reassigned_indexers[coord_name] = slice(start, stop, step)
1330        else:
1331            # Handle quantities
1332            reassigned_indexers[coord_name] = _to_magnitude(indexers[coord_name], coord_units)
1333
1334    return reassigned_indexers
1335
1336
1337def grid_deltas_from_dataarray(f, kind='default'):
1338    """Calculate the horizontal deltas between grid points of a DataArray.
1339
1340    Calculate the signed delta distance between grid points of a DataArray in the horizontal
1341    directions, using actual (real distance) or nominal (in projection space) deltas.
1342
1343    Parameters
1344    ----------
1345    f : `xarray.DataArray`
1346        Parsed DataArray (``metpy_crs`` coordinate must be available for kind="actual")
1347    kind : str
1348        Type of grid delta to calculate. "actual" returns true distances as calculated from
1349        longitude and latitude via `lat_lon_grid_deltas`. "nominal" returns horizontal
1350        differences in the data's coordinate space, either in degrees (for lat/lon CRS) or
1351        meters (for y/x CRS). "default" behaves like "actual" for datasets with a lat/lon CRS
1352        and like "nominal" for all others. Defaults to "default".
1353
1354    Returns
1355    -------
1356    dx, dy:
1357        arrays of signed deltas between grid points in the x and y directions with dimensions
1358        matching those of `f`.
1359
1360    See Also
1361    --------
1362    lat_lon_grid_deltas
1363
1364    """
1365    from metpy.calc import lat_lon_grid_deltas
1366
1367    # Determine behavior
1368    if (
1369        kind == 'default'
1370        and (
1371            not hasattr(f.metpy, 'crs')
1372            or f.metpy.crs['grid_mapping_name'] == 'latitude_longitude'
1373        )
1374    ):
1375        # Use actual grid deltas by default with latitude_longitude or missing CRS
1376        kind = 'actual'
1377    elif kind == 'default':
1378        # Otherwise, use grid deltas in projected grid space by default
1379        kind = 'nominal'
1380    elif kind not in ('actual', 'nominal'):
1381        raise ValueError('"kind" argument must be specified as "default", "actual", or '
1382                         '"nominal"')
1383
1384    if kind == 'actual':
1385        # Get latitude/longitude coordinates and find dim order
1386        latitude, longitude = xr.broadcast(*f.metpy.coordinates('latitude', 'longitude'))
1387        try:
1388            y_dim = latitude.metpy.find_axis_number('y')
1389            x_dim = latitude.metpy.find_axis_number('x')
1390        except AttributeError:
1391            warnings.warn('y and x dimensions unable to be identified. Assuming [..., y, x] '
1392                          'dimension order.')
1393            y_dim, x_dim = -2, -1
1394
1395        # Get geod if it exists, otherwise fall back to PyProj default
1396        try:
1397            geod = f.metpy.pyproj_crs.get_geod()
1398        except AttributeError:
1399            geod = CRS.from_cf({'grid_mapping_name': 'latitude_longitude'}).get_geod()
1400        # Obtain grid deltas as xarray Variables
1401        (dx_var, dx_units), (dy_var, dy_units) = (
1402            (xr.Variable(dims=latitude.dims, data=deltas.magnitude), deltas.units)
1403            for deltas in lat_lon_grid_deltas(longitude, latitude, x_dim=x_dim, y_dim=y_dim,
1404                                              geod=geod))
1405    else:
1406        # Obtain y/x coordinate differences
1407        y, x = f.metpy.coordinates('y', 'x')
1408        dx_var = x.diff(x.dims[0]).variable
1409        dx_units = units(x.attrs.get('units'))
1410        dy_var = y.diff(y.dims[0]).variable
1411        dy_units = units(y.attrs.get('units'))
1412
1413    # Broadcast to input and attach units
1414    dx_var = dx_var.set_dims(f.dims, shape=[dx_var.sizes[dim] if dim in dx_var.dims else 1
1415                                            for dim in f.dims])
1416    dx = units.Quantity(dx_var.data, dx_units)
1417    dy_var = dy_var.set_dims(f.dims, shape=[dy_var.sizes[dim] if dim in dy_var.dims else 1
1418                                            for dim in f.dims])
1419    dy = units.Quantity(dy_var.data, dy_units)
1420
1421    return dx, dy
1422
1423
1424def dataarray_arguments(bound_args):
1425    """Get any dataarray arguments in the bound function arguments."""
1426    for value in chain(bound_args.args, bound_args.kwargs.values()):
1427        if isinstance(value, xr.DataArray):
1428            yield value
1429
1430
1431def add_grid_arguments_from_xarray(func):
1432    """Fill in optional arguments like dx/dy from DataArray arguments."""
1433    @functools.wraps(func)
1434    def wrapper(*args, **kwargs):
1435        bound_args = signature(func).bind(*args, **kwargs)
1436        bound_args.apply_defaults()
1437
1438        # Search for DataArray with valid latitude and longitude coordinates to find grid
1439        # deltas and any other needed parameter
1440        grid_prototype = None
1441        for da in dataarray_arguments(bound_args):
1442            if hasattr(da.metpy, 'latitude') and hasattr(da.metpy, 'longitude'):
1443                grid_prototype = da
1444                break
1445
1446        # Fill in x_dim/y_dim
1447        if (
1448            grid_prototype is not None
1449            and 'x_dim' in bound_args.arguments
1450            and 'y_dim' in bound_args.arguments
1451        ):
1452            try:
1453                bound_args.arguments['x_dim'] = grid_prototype.metpy.find_axis_number('x')
1454                bound_args.arguments['y_dim'] = grid_prototype.metpy.find_axis_number('y')
1455            except AttributeError:
1456                # If axis number not found, fall back to default but warn.
1457                warnings.warn('Horizontal dimension numbers not found. Defaulting to '
1458                              '(..., Y, X) order.')
1459
1460        # Fill in vertical_dim
1461        if (
1462            grid_prototype is not None
1463            and 'vertical_dim' in bound_args.arguments
1464        ):
1465            try:
1466                bound_args.arguments['vertical_dim'] = (
1467                    grid_prototype.metpy.find_axis_number('vertical')
1468                )
1469            except AttributeError:
1470                # If axis number not found, fall back to default but warn.
1471                warnings.warn(
1472                    'Vertical dimension number not found. Defaulting to (..., Z, Y, X) order.'
1473                )
1474
1475        # Fill in dz
1476        if (
1477            grid_prototype is not None
1478            and 'dz' in bound_args.arguments
1479            and bound_args.arguments['dz'] is None
1480        ):
1481            try:
1482                vertical_coord = grid_prototype.metpy.vertical
1483                bound_args.arguments['dz'] = np.diff(vertical_coord.metpy.unit_array)
1484            except (AttributeError, ValueError):
1485                # Skip, since this only comes up in advection, where dz is optional (may not
1486                # need vertical at all)
1487                pass
1488
1489        # Fill in dx/dy
1490        if (
1491            'dx' in bound_args.arguments and bound_args.arguments['dx'] is None
1492            and 'dy' in bound_args.arguments and bound_args.arguments['dy'] is None
1493        ):
1494            if grid_prototype is not None:
1495                bound_args.arguments['dx'], bound_args.arguments['dy'] = (
1496                    grid_deltas_from_dataarray(grid_prototype, kind='actual')
1497                )
1498            elif 'dz' in bound_args.arguments:
1499                # Handle advection case, allowing dx/dy to be None but dz to not be None
1500                if bound_args.arguments['dz'] is None:
1501                    raise ValueError(
1502                        'Must provide dx, dy, and/or dz arguments or input DataArray with '
1503                        'proper coordinates.'
1504                    )
1505            else:
1506                raise ValueError('Must provide dx/dy arguments or input DataArray with '
1507                                 'latitude/longitude coordinates.')
1508
1509        # Fill in latitude
1510        if 'latitude' in bound_args.arguments and bound_args.arguments['latitude'] is None:
1511            if grid_prototype is not None:
1512                bound_args.arguments['latitude'] = (
1513                    grid_prototype.metpy.latitude
1514                )
1515            else:
1516                raise ValueError('Must provide latitude argument or input DataArray with '
1517                                 'latitude/longitude coordinates.')
1518
1519        return func(*bound_args.args, **bound_args.kwargs)
1520    return wrapper
1521
1522
1523def add_vertical_dim_from_xarray(func):
1524    """Fill in optional vertical_dim from DataArray argument."""
1525    @functools.wraps(func)
1526    def wrapper(*args, **kwargs):
1527        bound_args = signature(func).bind(*args, **kwargs)
1528        bound_args.apply_defaults()
1529
1530        # Fill in vertical_dim
1531        if 'vertical_dim' in bound_args.arguments:
1532            a = next(dataarray_arguments(bound_args), None)
1533            if a is not None:
1534                try:
1535                    bound_args.arguments['vertical_dim'] = a.metpy.find_axis_number('vertical')
1536                except AttributeError:
1537                    # If axis number not found, fall back to default but warn.
1538                    warnings.warn(
1539                        'Vertical dimension number not found. Defaulting to initial dimension.'
1540                    )
1541
1542        return func(*bound_args.args, **bound_args.kwargs)
1543    return wrapper
1544