1from __future__ import absolute_import, print_function, division
2from pony.py23compat import PYPY, PYPY2
3
4import sys, unittest
5from datetime import date
6from decimal import Decimal
7
8from pony.orm.core import *
9from pony.orm.sqltranslation import IncomparableTypesError
10from pony.orm.tests.testutils import *
11from pony.orm.tests import setup_database, teardown_database
12
13db = Database()
14
15class Student(db.Entity):
16    name = Required(unicode)
17    dob = Optional(date)
18    gpa = Optional(float)
19    scholarship = Optional(Decimal, 7, 2)
20    group = Required('Group')
21    courses = Set('Course')
22
23class Group(db.Entity):
24    number = PrimaryKey(int)
25    students = Set(Student)
26    dept = Required('Department')
27
28class Department(db.Entity):
29    number = PrimaryKey(int)
30    groups = Set(Group)
31
32class Course(db.Entity):
33    name = Required(unicode)
34    semester = Required(int)
35    PrimaryKey(name, semester)
36    students = Set(Student)
37
38
39class TestSQLTranslatorExceptions(unittest.TestCase):
40    @classmethod
41    def setUpClass(cls):
42        setup_database(db)
43        with db_session:
44            d1 = Department(number=44)
45            g1 = Group(number=101, dept=d1)
46            Student(name='S1', group=g1)
47            Student(name='S2', group=g1)
48            Student(name='S3', group=g1)
49    @classmethod
50    def tearDownClass(cls):
51        teardown_database(db)
52    def setUp(self):
53        rollback()
54        db_session.__enter__()
55    def tearDown(self):
56        rollback()
57        db_session.__exit__()
58    @raises_exception(NotImplementedError, 'for x in s.name')
59    def test1(self):
60        x = 10
61        select(s for s in Student for x in s.name)
62    @raises_exception(TranslationError, "Inside declarative query, iterator must be entity or query. Got: for i in x")
63    def test2(self):
64        x = [1, 2, 3]
65        select(s for s in Student for i in x)
66    @raises_exception(TranslationError, "Inside declarative query, iterator must be entity or query. Got: for s2 in g.students")
67    def test3(self):
68        g = Group[101]
69        select(s for s in Student for s2 in g.students)
70    @raises_exception(NotImplementedError, "*args is not supported")
71    def test4(self):
72        args = 'abc'
73        select(s for s in Student if s.name.upper(*args))
74
75    if sys.version_info[:2] < (3, 5): # TODO
76        @raises_exception(NotImplementedError) # "**{'a': 'b', 'c': 'd'} is not supported
77        def test5(self):
78            select(s for s in Student if s.name.upper(**{'a':'b', 'c':'d'}))
79
80    @raises_exception(ExprEvalError, "`1 in 2` raises TypeError: argument of type 'int' is not iterable" if not PYPY else
81                                     "`1 in 2` raises TypeError: 'int' object is not iterable")
82    def test6(self):
83        select(s for s in Student if 1 in 2)
84    @raises_exception(NotImplementedError, 'Group[s.group.number]')
85    def test7(self):
86        select(s for s in Student if Group[s.group.number].dept.number == 44)
87    @raises_exception(ExprEvalError, "`Group[123, 456].dept.number == 44` raises TypeError: Invalid count of attrs in Group primary key (2 instead of 1)")
88    def test8(self):
89        select(s for s in Student if Group[123, 456].dept.number == 44)
90    @raises_exception(ExprEvalError, "`Course[123]` raises TypeError: Invalid count of attrs in Course primary key (1 instead of 2)")
91    def test9(self):
92        select(s for s in Student if Course[123] in s.courses)
93    @raises_exception(TypeError, "Incomparable types '%s' and 'float' in expression: s.name < s.gpa" % unicode.__name__)
94    def test10(self):
95        select(s for s in Student if s.name < s.gpa)
96    @raises_exception(ExprEvalError, "`Group(101)` raises TypeError: Group constructor accept only keyword arguments. Got: 1 positional argument")
97    def test11(self):
98        select(s for s in Student if s.group == Group(101))
99    @raises_exception(ExprEvalError, "`Group[date(2011, 1, 2)]` raises TypeError: Value type for attribute Group.number must be int. Got: %r" % date)
100    def test12(self):
101        select(s for s in Student if s.group == Group[date(2011, 1, 2)])
102    @raises_exception(TypeError, "Unsupported operand types 'int' and '%s' for operation '+' in expression: s.group.number + s.name" % unicode.__name__)
103    def test13(self):
104        select(s for s in Student if s.group.number + s.name < 0)
105    @raises_exception(TypeError, "Unsupported operand types 'Decimal' and 'float' for operation '+' in expression: s.scholarship + 1.1")
106    def test14(self):
107        select(s for s in Student if s.scholarship + 1.1 > 10)
108    @raises_exception(TypeError, "Unsupported operand types 'Decimal' and '%s' for operation '**' "
109                                 "in expression: s.scholarship ** 'abc'" % unicode.__name__)
110    def test15(self):
111        select(s for s in Student if s.scholarship ** 'abc' > 10)
112    @raises_exception(TypeError, "Unsupported operand types '%s' and 'int' for operation '+' in expression: s.name + 2" % unicode.__name__)
113    def test16(self):
114        select(s for s in Student if s.name + 2 > 10)
115    @raises_exception(TypeError, "Step is not supported in s.name[1:3:5]")
116    def test17(self):
117        select(s for s in Student if s.name[1:3:5] == 'A')
118    @raises_exception(TypeError, "Invalid type of start index (expected 'int', got '%s') in string slice s.name['a':1]"
119                                 % unicode.__name__)
120    def test18(self):
121        select(s for s in Student if s.name['a':1] == 'A')
122    @raises_exception(TypeError, "Invalid type of stop index (expected 'int', got '%s') in string slice s.name[1:'a']"
123                                 % unicode.__name__)
124    def test19(self):
125        select(s for s in Student if s.name[1:'a'] == 'A')
126    @raises_exception(TypeError, "String indices must be integers. Got '%s' in expression s.name['a']" % unicode.__name__)
127    def test21(self):
128        select(s.name for s in Student if s.name['a'] == 'h')
129    @raises_exception(TypeError, "Incomparable types 'int' and '%s' in expression: 1 in s.name" % unicode.__name__)
130    def test22(self):
131        select(s.name for s in Student if 1 in s.name)
132    @raises_exception(TypeError, "Expected '%s' argument but got 'int' in expression s.name.startswith(1)" % unicode.__name__)
133    def test23(self):
134        select(s.name for s in Student if s.name.startswith(1))
135    @raises_exception(TypeError, "Expected '%s' argument but got 'int' in expression s.name.endswith(1)" % unicode.__name__)
136    def test24(self):
137        select(s.name for s in Student if s.name.endswith(1))
138    @raises_exception(TypeError, "'chars' argument must be of '%s' type in s.name.strip(1), got: 'int'" % unicode.__name__)
139    def test25(self):
140        select(s.name for s in Student if s.name.strip(1))
141    @raises_exception(AttributeError, "'%s' object has no attribute 'unknown': s.name.unknown" % unicode.__name__)
142    def test26(self):
143        result = set(select(s for s in Student if s.name.unknown() == "joe"))
144    @raises_exception(AttributeError, "Entity Group does not have attribute foo: s.group.foo")
145    def test27(self):
146        select(s.name for s in Student if s.group.foo.bar == 10)
147    @raises_exception(ExprEvalError, "`g.dept.foo.bar` raises AttributeError: 'Department' object has no attribute 'foo'")
148    def test28(self):
149        g = Group[101]
150        select(s for s in Student if s.name == g.dept.foo.bar)
151    @raises_exception(TypeError, "'year' argument of date(year, month, day) function must be of 'int' type. "
152                                 "Got: '%s'" % unicode.__name__)
153    def test29(self):
154        select(s for s in Student if s.dob < date('2011', 1, 1))
155    @raises_exception(NotImplementedError, "date(s.id, 1, 1)")
156    def test30(self):
157        select(s for s in Student if s.dob < date(s.id, 1, 1))
158    @raises_exception(
159        ExprEvalError,
160        "`max()` raises TypeError: max() expects at least one argument" if PYPY else
161        "`max()` raises TypeError: max expected 1 arguments, got 0" if sys.version_info[:2] < (3, 8) else
162        "`max()` raises TypeError: max expected 1 argument, got 0" if sys.version_info[:2] < (3, 9) else
163        "`max()` raises TypeError: max expected at least 1 argument, got 0")
164    def test31(self):
165        select(s for s in Student if s.id < max())
166    @raises_exception(TypeError, "Incomparable types 'Student' and 'Course' in expression: s in s.courses")
167    def test32(self):
168        select(s for s in Student if s in s.courses)
169    @raises_exception(AttributeError, "s.courses.name.foo")
170    def test33(self):
171        select(s for s in Student if 'x' in s.courses.name.foo.bar)
172    @raises_exception(AttributeError, "s.courses.foo")
173    def test34(self):
174        select(s for s in Student if 'x' in s.courses.foo.bar)
175    @raises_exception(TypeError, "Function sum() expects query or items of numeric type, got '%s' in sum(s.courses.name)" % unicode.__name__)
176    def test35(self):
177        select(s for s in Student if sum(s.courses.name) > 10)
178    @raises_exception(TypeError, "Function sum() expects query or items of numeric type, got '%s' in sum(c.name for c in s.courses)" % unicode.__name__)
179    def test36(self):
180        select(s for s in Student if sum(c.name for c in s.courses) > 10)
181    @raises_exception(TypeError, "Function sum() expects query or items of numeric type, got '%s' in sum(c.name for c in s.courses)" % unicode.__name__)
182    def test37(self):
183        select(s for s in Student if sum(c.name for c in s.courses) > 10)
184    @raises_exception(TypeError, "Function avg() expects query or items of numeric type, got '%s' in avg(c.name for c in s.courses)" % unicode.__name__)
185    def test38(self):
186        select(s for s in Student if avg(c.name for c in s.courses) > 10 and len(s.courses) > 1)
187    @raises_exception(TypeError, "strip() takes at most 1 argument (3 given)")
188    def test39(self):
189        select(s for s in Student if s.name.strip(1, 2, 3))
190    @raises_exception(ExprEvalError,
191                      "`len(1, 2) == 3` raises TypeError: len() takes exactly 1 argument (2 given)" if PYPY2 else
192                      "`len(1, 2) == 3` raises TypeError: len() takes 1 positional argument but 2 were given" if PYPY else
193                      "`len(1, 2) == 3` raises TypeError: len() takes exactly one argument (2 given)")
194    def test40(self):
195        select(s for s in Student if len(1, 2) == 3)
196    @raises_exception(TypeError, "Function sum() expects query or items of numeric type, got 'Student' in sum(s for s in Student if s.group == g)")
197    def test41(self):
198        select(g for g in Group if sum(s for s in Student if s.group == g) > 1)
199    @raises_exception(TypeError, "Function avg() expects query or items of numeric type, got 'Student' in avg(s for s in Student if s.group == g)")
200    def test42(self):
201        select(g for g in Group if avg(s for s in Student if s.group == g) > 1)
202    @raises_exception(TypeError, "Function min() cannot be applied to type 'Student' in min(s for s in Student if s.group == g)")
203    def test43(self):
204        select(g for g in Group if min(s for s in Student if s.group == g) > 1)
205    @raises_exception(TypeError, "Function max() cannot be applied to type 'Student' in max(s for s in Student if s.group == g)")
206    def test44(self):
207        select(g for g in Group if max(s for s in Student if s.group == g) > 1)
208    @raises_exception(TypeError, "Attribute should be specified for 'max' aggregate function")
209    def test45(self):
210        max(s for s in Student)
211    @raises_exception(TypeError, "Single attribute should be specified for 'max' aggregate function")
212    def test46(self):
213        max((s.name, s.gpa) for s in Student)
214    @raises_exception(TypeError, "Attribute should be specified for 'sum' aggregate function")
215    def test47(self):
216        sum(s for s in Student)
217    @raises_exception(TypeError, "Single attribute should be specified for 'sum' aggregate function")
218    def test48(self):
219        sum((s.name, s.gpa) for s in Student)
220    @raises_exception(TypeError, "'sum' is valid for numeric attributes only")
221    def test49(self):
222        sum(s.name for s in Student)
223    @raises_exception(TypeError, "Cannot compare whole JSON value, you need to select specific sub-item: s.name == {'a':'b'}")
224    def test50(self):
225        # cannot compare JSON value to dynamic string,
226        # because a database does not provide json.dumps(s.name) functionality
227        select(s for s in Student if s.name == {'a': 'b'})
228    @raises_exception(IncomparableTypesError, "Incomparable types '%s' and 'int' in expression: s.name > a & 2" % unicode.__name__)
229    def test51(self):
230        a = 1
231        select(s for s in Student if s.name > a & 2)
232    @raises_exception(TypeError, "Incomparable types '%s' and 'float' in expression: s.name > 1 / a - 3" % unicode.__name__)
233    def test52(self):
234        a = 1
235        select(s for s in Student if s.name > 1 / a - 3)
236    @raises_exception(TypeError, "Incomparable types '%s' and 'int' in expression: s.name > 1 // a - 3" % unicode.__name__)
237    def test53(self):
238        a = 1
239        select(s for s in Student if s.name > 1 // a - 3)
240    @raises_exception(TypeError, "Incomparable types '%s' and 'int' in expression: s.name > -a" % unicode.__name__)
241    def test54(self):
242        a = 1
243        select(s for s in Student if s.name > -a)
244    @raises_exception(TypeError, "Incomparable types '%s' and 'list' in expression: s.name == [1, (2,)]" % unicode.__name__)
245    def test55(self):
246        select(s for s in Student if s.name == [1, (2,)])
247    @raises_exception(TypeError, "Delete query should be applied to a single entity. Got: (s, g)")
248    def test56(self):
249        delete((s, g) for g in Group for s in g.students if s.gpa > 3)
250
251if __name__ == '__main__':
252    unittest.main()
253