1"""Coders for individual Variable objects."""
2import warnings
3from functools import partial
4from typing import Any, Hashable
5
6import numpy as np
7import pandas as pd
8
9from ..core import dtypes, duck_array_ops, indexing
10from ..core.pycompat import is_duck_dask_array
11from ..core.variable import Variable
12
13
14class SerializationWarning(RuntimeWarning):
15    """Warnings about encoding/decoding issues in serialization."""
16
17
18class VariableCoder:
19    """Base class for encoding and decoding transformations on variables.
20
21    We use coders for transforming variables between xarray's data model and
22    a format suitable for serialization. For example, coders apply CF
23    conventions for how data should be represented in netCDF files.
24
25    Subclasses should implement encode() and decode(), which should satisfy
26    the identity ``coder.decode(coder.encode(variable)) == variable``. If any
27    options are necessary, they should be implemented as arguments to the
28    __init__ method.
29
30    The optional name argument to encode() and decode() exists solely for the
31    sake of better error messages, and should correspond to the name of
32    variables in the underlying store.
33    """
34
35    def encode(
36        self, variable: Variable, name: Hashable = None
37    ) -> Variable:  # pragma: no cover
38        """Convert an encoded variable to a decoded variable"""
39        raise NotImplementedError()
40
41    def decode(
42        self, variable: Variable, name: Hashable = None
43    ) -> Variable:  # pragma: no cover
44        """Convert an decoded variable to a encoded variable"""
45        raise NotImplementedError()
46
47
48class _ElementwiseFunctionArray(indexing.ExplicitlyIndexedNDArrayMixin):
49    """Lazily computed array holding values of elemwise-function.
50
51    Do not construct this object directly: call lazy_elemwise_func instead.
52
53    Values are computed upon indexing or coercion to a NumPy array.
54    """
55
56    def __init__(self, array, func, dtype):
57        assert not is_duck_dask_array(array)
58        self.array = indexing.as_indexable(array)
59        self.func = func
60        self._dtype = dtype
61
62    @property
63    def dtype(self):
64        return np.dtype(self._dtype)
65
66    def __getitem__(self, key):
67        return type(self)(self.array[key], self.func, self.dtype)
68
69    def __array__(self, dtype=None):
70        return self.func(self.array)
71
72    def __repr__(self):
73        return "{}({!r}, func={!r}, dtype={!r})".format(
74            type(self).__name__, self.array, self.func, self.dtype
75        )
76
77
78def lazy_elemwise_func(array, func, dtype):
79    """Lazily apply an element-wise function to an array.
80    Parameters
81    ----------
82    array : any valid value of Variable._data
83    func : callable
84        Function to apply to indexed slices of an array. For use with dask,
85        this should be a pickle-able object.
86    dtype : coercible to np.dtype
87        Dtype for the result of this function.
88
89    Returns
90    -------
91    Either a dask.array.Array or _ElementwiseFunctionArray.
92    """
93    if is_duck_dask_array(array):
94        import dask.array as da
95
96        return da.map_blocks(func, array, dtype=dtype)
97    else:
98        return _ElementwiseFunctionArray(array, func, dtype)
99
100
101def unpack_for_encoding(var):
102    return var.dims, var.data, var.attrs.copy(), var.encoding.copy()
103
104
105def unpack_for_decoding(var):
106    return var.dims, var._data, var.attrs.copy(), var.encoding.copy()
107
108
109def safe_setitem(dest, key, value, name=None):
110    if key in dest:
111        var_str = f" on variable {name!r}" if name else ""
112        raise ValueError(
113            "failed to prevent overwriting existing key {} in attrs{}. "
114            "This is probably an encoding field used by xarray to describe "
115            "how a variable is serialized. To proceed, remove this key from "
116            "the variable's attributes manually.".format(key, var_str)
117        )
118    dest[key] = value
119
120
121def pop_to(source, dest, key, name=None):
122    """
123    A convenience function which pops a key k from source to dest.
124    None values are not passed on.  If k already exists in dest an
125    error is raised.
126    """
127    value = source.pop(key, None)
128    if value is not None:
129        safe_setitem(dest, key, value, name=name)
130    return value
131
132
133def _apply_mask(
134    data: np.ndarray, encoded_fill_values: list, decoded_fill_value: Any, dtype: Any
135) -> np.ndarray:
136    """Mask all matching values in a NumPy arrays."""
137    data = np.asarray(data, dtype=dtype)
138    condition = False
139    for fv in encoded_fill_values:
140        condition |= data == fv
141    return np.where(condition, decoded_fill_value, data)
142
143
144class CFMaskCoder(VariableCoder):
145    """Mask or unmask fill values according to CF conventions."""
146
147    def encode(self, variable, name=None):
148        dims, data, attrs, encoding = unpack_for_encoding(variable)
149
150        dtype = np.dtype(encoding.get("dtype", data.dtype))
151        fv = encoding.get("_FillValue")
152        mv = encoding.get("missing_value")
153
154        if (
155            fv is not None
156            and mv is not None
157            and not duck_array_ops.allclose_or_equiv(fv, mv)
158        ):
159            raise ValueError(
160                f"Variable {name!r} has conflicting _FillValue ({fv}) and missing_value ({mv}). Cannot encode data."
161            )
162
163        if fv is not None:
164            # Ensure _FillValue is cast to same dtype as data's
165            encoding["_FillValue"] = dtype.type(fv)
166            fill_value = pop_to(encoding, attrs, "_FillValue", name=name)
167            if not pd.isnull(fill_value):
168                data = duck_array_ops.fillna(data, fill_value)
169
170        if mv is not None:
171            # Ensure missing_value is cast to same dtype as data's
172            encoding["missing_value"] = dtype.type(mv)
173            fill_value = pop_to(encoding, attrs, "missing_value", name=name)
174            if not pd.isnull(fill_value) and fv is None:
175                data = duck_array_ops.fillna(data, fill_value)
176
177        return Variable(dims, data, attrs, encoding)
178
179    def decode(self, variable, name=None):
180        dims, data, attrs, encoding = unpack_for_decoding(variable)
181
182        raw_fill_values = [
183            pop_to(attrs, encoding, attr, name=name)
184            for attr in ("missing_value", "_FillValue")
185        ]
186        if raw_fill_values:
187            encoded_fill_values = {
188                fv
189                for option in raw_fill_values
190                for fv in np.ravel(option)
191                if not pd.isnull(fv)
192            }
193
194            if len(encoded_fill_values) > 1:
195                warnings.warn(
196                    "variable {!r} has multiple fill values {}, "
197                    "decoding all values to NaN.".format(name, encoded_fill_values),
198                    SerializationWarning,
199                    stacklevel=3,
200                )
201
202            dtype, decoded_fill_value = dtypes.maybe_promote(data.dtype)
203
204            if encoded_fill_values:
205                transform = partial(
206                    _apply_mask,
207                    encoded_fill_values=encoded_fill_values,
208                    decoded_fill_value=decoded_fill_value,
209                    dtype=dtype,
210                )
211                data = lazy_elemwise_func(data, transform, dtype)
212
213        return Variable(dims, data, attrs, encoding)
214
215
216def _scale_offset_decoding(data, scale_factor, add_offset, dtype):
217    data = np.array(data, dtype=dtype, copy=True)
218    if scale_factor is not None:
219        data *= scale_factor
220    if add_offset is not None:
221        data += add_offset
222    return data
223
224
225def _choose_float_dtype(dtype, has_offset):
226    """Return a float dtype that can losslessly represent `dtype` values."""
227    # Keep float32 as-is.  Upcast half-precision to single-precision,
228    # because float16 is "intended for storage but not computation"
229    if dtype.itemsize <= 4 and np.issubdtype(dtype, np.floating):
230        return np.float32
231    # float32 can exactly represent all integers up to 24 bits
232    if dtype.itemsize <= 2 and np.issubdtype(dtype, np.integer):
233        # A scale factor is entirely safe (vanishing into the mantissa),
234        # but a large integer offset could lead to loss of precision.
235        # Sensitivity analysis can be tricky, so we just use a float64
236        # if there's any offset at all - better unoptimised than wrong!
237        if not has_offset:
238            return np.float32
239    # For all other types and circumstances, we just use float64.
240    # (safe because eg. complex numbers are not supported in NetCDF)
241    return np.float64
242
243
244class CFScaleOffsetCoder(VariableCoder):
245    """Scale and offset variables according to CF conventions.
246
247    Follows the formula:
248        decode_values = encoded_values * scale_factor + add_offset
249    """
250
251    def encode(self, variable, name=None):
252        dims, data, attrs, encoding = unpack_for_encoding(variable)
253
254        if "scale_factor" in encoding or "add_offset" in encoding:
255            dtype = _choose_float_dtype(data.dtype, "add_offset" in encoding)
256            data = data.astype(dtype=dtype, copy=True)
257        if "add_offset" in encoding:
258            data -= pop_to(encoding, attrs, "add_offset", name=name)
259        if "scale_factor" in encoding:
260            data /= pop_to(encoding, attrs, "scale_factor", name=name)
261
262        return Variable(dims, data, attrs, encoding)
263
264    def decode(self, variable, name=None):
265        dims, data, attrs, encoding = unpack_for_decoding(variable)
266
267        if "scale_factor" in attrs or "add_offset" in attrs:
268            scale_factor = pop_to(attrs, encoding, "scale_factor", name=name)
269            add_offset = pop_to(attrs, encoding, "add_offset", name=name)
270            dtype = _choose_float_dtype(data.dtype, "add_offset" in attrs)
271            if np.ndim(scale_factor) > 0:
272                scale_factor = np.asarray(scale_factor).item()
273            if np.ndim(add_offset) > 0:
274                add_offset = np.asarray(add_offset).item()
275            transform = partial(
276                _scale_offset_decoding,
277                scale_factor=scale_factor,
278                add_offset=add_offset,
279                dtype=dtype,
280            )
281            data = lazy_elemwise_func(data, transform, dtype)
282
283        return Variable(dims, data, attrs, encoding)
284
285
286class UnsignedIntegerCoder(VariableCoder):
287    def encode(self, variable, name=None):
288        dims, data, attrs, encoding = unpack_for_encoding(variable)
289
290        # from netCDF best practices
291        # https://www.unidata.ucar.edu/software/netcdf/docs/BestPractices.html
292        #     "_Unsigned = "true" to indicate that
293        #      integer data should be treated as unsigned"
294        if encoding.get("_Unsigned", "false") == "true":
295            pop_to(encoding, attrs, "_Unsigned")
296            signed_dtype = np.dtype(f"i{data.dtype.itemsize}")
297            if "_FillValue" in attrs:
298                new_fill = signed_dtype.type(attrs["_FillValue"])
299                attrs["_FillValue"] = new_fill
300            data = duck_array_ops.around(data).astype(signed_dtype)
301
302        return Variable(dims, data, attrs, encoding)
303
304    def decode(self, variable, name=None):
305        dims, data, attrs, encoding = unpack_for_decoding(variable)
306
307        if "_Unsigned" in attrs:
308            unsigned = pop_to(attrs, encoding, "_Unsigned")
309
310            if data.dtype.kind == "i":
311                if unsigned == "true":
312                    unsigned_dtype = np.dtype(f"u{data.dtype.itemsize}")
313                    transform = partial(np.asarray, dtype=unsigned_dtype)
314                    data = lazy_elemwise_func(data, transform, unsigned_dtype)
315                    if "_FillValue" in attrs:
316                        new_fill = unsigned_dtype.type(attrs["_FillValue"])
317                        attrs["_FillValue"] = new_fill
318            elif data.dtype.kind == "u":
319                if unsigned == "false":
320                    signed_dtype = np.dtype(f"i{data.dtype.itemsize}")
321                    transform = partial(np.asarray, dtype=signed_dtype)
322                    data = lazy_elemwise_func(data, transform, signed_dtype)
323                    if "_FillValue" in attrs:
324                        new_fill = signed_dtype.type(attrs["_FillValue"])
325                        attrs["_FillValue"] = new_fill
326            else:
327                warnings.warn(
328                    f"variable {name!r} has _Unsigned attribute but is not "
329                    "of integer type. Ignoring attribute.",
330                    SerializationWarning,
331                    stacklevel=3,
332                )
333
334        return Variable(dims, data, attrs, encoding)
335