1# sql/lambdas.py
2# Copyright (C) 2005-2021 the SQLAlchemy authors and contributors
3# <see AUTHORS file>
4#
5# This module is part of SQLAlchemy and is released under
6# the MIT License: https://www.opensource.org/licenses/mit-license.php
7
8import itertools
9import operator
10import sys
11import types
12import weakref
13
14from . import coercions
15from . import elements
16from . import roles
17from . import schema
18from . import traversals
19from . import type_api
20from . import visitors
21from .base import _clone
22from .base import Options
23from .operators import ColumnOperators
24from .. import exc
25from .. import inspection
26from .. import util
27from ..util import collections_abc
28from ..util import compat
29
30_closure_per_cache_key = util.LRUCache(1000)
31
32
33class LambdaOptions(Options):
34    enable_tracking = True
35    track_closure_variables = True
36    track_on = None
37    global_track_bound_values = True
38    track_bound_values = True
39    lambda_cache = None
40
41
42def lambda_stmt(
43    lmb,
44    enable_tracking=True,
45    track_closure_variables=True,
46    track_on=None,
47    global_track_bound_values=True,
48    track_bound_values=True,
49    lambda_cache=None,
50):
51    """Produce a SQL statement that is cached as a lambda.
52
53    The Python code object within the lambda is scanned for both Python
54    literals that will become bound parameters as well as closure variables
55    that refer to Core or ORM constructs that may vary.   The lambda itself
56    will be invoked only once per particular set of constructs detected.
57
58    E.g.::
59
60        from sqlalchemy import lambda_stmt
61
62        stmt = lambda_stmt(lambda: table.select())
63        stmt += lambda s: s.where(table.c.id == 5)
64
65        result = connection.execute(stmt)
66
67    The object returned is an instance of :class:`_sql.StatementLambdaElement`.
68
69    .. versionadded:: 1.4
70
71    :param lmb: a Python function, typically a lambda, which takes no arguments
72     and returns a SQL expression construct
73    :param enable_tracking: when False, all scanning of the given lambda for
74     changes in closure variables or bound parameters is disabled.  Use for
75     a lambda that produces the identical results in all cases with no
76     parameterization.
77    :param track_closure_variables: when False, changes in closure variables
78     within the lambda will not be scanned.   Use for a lambda where the
79     state of its closure variables will never change the SQL structure
80     returned by the lambda.
81    :param track_bound_values: when False, bound parameter tracking will
82     be disabled for the given lambda.  Use for a lambda that either does
83     not produce any bound values, or where the initial bound values never
84     change.
85    :param global_track_bound_values: when False, bound parameter tracking
86     will be disabled for the entire statement including additional links
87     added via the :meth:`_sql.StatementLambdaElement.add_criteria` method.
88    :param lambda_cache: a dictionary or other mapping-like object where
89     information about the lambda's Python code as well as the tracked closure
90     variables in the lambda itself will be stored.   Defaults
91     to a global LRU cache.  This cache is independent of the "compiled_cache"
92     used by the :class:`_engine.Connection` object.
93
94    .. seealso::
95
96        :ref:`engine_lambda_caching`
97
98
99    """
100
101    return StatementLambdaElement(
102        lmb,
103        roles.StatementRole,
104        LambdaOptions(
105            enable_tracking=enable_tracking,
106            track_on=track_on,
107            track_closure_variables=track_closure_variables,
108            global_track_bound_values=global_track_bound_values,
109            track_bound_values=track_bound_values,
110            lambda_cache=lambda_cache,
111        ),
112    )
113
114
115class LambdaElement(elements.ClauseElement):
116    """A SQL construct where the state is stored as an un-invoked lambda.
117
118    The :class:`_sql.LambdaElement` is produced transparently whenever
119    passing lambda expressions into SQL constructs, such as::
120
121        stmt = select(table).where(lambda: table.c.col == parameter)
122
123    The :class:`_sql.LambdaElement` is the base of the
124    :class:`_sql.StatementLambdaElement` which represents a full statement
125    within a lambda.
126
127    .. versionadded:: 1.4
128
129    .. seealso::
130
131        :ref:`engine_lambda_caching`
132
133    """
134
135    __visit_name__ = "lambda_element"
136
137    _is_lambda_element = True
138
139    _traverse_internals = [
140        ("_resolved", visitors.InternalTraversal.dp_clauseelement)
141    ]
142
143    _transforms = ()
144
145    parent_lambda = None
146
147    def __repr__(self):
148        return "%s(%r)" % (self.__class__.__name__, self.fn.__code__)
149
150    def __init__(
151        self, fn, role, opts=LambdaOptions, apply_propagate_attrs=None
152    ):
153        self.fn = fn
154        self.role = role
155        self.tracker_key = (fn.__code__,)
156        self.opts = opts
157
158        if apply_propagate_attrs is None and (role is roles.StatementRole):
159            apply_propagate_attrs = self
160
161        rec = self._retrieve_tracker_rec(fn, apply_propagate_attrs, opts)
162
163        if apply_propagate_attrs is not None:
164            propagate_attrs = rec.propagate_attrs
165            if propagate_attrs:
166                apply_propagate_attrs._propagate_attrs = propagate_attrs
167
168    def _retrieve_tracker_rec(self, fn, apply_propagate_attrs, opts):
169        lambda_cache = opts.lambda_cache
170        if lambda_cache is None:
171            lambda_cache = _closure_per_cache_key
172
173        tracker_key = self.tracker_key
174
175        fn = self.fn
176        closure = fn.__closure__
177        tracker = AnalyzedCode.get(
178            fn,
179            self,
180            opts,
181        )
182
183        self._resolved_bindparams = bindparams = []
184
185        if self.parent_lambda is not None:
186            parent_closure_cache_key = self.parent_lambda.closure_cache_key
187        else:
188            parent_closure_cache_key = ()
189
190        if parent_closure_cache_key is not traversals.NO_CACHE:
191            anon_map = traversals.anon_map()
192            cache_key = tuple(
193                [
194                    getter(closure, opts, anon_map, bindparams)
195                    for getter in tracker.closure_trackers
196                ]
197            )
198
199            if traversals.NO_CACHE not in anon_map:
200                cache_key = parent_closure_cache_key + cache_key
201
202                self.closure_cache_key = cache_key
203
204                try:
205                    rec = lambda_cache[tracker_key + cache_key]
206                except KeyError:
207                    rec = None
208            else:
209                cache_key = traversals.NO_CACHE
210                rec = None
211
212        else:
213            cache_key = traversals.NO_CACHE
214            rec = None
215
216        self.closure_cache_key = cache_key
217
218        if rec is None:
219            if cache_key is not traversals.NO_CACHE:
220                rec = AnalyzedFunction(
221                    tracker, self, apply_propagate_attrs, fn
222                )
223                rec.closure_bindparams = bindparams
224                lambda_cache[tracker_key + cache_key] = rec
225            else:
226                rec = NonAnalyzedFunction(self._invoke_user_fn(fn))
227
228        else:
229            bindparams[:] = [
230                orig_bind._with_value(new_bind.value, maintain_key=True)
231                for orig_bind, new_bind in zip(
232                    rec.closure_bindparams, bindparams
233                )
234            ]
235
236        self._rec = rec
237
238        if cache_key is not traversals.NO_CACHE:
239            if self.parent_lambda is not None:
240                bindparams[:0] = self.parent_lambda._resolved_bindparams
241
242            lambda_element = self
243            while lambda_element is not None:
244                rec = lambda_element._rec
245                if rec.bindparam_trackers:
246                    tracker_instrumented_fn = rec.tracker_instrumented_fn
247                    for tracker in rec.bindparam_trackers:
248                        tracker(
249                            lambda_element.fn,
250                            tracker_instrumented_fn,
251                            bindparams,
252                        )
253                lambda_element = lambda_element.parent_lambda
254
255        return rec
256
257    def __getattr__(self, key):
258        return getattr(self._rec.expected_expr, key)
259
260    @property
261    def _is_sequence(self):
262        return self._rec.is_sequence
263
264    @property
265    def _select_iterable(self):
266        if self._is_sequence:
267            return itertools.chain.from_iterable(
268                [element._select_iterable for element in self._resolved]
269            )
270
271        else:
272            return self._resolved._select_iterable
273
274    @property
275    def _from_objects(self):
276        if self._is_sequence:
277            return itertools.chain.from_iterable(
278                [element._from_objects for element in self._resolved]
279            )
280
281        else:
282            return self._resolved._from_objects
283
284    def _param_dict(self):
285        return {b.key: b.value for b in self._resolved_bindparams}
286
287    def _setup_binds_for_tracked_expr(self, expr):
288        bindparam_lookup = {b.key: b for b in self._resolved_bindparams}
289
290        def replace(thing):
291            if isinstance(thing, elements.BindParameter):
292
293                if thing.key in bindparam_lookup:
294                    bind = bindparam_lookup[thing.key]
295                    if thing.expanding:
296                        bind.expanding = True
297                        bind.expand_op = thing.expand_op
298                        bind.type = thing.type
299                    return bind
300
301        if self._rec.is_sequence:
302            expr = [
303                visitors.replacement_traverse(sub_expr, {}, replace)
304                for sub_expr in expr
305            ]
306        elif getattr(expr, "is_clause_element", False):
307            expr = visitors.replacement_traverse(expr, {}, replace)
308
309        return expr
310
311    def _copy_internals(
312        self, clone=_clone, deferred_copy_internals=None, **kw
313    ):
314        # TODO: this needs A LOT of tests
315        self._resolved = clone(
316            self._resolved,
317            deferred_copy_internals=deferred_copy_internals,
318            **kw
319        )
320
321    @util.memoized_property
322    def _resolved(self):
323        expr = self._rec.expected_expr
324
325        if self._resolved_bindparams:
326            expr = self._setup_binds_for_tracked_expr(expr)
327
328        return expr
329
330    def _gen_cache_key(self, anon_map, bindparams):
331        if self.closure_cache_key is traversals.NO_CACHE:
332            anon_map[traversals.NO_CACHE] = True
333            return None
334
335        cache_key = (
336            self.fn.__code__,
337            self.__class__,
338        ) + self.closure_cache_key
339
340        parent = self.parent_lambda
341        while parent is not None:
342            cache_key = (
343                (parent.fn.__code__,) + parent.closure_cache_key + cache_key
344            )
345
346            parent = parent.parent_lambda
347
348        if self._resolved_bindparams:
349            bindparams.extend(self._resolved_bindparams)
350        return cache_key
351
352    def _invoke_user_fn(self, fn, *arg):
353        return fn()
354
355
356class DeferredLambdaElement(LambdaElement):
357    """A LambdaElement where the lambda accepts arguments and is
358    invoked within the compile phase with special context.
359
360    This lambda doesn't normally produce its real SQL expression outside of the
361    compile phase.  It is passed a fixed set of initial arguments
362    so that it can generate a sample expression.
363
364    """
365
366    def __init__(self, fn, role, opts=LambdaOptions, lambda_args=()):
367        self.lambda_args = lambda_args
368        super(DeferredLambdaElement, self).__init__(fn, role, opts)
369
370    def _invoke_user_fn(self, fn, *arg):
371        return fn(*self.lambda_args)
372
373    def _resolve_with_args(self, *lambda_args):
374        tracker_fn = self._rec.tracker_instrumented_fn
375        expr = tracker_fn(*lambda_args)
376
377        expr = coercions.expect(self.role, expr)
378
379        expr = self._setup_binds_for_tracked_expr(expr)
380
381        # this validation is getting very close, but not quite, to achieving
382        # #5767.  The problem is if the base lambda uses an unnamed column
383        # as is very common with mixins, the parameter name is different
384        # and it produces a false positive; that is, for the documented case
385        # that is exactly what people will be doing, it doesn't work, so
386        # I'm not really sure how to handle this right now.
387        # expected_binds = [
388        #    b._orig_key
389        #    for b in self._rec.expr._generate_cache_key()[1]
390        #    if b.required
391        # ]
392        # got_binds = [
393        #    b._orig_key for b in expr._generate_cache_key()[1] if b.required
394        # ]
395        # if expected_binds != got_binds:
396        #    raise exc.InvalidRequestError(
397        #        "Lambda callable at %s produced a different set of bound "
398        #        "parameters than its original run: %s"
399        #        % (self.fn.__code__, ", ".join(got_binds))
400        #    )
401
402        # TODO: TEST TEST TEST, this is very out there
403        for deferred_copy_internals in self._transforms:
404            expr = deferred_copy_internals(expr)
405
406        return expr
407
408    def _copy_internals(
409        self, clone=_clone, deferred_copy_internals=None, **kw
410    ):
411        super(DeferredLambdaElement, self)._copy_internals(
412            clone=clone,
413            deferred_copy_internals=deferred_copy_internals,  # **kw
414            opts=kw,
415        )
416
417        # TODO: A LOT A LOT of tests.   for _resolve_with_args, we don't know
418        # our expression yet.   so hold onto the replacement
419        if deferred_copy_internals:
420            self._transforms += (deferred_copy_internals,)
421
422
423class StatementLambdaElement(roles.AllowsLambdaRole, LambdaElement):
424    """Represent a composable SQL statement as a :class:`_sql.LambdaElement`.
425
426    The :class:`_sql.StatementLambdaElement` is constructed using the
427    :func:`_sql.lambda_stmt` function::
428
429
430        from sqlalchemy import lambda_stmt
431
432        stmt = lambda_stmt(lambda: select(table))
433
434    Once constructed, additional criteria can be built onto the statement
435    by adding subsequent lambdas, which accept the existing statement
436    object as a single parameter::
437
438        stmt += lambda s: s.where(table.c.col == parameter)
439
440
441    .. versionadded:: 1.4
442
443    .. seealso::
444
445        :ref:`engine_lambda_caching`
446
447    """
448
449    def __add__(self, other):
450        return self.add_criteria(other)
451
452    def add_criteria(
453        self,
454        other,
455        enable_tracking=True,
456        track_on=None,
457        track_closure_variables=True,
458        track_bound_values=True,
459    ):
460        """Add new criteria to this :class:`_sql.StatementLambdaElement`.
461
462        E.g.::
463
464            >>> def my_stmt(parameter):
465            ...     stmt = lambda_stmt(
466            ...         lambda: select(table.c.x, table.c.y),
467            ...     )
468            ...     stmt = stmt.add_criteria(
469            ...         lambda: table.c.x > parameter
470            ...     )
471            ...     return stmt
472
473        The :meth:`_sql.StatementLambdaElement.add_criteria` method is
474        equivalent to using the Python addition operator to add a new
475        lambda, except that additional arguments may be added including
476        ``track_closure_values`` and ``track_on``::
477
478            >>> def my_stmt(self, foo):
479            ...     stmt = lambda_stmt(
480            ...         lambda: select(func.max(foo.x, foo.y)),
481            ...         track_closure_variables=False
482            ...     )
483            ...     stmt = stmt.add_criteria(
484            ...         lambda: self.where_criteria,
485            ...         track_on=[self]
486            ...     )
487            ...     return stmt
488
489        See :func:`_sql.lambda_stmt` for a description of the parameters
490        accepted.
491
492        """
493
494        opts = self.opts + dict(
495            enable_tracking=enable_tracking,
496            track_closure_variables=track_closure_variables,
497            global_track_bound_values=self.opts.global_track_bound_values,
498            track_on=track_on,
499            track_bound_values=track_bound_values,
500        )
501
502        return LinkedLambdaElement(other, parent_lambda=self, opts=opts)
503
504    def _execute_on_connection(
505        self, connection, multiparams, params, execution_options
506    ):
507        if self._rec.expected_expr.supports_execution:
508            return connection._execute_clauseelement(
509                self, multiparams, params, execution_options
510            )
511        else:
512            raise exc.ObjectNotExecutableError(self)
513
514    @property
515    def _with_options(self):
516        return self._rec.expected_expr._with_options
517
518    @property
519    def _effective_plugin_target(self):
520        return self._rec.expected_expr._effective_plugin_target
521
522    @property
523    def _execution_options(self):
524        return self._rec.expected_expr._execution_options
525
526    def spoil(self):
527        """Return a new :class:`.StatementLambdaElement` that will run
528        all lambdas unconditionally each time.
529
530        """
531        return NullLambdaStatement(self.fn())
532
533
534class NullLambdaStatement(roles.AllowsLambdaRole, elements.ClauseElement):
535    """Provides the :class:`.StatementLambdaElement` API but does not
536    cache or analyze lambdas.
537
538    the lambdas are instead invoked immediately.
539
540    The intended use is to isolate issues that may arise when using
541    lambda statements.
542
543    """
544
545    __visit_name__ = "lambda_element"
546
547    _is_lambda_element = True
548
549    _traverse_internals = [
550        ("_resolved", visitors.InternalTraversal.dp_clauseelement)
551    ]
552
553    def __init__(self, statement):
554        self._resolved = statement
555        self._propagate_attrs = statement._propagate_attrs
556
557    def __getattr__(self, key):
558        return getattr(self._resolved, key)
559
560    def __add__(self, other):
561        statement = other(self._resolved)
562
563        return NullLambdaStatement(statement)
564
565    def add_criteria(self, other, **kw):
566        statement = other(self._resolved)
567
568        return NullLambdaStatement(statement)
569
570    def _execute_on_connection(
571        self, connection, multiparams, params, execution_options
572    ):
573        if self._resolved.supports_execution:
574            return connection._execute_clauseelement(
575                self, multiparams, params, execution_options
576            )
577        else:
578            raise exc.ObjectNotExecutableError(self)
579
580
581class LinkedLambdaElement(StatementLambdaElement):
582    """Represent subsequent links of a :class:`.StatementLambdaElement`."""
583
584    role = None
585
586    def __init__(self, fn, parent_lambda, opts):
587        self.opts = opts
588        self.fn = fn
589        self.parent_lambda = parent_lambda
590
591        self.tracker_key = parent_lambda.tracker_key + (fn.__code__,)
592        self._retrieve_tracker_rec(fn, self, opts)
593        self._propagate_attrs = parent_lambda._propagate_attrs
594
595    def _invoke_user_fn(self, fn, *arg):
596        return fn(self.parent_lambda._resolved)
597
598
599class AnalyzedCode(object):
600    __slots__ = (
601        "track_closure_variables",
602        "track_bound_values",
603        "bindparam_trackers",
604        "closure_trackers",
605        "build_py_wrappers",
606    )
607    _fns = weakref.WeakKeyDictionary()
608
609    @classmethod
610    def get(cls, fn, lambda_element, lambda_kw, **kw):
611        try:
612            # TODO: validate kw haven't changed?
613            return cls._fns[fn.__code__]
614        except KeyError:
615            pass
616        cls._fns[fn.__code__] = analyzed = AnalyzedCode(
617            fn, lambda_element, lambda_kw, **kw
618        )
619        return analyzed
620
621    def __init__(self, fn, lambda_element, opts):
622        closure = fn.__closure__
623
624        self.track_bound_values = (
625            opts.track_bound_values and opts.global_track_bound_values
626        )
627        enable_tracking = opts.enable_tracking
628        track_on = opts.track_on
629        track_closure_variables = opts.track_closure_variables
630
631        self.track_closure_variables = track_closure_variables and not track_on
632
633        # a list of callables generated from _bound_parameter_getter_*
634        # functions.  Each of these uses a PyWrapper object to retrieve
635        # a parameter value
636        self.bindparam_trackers = []
637
638        # a list of callables generated from _cache_key_getter_* functions
639        # these callables work to generate a cache key for the lambda
640        # based on what's inside its closure variables.
641        self.closure_trackers = []
642
643        self.build_py_wrappers = []
644
645        if enable_tracking:
646            if track_on:
647                self._init_track_on(track_on)
648
649            self._init_globals(fn)
650
651            if closure:
652                self._init_closure(fn)
653
654        self._setup_additional_closure_trackers(fn, lambda_element, opts)
655
656    def _init_track_on(self, track_on):
657        self.closure_trackers.extend(
658            self._cache_key_getter_track_on(idx, elem)
659            for idx, elem in enumerate(track_on)
660        )
661
662    def _init_globals(self, fn):
663        build_py_wrappers = self.build_py_wrappers
664        bindparam_trackers = self.bindparam_trackers
665        track_bound_values = self.track_bound_values
666
667        for name in fn.__code__.co_names:
668            if name not in fn.__globals__:
669                continue
670
671            _bound_value = self._roll_down_to_literal(fn.__globals__[name])
672
673            if coercions._deep_is_literal(_bound_value):
674                build_py_wrappers.append((name, None))
675                if track_bound_values:
676                    bindparam_trackers.append(
677                        self._bound_parameter_getter_func_globals(name)
678                    )
679
680    def _init_closure(self, fn):
681        build_py_wrappers = self.build_py_wrappers
682        closure = fn.__closure__
683
684        track_bound_values = self.track_bound_values
685        track_closure_variables = self.track_closure_variables
686        bindparam_trackers = self.bindparam_trackers
687        closure_trackers = self.closure_trackers
688
689        for closure_index, (fv, cell) in enumerate(
690            zip(fn.__code__.co_freevars, closure)
691        ):
692            _bound_value = self._roll_down_to_literal(cell.cell_contents)
693
694            if coercions._deep_is_literal(_bound_value):
695                build_py_wrappers.append((fv, closure_index))
696                if track_bound_values:
697                    bindparam_trackers.append(
698                        self._bound_parameter_getter_func_closure(
699                            fv, closure_index
700                        )
701                    )
702            else:
703                # for normal cell contents, add them to a list that
704                # we can compare later when we get new lambdas.  if
705                # any identities have changed, then we will
706                # recalculate the whole lambda and run it again.
707
708                if track_closure_variables:
709                    closure_trackers.append(
710                        self._cache_key_getter_closure_variable(
711                            fn, fv, closure_index, cell.cell_contents
712                        )
713                    )
714
715    def _setup_additional_closure_trackers(self, fn, lambda_element, opts):
716        # an additional step is to actually run the function, then
717        # go through the PyWrapper objects that were set up to catch a bound
718        # parameter.   then if they *didn't* make a param, oh they're another
719        # object in the closure we have to track for our cache key.  so
720        # create trackers to catch those.
721
722        analyzed_function = AnalyzedFunction(
723            self,
724            lambda_element,
725            None,
726            fn,
727        )
728
729        closure_trackers = self.closure_trackers
730
731        for pywrapper in analyzed_function.closure_pywrappers:
732            if not pywrapper._sa__has_param:
733                closure_trackers.append(
734                    self._cache_key_getter_tracked_literal(fn, pywrapper)
735                )
736
737    @classmethod
738    def _roll_down_to_literal(cls, element):
739        is_clause_element = hasattr(element, "__clause_element__")
740
741        if is_clause_element:
742            while not isinstance(
743                element, (elements.ClauseElement, schema.SchemaItem, type)
744            ):
745                try:
746                    element = element.__clause_element__()
747                except AttributeError:
748                    break
749
750        if not is_clause_element:
751            insp = inspection.inspect(element, raiseerr=False)
752            if insp is not None:
753                try:
754                    return insp.__clause_element__()
755                except AttributeError:
756                    return insp
757
758            # TODO: should we coerce consts None/True/False here?
759            return element
760        else:
761            return element
762
763    def _bound_parameter_getter_func_globals(self, name):
764        """Return a getter that will extend a list of bound parameters
765        with new entries from the ``__globals__`` collection of a particular
766        lambda.
767
768        """
769
770        def extract_parameter_value(
771            current_fn, tracker_instrumented_fn, result
772        ):
773            wrapper = tracker_instrumented_fn.__globals__[name]
774            object.__getattribute__(wrapper, "_extract_bound_parameters")(
775                current_fn.__globals__[name], result
776            )
777
778        return extract_parameter_value
779
780    def _bound_parameter_getter_func_closure(self, name, closure_index):
781        """Return a getter that will extend a list of bound parameters
782        with new entries from the ``__closure__`` collection of a particular
783        lambda.
784
785        """
786
787        def extract_parameter_value(
788            current_fn, tracker_instrumented_fn, result
789        ):
790            wrapper = tracker_instrumented_fn.__closure__[
791                closure_index
792            ].cell_contents
793            object.__getattribute__(wrapper, "_extract_bound_parameters")(
794                current_fn.__closure__[closure_index].cell_contents, result
795            )
796
797        return extract_parameter_value
798
799    def _cache_key_getter_track_on(self, idx, elem):
800        """Return a getter that will extend a cache key with new entries
801        from the "track_on" parameter passed to a :class:`.LambdaElement`.
802
803        """
804
805        if isinstance(elem, tuple):
806            # tuple must contain hascachekey elements
807            def get(closure, opts, anon_map, bindparams):
808                return tuple(
809                    tup_elem._gen_cache_key(anon_map, bindparams)
810                    for tup_elem in opts.track_on[idx]
811                )
812
813        elif isinstance(elem, traversals.HasCacheKey):
814
815            def get(closure, opts, anon_map, bindparams):
816                return opts.track_on[idx]._gen_cache_key(anon_map, bindparams)
817
818        else:
819
820            def get(closure, opts, anon_map, bindparams):
821                return opts.track_on[idx]
822
823        return get
824
825    def _cache_key_getter_closure_variable(
826        self,
827        fn,
828        variable_name,
829        idx,
830        cell_contents,
831        use_clause_element=False,
832        use_inspect=False,
833    ):
834        """Return a getter that will extend a cache key with new entries
835        from the ``__closure__`` collection of a particular lambda.
836
837        """
838
839        if isinstance(cell_contents, traversals.HasCacheKey):
840
841            def get(closure, opts, anon_map, bindparams):
842
843                obj = closure[idx].cell_contents
844                if use_inspect:
845                    obj = inspection.inspect(obj)
846                elif use_clause_element:
847                    while hasattr(obj, "__clause_element__"):
848                        if not getattr(obj, "is_clause_element", False):
849                            obj = obj.__clause_element__()
850
851                return obj._gen_cache_key(anon_map, bindparams)
852
853        elif isinstance(cell_contents, types.FunctionType):
854
855            def get(closure, opts, anon_map, bindparams):
856                return closure[idx].cell_contents.__code__
857
858        elif isinstance(cell_contents, collections_abc.Sequence):
859
860            def get(closure, opts, anon_map, bindparams):
861                contents = closure[idx].cell_contents
862
863                try:
864                    return tuple(
865                        elem._gen_cache_key(anon_map, bindparams)
866                        for elem in contents
867                    )
868                except AttributeError as ae:
869                    self._raise_for_uncacheable_closure_variable(
870                        variable_name, fn, from_=ae
871                    )
872
873        else:
874            # if the object is a mapped class or aliased class, or some
875            # other object in the ORM realm of things like that, imitate
876            # the logic used in coercions.expect() to roll it down to the
877            # SQL element
878            element = cell_contents
879            is_clause_element = False
880            while hasattr(element, "__clause_element__"):
881                is_clause_element = True
882                if not getattr(element, "is_clause_element", False):
883                    element = element.__clause_element__()
884                else:
885                    break
886
887            if not is_clause_element:
888                insp = inspection.inspect(element, raiseerr=False)
889                if insp is not None:
890                    return self._cache_key_getter_closure_variable(
891                        fn, variable_name, idx, insp, use_inspect=True
892                    )
893            else:
894                return self._cache_key_getter_closure_variable(
895                    fn, variable_name, idx, element, use_clause_element=True
896                )
897
898            self._raise_for_uncacheable_closure_variable(variable_name, fn)
899
900        return get
901
902    def _raise_for_uncacheable_closure_variable(
903        self, variable_name, fn, from_=None
904    ):
905        util.raise_(
906            exc.InvalidRequestError(
907                "Closure variable named '%s' inside of lambda callable %s "
908                "does not refer to a cacheable SQL element, and also does not "
909                "appear to be serving as a SQL literal bound value based on "
910                "the default "
911                "SQL expression returned by the function.   This variable "
912                "needs to remain outside the scope of a SQL-generating lambda "
913                "so that a proper cache key may be generated from the "
914                "lambda's state.  Evaluate this variable outside of the "
915                "lambda, set track_on=[<elements>] to explicitly select "
916                "closure elements to track, or set "
917                "track_closure_variables=False to exclude "
918                "closure variables from being part of the cache key."
919                % (variable_name, fn.__code__),
920            ),
921            from_=from_,
922        )
923
924    def _cache_key_getter_tracked_literal(self, fn, pytracker):
925        """Return a getter that will extend a cache key with new entries
926        from the ``__closure__`` collection of a particular lambda.
927
928        this getter differs from _cache_key_getter_closure_variable
929        in that these are detected after the function is run, and PyWrapper
930        objects have recorded that a particular literal value is in fact
931        not being interpreted as a bound parameter.
932
933        """
934
935        elem = pytracker._sa__to_evaluate
936        closure_index = pytracker._sa__closure_index
937        variable_name = pytracker._sa__name
938
939        return self._cache_key_getter_closure_variable(
940            fn, variable_name, closure_index, elem
941        )
942
943
944class NonAnalyzedFunction(object):
945    __slots__ = ("expr",)
946
947    closure_bindparams = None
948    bindparam_trackers = None
949
950    def __init__(self, expr):
951        self.expr = expr
952
953    @property
954    def expected_expr(self):
955        return self.expr
956
957
958class AnalyzedFunction(object):
959    __slots__ = (
960        "analyzed_code",
961        "fn",
962        "closure_pywrappers",
963        "tracker_instrumented_fn",
964        "expr",
965        "bindparam_trackers",
966        "expected_expr",
967        "is_sequence",
968        "propagate_attrs",
969        "closure_bindparams",
970    )
971
972    def __init__(
973        self,
974        analyzed_code,
975        lambda_element,
976        apply_propagate_attrs,
977        fn,
978    ):
979        self.analyzed_code = analyzed_code
980        self.fn = fn
981
982        self.bindparam_trackers = analyzed_code.bindparam_trackers
983
984        self._instrument_and_run_function(lambda_element)
985
986        self._coerce_expression(lambda_element, apply_propagate_attrs)
987
988    def _instrument_and_run_function(self, lambda_element):
989        analyzed_code = self.analyzed_code
990
991        fn = self.fn
992        self.closure_pywrappers = closure_pywrappers = []
993
994        build_py_wrappers = analyzed_code.build_py_wrappers
995
996        if not build_py_wrappers:
997            self.tracker_instrumented_fn = tracker_instrumented_fn = fn
998            self.expr = lambda_element._invoke_user_fn(tracker_instrumented_fn)
999        else:
1000            track_closure_variables = analyzed_code.track_closure_variables
1001            closure = fn.__closure__
1002
1003            # will form the __closure__ of the function when we rebuild it
1004            if closure:
1005                new_closure = {
1006                    fv: cell.cell_contents
1007                    for fv, cell in zip(fn.__code__.co_freevars, closure)
1008                }
1009            else:
1010                new_closure = {}
1011
1012            # will form the __globals__ of the function when we rebuild it
1013            new_globals = fn.__globals__.copy()
1014
1015            for name, closure_index in build_py_wrappers:
1016                if closure_index is not None:
1017                    value = closure[closure_index].cell_contents
1018                    new_closure[name] = bind = PyWrapper(
1019                        fn,
1020                        name,
1021                        value,
1022                        closure_index=closure_index,
1023                        track_bound_values=(
1024                            self.analyzed_code.track_bound_values
1025                        ),
1026                    )
1027                    if track_closure_variables:
1028                        closure_pywrappers.append(bind)
1029                else:
1030                    value = fn.__globals__[name]
1031                    new_globals[name] = bind = PyWrapper(fn, name, value)
1032
1033            # rewrite the original fn.   things that look like they will
1034            # become bound parameters are wrapped in a PyWrapper.
1035            self.tracker_instrumented_fn = (
1036                tracker_instrumented_fn
1037            ) = self._rewrite_code_obj(
1038                fn,
1039                [new_closure[name] for name in fn.__code__.co_freevars],
1040                new_globals,
1041            )
1042
1043            # now invoke the function.  This will give us a new SQL
1044            # expression, but all the places that there would be a bound
1045            # parameter, the PyWrapper in its place will give us a bind
1046            # with a predictable name we can match up later.
1047
1048            # additionally, each PyWrapper will log that it did in fact
1049            # create a parameter, otherwise, it's some kind of Python
1050            # object in the closure and we want to track that, to make
1051            # sure it doesn't change to something else, or if it does,
1052            # that we create a different tracked function with that
1053            # variable.
1054            self.expr = lambda_element._invoke_user_fn(tracker_instrumented_fn)
1055
1056    def _coerce_expression(self, lambda_element, apply_propagate_attrs):
1057        """Run the tracker-generated expression through coercion rules.
1058
1059        After the user-defined lambda has been invoked to produce a statement
1060        for re-use, run it through coercion rules to both check that it's the
1061        correct type of object and also to coerce it to its useful form.
1062
1063        """
1064
1065        parent_lambda = lambda_element.parent_lambda
1066        expr = self.expr
1067
1068        if parent_lambda is None:
1069            if isinstance(expr, collections_abc.Sequence):
1070                self.expected_expr = [
1071                    coercions.expect(
1072                        lambda_element.role,
1073                        sub_expr,
1074                        apply_propagate_attrs=apply_propagate_attrs,
1075                    )
1076                    for sub_expr in expr
1077                ]
1078                self.is_sequence = True
1079            else:
1080                self.expected_expr = coercions.expect(
1081                    lambda_element.role,
1082                    expr,
1083                    apply_propagate_attrs=apply_propagate_attrs,
1084                )
1085                self.is_sequence = False
1086        else:
1087            self.expected_expr = expr
1088            self.is_sequence = False
1089
1090        if apply_propagate_attrs is not None:
1091            self.propagate_attrs = apply_propagate_attrs._propagate_attrs
1092        else:
1093            self.propagate_attrs = util.EMPTY_DICT
1094
1095    def _rewrite_code_obj(self, f, cell_values, globals_):
1096        """Return a copy of f, with a new closure and new globals
1097
1098        yes it works in pypy :P
1099
1100        """
1101
1102        argrange = range(len(cell_values))
1103
1104        code = "def make_cells():\n"
1105        if cell_values:
1106            code += "    (%s) = (%s)\n" % (
1107                ", ".join("i%d" % i for i in argrange),
1108                ", ".join("o%d" % i for i in argrange),
1109            )
1110        code += "    def closure():\n"
1111        code += "        return %s\n" % ", ".join("i%d" % i for i in argrange)
1112        code += "    return closure.__closure__"
1113        vars_ = {"o%d" % i: cell_values[i] for i in argrange}
1114        compat.exec_(code, vars_, vars_)
1115        closure = vars_["make_cells"]()
1116
1117        func = type(f)(
1118            f.__code__, globals_, f.__name__, f.__defaults__, closure
1119        )
1120        if sys.version_info >= (3,):
1121            func.__annotations__ = f.__annotations__
1122            func.__kwdefaults__ = f.__kwdefaults__
1123        func.__doc__ = f.__doc__
1124        func.__module__ = f.__module__
1125
1126        return func
1127
1128
1129class PyWrapper(ColumnOperators):
1130    """A wrapper object that is injected into the ``__globals__`` and
1131    ``__closure__`` of a Python function.
1132
1133    When the function is instrumented with :class:`.PyWrapper` objects, it is
1134    then invoked just once in order to set up the wrappers.  We look through
1135    all the :class:`.PyWrapper` objects we made to find the ones that generated
1136    a :class:`.BindParameter` object, e.g. the expression system interpreted
1137    something as a literal.   Those positions in the globals/closure are then
1138    ones that we will look at, each time a new lambda comes in that refers to
1139    the same ``__code__`` object.   In this way, we keep a single version of
1140    the SQL expression that this lambda produced, without calling upon the
1141    Python function that created it more than once, unless its other closure
1142    variables have changed.   The expression is then transformed to have the
1143    new bound values embedded into it.
1144
1145    """
1146
1147    def __init__(
1148        self,
1149        fn,
1150        name,
1151        to_evaluate,
1152        closure_index=None,
1153        getter=None,
1154        track_bound_values=True,
1155    ):
1156        self.fn = fn
1157        self._name = name
1158        self._to_evaluate = to_evaluate
1159        self._param = None
1160        self._has_param = False
1161        self._bind_paths = {}
1162        self._getter = getter
1163        self._closure_index = closure_index
1164        self.track_bound_values = track_bound_values
1165
1166    def __call__(self, *arg, **kw):
1167        elem = object.__getattribute__(self, "_to_evaluate")
1168        value = elem(*arg, **kw)
1169        if (
1170            self._sa_track_bound_values
1171            and coercions._deep_is_literal(value)
1172            and not isinstance(
1173                # TODO: coverage where an ORM option or similar is here
1174                value,
1175                traversals.HasCacheKey,
1176            )
1177        ):
1178            name = object.__getattribute__(self, "_name")
1179            raise exc.InvalidRequestError(
1180                "Can't invoke Python callable %s() inside of lambda "
1181                "expression argument at %s; lambda SQL constructs should "
1182                "not invoke functions from closure variables to produce "
1183                "literal values since the "
1184                "lambda SQL system normally extracts bound values without "
1185                "actually "
1186                "invoking the lambda or any functions within it.  Call the "
1187                "function outside of the "
1188                "lambda and assign to a local variable that is used in the "
1189                "lambda as a closure variable, or set "
1190                "track_bound_values=False if the return value of this "
1191                "function is used in some other way other than a SQL bound "
1192                "value." % (name, self._sa_fn.__code__)
1193            )
1194        else:
1195            return value
1196
1197    def operate(self, op, *other, **kwargs):
1198        elem = object.__getattribute__(self, "__clause_element__")()
1199        return op(elem, *other, **kwargs)
1200
1201    def reverse_operate(self, op, other, **kwargs):
1202        elem = object.__getattribute__(self, "__clause_element__")()
1203        return op(other, elem, **kwargs)
1204
1205    def _extract_bound_parameters(self, starting_point, result_list):
1206        param = object.__getattribute__(self, "_param")
1207        if param is not None:
1208            param = param._with_value(starting_point, maintain_key=True)
1209            result_list.append(param)
1210        for pywrapper in object.__getattribute__(self, "_bind_paths").values():
1211            getter = object.__getattribute__(pywrapper, "_getter")
1212            element = getter(starting_point)
1213            pywrapper._sa__extract_bound_parameters(element, result_list)
1214
1215    def __clause_element__(self):
1216        param = object.__getattribute__(self, "_param")
1217        to_evaluate = object.__getattribute__(self, "_to_evaluate")
1218        if param is None:
1219            name = object.__getattribute__(self, "_name")
1220            self._param = param = elements.BindParameter(
1221                name, required=False, unique=True
1222            )
1223            self._has_param = True
1224            param.type = type_api._resolve_value_to_type(to_evaluate)
1225        return param._with_value(to_evaluate, maintain_key=True)
1226
1227    def __bool__(self):
1228        to_evaluate = object.__getattribute__(self, "_to_evaluate")
1229        return bool(to_evaluate)
1230
1231    def __nonzero__(self):
1232        to_evaluate = object.__getattribute__(self, "_to_evaluate")
1233        return bool(to_evaluate)
1234
1235    def __getattribute__(self, key):
1236        if key.startswith("_sa_"):
1237            return object.__getattribute__(self, key[4:])
1238        elif key in (
1239            "__clause_element__",
1240            "operate",
1241            "reverse_operate",
1242            "__class__",
1243            "__dict__",
1244        ):
1245            return object.__getattribute__(self, key)
1246
1247        if key.startswith("__"):
1248            elem = object.__getattribute__(self, "_to_evaluate")
1249            return getattr(elem, key)
1250        else:
1251            return self._sa__add_getter(key, operator.attrgetter)
1252
1253    def __iter__(self):
1254        elem = object.__getattribute__(self, "_to_evaluate")
1255        return iter(elem)
1256
1257    def __getitem__(self, key):
1258        elem = object.__getattribute__(self, "_to_evaluate")
1259        if not hasattr(elem, "__getitem__"):
1260            raise AttributeError("__getitem__")
1261
1262        if isinstance(key, PyWrapper):
1263            # TODO: coverage
1264            raise exc.InvalidRequestError(
1265                "Dictionary keys / list indexes inside of a cached "
1266                "lambda must be Python literals only"
1267            )
1268        return self._sa__add_getter(key, operator.itemgetter)
1269
1270    def _add_getter(self, key, getter_fn):
1271
1272        bind_paths = object.__getattribute__(self, "_bind_paths")
1273
1274        bind_path_key = (key, getter_fn)
1275        if bind_path_key in bind_paths:
1276            return bind_paths[bind_path_key]
1277
1278        getter = getter_fn(key)
1279        elem = object.__getattribute__(self, "_to_evaluate")
1280        value = getter(elem)
1281
1282        rolled_down_value = AnalyzedCode._roll_down_to_literal(value)
1283
1284        if coercions._deep_is_literal(rolled_down_value):
1285            wrapper = PyWrapper(self._sa_fn, key, value, getter=getter)
1286            bind_paths[bind_path_key] = wrapper
1287            return wrapper
1288        else:
1289            return value
1290
1291
1292@inspection._inspects(LambdaElement)
1293def insp(lmb):
1294    return inspection.inspect(lmb._resolved)
1295