1import collections
2from collections.abc import Iterable
3import textwrap
4from typing import Any
5from typing import Callable
6from typing import Dict
7from typing import List
8from typing import Optional
9from typing import overload
10from typing import Sequence
11from typing import Tuple
12from typing import TypeVar
13from typing import Union
14import uuid
15import warnings
16
17from sqlalchemy.util import asbool  # noqa
18from sqlalchemy.util import immutabledict  # noqa
19from sqlalchemy.util import memoized_property  # noqa
20from sqlalchemy.util import to_list  # noqa
21from sqlalchemy.util import unique_list  # noqa
22
23from .compat import inspect_getfullargspec
24from .compat import string_types
25
26
27_T = TypeVar("_T")
28
29
30class _ModuleClsMeta(type):
31    def __setattr__(cls, key: str, value: Callable) -> None:
32        super(_ModuleClsMeta, cls).__setattr__(key, value)
33        cls._update_module_proxies(key)  # type: ignore
34
35
36class ModuleClsProxy(metaclass=_ModuleClsMeta):
37    """Create module level proxy functions for the
38    methods on a given class.
39
40    The functions will have a compatible signature
41    as the methods.
42
43    """
44
45    _setups: Dict[type, Tuple[set, list]] = collections.defaultdict(
46        lambda: (set(), [])
47    )
48
49    @classmethod
50    def _update_module_proxies(cls, name: str) -> None:
51        attr_names, modules = cls._setups[cls]
52        for globals_, locals_ in modules:
53            cls._add_proxied_attribute(name, globals_, locals_, attr_names)
54
55    def _install_proxy(self) -> None:
56        attr_names, modules = self._setups[self.__class__]
57        for globals_, locals_ in modules:
58            globals_["_proxy"] = self
59            for attr_name in attr_names:
60                globals_[attr_name] = getattr(self, attr_name)
61
62    def _remove_proxy(self) -> None:
63        attr_names, modules = self._setups[self.__class__]
64        for globals_, locals_ in modules:
65            globals_["_proxy"] = None
66            for attr_name in attr_names:
67                del globals_[attr_name]
68
69    @classmethod
70    def create_module_class_proxy(cls, globals_, locals_):
71        attr_names, modules = cls._setups[cls]
72        modules.append((globals_, locals_))
73        cls._setup_proxy(globals_, locals_, attr_names)
74
75    @classmethod
76    def _setup_proxy(cls, globals_, locals_, attr_names):
77        for methname in dir(cls):
78            cls._add_proxied_attribute(methname, globals_, locals_, attr_names)
79
80    @classmethod
81    def _add_proxied_attribute(cls, methname, globals_, locals_, attr_names):
82        if not methname.startswith("_"):
83            meth = getattr(cls, methname)
84            if callable(meth):
85                locals_[methname] = cls._create_method_proxy(
86                    methname, globals_, locals_
87                )
88            else:
89                attr_names.add(methname)
90
91    @classmethod
92    def _create_method_proxy(cls, name, globals_, locals_):
93        fn = getattr(cls, name)
94
95        def _name_error(name, from_):
96            raise NameError(
97                "Can't invoke function '%s', as the proxy object has "
98                "not yet been "
99                "established for the Alembic '%s' class.  "
100                "Try placing this code inside a callable."
101                % (name, cls.__name__)
102            ) from from_
103
104        globals_["_name_error"] = _name_error
105
106        translations = getattr(fn, "_legacy_translations", [])
107        if translations:
108            spec = inspect_getfullargspec(fn)
109            if spec[0] and spec[0][0] == "self":
110                spec[0].pop(0)
111
112            outer_args = inner_args = "*args, **kw"
113            translate_str = "args, kw = _translate(%r, %r, %r, args, kw)" % (
114                fn.__name__,
115                tuple(spec),
116                translations,
117            )
118
119            def translate(fn_name, spec, translations, args, kw):
120                return_kw = {}
121                return_args = []
122
123                for oldname, newname in translations:
124                    if oldname in kw:
125                        warnings.warn(
126                            "Argument %r is now named %r "
127                            "for method %s()." % (oldname, newname, fn_name)
128                        )
129                        return_kw[newname] = kw.pop(oldname)
130                return_kw.update(kw)
131
132                args = list(args)
133                if spec[3]:
134                    pos_only = spec[0][: -len(spec[3])]
135                else:
136                    pos_only = spec[0]
137                for arg in pos_only:
138                    if arg not in return_kw:
139                        try:
140                            return_args.append(args.pop(0))
141                        except IndexError:
142                            raise TypeError(
143                                "missing required positional argument: %s"
144                                % arg
145                            )
146                return_args.extend(args)
147
148                return return_args, return_kw
149
150            globals_["_translate"] = translate
151        else:
152            outer_args = "*args, **kw"
153            inner_args = "*args, **kw"
154            translate_str = ""
155
156        func_text = textwrap.dedent(
157            """\
158        def %(name)s(%(args)s):
159            %(doc)r
160            %(translate)s
161            try:
162                p = _proxy
163            except NameError as ne:
164                _name_error('%(name)s', ne)
165            return _proxy.%(name)s(%(apply_kw)s)
166            e
167        """
168            % {
169                "name": name,
170                "translate": translate_str,
171                "args": outer_args,
172                "apply_kw": inner_args,
173                "doc": fn.__doc__,
174            }
175        )
176        lcl = {}
177
178        exec(func_text, globals_, lcl)
179        return lcl[name]
180
181
182def _with_legacy_names(translations):
183    def decorate(fn):
184        fn._legacy_translations = translations
185        return fn
186
187    return decorate
188
189
190def rev_id() -> str:
191    return uuid.uuid4().hex[-12:]
192
193
194@overload
195def to_tuple(x: Any, default: tuple) -> tuple:
196    ...
197
198
199@overload
200def to_tuple(x: None, default: _T = None) -> _T:
201    ...
202
203
204@overload
205def to_tuple(x: Any, default: Optional[tuple] = None) -> tuple:
206    ...
207
208
209def to_tuple(x, default=None):
210    if x is None:
211        return default
212    elif isinstance(x, string_types):
213        return (x,)
214    elif isinstance(x, Iterable):
215        return tuple(x)
216    else:
217        return (x,)
218
219
220def dedupe_tuple(tup: Tuple[str, ...]) -> Tuple[str, ...]:
221    return tuple(unique_list(tup))
222
223
224class Dispatcher:
225    def __init__(self, uselist: bool = False) -> None:
226        self._registry: Dict[tuple, Any] = {}
227        self.uselist = uselist
228
229    def dispatch_for(
230        self, target: Any, qualifier: str = "default"
231    ) -> Callable:
232        def decorate(fn):
233            if self.uselist:
234                self._registry.setdefault((target, qualifier), []).append(fn)
235            else:
236                assert (target, qualifier) not in self._registry
237                self._registry[(target, qualifier)] = fn
238            return fn
239
240        return decorate
241
242    def dispatch(self, obj: Any, qualifier: str = "default") -> Any:
243
244        if isinstance(obj, string_types):
245            targets: Sequence = [obj]
246        elif isinstance(obj, type):
247            targets = obj.__mro__
248        else:
249            targets = type(obj).__mro__
250
251        for spcls in targets:
252            if qualifier != "default" and (spcls, qualifier) in self._registry:
253                return self._fn_or_list(self._registry[(spcls, qualifier)])
254            elif (spcls, "default") in self._registry:
255                return self._fn_or_list(self._registry[(spcls, "default")])
256        else:
257            raise ValueError("no dispatch function for object: %s" % obj)
258
259    def _fn_or_list(
260        self, fn_or_list: Union[List[Callable], Callable]
261    ) -> Callable:
262        if self.uselist:
263
264            def go(*arg, **kw):
265                for fn in fn_or_list:
266                    fn(*arg, **kw)
267
268            return go
269        else:
270            return fn_or_list  # type: ignore
271
272    def branch(self) -> "Dispatcher":
273        """Return a copy of this dispatcher that is independently
274        writable."""
275
276        d = Dispatcher()
277        if self.uselist:
278            d._registry.update(
279                (k, [fn for fn in self._registry[k]]) for k in self._registry
280            )
281        else:
282            d._registry.update(self._registry)
283        return d
284