1import sys, unittest
2from decimal import Decimal
3from datetime import date
4
5from pony.orm import *
6from pony.orm.tests.testutils import *
7from pony.orm.tests import setup_database, teardown_database
8
9db = Database()
10
11
12class Student(db.Entity):
13    name = Required(str)
14    scholarship = Optional(int)
15    gpa = Optional(Decimal, 3, 1)
16    dob = Optional(date)
17    group = Required('Group')
18    courses = Set('Course')
19    mentor = Optional('Teacher')
20    biography = Optional(LongStr)
21
22
23class Group(db.Entity):
24    number = PrimaryKey(int)
25    major = Required(str, lazy=True)
26    students = Set(Student)
27
28
29class Course(db.Entity):
30    name = Required(str, unique=True)
31    students = Set(Student)
32
33
34class Teacher(db.Entity):
35    name = Required(str)
36    students = Set(Student)
37
38
39class TestPrefetching(unittest.TestCase):
40    @classmethod
41    def setUpClass(cls):
42        setup_database(db)
43        with db_session:
44            g1 = Group(number=1, major='Math')
45            g2 = Group(number=2, major='Computer Sciense')
46            c1 = Course(name='Math')
47            c2 = Course(name='Physics')
48            c3 = Course(name='Computer Science')
49            t1 = Teacher(name='T1')
50            t2 = Teacher(name='T2')
51            Student(id=1, name='S1', group=g1, gpa=3.1, courses=[c1, c2], biography='S1 bio', mentor=t1)
52            Student(id=2, name='S2', group=g1, gpa=4.2, scholarship=100, dob=date(2000, 1, 1), biography='S2 bio')
53            Student(id=3, name='S3', group=g1, gpa=4.7, scholarship=200, dob=date(2001, 1, 2), courses=[c2, c3])
54            Student(id=4, name='S4', group=g2, gpa=3.2, biography='S4 bio', courses=[c1, c3], mentor=t2)
55            Student(id=5, name='S5', group=g2, gpa=4.5, biography='S5 bio', courses=[c1, c3])
56
57    @classmethod
58    def tearDownClass(cls):
59        teardown_database(db)
60
61    def test_1(self):
62        with db_session:
63            s1 = Student.select().first()
64            g = s1.group
65            self.assertEqual(g.major, 'Math')
66
67    @raises_exception(DatabaseSessionIsOver, 'Cannot load attribute Group[1].major: the database session is over')
68    def test_2(self):
69        with db_session:
70            s1 = Student.select().first()
71            g = s1.group
72        g.major
73
74    def test_3(self):
75        with db_session:
76            s1 = Student.select().prefetch(Group, Group.major).first()
77            g = s1.group
78        self.assertEqual(g.major, 'Math')
79
80    def test_4(self):
81        with db_session:
82            s1 = Student.select().prefetch(Student.group, Group.major).first()
83            g = s1.group
84        self.assertEqual(g.major, 'Math')
85
86    @raises_exception(TypeError, 'Argument of prefetch() query method must be entity class or attribute. Got: 111')
87    def test_5(self):
88        with db_session:
89            Student.select().prefetch(111)
90
91    @raises_exception(DatabaseSessionIsOver, 'Cannot load attribute Group[1].major: the database session is over')
92    def test_6(self):
93        with db_session:
94            name, group = select((s.name, s.group) for s in Student).first()
95        group.major
96
97    def test_7(self):
98        with db_session:
99            name, group = select((s.name, s.group) for s in Student).prefetch(Group, Group.major).first()
100        self.assertEqual(group.major, 'Math')
101
102    @raises_exception(DatabaseSessionIsOver, 'Cannot load collection Student[1].courses: the database session is over')
103    def test_8(self):
104        with db_session:
105            s1 = Student.select().first()
106        set(s1.courses)
107
108    @raises_exception(DatabaseSessionIsOver, 'Cannot load collection Student[1].courses: the database session is over')
109    def test_9(self):
110        with db_session:
111            s1 = Student.select().prefetch(Course).first()
112        set(s1.courses)
113
114    def test_10(self):
115        with db_session:
116            s1 = Student.select().prefetch(Student.courses).first()
117        self.assertEqual(set(s1.courses.name), {'Math', 'Physics'})
118
119    @raises_exception(DatabaseSessionIsOver, 'Cannot load attribute Student[1].biography: the database session is over')
120    def test_11(self):
121        with db_session:
122            s1 = Student.select().prefetch(Course).first()
123        s1.biography
124
125    def test_12(self):
126        with db_session:
127            s1 = Student.select().prefetch(Student.biography).first()
128        self.assertEqual(s1.biography, 'S1 bio')
129        table_name = 'Student' if db.provider.dialect == 'SQLite' and pony.__version__ < '0.9' else 'student'
130        expected_sql = '''SELECT "s"."id", "s"."name", "s"."scholarship", "s"."gpa", "s"."dob", "s"."group", "s"."mentor", "s"."biography"
131FROM "%s" "s"
132ORDER BY 1
133LIMIT 1''' % table_name
134        if db.provider.dialect == 'SQLite' and pony.__version__ >= '0.9':
135            expected_sql = expected_sql.replace('"', '`')
136        self.assertEqual(db.last_sql, expected_sql)
137
138    def test_13(self):
139        db.merge_local_stats()
140        with db_session:
141            q = select(g for g in Group)
142            for g in q: # 1 query
143                for s in g.students:  # 2 query
144                    b = s.biography  # 5 queries
145            query_count = db.local_stats[None].db_count
146            self.assertEqual(query_count, 8)
147
148    def test_14(self):
149        db.merge_local_stats()
150        with db_session:
151            q = select(g for g in Group).prefetch(Group.students)
152            for g in q:   # 1 query
153                for s in g.students:  # 1 query
154                    b = s.biography  # 5 queries
155            query_count = db.local_stats[None].db_count
156            self.assertEqual(query_count, 7)
157
158    def test_15(self):
159        with db_session:
160            q = select(g for g in Group).prefetch(Group.students)
161            q[:]
162        db.merge_local_stats()
163        with db_session:
164            q = select(g for g in Group).prefetch(Group.students, Student.biography)
165            for g in q:  # 1 query
166                for s in g.students:  # 1 query
167                    b = s.biography  # 0 queries
168            query_count = db.local_stats[None].db_count
169            self.assertEqual(query_count, 2)
170
171    def test_16(self):
172        db.merge_local_stats()
173        with db_session:
174            q = select(c for c in Course).prefetch(Course.students, Student.biography)
175            for c in q:  # 1 query
176                for s in c.students:  # 2 queries (as it is many-to-many relationship)
177                    b = s.biography  # 0 queries
178            query_count = db.local_stats[None].db_count
179            self.assertEqual(query_count, 3)
180
181    def test_17(self):
182        db.merge_local_stats()
183        with db_session:
184            q = select(c for c in Course).prefetch(Course.students, Student.biography, Group, Group.major)
185            for c in q:  # 1 query
186                for s in c.students:  # 2 queries (as it is many-to-many relationship)
187                    m = s.group.major  # 1 query
188                    b = s.biography  # 0 queries
189            query_count = db.local_stats[None].db_count
190            self.assertEqual(query_count, 4)
191
192    def test_18(self):
193        db.merge_local_stats()
194        with db_session:
195            q = Group.select().prefetch(Group.students, Student.biography)
196            for g in q:  # 2 queries
197                for s in g.students:
198                    m = s.mentor  # 0 queries
199                    b = s.biography  # 0 queries
200            query_count = db.local_stats[None].db_count
201            self.assertEqual(query_count, 2)
202
203    def test_19(self):
204        db.merge_local_stats()
205        with db_session:
206            q = Group.select().prefetch(Group.students, Student.biography, Student.mentor)
207            mentors = set()
208            for g in q:  # 3 queries
209                for s in g.students:
210                    m = s.mentor  # 0 queries
211                    if m is not None:
212                        mentors.add(m)
213                    b = s.biography  # 0 queries
214            query_count = db.local_stats[None].db_count
215            self.assertEqual(query_count, 3)
216
217            for m in mentors:
218                n = m.name  # 0 queries
219            query_count = db.local_stats[None].db_count
220            self.assertEqual(query_count, 3)
221
222
223if __name__ == '__main__':
224    unittest.main()
225