1import sys, unittest
2from decimal import Decimal
3from datetime import date
4
5from pony.orm import *
6from pony.orm.tests.testutils import *
7from pony.orm.tests import db_params, teardown_database
8
9class TestIndexes(unittest.TestCase):
10    def setUp(self):
11        self.db = Database(**db_params)
12
13    def tearDown(self):
14        teardown_database(self.db)
15
16    def test_1(self):
17        db = self.db
18        class Person(db.Entity):
19            name = Required(str)
20            age = Required(int)
21            composite_key(name, 'age')
22        db.generate_mapping(create_tables=True)
23
24        i1, i2 = Person._indexes_
25        self.assertEqual(i1.attrs, (Person.id,))
26        self.assertEqual(i1.is_pk, True)
27        self.assertEqual(i1.is_unique, True)
28        self.assertEqual(i2.attrs, (Person.name, Person.age))
29        self.assertEqual(i2.is_pk, False)
30        self.assertEqual(i2.is_unique, True)
31
32        table_name = 'Person' if db.provider.dialect == 'SQLite' and pony.__version__ < '0.9' else 'person'
33        table = db.schema.tables[table_name]
34        name_column = table.column_dict['name']
35        age_column = table.column_dict['age']
36        self.assertEqual(len(table.indexes), 2)
37        db_index = table.indexes[name_column, age_column]
38        self.assertEqual(db_index.is_pk, False)
39        self.assertEqual(db_index.is_unique, True)
40
41    def test_2(self):
42        db = self.db
43        class Person(db.Entity):
44            name = Required(str)
45            age = Required(int)
46            composite_index(name, 'age')
47        db.generate_mapping(create_tables=True)
48
49        i1, i2 = Person._indexes_
50        self.assertEqual(i1.attrs, (Person.id,))
51        self.assertEqual(i1.is_pk, True)
52        self.assertEqual(i1.is_unique, True)
53        self.assertEqual(i2.attrs, (Person.name, Person.age))
54        self.assertEqual(i2.is_pk, False)
55        self.assertEqual(i2.is_unique, False)
56
57        table_name = 'Person' if db.provider.dialect == 'SQLite' and pony.__version__ < '0.9' else 'person'
58        table = db.schema.tables[table_name]
59        name_column = table.column_dict['name']
60        age_column = table.column_dict['age']
61        self.assertEqual(len(table.indexes), 2)
62        db_index = table.indexes[name_column, age_column]
63        self.assertEqual(db_index.is_pk, False)
64        self.assertEqual(db_index.is_unique, False)
65
66        create_script = db.schema.generate_create_script()
67
68
69        dialect = self.db.provider.dialect
70        if pony.__version__ < '0.9':
71            if dialect == 'SQLite':
72                index_sql = 'CREATE INDEX "idx_person__name_age" ON "Person" ("name", "age")'
73            else:
74                index_sql = 'CREATE INDEX "idx_person__name_age" ON "person" ("name", "age")'
75        elif dialect == 'MySQL' or dialect == 'SQLite':
76            index_sql = 'CREATE INDEX `idx_person__name__age` ON `person` (`name`, `age`)'
77        elif dialect == 'PostgreSQL':
78            index_sql = 'CREATE INDEX "idx_person__name__age" ON "person" ("name", "age")'
79        elif dialect == 'Oracle':
80            index_sql = 'CREATE INDEX "IDX_PERSON__NAME__AGE" ON "PERSON" ("NAME", "AGE")'
81        else:
82            raise NotImplementedError
83        self.assertIn(index_sql, create_script)
84
85    def test_3(self):
86        db = self.db
87        class User(db.Entity):
88            name = Required(str, unique=True)
89
90        db.generate_mapping(create_tables=True)
91
92        with db_session:
93            u = User(id=1, name='A')
94
95        with db_session:
96            u = User[1]
97            u.name = 'B'
98
99        with db_session:
100            u = User[1]
101            self.assertEqual(u.name, 'B')
102
103    def test_4(self):  # issue 321
104        db = self.db
105        class Person(db.Entity):
106            name = Required(str)
107            age = Required(int)
108            composite_key(name, age)
109
110        db.generate_mapping(create_tables=True)
111        with db_session:
112            p1 = Person(id=1, name='John', age=19)
113
114        with db_session:
115            p1 = Person[1]
116            p1.set(name='John', age=19)
117            p1.delete()
118
119    def test_5(self):
120        db = self.db
121
122        class Table1(db.Entity):
123            name = Required(str)
124            table2s = Set('Table2')
125
126        class Table2(db.Entity):
127            height = Required(int)
128            length = Required(int)
129            table1 = Optional('Table1')
130            composite_key(height, length, table1)
131
132        db.generate_mapping(create_tables=True)
133
134        with db_session:
135            Table2(height=2, length=1)
136            Table2.exists(height=2, length=1)
137
138if __name__ == '__main__':
139    unittest.main()
140