1import datetime
2import decimal
3
4from sqlalchemy import ARRAY
5from sqlalchemy import bindparam
6from sqlalchemy import Column
7from sqlalchemy import Date
8from sqlalchemy import DateTime
9from sqlalchemy import extract
10from sqlalchemy import func
11from sqlalchemy import Integer
12from sqlalchemy import literal
13from sqlalchemy import literal_column
14from sqlalchemy import MetaData
15from sqlalchemy import Numeric
16from sqlalchemy import select
17from sqlalchemy import Sequence
18from sqlalchemy import sql
19from sqlalchemy import String
20from sqlalchemy import Table
21from sqlalchemy import testing
22from sqlalchemy import types as sqltypes
23from sqlalchemy import util
24from sqlalchemy.dialects import mysql
25from sqlalchemy.dialects import oracle
26from sqlalchemy.dialects import postgresql
27from sqlalchemy.dialects import sqlite
28from sqlalchemy.sql import column
29from sqlalchemy.sql import functions
30from sqlalchemy.sql import table
31from sqlalchemy.sql.compiler import BIND_TEMPLATES
32from sqlalchemy.sql.functions import FunctionElement
33from sqlalchemy.sql.functions import GenericFunction
34from sqlalchemy.testing import assert_raises_message
35from sqlalchemy.testing import AssertsCompiledSQL
36from sqlalchemy.testing import engines
37from sqlalchemy.testing import eq_
38from sqlalchemy.testing import fixtures
39from sqlalchemy.testing import is_
40from sqlalchemy.testing.engines import all_dialects
41
42table1 = table(
43    "mytable",
44    column("myid", Integer),
45    column("name", String),
46    column("description", String),
47)
48
49
50class CompileTest(fixtures.TestBase, AssertsCompiledSQL):
51    __dialect__ = "default"
52
53    def tear_down(self):
54        functions._registry.clear()
55
56    def test_compile(self):
57        for dialect in all_dialects(exclude=("sybase",)):
58            bindtemplate = BIND_TEMPLATES[dialect.paramstyle]
59            self.assert_compile(
60                func.current_timestamp(), "CURRENT_TIMESTAMP", dialect=dialect
61            )
62            self.assert_compile(func.localtime(), "LOCALTIME", dialect=dialect)
63            if dialect.name in ("firebird",):
64                self.assert_compile(
65                    func.nosuchfunction(), "nosuchfunction", dialect=dialect
66                )
67            else:
68                self.assert_compile(
69                    func.nosuchfunction(), "nosuchfunction()", dialect=dialect
70                )
71
72            # test generic function compile
73            class fake_func(GenericFunction):
74                __return_type__ = sqltypes.Integer
75
76                def __init__(self, arg, **kwargs):
77                    GenericFunction.__init__(self, arg, **kwargs)
78
79            self.assert_compile(
80                fake_func("foo"),
81                "fake_func(%s)"
82                % bindtemplate
83                % {"name": "fake_func_1", "position": 1},
84                dialect=dialect,
85            )
86
87    def test_use_labels(self):
88        self.assert_compile(
89            select([func.foo()], use_labels=True), "SELECT foo() AS foo_1"
90        )
91
92    def test_underscores(self):
93        self.assert_compile(func.if_(), "if()")
94
95    def test_generic_now(self):
96        assert isinstance(func.now().type, sqltypes.DateTime)
97
98        for ret, dialect in [
99            ("CURRENT_TIMESTAMP", sqlite.dialect()),
100            ("now()", postgresql.dialect()),
101            ("now()", mysql.dialect()),
102            ("CURRENT_TIMESTAMP", oracle.dialect()),
103        ]:
104            self.assert_compile(func.now(), ret, dialect=dialect)
105
106    def test_generic_random(self):
107        assert func.random().type == sqltypes.NULLTYPE
108        assert isinstance(func.random(type_=Integer).type, Integer)
109
110        for ret, dialect in [
111            ("random()", sqlite.dialect()),
112            ("random()", postgresql.dialect()),
113            ("rand()", mysql.dialect()),
114            ("random()", oracle.dialect()),
115        ]:
116            self.assert_compile(func.random(), ret, dialect=dialect)
117
118    def test_cube_operators(self):
119
120        t = table(
121            "t",
122            column("value"),
123            column("x"),
124            column("y"),
125            column("z"),
126            column("q"),
127        )
128
129        stmt = select([func.sum(t.c.value)])
130
131        self.assert_compile(
132            stmt.group_by(func.cube(t.c.x, t.c.y)),
133            "SELECT sum(t.value) AS sum_1 FROM t GROUP BY CUBE(t.x, t.y)",
134        )
135
136        self.assert_compile(
137            stmt.group_by(func.rollup(t.c.x, t.c.y)),
138            "SELECT sum(t.value) AS sum_1 FROM t GROUP BY ROLLUP(t.x, t.y)",
139        )
140
141        self.assert_compile(
142            stmt.group_by(func.grouping_sets(t.c.x, t.c.y)),
143            "SELECT sum(t.value) AS sum_1 FROM t "
144            "GROUP BY GROUPING SETS(t.x, t.y)",
145        )
146
147        self.assert_compile(
148            stmt.group_by(
149                func.grouping_sets(
150                    sql.tuple_(t.c.x, t.c.y), sql.tuple_(t.c.z, t.c.q)
151                )
152            ),
153            "SELECT sum(t.value) AS sum_1 FROM t GROUP BY "
154            "GROUPING SETS((t.x, t.y), (t.z, t.q))",
155        )
156
157    def test_generic_annotation(self):
158        fn = func.coalesce("x", "y")._annotate({"foo": "bar"})
159        self.assert_compile(fn, "coalesce(:coalesce_1, :coalesce_2)")
160
161    def test_custom_default_namespace(self):
162        class myfunc(GenericFunction):
163            pass
164
165        assert isinstance(func.myfunc(), myfunc)
166
167    def test_custom_type(self):
168        class myfunc(GenericFunction):
169            type = DateTime
170
171        assert isinstance(func.myfunc().type, DateTime)
172
173    def test_custom_legacy_type(self):
174        # in case someone was using this system
175        class myfunc(GenericFunction):
176            __return_type__ = DateTime
177
178        assert isinstance(func.myfunc().type, DateTime)
179
180    def test_custom_w_custom_name(self):
181        class myfunc(GenericFunction):
182            name = "notmyfunc"
183
184        assert isinstance(func.notmyfunc(), myfunc)
185        assert not isinstance(func.myfunc(), myfunc)
186
187    def test_custom_package_namespace(self):
188        def cls1(pk_name):
189            class myfunc(GenericFunction):
190                package = pk_name
191
192            return myfunc
193
194        f1 = cls1("mypackage")
195        f2 = cls1("myotherpackage")
196
197        assert isinstance(func.mypackage.myfunc(), f1)
198        assert isinstance(func.myotherpackage.myfunc(), f2)
199
200    def test_custom_name(self):
201        class MyFunction(GenericFunction):
202            name = "my_func"
203
204            def __init__(self, *args):
205                args = args + (3,)
206                super(MyFunction, self).__init__(*args)
207
208        self.assert_compile(
209            func.my_func(1, 2), "my_func(:my_func_1, :my_func_2, :my_func_3)"
210        )
211
212    def test_custom_registered_identifier(self):
213        class GeoBuffer(GenericFunction):
214            type = Integer
215            package = "geo"
216            name = "BufferOne"
217            identifier = "buf1"
218
219        class GeoBuffer2(GenericFunction):
220            type = Integer
221            name = "BufferTwo"
222            identifier = "buf2"
223
224        class BufferThree(GenericFunction):
225            type = Integer
226            identifier = "buf3"
227
228        self.assert_compile(func.geo.buf1(), "BufferOne()")
229        self.assert_compile(func.buf2(), "BufferTwo()")
230        self.assert_compile(func.buf3(), "BufferThree()")
231
232    def test_custom_args(self):
233        class myfunc(GenericFunction):
234            pass
235
236        self.assert_compile(
237            myfunc(1, 2, 3), "myfunc(:myfunc_1, :myfunc_2, :myfunc_3)"
238        )
239
240    def test_namespacing_conflicts(self):
241        self.assert_compile(func.text("foo"), "text(:text_1)")
242
243    def test_generic_count(self):
244        assert isinstance(func.count().type, sqltypes.Integer)
245
246        self.assert_compile(func.count(), "count(*)")
247        self.assert_compile(func.count(1), "count(:count_1)")
248        c = column("abc")
249        self.assert_compile(func.count(c), "count(abc)")
250
251    def test_constructor(self):
252        try:
253            func.current_timestamp("somearg")
254            assert False
255        except TypeError:
256            assert True
257
258        try:
259            func.char_length("a", "b")
260            assert False
261        except TypeError:
262            assert True
263
264        try:
265            func.char_length()
266            assert False
267        except TypeError:
268            assert True
269
270    def test_return_type_detection(self):
271
272        for fn in [func.coalesce, func.max, func.min, func.sum]:
273            for args, type_ in [
274                (
275                    (datetime.date(2007, 10, 5), datetime.date(2005, 10, 15)),
276                    sqltypes.Date,
277                ),
278                ((3, 5), sqltypes.Integer),
279                ((decimal.Decimal(3), decimal.Decimal(5)), sqltypes.Numeric),
280                (("foo", "bar"), sqltypes.String),
281                (
282                    (
283                        datetime.datetime(2007, 10, 5, 8, 3, 34),
284                        datetime.datetime(2005, 10, 15, 14, 45, 33),
285                    ),
286                    sqltypes.DateTime,
287                ),
288            ]:
289                assert isinstance(fn(*args).type, type_), "%s / %r != %s" % (
290                    fn(),
291                    fn(*args).type,
292                    type_,
293                )
294
295        assert isinstance(func.concat("foo", "bar").type, sqltypes.String)
296
297    def test_assorted(self):
298        table1 = table("mytable", column("myid", Integer))
299
300        table2 = table("myothertable", column("otherid", Integer))
301
302        # test an expression with a function
303        self.assert_compile(
304            func.lala(3, 4, literal("five"), table1.c.myid) * table2.c.otherid,
305            "lala(:lala_1, :lala_2, :param_1, mytable.myid) * "
306            "myothertable.otherid",
307        )
308
309        # test it in a SELECT
310        self.assert_compile(
311            select([func.count(table1.c.myid)]),
312            "SELECT count(mytable.myid) AS count_1 FROM mytable",
313        )
314
315        # test a "dotted" function name
316        self.assert_compile(
317            select([func.foo.bar.lala(table1.c.myid)]),
318            "SELECT foo.bar.lala(mytable.myid) AS lala_1 FROM mytable",
319        )
320
321        # test the bind parameter name with a "dotted" function name is
322        # only the name (limits the length of the bind param name)
323        self.assert_compile(
324            select([func.foo.bar.lala(12)]),
325            "SELECT foo.bar.lala(:lala_2) AS lala_1",
326        )
327
328        # test a dotted func off the engine itself
329        self.assert_compile(func.lala.hoho(7), "lala.hoho(:hoho_1)")
330
331        # test None becomes NULL
332        self.assert_compile(
333            func.my_func(1, 2, None, 3),
334            "my_func(:my_func_1, :my_func_2, NULL, :my_func_3)",
335        )
336
337        # test pickling
338        self.assert_compile(
339            util.pickle.loads(util.pickle.dumps(func.my_func(1, 2, None, 3))),
340            "my_func(:my_func_1, :my_func_2, NULL, :my_func_3)",
341        )
342
343        # assert func raises AttributeError for __bases__ attribute, since
344        # its not a class fixes pydoc
345        try:
346            func.__bases__
347            assert False
348        except AttributeError:
349            assert True
350
351    def test_functions_with_cols(self):
352        users = table(
353            "users", column("id"), column("name"), column("fullname")
354        )
355        calculate = select(
356            [column("q"), column("z"), column("r")],
357            from_obj=[
358                func.calculate(bindparam("x", None), bindparam("y", None))
359            ],
360        )
361
362        self.assert_compile(
363            select([users], users.c.id > calculate.c.z),
364            "SELECT users.id, users.name, users.fullname "
365            "FROM users, (SELECT q, z, r "
366            "FROM calculate(:x, :y)) "
367            "WHERE users.id > z",
368        )
369
370        s = select(
371            [users],
372            users.c.id.between(
373                calculate.alias("c1").unique_params(x=17, y=45).c.z,
374                calculate.alias("c2").unique_params(x=5, y=12).c.z,
375            ),
376        )
377
378        self.assert_compile(
379            s,
380            "SELECT users.id, users.name, users.fullname "
381            "FROM users, (SELECT q, z, r "
382            "FROM calculate(:x_1, :y_1)) AS c1, (SELECT q, z, r "
383            "FROM calculate(:x_2, :y_2)) AS c2 "
384            "WHERE users.id BETWEEN c1.z AND c2.z",
385            checkparams={"y_1": 45, "x_1": 17, "y_2": 12, "x_2": 5},
386        )
387
388    def test_non_functions(self):
389        expr = func.cast("foo", Integer)
390        self.assert_compile(expr, "CAST(:param_1 AS INTEGER)")
391
392        expr = func.extract("year", datetime.date(2010, 12, 5))
393        self.assert_compile(expr, "EXTRACT(year FROM :param_1)")
394
395    def test_select_method_one(self):
396        expr = func.rows("foo")
397        self.assert_compile(expr.select(), "SELECT rows(:rows_2) AS rows_1")
398
399    def test_alias_method_one(self):
400        expr = func.rows("foo")
401        self.assert_compile(expr.alias(), "rows(:rows_1)")
402
403    def test_select_method_two(self):
404        expr = func.rows("foo")
405        self.assert_compile(
406            select(["*"]).select_from(expr.select()),
407            "SELECT * FROM (SELECT rows(:rows_2) AS rows_1)",
408        )
409
410    def test_select_method_three(self):
411        expr = func.rows("foo")
412        self.assert_compile(
413            select([column("foo")]).select_from(expr),
414            "SELECT foo FROM rows(:rows_1)",
415        )
416
417    def test_alias_method_two(self):
418        expr = func.rows("foo")
419        self.assert_compile(
420            select(["*"]).select_from(expr.alias("bar")),
421            "SELECT * FROM rows(:rows_1) AS bar",
422        )
423
424    def test_alias_method_columns(self):
425        expr = func.rows("foo").alias("bar")
426
427        # this isn't very useful but is the old behavior
428        # prior to #2974.
429        # testing here that the expression exports its column
430        # list in a way that at least doesn't break.
431        self.assert_compile(
432            select([expr]), "SELECT bar.rows_1 FROM rows(:rows_2) AS bar"
433        )
434
435    def test_alias_method_columns_two(self):
436        expr = func.rows("foo").alias("bar")
437        assert len(expr.c)
438
439    def test_funcfilter_empty(self):
440        self.assert_compile(func.count(1).filter(), "count(:count_1)")
441
442    def test_funcfilter_criterion(self):
443        self.assert_compile(
444            func.count(1).filter(table1.c.name != None),  # noqa
445            "count(:count_1) FILTER (WHERE mytable.name IS NOT NULL)",
446        )
447
448    def test_funcfilter_compound_criterion(self):
449        self.assert_compile(
450            func.count(1).filter(
451                table1.c.name == None, table1.c.myid > 0  # noqa
452            ),
453            "count(:count_1) FILTER (WHERE mytable.name IS NULL AND "
454            "mytable.myid > :myid_1)",
455        )
456
457    def test_funcfilter_label(self):
458        self.assert_compile(
459            select(
460                [
461                    func.count(1)
462                    .filter(table1.c.description != None)  # noqa
463                    .label("foo")
464                ]
465            ),
466            "SELECT count(:count_1) FILTER (WHERE mytable.description "
467            "IS NOT NULL) AS foo FROM mytable",
468        )
469
470    def test_funcfilter_fromobj_fromfunc(self):
471        # test from_obj generation.
472        # from func:
473        self.assert_compile(
474            select(
475                [
476                    func.max(table1.c.name).filter(
477                        literal_column("description") != None  # noqa
478                    )
479                ]
480            ),
481            "SELECT max(mytable.name) FILTER (WHERE description "
482            "IS NOT NULL) AS anon_1 FROM mytable",
483        )
484
485    def test_funcfilter_fromobj_fromcriterion(self):
486        # from criterion:
487        self.assert_compile(
488            select([func.count(1).filter(table1.c.name == "name")]),
489            "SELECT count(:count_1) FILTER (WHERE mytable.name = :name_1) "
490            "AS anon_1 FROM mytable",
491        )
492
493    def test_funcfilter_chaining(self):
494        # test chaining:
495        self.assert_compile(
496            select(
497                [
498                    func.count(1)
499                    .filter(table1.c.name == "name")
500                    .filter(table1.c.description == "description")
501                ]
502            ),
503            "SELECT count(:count_1) FILTER (WHERE "
504            "mytable.name = :name_1 AND mytable.description = :description_1) "
505            "AS anon_1 FROM mytable",
506        )
507
508    def test_funcfilter_windowing_orderby(self):
509        # test filtered windowing:
510        self.assert_compile(
511            select(
512                [
513                    func.rank()
514                    .filter(table1.c.name > "foo")
515                    .over(order_by=table1.c.name)
516                ]
517            ),
518            "SELECT rank() FILTER (WHERE mytable.name > :name_1) "
519            "OVER (ORDER BY mytable.name) AS anon_1 FROM mytable",
520        )
521
522    def test_funcfilter_windowing_orderby_partitionby(self):
523        self.assert_compile(
524            select(
525                [
526                    func.rank()
527                    .filter(table1.c.name > "foo")
528                    .over(order_by=table1.c.name, partition_by=["description"])
529                ]
530            ),
531            "SELECT rank() FILTER (WHERE mytable.name > :name_1) "
532            "OVER (PARTITION BY mytable.description ORDER BY mytable.name) "
533            "AS anon_1 FROM mytable",
534        )
535
536    def test_funcfilter_windowing_range(self):
537        self.assert_compile(
538            select(
539                [
540                    func.rank()
541                    .filter(table1.c.name > "foo")
542                    .over(range_=(1, 5), partition_by=["description"])
543                ]
544            ),
545            "SELECT rank() FILTER (WHERE mytable.name > :name_1) "
546            "OVER (PARTITION BY mytable.description RANGE BETWEEN :param_1 "
547            "FOLLOWING AND :param_2 FOLLOWING) "
548            "AS anon_1 FROM mytable",
549        )
550
551    def test_funcfilter_windowing_rows(self):
552        self.assert_compile(
553            select(
554                [
555                    func.rank()
556                    .filter(table1.c.name > "foo")
557                    .over(rows=(1, 5), partition_by=["description"])
558                ]
559            ),
560            "SELECT rank() FILTER (WHERE mytable.name > :name_1) "
561            "OVER (PARTITION BY mytable.description ROWS BETWEEN :param_1 "
562            "FOLLOWING AND :param_2 FOLLOWING) "
563            "AS anon_1 FROM mytable",
564        )
565
566    def test_funcfilter_within_group(self):
567        stmt = select(
568            [
569                table1.c.myid,
570                func.percentile_cont(0.5).within_group(table1.c.name),
571            ]
572        )
573        self.assert_compile(
574            stmt,
575            "SELECT mytable.myid, percentile_cont(:percentile_cont_1) "
576            "WITHIN GROUP (ORDER BY mytable.name) "
577            "AS anon_1 "
578            "FROM mytable",
579            {"percentile_cont_1": 0.5},
580        )
581
582    def test_funcfilter_within_group_multi(self):
583        stmt = select(
584            [
585                table1.c.myid,
586                func.percentile_cont(0.5).within_group(
587                    table1.c.name, table1.c.description
588                ),
589            ]
590        )
591        self.assert_compile(
592            stmt,
593            "SELECT mytable.myid, percentile_cont(:percentile_cont_1) "
594            "WITHIN GROUP (ORDER BY mytable.name, mytable.description) "
595            "AS anon_1 "
596            "FROM mytable",
597            {"percentile_cont_1": 0.5},
598        )
599
600    def test_funcfilter_within_group_desc(self):
601        stmt = select(
602            [
603                table1.c.myid,
604                func.percentile_cont(0.5).within_group(table1.c.name.desc()),
605            ]
606        )
607        self.assert_compile(
608            stmt,
609            "SELECT mytable.myid, percentile_cont(:percentile_cont_1) "
610            "WITHIN GROUP (ORDER BY mytable.name DESC) "
611            "AS anon_1 "
612            "FROM mytable",
613            {"percentile_cont_1": 0.5},
614        )
615
616    def test_funcfilter_within_group_w_over(self):
617        stmt = select(
618            [
619                table1.c.myid,
620                func.percentile_cont(0.5)
621                .within_group(table1.c.name.desc())
622                .over(partition_by=table1.c.description),
623            ]
624        )
625        self.assert_compile(
626            stmt,
627            "SELECT mytable.myid, percentile_cont(:percentile_cont_1) "
628            "WITHIN GROUP (ORDER BY mytable.name DESC) "
629            "OVER (PARTITION BY mytable.description) AS anon_1 "
630            "FROM mytable",
631            {"percentile_cont_1": 0.5},
632        )
633
634    def test_incorrect_none_type(self):
635        class MissingType(FunctionElement):
636            name = "mt"
637            type = None
638
639        assert_raises_message(
640            TypeError,
641            "Object None associated with '.type' attribute is "
642            "not a TypeEngine class or object",
643            MissingType().compile,
644        )
645
646
647class ReturnTypeTest(AssertsCompiledSQL, fixtures.TestBase):
648    def test_array_agg(self):
649        expr = func.array_agg(column("data", Integer))
650        is_(expr.type._type_affinity, ARRAY)
651        is_(expr.type.item_type._type_affinity, Integer)
652
653    def test_array_agg_array_datatype(self):
654        expr = func.array_agg(column("data", ARRAY(Integer)))
655        is_(expr.type._type_affinity, ARRAY)
656        is_(expr.type.item_type._type_affinity, Integer)
657
658    def test_array_agg_array_literal_implicit_type(self):
659        from sqlalchemy.dialects.postgresql import array, ARRAY as PG_ARRAY
660
661        expr = array([column("data", Integer), column("d2", Integer)])
662
663        assert isinstance(expr.type, PG_ARRAY)
664
665        agg_expr = func.array_agg(expr)
666        assert isinstance(agg_expr.type, PG_ARRAY)
667        is_(agg_expr.type._type_affinity, ARRAY)
668        is_(agg_expr.type.item_type._type_affinity, Integer)
669
670        self.assert_compile(
671            agg_expr, "array_agg(ARRAY[data, d2])", dialect="postgresql"
672        )
673
674    def test_array_agg_array_literal_explicit_type(self):
675        from sqlalchemy.dialects.postgresql import array
676
677        expr = array([column("data", Integer), column("d2", Integer)])
678
679        agg_expr = func.array_agg(expr, type_=ARRAY(Integer))
680        is_(agg_expr.type._type_affinity, ARRAY)
681        is_(agg_expr.type.item_type._type_affinity, Integer)
682
683        self.assert_compile(
684            agg_expr, "array_agg(ARRAY[data, d2])", dialect="postgresql"
685        )
686
687    def test_mode(self):
688        expr = func.mode(0.5).within_group(column("data", Integer).desc())
689        is_(expr.type._type_affinity, Integer)
690
691    def test_percentile_cont(self):
692        expr = func.percentile_cont(0.5).within_group(column("data", Integer))
693        is_(expr.type._type_affinity, Integer)
694
695    def test_percentile_cont_array(self):
696        expr = func.percentile_cont(0.5, 0.7).within_group(
697            column("data", Integer)
698        )
699        is_(expr.type._type_affinity, ARRAY)
700        is_(expr.type.item_type._type_affinity, Integer)
701
702    def test_percentile_cont_array_desc(self):
703        expr = func.percentile_cont(0.5, 0.7).within_group(
704            column("data", Integer).desc()
705        )
706        is_(expr.type._type_affinity, ARRAY)
707        is_(expr.type.item_type._type_affinity, Integer)
708
709    def test_cume_dist(self):
710        expr = func.cume_dist(0.5).within_group(column("data", Integer).desc())
711        is_(expr.type._type_affinity, Numeric)
712
713    def test_percent_rank(self):
714        expr = func.percent_rank(0.5).within_group(column("data", Integer))
715        is_(expr.type._type_affinity, Numeric)
716
717
718class ExecuteTest(fixtures.TestBase):
719    __backend__ = True
720
721    @engines.close_first
722    def tearDown(self):
723        pass
724
725    def test_conn_execute(self):
726        from sqlalchemy.sql.expression import FunctionElement
727        from sqlalchemy.ext.compiler import compiles
728
729        class myfunc(FunctionElement):
730            type = Date()
731
732        @compiles(myfunc)
733        def compile_(elem, compiler, **kw):
734            return compiler.process(func.current_date())
735
736        conn = testing.db.connect()
737        try:
738            x = conn.execute(func.current_date()).scalar()
739            y = conn.execute(func.current_date().select()).scalar()
740            z = conn.scalar(func.current_date())
741            q = conn.scalar(myfunc())
742        finally:
743            conn.close()
744        assert (x == y == z == q) is True
745
746    def test_exec_options(self):
747        f = func.foo()
748        eq_(f._execution_options, {})
749
750        f = f.execution_options(foo="bar")
751        eq_(f._execution_options, {"foo": "bar"})
752        s = f.select()
753        eq_(s._execution_options, {"foo": "bar"})
754
755        ret = testing.db.execute(func.now().execution_options(foo="bar"))
756        eq_(ret.context.execution_options, {"foo": "bar"})
757        ret.close()
758
759    @engines.close_first
760    @testing.provide_metadata
761    def test_update(self):
762        """
763        Tests sending functions and SQL expressions to the VALUES and SET
764        clauses of INSERT/UPDATE instances, and that column-level defaults
765        get overridden.
766        """
767
768        meta = self.metadata
769        t = Table(
770            "t1",
771            meta,
772            Column(
773                "id",
774                Integer,
775                Sequence("t1idseq", optional=True),
776                primary_key=True,
777            ),
778            Column("value", Integer),
779        )
780        t2 = Table(
781            "t2",
782            meta,
783            Column(
784                "id",
785                Integer,
786                Sequence("t2idseq", optional=True),
787                primary_key=True,
788            ),
789            Column("value", Integer, default=7),
790            Column("stuff", String(20), onupdate="thisisstuff"),
791        )
792        meta.create_all()
793        t.insert(values=dict(value=func.length("one"))).execute()
794        assert t.select().execute().first()["value"] == 3
795        t.update(values=dict(value=func.length("asfda"))).execute()
796        assert t.select().execute().first()["value"] == 5
797
798        r = t.insert(values=dict(value=func.length("sfsaafsda"))).execute()
799        id_ = r.inserted_primary_key[0]
800        assert t.select(t.c.id == id_).execute().first()["value"] == 9
801        t.update(values={t.c.value: func.length("asdf")}).execute()
802        assert t.select().execute().first()["value"] == 4
803        t2.insert().execute()
804        t2.insert(values=dict(value=func.length("one"))).execute()
805        t2.insert(values=dict(value=func.length("asfda") + -19)).execute(
806            stuff="hi"
807        )
808
809        res = exec_sorted(select([t2.c.value, t2.c.stuff]))
810        eq_(res, [(-14, "hi"), (3, None), (7, None)])
811
812        t2.update(values=dict(value=func.length("asdsafasd"))).execute(
813            stuff="some stuff"
814        )
815        assert select([t2.c.value, t2.c.stuff]).execute().fetchall() == [
816            (9, "some stuff"),
817            (9, "some stuff"),
818            (9, "some stuff"),
819        ]
820
821        t2.delete().execute()
822
823        t2.insert(values=dict(value=func.length("one") + 8)).execute()
824        assert t2.select().execute().first()["value"] == 11
825
826        t2.update(values=dict(value=func.length("asfda"))).execute()
827        eq_(
828            select([t2.c.value, t2.c.stuff]).execute().first(),
829            (5, "thisisstuff"),
830        )
831
832        t2.update(
833            values={t2.c.value: func.length("asfdaasdf"), t2.c.stuff: "foo"}
834        ).execute()
835        eq_(select([t2.c.value, t2.c.stuff]).execute().first(), (9, "foo"))
836
837    @testing.fails_on_everything_except("postgresql")
838    def test_as_from(self):
839        # TODO: shouldn't this work on oracle too ?
840        x = func.current_date(bind=testing.db).execute().scalar()
841        y = func.current_date(bind=testing.db).select().execute().scalar()
842        z = func.current_date(bind=testing.db).scalar()
843        w = select(
844            ["*"], from_obj=[func.current_date(bind=testing.db)]
845        ).scalar()
846
847        assert x == y == z == w
848
849    def test_extract_bind(self):
850        """Basic common denominator execution tests for extract()"""
851
852        date = datetime.date(2010, 5, 1)
853
854        def execute(field):
855            return testing.db.execute(select([extract(field, date)])).scalar()
856
857        assert execute("year") == 2010
858        assert execute("month") == 5
859        assert execute("day") == 1
860
861        date = datetime.datetime(2010, 5, 1, 12, 11, 10)
862
863        assert execute("year") == 2010
864        assert execute("month") == 5
865        assert execute("day") == 1
866
867    def test_extract_expression(self):
868        meta = MetaData(testing.db)
869        table = Table("test", meta, Column("dt", DateTime), Column("d", Date))
870        meta.create_all()
871        try:
872            table.insert().execute(
873                {
874                    "dt": datetime.datetime(2010, 5, 1, 12, 11, 10),
875                    "d": datetime.date(2010, 5, 1),
876                }
877            )
878            rs = select(
879                [extract("year", table.c.dt), extract("month", table.c.d)]
880            ).execute()
881            row = rs.first()
882            assert row[0] == 2010
883            assert row[1] == 5
884            rs.close()
885        finally:
886            meta.drop_all()
887
888
889def exec_sorted(statement, *args, **kw):
890    """Executes a statement and returns a sorted list plain tuple rows."""
891
892    return sorted(
893        [tuple(row) for row in statement.execute(*args, **kw).fetchall()]
894    )
895