1from __future__ import absolute_import, print_function, division
2from pony.py23compat import PYPY2
3
4import unittest
5from datetime import date
6
7from pony.orm import *
8from pony.orm.tests.testutils import raises_exception
9from pony.orm.tests import setup_database, teardown_database, only_for
10
11db = Database()
12
13
14class Person(db.Entity):
15    id = PrimaryKey(int)
16    name = Required(str)
17    age = Required(int)
18    dob = Required(date)
19
20
21@only_for('sqlite')
22class TestRawSQL(unittest.TestCase):
23    @classmethod
24    def setUpClass(cls):
25        setup_database(db)
26        with db_session:
27            Person(id=1, name='John', age=30, dob=date(1985, 1, 1))
28            Person(id=2, name='Mike', age=32, dob=date(1983, 5, 20))
29            Person(id=3, name='Mary', age=20, dob=date(1995, 2, 15))
30
31    @classmethod
32    def tearDownClass(cls):
33        teardown_database(db)
34
35    @db_session
36    def test_1(self):
37        # raw_sql result can be treated as a logical expression
38        persons = select(p for p in Person if raw_sql('abs("p"."age") > 25'))[:]
39        self.assertEqual(set(persons), {Person[1], Person[2]})
40
41    @db_session
42    def test_2(self):
43        # raw_sql result can be used for comparison
44        persons = select(p for p in Person if raw_sql('abs("p"."age")') > 25)[:]
45        self.assertEqual(set(persons), {Person[1], Person[2]})
46
47    @db_session
48    def test_3(self):
49        # raw_sql can accept $parameters
50        x = 25
51        persons = select(p for p in Person if raw_sql('abs("p"."age") > $x'))[:]
52        self.assertEqual(set(persons), {Person[1], Person[2]})
53
54    @db_session
55    def test_4(self):
56        # dynamic raw_sql content (1)
57        x = 1
58        s = 'p.id > $x'
59        persons = select(p for p in Person if raw_sql(s))[:]
60        self.assertEqual(set(persons), {Person[2], Person[3]})
61
62    @db_session
63    def test_5(self):
64        # dynamic raw_sql content (2)
65        x = 1
66        cond = raw_sql('p.id > $x')
67        persons = select(p for p in Person if cond)[:]
68        self.assertEqual(set(persons), {Person[2], Person[3]})
69
70    @db_session
71    def test_6(self):
72        # correct converter should be applied to raw_sql parameter type
73        x = date(1990, 1, 1)
74        persons = select(p for p in Person if raw_sql('p.dob < $x'))[:]
75        self.assertEqual(set(persons), {Person[1], Person[2]})
76
77    @db_session
78    def test_7(self):
79        # raw_sql argument may be complex expression (1)
80        x = 10
81        y = 15
82        persons = select(p for p in Person if raw_sql('p.age > $(x + y)'))[:]
83        self.assertEqual(set(persons), {Person[1], Person[2]})
84
85    @db_session
86    def test_8(self):
87        # raw_sql argument may be complex expression (2)
88        persons = select(p for p in Person if raw_sql('p.dob < $date.today()'))[:]
89        self.assertEqual(set(persons), {Person[1], Person[2], Person[3]})
90
91    @db_session
92    def test_9(self):
93        # using raw_sql in the expression part of the generator
94        names = select(raw_sql('UPPER(p.name)') for p in Person)[:]
95        self.assertEqual(set(names), {'JOHN', 'MIKE', 'MARY'})
96
97    @db_session
98    def test_10(self):
99        # raw_sql does not know result type and cannot appy correct type converter automatically
100        dates = select(raw_sql('(p.dob)') for p in Person).order_by(lambda: p.id)[:]
101        self.assertEqual(dates, ['1985-01-01', '1983-05-20', '1995-02-15'])
102
103    @db_session
104    def test_11(self):
105        # it is possible to specify raw_sql type manually
106        dates = select(raw_sql('(p.dob)', result_type=date) for p in Person).order_by(lambda: p.id)[:]
107        self.assertEqual(dates, [date(1985, 1, 1), date(1983, 5, 20), date(1995, 2, 15)])
108
109    @db_session
110    def test_12(self):
111        # raw_sql can be used in lambdas
112        x = 25
113        persons = Person.select(lambda p: p.age > raw_sql('$x'))[:]
114        self.assertEqual(set(persons), {Person[1], Person[2]})
115
116    @db_session
117    def test_13(self):
118        # raw_sql in filter()
119        x = 25
120        persons = select(p for p in Person).filter(lambda p: p.age > raw_sql('$x'))[:]
121        self.assertEqual(set(persons), {Person[1], Person[2]})
122
123    @db_session
124    def test_14(self):
125        # raw_sql in filter() without using lambda
126        x = 25
127        persons = Person.select().filter(raw_sql('p.age > $x'))[:]
128        self.assertEqual(set(persons), {Person[1], Person[2]})
129
130    @db_session
131    def test_15(self):
132        # several raw_sql expressions in a single query
133        x = '123'
134        y = 'John'
135        persons = Person.select(lambda p: raw_sql("UPPER(p.name) || $x") == raw_sql("UPPER($y || '123')"))[:]
136        self.assertEqual(set(persons), {Person[1]})
137
138    @db_session
139    def test_16(self):
140        # the same param name can be used several times with different types & values
141        x = 10
142        y = 31
143        q = select(p for p in Person if p.age > x and p.age < raw_sql('$y'))
144        x = date(1980, 1, 1)
145        y = 'j'
146        q = q.filter(lambda p: p.dob > x and p.name.startswith(raw_sql('UPPER($y)')))
147        persons = q[:]
148        self.assertEqual(set(persons), {Person[1]})
149
150    @db_session
151    def test_17(self):
152        # raw_sql in order_by() section
153        x = 9
154        persons = Person.select().order_by(lambda p: raw_sql('SUBSTR(p.dob, $x)'))[:]
155        self.assertEqual(persons, [Person[1], Person[3], Person[2]])
156
157    @db_session
158    def test_18(self):
159        # raw_sql in order_by() section without using lambda
160        x = 9
161        persons = Person.select().order_by(raw_sql('SUBSTR(p.dob, $x)'))[:]
162        self.assertEqual(persons, [Person[1], Person[3], Person[2]])
163
164    @db_session
165    @raises_exception(TranslationError, "Expression `raw_sql(p.name)` cannot be translated into SQL "
166                                        "because raw SQL fragment will be different for each row")
167    def test_19(self):
168        # raw_sql argument cannot depend on iterator variables
169        select(p for p in Person if raw_sql(p.name))[:]
170
171    @db_session
172    @raises_exception(ExprEvalError,
173                      "`raw_sql('p.dob < $x')` raises NameError: global name 'x' is not defined" if PYPY2 else
174                      "`raw_sql('p.dob < $x')` raises NameError: name 'x' is not defined")
175    def test_20(self):
176        # testing for situation where parameter variable is missing
177        select(p for p in Person if raw_sql('p.dob < $x'))[:]
178
179    @db_session
180    def test_21(self):
181        x = None
182        persons = select(p for p in Person if p.id == 1 and raw_sql('$x') is None)[:]
183        self.assertEqual(persons, [Person[1]])
184
185
186if __name__ == '__main__':
187    unittest.main()
188