1# sql/lambdas.py 2# Copyright (C) 2005-2021 the SQLAlchemy authors and contributors 3# <see AUTHORS file> 4# 5# This module is part of SQLAlchemy and is released under 6# the MIT License: https://www.opensource.org/licenses/mit-license.php 7 8import itertools 9import operator 10import sys 11import types 12import weakref 13 14from . import coercions 15from . import elements 16from . import roles 17from . import schema 18from . import traversals 19from . import type_api 20from . import visitors 21from .base import _clone 22from .base import Options 23from .operators import ColumnOperators 24from .. import exc 25from .. import inspection 26from .. import util 27from ..util import collections_abc 28from ..util import compat 29 30_closure_per_cache_key = util.LRUCache(1000) 31 32 33class LambdaOptions(Options): 34 enable_tracking = True 35 track_closure_variables = True 36 track_on = None 37 global_track_bound_values = True 38 track_bound_values = True 39 lambda_cache = None 40 41 42def lambda_stmt( 43 lmb, 44 enable_tracking=True, 45 track_closure_variables=True, 46 track_on=None, 47 global_track_bound_values=True, 48 track_bound_values=True, 49 lambda_cache=None, 50): 51 """Produce a SQL statement that is cached as a lambda. 52 53 The Python code object within the lambda is scanned for both Python 54 literals that will become bound parameters as well as closure variables 55 that refer to Core or ORM constructs that may vary. The lambda itself 56 will be invoked only once per particular set of constructs detected. 57 58 E.g.:: 59 60 from sqlalchemy import lambda_stmt 61 62 stmt = lambda_stmt(lambda: table.select()) 63 stmt += lambda s: s.where(table.c.id == 5) 64 65 result = connection.execute(stmt) 66 67 The object returned is an instance of :class:`_sql.StatementLambdaElement`. 68 69 .. versionadded:: 1.4 70 71 :param lmb: a Python function, typically a lambda, which takes no arguments 72 and returns a SQL expression construct 73 :param enable_tracking: when False, all scanning of the given lambda for 74 changes in closure variables or bound parameters is disabled. Use for 75 a lambda that produces the identical results in all cases with no 76 parameterization. 77 :param track_closure_variables: when False, changes in closure variables 78 within the lambda will not be scanned. Use for a lambda where the 79 state of its closure variables will never change the SQL structure 80 returned by the lambda. 81 :param track_bound_values: when False, bound parameter tracking will 82 be disabled for the given lambda. Use for a lambda that either does 83 not produce any bound values, or where the initial bound values never 84 change. 85 :param global_track_bound_values: when False, bound parameter tracking 86 will be disabled for the entire statement including additional links 87 added via the :meth:`_sql.StatementLambdaElement.add_criteria` method. 88 :param lambda_cache: a dictionary or other mapping-like object where 89 information about the lambda's Python code as well as the tracked closure 90 variables in the lambda itself will be stored. Defaults 91 to a global LRU cache. This cache is independent of the "compiled_cache" 92 used by the :class:`_engine.Connection` object. 93 94 .. seealso:: 95 96 :ref:`engine_lambda_caching` 97 98 99 """ 100 101 return StatementLambdaElement( 102 lmb, 103 roles.StatementRole, 104 LambdaOptions( 105 enable_tracking=enable_tracking, 106 track_on=track_on, 107 track_closure_variables=track_closure_variables, 108 global_track_bound_values=global_track_bound_values, 109 track_bound_values=track_bound_values, 110 lambda_cache=lambda_cache, 111 ), 112 ) 113 114 115class LambdaElement(elements.ClauseElement): 116 """A SQL construct where the state is stored as an un-invoked lambda. 117 118 The :class:`_sql.LambdaElement` is produced transparently whenever 119 passing lambda expressions into SQL constructs, such as:: 120 121 stmt = select(table).where(lambda: table.c.col == parameter) 122 123 The :class:`_sql.LambdaElement` is the base of the 124 :class:`_sql.StatementLambdaElement` which represents a full statement 125 within a lambda. 126 127 .. versionadded:: 1.4 128 129 .. seealso:: 130 131 :ref:`engine_lambda_caching` 132 133 """ 134 135 __visit_name__ = "lambda_element" 136 137 _is_lambda_element = True 138 139 _traverse_internals = [ 140 ("_resolved", visitors.InternalTraversal.dp_clauseelement) 141 ] 142 143 _transforms = () 144 145 parent_lambda = None 146 147 def __repr__(self): 148 return "%s(%r)" % (self.__class__.__name__, self.fn.__code__) 149 150 def __init__( 151 self, fn, role, opts=LambdaOptions, apply_propagate_attrs=None 152 ): 153 self.fn = fn 154 self.role = role 155 self.tracker_key = (fn.__code__,) 156 self.opts = opts 157 158 if apply_propagate_attrs is None and (role is roles.StatementRole): 159 apply_propagate_attrs = self 160 161 rec = self._retrieve_tracker_rec(fn, apply_propagate_attrs, opts) 162 163 if apply_propagate_attrs is not None: 164 propagate_attrs = rec.propagate_attrs 165 if propagate_attrs: 166 apply_propagate_attrs._propagate_attrs = propagate_attrs 167 168 def _retrieve_tracker_rec(self, fn, apply_propagate_attrs, opts): 169 lambda_cache = opts.lambda_cache 170 if lambda_cache is None: 171 lambda_cache = _closure_per_cache_key 172 173 tracker_key = self.tracker_key 174 175 fn = self.fn 176 closure = fn.__closure__ 177 tracker = AnalyzedCode.get( 178 fn, 179 self, 180 opts, 181 ) 182 183 self._resolved_bindparams = bindparams = [] 184 185 if self.parent_lambda is not None: 186 parent_closure_cache_key = self.parent_lambda.closure_cache_key 187 else: 188 parent_closure_cache_key = () 189 190 if parent_closure_cache_key is not traversals.NO_CACHE: 191 anon_map = traversals.anon_map() 192 cache_key = tuple( 193 [ 194 getter(closure, opts, anon_map, bindparams) 195 for getter in tracker.closure_trackers 196 ] 197 ) 198 199 if traversals.NO_CACHE not in anon_map: 200 cache_key = parent_closure_cache_key + cache_key 201 202 self.closure_cache_key = cache_key 203 204 try: 205 rec = lambda_cache[tracker_key + cache_key] 206 except KeyError: 207 rec = None 208 else: 209 cache_key = traversals.NO_CACHE 210 rec = None 211 212 else: 213 cache_key = traversals.NO_CACHE 214 rec = None 215 216 self.closure_cache_key = cache_key 217 218 if rec is None: 219 if cache_key is not traversals.NO_CACHE: 220 rec = AnalyzedFunction( 221 tracker, self, apply_propagate_attrs, fn 222 ) 223 rec.closure_bindparams = bindparams 224 lambda_cache[tracker_key + cache_key] = rec 225 else: 226 rec = NonAnalyzedFunction(self._invoke_user_fn(fn)) 227 228 else: 229 bindparams[:] = [ 230 orig_bind._with_value(new_bind.value, maintain_key=True) 231 for orig_bind, new_bind in zip( 232 rec.closure_bindparams, bindparams 233 ) 234 ] 235 236 self._rec = rec 237 238 if cache_key is not traversals.NO_CACHE: 239 if self.parent_lambda is not None: 240 bindparams[:0] = self.parent_lambda._resolved_bindparams 241 242 lambda_element = self 243 while lambda_element is not None: 244 rec = lambda_element._rec 245 if rec.bindparam_trackers: 246 tracker_instrumented_fn = rec.tracker_instrumented_fn 247 for tracker in rec.bindparam_trackers: 248 tracker( 249 lambda_element.fn, 250 tracker_instrumented_fn, 251 bindparams, 252 ) 253 lambda_element = lambda_element.parent_lambda 254 255 return rec 256 257 def __getattr__(self, key): 258 return getattr(self._rec.expected_expr, key) 259 260 @property 261 def _is_sequence(self): 262 return self._rec.is_sequence 263 264 @property 265 def _select_iterable(self): 266 if self._is_sequence: 267 return itertools.chain.from_iterable( 268 [element._select_iterable for element in self._resolved] 269 ) 270 271 else: 272 return self._resolved._select_iterable 273 274 @property 275 def _from_objects(self): 276 if self._is_sequence: 277 return itertools.chain.from_iterable( 278 [element._from_objects for element in self._resolved] 279 ) 280 281 else: 282 return self._resolved._from_objects 283 284 def _param_dict(self): 285 return {b.key: b.value for b in self._resolved_bindparams} 286 287 def _setup_binds_for_tracked_expr(self, expr): 288 bindparam_lookup = {b.key: b for b in self._resolved_bindparams} 289 290 def replace(thing): 291 if isinstance(thing, elements.BindParameter): 292 293 if thing.key in bindparam_lookup: 294 bind = bindparam_lookup[thing.key] 295 if thing.expanding: 296 bind.expanding = True 297 bind.expand_op = thing.expand_op 298 bind.type = thing.type 299 return bind 300 301 if self._rec.is_sequence: 302 expr = [ 303 visitors.replacement_traverse(sub_expr, {}, replace) 304 for sub_expr in expr 305 ] 306 elif getattr(expr, "is_clause_element", False): 307 expr = visitors.replacement_traverse(expr, {}, replace) 308 309 return expr 310 311 def _copy_internals( 312 self, clone=_clone, deferred_copy_internals=None, **kw 313 ): 314 # TODO: this needs A LOT of tests 315 self._resolved = clone( 316 self._resolved, 317 deferred_copy_internals=deferred_copy_internals, 318 **kw 319 ) 320 321 @util.memoized_property 322 def _resolved(self): 323 expr = self._rec.expected_expr 324 325 if self._resolved_bindparams: 326 expr = self._setup_binds_for_tracked_expr(expr) 327 328 return expr 329 330 def _gen_cache_key(self, anon_map, bindparams): 331 if self.closure_cache_key is traversals.NO_CACHE: 332 anon_map[traversals.NO_CACHE] = True 333 return None 334 335 cache_key = ( 336 self.fn.__code__, 337 self.__class__, 338 ) + self.closure_cache_key 339 340 parent = self.parent_lambda 341 while parent is not None: 342 cache_key = ( 343 (parent.fn.__code__,) + parent.closure_cache_key + cache_key 344 ) 345 346 parent = parent.parent_lambda 347 348 if self._resolved_bindparams: 349 bindparams.extend(self._resolved_bindparams) 350 return cache_key 351 352 def _invoke_user_fn(self, fn, *arg): 353 return fn() 354 355 356class DeferredLambdaElement(LambdaElement): 357 """A LambdaElement where the lambda accepts arguments and is 358 invoked within the compile phase with special context. 359 360 This lambda doesn't normally produce its real SQL expression outside of the 361 compile phase. It is passed a fixed set of initial arguments 362 so that it can generate a sample expression. 363 364 """ 365 366 def __init__(self, fn, role, opts=LambdaOptions, lambda_args=()): 367 self.lambda_args = lambda_args 368 super(DeferredLambdaElement, self).__init__(fn, role, opts) 369 370 def _invoke_user_fn(self, fn, *arg): 371 return fn(*self.lambda_args) 372 373 def _resolve_with_args(self, *lambda_args): 374 tracker_fn = self._rec.tracker_instrumented_fn 375 expr = tracker_fn(*lambda_args) 376 377 expr = coercions.expect(self.role, expr) 378 379 expr = self._setup_binds_for_tracked_expr(expr) 380 381 # this validation is getting very close, but not quite, to achieving 382 # #5767. The problem is if the base lambda uses an unnamed column 383 # as is very common with mixins, the parameter name is different 384 # and it produces a false positive; that is, for the documented case 385 # that is exactly what people will be doing, it doesn't work, so 386 # I'm not really sure how to handle this right now. 387 # expected_binds = [ 388 # b._orig_key 389 # for b in self._rec.expr._generate_cache_key()[1] 390 # if b.required 391 # ] 392 # got_binds = [ 393 # b._orig_key for b in expr._generate_cache_key()[1] if b.required 394 # ] 395 # if expected_binds != got_binds: 396 # raise exc.InvalidRequestError( 397 # "Lambda callable at %s produced a different set of bound " 398 # "parameters than its original run: %s" 399 # % (self.fn.__code__, ", ".join(got_binds)) 400 # ) 401 402 # TODO: TEST TEST TEST, this is very out there 403 for deferred_copy_internals in self._transforms: 404 expr = deferred_copy_internals(expr) 405 406 return expr 407 408 def _copy_internals( 409 self, clone=_clone, deferred_copy_internals=None, **kw 410 ): 411 super(DeferredLambdaElement, self)._copy_internals( 412 clone=clone, 413 deferred_copy_internals=deferred_copy_internals, # **kw 414 opts=kw, 415 ) 416 417 # TODO: A LOT A LOT of tests. for _resolve_with_args, we don't know 418 # our expression yet. so hold onto the replacement 419 if deferred_copy_internals: 420 self._transforms += (deferred_copy_internals,) 421 422 423class StatementLambdaElement(roles.AllowsLambdaRole, LambdaElement): 424 """Represent a composable SQL statement as a :class:`_sql.LambdaElement`. 425 426 The :class:`_sql.StatementLambdaElement` is constructed using the 427 :func:`_sql.lambda_stmt` function:: 428 429 430 from sqlalchemy import lambda_stmt 431 432 stmt = lambda_stmt(lambda: select(table)) 433 434 Once constructed, additional criteria can be built onto the statement 435 by adding subsequent lambdas, which accept the existing statement 436 object as a single parameter:: 437 438 stmt += lambda s: s.where(table.c.col == parameter) 439 440 441 .. versionadded:: 1.4 442 443 .. seealso:: 444 445 :ref:`engine_lambda_caching` 446 447 """ 448 449 def __add__(self, other): 450 return self.add_criteria(other) 451 452 def add_criteria( 453 self, 454 other, 455 enable_tracking=True, 456 track_on=None, 457 track_closure_variables=True, 458 track_bound_values=True, 459 ): 460 """Add new criteria to this :class:`_sql.StatementLambdaElement`. 461 462 E.g.:: 463 464 >>> def my_stmt(parameter): 465 ... stmt = lambda_stmt( 466 ... lambda: select(table.c.x, table.c.y), 467 ... ) 468 ... stmt = stmt.add_criteria( 469 ... lambda: table.c.x > parameter 470 ... ) 471 ... return stmt 472 473 The :meth:`_sql.StatementLambdaElement.add_criteria` method is 474 equivalent to using the Python addition operator to add a new 475 lambda, except that additional arguments may be added including 476 ``track_closure_values`` and ``track_on``:: 477 478 >>> def my_stmt(self, foo): 479 ... stmt = lambda_stmt( 480 ... lambda: select(func.max(foo.x, foo.y)), 481 ... track_closure_variables=False 482 ... ) 483 ... stmt = stmt.add_criteria( 484 ... lambda: self.where_criteria, 485 ... track_on=[self] 486 ... ) 487 ... return stmt 488 489 See :func:`_sql.lambda_stmt` for a description of the parameters 490 accepted. 491 492 """ 493 494 opts = self.opts + dict( 495 enable_tracking=enable_tracking, 496 track_closure_variables=track_closure_variables, 497 global_track_bound_values=self.opts.global_track_bound_values, 498 track_on=track_on, 499 track_bound_values=track_bound_values, 500 ) 501 502 return LinkedLambdaElement(other, parent_lambda=self, opts=opts) 503 504 def _execute_on_connection( 505 self, connection, multiparams, params, execution_options 506 ): 507 if self._rec.expected_expr.supports_execution: 508 return connection._execute_clauseelement( 509 self, multiparams, params, execution_options 510 ) 511 else: 512 raise exc.ObjectNotExecutableError(self) 513 514 @property 515 def _with_options(self): 516 return self._rec.expected_expr._with_options 517 518 @property 519 def _effective_plugin_target(self): 520 return self._rec.expected_expr._effective_plugin_target 521 522 @property 523 def _execution_options(self): 524 return self._rec.expected_expr._execution_options 525 526 def spoil(self): 527 """Return a new :class:`.StatementLambdaElement` that will run 528 all lambdas unconditionally each time. 529 530 """ 531 return NullLambdaStatement(self.fn()) 532 533 534class NullLambdaStatement(roles.AllowsLambdaRole, elements.ClauseElement): 535 """Provides the :class:`.StatementLambdaElement` API but does not 536 cache or analyze lambdas. 537 538 the lambdas are instead invoked immediately. 539 540 The intended use is to isolate issues that may arise when using 541 lambda statements. 542 543 """ 544 545 __visit_name__ = "lambda_element" 546 547 _is_lambda_element = True 548 549 _traverse_internals = [ 550 ("_resolved", visitors.InternalTraversal.dp_clauseelement) 551 ] 552 553 def __init__(self, statement): 554 self._resolved = statement 555 self._propagate_attrs = statement._propagate_attrs 556 557 def __getattr__(self, key): 558 return getattr(self._resolved, key) 559 560 def __add__(self, other): 561 statement = other(self._resolved) 562 563 return NullLambdaStatement(statement) 564 565 def add_criteria(self, other, **kw): 566 statement = other(self._resolved) 567 568 return NullLambdaStatement(statement) 569 570 def _execute_on_connection( 571 self, connection, multiparams, params, execution_options 572 ): 573 if self._resolved.supports_execution: 574 return connection._execute_clauseelement( 575 self, multiparams, params, execution_options 576 ) 577 else: 578 raise exc.ObjectNotExecutableError(self) 579 580 581class LinkedLambdaElement(StatementLambdaElement): 582 """Represent subsequent links of a :class:`.StatementLambdaElement`.""" 583 584 role = None 585 586 def __init__(self, fn, parent_lambda, opts): 587 self.opts = opts 588 self.fn = fn 589 self.parent_lambda = parent_lambda 590 591 self.tracker_key = parent_lambda.tracker_key + (fn.__code__,) 592 self._retrieve_tracker_rec(fn, self, opts) 593 self._propagate_attrs = parent_lambda._propagate_attrs 594 595 def _invoke_user_fn(self, fn, *arg): 596 return fn(self.parent_lambda._resolved) 597 598 599class AnalyzedCode(object): 600 __slots__ = ( 601 "track_closure_variables", 602 "track_bound_values", 603 "bindparam_trackers", 604 "closure_trackers", 605 "build_py_wrappers", 606 ) 607 _fns = weakref.WeakKeyDictionary() 608 609 @classmethod 610 def get(cls, fn, lambda_element, lambda_kw, **kw): 611 try: 612 # TODO: validate kw haven't changed? 613 return cls._fns[fn.__code__] 614 except KeyError: 615 pass 616 cls._fns[fn.__code__] = analyzed = AnalyzedCode( 617 fn, lambda_element, lambda_kw, **kw 618 ) 619 return analyzed 620 621 def __init__(self, fn, lambda_element, opts): 622 closure = fn.__closure__ 623 624 self.track_bound_values = ( 625 opts.track_bound_values and opts.global_track_bound_values 626 ) 627 enable_tracking = opts.enable_tracking 628 track_on = opts.track_on 629 track_closure_variables = opts.track_closure_variables 630 631 self.track_closure_variables = track_closure_variables and not track_on 632 633 # a list of callables generated from _bound_parameter_getter_* 634 # functions. Each of these uses a PyWrapper object to retrieve 635 # a parameter value 636 self.bindparam_trackers = [] 637 638 # a list of callables generated from _cache_key_getter_* functions 639 # these callables work to generate a cache key for the lambda 640 # based on what's inside its closure variables. 641 self.closure_trackers = [] 642 643 self.build_py_wrappers = [] 644 645 if enable_tracking: 646 if track_on: 647 self._init_track_on(track_on) 648 649 self._init_globals(fn) 650 651 if closure: 652 self._init_closure(fn) 653 654 self._setup_additional_closure_trackers(fn, lambda_element, opts) 655 656 def _init_track_on(self, track_on): 657 self.closure_trackers.extend( 658 self._cache_key_getter_track_on(idx, elem) 659 for idx, elem in enumerate(track_on) 660 ) 661 662 def _init_globals(self, fn): 663 build_py_wrappers = self.build_py_wrappers 664 bindparam_trackers = self.bindparam_trackers 665 track_bound_values = self.track_bound_values 666 667 for name in fn.__code__.co_names: 668 if name not in fn.__globals__: 669 continue 670 671 _bound_value = self._roll_down_to_literal(fn.__globals__[name]) 672 673 if coercions._deep_is_literal(_bound_value): 674 build_py_wrappers.append((name, None)) 675 if track_bound_values: 676 bindparam_trackers.append( 677 self._bound_parameter_getter_func_globals(name) 678 ) 679 680 def _init_closure(self, fn): 681 build_py_wrappers = self.build_py_wrappers 682 closure = fn.__closure__ 683 684 track_bound_values = self.track_bound_values 685 track_closure_variables = self.track_closure_variables 686 bindparam_trackers = self.bindparam_trackers 687 closure_trackers = self.closure_trackers 688 689 for closure_index, (fv, cell) in enumerate( 690 zip(fn.__code__.co_freevars, closure) 691 ): 692 _bound_value = self._roll_down_to_literal(cell.cell_contents) 693 694 if coercions._deep_is_literal(_bound_value): 695 build_py_wrappers.append((fv, closure_index)) 696 if track_bound_values: 697 bindparam_trackers.append( 698 self._bound_parameter_getter_func_closure( 699 fv, closure_index 700 ) 701 ) 702 else: 703 # for normal cell contents, add them to a list that 704 # we can compare later when we get new lambdas. if 705 # any identities have changed, then we will 706 # recalculate the whole lambda and run it again. 707 708 if track_closure_variables: 709 closure_trackers.append( 710 self._cache_key_getter_closure_variable( 711 fn, fv, closure_index, cell.cell_contents 712 ) 713 ) 714 715 def _setup_additional_closure_trackers(self, fn, lambda_element, opts): 716 # an additional step is to actually run the function, then 717 # go through the PyWrapper objects that were set up to catch a bound 718 # parameter. then if they *didn't* make a param, oh they're another 719 # object in the closure we have to track for our cache key. so 720 # create trackers to catch those. 721 722 analyzed_function = AnalyzedFunction( 723 self, 724 lambda_element, 725 None, 726 fn, 727 ) 728 729 closure_trackers = self.closure_trackers 730 731 for pywrapper in analyzed_function.closure_pywrappers: 732 if not pywrapper._sa__has_param: 733 closure_trackers.append( 734 self._cache_key_getter_tracked_literal(fn, pywrapper) 735 ) 736 737 @classmethod 738 def _roll_down_to_literal(cls, element): 739 is_clause_element = hasattr(element, "__clause_element__") 740 741 if is_clause_element: 742 while not isinstance( 743 element, (elements.ClauseElement, schema.SchemaItem, type) 744 ): 745 try: 746 element = element.__clause_element__() 747 except AttributeError: 748 break 749 750 if not is_clause_element: 751 insp = inspection.inspect(element, raiseerr=False) 752 if insp is not None: 753 try: 754 return insp.__clause_element__() 755 except AttributeError: 756 return insp 757 758 # TODO: should we coerce consts None/True/False here? 759 return element 760 else: 761 return element 762 763 def _bound_parameter_getter_func_globals(self, name): 764 """Return a getter that will extend a list of bound parameters 765 with new entries from the ``__globals__`` collection of a particular 766 lambda. 767 768 """ 769 770 def extract_parameter_value( 771 current_fn, tracker_instrumented_fn, result 772 ): 773 wrapper = tracker_instrumented_fn.__globals__[name] 774 object.__getattribute__(wrapper, "_extract_bound_parameters")( 775 current_fn.__globals__[name], result 776 ) 777 778 return extract_parameter_value 779 780 def _bound_parameter_getter_func_closure(self, name, closure_index): 781 """Return a getter that will extend a list of bound parameters 782 with new entries from the ``__closure__`` collection of a particular 783 lambda. 784 785 """ 786 787 def extract_parameter_value( 788 current_fn, tracker_instrumented_fn, result 789 ): 790 wrapper = tracker_instrumented_fn.__closure__[ 791 closure_index 792 ].cell_contents 793 object.__getattribute__(wrapper, "_extract_bound_parameters")( 794 current_fn.__closure__[closure_index].cell_contents, result 795 ) 796 797 return extract_parameter_value 798 799 def _cache_key_getter_track_on(self, idx, elem): 800 """Return a getter that will extend a cache key with new entries 801 from the "track_on" parameter passed to a :class:`.LambdaElement`. 802 803 """ 804 805 if isinstance(elem, tuple): 806 # tuple must contain hascachekey elements 807 def get(closure, opts, anon_map, bindparams): 808 return tuple( 809 tup_elem._gen_cache_key(anon_map, bindparams) 810 for tup_elem in opts.track_on[idx] 811 ) 812 813 elif isinstance(elem, traversals.HasCacheKey): 814 815 def get(closure, opts, anon_map, bindparams): 816 return opts.track_on[idx]._gen_cache_key(anon_map, bindparams) 817 818 else: 819 820 def get(closure, opts, anon_map, bindparams): 821 return opts.track_on[idx] 822 823 return get 824 825 def _cache_key_getter_closure_variable( 826 self, 827 fn, 828 variable_name, 829 idx, 830 cell_contents, 831 use_clause_element=False, 832 use_inspect=False, 833 ): 834 """Return a getter that will extend a cache key with new entries 835 from the ``__closure__`` collection of a particular lambda. 836 837 """ 838 839 if isinstance(cell_contents, traversals.HasCacheKey): 840 841 def get(closure, opts, anon_map, bindparams): 842 843 obj = closure[idx].cell_contents 844 if use_inspect: 845 obj = inspection.inspect(obj) 846 elif use_clause_element: 847 while hasattr(obj, "__clause_element__"): 848 if not getattr(obj, "is_clause_element", False): 849 obj = obj.__clause_element__() 850 851 return obj._gen_cache_key(anon_map, bindparams) 852 853 elif isinstance(cell_contents, types.FunctionType): 854 855 def get(closure, opts, anon_map, bindparams): 856 return closure[idx].cell_contents.__code__ 857 858 elif isinstance(cell_contents, collections_abc.Sequence): 859 860 def get(closure, opts, anon_map, bindparams): 861 contents = closure[idx].cell_contents 862 863 try: 864 return tuple( 865 elem._gen_cache_key(anon_map, bindparams) 866 for elem in contents 867 ) 868 except AttributeError as ae: 869 self._raise_for_uncacheable_closure_variable( 870 variable_name, fn, from_=ae 871 ) 872 873 else: 874 # if the object is a mapped class or aliased class, or some 875 # other object in the ORM realm of things like that, imitate 876 # the logic used in coercions.expect() to roll it down to the 877 # SQL element 878 element = cell_contents 879 is_clause_element = False 880 while hasattr(element, "__clause_element__"): 881 is_clause_element = True 882 if not getattr(element, "is_clause_element", False): 883 element = element.__clause_element__() 884 else: 885 break 886 887 if not is_clause_element: 888 insp = inspection.inspect(element, raiseerr=False) 889 if insp is not None: 890 return self._cache_key_getter_closure_variable( 891 fn, variable_name, idx, insp, use_inspect=True 892 ) 893 else: 894 return self._cache_key_getter_closure_variable( 895 fn, variable_name, idx, element, use_clause_element=True 896 ) 897 898 self._raise_for_uncacheable_closure_variable(variable_name, fn) 899 900 return get 901 902 def _raise_for_uncacheable_closure_variable( 903 self, variable_name, fn, from_=None 904 ): 905 util.raise_( 906 exc.InvalidRequestError( 907 "Closure variable named '%s' inside of lambda callable %s " 908 "does not refer to a cacheable SQL element, and also does not " 909 "appear to be serving as a SQL literal bound value based on " 910 "the default " 911 "SQL expression returned by the function. This variable " 912 "needs to remain outside the scope of a SQL-generating lambda " 913 "so that a proper cache key may be generated from the " 914 "lambda's state. Evaluate this variable outside of the " 915 "lambda, set track_on=[<elements>] to explicitly select " 916 "closure elements to track, or set " 917 "track_closure_variables=False to exclude " 918 "closure variables from being part of the cache key." 919 % (variable_name, fn.__code__), 920 ), 921 from_=from_, 922 ) 923 924 def _cache_key_getter_tracked_literal(self, fn, pytracker): 925 """Return a getter that will extend a cache key with new entries 926 from the ``__closure__`` collection of a particular lambda. 927 928 this getter differs from _cache_key_getter_closure_variable 929 in that these are detected after the function is run, and PyWrapper 930 objects have recorded that a particular literal value is in fact 931 not being interpreted as a bound parameter. 932 933 """ 934 935 elem = pytracker._sa__to_evaluate 936 closure_index = pytracker._sa__closure_index 937 variable_name = pytracker._sa__name 938 939 return self._cache_key_getter_closure_variable( 940 fn, variable_name, closure_index, elem 941 ) 942 943 944class NonAnalyzedFunction(object): 945 __slots__ = ("expr",) 946 947 closure_bindparams = None 948 bindparam_trackers = None 949 950 def __init__(self, expr): 951 self.expr = expr 952 953 @property 954 def expected_expr(self): 955 return self.expr 956 957 958class AnalyzedFunction(object): 959 __slots__ = ( 960 "analyzed_code", 961 "fn", 962 "closure_pywrappers", 963 "tracker_instrumented_fn", 964 "expr", 965 "bindparam_trackers", 966 "expected_expr", 967 "is_sequence", 968 "propagate_attrs", 969 "closure_bindparams", 970 ) 971 972 def __init__( 973 self, 974 analyzed_code, 975 lambda_element, 976 apply_propagate_attrs, 977 fn, 978 ): 979 self.analyzed_code = analyzed_code 980 self.fn = fn 981 982 self.bindparam_trackers = analyzed_code.bindparam_trackers 983 984 self._instrument_and_run_function(lambda_element) 985 986 self._coerce_expression(lambda_element, apply_propagate_attrs) 987 988 def _instrument_and_run_function(self, lambda_element): 989 analyzed_code = self.analyzed_code 990 991 fn = self.fn 992 self.closure_pywrappers = closure_pywrappers = [] 993 994 build_py_wrappers = analyzed_code.build_py_wrappers 995 996 if not build_py_wrappers: 997 self.tracker_instrumented_fn = tracker_instrumented_fn = fn 998 self.expr = lambda_element._invoke_user_fn(tracker_instrumented_fn) 999 else: 1000 track_closure_variables = analyzed_code.track_closure_variables 1001 closure = fn.__closure__ 1002 1003 # will form the __closure__ of the function when we rebuild it 1004 if closure: 1005 new_closure = { 1006 fv: cell.cell_contents 1007 for fv, cell in zip(fn.__code__.co_freevars, closure) 1008 } 1009 else: 1010 new_closure = {} 1011 1012 # will form the __globals__ of the function when we rebuild it 1013 new_globals = fn.__globals__.copy() 1014 1015 for name, closure_index in build_py_wrappers: 1016 if closure_index is not None: 1017 value = closure[closure_index].cell_contents 1018 new_closure[name] = bind = PyWrapper( 1019 fn, 1020 name, 1021 value, 1022 closure_index=closure_index, 1023 track_bound_values=( 1024 self.analyzed_code.track_bound_values 1025 ), 1026 ) 1027 if track_closure_variables: 1028 closure_pywrappers.append(bind) 1029 else: 1030 value = fn.__globals__[name] 1031 new_globals[name] = bind = PyWrapper(fn, name, value) 1032 1033 # rewrite the original fn. things that look like they will 1034 # become bound parameters are wrapped in a PyWrapper. 1035 self.tracker_instrumented_fn = ( 1036 tracker_instrumented_fn 1037 ) = self._rewrite_code_obj( 1038 fn, 1039 [new_closure[name] for name in fn.__code__.co_freevars], 1040 new_globals, 1041 ) 1042 1043 # now invoke the function. This will give us a new SQL 1044 # expression, but all the places that there would be a bound 1045 # parameter, the PyWrapper in its place will give us a bind 1046 # with a predictable name we can match up later. 1047 1048 # additionally, each PyWrapper will log that it did in fact 1049 # create a parameter, otherwise, it's some kind of Python 1050 # object in the closure and we want to track that, to make 1051 # sure it doesn't change to something else, or if it does, 1052 # that we create a different tracked function with that 1053 # variable. 1054 self.expr = lambda_element._invoke_user_fn(tracker_instrumented_fn) 1055 1056 def _coerce_expression(self, lambda_element, apply_propagate_attrs): 1057 """Run the tracker-generated expression through coercion rules. 1058 1059 After the user-defined lambda has been invoked to produce a statement 1060 for re-use, run it through coercion rules to both check that it's the 1061 correct type of object and also to coerce it to its useful form. 1062 1063 """ 1064 1065 parent_lambda = lambda_element.parent_lambda 1066 expr = self.expr 1067 1068 if parent_lambda is None: 1069 if isinstance(expr, collections_abc.Sequence): 1070 self.expected_expr = [ 1071 coercions.expect( 1072 lambda_element.role, 1073 sub_expr, 1074 apply_propagate_attrs=apply_propagate_attrs, 1075 ) 1076 for sub_expr in expr 1077 ] 1078 self.is_sequence = True 1079 else: 1080 self.expected_expr = coercions.expect( 1081 lambda_element.role, 1082 expr, 1083 apply_propagate_attrs=apply_propagate_attrs, 1084 ) 1085 self.is_sequence = False 1086 else: 1087 self.expected_expr = expr 1088 self.is_sequence = False 1089 1090 if apply_propagate_attrs is not None: 1091 self.propagate_attrs = apply_propagate_attrs._propagate_attrs 1092 else: 1093 self.propagate_attrs = util.EMPTY_DICT 1094 1095 def _rewrite_code_obj(self, f, cell_values, globals_): 1096 """Return a copy of f, with a new closure and new globals 1097 1098 yes it works in pypy :P 1099 1100 """ 1101 1102 argrange = range(len(cell_values)) 1103 1104 code = "def make_cells():\n" 1105 if cell_values: 1106 code += " (%s) = (%s)\n" % ( 1107 ", ".join("i%d" % i for i in argrange), 1108 ", ".join("o%d" % i for i in argrange), 1109 ) 1110 code += " def closure():\n" 1111 code += " return %s\n" % ", ".join("i%d" % i for i in argrange) 1112 code += " return closure.__closure__" 1113 vars_ = {"o%d" % i: cell_values[i] for i in argrange} 1114 compat.exec_(code, vars_, vars_) 1115 closure = vars_["make_cells"]() 1116 1117 func = type(f)( 1118 f.__code__, globals_, f.__name__, f.__defaults__, closure 1119 ) 1120 if sys.version_info >= (3,): 1121 func.__annotations__ = f.__annotations__ 1122 func.__kwdefaults__ = f.__kwdefaults__ 1123 func.__doc__ = f.__doc__ 1124 func.__module__ = f.__module__ 1125 1126 return func 1127 1128 1129class PyWrapper(ColumnOperators): 1130 """A wrapper object that is injected into the ``__globals__`` and 1131 ``__closure__`` of a Python function. 1132 1133 When the function is instrumented with :class:`.PyWrapper` objects, it is 1134 then invoked just once in order to set up the wrappers. We look through 1135 all the :class:`.PyWrapper` objects we made to find the ones that generated 1136 a :class:`.BindParameter` object, e.g. the expression system interpreted 1137 something as a literal. Those positions in the globals/closure are then 1138 ones that we will look at, each time a new lambda comes in that refers to 1139 the same ``__code__`` object. In this way, we keep a single version of 1140 the SQL expression that this lambda produced, without calling upon the 1141 Python function that created it more than once, unless its other closure 1142 variables have changed. The expression is then transformed to have the 1143 new bound values embedded into it. 1144 1145 """ 1146 1147 def __init__( 1148 self, 1149 fn, 1150 name, 1151 to_evaluate, 1152 closure_index=None, 1153 getter=None, 1154 track_bound_values=True, 1155 ): 1156 self.fn = fn 1157 self._name = name 1158 self._to_evaluate = to_evaluate 1159 self._param = None 1160 self._has_param = False 1161 self._bind_paths = {} 1162 self._getter = getter 1163 self._closure_index = closure_index 1164 self.track_bound_values = track_bound_values 1165 1166 def __call__(self, *arg, **kw): 1167 elem = object.__getattribute__(self, "_to_evaluate") 1168 value = elem(*arg, **kw) 1169 if ( 1170 self._sa_track_bound_values 1171 and coercions._deep_is_literal(value) 1172 and not isinstance( 1173 # TODO: coverage where an ORM option or similar is here 1174 value, 1175 traversals.HasCacheKey, 1176 ) 1177 ): 1178 name = object.__getattribute__(self, "_name") 1179 raise exc.InvalidRequestError( 1180 "Can't invoke Python callable %s() inside of lambda " 1181 "expression argument at %s; lambda SQL constructs should " 1182 "not invoke functions from closure variables to produce " 1183 "literal values since the " 1184 "lambda SQL system normally extracts bound values without " 1185 "actually " 1186 "invoking the lambda or any functions within it. Call the " 1187 "function outside of the " 1188 "lambda and assign to a local variable that is used in the " 1189 "lambda as a closure variable, or set " 1190 "track_bound_values=False if the return value of this " 1191 "function is used in some other way other than a SQL bound " 1192 "value." % (name, self._sa_fn.__code__) 1193 ) 1194 else: 1195 return value 1196 1197 def operate(self, op, *other, **kwargs): 1198 elem = object.__getattribute__(self, "__clause_element__")() 1199 return op(elem, *other, **kwargs) 1200 1201 def reverse_operate(self, op, other, **kwargs): 1202 elem = object.__getattribute__(self, "__clause_element__")() 1203 return op(other, elem, **kwargs) 1204 1205 def _extract_bound_parameters(self, starting_point, result_list): 1206 param = object.__getattribute__(self, "_param") 1207 if param is not None: 1208 param = param._with_value(starting_point, maintain_key=True) 1209 result_list.append(param) 1210 for pywrapper in object.__getattribute__(self, "_bind_paths").values(): 1211 getter = object.__getattribute__(pywrapper, "_getter") 1212 element = getter(starting_point) 1213 pywrapper._sa__extract_bound_parameters(element, result_list) 1214 1215 def __clause_element__(self): 1216 param = object.__getattribute__(self, "_param") 1217 to_evaluate = object.__getattribute__(self, "_to_evaluate") 1218 if param is None: 1219 name = object.__getattribute__(self, "_name") 1220 self._param = param = elements.BindParameter( 1221 name, required=False, unique=True 1222 ) 1223 self._has_param = True 1224 param.type = type_api._resolve_value_to_type(to_evaluate) 1225 return param._with_value(to_evaluate, maintain_key=True) 1226 1227 def __bool__(self): 1228 to_evaluate = object.__getattribute__(self, "_to_evaluate") 1229 return bool(to_evaluate) 1230 1231 def __nonzero__(self): 1232 to_evaluate = object.__getattribute__(self, "_to_evaluate") 1233 return bool(to_evaluate) 1234 1235 def __getattribute__(self, key): 1236 if key.startswith("_sa_"): 1237 return object.__getattribute__(self, key[4:]) 1238 elif key in ( 1239 "__clause_element__", 1240 "operate", 1241 "reverse_operate", 1242 "__class__", 1243 "__dict__", 1244 ): 1245 return object.__getattribute__(self, key) 1246 1247 if key.startswith("__"): 1248 elem = object.__getattribute__(self, "_to_evaluate") 1249 return getattr(elem, key) 1250 else: 1251 return self._sa__add_getter(key, operator.attrgetter) 1252 1253 def __iter__(self): 1254 elem = object.__getattribute__(self, "_to_evaluate") 1255 return iter(elem) 1256 1257 def __getitem__(self, key): 1258 elem = object.__getattribute__(self, "_to_evaluate") 1259 if not hasattr(elem, "__getitem__"): 1260 raise AttributeError("__getitem__") 1261 1262 if isinstance(key, PyWrapper): 1263 # TODO: coverage 1264 raise exc.InvalidRequestError( 1265 "Dictionary keys / list indexes inside of a cached " 1266 "lambda must be Python literals only" 1267 ) 1268 return self._sa__add_getter(key, operator.itemgetter) 1269 1270 def _add_getter(self, key, getter_fn): 1271 1272 bind_paths = object.__getattribute__(self, "_bind_paths") 1273 1274 bind_path_key = (key, getter_fn) 1275 if bind_path_key in bind_paths: 1276 return bind_paths[bind_path_key] 1277 1278 getter = getter_fn(key) 1279 elem = object.__getattribute__(self, "_to_evaluate") 1280 value = getter(elem) 1281 1282 rolled_down_value = AnalyzedCode._roll_down_to_literal(value) 1283 1284 if coercions._deep_is_literal(rolled_down_value): 1285 wrapper = PyWrapper(self._sa_fn, key, value, getter=getter) 1286 bind_paths[bind_path_key] = wrapper 1287 return wrapper 1288 else: 1289 return value 1290 1291 1292@inspection._inspects(LambdaElement) 1293def insp(lmb): 1294 return inspection.inspect(lmb._resolved) 1295