1"""
2    pint.registry_helpers
3    ~~~~~~~~~~~~~~~~~~~~~
4
5    Miscellaneous methods of the registry written as separate functions.
6
7    :copyright: 2016 by Pint Authors, see AUTHORS for more details..
8    :license: BSD, see LICENSE for more details.
9"""
10
11import functools
12from inspect import signature
13from itertools import zip_longest
14from typing import TYPE_CHECKING, Callable, Iterable, TypeVar, Union
15
16from ._typing import F
17from .errors import DimensionalityError
18from .quantity import Quantity
19from .util import UnitsContainer, to_units_container
20
21if TYPE_CHECKING:
22    from .registry import UnitRegistry
23    from .unit import Unit
24
25T = TypeVar("T")
26
27
28def _replace_units(original_units, values_by_name):
29    """Convert a unit compatible type to a UnitsContainer.
30
31    Parameters
32    ----------
33    original_units :
34        a UnitsContainer instance.
35    values_by_name :
36        a map between original names and the new values.
37
38    Returns
39    -------
40
41    """
42    q = 1
43    for arg_name, exponent in original_units.items():
44        q = q * values_by_name[arg_name] ** exponent
45
46    return getattr(q, "_units", UnitsContainer({}))
47
48
49def _to_units_container(a, registry=None):
50    """Convert a unit compatible type to a UnitsContainer,
51    checking if it is string field prefixed with an equal
52    (which is considered a reference)
53
54    Parameters
55    ----------
56    a :
57
58    registry :
59         (Default value = None)
60
61    Returns
62    -------
63    UnitsContainer, bool
64
65
66    """
67    if isinstance(a, str) and "=" in a:
68        return to_units_container(a.split("=", 1)[1]), True
69    return to_units_container(a, registry), False
70
71
72def _parse_wrap_args(args, registry=None):
73
74    # Arguments which contain definitions
75    # (i.e. names that appear alone and for the first time)
76    defs_args = set()
77    defs_args_ndx = set()
78
79    # Arguments which depend on others
80    dependent_args_ndx = set()
81
82    # Arguments which have units.
83    unit_args_ndx = set()
84
85    # _to_units_container
86    args_as_uc = [_to_units_container(arg, registry) for arg in args]
87
88    # Check for references in args, remove None values
89    for ndx, (arg, is_ref) in enumerate(args_as_uc):
90        if arg is None:
91            continue
92        elif is_ref:
93            if len(arg) == 1:
94                [(key, value)] = arg.items()
95                if value == 1 and key not in defs_args:
96                    # This is the first time that
97                    # a variable is used => it is a definition.
98                    defs_args.add(key)
99                    defs_args_ndx.add(ndx)
100                    args_as_uc[ndx] = (key, True)
101                else:
102                    # The variable was already found elsewhere,
103                    # we consider it a dependent variable.
104                    dependent_args_ndx.add(ndx)
105            else:
106                dependent_args_ndx.add(ndx)
107        else:
108            unit_args_ndx.add(ndx)
109
110    # Check that all valid dependent variables
111    for ndx in dependent_args_ndx:
112        arg, is_ref = args_as_uc[ndx]
113        if not isinstance(arg, dict):
114            continue
115        if not set(arg.keys()) <= defs_args:
116            raise ValueError(
117                "Found a missing token while wrapping a function: "
118                "Not all variable referenced in %s are defined using !" % args[ndx]
119            )
120
121    def _converter(ureg, values, strict):
122        new_values = list(value for value in values)
123
124        values_by_name = {}
125
126        # first pass: Grab named values
127        for ndx in defs_args_ndx:
128            value = values[ndx]
129            values_by_name[args_as_uc[ndx][0]] = value
130            new_values[ndx] = getattr(value, "_magnitude", value)
131
132        # second pass: calculate derived values based on named values
133        for ndx in dependent_args_ndx:
134            value = values[ndx]
135            assert _replace_units(args_as_uc[ndx][0], values_by_name) is not None
136            new_values[ndx] = ureg._convert(
137                getattr(value, "_magnitude", value),
138                getattr(value, "_units", UnitsContainer({})),
139                _replace_units(args_as_uc[ndx][0], values_by_name),
140            )
141
142        # third pass: convert other arguments
143        for ndx in unit_args_ndx:
144
145            if isinstance(values[ndx], ureg.Quantity):
146                new_values[ndx] = ureg._convert(
147                    values[ndx]._magnitude, values[ndx]._units, args_as_uc[ndx][0]
148                )
149            else:
150                if strict:
151                    if isinstance(values[ndx], str):
152                        # if the value is a string, we try to parse it
153                        tmp_value = ureg.parse_expression(values[ndx])
154                        new_values[ndx] = ureg._convert(
155                            tmp_value._magnitude, tmp_value._units, args_as_uc[ndx][0]
156                        )
157                    else:
158                        raise ValueError(
159                            "A wrapped function using strict=True requires "
160                            "quantity or a string for all arguments with not None units. "
161                            "(error found for {}, {})".format(
162                                args_as_uc[ndx][0], new_values[ndx]
163                            )
164                        )
165
166        return new_values, values_by_name
167
168    return _converter
169
170
171def _apply_defaults(func, args, kwargs):
172    """Apply default keyword arguments.
173
174    Named keywords may have been left blank. This function applies the default
175    values so that every argument is defined.
176    """
177
178    sig = signature(func)
179    bound_arguments = sig.bind(*args, **kwargs)
180    for param in sig.parameters.values():
181        if param.name not in bound_arguments.arguments:
182            bound_arguments.arguments[param.name] = param.default
183    args = [bound_arguments.arguments[key] for key in sig.parameters.keys()]
184    return args, {}
185
186
187def wraps(
188    ureg: "UnitRegistry",
189    ret: Union[str, "Unit", Iterable[str], Iterable["Unit"], None],
190    args: Union[str, "Unit", Iterable[str], Iterable["Unit"], None],
191    strict: bool = True,
192) -> Callable[[Callable[..., T]], Callable[..., Quantity[T]]]:
193    """Wraps a function to become pint-aware.
194
195    Use it when a function requires a numerical value but in some specific
196    units. The wrapper function will take a pint quantity, convert to the units
197    specified in `args` and then call the wrapped function with the resulting
198    magnitude.
199
200    The value returned by the wrapped function will be converted to the units
201    specified in `ret`.
202
203    Parameters
204    ----------
205    ureg : pint.UnitRegistry
206        a UnitRegistry instance.
207    ret : str, pint.Unit, iterable of str, or iterable of pint.Unit
208        Units of each of the return values. Use `None` to skip argument conversion.
209    args : str, pint.Unit, iterable of str, or iterable of pint.Unit
210        Units of each of the input arguments. Use `None` to skip argument conversion.
211    strict : bool
212        Indicates that only quantities are accepted. (Default value = True)
213
214    Returns
215    -------
216    callable
217        the wrapper function.
218
219    Raises
220    ------
221    TypeError
222        if the number of given arguments does not match the number of function parameters.
223        if the any of the provided arguments is not a unit a string or Quantity
224
225    """
226
227    if not isinstance(args, (list, tuple)):
228        args = (args,)
229
230    for arg in args:
231        if arg is not None and not isinstance(arg, (ureg.Unit, str)):
232            raise TypeError(
233                "wraps arguments must by of type str or Unit, not %s (%s)"
234                % (type(arg), arg)
235            )
236
237    converter = _parse_wrap_args(args)
238
239    is_ret_container = isinstance(ret, (list, tuple))
240    if is_ret_container:
241        for arg in ret:
242            if arg is not None and not isinstance(arg, (ureg.Unit, str)):
243                raise TypeError(
244                    "wraps 'ret' argument must by of type str or Unit, not %s (%s)"
245                    % (type(arg), arg)
246                )
247        ret = ret.__class__([_to_units_container(arg, ureg) for arg in ret])
248    else:
249        if ret is not None and not isinstance(ret, (ureg.Unit, str)):
250            raise TypeError(
251                "wraps 'ret' argument must by of type str or Unit, not %s (%s)"
252                % (type(ret), ret)
253            )
254        ret = _to_units_container(ret, ureg)
255
256    def decorator(func: Callable[..., T]) -> Callable[..., Quantity[T]]:
257
258        count_params = len(signature(func).parameters)
259        if len(args) != count_params:
260            raise TypeError(
261                "%s takes %i parameters, but %i units were passed"
262                % (func.__name__, count_params, len(args))
263            )
264
265        assigned = tuple(
266            attr for attr in functools.WRAPPER_ASSIGNMENTS if hasattr(func, attr)
267        )
268        updated = tuple(
269            attr for attr in functools.WRAPPER_UPDATES if hasattr(func, attr)
270        )
271
272        @functools.wraps(func, assigned=assigned, updated=updated)
273        def wrapper(*values, **kw) -> Quantity[T]:
274
275            values, kw = _apply_defaults(func, values, kw)
276
277            # In principle, the values are used as is
278            # When then extract the magnitudes when needed.
279            new_values, values_by_name = converter(ureg, values, strict)
280
281            result = func(*new_values, **kw)
282
283            if is_ret_container:
284                out_units = (
285                    _replace_units(r, values_by_name) if is_ref else r
286                    for (r, is_ref) in ret
287                )
288                return ret.__class__(
289                    res if unit is None else ureg.Quantity(res, unit)
290                    for unit, res in zip_longest(out_units, result)
291                )
292
293            if ret[0] is None:
294                return result
295
296            return ureg.Quantity(
297                result, _replace_units(ret[0], values_by_name) if ret[1] else ret[0]
298            )
299
300        return wrapper
301
302    return decorator
303
304
305def check(
306    ureg: "UnitRegistry", *args: Union[str, UnitsContainer, "Unit", None]
307) -> Callable[[F], F]:
308    """Decorator to for quantity type checking for function inputs.
309
310    Use it to ensure that the decorated function input parameters match
311    the expected dimension of pint quantity.
312
313    The wrapper function raises:
314      - `pint.DimensionalityError` if an argument doesn't match the required dimensions.
315
316    ureg : UnitRegistry
317        a UnitRegistry instance.
318    args : str or UnitContainer or None
319        Dimensions of each of the input arguments.
320        Use `None` to skip argument conversion.
321
322    Returns
323    -------
324    callable
325        the wrapped function.
326
327    Raises
328    ------
329    TypeError
330        If the number of given dimensions does not match the number of function
331        parameters.
332    ValueError
333        If the any of the provided dimensions cannot be parsed as a dimension.
334    """
335    dimensions = [
336        ureg.get_dimensionality(dim) if dim is not None else None for dim in args
337    ]
338
339    def decorator(func):
340
341        count_params = len(signature(func).parameters)
342        if len(dimensions) != count_params:
343            raise TypeError(
344                "%s takes %i parameters, but %i dimensions were passed"
345                % (func.__name__, count_params, len(dimensions))
346            )
347
348        assigned = tuple(
349            attr for attr in functools.WRAPPER_ASSIGNMENTS if hasattr(func, attr)
350        )
351        updated = tuple(
352            attr for attr in functools.WRAPPER_UPDATES if hasattr(func, attr)
353        )
354
355        @functools.wraps(func, assigned=assigned, updated=updated)
356        def wrapper(*args, **kwargs):
357            list_args, empty = _apply_defaults(func, args, kwargs)
358
359            for dim, value in zip(dimensions, list_args):
360
361                if dim is None:
362                    continue
363
364                if not ureg.Quantity(value).check(dim):
365                    val_dim = ureg.get_dimensionality(value)
366                    raise DimensionalityError(value, "a quantity of", val_dim, dim)
367            return func(*args, **kwargs)
368
369        return wrapper
370
371    return decorator
372