1# coding: utf-8
2from sqlalchemy.testing.assertions import eq_, assert_raises, \
3    assert_raises_message, is_, AssertsExecutionResults, \
4    AssertsCompiledSQL, ComparesTables
5from sqlalchemy.testing import engines, fixtures
6from sqlalchemy import testing
7import datetime
8from sqlalchemy import Table, MetaData, Column, Integer, Enum, Float, select, \
9    func, DateTime, Numeric, exc, String, cast, REAL, TypeDecorator, Unicode, \
10    Text, null, text
11from sqlalchemy.sql import operators
12from sqlalchemy import types
13import sqlalchemy as sa
14from sqlalchemy.dialects.postgresql import base as postgresql
15from sqlalchemy.dialects.postgresql import HSTORE, hstore, array, \
16    INT4RANGE, INT8RANGE, NUMRANGE, DATERANGE, TSRANGE, TSTZRANGE, \
17    JSON, JSONB
18import decimal
19from sqlalchemy import util
20from sqlalchemy.testing.util import round_decimal
21from sqlalchemy import inspect
22from sqlalchemy import event
23
24tztable = notztable = metadata = table = None
25
26
27class FloatCoercionTest(fixtures.TablesTest, AssertsExecutionResults):
28    __only_on__ = 'postgresql'
29    __dialect__ = postgresql.dialect()
30    __backend__ = True
31
32    @classmethod
33    def define_tables(cls, metadata):
34        data_table = Table('data_table', metadata,
35                           Column('id', Integer, primary_key=True),
36                           Column('data', Integer)
37                           )
38
39    @classmethod
40    def insert_data(cls):
41        data_table = cls.tables.data_table
42
43        data_table.insert().execute(
44            {'data': 3},
45            {'data': 5},
46            {'data': 7},
47            {'data': 2},
48            {'data': 15},
49            {'data': 12},
50            {'data': 6},
51            {'data': 478},
52            {'data': 52},
53            {'data': 9},
54        )
55
56    @testing.fails_on(
57        'postgresql+zxjdbc',
58        'XXX: postgresql+zxjdbc currently returns a Decimal result for Float')
59    def test_float_coercion(self):
60        data_table = self.tables.data_table
61
62        for type_, result in [
63            (Numeric, decimal.Decimal('140.381230939')),
64            (Float, 140.381230939),
65            (Float(asdecimal=True), decimal.Decimal('140.381230939')),
66            (Numeric(asdecimal=False), 140.381230939),
67        ]:
68            ret = testing.db.execute(
69                select([
70                    func.stddev_pop(data_table.c.data, type_=type_)
71                ])
72            ).scalar()
73
74            eq_(round_decimal(ret, 9), result)
75
76            ret = testing.db.execute(
77                select([
78                    cast(func.stddev_pop(data_table.c.data), type_)
79                ])
80            ).scalar()
81            eq_(round_decimal(ret, 9), result)
82
83    @testing.fails_on('postgresql+zxjdbc',
84                      'zxjdbc has no support for PG arrays')
85    @testing.provide_metadata
86    def test_arrays(self):
87        metadata = self.metadata
88        t1 = Table('t', metadata,
89                   Column('x', postgresql.ARRAY(Float)),
90                   Column('y', postgresql.ARRAY(REAL)),
91                   Column('z', postgresql.ARRAY(postgresql.DOUBLE_PRECISION)),
92                   Column('q', postgresql.ARRAY(Numeric))
93                   )
94        metadata.create_all()
95        t1.insert().execute(x=[5], y=[5], z=[6], q=[decimal.Decimal("6.4")])
96        row = t1.select().execute().first()
97        eq_(
98            row,
99            ([5], [5], [6], [decimal.Decimal("6.4")])
100        )
101
102
103class EnumTest(fixtures.TestBase, AssertsExecutionResults):
104    __backend__ = True
105
106    __only_on__ = 'postgresql > 8.3'
107
108    @testing.fails_on('postgresql+zxjdbc',
109                      'zxjdbc fails on ENUM: column "XXX" is of type '
110                      'XXX but expression is of type character varying')
111    def test_create_table(self):
112        metadata = MetaData(testing.db)
113        t1 = Table(
114            'table', metadata,
115            Column(
116                'id', Integer, primary_key=True),
117            Column(
118                'value', Enum(
119                    'one', 'two', 'three', name='onetwothreetype')))
120        t1.create()
121        t1.create(checkfirst=True)  # check the create
122        try:
123            t1.insert().execute(value='two')
124            t1.insert().execute(value='three')
125            t1.insert().execute(value='three')
126            eq_(t1.select().order_by(t1.c.id).execute().fetchall(),
127                [(1, 'two'), (2, 'three'), (3, 'three')])
128        finally:
129            metadata.drop_all()
130            metadata.drop_all()
131
132    def test_name_required(self):
133        metadata = MetaData(testing.db)
134        etype = Enum('four', 'five', 'six', metadata=metadata)
135        assert_raises(exc.CompileError, etype.create)
136        assert_raises(exc.CompileError, etype.compile,
137                      dialect=postgresql.dialect())
138
139    @testing.fails_on('postgresql+zxjdbc',
140                      'zxjdbc fails on ENUM: column "XXX" is of type '
141                      'XXX but expression is of type character varying')
142    @testing.provide_metadata
143    def test_unicode_labels(self):
144        metadata = self.metadata
145        t1 = Table(
146            'table',
147            metadata,
148            Column(
149                'id',
150                Integer,
151                primary_key=True),
152            Column(
153                'value',
154                Enum(
155                    util.u('réveillé'),
156                    util.u('drôle'),
157                    util.u('S’il'),
158                    name='onetwothreetype')))
159        metadata.create_all()
160        t1.insert().execute(value=util.u('drôle'))
161        t1.insert().execute(value=util.u('réveillé'))
162        t1.insert().execute(value=util.u('S’il'))
163        eq_(t1.select().order_by(t1.c.id).execute().fetchall(),
164            [(1, util.u('drôle')), (2, util.u('réveillé')),
165             (3, util.u('S’il'))]
166            )
167        m2 = MetaData(testing.db)
168        t2 = Table('table', m2, autoload=True)
169        eq_(
170            t2.c.value.type.enums,
171            (util.u('réveillé'), util.u('drôle'), util.u('S’il'))
172        )
173
174    @testing.provide_metadata
175    def test_non_native_enum(self):
176        metadata = self.metadata
177        t1 = Table(
178            'foo',
179            metadata,
180            Column(
181                'bar',
182                Enum(
183                    'one',
184                    'two',
185                    'three',
186                    name='myenum',
187                    native_enum=False)))
188
189        def go():
190            t1.create(testing.db)
191
192        self.assert_sql(
193            testing.db, go, [
194                ("CREATE TABLE foo (\tbar "
195                 "VARCHAR(5), \tCONSTRAINT myenum CHECK "
196                 "(bar IN ('one', 'two', 'three')))", {})])
197        with testing.db.begin() as conn:
198            conn.execute(
199                t1.insert(), {'bar': 'two'}
200            )
201            eq_(
202                conn.scalar(select([t1.c.bar])), 'two'
203            )
204
205    @testing.provide_metadata
206    def test_non_native_enum_w_unicode(self):
207        metadata = self.metadata
208        t1 = Table(
209            'foo',
210            metadata,
211            Column(
212                'bar',
213                Enum('B', util.u('Ü'), name='myenum', native_enum=False)))
214
215        def go():
216            t1.create(testing.db)
217
218        self.assert_sql(
219            testing.db,
220            go,
221            [
222                (
223                    util.u(
224                        "CREATE TABLE foo (\tbar "
225                        "VARCHAR(1), \tCONSTRAINT myenum CHECK "
226                        "(bar IN ('B', 'Ü')))"
227                    ),
228                    {}
229                )
230            ])
231
232        with testing.db.begin() as conn:
233            conn.execute(
234                t1.insert(), {'bar': util.u('Ü')}
235            )
236            eq_(
237                conn.scalar(select([t1.c.bar])), util.u('Ü')
238            )
239
240    @testing.provide_metadata
241    def test_disable_create(self):
242        metadata = self.metadata
243
244        e1 = postgresql.ENUM('one', 'two', 'three',
245                             name="myenum",
246                             create_type=False)
247
248        t1 = Table('e1', metadata,
249                   Column('c1', e1)
250                   )
251        # table can be created separately
252        # without conflict
253        e1.create(bind=testing.db)
254        t1.create(testing.db)
255        t1.drop(testing.db)
256        e1.drop(bind=testing.db)
257
258    @testing.provide_metadata
259    def test_generate_multiple(self):
260        """Test that the same enum twice only generates once
261        for the create_all() call, without using checkfirst.
262
263        A 'memo' collection held by the DDL runner
264        now handles this.
265
266        """
267        metadata = self.metadata
268
269        e1 = Enum('one', 'two', 'three',
270                  name="myenum")
271        t1 = Table('e1', metadata,
272                   Column('c1', e1)
273                   )
274
275        t2 = Table('e2', metadata,
276                   Column('c1', e1)
277                   )
278
279        metadata.create_all(checkfirst=False)
280        metadata.drop_all(checkfirst=False)
281        assert 'myenum' not in [
282            e['name'] for e in inspect(testing.db).get_enums()]
283
284    @testing.provide_metadata
285    def test_generate_alone_on_metadata(self):
286        """Test that the same enum twice only generates once
287        for the create_all() call, without using checkfirst.
288
289        A 'memo' collection held by the DDL runner
290        now handles this.
291
292        """
293        metadata = self.metadata
294
295        e1 = Enum('one', 'two', 'three',
296                  name="myenum", metadata=self.metadata)
297
298        metadata.create_all(checkfirst=False)
299        assert 'myenum' in [
300            e['name'] for e in inspect(testing.db).get_enums()]
301        metadata.drop_all(checkfirst=False)
302        assert 'myenum' not in [
303            e['name'] for e in inspect(testing.db).get_enums()]
304
305    @testing.provide_metadata
306    def test_generate_multiple_on_metadata(self):
307        metadata = self.metadata
308
309        e1 = Enum('one', 'two', 'three',
310                  name="myenum", metadata=metadata)
311
312        t1 = Table('e1', metadata,
313                   Column('c1', e1)
314                   )
315
316        t2 = Table('e2', metadata,
317                   Column('c1', e1)
318                   )
319
320        metadata.create_all(checkfirst=False)
321        assert 'myenum' in [
322            e['name'] for e in inspect(testing.db).get_enums()]
323        metadata.drop_all(checkfirst=False)
324        assert 'myenum' not in [
325            e['name'] for e in inspect(testing.db).get_enums()]
326
327        e1.create()  # creates ENUM
328        t1.create()  # does not create ENUM
329        t2.create()  # does not create ENUM
330
331    @testing.provide_metadata
332    def test_drops_on_table(self):
333        metadata = self.metadata
334
335        e1 = Enum('one', 'two', 'three',
336                  name="myenum")
337        table = Table(
338            'e1', metadata,
339            Column('c1', e1)
340        )
341
342        table.create()
343        table.drop()
344        assert 'myenum' not in [
345            e['name'] for e in inspect(testing.db).get_enums()]
346        table.create()
347        assert 'myenum' in [
348            e['name'] for e in inspect(testing.db).get_enums()]
349        table.drop()
350        assert 'myenum' not in [
351            e['name'] for e in inspect(testing.db).get_enums()]
352
353    @testing.provide_metadata
354    def test_remain_on_table_metadata_wide(self):
355        metadata = self.metadata
356
357        e1 = Enum('one', 'two', 'three',
358                  name="myenum", metadata=metadata)
359        table = Table(
360            'e1', metadata,
361            Column('c1', e1)
362        )
363
364        # need checkfirst here, otherwise enum will not be created
365        assert_raises_message(
366            sa.exc.ProgrammingError,
367            '.*type "myenum" does not exist',
368            table.create,
369        )
370        table.create(checkfirst=True)
371        table.drop()
372        table.create(checkfirst=True)
373        table.drop()
374        assert 'myenum' in [
375            e['name'] for e in inspect(testing.db).get_enums()]
376        metadata.drop_all()
377        assert 'myenum' not in [
378            e['name'] for e in inspect(testing.db).get_enums()]
379
380    def test_non_native_dialect(self):
381        engine = engines.testing_engine()
382        engine.connect()
383        engine.dialect.supports_native_enum = False
384        metadata = MetaData()
385        t1 = Table(
386            'foo',
387            metadata,
388            Column(
389                'bar',
390                Enum(
391                    'one',
392                    'two',
393                    'three',
394                    name='myenum')))
395
396        def go():
397            t1.create(engine)
398
399        try:
400            self.assert_sql(
401                engine, go, [
402                    ("CREATE TABLE foo (bar "
403                     "VARCHAR(5), CONSTRAINT myenum CHECK "
404                     "(bar IN ('one', 'two', 'three')))", {})])
405        finally:
406            metadata.drop_all(engine)
407
408    def test_standalone_enum(self):
409        metadata = MetaData(testing.db)
410        etype = Enum('four', 'five', 'six', name='fourfivesixtype',
411                     metadata=metadata)
412        etype.create()
413        try:
414            assert testing.db.dialect.has_type(testing.db,
415                                               'fourfivesixtype')
416        finally:
417            etype.drop()
418            assert not testing.db.dialect.has_type(testing.db,
419                                                   'fourfivesixtype')
420        metadata.create_all()
421        try:
422            assert testing.db.dialect.has_type(testing.db,
423                                               'fourfivesixtype')
424        finally:
425            metadata.drop_all()
426            assert not testing.db.dialect.has_type(testing.db,
427                                                   'fourfivesixtype')
428
429    def test_no_support(self):
430        def server_version_info(self):
431            return (8, 2)
432
433        e = engines.testing_engine()
434        dialect = e.dialect
435        dialect._get_server_version_info = server_version_info
436
437        assert dialect.supports_native_enum
438        e.connect()
439        assert not dialect.supports_native_enum
440
441        # initialize is called again on new pool
442        e.dispose()
443        e.connect()
444        assert not dialect.supports_native_enum
445
446    def test_reflection(self):
447        metadata = MetaData(testing.db)
448        etype = Enum('four', 'five', 'six', name='fourfivesixtype',
449                     metadata=metadata)
450        t1 = Table(
451            'table', metadata,
452            Column(
453                'id', Integer, primary_key=True),
454            Column(
455                'value', Enum(
456                    'one', 'two', 'three', name='onetwothreetype')),
457            Column('value2', etype))
458        metadata.create_all()
459        try:
460            m2 = MetaData(testing.db)
461            t2 = Table('table', m2, autoload=True)
462            assert t2.c.value.type.enums == ('one', 'two', 'three')
463            assert t2.c.value.type.name == 'onetwothreetype'
464            assert t2.c.value2.type.enums == ('four', 'five', 'six')
465            assert t2.c.value2.type.name == 'fourfivesixtype'
466        finally:
467            metadata.drop_all()
468
469    def test_schema_reflection(self):
470        metadata = MetaData(testing.db)
471        etype = Enum(
472            'four',
473            'five',
474            'six',
475            name='fourfivesixtype',
476            schema='test_schema',
477            metadata=metadata,
478        )
479        t1 = Table(
480            'table', metadata,
481            Column(
482                'id', Integer, primary_key=True),
483            Column(
484                'value', Enum(
485                    'one', 'two', 'three',
486                    name='onetwothreetype', schema='test_schema')),
487            Column('value2', etype))
488        metadata.create_all()
489        try:
490            m2 = MetaData(testing.db)
491            t2 = Table('table', m2, autoload=True)
492            assert t2.c.value.type.enums == ('one', 'two', 'three')
493            assert t2.c.value.type.name == 'onetwothreetype'
494            assert t2.c.value2.type.enums == ('four', 'five', 'six')
495            assert t2.c.value2.type.name == 'fourfivesixtype'
496            assert t2.c.value2.type.schema == 'test_schema'
497        finally:
498            metadata.drop_all()
499
500
501class OIDTest(fixtures.TestBase):
502    __only_on__ = 'postgresql'
503    __backend__ = True
504
505    @testing.provide_metadata
506    def test_reflection(self):
507        metadata = self.metadata
508        Table('table', metadata, Column('x', Integer),
509              Column('y', postgresql.OID))
510        metadata.create_all()
511        m2 = MetaData()
512        t2 = Table('table', m2, autoload_with=testing.db, autoload=True)
513        assert isinstance(t2.c.y.type, postgresql.OID)
514
515
516class NumericInterpretationTest(fixtures.TestBase):
517    __only_on__ = 'postgresql'
518    __backend__ = True
519
520    def test_numeric_codes(self):
521        from sqlalchemy.dialects.postgresql import psycopg2cffi, pg8000, \
522            psycopg2, base
523
524        dialects = (pg8000.dialect(), psycopg2.dialect(),
525                    psycopg2cffi.dialect())
526        for dialect in dialects:
527            typ = Numeric().dialect_impl(dialect)
528            for code in base._INT_TYPES + base._FLOAT_TYPES + \
529                    base._DECIMAL_TYPES:
530                proc = typ.result_processor(dialect, code)
531                val = 23.7
532                if proc is not None:
533                    val = proc(val)
534                assert val in (23.7, decimal.Decimal("23.7"))
535
536    @testing.provide_metadata
537    def test_numeric_default(self):
538        metadata = self.metadata
539        # pg8000 appears to fail when the value is 0,
540        # returns an int instead of decimal.
541        t = Table('t', metadata,
542                  Column('id', Integer, primary_key=True),
543                  Column('nd', Numeric(asdecimal=True), default=1),
544                  Column('nf', Numeric(asdecimal=False), default=1),
545                  Column('fd', Float(asdecimal=True), default=1),
546                  Column('ff', Float(asdecimal=False), default=1),
547                  )
548        metadata.create_all()
549        r = t.insert().execute()
550
551        row = t.select().execute().first()
552        assert isinstance(row[1], decimal.Decimal)
553        assert isinstance(row[2], float)
554        assert isinstance(row[3], decimal.Decimal)
555        assert isinstance(row[4], float)
556        eq_(
557            row,
558            (1, decimal.Decimal("1"), 1, decimal.Decimal("1"), 1)
559        )
560
561
562class PythonTypeTest(fixtures.TestBase):
563    def test_interval(self):
564        is_(
565            postgresql.INTERVAL().python_type,
566            datetime.timedelta
567        )
568
569
570class TimezoneTest(fixtures.TestBase):
571    __backend__ = True
572
573    """Test timezone-aware datetimes.
574
575    psycopg will return a datetime with a tzinfo attached to it, if
576    postgresql returns it.  python then will not let you compare a
577    datetime with a tzinfo to a datetime that doesn't have one.  this
578    test illustrates two ways to have datetime types with and without
579    timezone info. """
580
581    __only_on__ = 'postgresql'
582
583    @classmethod
584    def setup_class(cls):
585        global tztable, notztable, metadata
586        metadata = MetaData(testing.db)
587
588        # current_timestamp() in postgresql is assumed to return
589        # TIMESTAMP WITH TIMEZONE
590
591        tztable = Table(
592            'tztable', metadata,
593            Column(
594                'id', Integer, primary_key=True),
595            Column(
596                'date', DateTime(
597                    timezone=True), onupdate=func.current_timestamp()),
598            Column('name', String(20)))
599        notztable = Table(
600            'notztable', metadata,
601            Column(
602                'id', Integer, primary_key=True),
603            Column(
604                'date', DateTime(
605                    timezone=False), onupdate=cast(
606                    func.current_timestamp(), DateTime(
607                        timezone=False))),
608            Column('name', String(20)))
609        metadata.create_all()
610
611    @classmethod
612    def teardown_class(cls):
613        metadata.drop_all()
614
615    @testing.fails_on('postgresql+zxjdbc',
616                      "XXX: postgresql+zxjdbc doesn't give a tzinfo back")
617    def test_with_timezone(self):
618
619        # get a date with a tzinfo
620
621        somedate = \
622            testing.db.connect().scalar(func.current_timestamp().select())
623        assert somedate.tzinfo
624        tztable.insert().execute(id=1, name='row1', date=somedate)
625        row = select([tztable.c.date], tztable.c.id
626                     == 1).execute().first()
627        eq_(row[0], somedate)
628        eq_(somedate.tzinfo.utcoffset(somedate),
629            row[0].tzinfo.utcoffset(row[0]))
630        result = tztable.update(tztable.c.id
631                                == 1).returning(tztable.c.date).\
632            execute(name='newname'
633                    )
634        row = result.first()
635        assert row[0] >= somedate
636
637    def test_without_timezone(self):
638
639        # get a date without a tzinfo
640
641        somedate = datetime.datetime(2005, 10, 20, 11, 52, 0, )
642        assert not somedate.tzinfo
643        notztable.insert().execute(id=1, name='row1', date=somedate)
644        row = select([notztable.c.date], notztable.c.id
645                     == 1).execute().first()
646        eq_(row[0], somedate)
647        eq_(row[0].tzinfo, None)
648        result = notztable.update(notztable.c.id
649                                  == 1).returning(notztable.c.date).\
650            execute(name='newname'
651                    )
652        row = result.first()
653        assert row[0] >= somedate
654
655
656class TimePrecisionTest(fixtures.TestBase, AssertsCompiledSQL):
657
658    __dialect__ = postgresql.dialect()
659    __prefer__ = 'postgresql'
660    __backend__ = True
661
662    def test_compile(self):
663        for type_, expected in [
664            (postgresql.TIME(), 'TIME WITHOUT TIME ZONE'),
665            (postgresql.TIME(precision=5), 'TIME(5) WITHOUT TIME ZONE'
666             ),
667            (postgresql.TIME(timezone=True, precision=5),
668             'TIME(5) WITH TIME ZONE'),
669            (postgresql.TIMESTAMP(), 'TIMESTAMP WITHOUT TIME ZONE'),
670            (postgresql.TIMESTAMP(precision=5),
671             'TIMESTAMP(5) WITHOUT TIME ZONE'),
672            (postgresql.TIMESTAMP(timezone=True, precision=5),
673             'TIMESTAMP(5) WITH TIME ZONE'),
674        ]:
675            self.assert_compile(type_, expected)
676
677    @testing.only_on('postgresql', 'DB specific feature')
678    @testing.provide_metadata
679    def test_reflection(self):
680        metadata = self.metadata
681        t1 = Table(
682            't1',
683            metadata,
684            Column('c1', postgresql.TIME()),
685            Column('c2', postgresql.TIME(precision=5)),
686            Column('c3', postgresql.TIME(timezone=True, precision=5)),
687            Column('c4', postgresql.TIMESTAMP()),
688            Column('c5', postgresql.TIMESTAMP(precision=5)),
689            Column('c6', postgresql.TIMESTAMP(timezone=True,
690                                              precision=5)),
691        )
692        t1.create()
693        m2 = MetaData(testing.db)
694        t2 = Table('t1', m2, autoload=True)
695        eq_(t2.c.c1.type.precision, None)
696        eq_(t2.c.c2.type.precision, 5)
697        eq_(t2.c.c3.type.precision, 5)
698        eq_(t2.c.c4.type.precision, None)
699        eq_(t2.c.c5.type.precision, 5)
700        eq_(t2.c.c6.type.precision, 5)
701        eq_(t2.c.c1.type.timezone, False)
702        eq_(t2.c.c2.type.timezone, False)
703        eq_(t2.c.c3.type.timezone, True)
704        eq_(t2.c.c4.type.timezone, False)
705        eq_(t2.c.c5.type.timezone, False)
706        eq_(t2.c.c6.type.timezone, True)
707
708
709class ArrayTest(fixtures.TablesTest, AssertsExecutionResults):
710    __only_on__ = 'postgresql'
711    __backend__ = True
712    __unsupported_on__ = 'postgresql+pg8000', 'postgresql+zxjdbc'
713
714    @classmethod
715    def define_tables(cls, metadata):
716
717        class ProcValue(TypeDecorator):
718            impl = postgresql.ARRAY(Integer, dimensions=2)
719
720            def process_bind_param(self, value, dialect):
721                if value is None:
722                    return None
723                return [
724                    [x + 5 for x in v]
725                    for v in value
726                ]
727
728            def process_result_value(self, value, dialect):
729                if value is None:
730                    return None
731                return [
732                    [x - 7 for x in v]
733                    for v in value
734                ]
735
736        Table('arrtable', metadata,
737              Column('id', Integer, primary_key=True),
738              Column('intarr', postgresql.ARRAY(Integer)),
739              Column('strarr', postgresql.ARRAY(Unicode())),
740              Column('dimarr', ProcValue)
741              )
742
743        Table('dim_arrtable', metadata,
744              Column('id', Integer, primary_key=True),
745              Column('intarr', postgresql.ARRAY(Integer, dimensions=1)),
746              Column('strarr', postgresql.ARRAY(Unicode(), dimensions=1)),
747              Column('dimarr', ProcValue)
748              )
749
750    def _fixture_456(self, table):
751        testing.db.execute(
752            table.insert(),
753            intarr=[4, 5, 6]
754        )
755
756    def test_reflect_array_column(self):
757        metadata2 = MetaData(testing.db)
758        tbl = Table('arrtable', metadata2, autoload=True)
759        assert isinstance(tbl.c.intarr.type, postgresql.ARRAY)
760        assert isinstance(tbl.c.strarr.type, postgresql.ARRAY)
761        assert isinstance(tbl.c.intarr.type.item_type, Integer)
762        assert isinstance(tbl.c.strarr.type.item_type, String)
763
764    def test_insert_array(self):
765        arrtable = self.tables.arrtable
766        arrtable.insert().execute(intarr=[1, 2, 3], strarr=[util.u('abc'),
767                                                            util.u('def')])
768        results = arrtable.select().execute().fetchall()
769        eq_(len(results), 1)
770        eq_(results[0]['intarr'], [1, 2, 3])
771        eq_(results[0]['strarr'], [util.u('abc'), util.u('def')])
772
773    def test_array_where(self):
774        arrtable = self.tables.arrtable
775        arrtable.insert().execute(intarr=[1, 2, 3], strarr=[util.u('abc'),
776                                                            util.u('def')])
777        arrtable.insert().execute(intarr=[4, 5, 6], strarr=util.u('ABC'))
778        results = arrtable.select().where(
779            arrtable.c.intarr == [
780                1,
781                2,
782                3]).execute().fetchall()
783        eq_(len(results), 1)
784        eq_(results[0]['intarr'], [1, 2, 3])
785
786    def test_array_concat(self):
787        arrtable = self.tables.arrtable
788        arrtable.insert().execute(intarr=[1, 2, 3],
789                                  strarr=[util.u('abc'), util.u('def')])
790        results = select([arrtable.c.intarr + [4, 5,
791                                               6]]).execute().fetchall()
792        eq_(len(results), 1)
793        eq_(results[0][0], [1, 2, 3, 4, 5, 6, ])
794
795    def test_array_comparison(self):
796        arrtable = self.tables.arrtable
797        arrtable.insert().execute(id=5, intarr=[1, 2, 3],
798                    strarr=[util.u('abc'), util.u('def')])
799        results = select([arrtable.c.id]).\
800                        where(arrtable.c.intarr < [4, 5, 6]).execute()\
801                        .fetchall()
802        eq_(len(results), 1)
803        eq_(results[0][0], 5)
804
805    def test_contains_override_raises(self):
806        col = Column('x', postgresql.ARRAY(Integer))
807
808        assert_raises_message(
809            NotImplementedError,
810            "Operator 'contains' is not supported on this expression",
811            lambda: 'foo' in col
812        )
813
814    def test_array_subtype_resultprocessor(self):
815        arrtable = self.tables.arrtable
816        arrtable.insert().execute(intarr=[4, 5, 6],
817                                  strarr=[[util.ue('m\xe4\xe4')], [
818                                      util.ue('m\xf6\xf6')]])
819        arrtable.insert().execute(intarr=[1, 2, 3], strarr=[
820            util.ue('m\xe4\xe4'), util.ue('m\xf6\xf6')])
821        results = \
822            arrtable.select(order_by=[arrtable.c.intarr]).execute().fetchall()
823        eq_(len(results), 2)
824        eq_(results[0]['strarr'], [util.ue('m\xe4\xe4'), util.ue('m\xf6\xf6')])
825        eq_(results[1]['strarr'],
826            [[util.ue('m\xe4\xe4')],
827             [util.ue('m\xf6\xf6')]])
828
829    def test_array_literal(self):
830        eq_(
831            testing.db.scalar(
832                select([
833                    postgresql.array([1, 2]) + postgresql.array([3, 4, 5])
834                ])
835            ), [1, 2, 3, 4, 5]
836        )
837
838    def test_array_literal_compare(self):
839        eq_(
840            testing.db.scalar(
841                select([
842                    postgresql.array([1, 2]) < [3, 4, 5]
843                ])
844                ), True
845        )
846
847    def test_array_getitem_single_type(self):
848        arrtable = self.tables.arrtable
849        is_(arrtable.c.intarr[1].type._type_affinity, Integer)
850        is_(arrtable.c.strarr[1].type._type_affinity, String)
851
852    def test_array_getitem_slice_type(self):
853        arrtable = self.tables.arrtable
854        is_(arrtable.c.intarr[1:3].type._type_affinity, postgresql.ARRAY)
855        is_(arrtable.c.strarr[1:3].type._type_affinity, postgresql.ARRAY)
856
857    def test_array_getitem_single_exec(self):
858        arrtable = self.tables.arrtable
859        self._fixture_456(arrtable)
860        eq_(
861            testing.db.scalar(select([arrtable.c.intarr[2]])),
862            5
863        )
864        testing.db.execute(
865            arrtable.update().values({arrtable.c.intarr[2]: 7})
866        )
867        eq_(
868            testing.db.scalar(select([arrtable.c.intarr[2]])),
869            7
870        )
871
872    def test_undim_array_empty(self):
873        arrtable = self.tables.arrtable
874        self._fixture_456(arrtable)
875        eq_(
876            testing.db.scalar(
877                select([arrtable.c.intarr]).
878                where(arrtable.c.intarr.contains([]))
879            ),
880            [4, 5, 6]
881        )
882
883    def test_array_getitem_slice_exec(self):
884        arrtable = self.tables.arrtable
885        testing.db.execute(
886            arrtable.insert(),
887            intarr=[4, 5, 6],
888            strarr=[util.u('abc'), util.u('def')]
889        )
890        eq_(
891            testing.db.scalar(select([arrtable.c.intarr[2:3]])),
892            [5, 6]
893        )
894        testing.db.execute(
895            arrtable.update().values({arrtable.c.intarr[2:3]: [7, 8]})
896        )
897        eq_(
898            testing.db.scalar(select([arrtable.c.intarr[2:3]])),
899            [7, 8]
900        )
901
902    def _test_undim_array_contains_typed_exec(self, struct):
903        arrtable = self.tables.arrtable
904        self._fixture_456(arrtable)
905        eq_(
906            testing.db.scalar(
907                select([arrtable.c.intarr]).
908                where(arrtable.c.intarr.contains(struct([4, 5])))
909            ),
910            [4, 5, 6]
911        )
912
913    def test_undim_array_contains_set_exec(self):
914        self._test_undim_array_contains_typed_exec(set)
915
916    def test_undim_array_contains_list_exec(self):
917        self._test_undim_array_contains_typed_exec(list)
918
919    def test_undim_array_contains_generator_exec(self):
920        self._test_undim_array_contains_typed_exec(
921            lambda elem: (x for x in elem))
922
923    def _test_dim_array_contains_typed_exec(self, struct):
924        dim_arrtable = self.tables.dim_arrtable
925        self._fixture_456(dim_arrtable)
926        eq_(
927            testing.db.scalar(
928                select([dim_arrtable.c.intarr]).
929                where(dim_arrtable.c.intarr.contains(struct([4, 5])))
930            ),
931            [4, 5, 6]
932        )
933
934    def test_dim_array_contains_set_exec(self):
935        self._test_dim_array_contains_typed_exec(set)
936
937    def test_dim_array_contains_list_exec(self):
938        self._test_dim_array_contains_typed_exec(list)
939
940    def test_dim_array_contains_generator_exec(self):
941        self._test_dim_array_contains_typed_exec(
942            lambda elem: (
943                x for x in elem))
944
945    def test_array_contained_by_exec(self):
946        arrtable = self.tables.arrtable
947        with testing.db.connect() as conn:
948            conn.execute(
949                arrtable.insert(),
950                intarr=[6, 5, 4]
951            )
952            eq_(
953                conn.scalar(
954                    select([arrtable.c.intarr.contained_by([4, 5, 6, 7])])
955                ),
956                True
957            )
958
959    def test_array_overlap_exec(self):
960        arrtable = self.tables.arrtable
961        with testing.db.connect() as conn:
962            conn.execute(
963                arrtable.insert(),
964                intarr=[4, 5, 6]
965            )
966            eq_(
967                conn.scalar(
968                    select([arrtable.c.intarr]).
969                    where(arrtable.c.intarr.overlap([7, 6]))
970                ),
971                [4, 5, 6]
972            )
973
974    def test_array_any_exec(self):
975        arrtable = self.tables.arrtable
976        with testing.db.connect() as conn:
977            conn.execute(
978                arrtable.insert(),
979                intarr=[4, 5, 6]
980            )
981            eq_(
982                conn.scalar(
983                    select([arrtable.c.intarr]).
984                    where(postgresql.Any(5, arrtable.c.intarr))
985                ),
986                [4, 5, 6]
987            )
988
989    def test_array_all_exec(self):
990        arrtable = self.tables.arrtable
991        with testing.db.connect() as conn:
992            conn.execute(
993                arrtable.insert(),
994                intarr=[4, 5, 6]
995            )
996            eq_(
997                conn.scalar(
998                    select([arrtable.c.intarr]).
999                    where(arrtable.c.intarr.all(4, operator=operators.le))
1000                ),
1001                [4, 5, 6]
1002            )
1003
1004    @testing.provide_metadata
1005    def test_tuple_flag(self):
1006        metadata = self.metadata
1007
1008        t1 = Table(
1009            't1', metadata,
1010            Column('id', Integer, primary_key=True),
1011            Column('data', postgresql.ARRAY(String(5), as_tuple=True)),
1012            Column(
1013                'data2',
1014                postgresql.ARRAY(
1015                    Numeric(asdecimal=False), as_tuple=True)
1016            )
1017        )
1018        metadata.create_all()
1019        testing.db.execute(
1020            t1.insert(), id=1, data=[
1021                "1", "2", "3"], data2=[
1022                5.4, 5.6])
1023        testing.db.execute(
1024            t1.insert(),
1025            id=2,
1026            data=[
1027                "4",
1028                "5",
1029                "6"],
1030            data2=[1.0])
1031        testing.db.execute(t1.insert(), id=3, data=[["4", "5"], ["6", "7"]],
1032                           data2=[[5.4, 5.6], [1.0, 1.1]])
1033
1034        r = testing.db.execute(t1.select().order_by(t1.c.id)).fetchall()
1035        eq_(
1036            r,
1037            [
1038                (1, ('1', '2', '3'), (5.4, 5.6)),
1039                (2, ('4', '5', '6'), (1.0,)),
1040                (3, (('4', '5'), ('6', '7')), ((5.4, 5.6), (1.0, 1.1)))
1041            ]
1042        )
1043        # hashable
1044        eq_(
1045            set(row[1] for row in r),
1046            set([('1', '2', '3'), ('4', '5', '6'), (('4', '5'), ('6', '7'))])
1047        )
1048
1049    def test_dimension(self):
1050        arrtable = self.tables.arrtable
1051        testing.db.execute(arrtable.insert(), dimarr=[[1, 2, 3], [4, 5, 6]])
1052        eq_(
1053            testing.db.scalar(select([arrtable.c.dimarr])),
1054            [[-1, 0, 1], [2, 3, 4]]
1055        )
1056
1057
1058class TimestampTest(fixtures.TestBase, AssertsExecutionResults):
1059    __only_on__ = 'postgresql'
1060    __backend__ = True
1061
1062    def test_timestamp(self):
1063        engine = testing.db
1064        connection = engine.connect()
1065
1066        s = select([text("timestamp '2007-12-25'")])
1067        result = connection.execute(s).first()
1068        eq_(result[0], datetime.datetime(2007, 12, 25, 0, 0))
1069
1070    def test_interval_arithmetic(self):
1071        # basically testing that we get timedelta back for an INTERVAL
1072        # result.  more of a driver assertion.
1073        engine = testing.db
1074        connection = engine.connect()
1075
1076        s = select([text("timestamp '2007-12-25' - timestamp '2007-11-15'")])
1077        result = connection.execute(s).first()
1078        eq_(result[0], datetime.timedelta(40))
1079
1080
1081class SpecialTypesTest(fixtures.TestBase, ComparesTables, AssertsCompiledSQL):
1082
1083    """test DDL and reflection of PG-specific types """
1084
1085    __only_on__ = 'postgresql >= 8.3.0',
1086    __backend__ = True
1087
1088    @classmethod
1089    def setup_class(cls):
1090        global metadata, table
1091        metadata = MetaData(testing.db)
1092
1093        # create these types so that we can issue
1094        # special SQL92 INTERVAL syntax
1095        class y2m(types.UserDefinedType, postgresql.INTERVAL):
1096
1097            def get_col_spec(self):
1098                return "INTERVAL YEAR TO MONTH"
1099
1100        class d2s(types.UserDefinedType, postgresql.INTERVAL):
1101
1102            def get_col_spec(self):
1103                return "INTERVAL DAY TO SECOND"
1104
1105        table = Table(
1106            'sometable', metadata,
1107            Column(
1108                'id', postgresql.UUID, primary_key=True),
1109            Column(
1110                'flag', postgresql.BIT),
1111            Column(
1112                'bitstring', postgresql.BIT(4)),
1113            Column('addr', postgresql.INET),
1114            Column('addr2', postgresql.MACADDR),
1115            Column('addr3', postgresql.CIDR),
1116            Column('doubleprec', postgresql.DOUBLE_PRECISION),
1117            Column('plain_interval', postgresql.INTERVAL),
1118            Column('year_interval', y2m()),
1119            Column('month_interval', d2s()),
1120            Column('precision_interval', postgresql.INTERVAL(
1121                precision=3)),
1122            Column('tsvector_document', postgresql.TSVECTOR))
1123
1124        metadata.create_all()
1125
1126        # cheat so that the "strict type check"
1127        # works
1128        table.c.year_interval.type = postgresql.INTERVAL()
1129        table.c.month_interval.type = postgresql.INTERVAL()
1130
1131    @classmethod
1132    def teardown_class(cls):
1133        metadata.drop_all()
1134
1135    def test_reflection(self):
1136        m = MetaData(testing.db)
1137        t = Table('sometable', m, autoload=True)
1138
1139        self.assert_tables_equal(table, t, strict_types=True)
1140        assert t.c.plain_interval.type.precision is None
1141        assert t.c.precision_interval.type.precision == 3
1142        assert t.c.bitstring.type.length == 4
1143
1144    def test_bit_compile(self):
1145        pairs = [(postgresql.BIT(), 'BIT(1)'),
1146                 (postgresql.BIT(5), 'BIT(5)'),
1147                 (postgresql.BIT(varying=True), 'BIT VARYING'),
1148                 (postgresql.BIT(5, varying=True), 'BIT VARYING(5)'),
1149                 ]
1150        for type_, expected in pairs:
1151            self.assert_compile(type_, expected)
1152
1153    @testing.provide_metadata
1154    def test_tsvector_round_trip(self):
1155        t = Table('t1', self.metadata, Column('data', postgresql.TSVECTOR))
1156        t.create()
1157        testing.db.execute(t.insert(), data="a fat cat sat")
1158        eq_(testing.db.scalar(select([t.c.data])), "'a' 'cat' 'fat' 'sat'")
1159
1160        testing.db.execute(t.update(), data="'a' 'cat' 'fat' 'mat' 'sat'")
1161
1162        eq_(testing.db.scalar(select([t.c.data])),
1163            "'a' 'cat' 'fat' 'mat' 'sat'")
1164
1165    @testing.provide_metadata
1166    def test_bit_reflection(self):
1167        metadata = self.metadata
1168        t1 = Table('t1', metadata,
1169                   Column('bit1', postgresql.BIT()),
1170                   Column('bit5', postgresql.BIT(5)),
1171                   Column('bitvarying', postgresql.BIT(varying=True)),
1172                   Column('bitvarying5', postgresql.BIT(5, varying=True)),
1173                   )
1174        t1.create()
1175        m2 = MetaData(testing.db)
1176        t2 = Table('t1', m2, autoload=True)
1177        eq_(t2.c.bit1.type.length, 1)
1178        eq_(t2.c.bit1.type.varying, False)
1179        eq_(t2.c.bit5.type.length, 5)
1180        eq_(t2.c.bit5.type.varying, False)
1181        eq_(t2.c.bitvarying.type.length, None)
1182        eq_(t2.c.bitvarying.type.varying, True)
1183        eq_(t2.c.bitvarying5.type.length, 5)
1184        eq_(t2.c.bitvarying5.type.varying, True)
1185
1186
1187class UUIDTest(fixtures.TestBase):
1188
1189    """Test the bind/return values of the UUID type."""
1190
1191    __only_on__ = 'postgresql >= 8.3'
1192    __backend__ = True
1193
1194    @testing.fails_on(
1195        'postgresql+zxjdbc',
1196        'column "data" is of type uuid but expression '
1197        'is of type character varying')
1198    @testing.fails_on('postgresql+pg8000', 'No support for UUID type')
1199    def test_uuid_string(self):
1200        import uuid
1201        self._test_round_trip(
1202            Table('utable', MetaData(),
1203                  Column('data', postgresql.UUID(as_uuid=False))
1204                  ),
1205            str(uuid.uuid4()),
1206            str(uuid.uuid4())
1207        )
1208
1209    @testing.fails_on(
1210        'postgresql+zxjdbc',
1211        'column "data" is of type uuid but expression is '
1212        'of type character varying')
1213    @testing.fails_on('postgresql+pg8000', 'No support for UUID type')
1214    def test_uuid_uuid(self):
1215        import uuid
1216        self._test_round_trip(
1217            Table('utable', MetaData(),
1218                  Column('data', postgresql.UUID(as_uuid=True))
1219                  ),
1220            uuid.uuid4(),
1221            uuid.uuid4()
1222        )
1223
1224    @testing.fails_on('postgresql+zxjdbc',
1225                      'column "data" is of type uuid[] but '
1226                      'expression is of type character varying')
1227    @testing.fails_on('postgresql+pg8000', 'No support for UUID type')
1228    def test_uuid_array(self):
1229        import uuid
1230        self._test_round_trip(
1231            Table(
1232                'utable', MetaData(),
1233                Column('data', postgresql.ARRAY(postgresql.UUID(as_uuid=True)))
1234            ),
1235            [uuid.uuid4(), uuid.uuid4()],
1236            [uuid.uuid4(), uuid.uuid4()],
1237        )
1238
1239    @testing.fails_on('postgresql+zxjdbc',
1240                      'column "data" is of type uuid[] but '
1241                      'expression is of type character varying')
1242    @testing.fails_on('postgresql+pg8000', 'No support for UUID type')
1243    def test_uuid_string_array(self):
1244        import uuid
1245        self._test_round_trip(
1246            Table(
1247                'utable', MetaData(),
1248                Column(
1249                    'data',
1250                    postgresql.ARRAY(postgresql.UUID(as_uuid=False)))
1251            ),
1252            [str(uuid.uuid4()), str(uuid.uuid4())],
1253            [str(uuid.uuid4()), str(uuid.uuid4())],
1254        )
1255
1256    def test_no_uuid_available(self):
1257        from sqlalchemy.dialects.postgresql import base
1258        uuid_type = base._python_UUID
1259        base._python_UUID = None
1260        try:
1261            assert_raises(
1262                NotImplementedError,
1263                postgresql.UUID, as_uuid=True
1264            )
1265        finally:
1266            base._python_UUID = uuid_type
1267
1268    def setup(self):
1269        self.conn = testing.db.connect()
1270        trans = self.conn.begin()
1271
1272    def teardown(self):
1273        self.conn.close()
1274
1275    def _test_round_trip(self, utable, value1, value2, exp_value2=None):
1276        utable.create(self.conn)
1277        self.conn.execute(utable.insert(), {'data': value1})
1278        self.conn.execute(utable.insert(), {'data': value2})
1279        r = self.conn.execute(
1280            select([utable.c.data]).
1281            where(utable.c.data != value1)
1282        )
1283        if exp_value2:
1284            eq_(r.fetchone()[0], exp_value2)
1285        else:
1286            eq_(r.fetchone()[0], value2)
1287        eq_(r.fetchone(), None)
1288
1289
1290class HStoreTest(AssertsCompiledSQL, fixtures.TestBase):
1291    __dialect__ = 'postgresql'
1292
1293    def setup(self):
1294        metadata = MetaData()
1295        self.test_table = Table('test_table', metadata,
1296                                Column('id', Integer, primary_key=True),
1297                                Column('hash', HSTORE)
1298                                )
1299        self.hashcol = self.test_table.c.hash
1300
1301    def _test_where(self, whereclause, expected):
1302        stmt = select([self.test_table]).where(whereclause)
1303        self.assert_compile(
1304            stmt,
1305            "SELECT test_table.id, test_table.hash FROM test_table "
1306            "WHERE %s" % expected
1307        )
1308
1309    def _test_cols(self, colclause, expected, from_=True):
1310        stmt = select([colclause])
1311        self.assert_compile(
1312            stmt,
1313            (
1314                "SELECT %s" +
1315                (" FROM test_table" if from_ else "")
1316            ) % expected
1317        )
1318
1319    def test_bind_serialize_default(self):
1320
1321        dialect = postgresql.dialect()
1322        proc = self.test_table.c.hash.type._cached_bind_processor(dialect)
1323        eq_(
1324            proc(util.OrderedDict([("key1", "value1"), ("key2", "value2")])),
1325            '"key1"=>"value1", "key2"=>"value2"'
1326        )
1327
1328    def test_bind_serialize_with_slashes_and_quotes(self):
1329        dialect = postgresql.dialect()
1330        proc = self.test_table.c.hash.type._cached_bind_processor(dialect)
1331        eq_(
1332            proc({'\\"a': '\\"1'}),
1333            '"\\\\\\"a"=>"\\\\\\"1"'
1334        )
1335
1336    def test_parse_error(self):
1337        dialect = postgresql.dialect()
1338        proc = self.test_table.c.hash.type._cached_result_processor(
1339            dialect, None)
1340        assert_raises_message(
1341            ValueError,
1342            r'''After u?'\[\.\.\.\], "key1"=>"value1", ', could not parse '''
1343            '''residual at position 36: u?'crapcrapcrap, "key3"\[\.\.\.\]''',
1344            proc,
1345            '"key2"=>"value2", "key1"=>"value1", '
1346            'crapcrapcrap, "key3"=>"value3"'
1347        )
1348
1349    def test_result_deserialize_default(self):
1350        dialect = postgresql.dialect()
1351        proc = self.test_table.c.hash.type._cached_result_processor(
1352            dialect, None)
1353        eq_(
1354            proc('"key2"=>"value2", "key1"=>"value1"'),
1355            {"key1": "value1", "key2": "value2"}
1356        )
1357
1358    def test_result_deserialize_with_slashes_and_quotes(self):
1359        dialect = postgresql.dialect()
1360        proc = self.test_table.c.hash.type._cached_result_processor(
1361            dialect, None)
1362        eq_(
1363            proc('"\\\\\\"a"=>"\\\\\\"1"'),
1364            {'\\"a': '\\"1'}
1365        )
1366
1367    def test_bind_serialize_psycopg2(self):
1368        from sqlalchemy.dialects.postgresql import psycopg2
1369
1370        dialect = psycopg2.PGDialect_psycopg2()
1371        dialect._has_native_hstore = True
1372        proc = self.test_table.c.hash.type._cached_bind_processor(dialect)
1373        is_(proc, None)
1374
1375        dialect = psycopg2.PGDialect_psycopg2()
1376        dialect._has_native_hstore = False
1377        proc = self.test_table.c.hash.type._cached_bind_processor(dialect)
1378        eq_(
1379            proc(util.OrderedDict([("key1", "value1"), ("key2", "value2")])),
1380            '"key1"=>"value1", "key2"=>"value2"'
1381        )
1382
1383    def test_result_deserialize_psycopg2(self):
1384        from sqlalchemy.dialects.postgresql import psycopg2
1385
1386        dialect = psycopg2.PGDialect_psycopg2()
1387        dialect._has_native_hstore = True
1388        proc = self.test_table.c.hash.type._cached_result_processor(
1389            dialect, None)
1390        is_(proc, None)
1391
1392        dialect = psycopg2.PGDialect_psycopg2()
1393        dialect._has_native_hstore = False
1394        proc = self.test_table.c.hash.type._cached_result_processor(
1395            dialect, None)
1396        eq_(
1397            proc('"key2"=>"value2", "key1"=>"value1"'),
1398            {"key1": "value1", "key2": "value2"}
1399        )
1400
1401    def test_where_has_key(self):
1402        self._test_where(
1403            # hide from 2to3
1404            getattr(self.hashcol, 'has_key')('foo'),
1405            "test_table.hash ? %(hash_1)s"
1406        )
1407
1408    def test_where_has_all(self):
1409        self._test_where(
1410            self.hashcol.has_all(postgresql.array(['1', '2'])),
1411            "test_table.hash ?& ARRAY[%(param_1)s, %(param_2)s]"
1412        )
1413
1414    def test_where_has_any(self):
1415        self._test_where(
1416            self.hashcol.has_any(postgresql.array(['1', '2'])),
1417            "test_table.hash ?| ARRAY[%(param_1)s, %(param_2)s]"
1418        )
1419
1420    def test_where_defined(self):
1421        self._test_where(
1422            self.hashcol.defined('foo'),
1423            "defined(test_table.hash, %(param_1)s)"
1424        )
1425
1426    def test_where_contains(self):
1427        self._test_where(
1428            self.hashcol.contains({'foo': '1'}),
1429            "test_table.hash @> %(hash_1)s"
1430        )
1431
1432    def test_where_contained_by(self):
1433        self._test_where(
1434            self.hashcol.contained_by({'foo': '1', 'bar': None}),
1435            "test_table.hash <@ %(hash_1)s"
1436        )
1437
1438    def test_where_getitem(self):
1439        self._test_where(
1440            self.hashcol['bar'] == None,
1441            "(test_table.hash -> %(hash_1)s) IS NULL"
1442        )
1443
1444    def test_cols_get(self):
1445        self._test_cols(
1446            self.hashcol['foo'],
1447            "test_table.hash -> %(hash_1)s AS anon_1",
1448            True
1449        )
1450
1451    def test_cols_delete_single_key(self):
1452        self._test_cols(
1453            self.hashcol.delete('foo'),
1454            "delete(test_table.hash, %(param_1)s) AS delete_1",
1455            True
1456        )
1457
1458    def test_cols_delete_array_of_keys(self):
1459        self._test_cols(
1460            self.hashcol.delete(postgresql.array(['foo', 'bar'])),
1461            ("delete(test_table.hash, ARRAY[%(param_1)s, %(param_2)s]) "
1462             "AS delete_1"),
1463            True
1464        )
1465
1466    def test_cols_delete_matching_pairs(self):
1467        self._test_cols(
1468            self.hashcol.delete(hstore('1', '2')),
1469            ("delete(test_table.hash, hstore(%(param_1)s, %(param_2)s)) "
1470             "AS delete_1"),
1471            True
1472        )
1473
1474    def test_cols_slice(self):
1475        self._test_cols(
1476            self.hashcol.slice(postgresql.array(['1', '2'])),
1477            ("slice(test_table.hash, ARRAY[%(param_1)s, %(param_2)s]) "
1478             "AS slice_1"),
1479            True
1480        )
1481
1482    def test_cols_hstore_pair_text(self):
1483        self._test_cols(
1484            hstore('foo', '3')['foo'],
1485            "hstore(%(param_1)s, %(param_2)s) -> %(hstore_1)s AS anon_1",
1486            False
1487        )
1488
1489    def test_cols_hstore_pair_array(self):
1490        self._test_cols(
1491            hstore(postgresql.array(['1', '2']),
1492                   postgresql.array(['3', None]))['1'],
1493            ("hstore(ARRAY[%(param_1)s, %(param_2)s], "
1494             "ARRAY[%(param_3)s, NULL]) -> %(hstore_1)s AS anon_1"),
1495            False
1496        )
1497
1498    def test_cols_hstore_single_array(self):
1499        self._test_cols(
1500            hstore(postgresql.array(['1', '2', '3', None]))['3'],
1501            ("hstore(ARRAY[%(param_1)s, %(param_2)s, %(param_3)s, NULL]) "
1502             "-> %(hstore_1)s AS anon_1"),
1503            False
1504        )
1505
1506    def test_cols_concat(self):
1507        self._test_cols(
1508            self.hashcol.concat(hstore(cast(self.test_table.c.id, Text), '3')),
1509            ("test_table.hash || hstore(CAST(test_table.id AS TEXT), "
1510             "%(param_1)s) AS anon_1"),
1511            True
1512        )
1513
1514    def test_cols_concat_op(self):
1515        self._test_cols(
1516            hstore('foo', 'bar') + self.hashcol,
1517            "hstore(%(param_1)s, %(param_2)s) || test_table.hash AS anon_1",
1518            True
1519        )
1520
1521    def test_cols_concat_get(self):
1522        self._test_cols(
1523            (self.hashcol + self.hashcol)['foo'],
1524            "test_table.hash || test_table.hash -> %(param_1)s AS anon_1"
1525        )
1526
1527    def test_cols_keys(self):
1528        self._test_cols(
1529            # hide from 2to3
1530            getattr(self.hashcol, 'keys')(),
1531            "akeys(test_table.hash) AS akeys_1",
1532            True
1533        )
1534
1535    def test_cols_vals(self):
1536        self._test_cols(
1537            self.hashcol.vals(),
1538            "avals(test_table.hash) AS avals_1",
1539            True
1540        )
1541
1542    def test_cols_array(self):
1543        self._test_cols(
1544            self.hashcol.array(),
1545            "hstore_to_array(test_table.hash) AS hstore_to_array_1",
1546            True
1547        )
1548
1549    def test_cols_matrix(self):
1550        self._test_cols(
1551            self.hashcol.matrix(),
1552            "hstore_to_matrix(test_table.hash) AS hstore_to_matrix_1",
1553            True
1554        )
1555
1556
1557class HStoreRoundTripTest(fixtures.TablesTest):
1558    __requires__ = 'hstore',
1559    __dialect__ = 'postgresql'
1560    __backend__ = True
1561
1562    @classmethod
1563    def define_tables(cls, metadata):
1564        Table('data_table', metadata,
1565              Column('id', Integer, primary_key=True),
1566              Column('name', String(30), nullable=False),
1567              Column('data', HSTORE)
1568              )
1569
1570    def _fixture_data(self, engine):
1571        data_table = self.tables.data_table
1572        engine.execute(
1573            data_table.insert(),
1574            {'name': 'r1', 'data': {"k1": "r1v1", "k2": "r1v2"}},
1575            {'name': 'r2', 'data': {"k1": "r2v1", "k2": "r2v2"}},
1576            {'name': 'r3', 'data': {"k1": "r3v1", "k2": "r3v2"}},
1577            {'name': 'r4', 'data': {"k1": "r4v1", "k2": "r4v2"}},
1578            {'name': 'r5', 'data': {"k1": "r5v1", "k2": "r5v2"}},
1579        )
1580
1581    def _assert_data(self, compare):
1582        data = testing.db.execute(
1583            select([self.tables.data_table.c.data]).
1584            order_by(self.tables.data_table.c.name)
1585        ).fetchall()
1586        eq_([d for d, in data], compare)
1587
1588    def _test_insert(self, engine):
1589        engine.execute(
1590            self.tables.data_table.insert(),
1591            {'name': 'r1', 'data': {"k1": "r1v1", "k2": "r1v2"}}
1592        )
1593        self._assert_data([{"k1": "r1v1", "k2": "r1v2"}])
1594
1595    def _non_native_engine(self):
1596        if testing.requires.psycopg2_native_hstore.enabled:
1597            engine = engines.testing_engine(
1598                options=dict(
1599                    use_native_hstore=False))
1600        else:
1601            engine = testing.db
1602        engine.connect().close()
1603        return engine
1604
1605    def test_reflect(self):
1606        insp = inspect(testing.db)
1607        cols = insp.get_columns('data_table')
1608        assert isinstance(cols[2]['type'], HSTORE)
1609
1610    @testing.requires.psycopg2_native_hstore
1611    def test_insert_native(self):
1612        engine = testing.db
1613        self._test_insert(engine)
1614
1615    def test_insert_python(self):
1616        engine = self._non_native_engine()
1617        self._test_insert(engine)
1618
1619    @testing.requires.psycopg2_native_hstore
1620    def test_criterion_native(self):
1621        engine = testing.db
1622        self._fixture_data(engine)
1623        self._test_criterion(engine)
1624
1625    def test_criterion_python(self):
1626        engine = self._non_native_engine()
1627        self._fixture_data(engine)
1628        self._test_criterion(engine)
1629
1630    def _test_criterion(self, engine):
1631        data_table = self.tables.data_table
1632        result = engine.execute(
1633            select([data_table.c.data]).where(
1634                data_table.c.data['k1'] == 'r3v1')).first()
1635        eq_(result, ({'k1': 'r3v1', 'k2': 'r3v2'},))
1636
1637    def _test_fixed_round_trip(self, engine):
1638        s = select([
1639            hstore(
1640                array(['key1', 'key2', 'key3']),
1641                array(['value1', 'value2', 'value3'])
1642            )
1643        ])
1644        eq_(
1645            engine.scalar(s),
1646            {"key1": "value1", "key2": "value2", "key3": "value3"}
1647        )
1648
1649    def test_fixed_round_trip_python(self):
1650        engine = self._non_native_engine()
1651        self._test_fixed_round_trip(engine)
1652
1653    @testing.requires.psycopg2_native_hstore
1654    def test_fixed_round_trip_native(self):
1655        engine = testing.db
1656        self._test_fixed_round_trip(engine)
1657
1658    def _test_unicode_round_trip(self, engine):
1659        s = select([
1660            hstore(
1661                array([util.u('réveillé'), util.u('drôle'), util.u('S’il')]),
1662                array([util.u('réveillé'), util.u('drôle'), util.u('S’il')])
1663            )
1664        ])
1665        eq_(
1666            engine.scalar(s),
1667            {
1668                util.u('réveillé'): util.u('réveillé'),
1669                util.u('drôle'): util.u('drôle'),
1670                util.u('S’il'): util.u('S’il')
1671            }
1672        )
1673
1674    @testing.requires.psycopg2_native_hstore
1675    def test_unicode_round_trip_python(self):
1676        engine = self._non_native_engine()
1677        self._test_unicode_round_trip(engine)
1678
1679    @testing.requires.psycopg2_native_hstore
1680    def test_unicode_round_trip_native(self):
1681        engine = testing.db
1682        self._test_unicode_round_trip(engine)
1683
1684    def test_escaped_quotes_round_trip_python(self):
1685        engine = self._non_native_engine()
1686        self._test_escaped_quotes_round_trip(engine)
1687
1688    @testing.requires.psycopg2_native_hstore
1689    def test_escaped_quotes_round_trip_native(self):
1690        engine = testing.db
1691        self._test_escaped_quotes_round_trip(engine)
1692
1693    def _test_escaped_quotes_round_trip(self, engine):
1694        engine.execute(
1695            self.tables.data_table.insert(),
1696            {'name': 'r1', 'data': {r'key \"foo\"': r'value \"bar"\ xyz'}}
1697        )
1698        self._assert_data([{r'key \"foo\"': r'value \"bar"\ xyz'}])
1699
1700    def test_orm_round_trip(self):
1701        from sqlalchemy import orm
1702
1703        class Data(object):
1704
1705            def __init__(self, name, data):
1706                self.name = name
1707                self.data = data
1708        orm.mapper(Data, self.tables.data_table)
1709        s = orm.Session(testing.db)
1710        d = Data(name='r1', data={"key1": "value1", "key2": "value2",
1711                                  "key3": "value3"})
1712        s.add(d)
1713        eq_(
1714            s.query(Data.data, Data).all(),
1715            [(d.data, d)]
1716        )
1717
1718
1719class _RangeTypeMixin(object):
1720    __requires__ = 'range_types', 'psycopg2_compatibility'
1721    __backend__ = True
1722
1723    def extras(self):
1724        # done this way so we don't get ImportErrors with
1725        # older psycopg2 versions.
1726        if testing.against("postgresql+psycopg2cffi"):
1727            from psycopg2cffi import extras
1728        else:
1729            from psycopg2 import extras
1730        return extras
1731
1732    @classmethod
1733    def define_tables(cls, metadata):
1734        # no reason ranges shouldn't be primary keys,
1735        # so lets just use them as such
1736        table = Table('data_table', metadata,
1737                      Column('range', cls._col_type, primary_key=True),
1738                      )
1739        cls.col = table.c.range
1740
1741    def test_actual_type(self):
1742        eq_(str(self._col_type()), self._col_str)
1743
1744    def test_reflect(self):
1745        from sqlalchemy import inspect
1746        insp = inspect(testing.db)
1747        cols = insp.get_columns('data_table')
1748        assert isinstance(cols[0]['type'], self._col_type)
1749
1750    def _assert_data(self):
1751        data = testing.db.execute(
1752            select([self.tables.data_table.c.range])
1753        ).fetchall()
1754        eq_(data, [(self._data_obj(), )])
1755
1756    def test_insert_obj(self):
1757        testing.db.engine.execute(
1758            self.tables.data_table.insert(),
1759            {'range': self._data_obj()}
1760        )
1761        self._assert_data()
1762
1763    def test_insert_text(self):
1764        testing.db.engine.execute(
1765            self.tables.data_table.insert(),
1766            {'range': self._data_str}
1767        )
1768        self._assert_data()
1769
1770    # operator tests
1771
1772    def _test_clause(self, colclause, expected):
1773        dialect = postgresql.dialect()
1774        compiled = str(colclause.compile(dialect=dialect))
1775        eq_(compiled, expected)
1776
1777    def test_where_equal(self):
1778        self._test_clause(
1779            self.col == self._data_str,
1780            "data_table.range = %(range_1)s"
1781        )
1782
1783    def test_where_not_equal(self):
1784        self._test_clause(
1785            self.col != self._data_str,
1786            "data_table.range <> %(range_1)s"
1787        )
1788
1789    def test_where_less_than(self):
1790        self._test_clause(
1791            self.col < self._data_str,
1792            "data_table.range < %(range_1)s"
1793        )
1794
1795    def test_where_greater_than(self):
1796        self._test_clause(
1797            self.col > self._data_str,
1798            "data_table.range > %(range_1)s"
1799        )
1800
1801    def test_where_less_than_or_equal(self):
1802        self._test_clause(
1803            self.col <= self._data_str,
1804            "data_table.range <= %(range_1)s"
1805        )
1806
1807    def test_where_greater_than_or_equal(self):
1808        self._test_clause(
1809            self.col >= self._data_str,
1810            "data_table.range >= %(range_1)s"
1811        )
1812
1813    def test_contains(self):
1814        self._test_clause(
1815            self.col.contains(self._data_str),
1816            "data_table.range @> %(range_1)s"
1817        )
1818
1819    def test_contained_by(self):
1820        self._test_clause(
1821            self.col.contained_by(self._data_str),
1822            "data_table.range <@ %(range_1)s"
1823        )
1824
1825    def test_overlaps(self):
1826        self._test_clause(
1827            self.col.overlaps(self._data_str),
1828            "data_table.range && %(range_1)s"
1829        )
1830
1831    def test_strictly_left_of(self):
1832        self._test_clause(
1833            self.col << self._data_str,
1834            "data_table.range << %(range_1)s"
1835        )
1836        self._test_clause(
1837            self.col.strictly_left_of(self._data_str),
1838            "data_table.range << %(range_1)s"
1839        )
1840
1841    def test_strictly_right_of(self):
1842        self._test_clause(
1843            self.col >> self._data_str,
1844            "data_table.range >> %(range_1)s"
1845        )
1846        self._test_clause(
1847            self.col.strictly_right_of(self._data_str),
1848            "data_table.range >> %(range_1)s"
1849        )
1850
1851    def test_not_extend_right_of(self):
1852        self._test_clause(
1853            self.col.not_extend_right_of(self._data_str),
1854            "data_table.range &< %(range_1)s"
1855        )
1856
1857    def test_not_extend_left_of(self):
1858        self._test_clause(
1859            self.col.not_extend_left_of(self._data_str),
1860            "data_table.range &> %(range_1)s"
1861        )
1862
1863    def test_adjacent_to(self):
1864        self._test_clause(
1865            self.col.adjacent_to(self._data_str),
1866            "data_table.range -|- %(range_1)s"
1867        )
1868
1869    def test_union(self):
1870        self._test_clause(
1871            self.col + self.col,
1872            "data_table.range + data_table.range"
1873        )
1874
1875    def test_union_result(self):
1876        # insert
1877        testing.db.engine.execute(
1878            self.tables.data_table.insert(),
1879            {'range': self._data_str}
1880        )
1881        # select
1882        range = self.tables.data_table.c.range
1883        data = testing.db.execute(
1884            select([range + range])
1885        ).fetchall()
1886        eq_(data, [(self._data_obj(), )])
1887
1888    def test_intersection(self):
1889        self._test_clause(
1890            self.col * self.col,
1891            "data_table.range * data_table.range"
1892        )
1893
1894    def test_intersection_result(self):
1895        # insert
1896        testing.db.engine.execute(
1897            self.tables.data_table.insert(),
1898            {'range': self._data_str}
1899        )
1900        # select
1901        range = self.tables.data_table.c.range
1902        data = testing.db.execute(
1903            select([range * range])
1904        ).fetchall()
1905        eq_(data, [(self._data_obj(), )])
1906
1907    def test_different(self):
1908        self._test_clause(
1909            self.col - self.col,
1910            "data_table.range - data_table.range"
1911        )
1912
1913    def test_difference_result(self):
1914        # insert
1915        testing.db.engine.execute(
1916            self.tables.data_table.insert(),
1917            {'range': self._data_str}
1918        )
1919        # select
1920        range = self.tables.data_table.c.range
1921        data = testing.db.execute(
1922            select([range - range])
1923        ).fetchall()
1924        eq_(data, [(self._data_obj().__class__(empty=True), )])
1925
1926
1927class Int4RangeTests(_RangeTypeMixin, fixtures.TablesTest):
1928
1929    _col_type = INT4RANGE
1930    _col_str = 'INT4RANGE'
1931    _data_str = '[1,2)'
1932
1933    def _data_obj(self):
1934        return self.extras().NumericRange(1, 2)
1935
1936
1937class Int8RangeTests(_RangeTypeMixin, fixtures.TablesTest):
1938
1939    _col_type = INT8RANGE
1940    _col_str = 'INT8RANGE'
1941    _data_str = '[9223372036854775806,9223372036854775807)'
1942
1943    def _data_obj(self):
1944        return self.extras().NumericRange(
1945            9223372036854775806, 9223372036854775807
1946        )
1947
1948
1949class NumRangeTests(_RangeTypeMixin, fixtures.TablesTest):
1950
1951    _col_type = NUMRANGE
1952    _col_str = 'NUMRANGE'
1953    _data_str = '[1.0,2.0)'
1954
1955    def _data_obj(self):
1956        return self.extras().NumericRange(
1957            decimal.Decimal('1.0'), decimal.Decimal('2.0')
1958        )
1959
1960
1961class DateRangeTests(_RangeTypeMixin, fixtures.TablesTest):
1962
1963    _col_type = DATERANGE
1964    _col_str = 'DATERANGE'
1965    _data_str = '[2013-03-23,2013-03-24)'
1966
1967    def _data_obj(self):
1968        return self.extras().DateRange(
1969            datetime.date(2013, 3, 23), datetime.date(2013, 3, 24)
1970        )
1971
1972
1973class DateTimeRangeTests(_RangeTypeMixin, fixtures.TablesTest):
1974
1975    _col_type = TSRANGE
1976    _col_str = 'TSRANGE'
1977    _data_str = '[2013-03-23 14:30,2013-03-23 23:30)'
1978
1979    def _data_obj(self):
1980        return self.extras().DateTimeRange(
1981            datetime.datetime(2013, 3, 23, 14, 30),
1982            datetime.datetime(2013, 3, 23, 23, 30)
1983        )
1984
1985
1986class DateTimeTZRangeTests(_RangeTypeMixin, fixtures.TablesTest):
1987
1988    _col_type = TSTZRANGE
1989    _col_str = 'TSTZRANGE'
1990
1991    # make sure we use one, steady timestamp with timezone pair
1992    # for all parts of all these tests
1993    _tstzs = None
1994
1995    def tstzs(self):
1996        if self._tstzs is None:
1997            lower = testing.db.scalar(
1998                func.current_timestamp().select()
1999            )
2000            upper = lower + datetime.timedelta(1)
2001            self._tstzs = (lower, upper)
2002        return self._tstzs
2003
2004    @property
2005    def _data_str(self):
2006        return '[%s,%s)' % self.tstzs()
2007
2008    def _data_obj(self):
2009        return self.extras().DateTimeTZRange(*self.tstzs())
2010
2011
2012class JSONTest(AssertsCompiledSQL, fixtures.TestBase):
2013    __dialect__ = 'postgresql'
2014
2015    def setup(self):
2016        metadata = MetaData()
2017        self.test_table = Table('test_table', metadata,
2018                                Column('id', Integer, primary_key=True),
2019                                Column('test_column', JSON),
2020                                )
2021        self.jsoncol = self.test_table.c.test_column
2022
2023    def _test_where(self, whereclause, expected):
2024        stmt = select([self.test_table]).where(whereclause)
2025        self.assert_compile(
2026            stmt,
2027            "SELECT test_table.id, test_table.test_column FROM test_table "
2028            "WHERE %s" % expected
2029        )
2030
2031    def _test_cols(self, colclause, expected, from_=True):
2032        stmt = select([colclause])
2033        self.assert_compile(
2034            stmt,
2035            (
2036                "SELECT %s" +
2037                (" FROM test_table" if from_ else "")
2038            ) % expected
2039        )
2040
2041    def test_bind_serialize_default(self):
2042        dialect = postgresql.dialect()
2043        proc = self.test_table.c.test_column.type._cached_bind_processor(
2044            dialect)
2045        eq_(
2046            proc({"A": [1, 2, 3, True, False]}),
2047            '{"A": [1, 2, 3, true, false]}'
2048        )
2049
2050    def test_bind_serialize_None(self):
2051        dialect = postgresql.dialect()
2052        proc = self.test_table.c.test_column.type._cached_bind_processor(
2053            dialect)
2054        eq_(
2055            proc(None),
2056            'null'
2057        )
2058
2059    def test_bind_serialize_none_as_null(self):
2060        dialect = postgresql.dialect()
2061        proc = JSON(none_as_null=True)._cached_bind_processor(
2062            dialect)
2063        eq_(
2064            proc(None),
2065            None
2066        )
2067        eq_(
2068            proc(null()),
2069            None
2070        )
2071
2072    def test_bind_serialize_null(self):
2073        dialect = postgresql.dialect()
2074        proc = self.test_table.c.test_column.type._cached_bind_processor(
2075            dialect)
2076        eq_(
2077            proc(null()),
2078            None
2079        )
2080
2081    def test_result_deserialize_default(self):
2082        dialect = postgresql.dialect()
2083        proc = self.test_table.c.test_column.type._cached_result_processor(
2084            dialect, None)
2085        eq_(
2086            proc('{"A": [1, 2, 3, true, false]}'),
2087            {"A": [1, 2, 3, True, False]}
2088        )
2089
2090    def test_result_deserialize_null(self):
2091        dialect = postgresql.dialect()
2092        proc = self.test_table.c.test_column.type._cached_result_processor(
2093            dialect, None)
2094        eq_(
2095            proc('null'),
2096            None
2097        )
2098
2099    def test_result_deserialize_None(self):
2100        dialect = postgresql.dialect()
2101        proc = self.test_table.c.test_column.type._cached_result_processor(
2102            dialect, None)
2103        eq_(
2104            proc(None),
2105            None
2106        )
2107
2108    # This test is a bit misleading -- in real life you will need to cast to
2109    # do anything
2110    def test_where_getitem(self):
2111        self._test_where(
2112            self.jsoncol['bar'] == None,
2113            "(test_table.test_column -> %(test_column_1)s) IS NULL"
2114        )
2115
2116    def test_where_path(self):
2117        self._test_where(
2118            self.jsoncol[("foo", 1)] == None,
2119            "(test_table.test_column #> %(test_column_1)s) IS NULL"
2120        )
2121
2122    def test_where_getitem_as_text(self):
2123        self._test_where(
2124            self.jsoncol['bar'].astext == None,
2125            "(test_table.test_column ->> %(test_column_1)s) IS NULL"
2126        )
2127
2128    def test_where_getitem_as_cast(self):
2129        self._test_where(
2130            self.jsoncol['bar'].cast(Integer) == 5,
2131            "CAST(test_table.test_column ->> %(test_column_1)s AS INTEGER) "
2132            "= %(param_1)s"
2133        )
2134
2135    def test_where_path_as_text(self):
2136        self._test_where(
2137            self.jsoncol[("foo", 1)].astext == None,
2138            "(test_table.test_column #>> %(test_column_1)s) IS NULL"
2139        )
2140
2141    def test_cols_get(self):
2142        self._test_cols(
2143            self.jsoncol['foo'],
2144            "test_table.test_column -> %(test_column_1)s AS anon_1",
2145            True
2146        )
2147
2148
2149class JSONRoundTripTest(fixtures.TablesTest):
2150    __only_on__ = ('postgresql >= 9.3',)
2151    __backend__ = True
2152
2153    test_type = JSON
2154
2155    @classmethod
2156    def define_tables(cls, metadata):
2157        Table('data_table', metadata,
2158              Column('id', Integer, primary_key=True),
2159              Column('name', String(30), nullable=False),
2160              Column('data', cls.test_type),
2161              Column('nulldata', cls.test_type(none_as_null=True))
2162              )
2163
2164    def _fixture_data(self, engine):
2165        data_table = self.tables.data_table
2166        engine.execute(
2167            data_table.insert(),
2168            {'name': 'r1', 'data': {"k1": "r1v1", "k2": "r1v2"}},
2169            {'name': 'r2', 'data': {"k1": "r2v1", "k2": "r2v2"}},
2170            {'name': 'r3', 'data': {"k1": "r3v1", "k2": "r3v2"}},
2171            {'name': 'r4', 'data': {"k1": "r4v1", "k2": "r4v2"}},
2172            {'name': 'r5', 'data': {"k1": "r5v1", "k2": "r5v2", "k3": 5}},
2173        )
2174
2175    def _assert_data(self, compare, column='data'):
2176        col = self.tables.data_table.c[column]
2177
2178        data = testing.db.execute(
2179            select([col]).
2180            order_by(self.tables.data_table.c.name)
2181        ).fetchall()
2182        eq_([d for d, in data], compare)
2183
2184    def _assert_column_is_NULL(self, column='data'):
2185        col = self.tables.data_table.c[column]
2186
2187        data = testing.db.execute(
2188            select([col]).
2189            where(col.is_(null()))
2190        ).fetchall()
2191        eq_([d for d, in data], [None])
2192
2193    def _test_insert(self, engine):
2194        engine.execute(
2195            self.tables.data_table.insert(),
2196            {'name': 'r1', 'data': {"k1": "r1v1", "k2": "r1v2"}}
2197        )
2198        self._assert_data([{"k1": "r1v1", "k2": "r1v2"}])
2199
2200    def _test_insert_nulls(self, engine):
2201        engine.execute(
2202            self.tables.data_table.insert(),
2203            {'name': 'r1', 'data': null()}
2204        )
2205        self._assert_data([None])
2206
2207    def _test_insert_none_as_null(self, engine):
2208        engine.execute(
2209            self.tables.data_table.insert(),
2210            {'name': 'r1', 'nulldata': None}
2211        )
2212        self._assert_column_is_NULL(column='nulldata')
2213
2214    def _non_native_engine(self, json_serializer=None, json_deserializer=None):
2215        if json_serializer is not None or json_deserializer is not None:
2216            options = {
2217                "json_serializer": json_serializer,
2218                "json_deserializer": json_deserializer
2219            }
2220        else:
2221            options = {}
2222
2223        if testing.against("postgresql+psycopg2") and \
2224                testing.db.dialect.psycopg2_version >= (2, 5):
2225            from psycopg2.extras import register_default_json
2226            engine = engines.testing_engine(options=options)
2227
2228            @event.listens_for(engine, "connect")
2229            def connect(dbapi_connection, connection_record):
2230                engine.dialect._has_native_json = False
2231
2232                def pass_(value):
2233                    return value
2234                register_default_json(dbapi_connection, loads=pass_)
2235        elif options:
2236            engine = engines.testing_engine(options=options)
2237        else:
2238            engine = testing.db
2239        engine.connect().close()
2240        return engine
2241
2242    def test_reflect(self):
2243        insp = inspect(testing.db)
2244        cols = insp.get_columns('data_table')
2245        assert isinstance(cols[2]['type'], self.test_type)
2246
2247    @testing.requires.psycopg2_native_json
2248    def test_insert_native(self):
2249        engine = testing.db
2250        self._test_insert(engine)
2251
2252    @testing.requires.psycopg2_native_json
2253    def test_insert_native_nulls(self):
2254        engine = testing.db
2255        self._test_insert_nulls(engine)
2256
2257    @testing.requires.psycopg2_native_json
2258    def test_insert_native_none_as_null(self):
2259        engine = testing.db
2260        self._test_insert_none_as_null(engine)
2261
2262    def test_insert_python(self):
2263        engine = self._non_native_engine()
2264        self._test_insert(engine)
2265
2266    def test_insert_python_nulls(self):
2267        engine = self._non_native_engine()
2268        self._test_insert_nulls(engine)
2269
2270    def test_insert_python_none_as_null(self):
2271        engine = self._non_native_engine()
2272        self._test_insert_none_as_null(engine)
2273
2274    def _test_custom_serialize_deserialize(self, native):
2275        import json
2276
2277        def loads(value):
2278            value = json.loads(value)
2279            value['x'] = value['x'] + '_loads'
2280            return value
2281
2282        def dumps(value):
2283            value = dict(value)
2284            value['x'] = 'dumps_y'
2285            return json.dumps(value)
2286
2287        if native:
2288            engine = engines.testing_engine(options=dict(
2289                json_serializer=dumps,
2290                json_deserializer=loads
2291            ))
2292        else:
2293            engine = self._non_native_engine(
2294                json_serializer=dumps,
2295                json_deserializer=loads
2296            )
2297
2298        s = select([
2299            cast(
2300                {
2301                    "key": "value",
2302                    "x": "q"
2303                },
2304                self.test_type
2305            )
2306        ])
2307        eq_(
2308            engine.scalar(s),
2309            {
2310                "key": "value",
2311                "x": "dumps_y_loads"
2312            },
2313        )
2314
2315    @testing.requires.psycopg2_native_json
2316    def test_custom_native(self):
2317        self._test_custom_serialize_deserialize(True)
2318
2319    @testing.requires.psycopg2_native_json
2320    def test_custom_python(self):
2321        self._test_custom_serialize_deserialize(False)
2322
2323    @testing.requires.psycopg2_native_json
2324    def test_criterion_native(self):
2325        engine = testing.db
2326        self._fixture_data(engine)
2327        self._test_criterion(engine)
2328
2329    def test_criterion_python(self):
2330        engine = self._non_native_engine()
2331        self._fixture_data(engine)
2332        self._test_criterion(engine)
2333
2334    def test_path_query(self):
2335        engine = testing.db
2336        self._fixture_data(engine)
2337        data_table = self.tables.data_table
2338        result = engine.execute(
2339            select([data_table.c.data]).where(
2340                data_table.c.data[('k1',)].astext == 'r3v1'
2341            )
2342        ).first()
2343        eq_(result, ({'k1': 'r3v1', 'k2': 'r3v2'},))
2344
2345    def test_query_returned_as_text(self):
2346        engine = testing.db
2347        self._fixture_data(engine)
2348        data_table = self.tables.data_table
2349        result = engine.execute(
2350            select([data_table.c.data['k1'].astext])
2351        ).first()
2352        assert isinstance(result[0], util.text_type)
2353
2354    def test_query_returned_as_int(self):
2355        engine = testing.db
2356        self._fixture_data(engine)
2357        data_table = self.tables.data_table
2358        result = engine.execute(
2359            select([data_table.c.data['k3'].cast(Integer)]).where(
2360                data_table.c.name == 'r5')
2361        ).first()
2362        assert isinstance(result[0], int)
2363
2364    def _test_criterion(self, engine):
2365        data_table = self.tables.data_table
2366        result = engine.execute(
2367            select([data_table.c.data]).where(
2368                data_table.c.data['k1'].astext == 'r3v1'
2369            )
2370        ).first()
2371        eq_(result, ({'k1': 'r3v1', 'k2': 'r3v2'},))
2372
2373    def _test_fixed_round_trip(self, engine):
2374        s = select([
2375            cast(
2376                {
2377                    "key": "value",
2378                    "key2": {"k1": "v1", "k2": "v2"}
2379                },
2380                self.test_type
2381            )
2382        ])
2383        eq_(
2384            engine.scalar(s),
2385            {
2386                "key": "value",
2387                "key2": {"k1": "v1", "k2": "v2"}
2388            },
2389        )
2390
2391    def test_fixed_round_trip_python(self):
2392        engine = self._non_native_engine()
2393        self._test_fixed_round_trip(engine)
2394
2395    @testing.requires.psycopg2_native_json
2396    def test_fixed_round_trip_native(self):
2397        engine = testing.db
2398        self._test_fixed_round_trip(engine)
2399
2400    def _test_unicode_round_trip(self, engine):
2401        s = select([
2402            cast(
2403                {
2404                    util.u('réveillé'): util.u('réveillé'),
2405                    "data": {"k1": util.u('drôle')}
2406                },
2407                self.test_type
2408            )
2409        ])
2410        eq_(
2411            engine.scalar(s),
2412            {
2413                util.u('réveillé'): util.u('réveillé'),
2414                "data": {"k1": util.u('drôle')}
2415            },
2416        )
2417
2418    def test_unicode_round_trip_python(self):
2419        engine = self._non_native_engine()
2420        self._test_unicode_round_trip(engine)
2421
2422    @testing.requires.psycopg2_native_json
2423    def test_unicode_round_trip_native(self):
2424        engine = testing.db
2425        self._test_unicode_round_trip(engine)
2426
2427
2428class JSONBTest(JSONTest):
2429
2430    def setup(self):
2431        metadata = MetaData()
2432        self.test_table = Table('test_table', metadata,
2433                                Column('id', Integer, primary_key=True),
2434                                Column('test_column', JSONB)
2435                                )
2436        self.jsoncol = self.test_table.c.test_column
2437
2438    # Note - add fixture data for arrays []
2439
2440    def test_where_has_key(self):
2441        self._test_where(
2442            # hide from 2to3
2443            getattr(self.jsoncol, 'has_key')('data'),
2444            "test_table.test_column ? %(test_column_1)s"
2445        )
2446
2447    def test_where_has_all(self):
2448        self._test_where(
2449            self.jsoncol.has_all(
2450                {'name': 'r1', 'data': {"k1": "r1v1", "k2": "r1v2"}}),
2451            "test_table.test_column ?& %(test_column_1)s")
2452
2453    def test_where_has_any(self):
2454        self._test_where(
2455            self.jsoncol.has_any(postgresql.array(['name', 'data'])),
2456            "test_table.test_column ?| ARRAY[%(param_1)s, %(param_2)s]"
2457        )
2458
2459    def test_where_contains(self):
2460        self._test_where(
2461            self.jsoncol.contains({"k1": "r1v1"}),
2462            "test_table.test_column @> %(test_column_1)s"
2463        )
2464
2465    def test_where_contained_by(self):
2466        self._test_where(
2467            self.jsoncol.contained_by({'foo': '1', 'bar': None}),
2468            "test_table.test_column <@ %(test_column_1)s"
2469        )
2470
2471
2472class JSONBRoundTripTest(JSONRoundTripTest):
2473    __only_on__ = ('postgresql >= 9.4',)
2474    __requires__ = ('postgresql_jsonb', )
2475
2476    test_type = JSONB
2477
2478    @testing.requires.postgresql_utf8_server_encoding
2479    def test_unicode_round_trip_python(self):
2480        super(JSONBRoundTripTest, self).test_unicode_round_trip_python()
2481
2482    @testing.requires.postgresql_utf8_server_encoding
2483    def test_unicode_round_trip_native(self):
2484        super(JSONBRoundTripTest, self).test_unicode_round_trip_native()
2485