1from sqlalchemy import ForeignKey
2from sqlalchemy import Integer
3from sqlalchemy import String
4from sqlalchemy.orm import create_session
5from sqlalchemy.orm import mapper
6from sqlalchemy.orm import polymorphic_union
7from sqlalchemy.orm import relationship
8from sqlalchemy.testing import AssertsCompiledSQL
9from sqlalchemy.testing import config
10from sqlalchemy.testing import fixtures
11from sqlalchemy.testing.schema import Column
12from sqlalchemy.testing.schema import Table
13
14
15class Company(fixtures.ComparableEntity):
16    pass
17
18
19class Person(fixtures.ComparableEntity):
20    pass
21
22
23class Engineer(Person):
24    pass
25
26
27class Manager(Person):
28    pass
29
30
31class Boss(Manager):
32    pass
33
34
35class Machine(fixtures.ComparableEntity):
36    pass
37
38
39class MachineType(fixtures.ComparableEntity):
40    pass
41
42
43class Paperwork(fixtures.ComparableEntity):
44    pass
45
46
47class Page(fixtures.ComparableEntity):
48    pass
49
50
51class _PolymorphicFixtureBase(fixtures.MappedTest, AssertsCompiledSQL):
52    run_inserts = "once"
53    run_setup_mappers = "once"
54    run_deletes = None
55
56    @classmethod
57    def define_tables(cls, metadata):
58        global people, engineers, managers, boss
59        global companies, paperwork, machines
60
61        companies = Table(
62            "companies",
63            metadata,
64            Column(
65                "company_id",
66                Integer,
67                primary_key=True,
68                test_needs_autoincrement=True,
69            ),
70            Column("name", String(50)),
71        )
72
73        people = Table(
74            "people",
75            metadata,
76            Column(
77                "person_id",
78                Integer,
79                primary_key=True,
80                test_needs_autoincrement=True,
81            ),
82            Column("company_id", Integer, ForeignKey("companies.company_id")),
83            Column("name", String(50)),
84            Column("type", String(30)),
85        )
86
87        engineers = Table(
88            "engineers",
89            metadata,
90            Column(
91                "person_id",
92                Integer,
93                ForeignKey("people.person_id"),
94                primary_key=True,
95            ),
96            Column("status", String(30)),
97            Column("engineer_name", String(50)),
98            Column("primary_language", String(50)),
99        )
100
101        machines = Table(
102            "machines",
103            metadata,
104            Column(
105                "machine_id",
106                Integer,
107                primary_key=True,
108                test_needs_autoincrement=True,
109            ),
110            Column("name", String(50)),
111            Column("engineer_id", Integer, ForeignKey("engineers.person_id")),
112        )
113
114        managers = Table(
115            "managers",
116            metadata,
117            Column(
118                "person_id",
119                Integer,
120                ForeignKey("people.person_id"),
121                primary_key=True,
122            ),
123            Column("status", String(30)),
124            Column("manager_name", String(50)),
125        )
126
127        boss = Table(
128            "boss",
129            metadata,
130            Column(
131                "boss_id",
132                Integer,
133                ForeignKey("managers.person_id"),
134                primary_key=True,
135            ),
136            Column("golf_swing", String(30)),
137        )
138
139        paperwork = Table(
140            "paperwork",
141            metadata,
142            Column(
143                "paperwork_id",
144                Integer,
145                primary_key=True,
146                test_needs_autoincrement=True,
147            ),
148            Column("description", String(50)),
149            Column("person_id", Integer, ForeignKey("people.person_id")),
150        )
151
152    @classmethod
153    def insert_data(cls):
154
155        cls.e1 = e1 = Engineer(
156            name="dilbert",
157            engineer_name="dilbert",
158            primary_language="java",
159            status="regular engineer",
160            paperwork=[
161                Paperwork(description="tps report #1"),
162                Paperwork(description="tps report #2"),
163            ],
164            machines=[Machine(name="IBM ThinkPad"), Machine(name="IPhone")],
165        )
166
167        cls.e2 = e2 = Engineer(
168            name="wally",
169            engineer_name="wally",
170            primary_language="c++",
171            status="regular engineer",
172            paperwork=[
173                Paperwork(description="tps report #3"),
174                Paperwork(description="tps report #4"),
175            ],
176            machines=[Machine(name="Commodore 64")],
177        )
178
179        cls.b1 = b1 = Boss(
180            name="pointy haired boss",
181            golf_swing="fore",
182            manager_name="pointy",
183            status="da boss",
184            paperwork=[Paperwork(description="review #1")],
185        )
186
187        cls.m1 = m1 = Manager(
188            name="dogbert",
189            manager_name="dogbert",
190            status="regular manager",
191            paperwork=[
192                Paperwork(description="review #2"),
193                Paperwork(description="review #3"),
194            ],
195        )
196
197        cls.e3 = e3 = Engineer(
198            name="vlad",
199            engineer_name="vlad",
200            primary_language="cobol",
201            status="elbonian engineer",
202            paperwork=[Paperwork(description="elbonian missive #3")],
203            machines=[Machine(name="Commodore 64"), Machine(name="IBM 3270")],
204        )
205
206        cls.c1 = c1 = Company(name="MegaCorp, Inc.")
207        c1.employees = [e1, e2, b1, m1]
208        cls.c2 = c2 = Company(name="Elbonia, Inc.")
209        c2.employees = [e3]
210
211        sess = create_session()
212        sess.add(c1)
213        sess.add(c2)
214        sess.flush()
215        sess.expunge_all()
216
217        cls.all_employees = [e1, e2, b1, m1, e3]
218        cls.c1_employees = [e1, e2, b1, m1]
219        cls.c2_employees = [e3]
220
221    def _company_with_emps_machines_fixture(self):
222        fixture = self._company_with_emps_fixture()
223        fixture[0].employees[0].machines = [
224            Machine(name="IBM ThinkPad"),
225            Machine(name="IPhone"),
226        ]
227        fixture[0].employees[1].machines = [Machine(name="Commodore 64")]
228        return fixture
229
230    def _company_with_emps_fixture(self):
231        return [
232            Company(
233                name="MegaCorp, Inc.",
234                employees=[
235                    Engineer(
236                        name="dilbert",
237                        engineer_name="dilbert",
238                        primary_language="java",
239                        status="regular engineer",
240                    ),
241                    Engineer(
242                        name="wally",
243                        engineer_name="wally",
244                        primary_language="c++",
245                        status="regular engineer",
246                    ),
247                    Boss(
248                        name="pointy haired boss",
249                        golf_swing="fore",
250                        manager_name="pointy",
251                        status="da boss",
252                    ),
253                    Manager(
254                        name="dogbert",
255                        manager_name="dogbert",
256                        status="regular manager",
257                    ),
258                ],
259            ),
260            Company(
261                name="Elbonia, Inc.",
262                employees=[
263                    Engineer(
264                        name="vlad",
265                        engineer_name="vlad",
266                        primary_language="cobol",
267                        status="elbonian engineer",
268                    )
269                ],
270            ),
271        ]
272
273    def _emps_wo_relationships_fixture(self):
274        return [
275            Engineer(
276                name="dilbert",
277                engineer_name="dilbert",
278                primary_language="java",
279                status="regular engineer",
280            ),
281            Engineer(
282                name="wally",
283                engineer_name="wally",
284                primary_language="c++",
285                status="regular engineer",
286            ),
287            Boss(
288                name="pointy haired boss",
289                golf_swing="fore",
290                manager_name="pointy",
291                status="da boss",
292            ),
293            Manager(
294                name="dogbert",
295                manager_name="dogbert",
296                status="regular manager",
297            ),
298            Engineer(
299                name="vlad",
300                engineer_name="vlad",
301                primary_language="cobol",
302                status="elbonian engineer",
303            ),
304        ]
305
306    @classmethod
307    def setup_mappers(cls):
308        mapper(
309            Company,
310            companies,
311            properties={
312                "employees": relationship(Person, order_by=people.c.person_id)
313            },
314        )
315
316        mapper(Machine, machines)
317
318        person_with_polymorphic, manager_with_polymorphic = (
319            cls._get_polymorphics()
320        )
321
322        mapper(
323            Person,
324            people,
325            with_polymorphic=person_with_polymorphic,
326            polymorphic_on=people.c.type,
327            polymorphic_identity="person",
328            properties={
329                "paperwork": relationship(
330                    Paperwork, order_by=paperwork.c.paperwork_id
331                )
332            },
333        )
334
335        mapper(
336            Engineer,
337            engineers,
338            inherits=Person,
339            polymorphic_identity="engineer",
340            properties={
341                "machines": relationship(
342                    Machine, order_by=machines.c.machine_id
343                )
344            },
345        )
346
347        mapper(
348            Manager,
349            managers,
350            with_polymorphic=manager_with_polymorphic,
351            inherits=Person,
352            polymorphic_identity="manager",
353        )
354
355        mapper(Boss, boss, inherits=Manager, polymorphic_identity="boss")
356
357        mapper(Paperwork, paperwork)
358
359
360class _Polymorphic(_PolymorphicFixtureBase):
361    select_type = ""
362
363    @classmethod
364    def _get_polymorphics(cls):
365        return None, None
366
367
368class _PolymorphicPolymorphic(_PolymorphicFixtureBase):
369    select_type = "Polymorphic"
370
371    @classmethod
372    def _get_polymorphics(cls):
373        return "*", "*"
374
375
376class _PolymorphicUnions(_PolymorphicFixtureBase):
377    select_type = "Unions"
378
379    @classmethod
380    def _get_polymorphics(cls):
381        people, engineers, managers, boss = (
382            cls.tables.people,
383            cls.tables.engineers,
384            cls.tables.managers,
385            cls.tables.boss,
386        )
387        person_join = polymorphic_union(
388            {
389                "engineer": people.join(engineers),
390                "manager": people.join(managers),
391            },
392            None,
393            "pjoin",
394        )
395        manager_join = people.join(managers).outerjoin(boss)
396        person_with_polymorphic = ([Person, Manager, Engineer], person_join)
397        manager_with_polymorphic = ("*", manager_join)
398        return person_with_polymorphic, manager_with_polymorphic
399
400
401class _PolymorphicAliasedJoins(_PolymorphicFixtureBase):
402    select_type = "AliasedJoins"
403
404    @classmethod
405    def _get_polymorphics(cls):
406        people, engineers, managers, boss = (
407            cls.tables.people,
408            cls.tables.engineers,
409            cls.tables.managers,
410            cls.tables.boss,
411        )
412        person_join = (
413            people.outerjoin(engineers)
414            .outerjoin(managers)
415            .select(use_labels=True)
416            .alias("pjoin")
417        )
418        manager_join = (
419            people.join(managers)
420            .outerjoin(boss)
421            .select(use_labels=True)
422            .alias("mjoin")
423        )
424        person_with_polymorphic = ([Person, Manager, Engineer], person_join)
425        manager_with_polymorphic = ("*", manager_join)
426        return person_with_polymorphic, manager_with_polymorphic
427
428
429class _PolymorphicJoins(_PolymorphicFixtureBase):
430    select_type = "Joins"
431
432    @classmethod
433    def _get_polymorphics(cls):
434        people, engineers, managers, boss = (
435            cls.tables.people,
436            cls.tables.engineers,
437            cls.tables.managers,
438            cls.tables.boss,
439        )
440        person_join = people.outerjoin(engineers).outerjoin(managers)
441        manager_join = people.join(managers).outerjoin(boss)
442        person_with_polymorphic = ([Person, Manager, Engineer], person_join)
443        manager_with_polymorphic = ("*", manager_join)
444        return person_with_polymorphic, manager_with_polymorphic
445
446
447class GeometryFixtureBase(fixtures.DeclarativeMappedTest):
448    """Provides arbitrary inheritance hierarchies based on a dictionary
449    structure.
450
451    e.g.::
452
453        self._fixture_from_geometry(
454            "a": {
455                "subclasses": {
456                    "b": {"polymorphic_load": "selectin"},
457                    "c": {
458                        "subclasses": {
459                            "d": {
460                                "polymorphic_load": "inlne", "single": True
461                            },
462                            "e": {
463                                "polymorphic_load": "inline", "single": True
464                            },
465                        },
466                        "polymorphic_load": "selectin",
467                    }
468                }
469            }
470        )
471
472    would provide the equivalent of::
473
474        class a(Base):
475            __tablename__ = 'a'
476
477            id = Column(Integer, primary_key=True)
478            a_data = Column(String(50))
479            type = Column(String(50))
480            __mapper_args__ = {
481                "polymorphic_on": type,
482                "polymorphic_identity": "a"
483            }
484
485        class b(a):
486            __tablename__ = 'b'
487
488            id = Column(ForeignKey('a.id'), primary_key=True)
489            b_data = Column(String(50))
490
491            __mapper_args__ = {
492                "polymorphic_identity": "b",
493                "polymorphic_load": "selectin"
494            }
495
496            # ...
497
498        class c(a):
499            __tablename__ = 'c'
500
501        class d(c):
502            # ...
503
504        class e(c):
505            # ...
506
507    Declarative is used so that we get extra behaviors of declarative,
508    such as single-inheritance column masking.
509
510    """
511
512    run_create_tables = "each"
513    run_define_tables = "each"
514    run_setup_classes = "each"
515    run_setup_mappers = "each"
516
517    def _fixture_from_geometry(self, geometry, base=None):
518        if not base:
519            is_base = True
520            base = self.DeclarativeBasic
521        else:
522            is_base = False
523
524        for key, value in geometry.items():
525            if is_base:
526                type_ = Column(String(50))
527                items = {
528                    "__tablename__": key,
529                    "id": Column(Integer, primary_key=True),
530                    "type": type_,
531                    "__mapper_args__": {
532                        "polymorphic_on": type_,
533                        "polymorphic_identity": key,
534                    },
535                }
536            else:
537                items = {"__mapper_args__": {"polymorphic_identity": key}}
538
539                if not value.get("single", False):
540                    items["__tablename__"] = key
541                    items["id"] = Column(
542                        ForeignKey("%s.id" % base.__tablename__),
543                        primary_key=True,
544                    )
545
546            items["%s_data" % key] = Column(String(50))
547
548            # add other mapper options to be transferred here as needed.
549            for mapper_opt in ("polymorphic_load",):
550                if mapper_opt in value:
551                    items["__mapper_args__"][mapper_opt] = value[mapper_opt]
552
553            if is_base:
554                klass = type(key, (fixtures.ComparableEntity, base), items)
555            else:
556                klass = type(key, (base,), items)
557
558            if "subclasses" in value:
559                self._fixture_from_geometry(value["subclasses"], klass)
560
561        if is_base and self.metadata.tables and self.run_create_tables:
562            self.tables.update(self.metadata.tables)
563            self.metadata.create_all(config.db)
564