1import collections 2from collections.abc import Iterable 3import textwrap 4from typing import Any 5from typing import Callable 6from typing import Dict 7from typing import List 8from typing import Optional 9from typing import overload 10from typing import Sequence 11from typing import Tuple 12from typing import TypeVar 13from typing import Union 14import uuid 15import warnings 16 17from sqlalchemy.util import asbool # noqa 18from sqlalchemy.util import immutabledict # noqa 19from sqlalchemy.util import memoized_property # noqa 20from sqlalchemy.util import to_list # noqa 21from sqlalchemy.util import unique_list # noqa 22 23from .compat import inspect_getfullargspec 24from .compat import string_types 25 26 27_T = TypeVar("_T") 28 29 30class _ModuleClsMeta(type): 31 def __setattr__(cls, key: str, value: Callable) -> None: 32 super(_ModuleClsMeta, cls).__setattr__(key, value) 33 cls._update_module_proxies(key) # type: ignore 34 35 36class ModuleClsProxy(metaclass=_ModuleClsMeta): 37 """Create module level proxy functions for the 38 methods on a given class. 39 40 The functions will have a compatible signature 41 as the methods. 42 43 """ 44 45 _setups: Dict[type, Tuple[set, list]] = collections.defaultdict( 46 lambda: (set(), []) 47 ) 48 49 @classmethod 50 def _update_module_proxies(cls, name: str) -> None: 51 attr_names, modules = cls._setups[cls] 52 for globals_, locals_ in modules: 53 cls._add_proxied_attribute(name, globals_, locals_, attr_names) 54 55 def _install_proxy(self) -> None: 56 attr_names, modules = self._setups[self.__class__] 57 for globals_, locals_ in modules: 58 globals_["_proxy"] = self 59 for attr_name in attr_names: 60 globals_[attr_name] = getattr(self, attr_name) 61 62 def _remove_proxy(self) -> None: 63 attr_names, modules = self._setups[self.__class__] 64 for globals_, locals_ in modules: 65 globals_["_proxy"] = None 66 for attr_name in attr_names: 67 del globals_[attr_name] 68 69 @classmethod 70 def create_module_class_proxy(cls, globals_, locals_): 71 attr_names, modules = cls._setups[cls] 72 modules.append((globals_, locals_)) 73 cls._setup_proxy(globals_, locals_, attr_names) 74 75 @classmethod 76 def _setup_proxy(cls, globals_, locals_, attr_names): 77 for methname in dir(cls): 78 cls._add_proxied_attribute(methname, globals_, locals_, attr_names) 79 80 @classmethod 81 def _add_proxied_attribute(cls, methname, globals_, locals_, attr_names): 82 if not methname.startswith("_"): 83 meth = getattr(cls, methname) 84 if callable(meth): 85 locals_[methname] = cls._create_method_proxy( 86 methname, globals_, locals_ 87 ) 88 else: 89 attr_names.add(methname) 90 91 @classmethod 92 def _create_method_proxy(cls, name, globals_, locals_): 93 fn = getattr(cls, name) 94 95 def _name_error(name, from_): 96 raise NameError( 97 "Can't invoke function '%s', as the proxy object has " 98 "not yet been " 99 "established for the Alembic '%s' class. " 100 "Try placing this code inside a callable." 101 % (name, cls.__name__) 102 ) from from_ 103 104 globals_["_name_error"] = _name_error 105 106 translations = getattr(fn, "_legacy_translations", []) 107 if translations: 108 spec = inspect_getfullargspec(fn) 109 if spec[0] and spec[0][0] == "self": 110 spec[0].pop(0) 111 112 outer_args = inner_args = "*args, **kw" 113 translate_str = "args, kw = _translate(%r, %r, %r, args, kw)" % ( 114 fn.__name__, 115 tuple(spec), 116 translations, 117 ) 118 119 def translate(fn_name, spec, translations, args, kw): 120 return_kw = {} 121 return_args = [] 122 123 for oldname, newname in translations: 124 if oldname in kw: 125 warnings.warn( 126 "Argument %r is now named %r " 127 "for method %s()." % (oldname, newname, fn_name) 128 ) 129 return_kw[newname] = kw.pop(oldname) 130 return_kw.update(kw) 131 132 args = list(args) 133 if spec[3]: 134 pos_only = spec[0][: -len(spec[3])] 135 else: 136 pos_only = spec[0] 137 for arg in pos_only: 138 if arg not in return_kw: 139 try: 140 return_args.append(args.pop(0)) 141 except IndexError: 142 raise TypeError( 143 "missing required positional argument: %s" 144 % arg 145 ) 146 return_args.extend(args) 147 148 return return_args, return_kw 149 150 globals_["_translate"] = translate 151 else: 152 outer_args = "*args, **kw" 153 inner_args = "*args, **kw" 154 translate_str = "" 155 156 func_text = textwrap.dedent( 157 """\ 158 def %(name)s(%(args)s): 159 %(doc)r 160 %(translate)s 161 try: 162 p = _proxy 163 except NameError as ne: 164 _name_error('%(name)s', ne) 165 return _proxy.%(name)s(%(apply_kw)s) 166 e 167 """ 168 % { 169 "name": name, 170 "translate": translate_str, 171 "args": outer_args, 172 "apply_kw": inner_args, 173 "doc": fn.__doc__, 174 } 175 ) 176 lcl = {} 177 178 exec(func_text, globals_, lcl) 179 return lcl[name] 180 181 182def _with_legacy_names(translations): 183 def decorate(fn): 184 fn._legacy_translations = translations 185 return fn 186 187 return decorate 188 189 190def rev_id() -> str: 191 return uuid.uuid4().hex[-12:] 192 193 194@overload 195def to_tuple(x: Any, default: tuple) -> tuple: 196 ... 197 198 199@overload 200def to_tuple(x: None, default: _T = None) -> _T: 201 ... 202 203 204@overload 205def to_tuple(x: Any, default: Optional[tuple] = None) -> tuple: 206 ... 207 208 209def to_tuple(x, default=None): 210 if x is None: 211 return default 212 elif isinstance(x, string_types): 213 return (x,) 214 elif isinstance(x, Iterable): 215 return tuple(x) 216 else: 217 return (x,) 218 219 220def dedupe_tuple(tup: Tuple[str, ...]) -> Tuple[str, ...]: 221 return tuple(unique_list(tup)) 222 223 224class Dispatcher: 225 def __init__(self, uselist: bool = False) -> None: 226 self._registry: Dict[tuple, Any] = {} 227 self.uselist = uselist 228 229 def dispatch_for( 230 self, target: Any, qualifier: str = "default" 231 ) -> Callable: 232 def decorate(fn): 233 if self.uselist: 234 self._registry.setdefault((target, qualifier), []).append(fn) 235 else: 236 assert (target, qualifier) not in self._registry 237 self._registry[(target, qualifier)] = fn 238 return fn 239 240 return decorate 241 242 def dispatch(self, obj: Any, qualifier: str = "default") -> Any: 243 244 if isinstance(obj, string_types): 245 targets: Sequence = [obj] 246 elif isinstance(obj, type): 247 targets = obj.__mro__ 248 else: 249 targets = type(obj).__mro__ 250 251 for spcls in targets: 252 if qualifier != "default" and (spcls, qualifier) in self._registry: 253 return self._fn_or_list(self._registry[(spcls, qualifier)]) 254 elif (spcls, "default") in self._registry: 255 return self._fn_or_list(self._registry[(spcls, "default")]) 256 else: 257 raise ValueError("no dispatch function for object: %s" % obj) 258 259 def _fn_or_list( 260 self, fn_or_list: Union[List[Callable], Callable] 261 ) -> Callable: 262 if self.uselist: 263 264 def go(*arg, **kw): 265 for fn in fn_or_list: 266 fn(*arg, **kw) 267 268 return go 269 else: 270 return fn_or_list # type: ignore 271 272 def branch(self) -> "Dispatcher": 273 """Return a copy of this dispatcher that is independently 274 writable.""" 275 276 d = Dispatcher() 277 if self.uselist: 278 d._registry.update( 279 (k, [fn for fn in self._registry[k]]) for k in self._registry 280 ) 281 else: 282 d._registry.update(self._registry) 283 return d 284