1import typing 2import inspect 3import functools 4from . import _uarray # type: ignore 5import copyreg # type: ignore 6import atexit 7import pickle 8 9ArgumentExtractorType = typing.Callable[..., typing.Tuple["Dispatchable", ...]] 10ArgumentReplacerType = typing.Callable[ 11 [typing.Tuple, typing.Dict, typing.Tuple], typing.Tuple[typing.Tuple, typing.Dict] 12] 13 14from ._uarray import ( # type: ignore 15 BackendNotImplementedError, 16 _Function, 17 _SkipBackendContext, 18 _SetBackendContext, 19) 20 21__all__ = [ 22 "set_backend", 23 "set_global_backend", 24 "skip_backend", 25 "register_backend", 26 "clear_backends", 27 "create_multimethod", 28 "generate_multimethod", 29 "_Function", 30 "BackendNotImplementedError", 31 "Dispatchable", 32 "wrap_single_convertor", 33 "all_of_type", 34 "mark_as", 35] 36 37 38def unpickle_function(mod_name, qname): 39 import importlib 40 41 try: 42 module = importlib.import_module(mod_name) 43 func = getattr(module, qname) 44 return func 45 except (ImportError, AttributeError) as e: 46 from pickle import UnpicklingError 47 48 raise UnpicklingError from e 49 50 51def pickle_function(func): 52 mod_name = getattr(func, "__module__", None) 53 qname = getattr(func, "__qualname__", None) 54 55 try: 56 test = unpickle_function(mod_name, qname) 57 except pickle.UnpicklingError: 58 test = None 59 60 if test is not func: 61 raise pickle.PicklingError( 62 "Can't pickle {}: it's not the same object as {}".format(func, test) 63 ) 64 65 return unpickle_function, (mod_name, qname) 66 67 68copyreg.pickle(_Function, pickle_function) 69atexit.register(_uarray.clear_all_globals) 70 71 72def create_multimethod(*args, **kwargs): 73 """ 74 Creates a decorator for generating multimethods. 75 76 This function creates a decorator that can be used with an argument 77 extractor in order to generate a multimethod. Other than for the 78 argument extractor, all arguments are passed on to 79 :obj:`generate_multimethod`. 80 81 See Also 82 -------- 83 generate_multimethod : Generates a multimethod. 84 """ 85 86 def wrapper(a): 87 return generate_multimethod(a, *args, **kwargs) 88 89 return wrapper 90 91 92def generate_multimethod( 93 argument_extractor: ArgumentExtractorType, 94 argument_replacer: ArgumentReplacerType, 95 domain: str, 96 default: typing.Optional[typing.Callable] = None 97): 98 """ 99 Generates a multimethod. 100 101 Parameters 102 ---------- 103 argument_extractor : ArgumentExtractorType 104 A callable which extracts the dispatchable arguments. Extracted arguments 105 should be marked by the :obj:`Dispatchable` class. It has the same signature 106 as the desired multimethod. 107 argument_replacer : ArgumentReplacerType 108 A callable with the signature (args, kwargs, dispatchables), which should also 109 return an (args, kwargs) pair with the dispatchables replaced inside the args/kwargs. 110 domain : str 111 A string value indicating the domain of this multimethod. 112 default : Optional[Callable], optional 113 The default implementation of this multimethod, where ``None`` (the default) specifies 114 there is no default implementation. 115 116 Examples 117 -------- 118 In this example, ``a`` is to be dispatched over, so we return it, while marking it as an ``int``. 119 The trailing comma is needed because the args have to be returned as an iterable. 120 121 >>> def override_me(a, b): 122 ... return Dispatchable(a, int), 123 124 Next, we define the argument replacer that replaces the dispatchables inside args/kwargs with the 125 supplied ones. 126 127 >>> def override_replacer(args, kwargs, dispatchables): 128 ... return (dispatchables[0], args[1]), {} 129 130 Next, we define the multimethod. 131 132 >>> overridden_me = generate_multimethod( 133 ... override_me, override_replacer, "ua_examples" 134 ... ) 135 136 Notice that there's no default implementation, unless you supply one. 137 138 >>> overridden_me(1, "a") 139 Traceback (most recent call last): 140 ... 141 uarray.backend.BackendNotImplementedError: ... 142 >>> overridden_me2 = generate_multimethod( 143 ... override_me, override_replacer, "ua_examples", default=lambda x, y: (x, y) 144 ... ) 145 >>> overridden_me2(1, "a") 146 (1, 'a') 147 148 See Also 149 -------- 150 uarray : 151 See the module documentation for how to override the method by creating backends. 152 """ 153 kw_defaults, arg_defaults, opts = get_defaults(argument_extractor) 154 ua_func = _Function( 155 argument_extractor, 156 argument_replacer, 157 domain, 158 arg_defaults, 159 kw_defaults, 160 default, 161 ) 162 163 return functools.update_wrapper(ua_func, argument_extractor) 164 165 166def set_backend(backend, coerce=False, only=False): 167 """ 168 A context manager that sets the preferred backend. 169 170 Parameters 171 ---------- 172 backend 173 The backend to set. 174 coerce 175 Whether or not to coerce to a specific backend's types. Implies ``only``. 176 only 177 Whether or not this should be the last backend to try. 178 179 See Also 180 -------- 181 skip_backend : A context manager that allows skipping of backends. 182 set_global_backend : Set a single, global backend for a domain. 183 """ 184 try: 185 return backend.__ua_cache__["set", coerce, only] 186 except AttributeError: 187 backend.__ua_cache__ = {} 188 except KeyError: 189 pass 190 191 ctx = _SetBackendContext(backend, coerce, only) 192 backend.__ua_cache__["set", coerce, only] = ctx 193 return ctx 194 195 196def skip_backend(backend): 197 """ 198 A context manager that allows one to skip a given backend from processing 199 entirely. This allows one to use another backend's code in a library that 200 is also a consumer of the same backend. 201 202 Parameters 203 ---------- 204 backend 205 The backend to skip. 206 207 See Also 208 -------- 209 set_backend : A context manager that allows setting of backends. 210 set_global_backend : Set a single, global backend for a domain. 211 """ 212 try: 213 return backend.__ua_cache__["skip"] 214 except AttributeError: 215 backend.__ua_cache__ = {} 216 except KeyError: 217 pass 218 219 ctx = _SkipBackendContext(backend) 220 backend.__ua_cache__["skip"] = ctx 221 return ctx 222 223 224def get_defaults(f): 225 sig = inspect.signature(f) 226 kw_defaults = {} 227 arg_defaults = [] 228 opts = set() 229 for k, v in sig.parameters.items(): 230 if v.default is not inspect.Parameter.empty: 231 kw_defaults[k] = v.default 232 if v.kind in ( 233 inspect.Parameter.POSITIONAL_ONLY, 234 inspect.Parameter.POSITIONAL_OR_KEYWORD, 235 ): 236 arg_defaults.append(v.default) 237 opts.add(k) 238 239 return kw_defaults, tuple(arg_defaults), opts 240 241 242def set_global_backend(backend, coerce=False, only=False): 243 """ 244 This utility method replaces the default backend for permanent use. It 245 will be tried in the list of backends automatically, unless the 246 ``only`` flag is set on a backend. This will be the first tried 247 backend outside the :obj:`set_backend` context manager. 248 249 Note that this method is not thread-safe. 250 251 .. warning:: 252 We caution library authors against using this function in 253 their code. We do *not* support this use-case. This function 254 is meant to be used only by users themselves, or by a reference 255 implementation, if one exists. 256 257 Parameters 258 ---------- 259 backend 260 The backend to register. 261 262 See Also 263 -------- 264 set_backend : A context manager that allows setting of backends. 265 skip_backend : A context manager that allows skipping of backends. 266 """ 267 _uarray.set_global_backend(backend, coerce, only) 268 269 270def register_backend(backend): 271 """ 272 This utility method sets registers backend for permanent use. It 273 will be tried in the list of backends automatically, unless the 274 ``only`` flag is set on a backend. 275 276 Note that this method is not thread-safe. 277 278 Parameters 279 ---------- 280 backend 281 The backend to register. 282 """ 283 _uarray.register_backend(backend) 284 285 286def clear_backends(domain, registered=True, globals=False): 287 """ 288 This utility method clears registered backends. 289 290 .. warning:: 291 We caution library authors against using this function in 292 their code. We do *not* support this use-case. This function 293 is meant to be used only by the users themselves. 294 295 .. warning:: 296 Do NOT use this method inside a multimethod call, or the 297 program is likely to crash. 298 299 Parameters 300 ---------- 301 domain : Optional[str] 302 The domain for which to de-register backends. ``None`` means 303 de-register for all domains. 304 registered : bool 305 Whether or not to clear registered backends. See :obj:`register_backend`. 306 globals : bool 307 Whether or not to clear global backends. See :obj:`set_global_backend`. 308 309 See Also 310 -------- 311 register_backend : Register a backend globally. 312 set_global_backend : Set a global backend. 313 """ 314 _uarray.clear_backends(domain, registered, globals) 315 316 317class Dispatchable: 318 """ 319 A utility class which marks an argument with a specific dispatch type. 320 321 322 Attributes 323 ---------- 324 value 325 The value of the Dispatchable. 326 327 type 328 The type of the Dispatchable. 329 330 Examples 331 -------- 332 >>> x = Dispatchable(1, str) 333 >>> x 334 <Dispatchable: type=<class 'str'>, value=1> 335 336 See Also 337 -------- 338 all_of_type 339 Marks all unmarked parameters of a function. 340 341 mark_as 342 Allows one to create a utility function to mark as a given type. 343 """ 344 345 def __init__(self, value, dispatch_type, coercible=True): 346 self.value = value 347 self.type = dispatch_type 348 self.coercible = coercible 349 350 def __getitem__(self, index): 351 return (self.type, self.value)[index] 352 353 def __str__(self): 354 return "<{0}: type={1!r}, value={2!r}>".format( 355 type(self).__name__, self.type, self.value 356 ) 357 358 __repr__ = __str__ 359 360 361def mark_as(dispatch_type): 362 """ 363 Creates a utility function to mark something as a specific type. 364 365 Examples 366 -------- 367 >>> mark_int = mark_as(int) 368 >>> mark_int(1) 369 <Dispatchable: type=<class 'int'>, value=1> 370 """ 371 return functools.partial(Dispatchable, dispatch_type=dispatch_type) 372 373 374def all_of_type(arg_type): 375 """ 376 Marks all unmarked arguments as a given type. 377 378 Examples 379 -------- 380 >>> @all_of_type(str) 381 ... def f(a, b): 382 ... return a, Dispatchable(b, int) 383 >>> f('a', 1) 384 (<Dispatchable: type=<class 'str'>, value='a'>, <Dispatchable: type=<class 'int'>, value=1>) 385 """ 386 387 def outer(func): 388 @functools.wraps(func) 389 def inner(*args, **kwargs): 390 extracted_args = func(*args, **kwargs) 391 return tuple( 392 Dispatchable(arg, arg_type) 393 if not isinstance(arg, Dispatchable) 394 else arg 395 for arg in extracted_args 396 ) 397 398 return inner 399 400 return outer 401 402 403def wrap_single_convertor(convert_single): 404 """ 405 Wraps a ``__ua_convert__`` defined for a single element to all elements. 406 If any of them return ``NotImplemented``, the operation is assumed to be 407 undefined. 408 409 Accepts a signature of (value, type, coerce). 410 """ 411 412 @functools.wraps(convert_single) 413 def __ua_convert__(dispatchables, coerce): 414 converted = [] 415 for d in dispatchables: 416 c = convert_single(d.value, d.type, coerce and d.coercible) 417 418 if c is NotImplemented: 419 return NotImplemented 420 421 converted.append(c) 422 423 return converted 424 425 return __ua_convert__ 426