1# coding: utf-8
2from sqlalchemy.testing.assertions import eq_, assert_raises, \
3    assert_raises_message, is_, AssertsExecutionResults, \
4    AssertsCompiledSQL, ComparesTables
5from sqlalchemy.testing import engines, fixtures
6from sqlalchemy import testing
7import datetime
8from sqlalchemy import Table, MetaData, Column, Integer, Enum, Float, select, \
9    func, DateTime, Numeric, exc, String, cast, REAL, TypeDecorator, Unicode, \
10    Text, null, text, column, ARRAY, any_, all_
11from sqlalchemy.sql import operators
12from sqlalchemy import types as sqltypes
13import sqlalchemy as sa
14from sqlalchemy.dialects import postgresql
15from sqlalchemy.dialects.postgresql import HSTORE, hstore, array, \
16    INT4RANGE, INT8RANGE, NUMRANGE, DATERANGE, TSRANGE, TSTZRANGE, \
17    JSON, JSONB
18import decimal
19from sqlalchemy import util
20from sqlalchemy.testing.util import round_decimal
21from sqlalchemy import inspect
22from sqlalchemy import event
23from sqlalchemy.ext.declarative import declarative_base
24from sqlalchemy.orm import Session
25
26tztable = notztable = metadata = table = None
27
28
29class FloatCoercionTest(fixtures.TablesTest, AssertsExecutionResults):
30    __only_on__ = 'postgresql'
31    __dialect__ = postgresql.dialect()
32    __backend__ = True
33
34    @classmethod
35    def define_tables(cls, metadata):
36        data_table = Table('data_table', metadata,
37                           Column('id', Integer, primary_key=True),
38                           Column('data', Integer)
39                           )
40
41    @classmethod
42    def insert_data(cls):
43        data_table = cls.tables.data_table
44
45        data_table.insert().execute(
46            {'data': 3},
47            {'data': 5},
48            {'data': 7},
49            {'data': 2},
50            {'data': 15},
51            {'data': 12},
52            {'data': 6},
53            {'data': 478},
54            {'data': 52},
55            {'data': 9},
56        )
57
58    @testing.fails_on(
59        'postgresql+zxjdbc',
60        'XXX: postgresql+zxjdbc currently returns a Decimal result for Float')
61    def test_float_coercion(self):
62        data_table = self.tables.data_table
63
64        for type_, result in [
65            (Numeric, decimal.Decimal('140.381230939')),
66            (Float, 140.381230939),
67            (Float(asdecimal=True), decimal.Decimal('140.381230939')),
68            (Numeric(asdecimal=False), 140.381230939),
69        ]:
70            ret = testing.db.execute(
71                select([
72                    func.stddev_pop(data_table.c.data, type_=type_)
73                ])
74            ).scalar()
75
76            eq_(round_decimal(ret, 9), result)
77
78            ret = testing.db.execute(
79                select([
80                    cast(func.stddev_pop(data_table.c.data), type_)
81                ])
82            ).scalar()
83            eq_(round_decimal(ret, 9), result)
84
85    @testing.fails_on('postgresql+zxjdbc',
86                      'zxjdbc has no support for PG arrays')
87    @testing.provide_metadata
88    def test_arrays(self):
89        metadata = self.metadata
90        t1 = Table('t', metadata,
91                   Column('x', postgresql.ARRAY(Float)),
92                   Column('y', postgresql.ARRAY(REAL)),
93                   Column('z', postgresql.ARRAY(postgresql.DOUBLE_PRECISION)),
94                   Column('q', postgresql.ARRAY(Numeric))
95                   )
96        metadata.create_all()
97        t1.insert().execute(x=[5], y=[5], z=[6], q=[decimal.Decimal("6.4")])
98        row = t1.select().execute().first()
99        eq_(
100            row,
101            ([5], [5], [6], [decimal.Decimal("6.4")])
102        )
103
104
105class EnumTest(fixtures.TestBase, AssertsExecutionResults):
106    __backend__ = True
107
108    __only_on__ = 'postgresql > 8.3'
109
110    @testing.fails_on('postgresql+zxjdbc',
111                      'zxjdbc fails on ENUM: column "XXX" is of type '
112                      'XXX but expression is of type character varying')
113    def test_create_table(self):
114        metadata = MetaData(testing.db)
115        t1 = Table(
116            'table', metadata,
117            Column(
118                'id', Integer, primary_key=True),
119            Column(
120                'value', Enum(
121                    'one', 'two', 'three', name='onetwothreetype')))
122        t1.create()
123        t1.create(checkfirst=True)  # check the create
124        try:
125            t1.insert().execute(value='two')
126            t1.insert().execute(value='three')
127            t1.insert().execute(value='three')
128            eq_(t1.select().order_by(t1.c.id).execute().fetchall(),
129                [(1, 'two'), (2, 'three'), (3, 'three')])
130        finally:
131            metadata.drop_all()
132            metadata.drop_all()
133
134    def test_name_required(self):
135        metadata = MetaData(testing.db)
136        etype = Enum('four', 'five', 'six', metadata=metadata)
137        assert_raises(exc.CompileError, etype.create)
138        assert_raises(exc.CompileError, etype.compile,
139                      dialect=postgresql.dialect())
140
141    @testing.fails_on('postgresql+zxjdbc',
142                      'zxjdbc fails on ENUM: column "XXX" is of type '
143                      'XXX but expression is of type character varying')
144    @testing.provide_metadata
145    def test_unicode_labels(self):
146        metadata = self.metadata
147        t1 = Table(
148            'table',
149            metadata,
150            Column(
151                'id',
152                Integer,
153                primary_key=True),
154            Column(
155                'value',
156                Enum(
157                    util.u('réveillé'),
158                    util.u('drôle'),
159                    util.u('S’il'),
160                    name='onetwothreetype')))
161        metadata.create_all()
162        t1.insert().execute(value=util.u('drôle'))
163        t1.insert().execute(value=util.u('réveillé'))
164        t1.insert().execute(value=util.u('S’il'))
165        eq_(t1.select().order_by(t1.c.id).execute().fetchall(),
166            [(1, util.u('drôle')), (2, util.u('réveillé')),
167             (3, util.u('S’il'))]
168            )
169        m2 = MetaData(testing.db)
170        t2 = Table('table', m2, autoload=True)
171        eq_(
172            t2.c.value.type.enums,
173            [util.u('réveillé'), util.u('drôle'), util.u('S’il')]
174        )
175
176    @testing.provide_metadata
177    def test_non_native_enum(self):
178        metadata = self.metadata
179        t1 = Table(
180            'foo',
181            metadata,
182            Column(
183                'bar',
184                Enum(
185                    'one',
186                    'two',
187                    'three',
188                    name='myenum',
189                    native_enum=False)))
190
191        def go():
192            t1.create(testing.db)
193
194        self.assert_sql(
195            testing.db, go, [
196                ("CREATE TABLE foo (\tbar "
197                 "VARCHAR(5), \tCONSTRAINT myenum CHECK "
198                 "(bar IN ('one', 'two', 'three')))", {})])
199        with testing.db.begin() as conn:
200            conn.execute(
201                t1.insert(), {'bar': 'two'}
202            )
203            eq_(
204                conn.scalar(select([t1.c.bar])), 'two'
205            )
206
207    @testing.provide_metadata
208    def test_non_native_enum_w_unicode(self):
209        metadata = self.metadata
210        t1 = Table(
211            'foo',
212            metadata,
213            Column(
214                'bar',
215                Enum('B', util.u('Ü'), name='myenum', native_enum=False)))
216
217        def go():
218            t1.create(testing.db)
219
220        self.assert_sql(
221            testing.db,
222            go,
223            [
224                (
225                    util.u(
226                        "CREATE TABLE foo (\tbar "
227                        "VARCHAR(1), \tCONSTRAINT myenum CHECK "
228                        "(bar IN ('B', 'Ü')))"
229                    ),
230                    {}
231                )
232            ])
233
234        with testing.db.begin() as conn:
235            conn.execute(
236                t1.insert(), {'bar': util.u('Ü')}
237            )
238            eq_(
239                conn.scalar(select([t1.c.bar])), util.u('Ü')
240            )
241
242    @testing.provide_metadata
243    def test_disable_create(self):
244        metadata = self.metadata
245
246        e1 = postgresql.ENUM('one', 'two', 'three',
247                             name="myenum",
248                             create_type=False)
249
250        t1 = Table('e1', metadata,
251                   Column('c1', e1)
252                   )
253        # table can be created separately
254        # without conflict
255        e1.create(bind=testing.db)
256        t1.create(testing.db)
257        t1.drop(testing.db)
258        e1.drop(bind=testing.db)
259
260    @testing.provide_metadata
261    def test_generate_multiple(self):
262        """Test that the same enum twice only generates once
263        for the create_all() call, without using checkfirst.
264
265        A 'memo' collection held by the DDL runner
266        now handles this.
267
268        """
269        metadata = self.metadata
270
271        e1 = Enum('one', 'two', 'three',
272                  name="myenum")
273        t1 = Table('e1', metadata,
274                   Column('c1', e1)
275                   )
276
277        t2 = Table('e2', metadata,
278                   Column('c1', e1)
279                   )
280
281        metadata.create_all(checkfirst=False)
282        metadata.drop_all(checkfirst=False)
283        assert 'myenum' not in [
284            e['name'] for e in inspect(testing.db).get_enums()]
285
286    @testing.provide_metadata
287    def test_generate_alone_on_metadata(self):
288        """Test that the same enum twice only generates once
289        for the create_all() call, without using checkfirst.
290
291        A 'memo' collection held by the DDL runner
292        now handles this.
293
294        """
295        metadata = self.metadata
296
297        e1 = Enum('one', 'two', 'three',
298                  name="myenum", metadata=self.metadata)
299
300        metadata.create_all(checkfirst=False)
301        assert 'myenum' in [
302            e['name'] for e in inspect(testing.db).get_enums()]
303        metadata.drop_all(checkfirst=False)
304        assert 'myenum' not in [
305            e['name'] for e in inspect(testing.db).get_enums()]
306
307    @testing.provide_metadata
308    def test_generate_multiple_on_metadata(self):
309        metadata = self.metadata
310
311        e1 = Enum('one', 'two', 'three',
312                  name="myenum", metadata=metadata)
313
314        t1 = Table('e1', metadata,
315                   Column('c1', e1)
316                   )
317
318        t2 = Table('e2', metadata,
319                   Column('c1', e1)
320                   )
321
322        metadata.create_all(checkfirst=False)
323        assert 'myenum' in [
324            e['name'] for e in inspect(testing.db).get_enums()]
325        metadata.drop_all(checkfirst=False)
326        assert 'myenum' not in [
327            e['name'] for e in inspect(testing.db).get_enums()]
328
329        e1.create()  # creates ENUM
330        t1.create()  # does not create ENUM
331        t2.create()  # does not create ENUM
332
333    @testing.provide_metadata
334    def test_drops_on_table(self):
335        metadata = self.metadata
336
337        e1 = Enum('one', 'two', 'three',
338                  name="myenum")
339        table = Table(
340            'e1', metadata,
341            Column('c1', e1)
342        )
343
344        table.create()
345        table.drop()
346        assert 'myenum' not in [
347            e['name'] for e in inspect(testing.db).get_enums()]
348        table.create()
349        assert 'myenum' in [
350            e['name'] for e in inspect(testing.db).get_enums()]
351        table.drop()
352        assert 'myenum' not in [
353            e['name'] for e in inspect(testing.db).get_enums()]
354
355    @testing.provide_metadata
356    def test_remain_on_table_metadata_wide(self):
357        metadata = self.metadata
358
359        e1 = Enum('one', 'two', 'three',
360                  name="myenum", metadata=metadata)
361        table = Table(
362            'e1', metadata,
363            Column('c1', e1)
364        )
365
366        # need checkfirst here, otherwise enum will not be created
367        assert_raises_message(
368            sa.exc.ProgrammingError,
369            '.*type "myenum" does not exist',
370            table.create,
371        )
372        table.create(checkfirst=True)
373        table.drop()
374        table.create(checkfirst=True)
375        table.drop()
376        assert 'myenum' in [
377            e['name'] for e in inspect(testing.db).get_enums()]
378        metadata.drop_all()
379        assert 'myenum' not in [
380            e['name'] for e in inspect(testing.db).get_enums()]
381
382    def test_non_native_dialect(self):
383        engine = engines.testing_engine()
384        engine.connect()
385        engine.dialect.supports_native_enum = False
386        metadata = MetaData()
387        t1 = Table(
388            'foo',
389            metadata,
390            Column(
391                'bar',
392                Enum(
393                    'one',
394                    'two',
395                    'three',
396                    name='myenum')))
397
398        def go():
399            t1.create(engine)
400
401        try:
402            self.assert_sql(
403                engine, go, [
404                    ("CREATE TABLE foo (bar "
405                     "VARCHAR(5), CONSTRAINT myenum CHECK "
406                     "(bar IN ('one', 'two', 'three')))", {})])
407        finally:
408            metadata.drop_all(engine)
409
410    def test_standalone_enum(self):
411        metadata = MetaData(testing.db)
412        etype = Enum('four', 'five', 'six', name='fourfivesixtype',
413                     metadata=metadata)
414        etype.create()
415        try:
416            assert testing.db.dialect.has_type(testing.db,
417                                               'fourfivesixtype')
418        finally:
419            etype.drop()
420            assert not testing.db.dialect.has_type(testing.db,
421                                                   'fourfivesixtype')
422        metadata.create_all()
423        try:
424            assert testing.db.dialect.has_type(testing.db,
425                                               'fourfivesixtype')
426        finally:
427            metadata.drop_all()
428            assert not testing.db.dialect.has_type(testing.db,
429                                                   'fourfivesixtype')
430
431    def test_no_support(self):
432        def server_version_info(self):
433            return (8, 2)
434
435        e = engines.testing_engine()
436        dialect = e.dialect
437        dialect._get_server_version_info = server_version_info
438
439        assert dialect.supports_native_enum
440        e.connect()
441        assert not dialect.supports_native_enum
442
443        # initialize is called again on new pool
444        e.dispose()
445        e.connect()
446        assert not dialect.supports_native_enum
447
448    @testing.provide_metadata
449    def test_reflection(self):
450        metadata = self.metadata
451        etype = Enum('four', 'five', 'six', name='fourfivesixtype',
452                     metadata=metadata)
453        t1 = Table(
454            'table', metadata,
455            Column(
456                'id', Integer, primary_key=True),
457            Column(
458                'value', Enum(
459                    'one', 'two', 'three', name='onetwothreetype')),
460            Column('value2', etype))
461        metadata.create_all()
462        m2 = MetaData(testing.db)
463        t2 = Table('table', m2, autoload=True)
464        eq_(t2.c.value.type.enums, ['one', 'two', 'three'])
465        eq_(t2.c.value.type.name, 'onetwothreetype')
466        eq_(t2.c.value2.type.enums, ['four', 'five', 'six'])
467        eq_(t2.c.value2.type.name, 'fourfivesixtype')
468
469    @testing.provide_metadata
470    def test_schema_reflection(self):
471        metadata = self.metadata
472        etype = Enum(
473            'four',
474            'five',
475            'six',
476            name='fourfivesixtype',
477            schema='test_schema',
478            metadata=metadata,
479        )
480        Table(
481            'table', metadata,
482            Column(
483                'id', Integer, primary_key=True),
484            Column(
485                'value', Enum(
486                    'one', 'two', 'three',
487                    name='onetwothreetype', schema='test_schema')),
488            Column('value2', etype))
489        metadata.create_all()
490        m2 = MetaData(testing.db)
491        t2 = Table('table', m2, autoload=True)
492        eq_(t2.c.value.type.enums, ['one', 'two', 'three'])
493        eq_(t2.c.value.type.name, 'onetwothreetype')
494        eq_(t2.c.value2.type.enums, ['four', 'five', 'six'])
495        eq_(t2.c.value2.type.name, 'fourfivesixtype')
496        eq_(t2.c.value2.type.schema, 'test_schema')
497
498    @testing.provide_metadata
499    def test_custom_subclass(self):
500        class MyEnum(TypeDecorator):
501            impl = Enum('oneHI', 'twoHI', 'threeHI', name='myenum')
502
503            def process_bind_param(self, value, dialect):
504                if value is not None:
505                    value += "HI"
506                return value
507
508            def process_result_value(self, value, dialect):
509                if value is not None:
510                    value += "THERE"
511                return value
512
513        t1 = Table(
514            'table1', self.metadata,
515            Column('data', MyEnum())
516        )
517        self.metadata.create_all(testing.db)
518
519        with testing.db.connect() as conn:
520            conn.execute(t1.insert(), {"data": "two"})
521            eq_(
522                conn.scalar(select([t1.c.data])),
523                "twoHITHERE"
524            )
525
526    @testing.provide_metadata
527    def test_generic_w_pg_variant(self):
528        some_table = Table(
529            'some_table', self.metadata,
530            Column(
531                'data',
532                Enum(
533                    "one", "two", "three",
534                    native_enum=True   # make sure this is True because
535                                       # it should *not* take effect due to
536                                       # the variant
537                ).with_variant(
538                    postgresql.ENUM("four", "five", "six", name="my_enum"),
539                    "postgresql"
540                )
541            )
542        )
543
544        with testing.db.begin() as conn:
545            assert 'my_enum' not in [
546                e['name'] for e in inspect(conn).get_enums()]
547
548            self.metadata.create_all(conn)
549
550            assert 'my_enum' in [
551                e['name'] for e in inspect(conn).get_enums()]
552
553            conn.execute(
554                some_table.insert(), {"data": "five"}
555            )
556
557            self.metadata.drop_all(conn)
558
559            assert 'my_enum' not in [
560                e['name'] for e in inspect(conn).get_enums()]
561
562    @testing.provide_metadata
563    def test_generic_w_some_other_variant(self):
564        some_table = Table(
565            'some_table', self.metadata,
566            Column(
567                'data',
568                Enum(
569                    "one", "two", "three",
570                    name="my_enum",
571                    native_enum=True
572                ).with_variant(
573                    Enum("four", "five", "six"),
574                    "mysql"
575                )
576            )
577        )
578
579        with testing.db.begin() as conn:
580            assert 'my_enum' not in [
581                e['name'] for e in inspect(conn).get_enums()]
582
583            self.metadata.create_all(conn)
584
585            assert 'my_enum' in [
586                e['name'] for e in inspect(conn).get_enums()]
587
588            conn.execute(
589                some_table.insert(), {"data": "two"}
590            )
591
592            self.metadata.drop_all(conn)
593
594            assert 'my_enum' not in [
595                e['name'] for e in inspect(conn).get_enums()]
596
597
598class OIDTest(fixtures.TestBase):
599    __only_on__ = 'postgresql'
600    __backend__ = True
601
602    @testing.provide_metadata
603    def test_reflection(self):
604        metadata = self.metadata
605        Table('table', metadata, Column('x', Integer),
606              Column('y', postgresql.OID))
607        metadata.create_all()
608        m2 = MetaData()
609        t2 = Table('table', m2, autoload_with=testing.db, autoload=True)
610        assert isinstance(t2.c.y.type, postgresql.OID)
611
612
613class NumericInterpretationTest(fixtures.TestBase):
614    __only_on__ = 'postgresql'
615    __backend__ = True
616
617    def test_numeric_codes(self):
618        from sqlalchemy.dialects.postgresql import pg8000, pygresql, \
619            psycopg2, psycopg2cffi, base
620
621        dialects = (pg8000.dialect(), pygresql.dialect(),
622                    psycopg2.dialect(), psycopg2cffi.dialect())
623        for dialect in dialects:
624            typ = Numeric().dialect_impl(dialect)
625            for code in base._INT_TYPES + base._FLOAT_TYPES + \
626                    base._DECIMAL_TYPES:
627                proc = typ.result_processor(dialect, code)
628                val = 23.7
629                if proc is not None:
630                    val = proc(val)
631                assert val in (23.7, decimal.Decimal("23.7"))
632
633    @testing.provide_metadata
634    def test_numeric_default(self):
635        metadata = self.metadata
636        # pg8000 appears to fail when the value is 0,
637        # returns an int instead of decimal.
638        t = Table('t', metadata,
639                  Column('id', Integer, primary_key=True),
640                  Column('nd', Numeric(asdecimal=True), default=1),
641                  Column('nf', Numeric(asdecimal=False), default=1),
642                  Column('fd', Float(asdecimal=True), default=1),
643                  Column('ff', Float(asdecimal=False), default=1),
644                  )
645        metadata.create_all()
646        r = t.insert().execute()
647
648        row = t.select().execute().first()
649        assert isinstance(row[1], decimal.Decimal)
650        assert isinstance(row[2], float)
651        assert isinstance(row[3], decimal.Decimal)
652        assert isinstance(row[4], float)
653        eq_(
654            row,
655            (1, decimal.Decimal("1"), 1, decimal.Decimal("1"), 1)
656        )
657
658
659class PythonTypeTest(fixtures.TestBase):
660    def test_interval(self):
661        is_(
662            postgresql.INTERVAL().python_type,
663            datetime.timedelta
664        )
665
666
667class TimezoneTest(fixtures.TestBase):
668    __backend__ = True
669
670    """Test timezone-aware datetimes.
671
672    psycopg will return a datetime with a tzinfo attached to it, if
673    postgresql returns it.  python then will not let you compare a
674    datetime with a tzinfo to a datetime that doesn't have one.  this
675    test illustrates two ways to have datetime types with and without
676    timezone info. """
677
678    __only_on__ = 'postgresql'
679
680    @classmethod
681    def setup_class(cls):
682        global tztable, notztable, metadata
683        metadata = MetaData(testing.db)
684
685        # current_timestamp() in postgresql is assumed to return
686        # TIMESTAMP WITH TIMEZONE
687
688        tztable = Table(
689            'tztable', metadata,
690            Column(
691                'id', Integer, primary_key=True),
692            Column(
693                'date', DateTime(
694                    timezone=True), onupdate=func.current_timestamp()),
695            Column('name', String(20)))
696        notztable = Table(
697            'notztable', metadata,
698            Column(
699                'id', Integer, primary_key=True),
700            Column(
701                'date', DateTime(
702                    timezone=False), onupdate=cast(
703                    func.current_timestamp(), DateTime(
704                        timezone=False))),
705            Column('name', String(20)))
706        metadata.create_all()
707
708    @classmethod
709    def teardown_class(cls):
710        metadata.drop_all()
711
712    @testing.fails_on('postgresql+zxjdbc',
713                      "XXX: postgresql+zxjdbc doesn't give a tzinfo back")
714    def test_with_timezone(self):
715
716        # get a date with a tzinfo
717
718        somedate = \
719            testing.db.connect().scalar(func.current_timestamp().select())
720        assert somedate.tzinfo
721        tztable.insert().execute(id=1, name='row1', date=somedate)
722        row = select([tztable.c.date], tztable.c.id
723                     == 1).execute().first()
724        eq_(row[0], somedate)
725        eq_(somedate.tzinfo.utcoffset(somedate),
726            row[0].tzinfo.utcoffset(row[0]))
727        result = tztable.update(tztable.c.id
728                                == 1).returning(tztable.c.date).\
729            execute(name='newname'
730                    )
731        row = result.first()
732        assert row[0] >= somedate
733
734    def test_without_timezone(self):
735
736        # get a date without a tzinfo
737
738        somedate = datetime.datetime(2005, 10, 20, 11, 52, 0, )
739        assert not somedate.tzinfo
740        notztable.insert().execute(id=1, name='row1', date=somedate)
741        row = select([notztable.c.date], notztable.c.id
742                     == 1).execute().first()
743        eq_(row[0], somedate)
744        eq_(row[0].tzinfo, None)
745        result = notztable.update(notztable.c.id
746                                  == 1).returning(notztable.c.date).\
747            execute(name='newname'
748                    )
749        row = result.first()
750        assert row[0] >= somedate
751
752
753class TimePrecisionTest(fixtures.TestBase, AssertsCompiledSQL):
754
755    __dialect__ = postgresql.dialect()
756    __prefer__ = 'postgresql'
757    __backend__ = True
758
759    def test_compile(self):
760        for type_, expected in [
761            (postgresql.TIME(), 'TIME WITHOUT TIME ZONE'),
762            (postgresql.TIME(precision=5), 'TIME(5) WITHOUT TIME ZONE'
763             ),
764            (postgresql.TIME(timezone=True, precision=5),
765             'TIME(5) WITH TIME ZONE'),
766            (postgresql.TIMESTAMP(), 'TIMESTAMP WITHOUT TIME ZONE'),
767            (postgresql.TIMESTAMP(precision=5),
768             'TIMESTAMP(5) WITHOUT TIME ZONE'),
769            (postgresql.TIMESTAMP(timezone=True, precision=5),
770             'TIMESTAMP(5) WITH TIME ZONE'),
771            (postgresql.TIME(precision=0),
772             'TIME(0) WITHOUT TIME ZONE'),
773            (postgresql.TIMESTAMP(precision=0),
774             'TIMESTAMP(0) WITHOUT TIME ZONE'),
775        ]:
776            self.assert_compile(type_, expected)
777
778    @testing.only_on('postgresql', 'DB specific feature')
779    @testing.provide_metadata
780    def test_reflection(self):
781        metadata = self.metadata
782        t1 = Table(
783            't1',
784            metadata,
785            Column('c1', postgresql.TIME()),
786            Column('c2', postgresql.TIME(precision=5)),
787            Column('c3', postgresql.TIME(timezone=True, precision=5)),
788            Column('c4', postgresql.TIMESTAMP()),
789            Column('c5', postgresql.TIMESTAMP(precision=5)),
790            Column('c6', postgresql.TIMESTAMP(timezone=True,
791                                              precision=5)),
792        )
793        t1.create()
794        m2 = MetaData(testing.db)
795        t2 = Table('t1', m2, autoload=True)
796        eq_(t2.c.c1.type.precision, None)
797        eq_(t2.c.c2.type.precision, 5)
798        eq_(t2.c.c3.type.precision, 5)
799        eq_(t2.c.c4.type.precision, None)
800        eq_(t2.c.c5.type.precision, 5)
801        eq_(t2.c.c6.type.precision, 5)
802        eq_(t2.c.c1.type.timezone, False)
803        eq_(t2.c.c2.type.timezone, False)
804        eq_(t2.c.c3.type.timezone, True)
805        eq_(t2.c.c4.type.timezone, False)
806        eq_(t2.c.c5.type.timezone, False)
807        eq_(t2.c.c6.type.timezone, True)
808
809
810class ArrayTest(AssertsCompiledSQL, fixtures.TestBase):
811    __dialect__ = 'postgresql'
812
813    def test_array_type_render_str(self):
814        self.assert_compile(
815            postgresql.ARRAY(Unicode(30)),
816            "VARCHAR(30)[]"
817        )
818
819    def test_array_type_render_str_collate(self):
820        self.assert_compile(
821            postgresql.ARRAY(Unicode(30, collation="en_US")),
822            'VARCHAR(30)[] COLLATE "en_US"'
823        )
824
825    def test_array_type_render_str_multidim(self):
826        self.assert_compile(
827            postgresql.ARRAY(Unicode(30), dimensions=2),
828            "VARCHAR(30)[][]"
829        )
830
831        self.assert_compile(
832            postgresql.ARRAY(Unicode(30), dimensions=3),
833            "VARCHAR(30)[][][]"
834        )
835
836    def test_array_type_render_str_collate_multidim(self):
837        self.assert_compile(
838            postgresql.ARRAY(Unicode(30, collation="en_US"), dimensions=2),
839            'VARCHAR(30)[][] COLLATE "en_US"'
840        )
841
842        self.assert_compile(
843            postgresql.ARRAY(Unicode(30, collation="en_US"), dimensions=3),
844            'VARCHAR(30)[][][] COLLATE "en_US"'
845        )
846
847
848    def test_array_int_index(self):
849        col = column('x', postgresql.ARRAY(Integer))
850        self.assert_compile(
851            select([col[3]]),
852            "SELECT x[%(x_1)s] AS anon_1",
853            checkparams={'x_1': 3}
854        )
855
856    def test_array_any(self):
857        col = column('x', postgresql.ARRAY(Integer))
858        self.assert_compile(
859            select([col.any(7, operator=operators.lt)]),
860            "SELECT %(param_1)s < ANY (x) AS anon_1",
861            checkparams={'param_1': 7}
862        )
863
864    def test_array_all(self):
865        col = column('x', postgresql.ARRAY(Integer))
866        self.assert_compile(
867            select([col.all(7, operator=operators.lt)]),
868            "SELECT %(param_1)s < ALL (x) AS anon_1",
869            checkparams={'param_1': 7}
870        )
871
872    def test_array_contains(self):
873        col = column('x', postgresql.ARRAY(Integer))
874        self.assert_compile(
875            select([col.contains(array([4, 5, 6]))]),
876            "SELECT x @> ARRAY[%(param_1)s, %(param_2)s, %(param_3)s] "
877            "AS anon_1",
878            checkparams={'param_1': 4, 'param_3': 6, 'param_2': 5}
879        )
880
881    def test_contains_override_raises(self):
882        col = column('x', postgresql.ARRAY(Integer))
883
884        assert_raises_message(
885            NotImplementedError,
886            "Operator 'contains' is not supported on this expression",
887            lambda: 'foo' in col
888        )
889
890    def test_array_contained_by(self):
891        col = column('x', postgresql.ARRAY(Integer))
892        self.assert_compile(
893            select([col.contained_by(array([4, 5, 6]))]),
894            "SELECT x <@ ARRAY[%(param_1)s, %(param_2)s, %(param_3)s] "
895            "AS anon_1",
896            checkparams={'param_1': 4, 'param_3': 6, 'param_2': 5}
897        )
898
899    def test_array_overlap(self):
900        col = column('x', postgresql.ARRAY(Integer))
901        self.assert_compile(
902            select([col.overlap(array([4, 5, 6]))]),
903            "SELECT x && ARRAY[%(param_1)s, %(param_2)s, %(param_3)s] "
904            "AS anon_1",
905            checkparams={'param_1': 4, 'param_3': 6, 'param_2': 5}
906        )
907
908    def test_array_slice_index(self):
909        col = column('x', postgresql.ARRAY(Integer))
910        self.assert_compile(
911            select([col[5:10]]),
912            "SELECT x[%(x_1)s:%(x_2)s] AS anon_1",
913            checkparams={'x_2': 10, 'x_1': 5}
914        )
915
916    def test_array_dim_index(self):
917        col = column('x', postgresql.ARRAY(Integer, dimensions=2))
918        self.assert_compile(
919            select([col[3][5]]),
920            "SELECT x[%(x_1)s][%(param_1)s] AS anon_1",
921            checkparams={'x_1': 3, 'param_1': 5}
922        )
923
924    def test_array_concat(self):
925        col = column('x', postgresql.ARRAY(Integer))
926        literal = array([4, 5])
927
928        self.assert_compile(
929            select([col + literal]),
930            "SELECT x || ARRAY[%(param_1)s, %(param_2)s] AS anon_1",
931            checkparams={'param_1': 4, 'param_2': 5}
932        )
933
934    def test_array_index_map_dimensions(self):
935        col = column('x', postgresql.ARRAY(Integer, dimensions=3))
936        is_(
937            col[5].type._type_affinity, ARRAY
938        )
939        assert isinstance(
940            col[5].type, postgresql.ARRAY
941        )
942        eq_(
943            col[5].type.dimensions, 2
944        )
945        is_(
946            col[5][6].type._type_affinity, ARRAY
947        )
948        assert isinstance(
949            col[5][6].type, postgresql.ARRAY
950        )
951        eq_(
952            col[5][6].type.dimensions, 1
953        )
954        is_(
955            col[5][6][7].type._type_affinity, Integer
956        )
957
958    def test_array_getitem_single_type(self):
959        m = MetaData()
960        arrtable = Table(
961            'arrtable', m,
962            Column('intarr', postgresql.ARRAY(Integer)),
963            Column('strarr', postgresql.ARRAY(String)),
964        )
965        is_(arrtable.c.intarr[1].type._type_affinity, Integer)
966        is_(arrtable.c.strarr[1].type._type_affinity, String)
967
968    def test_array_getitem_slice_type(self):
969        m = MetaData()
970        arrtable = Table(
971            'arrtable', m,
972            Column('intarr', postgresql.ARRAY(Integer)),
973            Column('strarr', postgresql.ARRAY(String)),
974        )
975
976        # type affinity is Array...
977        is_(arrtable.c.intarr[1:3].type._type_affinity, ARRAY)
978        is_(arrtable.c.strarr[1:3].type._type_affinity, ARRAY)
979
980        # but the slice returns the actual type
981        assert isinstance(arrtable.c.intarr[1:3].type, postgresql.ARRAY)
982        assert isinstance(arrtable.c.strarr[1:3].type, postgresql.ARRAY)
983
984    def test_array_functions_plus_getitem(self):
985        """test parenthesizing of functions plus indexing, which seems
986        to be required by PostgreSQL.
987
988        """
989        stmt = select([
990            func.array_cat(
991                array([1, 2, 3]),
992                array([4, 5, 6]),
993                type_=postgresql.ARRAY(Integer)
994            )[2:5]
995        ])
996        self.assert_compile(
997            stmt,
998            "SELECT (array_cat(ARRAY[%(param_1)s, %(param_2)s, %(param_3)s], "
999            "ARRAY[%(param_4)s, %(param_5)s, %(param_6)s]))"
1000            "[%(param_7)s:%(param_8)s] AS anon_1"
1001        )
1002
1003        self.assert_compile(
1004            func.array_cat(
1005                array([1, 2, 3]),
1006                array([4, 5, 6]),
1007                type_=postgresql.ARRAY(Integer)
1008            )[3],
1009            "(array_cat(ARRAY[%(param_1)s, %(param_2)s, %(param_3)s], "
1010            "ARRAY[%(param_4)s, %(param_5)s, %(param_6)s]))[%(array_cat_1)s]"
1011        )
1012
1013    def test_array_agg_generic(self):
1014        expr = func.array_agg(column('q', Integer))
1015        is_(expr.type.__class__, sqltypes.ARRAY)
1016        is_(expr.type.item_type.__class__, Integer)
1017
1018    def test_array_agg_specific(self):
1019        from sqlalchemy.dialects.postgresql import array_agg
1020        expr = array_agg(column('q', Integer))
1021        is_(expr.type.__class__, postgresql.ARRAY)
1022        is_(expr.type.item_type.__class__, Integer)
1023
1024
1025class ArrayRoundTripTest(fixtures.TablesTest, AssertsExecutionResults):
1026
1027    __only_on__ = 'postgresql'
1028    __backend__ = True
1029    __unsupported_on__ = 'postgresql+pg8000', 'postgresql+zxjdbc'
1030
1031    @classmethod
1032    def define_tables(cls, metadata):
1033
1034        class ProcValue(TypeDecorator):
1035            impl = postgresql.ARRAY(Integer, dimensions=2)
1036
1037            def process_bind_param(self, value, dialect):
1038                if value is None:
1039                    return None
1040                return [
1041                    [x + 5 for x in v]
1042                    for v in value
1043                ]
1044
1045            def process_result_value(self, value, dialect):
1046                if value is None:
1047                    return None
1048                return [
1049                    [x - 7 for x in v]
1050                    for v in value
1051                ]
1052
1053        Table('arrtable', metadata,
1054              Column('id', Integer, primary_key=True),
1055              Column('intarr', postgresql.ARRAY(Integer)),
1056              Column('strarr', postgresql.ARRAY(Unicode())),
1057              Column('dimarr', ProcValue)
1058              )
1059
1060        Table('dim_arrtable', metadata,
1061              Column('id', Integer, primary_key=True),
1062              Column('intarr', postgresql.ARRAY(Integer, dimensions=1)),
1063              Column('strarr', postgresql.ARRAY(Unicode(), dimensions=1)),
1064              Column('dimarr', ProcValue)
1065              )
1066
1067    def _fixture_456(self, table):
1068        testing.db.execute(
1069            table.insert(),
1070            intarr=[4, 5, 6]
1071        )
1072
1073    def test_reflect_array_column(self):
1074        metadata2 = MetaData(testing.db)
1075        tbl = Table('arrtable', metadata2, autoload=True)
1076        assert isinstance(tbl.c.intarr.type, postgresql.ARRAY)
1077        assert isinstance(tbl.c.strarr.type, postgresql.ARRAY)
1078        assert isinstance(tbl.c.intarr.type.item_type, Integer)
1079        assert isinstance(tbl.c.strarr.type.item_type, String)
1080
1081    @testing.provide_metadata
1082    def test_array_str_collation(self):
1083        m = self.metadata
1084
1085        t = Table(
1086            't', m, Column('data',
1087                           sqltypes.ARRAY(String(50, collation="en_US")))
1088        )
1089
1090        t.create()
1091
1092    @testing.provide_metadata
1093    def test_array_agg(self):
1094        values_table = Table('values', self.metadata, Column('value', Integer))
1095        self.metadata.create_all(testing.db)
1096        testing.db.execute(
1097            values_table.insert(),
1098            [{'value': i} for i in range(1, 10)]
1099        )
1100
1101        stmt = select([func.array_agg(values_table.c.value)])
1102        eq_(
1103            testing.db.execute(stmt).scalar(),
1104            list(range(1, 10))
1105        )
1106
1107        stmt = select([func.array_agg(values_table.c.value)[3]])
1108        eq_(
1109            testing.db.execute(stmt).scalar(),
1110            3
1111        )
1112
1113        stmt = select([func.array_agg(values_table.c.value)[2:4]])
1114        eq_(
1115            testing.db.execute(stmt).scalar(),
1116            [2, 3, 4]
1117        )
1118
1119    def test_array_index_slice_exprs(self):
1120        """test a variety of expressions that sometimes need parenthesizing"""
1121
1122        stmt = select([array([1, 2, 3, 4])[2:3]])
1123        eq_(
1124            testing.db.execute(stmt).scalar(),
1125            [2, 3]
1126        )
1127
1128        stmt = select([array([1, 2, 3, 4])[2]])
1129        eq_(
1130            testing.db.execute(stmt).scalar(),
1131            2
1132        )
1133
1134        stmt = select([(array([1, 2]) + array([3, 4]))[2:3]])
1135        eq_(
1136            testing.db.execute(stmt).scalar(),
1137            [2, 3]
1138        )
1139
1140        stmt = select([array([1, 2]) + array([3, 4])[2:3]])
1141        eq_(
1142            testing.db.execute(stmt).scalar(),
1143            [1, 2, 4]
1144        )
1145
1146        stmt = select([array([1, 2])[2:3] + array([3, 4])])
1147        eq_(
1148            testing.db.execute(stmt).scalar(),
1149            [2, 3, 4]
1150        )
1151
1152        stmt = select([
1153            func.array_cat(
1154                array([1, 2, 3]),
1155                array([4, 5, 6]),
1156                type_=postgresql.ARRAY(Integer)
1157            )[2:5]
1158        ])
1159        eq_(
1160            testing.db.execute(stmt).scalar(), [2, 3, 4, 5]
1161        )
1162
1163    def test_any_all_exprs(self):
1164        stmt = select([
1165            3 == any_(func.array_cat(
1166                array([1, 2, 3]),
1167                array([4, 5, 6]),
1168                type_=postgresql.ARRAY(Integer)
1169            ))
1170        ])
1171        eq_(
1172            testing.db.execute(stmt).scalar(), True
1173        )
1174
1175    def test_insert_array(self):
1176        arrtable = self.tables.arrtable
1177        arrtable.insert().execute(intarr=[1, 2, 3], strarr=[util.u('abc'),
1178                                                            util.u('def')])
1179        results = arrtable.select().execute().fetchall()
1180        eq_(len(results), 1)
1181        eq_(results[0]['intarr'], [1, 2, 3])
1182        eq_(results[0]['strarr'], [util.u('abc'), util.u('def')])
1183
1184    def test_insert_array_w_null(self):
1185        arrtable = self.tables.arrtable
1186        arrtable.insert().execute(intarr=[1, None, 3], strarr=[util.u('abc'),
1187                                                            None])
1188        results = arrtable.select().execute().fetchall()
1189        eq_(len(results), 1)
1190        eq_(results[0]['intarr'], [1, None, 3])
1191        eq_(results[0]['strarr'], [util.u('abc'), None])
1192
1193    def test_array_where(self):
1194        arrtable = self.tables.arrtable
1195        arrtable.insert().execute(intarr=[1, 2, 3], strarr=[util.u('abc'),
1196                                                            util.u('def')])
1197        arrtable.insert().execute(intarr=[4, 5, 6], strarr=util.u('ABC'))
1198        results = arrtable.select().where(
1199            arrtable.c.intarr == [
1200                1,
1201                2,
1202                3]).execute().fetchall()
1203        eq_(len(results), 1)
1204        eq_(results[0]['intarr'], [1, 2, 3])
1205
1206    def test_array_concat(self):
1207        arrtable = self.tables.arrtable
1208        arrtable.insert().execute(intarr=[1, 2, 3],
1209                                  strarr=[util.u('abc'), util.u('def')])
1210        results = select([arrtable.c.intarr + [4, 5,
1211                                               6]]).execute().fetchall()
1212        eq_(len(results), 1)
1213        eq_(results[0][0], [1, 2, 3, 4, 5, 6, ])
1214
1215    def test_array_comparison(self):
1216        arrtable = self.tables.arrtable
1217        arrtable.insert().execute(id=5, intarr=[1, 2, 3],
1218                                  strarr=[util.u('abc'), util.u('def')])
1219        results = select([arrtable.c.id])\
1220            .where(arrtable.c.intarr < [4, 5, 6])\
1221            .execute()\
1222            .fetchall()
1223        eq_(len(results), 1)
1224        eq_(results[0][0], 5)
1225
1226    def test_array_subtype_resultprocessor(self):
1227        arrtable = self.tables.arrtable
1228        arrtable.insert().execute(intarr=[4, 5, 6],
1229                                  strarr=[[util.ue('m\xe4\xe4')], [
1230                                      util.ue('m\xf6\xf6')]])
1231        arrtable.insert().execute(intarr=[1, 2, 3], strarr=[
1232            util.ue('m\xe4\xe4'), util.ue('m\xf6\xf6')])
1233        results = \
1234            arrtable.select(order_by=[arrtable.c.intarr]).execute().fetchall()
1235        eq_(len(results), 2)
1236        eq_(results[0]['strarr'], [util.ue('m\xe4\xe4'), util.ue('m\xf6\xf6')])
1237        eq_(results[1]['strarr'],
1238            [[util.ue('m\xe4\xe4')],
1239             [util.ue('m\xf6\xf6')]])
1240
1241    def test_array_literal(self):
1242        eq_(
1243            testing.db.scalar(
1244                select([
1245                    postgresql.array([1, 2]) + postgresql.array([3, 4, 5])
1246                ])
1247            ), [1, 2, 3, 4, 5]
1248        )
1249
1250    def test_array_literal_compare(self):
1251        eq_(
1252            testing.db.scalar(
1253                select([
1254                    postgresql.array([1, 2]) < [3, 4, 5]
1255                ])
1256                ), True
1257        )
1258
1259    def test_array_getitem_single_exec(self):
1260        arrtable = self.tables.arrtable
1261        self._fixture_456(arrtable)
1262        eq_(
1263            testing.db.scalar(select([arrtable.c.intarr[2]])),
1264            5
1265        )
1266        testing.db.execute(
1267            arrtable.update().values({arrtable.c.intarr[2]: 7})
1268        )
1269        eq_(
1270            testing.db.scalar(select([arrtable.c.intarr[2]])),
1271            7
1272        )
1273
1274    def test_undim_array_empty(self):
1275        arrtable = self.tables.arrtable
1276        self._fixture_456(arrtable)
1277        eq_(
1278            testing.db.scalar(
1279                select([arrtable.c.intarr]).
1280                where(arrtable.c.intarr.contains([]))
1281            ),
1282            [4, 5, 6]
1283        )
1284
1285    def test_array_getitem_slice_exec(self):
1286        arrtable = self.tables.arrtable
1287        testing.db.execute(
1288            arrtable.insert(),
1289            intarr=[4, 5, 6],
1290            strarr=[util.u('abc'), util.u('def')]
1291        )
1292        eq_(
1293            testing.db.scalar(select([arrtable.c.intarr[2:3]])),
1294            [5, 6]
1295        )
1296        testing.db.execute(
1297            arrtable.update().values({arrtable.c.intarr[2:3]: [7, 8]})
1298        )
1299        eq_(
1300            testing.db.scalar(select([arrtable.c.intarr[2:3]])),
1301            [7, 8]
1302        )
1303
1304    def _test_undim_array_contains_typed_exec(self, struct):
1305        arrtable = self.tables.arrtable
1306        self._fixture_456(arrtable)
1307        eq_(
1308            testing.db.scalar(
1309                select([arrtable.c.intarr]).
1310                where(arrtable.c.intarr.contains(struct([4, 5])))
1311            ),
1312            [4, 5, 6]
1313        )
1314
1315    def test_undim_array_contains_set_exec(self):
1316        self._test_undim_array_contains_typed_exec(set)
1317
1318    def test_undim_array_contains_list_exec(self):
1319        self._test_undim_array_contains_typed_exec(list)
1320
1321    def test_undim_array_contains_generator_exec(self):
1322        self._test_undim_array_contains_typed_exec(
1323            lambda elem: (x for x in elem))
1324
1325    def _test_dim_array_contains_typed_exec(self, struct):
1326        dim_arrtable = self.tables.dim_arrtable
1327        self._fixture_456(dim_arrtable)
1328        eq_(
1329            testing.db.scalar(
1330                select([dim_arrtable.c.intarr]).
1331                where(dim_arrtable.c.intarr.contains(struct([4, 5])))
1332            ),
1333            [4, 5, 6]
1334        )
1335
1336    def test_dim_array_contains_set_exec(self):
1337        self._test_dim_array_contains_typed_exec(set)
1338
1339    def test_dim_array_contains_list_exec(self):
1340        self._test_dim_array_contains_typed_exec(list)
1341
1342    def test_dim_array_contains_generator_exec(self):
1343        self._test_dim_array_contains_typed_exec(
1344            lambda elem: (
1345                x for x in elem))
1346
1347    def test_multi_dim_roundtrip(self):
1348        arrtable = self.tables.arrtable
1349        testing.db.execute(arrtable.insert(), dimarr=[[1, 2, 3], [4, 5, 6]])
1350        eq_(
1351            testing.db.scalar(select([arrtable.c.dimarr])),
1352            [[-1, 0, 1], [2, 3, 4]]
1353        )
1354
1355    def test_array_contained_by_exec(self):
1356        arrtable = self.tables.arrtable
1357        with testing.db.connect() as conn:
1358            conn.execute(
1359                arrtable.insert(),
1360                intarr=[6, 5, 4]
1361            )
1362            eq_(
1363                conn.scalar(
1364                    select([arrtable.c.intarr.contained_by([4, 5, 6, 7])])
1365                ),
1366                True
1367            )
1368
1369    def test_array_overlap_exec(self):
1370        arrtable = self.tables.arrtable
1371        with testing.db.connect() as conn:
1372            conn.execute(
1373                arrtable.insert(),
1374                intarr=[4, 5, 6]
1375            )
1376            eq_(
1377                conn.scalar(
1378                    select([arrtable.c.intarr]).
1379                    where(arrtable.c.intarr.overlap([7, 6]))
1380                ),
1381                [4, 5, 6]
1382            )
1383
1384    def test_array_any_exec(self):
1385        arrtable = self.tables.arrtable
1386        with testing.db.connect() as conn:
1387            conn.execute(
1388                arrtable.insert(),
1389                intarr=[4, 5, 6]
1390            )
1391            eq_(
1392                conn.scalar(
1393                    select([arrtable.c.intarr]).
1394                    where(postgresql.Any(5, arrtable.c.intarr))
1395                ),
1396                [4, 5, 6]
1397            )
1398
1399    def test_array_all_exec(self):
1400        arrtable = self.tables.arrtable
1401        with testing.db.connect() as conn:
1402            conn.execute(
1403                arrtable.insert(),
1404                intarr=[4, 5, 6]
1405            )
1406            eq_(
1407                conn.scalar(
1408                    select([arrtable.c.intarr]).
1409                    where(arrtable.c.intarr.all(4, operator=operators.le))
1410                ),
1411                [4, 5, 6]
1412            )
1413
1414    @testing.provide_metadata
1415    def test_tuple_flag(self):
1416        metadata = self.metadata
1417
1418        t1 = Table(
1419            't1', metadata,
1420            Column('id', Integer, primary_key=True),
1421            Column('data', postgresql.ARRAY(String(5), as_tuple=True)),
1422            Column(
1423                'data2',
1424                postgresql.ARRAY(
1425                    Numeric(asdecimal=False), as_tuple=True)
1426            )
1427        )
1428        metadata.create_all()
1429        testing.db.execute(
1430            t1.insert(), id=1, data=[
1431                "1", "2", "3"], data2=[
1432                5.4, 5.6])
1433        testing.db.execute(
1434            t1.insert(),
1435            id=2,
1436            data=[
1437                "4",
1438                "5",
1439                "6"],
1440            data2=[1.0])
1441        testing.db.execute(t1.insert(), id=3, data=[["4", "5"], ["6", "7"]],
1442                           data2=[[5.4, 5.6], [1.0, 1.1]])
1443
1444        r = testing.db.execute(t1.select().order_by(t1.c.id)).fetchall()
1445        eq_(
1446            r,
1447            [
1448                (1, ('1', '2', '3'), (5.4, 5.6)),
1449                (2, ('4', '5', '6'), (1.0,)),
1450                (3, (('4', '5'), ('6', '7')), ((5.4, 5.6), (1.0, 1.1)))
1451            ]
1452        )
1453        # hashable
1454        eq_(
1455            set(row[1] for row in r),
1456            set([('1', '2', '3'), ('4', '5', '6'), (('4', '5'), ('6', '7'))])
1457        )
1458
1459    def test_array_plus_native_enum_create(self):
1460        m = MetaData()
1461        t = Table(
1462            't', m,
1463            Column(
1464                'data_1',
1465                postgresql.ARRAY(
1466                    postgresql.ENUM('a', 'b', 'c', name='my_enum_1')
1467                )
1468            ),
1469            Column(
1470                'data_2',
1471                postgresql.ARRAY(
1472                    sqltypes.Enum('a', 'b', 'c', name='my_enum_2')
1473                )
1474            )
1475        )
1476
1477        t.create(testing.db)
1478        eq_(
1479            set(e['name'] for e in inspect(testing.db).get_enums()),
1480            set(['my_enum_1', 'my_enum_2'])
1481        )
1482        t.drop(testing.db)
1483        eq_(inspect(testing.db).get_enums(), [])
1484
1485
1486class HashableFlagORMTest(fixtures.TestBase):
1487    """test the various 'collection' types that they flip the 'hashable' flag
1488    appropriately.  [ticket:3499]"""
1489
1490    __only_on__ = 'postgresql'
1491
1492    def _test(self, type_, data):
1493        Base = declarative_base(metadata=self.metadata)
1494
1495        class A(Base):
1496            __tablename__ = 'a1'
1497            id = Column(Integer, primary_key=True)
1498            data = Column(type_)
1499        Base.metadata.create_all(testing.db)
1500        s = Session(testing.db)
1501        s.add_all([
1502            A(data=elem) for elem in data
1503        ])
1504        s.commit()
1505
1506        eq_(
1507            [(obj.A.id, obj.data) for obj in
1508             s.query(A, A.data).order_by(A.id)],
1509            list(enumerate(data, 1))
1510        )
1511
1512    @testing.provide_metadata
1513    def test_array(self):
1514        self._test(
1515            postgresql.ARRAY(Text()),
1516            [['a', 'b', 'c'], ['d', 'e', 'f']]
1517        )
1518
1519    @testing.requires.hstore
1520    @testing.provide_metadata
1521    def test_hstore(self):
1522        self._test(
1523            postgresql.HSTORE(),
1524            [
1525                {'a': '1', 'b': '2', 'c': '3'},
1526                {'d': '4', 'e': '5', 'f': '6'}
1527            ]
1528        )
1529
1530    @testing.provide_metadata
1531    def test_json(self):
1532        self._test(
1533            postgresql.JSON(),
1534            [
1535                {'a': '1', 'b': '2', 'c': '3'},
1536                {'d': '4', 'e': {'e1': '5', 'e2': '6'},
1537                 'f': {'f1': [9, 10, 11]}}
1538            ]
1539        )
1540
1541    @testing.requires.postgresql_jsonb
1542    @testing.provide_metadata
1543    def test_jsonb(self):
1544        self._test(
1545            postgresql.JSONB(),
1546            [
1547                {'a': '1', 'b': '2', 'c': '3'},
1548                {'d': '4', 'e': {'e1': '5', 'e2': '6'},
1549                 'f': {'f1': [9, 10, 11]}}
1550            ]
1551        )
1552
1553
1554class TimestampTest(fixtures.TestBase, AssertsExecutionResults):
1555    __only_on__ = 'postgresql'
1556    __backend__ = True
1557
1558    def test_timestamp(self):
1559        engine = testing.db
1560        connection = engine.connect()
1561
1562        s = select([text("timestamp '2007-12-25'")])
1563        result = connection.execute(s).first()
1564        eq_(result[0], datetime.datetime(2007, 12, 25, 0, 0))
1565
1566    def test_interval_arithmetic(self):
1567        # basically testing that we get timedelta back for an INTERVAL
1568        # result.  more of a driver assertion.
1569        engine = testing.db
1570        connection = engine.connect()
1571
1572        s = select([text("timestamp '2007-12-25' - timestamp '2007-11-15'")])
1573        result = connection.execute(s).first()
1574        eq_(result[0], datetime.timedelta(40))
1575
1576
1577class SpecialTypesTest(fixtures.TestBase, ComparesTables, AssertsCompiledSQL):
1578
1579    """test DDL and reflection of PG-specific types """
1580
1581    __only_on__ = 'postgresql >= 8.3.0',
1582    __backend__ = True
1583
1584    @classmethod
1585    def setup_class(cls):
1586        global metadata, table
1587        metadata = MetaData(testing.db)
1588
1589        # create these types so that we can issue
1590        # special SQL92 INTERVAL syntax
1591        class y2m(sqltypes.UserDefinedType, postgresql.INTERVAL):
1592
1593            def get_col_spec(self):
1594                return "INTERVAL YEAR TO MONTH"
1595
1596        class d2s(sqltypes.UserDefinedType, postgresql.INTERVAL):
1597
1598            def get_col_spec(self):
1599                return "INTERVAL DAY TO SECOND"
1600
1601        table = Table(
1602            'sometable', metadata,
1603            Column(
1604                'id', postgresql.UUID, primary_key=True),
1605            Column(
1606                'flag', postgresql.BIT),
1607            Column(
1608                'bitstring', postgresql.BIT(4)),
1609            Column('addr', postgresql.INET),
1610            Column('addr2', postgresql.MACADDR),
1611            Column('addr3', postgresql.CIDR),
1612            Column('doubleprec', postgresql.DOUBLE_PRECISION),
1613            Column('plain_interval', postgresql.INTERVAL),
1614            Column('year_interval', y2m()),
1615            Column('month_interval', d2s()),
1616            Column('precision_interval', postgresql.INTERVAL(
1617                precision=3)),
1618            Column('tsvector_document', postgresql.TSVECTOR))
1619
1620        metadata.create_all()
1621
1622        # cheat so that the "strict type check"
1623        # works
1624        table.c.year_interval.type = postgresql.INTERVAL()
1625        table.c.month_interval.type = postgresql.INTERVAL()
1626
1627    @classmethod
1628    def teardown_class(cls):
1629        metadata.drop_all()
1630
1631    def test_reflection(self):
1632        m = MetaData(testing.db)
1633        t = Table('sometable', m, autoload=True)
1634
1635        self.assert_tables_equal(table, t, strict_types=True)
1636        assert t.c.plain_interval.type.precision is None
1637        assert t.c.precision_interval.type.precision == 3
1638        assert t.c.bitstring.type.length == 4
1639
1640    def test_bit_compile(self):
1641        pairs = [(postgresql.BIT(), 'BIT(1)'),
1642                 (postgresql.BIT(5), 'BIT(5)'),
1643                 (postgresql.BIT(varying=True), 'BIT VARYING'),
1644                 (postgresql.BIT(5, varying=True), 'BIT VARYING(5)'),
1645                 ]
1646        for type_, expected in pairs:
1647            self.assert_compile(type_, expected)
1648
1649    @testing.provide_metadata
1650    def test_tsvector_round_trip(self):
1651        t = Table('t1', self.metadata, Column('data', postgresql.TSVECTOR))
1652        t.create()
1653        testing.db.execute(t.insert(), data="a fat cat sat")
1654        eq_(testing.db.scalar(select([t.c.data])), "'a' 'cat' 'fat' 'sat'")
1655
1656        testing.db.execute(t.update(), data="'a' 'cat' 'fat' 'mat' 'sat'")
1657
1658        eq_(testing.db.scalar(select([t.c.data])),
1659            "'a' 'cat' 'fat' 'mat' 'sat'")
1660
1661    @testing.provide_metadata
1662    def test_bit_reflection(self):
1663        metadata = self.metadata
1664        t1 = Table('t1', metadata,
1665                   Column('bit1', postgresql.BIT()),
1666                   Column('bit5', postgresql.BIT(5)),
1667                   Column('bitvarying', postgresql.BIT(varying=True)),
1668                   Column('bitvarying5', postgresql.BIT(5, varying=True)),
1669                   )
1670        t1.create()
1671        m2 = MetaData(testing.db)
1672        t2 = Table('t1', m2, autoload=True)
1673        eq_(t2.c.bit1.type.length, 1)
1674        eq_(t2.c.bit1.type.varying, False)
1675        eq_(t2.c.bit5.type.length, 5)
1676        eq_(t2.c.bit5.type.varying, False)
1677        eq_(t2.c.bitvarying.type.length, None)
1678        eq_(t2.c.bitvarying.type.varying, True)
1679        eq_(t2.c.bitvarying5.type.length, 5)
1680        eq_(t2.c.bitvarying5.type.varying, True)
1681
1682
1683class UUIDTest(fixtures.TestBase):
1684
1685    """Test the bind/return values of the UUID type."""
1686
1687    __only_on__ = 'postgresql >= 8.3'
1688    __backend__ = True
1689
1690    @testing.fails_on(
1691        'postgresql+zxjdbc',
1692        'column "data" is of type uuid but expression '
1693        'is of type character varying')
1694    @testing.fails_on('postgresql+pg8000', 'No support for UUID type')
1695    def test_uuid_string(self):
1696        import uuid
1697        self._test_round_trip(
1698            Table('utable', MetaData(),
1699                  Column('data', postgresql.UUID(as_uuid=False))
1700                  ),
1701            str(uuid.uuid4()),
1702            str(uuid.uuid4())
1703        )
1704
1705    @testing.fails_on(
1706        'postgresql+zxjdbc',
1707        'column "data" is of type uuid but expression is '
1708        'of type character varying')
1709    @testing.fails_on('postgresql+pg8000', 'No support for UUID type')
1710    def test_uuid_uuid(self):
1711        import uuid
1712        self._test_round_trip(
1713            Table('utable', MetaData(),
1714                  Column('data', postgresql.UUID(as_uuid=True))
1715                  ),
1716            uuid.uuid4(),
1717            uuid.uuid4()
1718        )
1719
1720    @testing.fails_on('postgresql+zxjdbc',
1721                      'column "data" is of type uuid[] but '
1722                      'expression is of type character varying')
1723    @testing.fails_on('postgresql+pg8000', 'No support for UUID type')
1724    def test_uuid_array(self):
1725        import uuid
1726        self._test_round_trip(
1727            Table(
1728                'utable', MetaData(),
1729                Column('data', postgresql.ARRAY(postgresql.UUID(as_uuid=True)))
1730            ),
1731            [uuid.uuid4(), uuid.uuid4()],
1732            [uuid.uuid4(), uuid.uuid4()],
1733        )
1734
1735    @testing.fails_on('postgresql+zxjdbc',
1736                      'column "data" is of type uuid[] but '
1737                      'expression is of type character varying')
1738    @testing.fails_on('postgresql+pg8000', 'No support for UUID type')
1739    def test_uuid_string_array(self):
1740        import uuid
1741        self._test_round_trip(
1742            Table(
1743                'utable', MetaData(),
1744                Column(
1745                    'data',
1746                    postgresql.ARRAY(postgresql.UUID(as_uuid=False)))
1747            ),
1748            [str(uuid.uuid4()), str(uuid.uuid4())],
1749            [str(uuid.uuid4()), str(uuid.uuid4())],
1750        )
1751
1752    def test_no_uuid_available(self):
1753        from sqlalchemy.dialects.postgresql import base
1754        uuid_type = base._python_UUID
1755        base._python_UUID = None
1756        try:
1757            assert_raises(
1758                NotImplementedError,
1759                postgresql.UUID, as_uuid=True
1760            )
1761        finally:
1762            base._python_UUID = uuid_type
1763
1764    def setup(self):
1765        self.conn = testing.db.connect()
1766        trans = self.conn.begin()
1767
1768    def teardown(self):
1769        self.conn.close()
1770
1771    def _test_round_trip(self, utable, value1, value2, exp_value2=None):
1772        utable.create(self.conn)
1773        self.conn.execute(utable.insert(), {'data': value1})
1774        self.conn.execute(utable.insert(), {'data': value2})
1775        r = self.conn.execute(
1776            select([utable.c.data]).
1777            where(utable.c.data != value1)
1778        )
1779        if exp_value2:
1780            eq_(r.fetchone()[0], exp_value2)
1781        else:
1782            eq_(r.fetchone()[0], value2)
1783        eq_(r.fetchone(), None)
1784
1785
1786class HStoreTest(AssertsCompiledSQL, fixtures.TestBase):
1787    __dialect__ = 'postgresql'
1788
1789    def setup(self):
1790        metadata = MetaData()
1791        self.test_table = Table('test_table', metadata,
1792                                Column('id', Integer, primary_key=True),
1793                                Column('hash', HSTORE)
1794                                )
1795        self.hashcol = self.test_table.c.hash
1796
1797    def _test_where(self, whereclause, expected):
1798        stmt = select([self.test_table]).where(whereclause)
1799        self.assert_compile(
1800            stmt,
1801            "SELECT test_table.id, test_table.hash FROM test_table "
1802            "WHERE %s" % expected
1803        )
1804
1805    def _test_cols(self, colclause, expected, from_=True):
1806        stmt = select([colclause])
1807        self.assert_compile(
1808            stmt,
1809            (
1810                "SELECT %s" +
1811                (" FROM test_table" if from_ else "")
1812            ) % expected
1813        )
1814
1815    def test_bind_serialize_default(self):
1816
1817        dialect = postgresql.dialect()
1818        proc = self.test_table.c.hash.type._cached_bind_processor(dialect)
1819        eq_(
1820            proc(util.OrderedDict([("key1", "value1"), ("key2", "value2")])),
1821            '"key1"=>"value1", "key2"=>"value2"'
1822        )
1823
1824    def test_bind_serialize_with_slashes_and_quotes(self):
1825        dialect = postgresql.dialect()
1826        proc = self.test_table.c.hash.type._cached_bind_processor(dialect)
1827        eq_(
1828            proc({'\\"a': '\\"1'}),
1829            '"\\\\\\"a"=>"\\\\\\"1"'
1830        )
1831
1832    def test_parse_error(self):
1833        dialect = postgresql.dialect()
1834        proc = self.test_table.c.hash.type._cached_result_processor(
1835            dialect, None)
1836        assert_raises_message(
1837            ValueError,
1838            r'''After u?'\[\.\.\.\], "key1"=>"value1", ', could not parse '''
1839            r'''residual at position 36: u?'crapcrapcrap, "key3"\[\.\.\.\]''',
1840            proc,
1841            '"key2"=>"value2", "key1"=>"value1", '
1842            'crapcrapcrap, "key3"=>"value3"'
1843        )
1844
1845    def test_result_deserialize_default(self):
1846        dialect = postgresql.dialect()
1847        proc = self.test_table.c.hash.type._cached_result_processor(
1848            dialect, None)
1849        eq_(
1850            proc('"key2"=>"value2", "key1"=>"value1"'),
1851            {"key1": "value1", "key2": "value2"}
1852        )
1853
1854    def test_result_deserialize_with_slashes_and_quotes(self):
1855        dialect = postgresql.dialect()
1856        proc = self.test_table.c.hash.type._cached_result_processor(
1857            dialect, None)
1858        eq_(
1859            proc('"\\\\\\"a"=>"\\\\\\"1"'),
1860            {'\\"a': '\\"1'}
1861        )
1862
1863    def test_bind_serialize_psycopg2(self):
1864        from sqlalchemy.dialects.postgresql import psycopg2
1865
1866        dialect = psycopg2.PGDialect_psycopg2()
1867        dialect._has_native_hstore = True
1868        proc = self.test_table.c.hash.type._cached_bind_processor(dialect)
1869        is_(proc, None)
1870
1871        dialect = psycopg2.PGDialect_psycopg2()
1872        dialect._has_native_hstore = False
1873        proc = self.test_table.c.hash.type._cached_bind_processor(dialect)
1874        eq_(
1875            proc(util.OrderedDict([("key1", "value1"), ("key2", "value2")])),
1876            '"key1"=>"value1", "key2"=>"value2"'
1877        )
1878
1879    def test_result_deserialize_psycopg2(self):
1880        from sqlalchemy.dialects.postgresql import psycopg2
1881
1882        dialect = psycopg2.PGDialect_psycopg2()
1883        dialect._has_native_hstore = True
1884        proc = self.test_table.c.hash.type._cached_result_processor(
1885            dialect, None)
1886        is_(proc, None)
1887
1888        dialect = psycopg2.PGDialect_psycopg2()
1889        dialect._has_native_hstore = False
1890        proc = self.test_table.c.hash.type._cached_result_processor(
1891            dialect, None)
1892        eq_(
1893            proc('"key2"=>"value2", "key1"=>"value1"'),
1894            {"key1": "value1", "key2": "value2"}
1895        )
1896
1897    def test_ret_type_text(self):
1898        col = column('x', HSTORE())
1899
1900        is_(col['foo'].type.__class__, Text)
1901
1902    def test_ret_type_custom(self):
1903        class MyType(sqltypes.UserDefinedType):
1904            pass
1905
1906        col = column('x', HSTORE(text_type=MyType))
1907
1908        is_(col['foo'].type.__class__, MyType)
1909
1910    def test_where_has_key(self):
1911        self._test_where(
1912            # hide from 2to3
1913            getattr(self.hashcol, 'has_key')('foo'),
1914            "test_table.hash ? %(hash_1)s"
1915        )
1916
1917    def test_where_has_all(self):
1918        self._test_where(
1919            self.hashcol.has_all(postgresql.array(['1', '2'])),
1920            "test_table.hash ?& ARRAY[%(param_1)s, %(param_2)s]"
1921        )
1922
1923    def test_where_has_any(self):
1924        self._test_where(
1925            self.hashcol.has_any(postgresql.array(['1', '2'])),
1926            "test_table.hash ?| ARRAY[%(param_1)s, %(param_2)s]"
1927        )
1928
1929    def test_where_defined(self):
1930        self._test_where(
1931            self.hashcol.defined('foo'),
1932            "defined(test_table.hash, %(defined_1)s)"
1933        )
1934
1935    def test_where_contains(self):
1936        self._test_where(
1937            self.hashcol.contains({'foo': '1'}),
1938            "test_table.hash @> %(hash_1)s"
1939        )
1940
1941    def test_where_contained_by(self):
1942        self._test_where(
1943            self.hashcol.contained_by({'foo': '1', 'bar': None}),
1944            "test_table.hash <@ %(hash_1)s"
1945        )
1946
1947    def test_where_getitem(self):
1948        self._test_where(
1949            self.hashcol['bar'] == None,  # noqa
1950            "(test_table.hash -> %(hash_1)s) IS NULL"
1951        )
1952
1953    def test_cols_get(self):
1954        self._test_cols(
1955            self.hashcol['foo'],
1956            "test_table.hash -> %(hash_1)s AS anon_1",
1957            True
1958        )
1959
1960    def test_cols_delete_single_key(self):
1961        self._test_cols(
1962            self.hashcol.delete('foo'),
1963            "delete(test_table.hash, %(delete_2)s) AS delete_1",
1964            True
1965        )
1966
1967    def test_cols_delete_array_of_keys(self):
1968        self._test_cols(
1969            self.hashcol.delete(postgresql.array(['foo', 'bar'])),
1970            ("delete(test_table.hash, ARRAY[%(param_1)s, %(param_2)s]) "
1971             "AS delete_1"),
1972            True
1973        )
1974
1975    def test_cols_delete_matching_pairs(self):
1976        self._test_cols(
1977            self.hashcol.delete(hstore('1', '2')),
1978            ("delete(test_table.hash, hstore(%(hstore_1)s, %(hstore_2)s)) "
1979             "AS delete_1"),
1980            True
1981        )
1982
1983    def test_cols_slice(self):
1984        self._test_cols(
1985            self.hashcol.slice(postgresql.array(['1', '2'])),
1986            ("slice(test_table.hash, ARRAY[%(param_1)s, %(param_2)s]) "
1987             "AS slice_1"),
1988            True
1989        )
1990
1991    def test_cols_hstore_pair_text(self):
1992        self._test_cols(
1993            hstore('foo', '3')['foo'],
1994            "hstore(%(hstore_1)s, %(hstore_2)s) -> %(hstore_3)s AS anon_1",
1995            False
1996        )
1997
1998    def test_cols_hstore_pair_array(self):
1999        self._test_cols(
2000            hstore(postgresql.array(['1', '2']),
2001                   postgresql.array(['3', None]))['1'],
2002            ("hstore(ARRAY[%(param_1)s, %(param_2)s], "
2003             "ARRAY[%(param_3)s, NULL]) -> %(hstore_1)s AS anon_1"),
2004            False
2005        )
2006
2007    def test_cols_hstore_single_array(self):
2008        self._test_cols(
2009            hstore(postgresql.array(['1', '2', '3', None]))['3'],
2010            ("hstore(ARRAY[%(param_1)s, %(param_2)s, %(param_3)s, NULL]) "
2011             "-> %(hstore_1)s AS anon_1"),
2012            False
2013        )
2014
2015    def test_cols_concat(self):
2016        self._test_cols(
2017            self.hashcol.concat(hstore(cast(self.test_table.c.id, Text), '3')),
2018            ("test_table.hash || hstore(CAST(test_table.id AS TEXT), "
2019             "%(hstore_1)s) AS anon_1"),
2020            True
2021        )
2022
2023    def test_cols_concat_op(self):
2024        self._test_cols(
2025            hstore('foo', 'bar') + self.hashcol,
2026            "hstore(%(hstore_1)s, %(hstore_2)s) || test_table.hash AS anon_1",
2027            True
2028        )
2029
2030    def test_cols_concat_get(self):
2031        self._test_cols(
2032            (self.hashcol + self.hashcol)['foo'],
2033            "(test_table.hash || test_table.hash) -> %(param_1)s AS anon_1"
2034        )
2035
2036    def test_cols_against_is(self):
2037        self._test_cols(
2038            self.hashcol['foo'] != None,  # noqa
2039            "(test_table.hash -> %(hash_1)s) IS NOT NULL AS anon_1"
2040        )
2041
2042    def test_cols_keys(self):
2043        self._test_cols(
2044            # hide from 2to3
2045            getattr(self.hashcol, 'keys')(),
2046            "akeys(test_table.hash) AS akeys_1",
2047            True
2048        )
2049
2050    def test_cols_vals(self):
2051        self._test_cols(
2052            self.hashcol.vals(),
2053            "avals(test_table.hash) AS avals_1",
2054            True
2055        )
2056
2057    def test_cols_array(self):
2058        self._test_cols(
2059            self.hashcol.array(),
2060            "hstore_to_array(test_table.hash) AS hstore_to_array_1",
2061            True
2062        )
2063
2064    def test_cols_matrix(self):
2065        self._test_cols(
2066            self.hashcol.matrix(),
2067            "hstore_to_matrix(test_table.hash) AS hstore_to_matrix_1",
2068            True
2069        )
2070
2071
2072class HStoreRoundTripTest(fixtures.TablesTest):
2073    __requires__ = 'hstore',
2074    __dialect__ = 'postgresql'
2075    __backend__ = True
2076
2077    @classmethod
2078    def define_tables(cls, metadata):
2079        Table('data_table', metadata,
2080              Column('id', Integer, primary_key=True),
2081              Column('name', String(30), nullable=False),
2082              Column('data', HSTORE)
2083              )
2084
2085    def _fixture_data(self, engine):
2086        data_table = self.tables.data_table
2087        engine.execute(
2088            data_table.insert(),
2089            {'name': 'r1', 'data': {"k1": "r1v1", "k2": "r1v2"}},
2090            {'name': 'r2', 'data': {"k1": "r2v1", "k2": "r2v2"}},
2091            {'name': 'r3', 'data': {"k1": "r3v1", "k2": "r3v2"}},
2092            {'name': 'r4', 'data': {"k1": "r4v1", "k2": "r4v2"}},
2093            {'name': 'r5', 'data': {"k1": "r5v1", "k2": "r5v2"}},
2094        )
2095
2096    def _assert_data(self, compare):
2097        data = testing.db.execute(
2098            select([self.tables.data_table.c.data]).
2099            order_by(self.tables.data_table.c.name)
2100        ).fetchall()
2101        eq_([d for d, in data], compare)
2102
2103    def _test_insert(self, engine):
2104        engine.execute(
2105            self.tables.data_table.insert(),
2106            {'name': 'r1', 'data': {"k1": "r1v1", "k2": "r1v2"}}
2107        )
2108        self._assert_data([{"k1": "r1v1", "k2": "r1v2"}])
2109
2110    def _non_native_engine(self):
2111        if testing.requires.psycopg2_native_hstore.enabled:
2112            engine = engines.testing_engine(
2113                options=dict(
2114                    use_native_hstore=False))
2115        else:
2116            engine = testing.db
2117        engine.connect().close()
2118        return engine
2119
2120    def test_reflect(self):
2121        insp = inspect(testing.db)
2122        cols = insp.get_columns('data_table')
2123        assert isinstance(cols[2]['type'], HSTORE)
2124
2125    def test_literal_round_trip(self):
2126        # in particular, this tests that the array index
2127        # operator against the function is handled by PG; with some
2128        # array functions it requires outer parenthezisation on the left and
2129        # we may not be doing that here
2130        expr = hstore(
2131            postgresql.array(['1', '2']),
2132            postgresql.array(['3', None]))['1']
2133        eq_(
2134            testing.db.scalar(
2135                select([expr])
2136            ),
2137            "3"
2138        )
2139
2140    @testing.requires.psycopg2_native_hstore
2141    def test_insert_native(self):
2142        engine = testing.db
2143        self._test_insert(engine)
2144
2145    def test_insert_python(self):
2146        engine = self._non_native_engine()
2147        self._test_insert(engine)
2148
2149    @testing.requires.psycopg2_native_hstore
2150    def test_criterion_native(self):
2151        engine = testing.db
2152        self._fixture_data(engine)
2153        self._test_criterion(engine)
2154
2155    def test_criterion_python(self):
2156        engine = self._non_native_engine()
2157        self._fixture_data(engine)
2158        self._test_criterion(engine)
2159
2160    def _test_criterion(self, engine):
2161        data_table = self.tables.data_table
2162        result = engine.execute(
2163            select([data_table.c.data]).where(
2164                data_table.c.data['k1'] == 'r3v1')).first()
2165        eq_(result, ({'k1': 'r3v1', 'k2': 'r3v2'},))
2166
2167    def _test_fixed_round_trip(self, engine):
2168        s = select([
2169            hstore(
2170                array(['key1', 'key2', 'key3']),
2171                array(['value1', 'value2', 'value3'])
2172            )
2173        ])
2174        eq_(
2175            engine.scalar(s),
2176            {"key1": "value1", "key2": "value2", "key3": "value3"}
2177        )
2178
2179    def test_fixed_round_trip_python(self):
2180        engine = self._non_native_engine()
2181        self._test_fixed_round_trip(engine)
2182
2183    @testing.requires.psycopg2_native_hstore
2184    def test_fixed_round_trip_native(self):
2185        engine = testing.db
2186        self._test_fixed_round_trip(engine)
2187
2188    def _test_unicode_round_trip(self, engine):
2189        s = select([
2190            hstore(
2191                array([util.u('réveillé'), util.u('drôle'), util.u('S’il')]),
2192                array([util.u('réveillé'), util.u('drôle'), util.u('S’il')])
2193            )
2194        ])
2195        eq_(
2196            engine.scalar(s),
2197            {
2198                util.u('réveillé'): util.u('réveillé'),
2199                util.u('drôle'): util.u('drôle'),
2200                util.u('S’il'): util.u('S’il')
2201            }
2202        )
2203
2204    @testing.requires.psycopg2_native_hstore
2205    def test_unicode_round_trip_python(self):
2206        engine = self._non_native_engine()
2207        self._test_unicode_round_trip(engine)
2208
2209    @testing.requires.psycopg2_native_hstore
2210    def test_unicode_round_trip_native(self):
2211        engine = testing.db
2212        self._test_unicode_round_trip(engine)
2213
2214    def test_escaped_quotes_round_trip_python(self):
2215        engine = self._non_native_engine()
2216        self._test_escaped_quotes_round_trip(engine)
2217
2218    @testing.requires.psycopg2_native_hstore
2219    def test_escaped_quotes_round_trip_native(self):
2220        engine = testing.db
2221        self._test_escaped_quotes_round_trip(engine)
2222
2223    def _test_escaped_quotes_round_trip(self, engine):
2224        engine.execute(
2225            self.tables.data_table.insert(),
2226            {'name': 'r1', 'data': {r'key \"foo\"': r'value \"bar"\ xyz'}}
2227        )
2228        self._assert_data([{r'key \"foo\"': r'value \"bar"\ xyz'}])
2229
2230    def test_orm_round_trip(self):
2231        from sqlalchemy import orm
2232
2233        class Data(object):
2234
2235            def __init__(self, name, data):
2236                self.name = name
2237                self.data = data
2238        orm.mapper(Data, self.tables.data_table)
2239        s = orm.Session(testing.db)
2240        d = Data(name='r1', data={"key1": "value1", "key2": "value2",
2241                                  "key3": "value3"})
2242        s.add(d)
2243        eq_(
2244            s.query(Data.data, Data).all(),
2245            [(d.data, d)]
2246        )
2247
2248
2249class _RangeTypeMixin(object):
2250    __requires__ = 'range_types', 'psycopg2_compatibility'
2251    __backend__ = True
2252
2253    def extras(self):
2254        # done this way so we don't get ImportErrors with
2255        # older psycopg2 versions.
2256        if testing.against("postgresql+psycopg2cffi"):
2257            from psycopg2cffi import extras
2258        else:
2259            from psycopg2 import extras
2260        return extras
2261
2262    @classmethod
2263    def define_tables(cls, metadata):
2264        # no reason ranges shouldn't be primary keys,
2265        # so lets just use them as such
2266        table = Table('data_table', metadata,
2267                      Column('range', cls._col_type, primary_key=True),
2268                      )
2269        cls.col = table.c.range
2270
2271    def test_actual_type(self):
2272        eq_(str(self._col_type()), self._col_str)
2273
2274    def test_reflect(self):
2275        from sqlalchemy import inspect
2276        insp = inspect(testing.db)
2277        cols = insp.get_columns('data_table')
2278        assert isinstance(cols[0]['type'], self._col_type)
2279
2280    def _assert_data(self):
2281        data = testing.db.execute(
2282            select([self.tables.data_table.c.range])
2283        ).fetchall()
2284        eq_(data, [(self._data_obj(), )])
2285
2286    def test_insert_obj(self):
2287        testing.db.engine.execute(
2288            self.tables.data_table.insert(),
2289            {'range': self._data_obj()}
2290        )
2291        self._assert_data()
2292
2293    def test_insert_text(self):
2294        testing.db.engine.execute(
2295            self.tables.data_table.insert(),
2296            {'range': self._data_str}
2297        )
2298        self._assert_data()
2299
2300    # operator tests
2301
2302    def _test_clause(self, colclause, expected):
2303        dialect = postgresql.dialect()
2304        compiled = str(colclause.compile(dialect=dialect))
2305        eq_(compiled, expected)
2306
2307    def test_where_equal(self):
2308        self._test_clause(
2309            self.col == self._data_str,
2310            "data_table.range = %(range_1)s"
2311        )
2312
2313    def test_where_not_equal(self):
2314        self._test_clause(
2315            self.col != self._data_str,
2316            "data_table.range <> %(range_1)s"
2317        )
2318
2319    def test_where_less_than(self):
2320        self._test_clause(
2321            self.col < self._data_str,
2322            "data_table.range < %(range_1)s"
2323        )
2324
2325    def test_where_greater_than(self):
2326        self._test_clause(
2327            self.col > self._data_str,
2328            "data_table.range > %(range_1)s"
2329        )
2330
2331    def test_where_less_than_or_equal(self):
2332        self._test_clause(
2333            self.col <= self._data_str,
2334            "data_table.range <= %(range_1)s"
2335        )
2336
2337    def test_where_greater_than_or_equal(self):
2338        self._test_clause(
2339            self.col >= self._data_str,
2340            "data_table.range >= %(range_1)s"
2341        )
2342
2343    def test_contains(self):
2344        self._test_clause(
2345            self.col.contains(self._data_str),
2346            "data_table.range @> %(range_1)s"
2347        )
2348
2349    def test_contained_by(self):
2350        self._test_clause(
2351            self.col.contained_by(self._data_str),
2352            "data_table.range <@ %(range_1)s"
2353        )
2354
2355    def test_overlaps(self):
2356        self._test_clause(
2357            self.col.overlaps(self._data_str),
2358            "data_table.range && %(range_1)s"
2359        )
2360
2361    def test_strictly_left_of(self):
2362        self._test_clause(
2363            self.col << self._data_str,
2364            "data_table.range << %(range_1)s"
2365        )
2366        self._test_clause(
2367            self.col.strictly_left_of(self._data_str),
2368            "data_table.range << %(range_1)s"
2369        )
2370
2371    def test_strictly_right_of(self):
2372        self._test_clause(
2373            self.col >> self._data_str,
2374            "data_table.range >> %(range_1)s"
2375        )
2376        self._test_clause(
2377            self.col.strictly_right_of(self._data_str),
2378            "data_table.range >> %(range_1)s"
2379        )
2380
2381    def test_not_extend_right_of(self):
2382        self._test_clause(
2383            self.col.not_extend_right_of(self._data_str),
2384            "data_table.range &< %(range_1)s"
2385        )
2386
2387    def test_not_extend_left_of(self):
2388        self._test_clause(
2389            self.col.not_extend_left_of(self._data_str),
2390            "data_table.range &> %(range_1)s"
2391        )
2392
2393    def test_adjacent_to(self):
2394        self._test_clause(
2395            self.col.adjacent_to(self._data_str),
2396            "data_table.range -|- %(range_1)s"
2397        )
2398
2399    def test_union(self):
2400        self._test_clause(
2401            self.col + self.col,
2402            "data_table.range + data_table.range"
2403        )
2404
2405    def test_union_result(self):
2406        # insert
2407        testing.db.engine.execute(
2408            self.tables.data_table.insert(),
2409            {'range': self._data_str}
2410        )
2411        # select
2412        range = self.tables.data_table.c.range
2413        data = testing.db.execute(
2414            select([range + range])
2415        ).fetchall()
2416        eq_(data, [(self._data_obj(), )])
2417
2418    def test_intersection(self):
2419        self._test_clause(
2420            self.col * self.col,
2421            "data_table.range * data_table.range"
2422        )
2423
2424    def test_intersection_result(self):
2425        # insert
2426        testing.db.engine.execute(
2427            self.tables.data_table.insert(),
2428            {'range': self._data_str}
2429        )
2430        # select
2431        range = self.tables.data_table.c.range
2432        data = testing.db.execute(
2433            select([range * range])
2434        ).fetchall()
2435        eq_(data, [(self._data_obj(), )])
2436
2437    def test_different(self):
2438        self._test_clause(
2439            self.col - self.col,
2440            "data_table.range - data_table.range"
2441        )
2442
2443    def test_difference_result(self):
2444        # insert
2445        testing.db.engine.execute(
2446            self.tables.data_table.insert(),
2447            {'range': self._data_str}
2448        )
2449        # select
2450        range = self.tables.data_table.c.range
2451        data = testing.db.execute(
2452            select([range - range])
2453        ).fetchall()
2454        eq_(data, [(self._data_obj().__class__(empty=True), )])
2455
2456
2457class Int4RangeTests(_RangeTypeMixin, fixtures.TablesTest):
2458
2459    _col_type = INT4RANGE
2460    _col_str = 'INT4RANGE'
2461    _data_str = '[1,2)'
2462
2463    def _data_obj(self):
2464        return self.extras().NumericRange(1, 2)
2465
2466
2467class Int8RangeTests(_RangeTypeMixin, fixtures.TablesTest):
2468
2469    _col_type = INT8RANGE
2470    _col_str = 'INT8RANGE'
2471    _data_str = '[9223372036854775806,9223372036854775807)'
2472
2473    def _data_obj(self):
2474        return self.extras().NumericRange(
2475            9223372036854775806, 9223372036854775807
2476        )
2477
2478
2479class NumRangeTests(_RangeTypeMixin, fixtures.TablesTest):
2480
2481    _col_type = NUMRANGE
2482    _col_str = 'NUMRANGE'
2483    _data_str = '[1.0,2.0)'
2484
2485    def _data_obj(self):
2486        return self.extras().NumericRange(
2487            decimal.Decimal('1.0'), decimal.Decimal('2.0')
2488        )
2489
2490
2491class DateRangeTests(_RangeTypeMixin, fixtures.TablesTest):
2492
2493    _col_type = DATERANGE
2494    _col_str = 'DATERANGE'
2495    _data_str = '[2013-03-23,2013-03-24)'
2496
2497    def _data_obj(self):
2498        return self.extras().DateRange(
2499            datetime.date(2013, 3, 23), datetime.date(2013, 3, 24)
2500        )
2501
2502
2503class DateTimeRangeTests(_RangeTypeMixin, fixtures.TablesTest):
2504
2505    _col_type = TSRANGE
2506    _col_str = 'TSRANGE'
2507    _data_str = '[2013-03-23 14:30,2013-03-23 23:30)'
2508
2509    def _data_obj(self):
2510        return self.extras().DateTimeRange(
2511            datetime.datetime(2013, 3, 23, 14, 30),
2512            datetime.datetime(2013, 3, 23, 23, 30)
2513        )
2514
2515
2516class DateTimeTZRangeTests(_RangeTypeMixin, fixtures.TablesTest):
2517
2518    _col_type = TSTZRANGE
2519    _col_str = 'TSTZRANGE'
2520
2521    # make sure we use one, steady timestamp with timezone pair
2522    # for all parts of all these tests
2523    _tstzs = None
2524
2525    def tstzs(self):
2526        if self._tstzs is None:
2527            lower = testing.db.scalar(
2528                func.current_timestamp().select()
2529            )
2530            upper = lower + datetime.timedelta(1)
2531            self._tstzs = (lower, upper)
2532        return self._tstzs
2533
2534    @property
2535    def _data_str(self):
2536        return '[%s,%s)' % self.tstzs()
2537
2538    def _data_obj(self):
2539        return self.extras().DateTimeTZRange(*self.tstzs())
2540
2541
2542class JSONTest(AssertsCompiledSQL, fixtures.TestBase):
2543    __dialect__ = 'postgresql'
2544
2545    def setup(self):
2546        metadata = MetaData()
2547        self.test_table = Table('test_table', metadata,
2548                                Column('id', Integer, primary_key=True),
2549                                Column('test_column', JSON),
2550                                )
2551        self.jsoncol = self.test_table.c.test_column
2552
2553    def _test_where(self, whereclause, expected):
2554        stmt = select([self.test_table]).where(whereclause)
2555        self.assert_compile(
2556            stmt,
2557            "SELECT test_table.id, test_table.test_column FROM test_table "
2558            "WHERE %s" % expected
2559        )
2560
2561    def _test_cols(self, colclause, expected, from_=True):
2562        stmt = select([colclause])
2563        self.assert_compile(
2564            stmt,
2565            (
2566                "SELECT %s" +
2567                (" FROM test_table" if from_ else "")
2568            ) % expected
2569        )
2570
2571    # This test is a bit misleading -- in real life you will need to cast to
2572    # do anything
2573    def test_where_getitem(self):
2574        self._test_where(
2575            self.jsoncol['bar'] == None,  # noqa
2576            "(test_table.test_column -> %(test_column_1)s) IS NULL"
2577        )
2578
2579    def test_where_path(self):
2580        self._test_where(
2581            self.jsoncol[("foo", 1)] == None,  # noqa
2582            "(test_table.test_column #> %(test_column_1)s) IS NULL"
2583        )
2584
2585    def test_path_typing(self):
2586        col = column('x', JSON())
2587        is_(
2588            col['q'].type._type_affinity, sqltypes.JSON
2589        )
2590        is_(
2591            col[('q', )].type._type_affinity, sqltypes.JSON
2592        )
2593        is_(
2594            col['q']['p'].type._type_affinity, sqltypes.JSON
2595        )
2596        is_(
2597            col[('q', 'p')].type._type_affinity, sqltypes.JSON
2598        )
2599
2600    def test_custom_astext_type(self):
2601        class MyType(sqltypes.UserDefinedType):
2602            pass
2603
2604        col = column('x', JSON(astext_type=MyType))
2605
2606        is_(
2607            col['q'].astext.type.__class__, MyType
2608        )
2609
2610        is_(
2611            col[('q', 'p')].astext.type.__class__, MyType
2612        )
2613
2614        is_(
2615            col['q']['p'].astext.type.__class__, MyType
2616        )
2617
2618    def test_where_getitem_as_text(self):
2619        self._test_where(
2620            self.jsoncol['bar'].astext == None,  # noqa
2621            "(test_table.test_column ->> %(test_column_1)s) IS NULL"
2622        )
2623
2624    def test_where_getitem_astext_cast(self):
2625        self._test_where(
2626            self.jsoncol['bar'].astext.cast(Integer) == 5,
2627            "CAST((test_table.test_column ->> %(test_column_1)s) AS INTEGER) "
2628            "= %(param_1)s"
2629        )
2630
2631    def test_where_getitem_json_cast(self):
2632        self._test_where(
2633            self.jsoncol['bar'].cast(Integer) == 5,
2634            "CAST((test_table.test_column -> %(test_column_1)s) AS INTEGER) "
2635            "= %(param_1)s"
2636        )
2637
2638    def test_where_path_as_text(self):
2639        self._test_where(
2640            self.jsoncol[("foo", 1)].astext == None,  # noqa
2641            "(test_table.test_column #>> %(test_column_1)s) IS NULL"
2642        )
2643
2644    def test_cols_get(self):
2645        self._test_cols(
2646            self.jsoncol['foo'],
2647            "test_table.test_column -> %(test_column_1)s AS anon_1",
2648            True
2649        )
2650
2651
2652class JSONRoundTripTest(fixtures.TablesTest):
2653    __only_on__ = ('postgresql >= 9.3',)
2654    __backend__ = True
2655
2656    test_type = JSON
2657
2658    @classmethod
2659    def define_tables(cls, metadata):
2660        Table('data_table', metadata,
2661              Column('id', Integer, primary_key=True),
2662              Column('name', String(30), nullable=False),
2663              Column('data', cls.test_type),
2664              Column('nulldata', cls.test_type(none_as_null=True))
2665              )
2666
2667    def _fixture_data(self, engine):
2668        data_table = self.tables.data_table
2669        engine.execute(
2670            data_table.insert(),
2671            {'name': 'r1', 'data': {"k1": "r1v1", "k2": "r1v2"}},
2672            {'name': 'r2', 'data': {"k1": "r2v1", "k2": "r2v2"}},
2673            {'name': 'r3', 'data': {"k1": "r3v1", "k2": "r3v2"}},
2674            {'name': 'r4', 'data': {"k1": "r4v1", "k2": "r4v2"}},
2675            {'name': 'r5', 'data': {"k1": "r5v1", "k2": "r5v2", "k3": 5}},
2676            {'name': 'r6', 'data': {"k1": {"r6v1": {'subr': [1, 2, 3]}}}},
2677        )
2678
2679    def _assert_data(self, compare, column='data'):
2680        col = self.tables.data_table.c[column]
2681
2682        data = testing.db.execute(
2683            select([col]).
2684            order_by(self.tables.data_table.c.name)
2685        ).fetchall()
2686        eq_([d for d, in data], compare)
2687
2688    def _assert_column_is_NULL(self, column='data'):
2689        col = self.tables.data_table.c[column]
2690
2691        data = testing.db.execute(
2692            select([col]).
2693            where(col.is_(null()))
2694        ).fetchall()
2695        eq_([d for d, in data], [None])
2696
2697    def _assert_column_is_JSON_NULL(self, column='data'):
2698        col = self.tables.data_table.c[column]
2699
2700        data = testing.db.execute(
2701            select([col]).
2702            where(cast(col, String) == "null")
2703        ).fetchall()
2704        eq_([d for d, in data], [None])
2705
2706    def _test_insert(self, engine):
2707        engine.execute(
2708            self.tables.data_table.insert(),
2709            {'name': 'r1', 'data': {"k1": "r1v1", "k2": "r1v2"}}
2710        )
2711        self._assert_data([{"k1": "r1v1", "k2": "r1v2"}])
2712
2713    def _test_insert_nulls(self, engine):
2714        engine.execute(
2715            self.tables.data_table.insert(),
2716            {'name': 'r1', 'data': null()}
2717        )
2718        self._assert_data([None])
2719
2720    def _test_insert_none_as_null(self, engine):
2721        engine.execute(
2722            self.tables.data_table.insert(),
2723            {'name': 'r1', 'nulldata': None}
2724        )
2725        self._assert_column_is_NULL(column='nulldata')
2726
2727    def _test_insert_nulljson_into_none_as_null(self, engine):
2728        engine.execute(
2729            self.tables.data_table.insert(),
2730            {'name': 'r1', 'nulldata': JSON.NULL}
2731        )
2732        self._assert_column_is_JSON_NULL(column='nulldata')
2733
2734    def _non_native_engine(self, json_serializer=None, json_deserializer=None):
2735        if json_serializer is not None or json_deserializer is not None:
2736            options = {
2737                "json_serializer": json_serializer,
2738                "json_deserializer": json_deserializer
2739            }
2740        else:
2741            options = {}
2742
2743        if testing.against("postgresql+psycopg2") and \
2744                testing.db.dialect.psycopg2_version >= (2, 5):
2745            from psycopg2.extras import register_default_json
2746            engine = engines.testing_engine(options=options)
2747
2748            @event.listens_for(engine, "connect")
2749            def connect(dbapi_connection, connection_record):
2750                engine.dialect._has_native_json = False
2751
2752                def pass_(value):
2753                    return value
2754                register_default_json(dbapi_connection, loads=pass_)
2755        elif options:
2756            engine = engines.testing_engine(options=options)
2757        else:
2758            engine = testing.db
2759        engine.connect().close()
2760        return engine
2761
2762    def test_reflect(self):
2763        insp = inspect(testing.db)
2764        cols = insp.get_columns('data_table')
2765        assert isinstance(cols[2]['type'], self.test_type)
2766
2767    @testing.requires.psycopg2_native_json
2768    def test_insert_native(self):
2769        engine = testing.db
2770        self._test_insert(engine)
2771
2772    @testing.requires.psycopg2_native_json
2773    def test_insert_native_nulls(self):
2774        engine = testing.db
2775        self._test_insert_nulls(engine)
2776
2777    @testing.requires.psycopg2_native_json
2778    def test_insert_native_none_as_null(self):
2779        engine = testing.db
2780        self._test_insert_none_as_null(engine)
2781
2782    @testing.requires.psycopg2_native_json
2783    def test_insert_native_nulljson_into_none_as_null(self):
2784        engine = testing.db
2785        self._test_insert_nulljson_into_none_as_null(engine)
2786
2787    def test_insert_python(self):
2788        engine = self._non_native_engine()
2789        self._test_insert(engine)
2790
2791    def test_insert_python_nulls(self):
2792        engine = self._non_native_engine()
2793        self._test_insert_nulls(engine)
2794
2795    def test_insert_python_none_as_null(self):
2796        engine = self._non_native_engine()
2797        self._test_insert_none_as_null(engine)
2798
2799    def test_insert_python_nulljson_into_none_as_null(self):
2800        engine = self._non_native_engine()
2801        self._test_insert_nulljson_into_none_as_null(engine)
2802
2803    def _test_custom_serialize_deserialize(self, native):
2804        import json
2805
2806        def loads(value):
2807            value = json.loads(value)
2808            value['x'] = value['x'] + '_loads'
2809            return value
2810
2811        def dumps(value):
2812            value = dict(value)
2813            value['x'] = 'dumps_y'
2814            return json.dumps(value)
2815
2816        if native:
2817            engine = engines.testing_engine(options=dict(
2818                json_serializer=dumps,
2819                json_deserializer=loads
2820            ))
2821        else:
2822            engine = self._non_native_engine(
2823                json_serializer=dumps,
2824                json_deserializer=loads
2825            )
2826
2827        s = select([
2828            cast(
2829                {
2830                    "key": "value",
2831                    "x": "q"
2832                },
2833                self.test_type
2834            )
2835        ])
2836        eq_(
2837            engine.scalar(s),
2838            {
2839                "key": "value",
2840                "x": "dumps_y_loads"
2841            },
2842        )
2843
2844    @testing.requires.psycopg2_native_json
2845    def test_custom_native(self):
2846        self._test_custom_serialize_deserialize(True)
2847
2848    @testing.requires.psycopg2_native_json
2849    def test_custom_python(self):
2850        self._test_custom_serialize_deserialize(False)
2851
2852    @testing.requires.psycopg2_native_json
2853    def test_criterion_native(self):
2854        engine = testing.db
2855        self._fixture_data(engine)
2856        self._test_criterion(engine)
2857
2858    def test_criterion_python(self):
2859        engine = self._non_native_engine()
2860        self._fixture_data(engine)
2861        self._test_criterion(engine)
2862
2863    def test_path_query(self):
2864        engine = testing.db
2865        self._fixture_data(engine)
2866        data_table = self.tables.data_table
2867
2868        result = engine.execute(
2869            select([data_table.c.name]).where(
2870                data_table.c.data[('k1', 'r6v1', 'subr')].astext == "[1, 2, 3]"
2871            )
2872        )
2873        eq_(result.scalar(), 'r6')
2874
2875    @testing.fails_on(
2876        "postgresql < 9.4",
2877        "Improvement in PostgreSQL behavior?")
2878    def test_multi_index_query(self):
2879        engine = testing.db
2880        self._fixture_data(engine)
2881        data_table = self.tables.data_table
2882
2883        result = engine.execute(
2884            select([data_table.c.name]).where(
2885                data_table.c.data['k1']['r6v1']['subr'].astext == "[1, 2, 3]"
2886            )
2887        )
2888        eq_(result.scalar(), 'r6')
2889
2890    def test_query_returned_as_text(self):
2891        engine = testing.db
2892        self._fixture_data(engine)
2893        data_table = self.tables.data_table
2894        result = engine.execute(
2895            select([data_table.c.data['k1'].astext])
2896        ).first()
2897        if engine.dialect.returns_unicode_strings:
2898            assert isinstance(result[0], util.text_type)
2899        else:
2900            assert isinstance(result[0], util.string_types)
2901
2902    def test_query_returned_as_int(self):
2903        engine = testing.db
2904        self._fixture_data(engine)
2905        data_table = self.tables.data_table
2906        result = engine.execute(
2907            select([data_table.c.data['k3'].astext.cast(Integer)]).where(
2908                data_table.c.name == 'r5')
2909        ).first()
2910        assert isinstance(result[0], int)
2911
2912    def _test_criterion(self, engine):
2913        data_table = self.tables.data_table
2914        result = engine.execute(
2915            select([data_table.c.data]).where(
2916                data_table.c.data['k1'].astext == 'r3v1'
2917            )
2918        ).first()
2919        eq_(result, ({'k1': 'r3v1', 'k2': 'r3v2'},))
2920
2921        result = engine.execute(
2922            select([data_table.c.data]).where(
2923                data_table.c.data['k1'].astext.cast(String) == 'r3v1'
2924            )
2925        ).first()
2926        eq_(result, ({'k1': 'r3v1', 'k2': 'r3v2'},))
2927
2928    def _test_fixed_round_trip(self, engine):
2929        s = select([
2930            cast(
2931                {
2932                    "key": "value",
2933                    "key2": {"k1": "v1", "k2": "v2"}
2934                },
2935                self.test_type
2936            )
2937        ])
2938        eq_(
2939            engine.scalar(s),
2940            {
2941                "key": "value",
2942                "key2": {"k1": "v1", "k2": "v2"}
2943            },
2944        )
2945
2946    def test_fixed_round_trip_python(self):
2947        engine = self._non_native_engine()
2948        self._test_fixed_round_trip(engine)
2949
2950    @testing.requires.psycopg2_native_json
2951    def test_fixed_round_trip_native(self):
2952        engine = testing.db
2953        self._test_fixed_round_trip(engine)
2954
2955    def _test_unicode_round_trip(self, engine):
2956        s = select([
2957            cast(
2958                {
2959                    util.u('réveillé'): util.u('réveillé'),
2960                    "data": {"k1": util.u('drôle')}
2961                },
2962                self.test_type
2963            )
2964        ])
2965        eq_(
2966            engine.scalar(s),
2967            {
2968                util.u('réveillé'): util.u('réveillé'),
2969                "data": {"k1": util.u('drôle')}
2970            },
2971        )
2972
2973    def test_unicode_round_trip_python(self):
2974        engine = self._non_native_engine()
2975        self._test_unicode_round_trip(engine)
2976
2977    @testing.requires.psycopg2_native_json
2978    def test_unicode_round_trip_native(self):
2979        engine = testing.db
2980        self._test_unicode_round_trip(engine)
2981
2982    def test_eval_none_flag_orm(self):
2983        Base = declarative_base()
2984
2985        class Data(Base):
2986            __table__ = self.tables.data_table
2987
2988        s = Session(testing.db)
2989
2990        d1 = Data(name='d1', data=None, nulldata=None)
2991        s.add(d1)
2992        s.commit()
2993
2994        s.bulk_insert_mappings(
2995            Data, [{"name": "d2", "data": None, "nulldata": None}]
2996        )
2997        eq_(
2998            s.query(
2999                cast(self.tables.data_table.c.data, String),
3000                cast(self.tables.data_table.c.nulldata, String)
3001            ).filter(self.tables.data_table.c.name == 'd1').first(),
3002            ("null", None)
3003        )
3004        eq_(
3005            s.query(
3006                cast(self.tables.data_table.c.data, String),
3007                cast(self.tables.data_table.c.nulldata, String)
3008            ).filter(self.tables.data_table.c.name == 'd2').first(),
3009            ("null", None)
3010        )
3011
3012
3013class JSONBTest(JSONTest):
3014
3015    def setup(self):
3016        metadata = MetaData()
3017        self.test_table = Table('test_table', metadata,
3018                                Column('id', Integer, primary_key=True),
3019                                Column('test_column', JSONB)
3020                                )
3021        self.jsoncol = self.test_table.c.test_column
3022
3023    # Note - add fixture data for arrays []
3024
3025    def test_where_has_key(self):
3026        self._test_where(
3027            # hide from 2to3
3028            getattr(self.jsoncol, 'has_key')('data'),
3029            "test_table.test_column ? %(test_column_1)s"
3030        )
3031
3032    def test_where_has_all(self):
3033        self._test_where(
3034            self.jsoncol.has_all(
3035                {'name': 'r1', 'data': {"k1": "r1v1", "k2": "r1v2"}}),
3036            "test_table.test_column ?& %(test_column_1)s")
3037
3038    def test_where_has_any(self):
3039        self._test_where(
3040            self.jsoncol.has_any(postgresql.array(['name', 'data'])),
3041            "test_table.test_column ?| ARRAY[%(param_1)s, %(param_2)s]"
3042        )
3043
3044    def test_where_contains(self):
3045        self._test_where(
3046            self.jsoncol.contains({"k1": "r1v1"}),
3047            "test_table.test_column @> %(test_column_1)s"
3048        )
3049
3050    def test_where_contained_by(self):
3051        self._test_where(
3052            self.jsoncol.contained_by({'foo': '1', 'bar': None}),
3053            "test_table.test_column <@ %(test_column_1)s"
3054        )
3055
3056
3057class JSONBRoundTripTest(JSONRoundTripTest):
3058    __requires__ = ('postgresql_jsonb', )
3059
3060    test_type = JSONB
3061
3062    @testing.requires.postgresql_utf8_server_encoding
3063    def test_unicode_round_trip_python(self):
3064        super(JSONBRoundTripTest, self).test_unicode_round_trip_python()
3065
3066    @testing.requires.postgresql_utf8_server_encoding
3067    def test_unicode_round_trip_native(self):
3068        super(JSONBRoundTripTest, self).test_unicode_round_trip_native()
3069