1from sqlalchemy.testing import eq_
2import sqlalchemy as sa
3from sqlalchemy import testing
4from sqlalchemy import Integer, String, ForeignKey, MetaData, func
5from sqlalchemy.testing.schema import Table
6from sqlalchemy.testing.schema import Column
7from sqlalchemy.orm import mapper, relationship, create_session
8from sqlalchemy.testing import eq_
9from sqlalchemy.testing import fixtures
10from test.orm import _fixtures
11
12
13class GenerativeQueryTest(fixtures.MappedTest):
14    run_inserts = 'once'
15    run_deletes = None
16
17    @classmethod
18    def define_tables(cls, metadata):
19        Table('foo', metadata,
20              Column('id', Integer, sa.Sequence('foo_id_seq'),
21                     primary_key=True),
22              Column('bar', Integer),
23              Column('range', Integer))
24
25    @classmethod
26    def fixtures(cls):
27        rows = tuple([(i, i % 10) for i in range(100)])
28        foo_data = (('bar', 'range'),) + rows
29        return dict(foo=foo_data)
30
31    @classmethod
32    def setup_mappers(cls):
33        foo = cls.tables.foo
34
35        class Foo(cls.Basic):
36            pass
37
38        mapper(Foo, foo)
39
40    def test_selectby(self):
41        Foo = self.classes.Foo
42
43        res = create_session().query(Foo).filter_by(range=5)
44        assert res.order_by(Foo.bar)[0].bar == 5
45        assert res.order_by(sa.desc(Foo.bar))[0].bar == 95
46
47    def test_slice(self):
48        Foo = self.classes.Foo
49
50        sess = create_session()
51        query = sess.query(Foo).order_by(Foo.id)
52        orig = query.all()
53
54        assert query[1] == orig[1]
55        assert query[-4] == orig[-4]
56        assert query[-1] == orig[-1]
57
58        assert list(query[10:20]) == orig[10:20]
59        assert list(query[10:]) == orig[10:]
60        assert list(query[:10]) == orig[:10]
61        assert list(query[:10]) == orig[:10]
62        assert list(query[5:5]) == orig[5:5]
63        assert list(query[10:40:3]) == orig[10:40:3]
64        assert list(query[-5:]) == orig[-5:]
65        assert list(query[-2:-5]) == orig[-2:-5]
66        assert list(query[-5:-2]) == orig[-5:-2]
67        assert list(query[:-2]) == orig[:-2]
68
69        assert query[10:20][5] == orig[10:20][5]
70
71    @testing.uses_deprecated('Call to deprecated function apply_max')
72    def test_aggregate(self):
73        foo, Foo = self.tables.foo, self.classes.Foo
74
75        sess = create_session()
76        query = sess.query(Foo)
77        assert query.count() == 100
78        assert sess.query(func.min(foo.c.bar)).filter(foo.c.bar < 30) \
79            .one() == (0,)
80
81        assert sess.query(func.max(foo.c.bar)).filter(foo.c.bar < 30) \
82            .one() == (29,)
83        assert next(query.filter(foo.c.bar < 30).values(
84            sa.func.max(foo.c.bar)))[0] == 29
85        assert next(query.filter(foo.c.bar < 30).values(
86            sa.func.max(foo.c.bar)))[0] == 29
87
88    @testing.fails_if(
89        lambda: testing.against('mysql+mysqldb') and
90        testing.db.dialect.dbapi.version_info[:4] == (1, 2, 1, 'gamma'),
91        "unknown incompatibility")
92    def test_aggregate_1(self):
93        foo = self.tables.foo
94
95        query = create_session().query(func.sum(foo.c.bar))
96        assert query.filter(foo.c.bar < 30).one() == (435,)
97
98    @testing.fails_on('firebird', 'FIXME: unknown')
99    @testing.fails_on(
100        'mssql',
101        'AVG produces an average as the original column type on mssql.')
102    def test_aggregate_2(self):
103        foo = self.tables.foo
104
105        query = create_session().query(func.avg(foo.c.bar))
106        avg = query.filter(foo.c.bar < 30).one()[0]
107        eq_(float(round(avg, 1)), 14.5)
108
109    @testing.fails_on(
110        'mssql',
111        'AVG produces an average as the original column type on mssql.')
112    def test_aggregate_3(self):
113        foo, Foo = self.tables.foo, self.classes.Foo
114
115        query = create_session().query(Foo)
116
117        avg_f = next(query.filter(foo.c.bar < 30).values(
118            sa.func.avg(foo.c.bar)))[0]
119        assert float(round(avg_f, 1)) == 14.5
120
121        avg_o = next(query.filter(foo.c.bar < 30).values(
122            sa.func.avg(foo.c.bar)))[0]
123        assert float(round(avg_o, 1)) == 14.5
124
125    def test_filter(self):
126        Foo = self.classes.Foo
127
128        query = create_session().query(Foo)
129        assert query.count() == 100
130        assert query.filter(Foo.bar < 30).count() == 30
131        res2 = query.filter(Foo.bar < 30).filter(Foo.bar > 10)
132        assert res2.count() == 19
133
134    def test_order_by(self):
135        Foo = self.classes.Foo
136
137        query = create_session().query(Foo)
138        assert query.order_by(Foo.bar)[0].bar == 0
139        assert query.order_by(sa.desc(Foo.bar))[0].bar == 99
140
141    def test_offset(self):
142        Foo = self.classes.Foo
143
144        query = create_session().query(Foo)
145        assert list(query.order_by(Foo.bar).offset(10))[0].bar == 10
146
147    def test_offset(self):
148        Foo = self.classes.Foo
149
150        query = create_session().query(Foo)
151        assert len(list(query.limit(10))) == 10
152
153
154class GenerativeTest2(fixtures.MappedTest):
155
156    @classmethod
157    def define_tables(cls, metadata):
158        Table('table1', metadata,
159              Column('id', Integer, primary_key=True))
160        Table('table2', metadata,
161              Column('t1id', Integer, ForeignKey("table1.id"),
162                     primary_key=True),
163              Column('num', Integer, primary_key=True))
164
165    @classmethod
166    def setup_mappers(cls):
167        table2, table1 = cls.tables.table2, cls.tables.table1
168
169        class Obj1(cls.Basic):
170            pass
171
172        class Obj2(cls.Basic):
173            pass
174
175        mapper(Obj1, table1)
176        mapper(Obj2, table2)
177
178    @classmethod
179    def fixtures(cls):
180        return dict(
181            table1=(('id',),
182                    (1,),
183                    (2,),
184                    (3,),
185                    (4,)),
186            table2=(('num', 't1id'),
187                    (1, 1),
188                    (2, 1),
189                    (3, 1),
190                    (4, 2),
191                    (5, 2),
192                    (6, 3)))
193
194    def test_distinct_count(self):
195        table2, Obj1, table1 = (self.tables.table2,
196                                self.classes.Obj1,
197                                self.tables.table1)
198
199        query = create_session().query(Obj1)
200        eq_(query.count(), 4)
201
202        res = query.filter(sa.and_(table1.c.id == table2.c.t1id,
203                                   table2.c.t1id == 1))
204        eq_(res.count(), 3)
205        res = query.filter(sa.and_(table1.c.id == table2.c.t1id,
206                                   table2.c.t1id == 1)).distinct()
207        eq_(res.count(), 1)
208
209
210class RelationshipsTest(_fixtures.FixtureTest):
211    run_setup_mappers = 'once'
212    run_inserts = 'once'
213    run_deletes = None
214
215    @classmethod
216    def setup_mappers(cls):
217        addresses, Order, User, Address, orders, users = (cls.tables.addresses,
218                                                          cls.classes.Order,
219                                                          cls.classes.User,
220                                                          cls.classes.Address,
221                                                          cls.tables.orders,
222                                                          cls.tables.users)
223
224        mapper(User, users, properties={
225            'orders': relationship(mapper(Order, orders, properties={
226                'addresses': relationship(mapper(Address, addresses))}))})
227
228    def test_join(self):
229        """Query.join"""
230
231        User, Address = self.classes.User, self.classes.Address
232
233        session = create_session()
234        q = (session.query(User).join('orders', 'addresses').
235             filter(Address.id == 1))
236        eq_([User(id=7)], q.all())
237
238    def test_outer_join(self):
239        """Query.outerjoin"""
240
241        Order, User, Address = (self.classes.Order,
242                                self.classes.User,
243                                self.classes.Address)
244
245        session = create_session()
246        q = (session.query(User).outerjoin('orders', 'addresses').
247             filter(sa.or_(Order.id == None, Address.id == 1)))  # noqa
248        eq_(set([User(id=7), User(id=8), User(id=10)]),
249            set(q.all()))
250
251    def test_outer_join_count(self):
252        """test the join and outerjoin functions on Query"""
253
254        Order, User, Address = (self.classes.Order,
255                                self.classes.User,
256                                self.classes.Address)
257
258        session = create_session()
259
260        q = (session.query(User).outerjoin('orders', 'addresses').
261             filter(sa.or_(Order.id == None, Address.id == 1)))  # noqa
262        eq_(q.count(), 4)
263
264    def test_from(self):
265        users, Order, User, Address, orders, addresses = \
266            (self.tables.users,
267             self.classes.Order,
268             self.classes.User,
269             self.classes.Address,
270             self.tables.orders,
271             self.tables.addresses)
272
273        session = create_session()
274
275        sel = users.outerjoin(orders).outerjoin(
276            addresses, orders.c.address_id == addresses.c.id)
277        q = (session.query(User).select_from(sel).
278             filter(sa.or_(Order.id == None, Address.id == 1)))  # noqa
279        eq_(set([User(id=7), User(id=8), User(id=10)]),
280            set(q.all()))
281
282
283class CaseSensitiveTest(fixtures.MappedTest):
284
285    @classmethod
286    def define_tables(cls, metadata):
287        Table('Table1', metadata,
288              Column('ID', Integer, primary_key=True))
289        Table('Table2', metadata,
290              Column('T1ID', Integer, ForeignKey("Table1.ID"),
291                     primary_key=True),
292              Column('NUM', Integer, primary_key=True))
293
294    @classmethod
295    def setup_mappers(cls):
296        Table2, Table1 = cls.tables.Table2, cls.tables.Table1
297
298        class Obj1(cls.Basic):
299            pass
300
301        class Obj2(cls.Basic):
302            pass
303
304        mapper(Obj1, Table1)
305        mapper(Obj2, Table2)
306
307    @classmethod
308    def fixtures(cls):
309        return dict(
310            Table1=(('ID',),
311                    (1,),
312                    (2,),
313                    (3,),
314                    (4,)),
315            Table2=(('NUM', 'T1ID'),
316                    (1, 1),
317                    (2, 1),
318                    (3, 1),
319                    (4, 2),
320                    (5, 2),
321                    (6, 3)))
322
323    def test_distinct_count(self):
324        Table2, Obj1, Table1 = (self.tables.Table2,
325                                self.classes.Obj1,
326                                self.tables.Table1)
327
328        q = create_session(bind=testing.db).query(Obj1)
329        assert q.count() == 4
330        res = q.filter(
331            sa.and_(Table1.c.ID == Table2.c.T1ID, Table2.c.T1ID == 1))
332        assert res.count() == 3
333        res = q.filter(sa.and_(Table1.c.ID == Table2.c.T1ID,
334                               Table2.c.T1ID == 1)).distinct()
335        eq_(res.count(), 1)
336