1""" 2 pint.registry_helpers 3 ~~~~~~~~~~~~~~~~~~~~~ 4 5 Miscellaneous methods of the registry written as separate functions. 6 7 :copyright: 2016 by Pint Authors, see AUTHORS for more details.. 8 :license: BSD, see LICENSE for more details. 9""" 10 11import functools 12from inspect import signature 13from itertools import zip_longest 14from typing import TYPE_CHECKING, Callable, Iterable, TypeVar, Union 15 16from ._typing import F 17from .errors import DimensionalityError 18from .quantity import Quantity 19from .util import UnitsContainer, to_units_container 20 21if TYPE_CHECKING: 22 from .registry import UnitRegistry 23 from .unit import Unit 24 25T = TypeVar("T") 26 27 28def _replace_units(original_units, values_by_name): 29 """Convert a unit compatible type to a UnitsContainer. 30 31 Parameters 32 ---------- 33 original_units : 34 a UnitsContainer instance. 35 values_by_name : 36 a map between original names and the new values. 37 38 Returns 39 ------- 40 41 """ 42 q = 1 43 for arg_name, exponent in original_units.items(): 44 q = q * values_by_name[arg_name] ** exponent 45 46 return getattr(q, "_units", UnitsContainer({})) 47 48 49def _to_units_container(a, registry=None): 50 """Convert a unit compatible type to a UnitsContainer, 51 checking if it is string field prefixed with an equal 52 (which is considered a reference) 53 54 Parameters 55 ---------- 56 a : 57 58 registry : 59 (Default value = None) 60 61 Returns 62 ------- 63 UnitsContainer, bool 64 65 66 """ 67 if isinstance(a, str) and "=" in a: 68 return to_units_container(a.split("=", 1)[1]), True 69 return to_units_container(a, registry), False 70 71 72def _parse_wrap_args(args, registry=None): 73 74 # Arguments which contain definitions 75 # (i.e. names that appear alone and for the first time) 76 defs_args = set() 77 defs_args_ndx = set() 78 79 # Arguments which depend on others 80 dependent_args_ndx = set() 81 82 # Arguments which have units. 83 unit_args_ndx = set() 84 85 # _to_units_container 86 args_as_uc = [_to_units_container(arg, registry) for arg in args] 87 88 # Check for references in args, remove None values 89 for ndx, (arg, is_ref) in enumerate(args_as_uc): 90 if arg is None: 91 continue 92 elif is_ref: 93 if len(arg) == 1: 94 [(key, value)] = arg.items() 95 if value == 1 and key not in defs_args: 96 # This is the first time that 97 # a variable is used => it is a definition. 98 defs_args.add(key) 99 defs_args_ndx.add(ndx) 100 args_as_uc[ndx] = (key, True) 101 else: 102 # The variable was already found elsewhere, 103 # we consider it a dependent variable. 104 dependent_args_ndx.add(ndx) 105 else: 106 dependent_args_ndx.add(ndx) 107 else: 108 unit_args_ndx.add(ndx) 109 110 # Check that all valid dependent variables 111 for ndx in dependent_args_ndx: 112 arg, is_ref = args_as_uc[ndx] 113 if not isinstance(arg, dict): 114 continue 115 if not set(arg.keys()) <= defs_args: 116 raise ValueError( 117 "Found a missing token while wrapping a function: " 118 "Not all variable referenced in %s are defined using !" % args[ndx] 119 ) 120 121 def _converter(ureg, values, strict): 122 new_values = list(value for value in values) 123 124 values_by_name = {} 125 126 # first pass: Grab named values 127 for ndx in defs_args_ndx: 128 value = values[ndx] 129 values_by_name[args_as_uc[ndx][0]] = value 130 new_values[ndx] = getattr(value, "_magnitude", value) 131 132 # second pass: calculate derived values based on named values 133 for ndx in dependent_args_ndx: 134 value = values[ndx] 135 assert _replace_units(args_as_uc[ndx][0], values_by_name) is not None 136 new_values[ndx] = ureg._convert( 137 getattr(value, "_magnitude", value), 138 getattr(value, "_units", UnitsContainer({})), 139 _replace_units(args_as_uc[ndx][0], values_by_name), 140 ) 141 142 # third pass: convert other arguments 143 for ndx in unit_args_ndx: 144 145 if isinstance(values[ndx], ureg.Quantity): 146 new_values[ndx] = ureg._convert( 147 values[ndx]._magnitude, values[ndx]._units, args_as_uc[ndx][0] 148 ) 149 else: 150 if strict: 151 if isinstance(values[ndx], str): 152 # if the value is a string, we try to parse it 153 tmp_value = ureg.parse_expression(values[ndx]) 154 new_values[ndx] = ureg._convert( 155 tmp_value._magnitude, tmp_value._units, args_as_uc[ndx][0] 156 ) 157 else: 158 raise ValueError( 159 "A wrapped function using strict=True requires " 160 "quantity or a string for all arguments with not None units. " 161 "(error found for {}, {})".format( 162 args_as_uc[ndx][0], new_values[ndx] 163 ) 164 ) 165 166 return new_values, values_by_name 167 168 return _converter 169 170 171def _apply_defaults(func, args, kwargs): 172 """Apply default keyword arguments. 173 174 Named keywords may have been left blank. This function applies the default 175 values so that every argument is defined. 176 """ 177 178 sig = signature(func) 179 bound_arguments = sig.bind(*args, **kwargs) 180 for param in sig.parameters.values(): 181 if param.name not in bound_arguments.arguments: 182 bound_arguments.arguments[param.name] = param.default 183 args = [bound_arguments.arguments[key] for key in sig.parameters.keys()] 184 return args, {} 185 186 187def wraps( 188 ureg: "UnitRegistry", 189 ret: Union[str, "Unit", Iterable[str], Iterable["Unit"], None], 190 args: Union[str, "Unit", Iterable[str], Iterable["Unit"], None], 191 strict: bool = True, 192) -> Callable[[Callable[..., T]], Callable[..., Quantity[T]]]: 193 """Wraps a function to become pint-aware. 194 195 Use it when a function requires a numerical value but in some specific 196 units. The wrapper function will take a pint quantity, convert to the units 197 specified in `args` and then call the wrapped function with the resulting 198 magnitude. 199 200 The value returned by the wrapped function will be converted to the units 201 specified in `ret`. 202 203 Parameters 204 ---------- 205 ureg : pint.UnitRegistry 206 a UnitRegistry instance. 207 ret : str, pint.Unit, iterable of str, or iterable of pint.Unit 208 Units of each of the return values. Use `None` to skip argument conversion. 209 args : str, pint.Unit, iterable of str, or iterable of pint.Unit 210 Units of each of the input arguments. Use `None` to skip argument conversion. 211 strict : bool 212 Indicates that only quantities are accepted. (Default value = True) 213 214 Returns 215 ------- 216 callable 217 the wrapper function. 218 219 Raises 220 ------ 221 TypeError 222 if the number of given arguments does not match the number of function parameters. 223 if the any of the provided arguments is not a unit a string or Quantity 224 225 """ 226 227 if not isinstance(args, (list, tuple)): 228 args = (args,) 229 230 for arg in args: 231 if arg is not None and not isinstance(arg, (ureg.Unit, str)): 232 raise TypeError( 233 "wraps arguments must by of type str or Unit, not %s (%s)" 234 % (type(arg), arg) 235 ) 236 237 converter = _parse_wrap_args(args) 238 239 is_ret_container = isinstance(ret, (list, tuple)) 240 if is_ret_container: 241 for arg in ret: 242 if arg is not None and not isinstance(arg, (ureg.Unit, str)): 243 raise TypeError( 244 "wraps 'ret' argument must by of type str or Unit, not %s (%s)" 245 % (type(arg), arg) 246 ) 247 ret = ret.__class__([_to_units_container(arg, ureg) for arg in ret]) 248 else: 249 if ret is not None and not isinstance(ret, (ureg.Unit, str)): 250 raise TypeError( 251 "wraps 'ret' argument must by of type str or Unit, not %s (%s)" 252 % (type(ret), ret) 253 ) 254 ret = _to_units_container(ret, ureg) 255 256 def decorator(func: Callable[..., T]) -> Callable[..., Quantity[T]]: 257 258 count_params = len(signature(func).parameters) 259 if len(args) != count_params: 260 raise TypeError( 261 "%s takes %i parameters, but %i units were passed" 262 % (func.__name__, count_params, len(args)) 263 ) 264 265 assigned = tuple( 266 attr for attr in functools.WRAPPER_ASSIGNMENTS if hasattr(func, attr) 267 ) 268 updated = tuple( 269 attr for attr in functools.WRAPPER_UPDATES if hasattr(func, attr) 270 ) 271 272 @functools.wraps(func, assigned=assigned, updated=updated) 273 def wrapper(*values, **kw) -> Quantity[T]: 274 275 values, kw = _apply_defaults(func, values, kw) 276 277 # In principle, the values are used as is 278 # When then extract the magnitudes when needed. 279 new_values, values_by_name = converter(ureg, values, strict) 280 281 result = func(*new_values, **kw) 282 283 if is_ret_container: 284 out_units = ( 285 _replace_units(r, values_by_name) if is_ref else r 286 for (r, is_ref) in ret 287 ) 288 return ret.__class__( 289 res if unit is None else ureg.Quantity(res, unit) 290 for unit, res in zip_longest(out_units, result) 291 ) 292 293 if ret[0] is None: 294 return result 295 296 return ureg.Quantity( 297 result, _replace_units(ret[0], values_by_name) if ret[1] else ret[0] 298 ) 299 300 return wrapper 301 302 return decorator 303 304 305def check( 306 ureg: "UnitRegistry", *args: Union[str, UnitsContainer, "Unit", None] 307) -> Callable[[F], F]: 308 """Decorator to for quantity type checking for function inputs. 309 310 Use it to ensure that the decorated function input parameters match 311 the expected dimension of pint quantity. 312 313 The wrapper function raises: 314 - `pint.DimensionalityError` if an argument doesn't match the required dimensions. 315 316 ureg : UnitRegistry 317 a UnitRegistry instance. 318 args : str or UnitContainer or None 319 Dimensions of each of the input arguments. 320 Use `None` to skip argument conversion. 321 322 Returns 323 ------- 324 callable 325 the wrapped function. 326 327 Raises 328 ------ 329 TypeError 330 If the number of given dimensions does not match the number of function 331 parameters. 332 ValueError 333 If the any of the provided dimensions cannot be parsed as a dimension. 334 """ 335 dimensions = [ 336 ureg.get_dimensionality(dim) if dim is not None else None for dim in args 337 ] 338 339 def decorator(func): 340 341 count_params = len(signature(func).parameters) 342 if len(dimensions) != count_params: 343 raise TypeError( 344 "%s takes %i parameters, but %i dimensions were passed" 345 % (func.__name__, count_params, len(dimensions)) 346 ) 347 348 assigned = tuple( 349 attr for attr in functools.WRAPPER_ASSIGNMENTS if hasattr(func, attr) 350 ) 351 updated = tuple( 352 attr for attr in functools.WRAPPER_UPDATES if hasattr(func, attr) 353 ) 354 355 @functools.wraps(func, assigned=assigned, updated=updated) 356 def wrapper(*args, **kwargs): 357 list_args, empty = _apply_defaults(func, args, kwargs) 358 359 for dim, value in zip(dimensions, list_args): 360 361 if dim is None: 362 continue 363 364 if not ureg.Quantity(value).check(dim): 365 val_dim = ureg.get_dimensionality(value) 366 raise DimensionalityError(value, "a quantity of", val_dim, dim) 367 return func(*args, **kwargs) 368 369 return wrapper 370 371 return decorator 372