1# coding: utf-8
2import datetime
3import decimal
4import importlib
5import operator
6import os
7
8import sqlalchemy as sa
9from sqlalchemy import and_
10from sqlalchemy import ARRAY
11from sqlalchemy import BigInteger
12from sqlalchemy import bindparam
13from sqlalchemy import BLOB
14from sqlalchemy import BOOLEAN
15from sqlalchemy import Boolean
16from sqlalchemy import cast
17from sqlalchemy import CHAR
18from sqlalchemy import CLOB
19from sqlalchemy import DATE
20from sqlalchemy import Date
21from sqlalchemy import DATETIME
22from sqlalchemy import DateTime
23from sqlalchemy import DECIMAL
24from sqlalchemy import dialects
25from sqlalchemy import distinct
26from sqlalchemy import Enum
27from sqlalchemy import exc
28from sqlalchemy import FLOAT
29from sqlalchemy import Float
30from sqlalchemy import func
31from sqlalchemy import inspection
32from sqlalchemy import INTEGER
33from sqlalchemy import Integer
34from sqlalchemy import Interval
35from sqlalchemy import JSON
36from sqlalchemy import LargeBinary
37from sqlalchemy import literal
38from sqlalchemy import MetaData
39from sqlalchemy import NCHAR
40from sqlalchemy import NUMERIC
41from sqlalchemy import Numeric
42from sqlalchemy import NVARCHAR
43from sqlalchemy import PickleType
44from sqlalchemy import REAL
45from sqlalchemy import select
46from sqlalchemy import SMALLINT
47from sqlalchemy import SmallInteger
48from sqlalchemy import String
49from sqlalchemy import testing
50from sqlalchemy import Text
51from sqlalchemy import text
52from sqlalchemy import TIME
53from sqlalchemy import Time
54from sqlalchemy import TIMESTAMP
55from sqlalchemy import type_coerce
56from sqlalchemy import TypeDecorator
57from sqlalchemy import types
58from sqlalchemy import Unicode
59from sqlalchemy import util
60from sqlalchemy import VARCHAR
61import sqlalchemy.dialects.mysql as mysql
62import sqlalchemy.dialects.oracle as oracle
63import sqlalchemy.dialects.postgresql as pg
64from sqlalchemy.engine import default
65from sqlalchemy.schema import AddConstraint
66from sqlalchemy.schema import CheckConstraint
67from sqlalchemy.sql import column
68from sqlalchemy.sql import ddl
69from sqlalchemy.sql import elements
70from sqlalchemy.sql import null
71from sqlalchemy.sql import operators
72from sqlalchemy.sql import sqltypes
73from sqlalchemy.sql import table
74from sqlalchemy.sql import visitors
75from sqlalchemy.sql.sqltypes import TypeEngine
76from sqlalchemy.testing import assert_raises
77from sqlalchemy.testing import assert_raises_message
78from sqlalchemy.testing import AssertsCompiledSQL
79from sqlalchemy.testing import AssertsExecutionResults
80from sqlalchemy.testing import engines
81from sqlalchemy.testing import eq_
82from sqlalchemy.testing import expect_deprecated_20
83from sqlalchemy.testing import expect_raises
84from sqlalchemy.testing import expect_warnings
85from sqlalchemy.testing import fixtures
86from sqlalchemy.testing import is_
87from sqlalchemy.testing import is_not
88from sqlalchemy.testing import mock
89from sqlalchemy.testing import pickleable
90from sqlalchemy.testing.schema import Column
91from sqlalchemy.testing.schema import pep435_enum
92from sqlalchemy.testing.schema import Table
93from sqlalchemy.testing.util import picklers
94from sqlalchemy.testing.util import round_decimal
95from sqlalchemy.util import u
96
97
98def _all_dialect_modules():
99    return [
100        importlib.import_module("sqlalchemy.dialects.%s" % d)
101        for d in dialects.__all__
102        if not d.startswith("_")
103    ]
104
105
106def _all_dialects():
107    return [d.base.dialect() for d in _all_dialect_modules()]
108
109
110def _types_for_mod(mod):
111    for key in dir(mod):
112        typ = getattr(mod, key)
113        if not isinstance(typ, type) or not issubclass(typ, types.TypeEngine):
114            continue
115        yield typ
116
117
118def _all_types(omit_special_types=False):
119    seen = set()
120    for typ in _types_for_mod(types):
121        if omit_special_types and typ in (
122            types.TypeDecorator,
123            types.TypeEngine,
124            types.Variant,
125        ):
126            continue
127
128        if typ in seen:
129            continue
130        seen.add(typ)
131        yield typ
132    for dialect in _all_dialect_modules():
133        for typ in _types_for_mod(dialect):
134            if typ in seen:
135                continue
136            seen.add(typ)
137            yield typ
138
139
140class AdaptTest(fixtures.TestBase):
141    @testing.combinations(((t,) for t in _types_for_mod(types)), id_="n")
142    def test_uppercase_importable(self, typ):
143        if typ.__name__ == typ.__name__.upper():
144            assert getattr(sa, typ.__name__) is typ
145            assert typ.__name__ in types.__all__
146
147    @testing.combinations(
148        ((d.name, d) for d in _all_dialects()), argnames="dialect", id_="ia"
149    )
150    @testing.combinations(
151        (REAL(), "REAL"),
152        (FLOAT(), "FLOAT"),
153        (NUMERIC(), "NUMERIC"),
154        (DECIMAL(), "DECIMAL"),
155        (INTEGER(), "INTEGER"),
156        (SMALLINT(), "SMALLINT"),
157        (TIMESTAMP(), ("TIMESTAMP", "TIMESTAMP WITHOUT TIME ZONE")),
158        (DATETIME(), "DATETIME"),
159        (DATE(), "DATE"),
160        (TIME(), ("TIME", "TIME WITHOUT TIME ZONE")),
161        (CLOB(), "CLOB"),
162        (VARCHAR(10), ("VARCHAR(10)", "VARCHAR(10 CHAR)")),
163        (
164            NVARCHAR(10),
165            ("NVARCHAR(10)", "NATIONAL VARCHAR(10)", "NVARCHAR2(10)"),
166        ),
167        (CHAR(), "CHAR"),
168        (NCHAR(), ("NCHAR", "NATIONAL CHAR")),
169        (BLOB(), ("BLOB", "BLOB SUB_TYPE 0")),
170        (BOOLEAN(), ("BOOLEAN", "BOOL", "INTEGER")),
171        argnames="type_, expected",
172        id_="ra",
173    )
174    def test_uppercase_rendering(self, dialect, type_, expected):
175        """Test that uppercase types from types.py always render as their
176        type.
177
178        As of SQLA 0.6, using an uppercase type means you want specifically
179        that type. If the database in use doesn't support that DDL, it (the DB
180        backend) should raise an error - it means you should be using a
181        lowercased (genericized) type.
182
183        """
184
185        if isinstance(expected, str):
186            expected = (expected,)
187
188        try:
189            compiled = type_.compile(dialect=dialect)
190        except NotImplementedError:
191            return
192
193        assert compiled in expected, "%r matches none of %r for dialect %s" % (
194            compiled,
195            expected,
196            dialect.name,
197        )
198
199        assert (
200            str(types.to_instance(type_)) in expected
201        ), "default str() of type %r not expected, %r" % (type_, expected)
202
203    def _adaptions():
204        for typ in _all_types(omit_special_types=True):
205
206            # up adapt from LowerCase to UPPERCASE,
207            # as well as to all non-sqltypes
208            up_adaptions = [typ] + typ.__subclasses__()
209            yield "%s.%s" % (
210                typ.__module__,
211                typ.__name__,
212            ), False, typ, up_adaptions
213            for subcl in typ.__subclasses__():
214                if (
215                    subcl is not typ
216                    and typ is not TypeDecorator
217                    and "sqlalchemy" in subcl.__module__
218                ):
219                    yield "%s.%s" % (
220                        subcl.__module__,
221                        subcl.__name__,
222                    ), True, subcl, [typ]
223
224    @testing.combinations(_adaptions(), id_="iaaa")
225    def test_adapt_method(self, is_down_adaption, typ, target_adaptions):
226        """ensure all types have a working adapt() method,
227        which creates a distinct copy.
228
229        The distinct copy ensures that when we cache
230        the adapted() form of a type against the original
231        in a weak key dictionary, a cycle is not formed.
232
233        This test doesn't test type-specific arguments of
234        adapt() beyond their defaults.
235
236        """
237
238        if issubclass(typ, ARRAY):
239            t1 = typ(String)
240        else:
241            t1 = typ()
242        for cls in target_adaptions:
243            if (is_down_adaption and issubclass(typ, sqltypes.Emulated)) or (
244                not is_down_adaption and issubclass(cls, sqltypes.Emulated)
245            ):
246                continue
247
248            # print("ADAPT %s -> %s" % (t1.__class__, cls))
249            t2 = t1.adapt(cls)
250            assert t1 is not t2
251
252            if is_down_adaption:
253                t2, t1 = t1, t2
254
255            for k in t1.__dict__:
256                if k in (
257                    "impl",
258                    "_is_oracle_number",
259                    "_create_events",
260                    "create_constraint",
261                    "inherit_schema",
262                    "schema",
263                    "metadata",
264                    "name",
265                ):
266                    continue
267                # assert each value was copied, or that
268                # the adapted type has a more specific
269                # value than the original (i.e. SQL Server
270                # applies precision=24 for REAL)
271                assert (
272                    getattr(t2, k) == t1.__dict__[k] or t1.__dict__[k] is None
273                )
274
275        eq_(t1.evaluates_none().should_evaluate_none, True)
276
277    def test_python_type(self):
278        eq_(types.Integer().python_type, int)
279        eq_(types.Numeric().python_type, decimal.Decimal)
280        eq_(types.Numeric(asdecimal=False).python_type, float)
281        eq_(types.LargeBinary().python_type, util.binary_type)
282        eq_(types.Float().python_type, float)
283        eq_(types.Interval().python_type, datetime.timedelta)
284        eq_(types.Date().python_type, datetime.date)
285        eq_(types.DateTime().python_type, datetime.datetime)
286        eq_(types.String().python_type, str)
287        eq_(types.Unicode().python_type, util.text_type)
288        eq_(types.Enum("one", "two", "three").python_type, str)
289
290        assert_raises(
291            NotImplementedError, lambda: types.TypeEngine().python_type
292        )
293
294    @testing.uses_deprecated()
295    @testing.combinations(*[(t,) for t in _all_types(omit_special_types=True)])
296    def test_repr(self, typ):
297        if issubclass(typ, ARRAY):
298            t1 = typ(String)
299        else:
300            t1 = typ()
301        repr(t1)
302
303    @testing.uses_deprecated()
304    @testing.combinations(*[(t,) for t in _all_types(omit_special_types=True)])
305    def test_str(self, typ):
306        if issubclass(typ, ARRAY):
307            t1 = typ(String)
308        else:
309            t1 = typ()
310        str(t1)
311
312    def test_str_third_party(self):
313        class TINYINT(types.TypeEngine):
314            __visit_name__ = "TINYINT"
315
316        eq_(str(TINYINT()), "TINYINT")
317
318    def test_str_third_party_uppercase_no_visit_name(self):
319        class TINYINT(types.TypeEngine):
320            pass
321
322        eq_(str(TINYINT()), "TINYINT")
323
324    def test_str_third_party_camelcase_no_visit_name(self):
325        class TinyInt(types.TypeEngine):
326            pass
327
328        eq_(str(TinyInt()), "TinyInt()")
329
330    def test_adapt_constructor_copy_override_kw(self):
331        """test that adapt() can accept kw args that override
332        the state of the original object.
333
334        This essentially is testing the behavior of util.constructor_copy().
335
336        """
337        t1 = String(length=50)
338        t2 = t1.adapt(Text)
339        eq_(t2.length, 50)
340
341    def test_convert_unicode_text_type(self):
342        with testing.expect_deprecated(
343            "The String.convert_unicode parameter is deprecated"
344        ):
345            eq_(types.String(convert_unicode=True).python_type, util.text_type)
346
347
348class TypeAffinityTest(fixtures.TestBase):
349    @testing.combinations(
350        (String(), String),
351        (VARCHAR(), String),
352        (Date(), Date),
353        (LargeBinary(), types._Binary),
354        id_="rn",
355    )
356    def test_type_affinity(self, type_, affin):
357        eq_(type_._type_affinity, affin)
358
359    @testing.combinations(
360        (Integer(), SmallInteger(), True),
361        (Integer(), String(), False),
362        (Integer(), Integer(), True),
363        (Text(), String(), True),
364        (Text(), Unicode(), True),
365        (LargeBinary(), Integer(), False),
366        (LargeBinary(), PickleType(), True),
367        (PickleType(), LargeBinary(), True),
368        (PickleType(), PickleType(), True),
369        id_="rra",
370    )
371    def test_compare_type_affinity(self, t1, t2, comp):
372        eq_(t1._compare_type_affinity(t2), comp, "%s %s" % (t1, t2))
373
374    def test_decorator_doesnt_cache(self):
375        from sqlalchemy.dialects import postgresql
376
377        class MyType(TypeDecorator):
378            impl = CHAR
379            cache_ok = True
380
381            def load_dialect_impl(self, dialect):
382                if dialect.name == "postgresql":
383                    return dialect.type_descriptor(postgresql.UUID())
384                else:
385                    return dialect.type_descriptor(CHAR(32))
386
387        t1 = MyType()
388        d = postgresql.dialect()
389        assert t1._type_affinity is String
390        assert t1.dialect_impl(d)._type_affinity is postgresql.UUID
391
392
393class AsGenericTest(fixtures.TestBase):
394    @testing.combinations(
395        (String(), String()),
396        (VARCHAR(length=100), String(length=100)),
397        (NVARCHAR(length=100), Unicode(length=100)),
398        (DATE(), Date()),
399        (pg.JSON(), sa.JSON()),
400        (pg.ARRAY(sa.String), sa.ARRAY(sa.String)),
401        (Enum("a", "b", "c"), Enum("a", "b", "c")),
402        (pg.ENUM("a", "b", "c"), Enum("a", "b", "c")),
403        (mysql.ENUM("a", "b", "c"), Enum("a", "b", "c")),
404        (pg.INTERVAL(precision=5), Interval(native=True, second_precision=5)),
405        (
406            oracle.INTERVAL(second_precision=5, day_precision=5),
407            Interval(native=True, day_precision=5, second_precision=5),
408        ),
409    )
410    def test_as_generic(self, t1, t2):
411        assert repr(t1.as_generic(allow_nulltype=False)) == repr(t2)
412
413    @testing.combinations(
414        *[
415            (t,)
416            for t in _all_types(omit_special_types=True)
417            if not util.method_is_overridden(t, TypeEngine.as_generic)
418        ]
419    )
420    def test_as_generic_all_types_heuristic(self, type_):
421        if issubclass(type_, ARRAY):
422            t1 = type_(String)
423        else:
424            t1 = type_()
425
426        try:
427            gentype = t1.as_generic()
428        except NotImplementedError:
429            pass
430        else:
431            assert isinstance(t1, gentype.__class__)
432            assert isinstance(gentype, TypeEngine)
433
434        gentype = t1.as_generic(allow_nulltype=True)
435        if not isinstance(gentype, types.NULLTYPE.__class__):
436            assert isinstance(t1, gentype.__class__)
437            assert isinstance(gentype, TypeEngine)
438
439    @testing.combinations(
440        *[
441            (t,)
442            for t in _all_types(omit_special_types=True)
443            if util.method_is_overridden(t, TypeEngine.as_generic)
444        ]
445    )
446    def test_as_generic_all_types_custom(self, type_):
447        if issubclass(type_, ARRAY):
448            t1 = type_(String)
449        else:
450            t1 = type_()
451
452        gentype = t1.as_generic(allow_nulltype=False)
453        assert isinstance(gentype, TypeEngine)
454
455
456class PickleTypesTest(fixtures.TestBase):
457    @testing.combinations(
458        ("Boo", Boolean()),
459        ("Str", String()),
460        ("Tex", Text()),
461        ("Uni", Unicode()),
462        ("Int", Integer()),
463        ("Sma", SmallInteger()),
464        ("Big", BigInteger()),
465        ("Num", Numeric()),
466        ("Flo", Float()),
467        ("Dat", DateTime()),
468        ("Dat", Date()),
469        ("Tim", Time()),
470        ("Lar", LargeBinary()),
471        ("Pic", PickleType()),
472        ("Int", Interval()),
473        id_="ar",
474    )
475    def test_pickle_types(self, name, type_):
476        column_type = Column(name, type_)
477        meta = MetaData()
478        Table("foo", meta, column_type)
479
480        for loads, dumps in picklers():
481            loads(dumps(column_type))
482            loads(dumps(meta))
483
484
485class _UserDefinedTypeFixture(object):
486    @classmethod
487    def define_tables(cls, metadata):
488        class MyType(types.UserDefinedType):
489            def get_col_spec(self):
490                return "VARCHAR(100)"
491
492            def bind_processor(self, dialect):
493                def process(value):
494                    if value is None:
495                        value = "<null value>"
496                    return "BIND_IN" + value
497
498                return process
499
500            def result_processor(self, dialect, coltype):
501                def process(value):
502                    return value + "BIND_OUT"
503
504                return process
505
506            def adapt(self, typeobj):
507                return typeobj()
508
509        class MyDecoratedType(types.TypeDecorator):
510            impl = String
511            cache_ok = True
512
513            def bind_processor(self, dialect):
514                impl_processor = super(MyDecoratedType, self).bind_processor(
515                    dialect
516                ) or (lambda value: value)
517
518                def process(value):
519                    if value is None:
520                        value = "<null value>"
521                    return "BIND_IN" + impl_processor(value)
522
523                return process
524
525            def result_processor(self, dialect, coltype):
526                impl_processor = super(MyDecoratedType, self).result_processor(
527                    dialect, coltype
528                ) or (lambda value: value)
529
530                def process(value):
531                    return impl_processor(value) + "BIND_OUT"
532
533                return process
534
535            def copy(self):
536                return MyDecoratedType()
537
538        class MyNewUnicodeType(types.TypeDecorator):
539            impl = Unicode
540            cache_ok = True
541
542            def process_bind_param(self, value, dialect):
543                if value is None:
544                    value = u"<null value>"
545                return "BIND_IN" + value
546
547            def process_result_value(self, value, dialect):
548                return value + "BIND_OUT"
549
550            def copy(self):
551                return MyNewUnicodeType(self.impl.length)
552
553        class MyNewIntType(types.TypeDecorator):
554            impl = Integer
555            cache_ok = True
556
557            def process_bind_param(self, value, dialect):
558                if value is None:
559                    value = 29
560                return value * 10
561
562            def process_result_value(self, value, dialect):
563                return value * 10
564
565            def copy(self):
566                return MyNewIntType()
567
568        class MyNewIntSubClass(MyNewIntType):
569            def process_result_value(self, value, dialect):
570                return value * 15
571
572            def copy(self):
573                return MyNewIntSubClass()
574
575        class MyUnicodeType(types.TypeDecorator):
576            impl = Unicode
577            cache_ok = True
578
579            def bind_processor(self, dialect):
580                impl_processor = super(MyUnicodeType, self).bind_processor(
581                    dialect
582                ) or (lambda value: value)
583
584                def process(value):
585                    if value is None:
586                        value = u"<null value>"
587
588                    return "BIND_IN" + impl_processor(value)
589
590                return process
591
592            def result_processor(self, dialect, coltype):
593                impl_processor = super(MyUnicodeType, self).result_processor(
594                    dialect, coltype
595                ) or (lambda value: value)
596
597                def process(value):
598                    return impl_processor(value) + "BIND_OUT"
599
600                return process
601
602            def copy(self):
603                return MyUnicodeType(self.impl.length)
604
605        class MyDecOfDec(types.TypeDecorator):
606            impl = MyNewIntType
607            cache_ok = True
608
609        Table(
610            "users",
611            metadata,
612            Column("user_id", Integer, primary_key=True),
613            # totall custom type
614            Column("goofy", MyType, nullable=False),
615            # decorated type with an argument, so its a String
616            Column("goofy2", MyDecoratedType(50), nullable=False),
617            Column("goofy4", MyUnicodeType(50), nullable=False),
618            Column("goofy7", MyNewUnicodeType(50), nullable=False),
619            Column("goofy8", MyNewIntType, nullable=False),
620            Column("goofy9", MyNewIntSubClass, nullable=False),
621            Column("goofy10", MyDecOfDec, nullable=False),
622        )
623
624
625class UserDefinedRoundTripTest(_UserDefinedTypeFixture, fixtures.TablesTest):
626    __backend__ = True
627
628    def _data_fixture(self, connection):
629        users = self.tables.users
630        connection.execute(
631            users.insert(),
632            dict(
633                user_id=2,
634                goofy="jack",
635                goofy2="jack",
636                goofy4=util.u("jack"),
637                goofy7=util.u("jack"),
638                goofy8=12,
639                goofy9=12,
640                goofy10=12,
641            ),
642        )
643        connection.execute(
644            users.insert(),
645            dict(
646                user_id=3,
647                goofy="lala",
648                goofy2="lala",
649                goofy4=util.u("lala"),
650                goofy7=util.u("lala"),
651                goofy8=15,
652                goofy9=15,
653                goofy10=15,
654            ),
655        )
656        connection.execute(
657            users.insert(),
658            dict(
659                user_id=4,
660                goofy="fred",
661                goofy2="fred",
662                goofy4=util.u("fred"),
663                goofy7=util.u("fred"),
664                goofy8=9,
665                goofy9=9,
666                goofy10=9,
667            ),
668        )
669        connection.execute(
670            users.insert(),
671            dict(
672                user_id=5,
673                goofy=None,
674                goofy2=None,
675                goofy4=None,
676                goofy7=None,
677                goofy8=None,
678                goofy9=None,
679                goofy10=None,
680            ),
681        )
682
683    def test_processing(self, connection):
684        users = self.tables.users
685        self._data_fixture(connection)
686
687        result = connection.execute(
688            users.select().order_by(users.c.user_id)
689        ).fetchall()
690        eq_(
691            result,
692            [
693                (
694                    2,
695                    "BIND_INjackBIND_OUT",
696                    "BIND_INjackBIND_OUT",
697                    "BIND_INjackBIND_OUT",
698                    "BIND_INjackBIND_OUT",
699                    1200,
700                    1800,
701                    1200,
702                ),
703                (
704                    3,
705                    "BIND_INlalaBIND_OUT",
706                    "BIND_INlalaBIND_OUT",
707                    "BIND_INlalaBIND_OUT",
708                    "BIND_INlalaBIND_OUT",
709                    1500,
710                    2250,
711                    1500,
712                ),
713                (
714                    4,
715                    "BIND_INfredBIND_OUT",
716                    "BIND_INfredBIND_OUT",
717                    "BIND_INfredBIND_OUT",
718                    "BIND_INfredBIND_OUT",
719                    900,
720                    1350,
721                    900,
722                ),
723                (
724                    5,
725                    "BIND_IN<null value>BIND_OUT",
726                    "BIND_IN<null value>BIND_OUT",
727                    "BIND_IN<null value>BIND_OUT",
728                    "BIND_IN<null value>BIND_OUT",
729                    2900,
730                    4350,
731                    2900,
732                ),
733            ],
734        )
735
736    def test_plain_in_typedec(self, connection):
737        users = self.tables.users
738        self._data_fixture(connection)
739
740        stmt = (
741            select(users.c.user_id, users.c.goofy8)
742            .where(users.c.goofy8.in_([15, 9]))
743            .order_by(users.c.user_id)
744        )
745        result = connection.execute(stmt, {"goofy": [15, 9]})
746        eq_(result.fetchall(), [(3, 1500), (4, 900)])
747
748    def test_plain_in_typedec_of_typedec(self, connection):
749        users = self.tables.users
750        self._data_fixture(connection)
751
752        stmt = (
753            select(users.c.user_id, users.c.goofy10)
754            .where(users.c.goofy10.in_([15, 9]))
755            .order_by(users.c.user_id)
756        )
757        result = connection.execute(stmt, {"goofy": [15, 9]})
758        eq_(result.fetchall(), [(3, 1500), (4, 900)])
759
760    def test_expanding_in_typedec(self, connection):
761        users = self.tables.users
762        self._data_fixture(connection)
763
764        stmt = (
765            select(users.c.user_id, users.c.goofy8)
766            .where(users.c.goofy8.in_(bindparam("goofy", expanding=True)))
767            .order_by(users.c.user_id)
768        )
769        result = connection.execute(stmt, {"goofy": [15, 9]})
770        eq_(result.fetchall(), [(3, 1500), (4, 900)])
771
772    def test_expanding_in_typedec_of_typedec(self, connection):
773        users = self.tables.users
774        self._data_fixture(connection)
775
776        stmt = (
777            select(users.c.user_id, users.c.goofy10)
778            .where(users.c.goofy10.in_(bindparam("goofy", expanding=True)))
779            .order_by(users.c.user_id)
780        )
781        result = connection.execute(stmt, {"goofy": [15, 9]})
782        eq_(result.fetchall(), [(3, 1500), (4, 900)])
783
784
785class BindProcessorInsertValuesTest(UserDefinedRoundTripTest):
786    """related to #6770, test that insert().values() applies to
787    bound parameter handlers including the None value."""
788
789    __backend__ = True
790
791    def _data_fixture(self, connection):
792        users = self.tables.users
793        connection.execute(
794            users.insert().values(
795                user_id=2,
796                goofy="jack",
797                goofy2="jack",
798                goofy4=util.u("jack"),
799                goofy7=util.u("jack"),
800                goofy8=12,
801                goofy9=12,
802                goofy10=12,
803            ),
804        )
805        connection.execute(
806            users.insert().values(
807                user_id=3,
808                goofy="lala",
809                goofy2="lala",
810                goofy4=util.u("lala"),
811                goofy7=util.u("lala"),
812                goofy8=15,
813                goofy9=15,
814                goofy10=15,
815            ),
816        )
817        connection.execute(
818            users.insert().values(
819                user_id=4,
820                goofy="fred",
821                goofy2="fred",
822                goofy4=util.u("fred"),
823                goofy7=util.u("fred"),
824                goofy8=9,
825                goofy9=9,
826                goofy10=9,
827            ),
828        )
829        connection.execute(
830            users.insert().values(
831                user_id=5,
832                goofy=None,
833                goofy2=None,
834                goofy4=None,
835                goofy7=None,
836                goofy8=None,
837                goofy9=None,
838                goofy10=None,
839            ),
840        )
841
842
843class UserDefinedTest(
844    _UserDefinedTypeFixture, fixtures.TablesTest, AssertsCompiledSQL
845):
846
847    run_create_tables = None
848    run_inserts = None
849    run_deletes = None
850
851    """tests user-defined types."""
852
853    def test_typedecorator_literal_render(self):
854        class MyType(types.TypeDecorator):
855            impl = String
856            cache_ok = True
857
858            def process_literal_param(self, value, dialect):
859                return "HI->%s<-THERE" % value
860
861        self.assert_compile(
862            select(literal("test", MyType)),
863            "SELECT 'HI->test<-THERE' AS anon_1",
864            dialect="default",
865            literal_binds=True,
866        )
867
868    def test_kw_colspec(self):
869        class MyType(types.UserDefinedType):
870            def get_col_spec(self, **kw):
871                return "FOOB %s" % kw["type_expression"].name
872
873        class MyOtherType(types.UserDefinedType):
874            def get_col_spec(self):
875                return "BAR"
876
877        t = Table("t", MetaData(), Column("bar", MyType, nullable=False))
878
879        self.assert_compile(ddl.CreateColumn(t.c.bar), "bar FOOB bar NOT NULL")
880
881        t = Table("t", MetaData(), Column("bar", MyOtherType, nullable=False))
882        self.assert_compile(ddl.CreateColumn(t.c.bar), "bar BAR NOT NULL")
883
884    def test_typedecorator_literal_render_fallback_bound(self):
885        # fall back to process_bind_param for literal
886        # value rendering.
887        class MyType(types.TypeDecorator):
888            impl = String
889            cache_ok = True
890
891            def process_bind_param(self, value, dialect):
892                return "HI->%s<-THERE" % value
893
894        self.assert_compile(
895            select(literal("test", MyType)),
896            "SELECT 'HI->test<-THERE' AS anon_1",
897            dialect="default",
898            literal_binds=True,
899        )
900
901    def test_typedecorator_impl(self):
902        for impl_, exp, kw in [
903            (Float, "FLOAT", {}),
904            (Float, "FLOAT(2)", {"precision": 2}),
905            (Float(2), "FLOAT(2)", {"precision": 4}),
906            (Numeric(19, 2), "NUMERIC(19, 2)", {}),
907        ]:
908            for dialect_ in (
909                dialects.postgresql,
910                dialects.mssql,
911                dialects.mysql,
912            ):
913                dialect_ = dialect_.dialect()
914
915                raw_impl = types.to_instance(impl_, **kw)
916
917                class MyType(types.TypeDecorator):
918                    impl = impl_
919                    cache_ok = True
920
921                dec_type = MyType(**kw)
922
923                eq_(dec_type.impl.__class__, raw_impl.__class__)
924
925                raw_dialect_impl = raw_impl.dialect_impl(dialect_)
926                dec_dialect_impl = dec_type.dialect_impl(dialect_)
927                eq_(dec_dialect_impl.__class__, MyType)
928                eq_(
929                    raw_dialect_impl.__class__, dec_dialect_impl.impl.__class__
930                )
931
932                self.assert_compile(MyType(**kw), exp, dialect=dialect_)
933
934    def test_user_defined_typedec_impl(self):
935        class MyType(types.TypeDecorator):
936            impl = Float
937            cache_ok = True
938
939            def load_dialect_impl(self, dialect):
940                if dialect.name == "sqlite":
941                    return String(50)
942                else:
943                    return super(MyType, self).load_dialect_impl(dialect)
944
945        sl = dialects.sqlite.dialect()
946        pg = dialects.postgresql.dialect()
947        t = MyType()
948        self.assert_compile(t, "VARCHAR(50)", dialect=sl)
949        self.assert_compile(t, "FLOAT", dialect=pg)
950        eq_(
951            t.dialect_impl(dialect=sl).impl.__class__,
952            String().dialect_impl(dialect=sl).__class__,
953        )
954        eq_(
955            t.dialect_impl(dialect=pg).impl.__class__,
956            Float().dialect_impl(pg).__class__,
957        )
958
959    @testing.combinations((Boolean,), (Enum,))
960    def test_typedecorator_schematype_constraint(self, typ):
961        class B(TypeDecorator):
962            impl = typ
963            cache_ok = True
964
965        t1 = Table("t1", MetaData(), Column("q", B(create_constraint=True)))
966        eq_(
967            len([c for c in t1.constraints if isinstance(c, CheckConstraint)]),
968            1,
969        )
970
971    def test_type_decorator_repr(self):
972        class MyType(TypeDecorator):
973            impl = VARCHAR
974
975            cache_ok = True
976
977        eq_(repr(MyType(45)), "MyType(length=45)")
978
979    def test_user_defined_typedec_impl_bind(self):
980        class TypeOne(types.TypeEngine):
981            def bind_processor(self, dialect):
982                def go(value):
983                    return value + " ONE"
984
985                return go
986
987        class TypeTwo(types.TypeEngine):
988            def bind_processor(self, dialect):
989                def go(value):
990                    return value + " TWO"
991
992                return go
993
994        class MyType(types.TypeDecorator):
995            impl = TypeOne
996            cache_ok = True
997
998            def load_dialect_impl(self, dialect):
999                if dialect.name == "sqlite":
1000                    return TypeOne()
1001                else:
1002                    return TypeTwo()
1003
1004            def process_bind_param(self, value, dialect):
1005                return "MYTYPE " + value
1006
1007        sl = dialects.sqlite.dialect()
1008        pg = dialects.postgresql.dialect()
1009        t = MyType()
1010        eq_(t._cached_bind_processor(sl)("foo"), "MYTYPE foo ONE")
1011        eq_(t._cached_bind_processor(pg)("foo"), "MYTYPE foo TWO")
1012
1013    def test_user_defined_dialect_specific_args(self):
1014        class MyType(types.UserDefinedType):
1015            def __init__(self, foo="foo", **kwargs):
1016                super(MyType, self).__init__()
1017                self.foo = foo
1018                self.dialect_specific_args = kwargs
1019
1020            def adapt(self, cls):
1021                return cls(foo=self.foo, **self.dialect_specific_args)
1022
1023        t = MyType(bar="bar")
1024        a = t.dialect_impl(testing.db.dialect)
1025        eq_(a.foo, "foo")
1026        eq_(a.dialect_specific_args["bar"], "bar")
1027
1028
1029class StringConvertUnicodeTest(fixtures.TestBase):
1030    @testing.combinations((Unicode,), (String,), argnames="datatype")
1031    @testing.combinations((True,), (False,), argnames="convert_unicode")
1032    @testing.combinations(
1033        (String.RETURNS_CONDITIONAL,),
1034        (String.RETURNS_BYTES,),
1035        (String.RETURNS_UNICODE),
1036        argnames="returns_unicode_strings",
1037    )
1038    def test_convert_unicode(
1039        self, datatype, convert_unicode, returns_unicode_strings
1040    ):
1041        s1 = datatype()
1042        dialect = mock.Mock(
1043            returns_unicode_strings=returns_unicode_strings,
1044            encoding="utf-8",
1045            convert_unicode=convert_unicode,
1046        )
1047
1048        proc = s1.result_processor(dialect, None)
1049
1050        string = u("méil")
1051        bytestring = string.encode("utf-8")
1052
1053        if (
1054            datatype is Unicode or convert_unicode
1055        ) and returns_unicode_strings in (
1056            String.RETURNS_CONDITIONAL,
1057            String.RETURNS_BYTES,
1058        ):
1059            eq_(proc(bytestring), string)
1060
1061            if returns_unicode_strings is String.RETURNS_CONDITIONAL:
1062                eq_(proc(string), string)
1063            else:
1064                if util.py3k:
1065                    # trying to decode a unicode
1066                    assert_raises(TypeError, proc, string)
1067                else:
1068                    assert_raises(UnicodeEncodeError, proc, string)
1069        else:
1070            is_(proc, None)
1071
1072
1073class TypeCoerceCastTest(fixtures.TablesTest):
1074    __backend__ = True
1075
1076    @classmethod
1077    def define_tables(cls, metadata):
1078        class MyType(types.TypeDecorator):
1079            impl = String(50)
1080            cache_ok = True
1081
1082            def process_bind_param(self, value, dialect):
1083                return "BIND_IN" + str(value)
1084
1085            def process_result_value(self, value, dialect):
1086                return value + "BIND_OUT"
1087
1088        cls.MyType = MyType
1089
1090        Table("t", metadata, Column("data", String(50)))
1091
1092    def test_insert_round_trip_cast(self, connection):
1093        self._test_insert_round_trip(cast, connection)
1094
1095    def test_insert_round_trip_type_coerce(self, connection):
1096        self._test_insert_round_trip(type_coerce, connection)
1097
1098    def _test_insert_round_trip(self, coerce_fn, conn):
1099        MyType = self.MyType
1100        t = self.tables.t
1101
1102        conn.execute(t.insert().values(data=coerce_fn("d1", MyType)))
1103
1104        eq_(
1105            conn.execute(select(coerce_fn(t.c.data, MyType))).fetchall(),
1106            [("BIND_INd1BIND_OUT",)],
1107        )
1108
1109    def test_coerce_from_nulltype_cast(self, connection):
1110        self._test_coerce_from_nulltype(cast, connection)
1111
1112    def test_coerce_from_nulltype_type_coerce(self, connection):
1113        self._test_coerce_from_nulltype(type_coerce, connection)
1114
1115    def _test_coerce_from_nulltype(self, coerce_fn, conn):
1116        MyType = self.MyType
1117
1118        # test coerce from nulltype - e.g. use an object that
1119        # doesn't match to a known type
1120        class MyObj(object):
1121            def __str__(self):
1122                return "THISISMYOBJ"
1123
1124        t = self.tables.t
1125
1126        conn.execute(t.insert().values(data=coerce_fn(MyObj(), MyType)))
1127
1128        eq_(
1129            conn.execute(select(coerce_fn(t.c.data, MyType))).fetchall(),
1130            [("BIND_INTHISISMYOBJBIND_OUT",)],
1131        )
1132
1133    def test_vs_non_coerced_cast(self, connection):
1134        self._test_vs_non_coerced(cast, connection)
1135
1136    def test_vs_non_coerced_type_coerce(self, connection):
1137        self._test_vs_non_coerced(type_coerce, connection)
1138
1139    def _test_vs_non_coerced(self, coerce_fn, conn):
1140        MyType = self.MyType
1141        t = self.tables.t
1142
1143        conn.execute(t.insert().values(data=coerce_fn("d1", MyType)))
1144
1145        eq_(
1146            conn.execute(
1147                select(t.c.data, coerce_fn(t.c.data, MyType))
1148            ).fetchall(),
1149            [("BIND_INd1", "BIND_INd1BIND_OUT")],
1150        )
1151
1152    def test_vs_non_coerced_alias_cast(self, connection):
1153        self._test_vs_non_coerced_alias(cast, connection)
1154
1155    def test_vs_non_coerced_alias_type_coerce(self, connection):
1156        self._test_vs_non_coerced_alias(type_coerce, connection)
1157
1158    def _test_vs_non_coerced_alias(self, coerce_fn, conn):
1159        MyType = self.MyType
1160        t = self.tables.t
1161
1162        conn.execute(t.insert().values(data=coerce_fn("d1", MyType)))
1163
1164        eq_(
1165            conn.execute(
1166                select(t.c.data.label("x"), coerce_fn(t.c.data, MyType))
1167                .alias()
1168                .select()
1169            ).fetchall(),
1170            [("BIND_INd1", "BIND_INd1BIND_OUT")],
1171        )
1172
1173    def test_vs_non_coerced_where_cast(self, connection):
1174        self._test_vs_non_coerced_where(cast, connection)
1175
1176    def test_vs_non_coerced_where_type_coerce(self, connection):
1177        self._test_vs_non_coerced_where(type_coerce, connection)
1178
1179    def _test_vs_non_coerced_where(self, coerce_fn, conn):
1180        MyType = self.MyType
1181
1182        t = self.tables.t
1183        conn.execute(t.insert().values(data=coerce_fn("d1", MyType)))
1184
1185        # coerce on left side
1186        eq_(
1187            conn.execute(
1188                select(t.c.data, coerce_fn(t.c.data, MyType)).where(
1189                    coerce_fn(t.c.data, MyType) == "d1"
1190                )
1191            ).fetchall(),
1192            [("BIND_INd1", "BIND_INd1BIND_OUT")],
1193        )
1194
1195        # coerce on right side
1196        eq_(
1197            conn.execute(
1198                select(t.c.data, coerce_fn(t.c.data, MyType)).where(
1199                    t.c.data == coerce_fn("d1", MyType)
1200                )
1201            ).fetchall(),
1202            [("BIND_INd1", "BIND_INd1BIND_OUT")],
1203        )
1204
1205    def test_coerce_none_cast(self, connection):
1206        self._test_coerce_none(cast, connection)
1207
1208    def test_coerce_none_type_coerce(self, connection):
1209        self._test_coerce_none(type_coerce, connection)
1210
1211    def _test_coerce_none(self, coerce_fn, conn):
1212        MyType = self.MyType
1213
1214        t = self.tables.t
1215        conn.execute(t.insert().values(data=coerce_fn("d1", MyType)))
1216        eq_(
1217            conn.execute(
1218                select(t.c.data, coerce_fn(t.c.data, MyType)).where(
1219                    t.c.data == coerce_fn(None, MyType)
1220                )
1221            ).fetchall(),
1222            [],
1223        )
1224
1225        eq_(
1226            conn.execute(
1227                select(t.c.data, coerce_fn(t.c.data, MyType)).where(
1228                    coerce_fn(t.c.data, MyType) == None
1229                )
1230            ).fetchall(),  # noqa
1231            [],
1232        )
1233
1234    def test_resolve_clause_element_cast(self, connection):
1235        self._test_resolve_clause_element(cast, connection)
1236
1237    def test_resolve_clause_element_type_coerce(self, connection):
1238        self._test_resolve_clause_element(type_coerce, connection)
1239
1240    def _test_resolve_clause_element(self, coerce_fn, conn):
1241        MyType = self.MyType
1242
1243        t = self.tables.t
1244        conn.execute(t.insert().values(data=coerce_fn("d1", MyType)))
1245
1246        class MyFoob(object):
1247            def __clause_element__(self):
1248                return t.c.data
1249
1250        eq_(
1251            conn.execute(
1252                select(t.c.data, coerce_fn(MyFoob(), MyType))
1253            ).fetchall(),
1254            [("BIND_INd1", "BIND_INd1BIND_OUT")],
1255        )
1256
1257    def test_cast_replace_col_w_bind(self, connection):
1258        self._test_replace_col_w_bind(cast, connection)
1259
1260    def test_type_coerce_replace_col_w_bind(self, connection):
1261        self._test_replace_col_w_bind(type_coerce, connection)
1262
1263    def _test_replace_col_w_bind(self, coerce_fn, conn):
1264        MyType = self.MyType
1265
1266        t = self.tables.t
1267        conn.execute(t.insert().values(data=coerce_fn("d1", MyType)))
1268
1269        stmt = select(t.c.data, coerce_fn(t.c.data, MyType))
1270
1271        def col_to_bind(col):
1272            if col is t.c.data:
1273                return bindparam(None, "x", type_=col.type, unique=True)
1274            return None
1275
1276        # ensure we evaluate the expression so that we can see
1277        # the clone resets this info
1278        stmt.compile()
1279
1280        new_stmt = visitors.replacement_traverse(stmt, {}, col_to_bind)
1281
1282        # original statement
1283        eq_(
1284            conn.execute(stmt).fetchall(),
1285            [("BIND_INd1", "BIND_INd1BIND_OUT")],
1286        )
1287
1288        # replaced with binds; CAST can't affect the bound parameter
1289        # on the way in here
1290        eq_(
1291            conn.execute(new_stmt).fetchall(),
1292            [("x", "BIND_INxBIND_OUT")]
1293            if coerce_fn is type_coerce
1294            else [("x", "xBIND_OUT")],
1295        )
1296
1297    def test_cast_bind(self, connection):
1298        self._test_bind(cast, connection)
1299
1300    def test_type_bind(self, connection):
1301        self._test_bind(type_coerce, connection)
1302
1303    def _test_bind(self, coerce_fn, conn):
1304        MyType = self.MyType
1305
1306        t = self.tables.t
1307        conn.execute(t.insert().values(data=coerce_fn("d1", MyType)))
1308
1309        stmt = select(
1310            bindparam(None, "x", String(50), unique=True),
1311            coerce_fn(bindparam(None, "x", String(50), unique=True), MyType),
1312        )
1313
1314        eq_(
1315            conn.execute(stmt).fetchall(),
1316            [("x", "BIND_INxBIND_OUT")]
1317            if coerce_fn is type_coerce
1318            else [("x", "xBIND_OUT")],
1319        )
1320
1321    def test_cast_existing_typed(self, connection):
1322        MyType = self.MyType
1323        coerce_fn = cast
1324
1325        # when cast() is given an already typed value,
1326        # the type does not take effect on the value itself.
1327        eq_(
1328            connection.scalar(select(coerce_fn(literal("d1"), MyType))),
1329            "d1BIND_OUT",
1330        )
1331
1332    def test_type_coerce_existing_typed(self, connection):
1333        MyType = self.MyType
1334        coerce_fn = type_coerce
1335        t = self.tables.t
1336
1337        # type_coerce does upgrade the given expression to the
1338        # given type.
1339
1340        connection.execute(
1341            t.insert().values(data=coerce_fn(literal("d1"), MyType))
1342        )
1343
1344        eq_(
1345            connection.execute(select(coerce_fn(t.c.data, MyType))).fetchall(),
1346            [("BIND_INd1BIND_OUT",)],
1347        )
1348
1349
1350class VariantBackendTest(fixtures.TestBase, AssertsCompiledSQL):
1351    __backend__ = True
1352
1353    @testing.fixture
1354    def variant_roundtrip(self, metadata, connection):
1355        def run(datatype, data, assert_data):
1356            t = Table(
1357                "t",
1358                metadata,
1359                Column("data", datatype),
1360            )
1361            t.create(connection)
1362
1363            connection.execute(t.insert(), [{"data": elem} for elem in data])
1364            eq_(
1365                connection.execute(select(t).order_by(t.c.data)).all(),
1366                [(elem,) for elem in assert_data],
1367            )
1368
1369            eq_(
1370                # test an IN, which in 1.4 is an expanding
1371                connection.execute(
1372                    select(t).where(t.c.data.in_(data)).order_by(t.c.data)
1373                ).all(),
1374                [(elem,) for elem in assert_data],
1375            )
1376
1377        return run
1378
1379    def test_type_decorator_variant_one_roundtrip(self, variant_roundtrip):
1380        class Foo(TypeDecorator):
1381            impl = String(50)
1382            cache_ok = True
1383
1384        if testing.against("postgresql"):
1385            data = [5, 6, 10]
1386        else:
1387            data = ["five", "six", "ten"]
1388        variant_roundtrip(
1389            Foo().with_variant(Integer, "postgresql"), data, data
1390        )
1391
1392    def test_type_decorator_variant_two(self, variant_roundtrip):
1393        class UTypeOne(types.UserDefinedType):
1394            def get_col_spec(self):
1395                return "VARCHAR(50)"
1396
1397            def bind_processor(self, dialect):
1398                def process(value):
1399                    return value + "UONE"
1400
1401                return process
1402
1403        class UTypeTwo(types.UserDefinedType):
1404            def get_col_spec(self):
1405                return "VARCHAR(50)"
1406
1407            def bind_processor(self, dialect):
1408                def process(value):
1409                    return value + "UTWO"
1410
1411                return process
1412
1413        variant = UTypeOne()
1414        for db in ["postgresql", "mysql", "mariadb"]:
1415            variant = variant.with_variant(UTypeTwo(), db)
1416
1417        class Foo(TypeDecorator):
1418            impl = variant
1419            cache_ok = True
1420
1421        if testing.against("postgresql"):
1422            data = assert_data = [5, 6, 10]
1423        elif testing.against("mysql") or testing.against("mariadb"):
1424            data = ["five", "six", "ten"]
1425            assert_data = ["fiveUTWO", "sixUTWO", "tenUTWO"]
1426        else:
1427            data = ["five", "six", "ten"]
1428            assert_data = ["fiveUONE", "sixUONE", "tenUONE"]
1429
1430        variant_roundtrip(
1431            Foo().with_variant(Integer, "postgresql"), data, assert_data
1432        )
1433
1434    def test_type_decorator_variant_three(self, variant_roundtrip):
1435        class Foo(TypeDecorator):
1436            impl = String
1437            cache_ok = True
1438
1439        if testing.against("postgresql"):
1440            data = ["five", "six", "ten"]
1441        else:
1442            data = [5, 6, 10]
1443
1444        variant_roundtrip(
1445            Integer().with_variant(Foo(), "postgresql"), data, data
1446        )
1447
1448    def test_type_decorator_compile_variant_one(self):
1449        class Foo(TypeDecorator):
1450            impl = String
1451            cache_ok = True
1452
1453        self.assert_compile(
1454            Foo().with_variant(Integer, "sqlite"),
1455            "INTEGER",
1456            dialect=dialects.sqlite.dialect(),
1457        )
1458
1459        self.assert_compile(
1460            Foo().with_variant(Integer, "sqlite"),
1461            "VARCHAR",
1462            dialect=dialects.postgresql.dialect(),
1463        )
1464
1465    def test_type_decorator_compile_variant_two(self):
1466        class UTypeOne(types.UserDefinedType):
1467            def get_col_spec(self):
1468                return "UTYPEONE"
1469
1470            def bind_processor(self, dialect):
1471                def process(value):
1472                    return value + "UONE"
1473
1474                return process
1475
1476        class UTypeTwo(types.UserDefinedType):
1477            def get_col_spec(self):
1478                return "UTYPETWO"
1479
1480            def bind_processor(self, dialect):
1481                def process(value):
1482                    return value + "UTWO"
1483
1484                return process
1485
1486        variant = UTypeOne().with_variant(UTypeTwo(), "postgresql")
1487
1488        class Foo(TypeDecorator):
1489            impl = variant
1490            cache_ok = True
1491
1492        self.assert_compile(
1493            Foo().with_variant(Integer, "sqlite"),
1494            "INTEGER",
1495            dialect=dialects.sqlite.dialect(),
1496        )
1497
1498        self.assert_compile(
1499            Foo().with_variant(Integer, "sqlite"),
1500            "UTYPETWO",
1501            dialect=dialects.postgresql.dialect(),
1502        )
1503
1504    def test_type_decorator_compile_variant_three(self):
1505        class Foo(TypeDecorator):
1506            impl = String
1507            cache_ok = True
1508
1509        self.assert_compile(
1510            Integer().with_variant(Foo(), "postgresql"),
1511            "INTEGER",
1512            dialect=dialects.sqlite.dialect(),
1513        )
1514
1515        self.assert_compile(
1516            Integer().with_variant(Foo(), "postgresql"),
1517            "VARCHAR",
1518            dialect=dialects.postgresql.dialect(),
1519        )
1520
1521
1522class VariantTest(fixtures.TestBase, AssertsCompiledSQL):
1523    def setup_test(self):
1524        class UTypeOne(types.UserDefinedType):
1525            def get_col_spec(self):
1526                return "UTYPEONE"
1527
1528            def bind_processor(self, dialect):
1529                def process(value):
1530                    return value + "UONE"
1531
1532                return process
1533
1534        class UTypeTwo(types.UserDefinedType):
1535            def get_col_spec(self):
1536                return "UTYPETWO"
1537
1538            def bind_processor(self, dialect):
1539                def process(value):
1540                    return value + "UTWO"
1541
1542                return process
1543
1544        class UTypeThree(types.UserDefinedType):
1545            def get_col_spec(self):
1546                return "UTYPETHREE"
1547
1548        self.UTypeOne = UTypeOne
1549        self.UTypeTwo = UTypeTwo
1550        self.UTypeThree = UTypeThree
1551        self.variant = self.UTypeOne().with_variant(
1552            self.UTypeTwo(), "postgresql"
1553        )
1554        self.composite = self.variant.with_variant(self.UTypeThree(), "mysql")
1555
1556    def test_illegal_dupe(self):
1557        v = self.UTypeOne().with_variant(self.UTypeTwo(), "postgresql")
1558        assert_raises_message(
1559            exc.ArgumentError,
1560            "Dialect 'postgresql' is already present "
1561            "in the mapping for this Variant",
1562            lambda: v.with_variant(self.UTypeThree(), "postgresql"),
1563        )
1564
1565    def test_compile(self):
1566        self.assert_compile(self.variant, "UTYPEONE", use_default_dialect=True)
1567        self.assert_compile(
1568            self.variant, "UTYPEONE", dialect=dialects.mysql.dialect()
1569        )
1570        self.assert_compile(
1571            self.variant, "UTYPETWO", dialect=dialects.postgresql.dialect()
1572        )
1573
1574    def test_to_instance(self):
1575        self.assert_compile(
1576            self.UTypeOne().with_variant(self.UTypeTwo, "postgresql"),
1577            "UTYPETWO",
1578            dialect=dialects.postgresql.dialect(),
1579        )
1580
1581    def test_compile_composite(self):
1582        self.assert_compile(
1583            self.composite, "UTYPEONE", use_default_dialect=True
1584        )
1585        self.assert_compile(
1586            self.composite, "UTYPETHREE", dialect=dialects.mysql.dialect()
1587        )
1588        self.assert_compile(
1589            self.composite, "UTYPETWO", dialect=dialects.postgresql.dialect()
1590        )
1591
1592    def test_bind_process(self):
1593        eq_(
1594            self.variant._cached_bind_processor(dialects.mysql.dialect())(
1595                "foo"
1596            ),
1597            "fooUONE",
1598        )
1599        eq_(
1600            self.variant._cached_bind_processor(default.DefaultDialect())(
1601                "foo"
1602            ),
1603            "fooUONE",
1604        )
1605        eq_(
1606            self.variant._cached_bind_processor(dialects.postgresql.dialect())(
1607                "foo"
1608            ),
1609            "fooUTWO",
1610        )
1611
1612    def test_bind_process_composite(self):
1613        assert (
1614            self.composite._cached_bind_processor(dialects.mysql.dialect())
1615            is None
1616        )
1617        eq_(
1618            self.composite._cached_bind_processor(default.DefaultDialect())(
1619                "foo"
1620            ),
1621            "fooUONE",
1622        )
1623        eq_(
1624            self.composite._cached_bind_processor(
1625                dialects.postgresql.dialect()
1626            )("foo"),
1627            "fooUTWO",
1628        )
1629
1630    def test_comparator_variant(self):
1631        expr = column("x", self.variant) == "bar"
1632        is_(expr.right.type, self.variant)
1633
1634    @testing.only_on("sqlite")
1635    @testing.provide_metadata
1636    def test_round_trip(self, connection):
1637        variant = self.UTypeOne().with_variant(self.UTypeTwo(), "sqlite")
1638
1639        t = Table("t", self.metadata, Column("x", variant))
1640        t.create(connection)
1641
1642        connection.execute(t.insert(), dict(x="foo"))
1643
1644        eq_(connection.scalar(select(t.c.x).where(t.c.x == "foo")), "fooUTWO")
1645
1646    @testing.only_on("sqlite")
1647    @testing.provide_metadata
1648    def test_round_trip_sqlite_datetime(self, connection):
1649        variant = DateTime().with_variant(
1650            dialects.sqlite.DATETIME(truncate_microseconds=True), "sqlite"
1651        )
1652
1653        t = Table("t", self.metadata, Column("x", variant))
1654        t.create(connection)
1655
1656        connection.execute(
1657            t.insert(),
1658            dict(x=datetime.datetime(2015, 4, 18, 10, 15, 17, 4839)),
1659        )
1660
1661        eq_(
1662            connection.scalar(
1663                select(t.c.x).where(
1664                    t.c.x == datetime.datetime(2015, 4, 18, 10, 15, 17, 1059)
1665                )
1666            ),
1667            datetime.datetime(2015, 4, 18, 10, 15, 17),
1668        )
1669
1670
1671class UnicodeTest(fixtures.TestBase):
1672
1673    """Exercise the Unicode and related types.
1674
1675    Note:  unicode round trip tests are now in
1676    sqlalchemy/testing/suite/test_types.py.
1677
1678    """
1679
1680    __backend__ = True
1681
1682    data = util.u(
1683        "Alors vous imaginez ma surprise, au lever du jour, quand "
1684        "une drôle de petite voix m’a réveillé. "
1685        "Elle disait: « S’il vous plaît… dessine-moi un mouton! »"
1686    )
1687
1688    def test_unicode_warnings_typelevel_native_unicode(self):
1689
1690        unicodedata = self.data
1691        u = Unicode()
1692        dialect = default.DefaultDialect()
1693        dialect.supports_unicode_binds = True
1694        uni = u.dialect_impl(dialect).bind_processor(dialect)
1695        if util.py3k:
1696            assert_raises(exc.SAWarning, uni, b"x")
1697            assert isinstance(uni(unicodedata), str)
1698        else:
1699            assert_raises(exc.SAWarning, uni, "x")
1700            assert isinstance(uni(unicodedata), unicode)  # noqa
1701
1702    def test_unicode_warnings_typelevel_sqla_unicode(self):
1703        unicodedata = self.data
1704        u = Unicode()
1705        dialect = default.DefaultDialect()
1706        dialect.supports_unicode_binds = False
1707        uni = u.dialect_impl(dialect).bind_processor(dialect)
1708        assert_raises(exc.SAWarning, uni, util.b("x"))
1709        assert isinstance(uni(unicodedata), util.binary_type)
1710
1711        eq_(uni(unicodedata), unicodedata.encode("utf-8"))
1712
1713    def test_unicode_warnings_totally_wrong_type(self):
1714        u = Unicode()
1715        dialect = default.DefaultDialect()
1716        dialect.supports_unicode_binds = False
1717        uni = u.dialect_impl(dialect).bind_processor(dialect)
1718        with expect_warnings(
1719            "Unicode type received non-unicode bind param value 5."
1720        ):
1721            eq_(uni(5), 5)
1722
1723
1724class EnumTest(AssertsCompiledSQL, fixtures.TablesTest):
1725    __backend__ = True
1726
1727    SomeEnum = pep435_enum("SomeEnum")
1728
1729    one = SomeEnum("one", 1)
1730    two = SomeEnum("two", 2)
1731    three = SomeEnum("three", 3, "four")
1732    a_member = SomeEnum("AMember", "a")
1733    b_member = SomeEnum("BMember", "b")
1734
1735    SomeOtherEnum = pep435_enum("SomeOtherEnum")
1736
1737    other_one = SomeOtherEnum("one", 1)
1738    other_two = SomeOtherEnum("two", 2)
1739    other_three = SomeOtherEnum("three", 3)
1740    other_a_member = SomeOtherEnum("AMember", "a")
1741    other_b_member = SomeOtherEnum("BMember", "b")
1742
1743    @staticmethod
1744    def get_enum_string_values(some_enum):
1745        return [str(v.value) for v in some_enum.__members__.values()]
1746
1747    @classmethod
1748    def define_tables(cls, metadata):
1749        # note create_constraint has changed in 1.4 as of #5367
1750        Table(
1751            "enum_table",
1752            metadata,
1753            Column("id", Integer, primary_key=True),
1754            Column(
1755                "someenum",
1756                Enum(
1757                    "one",
1758                    "two",
1759                    "three",
1760                    name="myenum",
1761                    create_constraint=True,
1762                ),
1763            ),
1764        )
1765
1766        Table(
1767            "non_native_enum_table",
1768            metadata,
1769            Column("id", Integer, primary_key=True, autoincrement=False),
1770            Column(
1771                "someenum",
1772                Enum(
1773                    "one",
1774                    "two",
1775                    "three",
1776                    native_enum=False,
1777                    create_constraint=True,
1778                ),
1779            ),
1780            Column(
1781                "someotherenum",
1782                Enum(
1783                    "one",
1784                    "two",
1785                    "three",
1786                    native_enum=False,
1787                    validate_strings=True,
1788                ),
1789            ),
1790        )
1791
1792        Table(
1793            "stdlib_enum_table",
1794            metadata,
1795            Column("id", Integer, primary_key=True),
1796            Column(
1797                "someenum",
1798                Enum(cls.SomeEnum, create_constraint=True, omit_aliases=False),
1799            ),
1800        )
1801        Table(
1802            "stdlib_enum_table_no_alias",
1803            metadata,
1804            Column("id", Integer, primary_key=True),
1805            Column(
1806                "someenum",
1807                Enum(
1808                    cls.SomeEnum,
1809                    create_constraint=True,
1810                    omit_aliases=True,
1811                    name="someenum_no_alias",
1812                ),
1813            ),
1814        )
1815
1816        Table(
1817            "stdlib_enum_table2",
1818            metadata,
1819            Column("id", Integer, primary_key=True),
1820            Column(
1821                "someotherenum",
1822                Enum(
1823                    cls.SomeOtherEnum,
1824                    values_callable=EnumTest.get_enum_string_values,
1825                    create_constraint=True,
1826                ),
1827            ),
1828        )
1829
1830    def test_python_type(self):
1831        eq_(types.Enum(self.SomeOtherEnum).python_type, self.SomeOtherEnum)
1832
1833    def test_pickle_types(self):
1834        global SomeEnum
1835        SomeEnum = self.SomeEnum
1836        for loads, dumps in picklers():
1837            column_types = [
1838                Column("Enu", Enum("x", "y", "z", name="somename")),
1839                Column("En2", Enum(self.SomeEnum, omit_aliases=False)),
1840            ]
1841            for column_type in column_types:
1842                meta = MetaData()
1843                Table("foo", meta, column_type)
1844                loads(dumps(column_type))
1845                loads(dumps(meta))
1846
1847    def test_validators_pep435(self):
1848        type_ = Enum(self.SomeEnum, omit_aliases=False)
1849        validate_type = Enum(
1850            self.SomeEnum, validate_strings=True, omit_aliases=False
1851        )
1852
1853        bind_processor = type_.bind_processor(testing.db.dialect)
1854        bind_processor_validates = validate_type.bind_processor(
1855            testing.db.dialect
1856        )
1857        eq_(bind_processor("one"), "one")
1858        eq_(bind_processor(self.one), "one")
1859        eq_(bind_processor("foo"), "foo")
1860        assert_raises_message(
1861            LookupError,
1862            "'5' is not among the defined enum values. Enum name: someenum. "
1863            "Possible values: one, two, three, ..., BMember",
1864            bind_processor,
1865            5,
1866        )
1867
1868        assert_raises_message(
1869            LookupError,
1870            "'foo' is not among the defined enum values. Enum name: someenum. "
1871            "Possible values: one, two, three, ..., BMember",
1872            bind_processor_validates,
1873            "foo",
1874        )
1875
1876        result_processor = type_.result_processor(testing.db.dialect, None)
1877
1878        eq_(result_processor("one"), self.one)
1879        assert_raises_message(
1880            LookupError,
1881            "'foo' is not among the defined enum values. Enum name: someenum. "
1882            "Possible values: one, two, three, ..., BMember",
1883            result_processor,
1884            "foo",
1885        )
1886
1887        literal_processor = type_.literal_processor(testing.db.dialect)
1888        validate_literal_processor = validate_type.literal_processor(
1889            testing.db.dialect
1890        )
1891        eq_(literal_processor("one"), "'one'")
1892
1893        eq_(literal_processor("foo"), "'foo'")
1894
1895        assert_raises_message(
1896            LookupError,
1897            "'5' is not among the defined enum values. Enum name: someenum. "
1898            "Possible values: one, two, three, ..., BMember",
1899            literal_processor,
1900            5,
1901        )
1902
1903        assert_raises_message(
1904            LookupError,
1905            "'foo' is not among the defined enum values. Enum name: someenum. "
1906            "Possible values: one, two, three, ..., BMember",
1907            validate_literal_processor,
1908            "foo",
1909        )
1910
1911    def test_validators_plain(self):
1912        type_ = Enum("one", "two")
1913        validate_type = Enum("one", "two", validate_strings=True)
1914
1915        bind_processor = type_.bind_processor(testing.db.dialect)
1916        bind_processor_validates = validate_type.bind_processor(
1917            testing.db.dialect
1918        )
1919        eq_(bind_processor("one"), "one")
1920        eq_(bind_processor("foo"), "foo")
1921        assert_raises_message(
1922            LookupError,
1923            "'5' is not among the defined enum values. Enum name: None. "
1924            "Possible values: one, two",
1925            bind_processor,
1926            5,
1927        )
1928
1929        assert_raises_message(
1930            LookupError,
1931            "'foo' is not among the defined enum values. Enum name: None. "
1932            "Possible values: one, two",
1933            bind_processor_validates,
1934            "foo",
1935        )
1936
1937        result_processor = type_.result_processor(testing.db.dialect, None)
1938
1939        eq_(result_processor("one"), "one")
1940        assert_raises_message(
1941            LookupError,
1942            "'foo' is not among the defined enum values. Enum name: None. "
1943            "Possible values: one, two",
1944            result_processor,
1945            "foo",
1946        )
1947
1948        literal_processor = type_.literal_processor(testing.db.dialect)
1949        validate_literal_processor = validate_type.literal_processor(
1950            testing.db.dialect
1951        )
1952        eq_(literal_processor("one"), "'one'")
1953        eq_(literal_processor("foo"), "'foo'")
1954        assert_raises_message(
1955            LookupError,
1956            "'5' is not among the defined enum values. Enum name: None. "
1957            "Possible values: one, two",
1958            literal_processor,
1959            5,
1960        )
1961
1962        assert_raises_message(
1963            LookupError,
1964            "'foo' is not among the defined enum values. Enum name: None. "
1965            "Possible values: one, two",
1966            validate_literal_processor,
1967            "foo",
1968        )
1969
1970    def test_enum_raise_lookup_ellipses(self):
1971        type_ = Enum("one", "twothreefourfivesix", "seven", "eight")
1972        bind_processor = type_.bind_processor(testing.db.dialect)
1973
1974        eq_(bind_processor("one"), "one")
1975        assert_raises_message(
1976            LookupError,
1977            "'5' is not among the defined enum values. Enum name: None. "
1978            "Possible values: one, twothreefou.., seven, eight",
1979            bind_processor,
1980            5,
1981        )
1982
1983    def test_enum_raise_lookup_none(self):
1984        type_ = Enum()
1985        bind_processor = type_.bind_processor(testing.db.dialect)
1986
1987        assert_raises_message(
1988            LookupError,
1989            "'5' is not among the defined enum values. Enum name: None. "
1990            "Possible values: None",
1991            bind_processor,
1992            5,
1993        )
1994
1995    def test_validators_not_in_like_roundtrip(self, connection):
1996        enum_table = self.tables["non_native_enum_table"]
1997
1998        connection.execute(
1999            enum_table.insert(),
2000            [
2001                {"id": 1, "someenum": "two"},
2002                {"id": 2, "someenum": "two"},
2003                {"id": 3, "someenum": "one"},
2004            ],
2005        )
2006
2007        eq_(
2008            connection.execute(
2009                enum_table.select()
2010                .where(enum_table.c.someenum.like("%wo%"))
2011                .order_by(enum_table.c.id)
2012            ).fetchall(),
2013            [(1, "two", None), (2, "two", None)],
2014        )
2015
2016    def test_validators_not_in_concatenate_roundtrip(self, connection):
2017        enum_table = self.tables["non_native_enum_table"]
2018
2019        connection.execute(
2020            enum_table.insert(),
2021            [
2022                {"id": 1, "someenum": "two"},
2023                {"id": 2, "someenum": "two"},
2024                {"id": 3, "someenum": "one"},
2025            ],
2026        )
2027
2028        eq_(
2029            connection.execute(
2030                select("foo" + enum_table.c.someenum).order_by(enum_table.c.id)
2031            ).fetchall(),
2032            [("footwo",), ("footwo",), ("fooone",)],
2033        )
2034
2035    def test_round_trip(self, connection):
2036        enum_table = self.tables["enum_table"]
2037
2038        connection.execute(
2039            enum_table.insert(),
2040            [
2041                {"id": 1, "someenum": "two"},
2042                {"id": 2, "someenum": "two"},
2043                {"id": 3, "someenum": "one"},
2044            ],
2045        )
2046
2047        eq_(
2048            connection.execute(
2049                enum_table.select().order_by(enum_table.c.id)
2050            ).fetchall(),
2051            [(1, "two"), (2, "two"), (3, "one")],
2052        )
2053
2054    def test_null_round_trip(self, connection):
2055        enum_table = self.tables.enum_table
2056        non_native_enum_table = self.tables.non_native_enum_table
2057
2058        connection.execute(enum_table.insert(), {"id": 1, "someenum": None})
2059        eq_(connection.scalar(select(enum_table.c.someenum)), None)
2060
2061        connection.execute(
2062            non_native_enum_table.insert(), {"id": 1, "someenum": None}
2063        )
2064        eq_(connection.scalar(select(non_native_enum_table.c.someenum)), None)
2065
2066    @testing.requires.enforces_check_constraints
2067    def test_check_constraint(self, connection):
2068        assert_raises(
2069            (
2070                exc.IntegrityError,
2071                exc.ProgrammingError,
2072                exc.OperationalError,
2073                # PyMySQL raising InternalError until
2074                # https://github.com/PyMySQL/PyMySQL/issues/607 is resolved
2075                exc.InternalError,
2076            ),
2077            connection.exec_driver_sql,
2078            "insert into non_native_enum_table "
2079            "(id, someenum) values(1, 'four')",
2080        )
2081
2082    @testing.requires.enforces_check_constraints
2083    @testing.provide_metadata
2084    def test_variant_we_are_default(self):
2085        # test that the "variant" does not create a constraint
2086        t = Table(
2087            "my_table",
2088            self.metadata,
2089            Column(
2090                "data",
2091                Enum(
2092                    "one",
2093                    "two",
2094                    "three",
2095                    native_enum=False,
2096                    name="e1",
2097                    create_constraint=True,
2098                ).with_variant(
2099                    Enum(
2100                        "four",
2101                        "five",
2102                        "six",
2103                        native_enum=False,
2104                        name="e2",
2105                        create_constraint=True,
2106                    ),
2107                    "some_other_db",
2108                ),
2109            ),
2110            mysql_engine="InnoDB",
2111        )
2112
2113        eq_(
2114            len([c for c in t.constraints if isinstance(c, CheckConstraint)]),
2115            2,
2116        )
2117
2118        self.metadata.create_all(testing.db)
2119
2120        # not using the connection fixture because we need to rollback and
2121        # start again in the middle
2122        with testing.db.connect() as connection:
2123            # postgresql needs this in order to continue after the exception
2124            trans = connection.begin()
2125            assert_raises(
2126                (exc.DBAPIError,),
2127                connection.exec_driver_sql,
2128                "insert into my_table " "(data) values('four')",
2129            )
2130            trans.rollback()
2131
2132            with connection.begin():
2133                connection.exec_driver_sql(
2134                    "insert into my_table (data) values ('two')"
2135                )
2136                eq_(connection.execute(select(t.c.data)).scalar(), "two")
2137
2138    @testing.requires.enforces_check_constraints
2139    @testing.provide_metadata
2140    def test_variant_we_are_not_default(self):
2141        # test that the "variant" does not create a constraint
2142        t = Table(
2143            "my_table",
2144            self.metadata,
2145            Column(
2146                "data",
2147                Enum(
2148                    "one",
2149                    "two",
2150                    "three",
2151                    native_enum=False,
2152                    name="e1",
2153                    create_constraint=True,
2154                ).with_variant(
2155                    Enum(
2156                        "four",
2157                        "five",
2158                        "six",
2159                        native_enum=False,
2160                        name="e2",
2161                        create_constraint=True,
2162                    ),
2163                    testing.db.dialect.name,
2164                ),
2165            ),
2166        )
2167
2168        # ensure Variant isn't exploding the constraints
2169        eq_(
2170            len([c for c in t.constraints if isinstance(c, CheckConstraint)]),
2171            2,
2172        )
2173
2174        self.metadata.create_all(testing.db)
2175
2176        # not using the connection fixture because we need to rollback and
2177        # start again in the middle
2178        with testing.db.connect() as connection:
2179            # postgresql needs this in order to continue after the exception
2180            trans = connection.begin()
2181            assert_raises(
2182                (exc.DBAPIError,),
2183                connection.exec_driver_sql,
2184                "insert into my_table (data) values('two')",
2185            )
2186            trans.rollback()
2187
2188            with connection.begin():
2189                connection.exec_driver_sql(
2190                    "insert into my_table (data) values ('four')"
2191                )
2192                eq_(connection.execute(select(t.c.data)).scalar(), "four")
2193
2194    def test_skip_check_constraint(self, connection):
2195        connection.exec_driver_sql(
2196            "insert into non_native_enum_table "
2197            "(id, someotherenum) values(1, 'four')"
2198        )
2199        eq_(
2200            connection.exec_driver_sql(
2201                "select someotherenum from non_native_enum_table"
2202            ).scalar(),
2203            "four",
2204        )
2205        assert_raises_message(
2206            LookupError,
2207            "'four' is not among the defined enum values. "
2208            "Enum name: None. Possible values: one, two, three",
2209            connection.scalar,
2210            select(self.tables.non_native_enum_table.c.someotherenum),
2211        )
2212
2213    def test_non_native_round_trip(self, connection):
2214        non_native_enum_table = self.tables["non_native_enum_table"]
2215
2216        connection.execute(
2217            non_native_enum_table.insert(),
2218            [
2219                {"id": 1, "someenum": "two"},
2220                {"id": 2, "someenum": "two"},
2221                {"id": 3, "someenum": "one"},
2222            ],
2223        )
2224
2225        eq_(
2226            connection.execute(
2227                select(
2228                    non_native_enum_table.c.id,
2229                    non_native_enum_table.c.someenum,
2230                ).order_by(non_native_enum_table.c.id)
2231            ).fetchall(),
2232            [(1, "two"), (2, "two"), (3, "one")],
2233        )
2234
2235    def test_pep435_default_sort_key(self):
2236        one, two, a_member, b_member = (
2237            self.one,
2238            self.two,
2239            self.a_member,
2240            self.b_member,
2241        )
2242        typ = Enum(self.SomeEnum, omit_aliases=False)
2243
2244        is_(typ.sort_key_function.__func__, typ._db_value_for_elem.__func__)
2245
2246        eq_(
2247            sorted([two, one, a_member, b_member], key=typ.sort_key_function),
2248            [a_member, b_member, one, two],
2249        )
2250
2251    def test_pep435_custom_sort_key(self):
2252        one, two, a_member, b_member = (
2253            self.one,
2254            self.two,
2255            self.a_member,
2256            self.b_member,
2257        )
2258
2259        def sort_enum_key_value(value):
2260            return str(value.value)
2261
2262        typ = Enum(
2263            self.SomeEnum,
2264            sort_key_function=sort_enum_key_value,
2265            omit_aliases=False,
2266        )
2267        is_(typ.sort_key_function, sort_enum_key_value)
2268
2269        eq_(
2270            sorted([two, one, a_member, b_member], key=typ.sort_key_function),
2271            [one, two, a_member, b_member],
2272        )
2273
2274    def test_pep435_no_sort_key(self):
2275        typ = Enum(self.SomeEnum, sort_key_function=None, omit_aliases=False)
2276        is_(typ.sort_key_function, None)
2277
2278    def test_pep435_enum_round_trip(self, connection):
2279        stdlib_enum_table = self.tables["stdlib_enum_table"]
2280
2281        connection.execute(
2282            stdlib_enum_table.insert(),
2283            [
2284                {"id": 1, "someenum": self.SomeEnum.two},
2285                {"id": 2, "someenum": self.SomeEnum.two},
2286                {"id": 3, "someenum": self.SomeEnum.one},
2287                {"id": 4, "someenum": self.SomeEnum.three},
2288                {"id": 5, "someenum": self.SomeEnum.four},
2289                {"id": 6, "someenum": "three"},
2290                {"id": 7, "someenum": "four"},
2291            ],
2292        )
2293
2294        eq_(
2295            connection.execute(
2296                stdlib_enum_table.select().order_by(stdlib_enum_table.c.id)
2297            ).fetchall(),
2298            [
2299                (1, self.SomeEnum.two),
2300                (2, self.SomeEnum.two),
2301                (3, self.SomeEnum.one),
2302                (4, self.SomeEnum.three),
2303                (5, self.SomeEnum.three),
2304                (6, self.SomeEnum.three),
2305                (7, self.SomeEnum.three),
2306            ],
2307        )
2308
2309    def test_pep435_enum_values_callable_round_trip(self, connection):
2310        stdlib_enum_table_custom_values = self.tables["stdlib_enum_table2"]
2311
2312        connection.execute(
2313            stdlib_enum_table_custom_values.insert(),
2314            [
2315                {"id": 1, "someotherenum": self.SomeOtherEnum.AMember},
2316                {"id": 2, "someotherenum": self.SomeOtherEnum.BMember},
2317                {"id": 3, "someotherenum": self.SomeOtherEnum.AMember},
2318            ],
2319        )
2320
2321        eq_(
2322            connection.execute(
2323                stdlib_enum_table_custom_values.select().order_by(
2324                    stdlib_enum_table_custom_values.c.id
2325                )
2326            ).fetchall(),
2327            [
2328                (1, self.SomeOtherEnum.AMember),
2329                (2, self.SomeOtherEnum.BMember),
2330                (3, self.SomeOtherEnum.AMember),
2331            ],
2332        )
2333
2334    def test_pep435_enum_expanding_in(self, connection):
2335        stdlib_enum_table_custom_values = self.tables["stdlib_enum_table2"]
2336
2337        connection.execute(
2338            stdlib_enum_table_custom_values.insert(),
2339            [
2340                {"id": 1, "someotherenum": self.SomeOtherEnum.one},
2341                {"id": 2, "someotherenum": self.SomeOtherEnum.two},
2342                {"id": 3, "someotherenum": self.SomeOtherEnum.three},
2343            ],
2344        )
2345
2346        stmt = (
2347            stdlib_enum_table_custom_values.select()
2348            .where(
2349                stdlib_enum_table_custom_values.c.someotherenum.in_(
2350                    bindparam("member", expanding=True)
2351                )
2352            )
2353            .order_by(stdlib_enum_table_custom_values.c.id)
2354        )
2355        eq_(
2356            connection.execute(
2357                stmt,
2358                {"member": [self.SomeOtherEnum.one, self.SomeOtherEnum.three]},
2359            ).fetchall(),
2360            [(1, self.SomeOtherEnum.one), (3, self.SomeOtherEnum.three)],
2361        )
2362
2363    def test_adapt(self):
2364        from sqlalchemy.dialects.postgresql import ENUM
2365
2366        e1 = Enum("one", "two", "three", native_enum=False)
2367
2368        false_adapt = e1.adapt(ENUM)
2369        eq_(false_adapt.native_enum, False)
2370        assert not isinstance(false_adapt, ENUM)
2371
2372        e1 = Enum("one", "two", "three", native_enum=True)
2373        true_adapt = e1.adapt(ENUM)
2374        eq_(true_adapt.native_enum, True)
2375        assert isinstance(true_adapt, ENUM)
2376
2377        e1 = Enum(
2378            "one",
2379            "two",
2380            "three",
2381            name="foo",
2382            schema="bar",
2383            metadata=MetaData(),
2384        )
2385        eq_(e1.adapt(ENUM).name, "foo")
2386        eq_(e1.adapt(ENUM).schema, "bar")
2387        is_(e1.adapt(ENUM).metadata, e1.metadata)
2388        eq_(e1.adapt(Enum).name, "foo")
2389        eq_(e1.adapt(Enum).schema, "bar")
2390        is_(e1.adapt(Enum).metadata, e1.metadata)
2391        e1 = Enum(self.SomeEnum, omit_aliases=False)
2392        eq_(e1.adapt(ENUM).name, "someenum")
2393        eq_(
2394            e1.adapt(ENUM).enums,
2395            ["one", "two", "three", "four", "AMember", "BMember"],
2396        )
2397
2398        e1_vc = Enum(
2399            self.SomeOtherEnum, values_callable=EnumTest.get_enum_string_values
2400        )
2401        eq_(e1_vc.adapt(ENUM).name, "someotherenum")
2402        eq_(e1_vc.adapt(ENUM).enums, ["1", "2", "3", "a", "b"])
2403
2404    def test_adapt_length(self):
2405        from sqlalchemy.dialects.postgresql import ENUM
2406
2407        e1 = Enum("one", "two", "three", length=50, native_enum=False)
2408        eq_(e1.adapt(ENUM).length, 50)
2409        eq_(e1.adapt(Enum).length, 50)
2410
2411        e1 = Enum("one", "two", "three")
2412        eq_(e1.length, 5)
2413        eq_(e1.adapt(ENUM).length, 5)
2414        eq_(e1.adapt(Enum).length, 5)
2415
2416    @testing.provide_metadata
2417    def test_create_metadata_bound_no_crash(self):
2418        m1 = self.metadata
2419        Enum("a", "b", "c", metadata=m1, name="ncenum")
2420
2421        m1.create_all(testing.db)
2422
2423    def test_non_native_constraint_custom_type(self):
2424        class Foob(object):
2425            def __init__(self, name):
2426                self.name = name
2427
2428        class MyEnum(TypeDecorator):
2429            cache_ok = True
2430
2431            def __init__(self, values):
2432                self.impl = Enum(
2433                    *[v.name for v in values],
2434                    name="myenum",
2435                    native_enum=False,
2436                    create_constraint=True
2437                )
2438
2439            # future method
2440            def process_literal_param(self, value, dialect):
2441                return value.name
2442
2443            def process_bind_param(self, value, dialect):
2444                return value.name
2445
2446        m = MetaData()
2447        t1 = Table("t", m, Column("x", MyEnum([Foob("a"), Foob("b")])))
2448        const = [c for c in t1.constraints if isinstance(c, CheckConstraint)][
2449            0
2450        ]
2451
2452        self.assert_compile(
2453            AddConstraint(const),
2454            "ALTER TABLE t ADD CONSTRAINT myenum CHECK (x IN ('a', 'b'))",
2455            dialect="default",
2456        )
2457
2458    def test_lookup_failure(self, connection):
2459        assert_raises(
2460            exc.StatementError,
2461            connection.execute,
2462            self.tables["non_native_enum_table"].insert(),
2463            {"id": 4, "someotherenum": "four"},
2464        )
2465
2466    def test_mock_engine_no_prob(self):
2467        """ensure no 'checkfirst' queries are run when enums
2468        are created with checkfirst=False"""
2469
2470        e = engines.mock_engine()
2471        t = Table(
2472            "t1",
2473            MetaData(),
2474            Column("x", Enum("x", "y", name="pge", create_constraint=True)),
2475        )
2476        t.create(e, checkfirst=False)
2477        # basically looking for the start of
2478        # the constraint, or the ENUM def itself,
2479        # depending on backend.
2480        assert "('x'," in e.print_sql()
2481
2482    @testing.uses_deprecated(".*convert_unicode")
2483    def test_repr(self):
2484        e = Enum(
2485            "x",
2486            "y",
2487            name="somename",
2488            convert_unicode=True,
2489            quote=True,
2490            inherit_schema=True,
2491            native_enum=False,
2492        )
2493        eq_(
2494            repr(e),
2495            "Enum('x', 'y', name='somename', "
2496            "inherit_schema=True, native_enum=False)",
2497        )
2498
2499    def test_length_native(self):
2500        e = Enum("x", "y", "long", length=42)
2501
2502        eq_(e.length, len("long"))
2503
2504        # no error is raised
2505        e = Enum("x", "y", "long", length=1)
2506        eq_(e.length, len("long"))
2507
2508    def test_length_raises(self):
2509        assert_raises_message(
2510            ValueError,
2511            "When provided, length must be larger or equal.*",
2512            Enum,
2513            "x",
2514            "y",
2515            "long",
2516            native_enum=False,
2517            length=1,
2518        )
2519
2520    def test_no_length_non_native(self):
2521        e = Enum("x", "y", "long", native_enum=False)
2522        eq_(e.length, len("long"))
2523
2524    def test_length_non_native(self):
2525        e = Enum("x", "y", "long", native_enum=False, length=42)
2526        eq_(e.length, 42)
2527
2528    def test_omit_aliases(self, connection):
2529        table0 = self.tables["stdlib_enum_table"]
2530        type0 = table0.c.someenum.type
2531        eq_(type0.enums, ["one", "two", "three", "four", "AMember", "BMember"])
2532
2533        table = self.tables["stdlib_enum_table_no_alias"]
2534
2535        type_ = table.c.someenum.type
2536        eq_(type_.enums, ["one", "two", "three", "AMember", "BMember"])
2537
2538        connection.execute(
2539            table.insert(),
2540            [
2541                {"id": 1, "someenum": self.SomeEnum.three},
2542                {"id": 2, "someenum": self.SomeEnum.four},
2543            ],
2544        )
2545        eq_(
2546            connection.execute(table.select().order_by(table.c.id)).fetchall(),
2547            [(1, self.SomeEnum.three), (2, self.SomeEnum.three)],
2548        )
2549
2550    def test_omit_warn(self):
2551        with expect_deprecated_20(
2552            r"The provided enum someenum contains the aliases \['four'\]"
2553        ):
2554            Enum(self.SomeEnum)
2555
2556    @testing.combinations(
2557        (True, "native"), (False, "non_native"), id_="ai", argnames="native"
2558    )
2559    @testing.combinations(
2560        (True, "omit_alias"), (False, "with_alias"), id_="ai", argnames="omit"
2561    )
2562    @testing.provide_metadata
2563    @testing.skip_if("mysql < 8")
2564    def test_duplicate_values_accepted(self, native, omit):
2565        foo_enum = pep435_enum("foo_enum")
2566        foo_enum("one", 1, "two")
2567        foo_enum("three", 3, "four")
2568        tbl = sa.Table(
2569            "foo_table",
2570            self.metadata,
2571            sa.Column("id", sa.Integer),
2572            sa.Column(
2573                "data",
2574                sa.Enum(
2575                    foo_enum,
2576                    native_enum=native,
2577                    omit_aliases=omit,
2578                    create_constraint=True,
2579                ),
2580            ),
2581        )
2582        t = sa.table("foo_table", sa.column("id"), sa.column("data"))
2583
2584        self.metadata.create_all(testing.db)
2585        if omit:
2586            with expect_raises(
2587                (
2588                    exc.IntegrityError,
2589                    exc.DataError,
2590                    exc.OperationalError,
2591                    exc.DBAPIError,
2592                )
2593            ):
2594                with testing.db.begin() as conn:
2595                    conn.execute(
2596                        t.insert(),
2597                        [
2598                            {"id": 1, "data": "four"},
2599                            {"id": 2, "data": "three"},
2600                        ],
2601                    )
2602        else:
2603            with testing.db.begin() as conn:
2604                conn.execute(
2605                    t.insert(),
2606                    [{"id": 1, "data": "four"}, {"id": 2, "data": "three"}],
2607                )
2608
2609                eq_(
2610                    conn.execute(t.select().order_by(t.c.id)).fetchall(),
2611                    [(1, "four"), (2, "three")],
2612                )
2613                eq_(
2614                    conn.execute(tbl.select().order_by(tbl.c.id)).fetchall(),
2615                    [(1, foo_enum.three), (2, foo_enum.three)],
2616                )
2617
2618
2619MyPickleType = None
2620
2621
2622class BinaryTest(fixtures.TablesTest, AssertsExecutionResults):
2623    __backend__ = True
2624
2625    @classmethod
2626    def define_tables(cls, metadata):
2627        global MyPickleType
2628
2629        class MyPickleType(types.TypeDecorator):
2630            impl = PickleType
2631            cache_ok = True
2632
2633            def process_bind_param(self, value, dialect):
2634                if value:
2635                    value.stuff = "this is modified stuff"
2636                return value
2637
2638            def process_result_value(self, value, dialect):
2639                if value:
2640                    value.stuff = "this is the right stuff"
2641                return value
2642
2643        Table(
2644            "binary_table",
2645            metadata,
2646            Column(
2647                "primary_id",
2648                Integer,
2649                primary_key=True,
2650                test_needs_autoincrement=True,
2651            ),
2652            Column("data", LargeBinary),
2653            Column("data_slice", LargeBinary(100)),
2654            Column("misc", String(30)),
2655            Column("pickled", PickleType),
2656            Column("mypickle", MyPickleType),
2657        )
2658
2659    @testing.requires.non_broken_binary
2660    def test_round_trip(self, connection):
2661        binary_table = self.tables.binary_table
2662
2663        testobj1 = pickleable.Foo("im foo 1")
2664        testobj2 = pickleable.Foo("im foo 2")
2665        testobj3 = pickleable.Foo("im foo 3")
2666
2667        stream1 = self.load_stream("binary_data_one.dat")
2668        stream2 = self.load_stream("binary_data_two.dat")
2669        connection.execute(
2670            binary_table.insert(),
2671            dict(
2672                primary_id=1,
2673                misc="binary_data_one.dat",
2674                data=stream1,
2675                data_slice=stream1[0:100],
2676                pickled=testobj1,
2677                mypickle=testobj3,
2678            ),
2679        )
2680        connection.execute(
2681            binary_table.insert(),
2682            dict(
2683                primary_id=2,
2684                misc="binary_data_two.dat",
2685                data=stream2,
2686                data_slice=stream2[0:99],
2687                pickled=testobj2,
2688            ),
2689        )
2690        connection.execute(
2691            binary_table.insert(),
2692            dict(
2693                primary_id=3,
2694                misc="binary_data_two.dat",
2695                data=None,
2696                data_slice=stream2[0:99],
2697                pickled=None,
2698            ),
2699        )
2700
2701        for stmt in (
2702            binary_table.select().order_by(binary_table.c.primary_id),
2703            text(
2704                "select * from binary_table order by binary_table.primary_id",
2705            ).columns(
2706                **{
2707                    "pickled": PickleType,
2708                    "mypickle": MyPickleType,
2709                    "data": LargeBinary,
2710                    "data_slice": LargeBinary,
2711                }
2712            ),
2713        ):
2714            result = connection.execute(stmt).fetchall()
2715            eq_(stream1, result[0]._mapping["data"])
2716            eq_(stream1[0:100], result[0]._mapping["data_slice"])
2717            eq_(stream2, result[1]._mapping["data"])
2718            eq_(testobj1, result[0]._mapping["pickled"])
2719            eq_(testobj2, result[1]._mapping["pickled"])
2720            eq_(testobj3.moredata, result[0]._mapping["mypickle"].moredata)
2721            eq_(
2722                result[0]._mapping["mypickle"].stuff, "this is the right stuff"
2723            )
2724
2725    @testing.requires.binary_comparisons
2726    def test_comparison(self, connection):
2727        """test that type coercion occurs on comparison for binary"""
2728        binary_table = self.tables.binary_table
2729
2730        expr = binary_table.c.data == "foo"
2731        assert isinstance(expr.right.type, LargeBinary)
2732
2733        data = os.urandom(32)
2734        connection.execute(binary_table.insert(), dict(data=data))
2735        eq_(
2736            connection.scalar(
2737                select(func.count("*"))
2738                .select_from(binary_table)
2739                .where(binary_table.c.data == data)
2740            ),
2741            1,
2742        )
2743
2744    @testing.requires.binary_literals
2745    def test_literal_roundtrip(self, connection):
2746        compiled = select(cast(literal(util.b("foo")), LargeBinary)).compile(
2747            dialect=testing.db.dialect, compile_kwargs={"literal_binds": True}
2748        )
2749        result = connection.execute(compiled)
2750        eq_(result.scalar(), util.b("foo"))
2751
2752    def test_bind_processor_no_dbapi(self):
2753        b = LargeBinary()
2754        eq_(b.bind_processor(default.DefaultDialect()), None)
2755
2756    def load_stream(self, name):
2757        f = os.path.join(os.path.dirname(__file__), "..", name)
2758        with open(f, mode="rb") as o:
2759            return o.read()
2760
2761
2762class JSONTest(fixtures.TestBase):
2763    def setup_test(self):
2764        metadata = MetaData()
2765        self.test_table = Table(
2766            "test_table",
2767            metadata,
2768            Column("id", Integer, primary_key=True),
2769            Column("test_column", JSON),
2770        )
2771        self.jsoncol = self.test_table.c.test_column
2772
2773        self.dialect = default.DefaultDialect()
2774        self.dialect._json_serializer = None
2775        self.dialect._json_deserializer = None
2776
2777    def test_bind_serialize_default(self):
2778        proc = self.test_table.c.test_column.type._cached_bind_processor(
2779            self.dialect
2780        )
2781        eq_(
2782            proc({"A": [1, 2, 3, True, False]}),
2783            '{"A": [1, 2, 3, true, false]}',
2784        )
2785
2786    def test_bind_serialize_None(self):
2787        proc = self.test_table.c.test_column.type._cached_bind_processor(
2788            self.dialect
2789        )
2790        eq_(proc(None), "null")
2791
2792    def test_bind_serialize_none_as_null(self):
2793        proc = JSON(none_as_null=True)._cached_bind_processor(self.dialect)
2794        eq_(proc(None), None)
2795        eq_(proc(null()), None)
2796
2797    def test_bind_serialize_null(self):
2798        proc = self.test_table.c.test_column.type._cached_bind_processor(
2799            self.dialect
2800        )
2801        eq_(proc(null()), None)
2802
2803    def test_result_deserialize_default(self):
2804        proc = self.test_table.c.test_column.type._cached_result_processor(
2805            self.dialect, None
2806        )
2807        eq_(
2808            proc('{"A": [1, 2, 3, true, false]}'),
2809            {"A": [1, 2, 3, True, False]},
2810        )
2811
2812    def test_result_deserialize_null(self):
2813        proc = self.test_table.c.test_column.type._cached_result_processor(
2814            self.dialect, None
2815        )
2816        eq_(proc("null"), None)
2817
2818    def test_result_deserialize_None(self):
2819        proc = self.test_table.c.test_column.type._cached_result_processor(
2820            self.dialect, None
2821        )
2822        eq_(proc(None), None)
2823
2824    def _dialect_index_fixture(self, int_processor, str_processor):
2825        class MyInt(Integer):
2826            def bind_processor(self, dialect):
2827                return lambda value: value + 10
2828
2829            def literal_processor(self, diaect):
2830                return lambda value: str(value + 15)
2831
2832        class MyString(String):
2833            def bind_processor(self, dialect):
2834                return lambda value: value + "10"
2835
2836            def literal_processor(self, diaect):
2837                return lambda value: value + "15"
2838
2839        class MyDialect(default.DefaultDialect):
2840            colspecs = {}
2841            if int_processor:
2842                colspecs[Integer] = MyInt
2843            if str_processor:
2844                colspecs[String] = MyString
2845
2846        return MyDialect()
2847
2848    def test_index_bind_proc_int(self):
2849        expr = self.test_table.c.test_column[5]
2850
2851        int_dialect = self._dialect_index_fixture(True, True)
2852        non_int_dialect = self._dialect_index_fixture(False, True)
2853
2854        bindproc = expr.right.type._cached_bind_processor(int_dialect)
2855        eq_(bindproc(expr.right.value), 15)
2856
2857        bindproc = expr.right.type._cached_bind_processor(non_int_dialect)
2858        eq_(bindproc(expr.right.value), 5)
2859
2860    def test_index_literal_proc_int(self):
2861        expr = self.test_table.c.test_column[5]
2862
2863        int_dialect = self._dialect_index_fixture(True, True)
2864        non_int_dialect = self._dialect_index_fixture(False, True)
2865
2866        bindproc = expr.right.type._cached_literal_processor(int_dialect)
2867        eq_(bindproc(expr.right.value), "20")
2868
2869        bindproc = expr.right.type._cached_literal_processor(non_int_dialect)
2870        eq_(bindproc(expr.right.value), "5")
2871
2872    def test_index_bind_proc_str(self):
2873        expr = self.test_table.c.test_column["five"]
2874
2875        str_dialect = self._dialect_index_fixture(True, True)
2876        non_str_dialect = self._dialect_index_fixture(False, False)
2877
2878        bindproc = expr.right.type._cached_bind_processor(str_dialect)
2879        eq_(bindproc(expr.right.value), "five10")
2880
2881        bindproc = expr.right.type._cached_bind_processor(non_str_dialect)
2882        eq_(bindproc(expr.right.value), "five")
2883
2884    def test_index_literal_proc_str(self):
2885        expr = self.test_table.c.test_column["five"]
2886
2887        str_dialect = self._dialect_index_fixture(True, True)
2888        non_str_dialect = self._dialect_index_fixture(False, False)
2889
2890        bindproc = expr.right.type._cached_literal_processor(str_dialect)
2891        eq_(bindproc(expr.right.value), "five15")
2892
2893        bindproc = expr.right.type._cached_literal_processor(non_str_dialect)
2894        eq_(bindproc(expr.right.value), "'five'")
2895
2896
2897class ArrayTest(fixtures.TestBase):
2898    def _myarray_fixture(self):
2899        class MyArray(ARRAY):
2900            pass
2901
2902        return MyArray
2903
2904    def test_array_index_map_dimensions(self):
2905        col = column("x", ARRAY(Integer, dimensions=3))
2906        is_(col[5].type._type_affinity, ARRAY)
2907        eq_(col[5].type.dimensions, 2)
2908        is_(col[5][6].type._type_affinity, ARRAY)
2909        eq_(col[5][6].type.dimensions, 1)
2910        is_(col[5][6][7].type._type_affinity, Integer)
2911
2912    def test_array_getitem_single_type(self):
2913        m = MetaData()
2914        arrtable = Table(
2915            "arrtable",
2916            m,
2917            Column("intarr", ARRAY(Integer)),
2918            Column("strarr", ARRAY(String)),
2919        )
2920        is_(arrtable.c.intarr[1].type._type_affinity, Integer)
2921        is_(arrtable.c.strarr[1].type._type_affinity, String)
2922
2923    def test_array_getitem_slice_type(self):
2924        m = MetaData()
2925        arrtable = Table(
2926            "arrtable",
2927            m,
2928            Column("intarr", ARRAY(Integer)),
2929            Column("strarr", ARRAY(String)),
2930        )
2931        is_(arrtable.c.intarr[1:3].type._type_affinity, ARRAY)
2932        is_(arrtable.c.strarr[1:3].type._type_affinity, ARRAY)
2933
2934    def test_array_getitem_slice_type_dialect_level(self):
2935        MyArray = self._myarray_fixture()
2936        m = MetaData()
2937        arrtable = Table(
2938            "arrtable",
2939            m,
2940            Column("intarr", MyArray(Integer)),
2941            Column("strarr", MyArray(String)),
2942        )
2943        is_(arrtable.c.intarr[1:3].type._type_affinity, ARRAY)
2944        is_(arrtable.c.strarr[1:3].type._type_affinity, ARRAY)
2945
2946        # but the slice returns the actual type
2947        assert isinstance(arrtable.c.intarr[1:3].type, MyArray)
2948        assert isinstance(arrtable.c.strarr[1:3].type, MyArray)
2949
2950
2951MyCustomType = MyTypeDec = None
2952
2953
2954class ExpressionTest(
2955    fixtures.TablesTest, AssertsExecutionResults, AssertsCompiledSQL
2956):
2957    __dialect__ = "default"
2958
2959    @classmethod
2960    def define_tables(cls, metadata):
2961        global MyCustomType, MyTypeDec
2962
2963        class MyCustomType(types.UserDefinedType):
2964            def get_col_spec(self):
2965                return "INT"
2966
2967            def bind_processor(self, dialect):
2968                def process(value):
2969                    return value * 10
2970
2971                return process
2972
2973            def result_processor(self, dialect, coltype):
2974                def process(value):
2975                    return value / 10
2976
2977                return process
2978
2979        class MyOldCustomType(MyCustomType):
2980            def adapt_operator(self, op):
2981                return {
2982                    operators.add: operators.sub,
2983                    operators.sub: operators.add,
2984                }.get(op, op)
2985
2986        class MyTypeDec(types.TypeDecorator):
2987            impl = String
2988
2989            cache_ok = True
2990
2991            def process_bind_param(self, value, dialect):
2992                return "BIND_IN" + str(value)
2993
2994            def process_result_value(self, value, dialect):
2995                return value + "BIND_OUT"
2996
2997        class MyDecOfDec(types.TypeDecorator):
2998            impl = MyTypeDec
2999
3000            cache_ok = True
3001
3002        Table(
3003            "test",
3004            metadata,
3005            Column("id", Integer, primary_key=True),
3006            Column("data", String(30)),
3007            Column("atimestamp", Date),
3008            Column("avalue", MyCustomType),
3009            Column("bvalue", MyTypeDec(50)),
3010            Column("cvalue", MyDecOfDec(50)),
3011        )
3012
3013    @classmethod
3014    def insert_data(cls, connection):
3015        test_table = cls.tables.test
3016        connection.execute(
3017            test_table.insert(),
3018            {
3019                "id": 1,
3020                "data": "somedata",
3021                "atimestamp": datetime.date(2007, 10, 15),
3022                "avalue": 25,
3023                "bvalue": "foo",
3024                "cvalue": "foo",
3025            },
3026        )
3027
3028    def test_control(self, connection):
3029        test_table = self.tables.test
3030        assert (
3031            connection.exec_driver_sql("select avalue from test").scalar()
3032            == 250
3033        )
3034
3035        eq_(
3036            connection.execute(test_table.select()).fetchall(),
3037            [
3038                (
3039                    1,
3040                    "somedata",
3041                    datetime.date(2007, 10, 15),
3042                    25,
3043                    "BIND_INfooBIND_OUT",
3044                    "BIND_INfooBIND_OUT",
3045                )
3046            ],
3047        )
3048
3049    def test_bind_adapt(self, connection):
3050        # test an untyped bind gets the left side's type
3051
3052        test_table = self.tables.test
3053
3054        expr = test_table.c.atimestamp == bindparam("thedate")
3055        eq_(expr.right.type._type_affinity, Date)
3056
3057        eq_(
3058            connection.execute(
3059                select(
3060                    test_table.c.id,
3061                    test_table.c.data,
3062                    test_table.c.atimestamp,
3063                ).where(expr),
3064                {"thedate": datetime.date(2007, 10, 15)},
3065            ).fetchall(),
3066            [(1, "somedata", datetime.date(2007, 10, 15))],
3067        )
3068
3069        expr = test_table.c.avalue == bindparam("somevalue")
3070        eq_(expr.right.type._type_affinity, MyCustomType)
3071
3072        eq_(
3073            connection.execute(
3074                test_table.select().where(expr), {"somevalue": 25}
3075            ).fetchall(),
3076            [
3077                (
3078                    1,
3079                    "somedata",
3080                    datetime.date(2007, 10, 15),
3081                    25,
3082                    "BIND_INfooBIND_OUT",
3083                    "BIND_INfooBIND_OUT",
3084                )
3085            ],
3086        )
3087
3088        expr = test_table.c.bvalue == bindparam("somevalue")
3089        eq_(expr.right.type._type_affinity, String)
3090
3091        eq_(
3092            connection.execute(
3093                test_table.select().where(expr), {"somevalue": "foo"}
3094            ).fetchall(),
3095            [
3096                (
3097                    1,
3098                    "somedata",
3099                    datetime.date(2007, 10, 15),
3100                    25,
3101                    "BIND_INfooBIND_OUT",
3102                    "BIND_INfooBIND_OUT",
3103                )
3104            ],
3105        )
3106
3107    def test_grouped_bind_adapt(self):
3108        test_table = self.tables.test
3109
3110        expr = test_table.c.atimestamp == elements.Grouping(
3111            bindparam("thedate")
3112        )
3113        eq_(expr.right.type._type_affinity, Date)
3114        eq_(expr.right.element.type._type_affinity, Date)
3115
3116        expr = test_table.c.atimestamp == elements.Grouping(
3117            elements.Grouping(bindparam("thedate"))
3118        )
3119        eq_(expr.right.type._type_affinity, Date)
3120        eq_(expr.right.element.type._type_affinity, Date)
3121        eq_(expr.right.element.element.type._type_affinity, Date)
3122
3123    def test_bind_adapt_update(self):
3124        test_table = self.tables.test
3125
3126        bp = bindparam("somevalue")
3127        stmt = test_table.update().values(avalue=bp)
3128        compiled = stmt.compile()
3129        eq_(bp.type._type_affinity, types.NullType)
3130        eq_(compiled.binds["somevalue"].type._type_affinity, MyCustomType)
3131
3132    def test_bind_adapt_insert(self):
3133        test_table = self.tables.test
3134        bp = bindparam("somevalue")
3135
3136        stmt = test_table.insert().values(avalue=bp)
3137        compiled = stmt.compile()
3138        eq_(bp.type._type_affinity, types.NullType)
3139        eq_(compiled.binds["somevalue"].type._type_affinity, MyCustomType)
3140
3141    def test_bind_adapt_expression(self):
3142        test_table = self.tables.test
3143
3144        bp = bindparam("somevalue")
3145        stmt = test_table.c.avalue == bp
3146        eq_(bp.type._type_affinity, types.NullType)
3147        eq_(stmt.right.type._type_affinity, MyCustomType)
3148
3149    def test_literal_adapt(self):
3150        # literals get typed based on the types dictionary, unless
3151        # compatible with the left side type
3152
3153        expr = column("foo", String) == 5
3154        eq_(expr.right.type._type_affinity, Integer)
3155
3156        expr = column("foo", String) == "asdf"
3157        eq_(expr.right.type._type_affinity, String)
3158
3159        expr = column("foo", CHAR) == 5
3160        eq_(expr.right.type._type_affinity, Integer)
3161
3162        expr = column("foo", CHAR) == "asdf"
3163        eq_(expr.right.type.__class__, CHAR)
3164
3165    @testing.combinations(
3166        (5, Integer),
3167        (2.65, Float),
3168        (True, Boolean),
3169        (decimal.Decimal("2.65"), Numeric),
3170        (datetime.date(2015, 7, 20), Date),
3171        (datetime.time(10, 15, 20), Time),
3172        (datetime.datetime(2015, 7, 20, 10, 15, 20), DateTime),
3173        (datetime.timedelta(seconds=5), Interval),
3174        (None, types.NullType),
3175    )
3176    def test_actual_literal_adapters(self, data, expected):
3177        is_(literal(data).type.__class__, expected)
3178
3179    def test_typedec_operator_adapt(self, connection):
3180        test_table = self.tables.test
3181
3182        expr = test_table.c.bvalue + "hi"
3183
3184        assert expr.type.__class__ is MyTypeDec
3185        assert expr.right.type.__class__ is MyTypeDec
3186
3187        eq_(
3188            connection.execute(select(expr.label("foo"))).scalar(),
3189            "BIND_INfooBIND_INhiBIND_OUT",
3190        )
3191
3192    def test_typedec_is_adapt(self):
3193        class CoerceNothing(TypeDecorator):
3194            coerce_to_is_types = ()
3195            impl = Integer
3196            cache_ok = True
3197
3198        class CoerceBool(TypeDecorator):
3199            coerce_to_is_types = (bool,)
3200            impl = Boolean
3201            cache_ok = True
3202
3203        class CoerceNone(TypeDecorator):
3204            coerce_to_is_types = (type(None),)
3205            impl = Integer
3206            cache_ok = True
3207
3208        c1 = column("x", CoerceNothing())
3209        c2 = column("x", CoerceBool())
3210        c3 = column("x", CoerceNone())
3211
3212        self.assert_compile(
3213            and_(c1 == None, c2 == None, c3 == None),  # noqa
3214            "x = :x_1 AND x = :x_2 AND x IS NULL",
3215        )
3216        self.assert_compile(
3217            and_(c1 == True, c2 == True, c3 == True),  # noqa
3218            "x = :x_1 AND x = true AND x = :x_2",
3219            dialect=default.DefaultDialect(supports_native_boolean=True),
3220        )
3221        self.assert_compile(
3222            and_(c1 == 3, c2 == 3, c3 == 3),
3223            "x = :x_1 AND x = :x_2 AND x = :x_3",
3224            dialect=default.DefaultDialect(supports_native_boolean=True),
3225        )
3226        self.assert_compile(
3227            and_(c1.is_(True), c2.is_(True), c3.is_(True)),
3228            "x IS :x_1 AND x IS true AND x IS :x_2",
3229            dialect=default.DefaultDialect(supports_native_boolean=True),
3230        )
3231
3232    def test_typedec_righthand_coercion(self, connection):
3233        class MyTypeDec(types.TypeDecorator):
3234            impl = String
3235            cache_ok = True
3236
3237            def process_bind_param(self, value, dialect):
3238                return "BIND_IN" + str(value)
3239
3240            def process_result_value(self, value, dialect):
3241                return value + "BIND_OUT"
3242
3243        tab = table("test", column("bvalue", MyTypeDec))
3244        expr = tab.c.bvalue + 6
3245
3246        self.assert_compile(
3247            expr, "test.bvalue || :bvalue_1", use_default_dialect=True
3248        )
3249
3250        is_(expr.right.type.__class__, MyTypeDec)
3251        is_(expr.type.__class__, MyTypeDec)
3252
3253        eq_(
3254            connection.execute(select(expr.label("foo"))).scalar(),
3255            "BIND_INfooBIND_IN6BIND_OUT",
3256        )
3257
3258    def test_variant_righthand_coercion_honors_wrapped(self):
3259        my_json_normal = JSON()
3260        my_json_variant = JSON().with_variant(String(), "sqlite")
3261
3262        tab = table(
3263            "test",
3264            column("avalue", my_json_normal),
3265            column("bvalue", my_json_variant),
3266        )
3267        expr = tab.c.avalue["foo"] == "bar"
3268
3269        is_(expr.right.type._type_affinity, String)
3270        is_not(expr.right.type, my_json_normal)
3271
3272        expr = tab.c.bvalue["foo"] == "bar"
3273
3274        is_(expr.right.type._type_affinity, String)
3275        is_not(expr.right.type, my_json_variant)
3276
3277    def test_variant_righthand_coercion_returns_self(self):
3278        my_datetime_normal = DateTime()
3279        my_datetime_variant = DateTime().with_variant(
3280            dialects.sqlite.DATETIME(truncate_microseconds=False), "sqlite"
3281        )
3282
3283        tab = table(
3284            "test",
3285            column("avalue", my_datetime_normal),
3286            column("bvalue", my_datetime_variant),
3287        )
3288        expr = tab.c.avalue == datetime.datetime(2015, 10, 14, 15, 17, 18)
3289
3290        is_(expr.right.type._type_affinity, DateTime)
3291        is_(expr.right.type, my_datetime_normal)
3292
3293        expr = tab.c.bvalue == datetime.datetime(2015, 10, 14, 15, 17, 18)
3294
3295        is_(expr.right.type, my_datetime_variant)
3296
3297    def test_bind_typing(self):
3298        from sqlalchemy.sql import column
3299
3300        class MyFoobarType(types.UserDefinedType):
3301            pass
3302
3303        class Foo(object):
3304            pass
3305
3306        # unknown type + integer, right hand bind
3307        # coerces to given type
3308        expr = column("foo", MyFoobarType) + 5
3309        assert expr.right.type._type_affinity is MyFoobarType
3310
3311        # untyped bind - it gets assigned MyFoobarType
3312        bp = bindparam("foo")
3313        expr = column("foo", MyFoobarType) + bp
3314        assert bp.type._type_affinity is types.NullType  # noqa
3315        assert expr.right.type._type_affinity is MyFoobarType
3316
3317        expr = column("foo", MyFoobarType) + bindparam("foo", type_=Integer)
3318        assert expr.right.type._type_affinity is types.Integer
3319
3320        # unknown type + unknown, right hand bind
3321        # coerces to the left
3322        expr = column("foo", MyFoobarType) + Foo()
3323        assert expr.right.type._type_affinity is MyFoobarType
3324
3325        # including for non-commutative ops
3326        expr = column("foo", MyFoobarType) - Foo()
3327        assert expr.right.type._type_affinity is MyFoobarType
3328
3329        expr = column("foo", MyFoobarType) - datetime.date(2010, 8, 25)
3330        assert expr.right.type._type_affinity is MyFoobarType
3331
3332    def test_date_coercion(self):
3333        expr = column("bar", types.NULLTYPE) - column("foo", types.TIMESTAMP)
3334        eq_(expr.type._type_affinity, types.NullType)
3335
3336        expr = func.sysdate() - column("foo", types.TIMESTAMP)
3337        eq_(expr.type._type_affinity, types.Interval)
3338
3339        expr = func.current_date() - column("foo", types.TIMESTAMP)
3340        eq_(expr.type._type_affinity, types.Interval)
3341
3342    def test_interval_coercion(self):
3343        expr = column("bar", types.Interval) + column("foo", types.Date)
3344        eq_(expr.type._type_affinity, types.DateTime)
3345
3346        expr = column("bar", types.Interval) * column("foo", types.Numeric)
3347        eq_(expr.type._type_affinity, types.Interval)
3348
3349    @testing.combinations(
3350        (operator.add,),
3351        (operator.mul,),
3352        (operator.truediv,),
3353        (operator.sub,),
3354        argnames="op",
3355        id_="n",
3356    )
3357    @testing.combinations(
3358        (Numeric(10, 2),), (Integer(),), argnames="other", id_="r"
3359    )
3360    def test_numerics_coercion(self, op, other):
3361        expr = op(column("bar", types.Numeric(10, 2)), column("foo", other))
3362        assert isinstance(expr.type, types.Numeric)
3363        expr = op(column("foo", other), column("bar", types.Numeric(10, 2)))
3364        assert isinstance(expr.type, types.Numeric)
3365
3366    def test_asdecimal_int_to_numeric(self):
3367        expr = column("a", Integer) * column("b", Numeric(asdecimal=False))
3368        is_(expr.type.asdecimal, False)
3369
3370        expr = column("a", Integer) * column("b", Numeric())
3371        is_(expr.type.asdecimal, True)
3372
3373        expr = column("a", Integer) * column("b", Float())
3374        is_(expr.type.asdecimal, False)
3375        assert isinstance(expr.type, Float)
3376
3377    def test_asdecimal_numeric_to_int(self):
3378        expr = column("a", Numeric(asdecimal=False)) * column("b", Integer)
3379        is_(expr.type.asdecimal, False)
3380
3381        expr = column("a", Numeric()) * column("b", Integer)
3382        is_(expr.type.asdecimal, True)
3383
3384        expr = column("a", Float()) * column("b", Integer)
3385        is_(expr.type.asdecimal, False)
3386        assert isinstance(expr.type, Float)
3387
3388    def test_null_comparison(self):
3389        eq_(
3390            str(column("a", types.NullType()) + column("b", types.NullType())),
3391            "a + b",
3392        )
3393
3394    def test_expression_typing(self):
3395        expr = column("bar", Integer) - 3
3396
3397        eq_(expr.type._type_affinity, Integer)
3398
3399        expr = bindparam("bar") + bindparam("foo")
3400        eq_(expr.type, types.NULLTYPE)
3401
3402    def test_distinct(self, connection):
3403        test_table = self.tables.test
3404
3405        s = select(distinct(test_table.c.avalue))
3406        eq_(connection.execute(s).scalar(), 25)
3407
3408        s = select(test_table.c.avalue.distinct())
3409        eq_(connection.execute(s).scalar(), 25)
3410
3411        assert distinct(test_table.c.data).type == test_table.c.data.type
3412        assert test_table.c.data.distinct().type == test_table.c.data.type
3413
3414    def test_detect_coercion_of_builtins(self):
3415        @inspection._self_inspects
3416        class SomeSQLAThing(object):
3417            def __repr__(self):
3418                return "some_sqla_thing()"
3419
3420        class SomeOtherThing(object):
3421            pass
3422
3423        assert_raises_message(
3424            exc.ArgumentError,
3425            r"SQL expression element or literal value expected, got "
3426            r"some_sqla_thing\(\).",
3427            lambda: column("a", String) == SomeSQLAThing(),
3428        )
3429
3430        is_(bindparam("x", SomeOtherThing()).type, types.NULLTYPE)
3431
3432    def test_detect_coercion_not_fooled_by_mock(self):
3433        m1 = mock.Mock()
3434        is_(bindparam("x", m1).type, types.NULLTYPE)
3435
3436
3437class CompileTest(fixtures.TestBase, AssertsCompiledSQL):
3438    __dialect__ = "default"
3439
3440    @testing.requires.unbounded_varchar
3441    def test_string_plain(self):
3442        self.assert_compile(String(), "VARCHAR")
3443
3444    def test_string_length(self):
3445        self.assert_compile(String(50), "VARCHAR(50)")
3446
3447    def test_string_collation(self):
3448        self.assert_compile(
3449            String(50, collation="FOO"), 'VARCHAR(50) COLLATE "FOO"'
3450        )
3451
3452    def test_char_plain(self):
3453        self.assert_compile(CHAR(), "CHAR")
3454
3455    def test_char_length(self):
3456        self.assert_compile(CHAR(50), "CHAR(50)")
3457
3458    def test_char_collation(self):
3459        self.assert_compile(
3460            CHAR(50, collation="FOO"), 'CHAR(50) COLLATE "FOO"'
3461        )
3462
3463    def test_text_plain(self):
3464        self.assert_compile(Text(), "TEXT")
3465
3466    def test_text_length(self):
3467        self.assert_compile(Text(50), "TEXT(50)")
3468
3469    def test_text_collation(self):
3470        self.assert_compile(Text(collation="FOO"), 'TEXT COLLATE "FOO"')
3471
3472    def test_default_compile_pg_inet(self):
3473        self.assert_compile(
3474            dialects.postgresql.INET(), "INET", allow_dialect_select=True
3475        )
3476
3477    def test_default_compile_pg_float(self):
3478        self.assert_compile(
3479            dialects.postgresql.FLOAT(), "FLOAT", allow_dialect_select=True
3480        )
3481
3482    def test_default_compile_mysql_integer(self):
3483        self.assert_compile(
3484            dialects.mysql.INTEGER(display_width=5),
3485            "INTEGER",
3486            allow_dialect_select=True,
3487        )
3488
3489        self.assert_compile(
3490            dialects.mysql.INTEGER(display_width=5),
3491            "INTEGER(5)",
3492            dialect="mysql",
3493        )
3494
3495    def test_numeric_plain(self):
3496        self.assert_compile(types.NUMERIC(), "NUMERIC")
3497
3498    def test_numeric_precision(self):
3499        self.assert_compile(types.NUMERIC(2), "NUMERIC(2)")
3500
3501    def test_numeric_scale(self):
3502        self.assert_compile(types.NUMERIC(2, 4), "NUMERIC(2, 4)")
3503
3504    def test_decimal_plain(self):
3505        self.assert_compile(types.DECIMAL(), "DECIMAL")
3506
3507    def test_decimal_precision(self):
3508        self.assert_compile(types.DECIMAL(2), "DECIMAL(2)")
3509
3510    def test_decimal_scale(self):
3511        self.assert_compile(types.DECIMAL(2, 4), "DECIMAL(2, 4)")
3512
3513    def test_kwarg_legacy_typecompiler(self):
3514        from sqlalchemy.sql import compiler
3515
3516        class SomeTypeCompiler(compiler.GenericTypeCompiler):
3517            # transparently decorated w/ kw decorator
3518            def visit_VARCHAR(self, type_):
3519                return "MYVARCHAR"
3520
3521            # not affected
3522            def visit_INTEGER(self, type_, **kw):
3523                return "MYINTEGER %s" % kw["type_expression"].name
3524
3525        dialect = default.DefaultDialect()
3526        dialect.type_compiler = SomeTypeCompiler(dialect)
3527        self.assert_compile(
3528            ddl.CreateColumn(Column("bar", VARCHAR(50))),
3529            "bar MYVARCHAR",
3530            dialect=dialect,
3531        )
3532        self.assert_compile(
3533            ddl.CreateColumn(Column("bar", INTEGER)),
3534            "bar MYINTEGER bar",
3535            dialect=dialect,
3536        )
3537
3538
3539class TestKWArgPassThru(AssertsCompiledSQL, fixtures.TestBase):
3540    __backend__ = True
3541
3542    def test_user_defined(self):
3543        """test that dialects pass the column through on DDL."""
3544
3545        class MyType(types.UserDefinedType):
3546            def get_col_spec(self, **kw):
3547                return "FOOB %s" % kw["type_expression"].name
3548
3549        m = MetaData()
3550        t = Table("t", m, Column("bar", MyType, nullable=False))
3551        self.assert_compile(ddl.CreateColumn(t.c.bar), "bar FOOB bar NOT NULL")
3552
3553
3554class NumericRawSQLTest(fixtures.TestBase):
3555
3556    """Test what DBAPIs and dialects return without any typing
3557    information supplied at the SQLA level.
3558
3559    """
3560
3561    __backend__ = True
3562
3563    def _fixture(self, connection, metadata, type_, data):
3564        t = Table("t", metadata, Column("val", type_))
3565        metadata.create_all(connection)
3566        connection.execute(t.insert(), dict(val=data))
3567
3568    @testing.fails_on("sqlite", "Doesn't provide Decimal results natively")
3569    @testing.provide_metadata
3570    def test_decimal_fp(self, connection):
3571        metadata = self.metadata
3572        self._fixture(
3573            connection, metadata, Numeric(10, 5), decimal.Decimal("45.5")
3574        )
3575        val = connection.exec_driver_sql("select val from t").scalar()
3576        assert isinstance(val, decimal.Decimal)
3577        eq_(val, decimal.Decimal("45.5"))
3578
3579    @testing.fails_on("sqlite", "Doesn't provide Decimal results natively")
3580    @testing.provide_metadata
3581    def test_decimal_int(self, connection):
3582        metadata = self.metadata
3583        self._fixture(
3584            connection, metadata, Numeric(10, 5), decimal.Decimal("45")
3585        )
3586        val = connection.exec_driver_sql("select val from t").scalar()
3587        assert isinstance(val, decimal.Decimal)
3588        eq_(val, decimal.Decimal("45"))
3589
3590    @testing.provide_metadata
3591    def test_ints(self, connection):
3592        metadata = self.metadata
3593        self._fixture(connection, metadata, Integer, 45)
3594        val = connection.exec_driver_sql("select val from t").scalar()
3595        assert isinstance(val, util.int_types)
3596        eq_(val, 45)
3597
3598    @testing.provide_metadata
3599    def test_float(self, connection):
3600        metadata = self.metadata
3601        self._fixture(connection, metadata, Float, 46.583)
3602        val = connection.exec_driver_sql("select val from t").scalar()
3603        assert isinstance(val, float)
3604
3605        # some DBAPIs have unusual float handling
3606        if testing.against("oracle+cx_oracle", "mysql+oursql", "firebird"):
3607            eq_(round_decimal(val, 3), 46.583)
3608        else:
3609            eq_(val, 46.583)
3610
3611
3612class IntervalTest(fixtures.TablesTest, AssertsExecutionResults):
3613
3614    __backend__ = True
3615
3616    @classmethod
3617    def define_tables(cls, metadata):
3618        Table(
3619            "intervals",
3620            metadata,
3621            Column(
3622                "id", Integer, primary_key=True, test_needs_autoincrement=True
3623            ),
3624            Column("native_interval", Interval()),
3625            Column(
3626                "native_interval_args",
3627                Interval(day_precision=3, second_precision=6),
3628            ),
3629            Column("non_native_interval", Interval(native=False)),
3630        )
3631
3632    def test_non_native_adapt(self):
3633        interval = Interval(native=False)
3634        adapted = interval.dialect_impl(testing.db.dialect)
3635        assert isinstance(adapted, Interval)
3636        assert adapted.native is False
3637        eq_(str(adapted), "DATETIME")
3638
3639    def test_roundtrip(self, connection):
3640        interval_table = self.tables.intervals
3641
3642        small_delta = datetime.timedelta(days=15, seconds=5874)
3643        delta = datetime.timedelta(14)
3644        connection.execute(
3645            interval_table.insert(),
3646            dict(
3647                native_interval=small_delta,
3648                native_interval_args=delta,
3649                non_native_interval=delta,
3650            ),
3651        )
3652        row = connection.execute(interval_table.select()).first()
3653        eq_(row.native_interval, small_delta)
3654        eq_(row.native_interval_args, delta)
3655        eq_(row.non_native_interval, delta)
3656
3657    def test_null(self, connection):
3658        interval_table = self.tables.intervals
3659
3660        connection.execute(
3661            interval_table.insert(),
3662            dict(
3663                id=1,
3664                native_inverval=None,
3665                non_native_interval=None,
3666            ),
3667        )
3668        row = connection.execute(interval_table.select()).first()
3669        eq_(row.native_interval, None)
3670        eq_(row.native_interval_args, None)
3671        eq_(row.non_native_interval, None)
3672
3673
3674class IntegerTest(fixtures.TestBase):
3675    __backend__ = True
3676
3677    def test_integer_literal_processor(self):
3678        typ = Integer()
3679        eq_(typ._cached_literal_processor(testing.db.dialect)(5), "5")
3680
3681        assert_raises(
3682            ValueError,
3683            typ._cached_literal_processor(testing.db.dialect),
3684            "notanint",
3685        )
3686
3687
3688class BooleanTest(
3689    fixtures.TablesTest, AssertsExecutionResults, AssertsCompiledSQL
3690):
3691
3692    """test edge cases for booleans.  Note that the main boolean test suite
3693    is now in testing/suite/test_types.py
3694
3695    the default value of create_constraint was changed to False in
3696    version 1.4 with #5367.
3697
3698    """
3699
3700    __backend__ = True
3701
3702    @classmethod
3703    def define_tables(cls, metadata):
3704        Table(
3705            "boolean_table",
3706            metadata,
3707            Column("id", Integer, primary_key=True, autoincrement=False),
3708            Column("value", Boolean(create_constraint=True)),
3709            Column("unconstrained_value", Boolean()),
3710        )
3711
3712    @testing.requires.enforces_check_constraints
3713    @testing.requires.non_native_boolean_unconstrained
3714    def test_constraint(self, connection):
3715        assert_raises(
3716            (
3717                exc.IntegrityError,
3718                exc.ProgrammingError,
3719                exc.OperationalError,
3720                exc.InternalError,  # older pymysql's do this
3721            ),
3722            connection.exec_driver_sql,
3723            "insert into boolean_table (id, value) values(1, 5)",
3724        )
3725
3726    @testing.skip_if(lambda: testing.db.dialect.supports_native_boolean)
3727    def test_unconstrained(self, connection):
3728        connection.exec_driver_sql(
3729            "insert into boolean_table (id, unconstrained_value)"
3730            "values (1, 5)"
3731        )
3732
3733    def test_non_native_constraint_custom_type(self):
3734        class Foob(object):
3735            def __init__(self, value):
3736                self.value = value
3737
3738        class MyBool(TypeDecorator):
3739            impl = Boolean(create_constraint=True)
3740            cache_ok = True
3741
3742            # future method
3743            def process_literal_param(self, value, dialect):
3744                return value.value
3745
3746            def process_bind_param(self, value, dialect):
3747                return value.value
3748
3749        m = MetaData()
3750        t1 = Table("t", m, Column("x", MyBool()))
3751        const = [c for c in t1.constraints if isinstance(c, CheckConstraint)][
3752            0
3753        ]
3754
3755        self.assert_compile(
3756            AddConstraint(const),
3757            "ALTER TABLE t ADD CHECK (x IN (0, 1))",
3758            dialect="sqlite",
3759        )
3760
3761    @testing.skip_if(lambda: testing.db.dialect.supports_native_boolean)
3762    def test_nonnative_processor_coerces_to_onezero(self):
3763        boolean_table = self.tables.boolean_table
3764        with testing.db.connect() as conn:
3765            assert_raises_message(
3766                exc.StatementError,
3767                "Value 5 is not None, True, or False",
3768                conn.execute,
3769                boolean_table.insert(),
3770                {"id": 1, "unconstrained_value": 5},
3771            )
3772
3773    @testing.requires.non_native_boolean_unconstrained
3774    def test_nonnative_processor_coerces_integer_to_boolean(self, connection):
3775        boolean_table = self.tables.boolean_table
3776        connection.exec_driver_sql(
3777            "insert into boolean_table (id, unconstrained_value) "
3778            "values (1, 5)"
3779        )
3780
3781        eq_(
3782            connection.exec_driver_sql(
3783                "select unconstrained_value from boolean_table"
3784            ).scalar(),
3785            5,
3786        )
3787
3788        eq_(
3789            connection.scalar(select(boolean_table.c.unconstrained_value)),
3790            True,
3791        )
3792
3793    def test_bind_processor_coercion_native_true(self):
3794        proc = Boolean().bind_processor(
3795            mock.Mock(supports_native_boolean=True)
3796        )
3797        is_(proc(True), True)
3798
3799    def test_bind_processor_coercion_native_false(self):
3800        proc = Boolean().bind_processor(
3801            mock.Mock(supports_native_boolean=True)
3802        )
3803        is_(proc(False), False)
3804
3805    def test_bind_processor_coercion_native_none(self):
3806        proc = Boolean().bind_processor(
3807            mock.Mock(supports_native_boolean=True)
3808        )
3809        is_(proc(None), None)
3810
3811    def test_bind_processor_coercion_native_0(self):
3812        proc = Boolean().bind_processor(
3813            mock.Mock(supports_native_boolean=True)
3814        )
3815        is_(proc(0), False)
3816
3817    def test_bind_processor_coercion_native_1(self):
3818        proc = Boolean().bind_processor(
3819            mock.Mock(supports_native_boolean=True)
3820        )
3821        is_(proc(1), True)
3822
3823    def test_bind_processor_coercion_native_str(self):
3824        proc = Boolean().bind_processor(
3825            mock.Mock(supports_native_boolean=True)
3826        )
3827        assert_raises_message(
3828            TypeError, "Not a boolean value: 'foo'", proc, "foo"
3829        )
3830
3831    def test_bind_processor_coercion_native_int_out_of_range(self):
3832        proc = Boolean().bind_processor(
3833            mock.Mock(supports_native_boolean=True)
3834        )
3835        assert_raises_message(
3836            ValueError, "Value 15 is not None, True, or False", proc, 15
3837        )
3838
3839    def test_bind_processor_coercion_nonnative_true(self):
3840        proc = Boolean().bind_processor(
3841            mock.Mock(supports_native_boolean=False)
3842        )
3843        eq_(proc(True), 1)
3844
3845    def test_bind_processor_coercion_nonnative_false(self):
3846        proc = Boolean().bind_processor(
3847            mock.Mock(supports_native_boolean=False)
3848        )
3849        eq_(proc(False), 0)
3850
3851    def test_bind_processor_coercion_nonnative_none(self):
3852        proc = Boolean().bind_processor(
3853            mock.Mock(supports_native_boolean=False)
3854        )
3855        is_(proc(None), None)
3856
3857    def test_bind_processor_coercion_nonnative_0(self):
3858        proc = Boolean().bind_processor(
3859            mock.Mock(supports_native_boolean=False)
3860        )
3861        eq_(proc(0), 0)
3862
3863    def test_bind_processor_coercion_nonnative_1(self):
3864        proc = Boolean().bind_processor(
3865            mock.Mock(supports_native_boolean=False)
3866        )
3867        eq_(proc(1), 1)
3868
3869    def test_bind_processor_coercion_nonnative_str(self):
3870        proc = Boolean().bind_processor(
3871            mock.Mock(supports_native_boolean=False)
3872        )
3873        assert_raises_message(
3874            TypeError, "Not a boolean value: 'foo'", proc, "foo"
3875        )
3876
3877    def test_bind_processor_coercion_nonnative_int_out_of_range(self):
3878        proc = Boolean().bind_processor(
3879            mock.Mock(supports_native_boolean=False)
3880        )
3881        assert_raises_message(
3882            ValueError, "Value 15 is not None, True, or False", proc, 15
3883        )
3884
3885    def test_literal_processor_coercion_native_true(self):
3886        proc = Boolean().literal_processor(
3887            default.DefaultDialect(supports_native_boolean=True)
3888        )
3889        eq_(proc(True), "true")
3890
3891    def test_literal_processor_coercion_native_false(self):
3892        proc = Boolean().literal_processor(
3893            default.DefaultDialect(supports_native_boolean=True)
3894        )
3895        eq_(proc(False), "false")
3896
3897    def test_literal_processor_coercion_native_1(self):
3898        proc = Boolean().literal_processor(
3899            default.DefaultDialect(supports_native_boolean=True)
3900        )
3901        eq_(proc(1), "true")
3902
3903    def test_literal_processor_coercion_native_0(self):
3904        proc = Boolean().literal_processor(
3905            default.DefaultDialect(supports_native_boolean=True)
3906        )
3907        eq_(proc(0), "false")
3908
3909    def test_literal_processor_coercion_native_str(self):
3910        proc = Boolean().literal_processor(
3911            default.DefaultDialect(supports_native_boolean=True)
3912        )
3913        assert_raises_message(
3914            TypeError, "Not a boolean value: 'foo'", proc, "foo"
3915        )
3916
3917    def test_literal_processor_coercion_native_int_out_of_range(self):
3918        proc = Boolean().literal_processor(
3919            default.DefaultDialect(supports_native_boolean=True)
3920        )
3921        assert_raises_message(
3922            ValueError, "Value 15 is not None, True, or False", proc, 15
3923        )
3924
3925    def test_literal_processor_coercion_nonnative_true(self):
3926        proc = Boolean().literal_processor(
3927            default.DefaultDialect(supports_native_boolean=False)
3928        )
3929        eq_(proc(True), "1")
3930
3931    def test_literal_processor_coercion_nonnative_false(self):
3932        proc = Boolean().literal_processor(
3933            default.DefaultDialect(supports_native_boolean=False)
3934        )
3935        eq_(proc(False), "0")
3936
3937    def test_literal_processor_coercion_nonnative_1(self):
3938        proc = Boolean().literal_processor(
3939            default.DefaultDialect(supports_native_boolean=False)
3940        )
3941        eq_(proc(1), "1")
3942
3943    def test_literal_processor_coercion_nonnative_0(self):
3944        proc = Boolean().literal_processor(
3945            default.DefaultDialect(supports_native_boolean=False)
3946        )
3947        eq_(proc(0), "0")
3948
3949    def test_literal_processor_coercion_nonnative_str(self):
3950        proc = Boolean().literal_processor(
3951            default.DefaultDialect(supports_native_boolean=False)
3952        )
3953        assert_raises_message(
3954            TypeError, "Not a boolean value: 'foo'", proc, "foo"
3955        )
3956
3957
3958class PickleTest(fixtures.TestBase):
3959    def test_eq_comparison(self):
3960        p1 = PickleType()
3961
3962        for obj in (
3963            {"1": "2"},
3964            pickleable.Bar(5, 6),
3965            pickleable.OldSchool(10, 11),
3966        ):
3967            assert p1.compare_values(p1.copy_value(obj), obj)
3968
3969        assert_raises(
3970            NotImplementedError,
3971            p1.compare_values,
3972            pickleable.BrokenComparable("foo"),
3973            pickleable.BrokenComparable("foo"),
3974        )
3975
3976    def test_nonmutable_comparison(self):
3977        p1 = PickleType()
3978
3979        for obj in (
3980            {"1": "2"},
3981            pickleable.Bar(5, 6),
3982            pickleable.OldSchool(10, 11),
3983        ):
3984            assert p1.compare_values(p1.copy_value(obj), obj)
3985
3986    @testing.combinations(
3987        None, mysql.LONGBLOB, LargeBinary, mysql.LONGBLOB(), LargeBinary()
3988    )
3989    def test_customized_impl(self, impl):
3990        """test #6646"""
3991
3992        if impl is None:
3993            p1 = PickleType()
3994            assert isinstance(p1.impl, LargeBinary)
3995        else:
3996            p1 = PickleType(impl=impl)
3997
3998            if not isinstance(impl, type):
3999                impl = type(impl)
4000
4001            assert isinstance(p1.impl, impl)
4002
4003
4004class CallableTest(fixtures.TestBase):
4005    @testing.provide_metadata
4006    def test_callable_as_arg(self, connection):
4007        ucode = util.partial(Unicode)
4008
4009        thing_table = Table("thing", self.metadata, Column("name", ucode(20)))
4010        assert isinstance(thing_table.c.name.type, Unicode)
4011        thing_table.create(connection)
4012
4013    @testing.provide_metadata
4014    def test_callable_as_kwarg(self, connection):
4015        ucode = util.partial(Unicode)
4016
4017        thang_table = Table(
4018            "thang",
4019            self.metadata,
4020            Column("name", type_=ucode(20), primary_key=True),
4021        )
4022        assert isinstance(thang_table.c.name.type, Unicode)
4023        thang_table.create(connection)
4024
4025
4026class LiteralTest(fixtures.TestBase):
4027    __backend__ = True
4028
4029    @testing.combinations(
4030        ("datetime", datetime.datetime.now()),
4031        ("date", datetime.date.today()),
4032        ("time", datetime.time()),
4033        argnames="value",
4034        id_="ia",
4035    )
4036    @testing.skip_if(testing.requires.datetime_literals)
4037    def test_render_datetime(self, value):
4038        lit = literal(value)
4039
4040        assert_raises_message(
4041            NotImplementedError,
4042            "Don't know how to literal-quote value.*",
4043            lit.compile,
4044            dialect=testing.db.dialect,
4045            compile_kwargs={"literal_binds": True},
4046        )
4047