1from datetime import datetime
2
3from sqlalchemy import DateTime
4from sqlalchemy import Float
5from sqlalchemy import ForeignKey
6from sqlalchemy import Integer
7from sqlalchemy import LargeBinary
8from sqlalchemy import String
9from sqlalchemy.orm import backref
10from sqlalchemy.orm import create_session
11from sqlalchemy.orm import deferred
12from sqlalchemy.orm import mapper
13from sqlalchemy.orm import relationship
14from sqlalchemy.testing import fixtures
15from sqlalchemy.testing.schema import Column
16from sqlalchemy.testing.schema import Table
17
18
19class InheritTest(fixtures.MappedTest):
20    """tests some various inheritance round trips involving a particular set of
21    polymorphic inheritance relationships"""
22
23    @classmethod
24    def define_tables(cls, metadata):
25        global products_table, specification_table, documents_table
26        global Product, Detail, Assembly, SpecLine, Document, RasterDocument
27
28        products_table = Table(
29            "products",
30            metadata,
31            Column(
32                "product_id",
33                Integer,
34                primary_key=True,
35                test_needs_autoincrement=True,
36            ),
37            Column("product_type", String(128)),
38            Column("name", String(128)),
39            Column("mark", String(128)),
40        )
41
42        specification_table = Table(
43            "specification",
44            metadata,
45            Column(
46                "spec_line_id",
47                Integer,
48                primary_key=True,
49                test_needs_autoincrement=True,
50            ),
51            Column(
52                "master_id",
53                Integer,
54                ForeignKey("products.product_id"),
55                nullable=True,
56            ),
57            Column(
58                "slave_id",
59                Integer,
60                ForeignKey("products.product_id"),
61                nullable=True,
62            ),
63            Column("quantity", Float, default=1.0),
64        )
65
66        documents_table = Table(
67            "documents",
68            metadata,
69            Column(
70                "document_id",
71                Integer,
72                primary_key=True,
73                test_needs_autoincrement=True,
74            ),
75            Column("document_type", String(128)),
76            Column("product_id", Integer, ForeignKey("products.product_id")),
77            Column("create_date", DateTime, default=lambda: datetime.now()),
78            Column(
79                "last_updated",
80                DateTime,
81                default=lambda: datetime.now(),
82                onupdate=lambda: datetime.now(),
83            ),
84            Column("name", String(128)),
85            Column("data", LargeBinary),
86            Column("size", Integer, default=0),
87        )
88
89        class Product(object):
90            def __init__(self, name, mark=""):
91                self.name = name
92                self.mark = mark
93
94            def __repr__(self):
95                return "<%s %s>" % (self.__class__.__name__, self.name)
96
97        class Detail(Product):
98            def __init__(self, name):
99                self.name = name
100
101        class Assembly(Product):
102            def __repr__(self):
103                return (
104                    Product.__repr__(self)
105                    + " "
106                    + " ".join(
107                        [
108                            x + "=" + repr(getattr(self, x, None))
109                            for x in ["specification", "documents"]
110                        ]
111                    )
112                )
113
114        class SpecLine(object):
115            def __init__(self, master=None, slave=None, quantity=1):
116                self.master = master
117                self.slave = slave
118                self.quantity = quantity
119
120            def __repr__(self):
121                return "<%s %.01f %s>" % (
122                    self.__class__.__name__,
123                    self.quantity or 0.0,
124                    repr(self.slave),
125                )
126
127        class Document(object):
128            def __init__(self, name, data=None):
129                self.name = name
130                self.data = data
131
132            def __repr__(self):
133                return "<%s %s>" % (self.__class__.__name__, self.name)
134
135        class RasterDocument(Document):
136            pass
137
138    def test_one(self):
139        product_mapper = mapper(
140            Product,
141            products_table,
142            polymorphic_on=products_table.c.product_type,
143            polymorphic_identity="product",
144        )
145
146        detail_mapper = mapper(
147            Detail, inherits=product_mapper, polymorphic_identity="detail"
148        )
149
150        assembly_mapper = mapper(
151            Assembly, inherits=product_mapper, polymorphic_identity="assembly"
152        )
153
154        specification_mapper = mapper(
155            SpecLine,
156            specification_table,
157            properties=dict(
158                master=relationship(
159                    Assembly,
160                    foreign_keys=[specification_table.c.master_id],
161                    primaryjoin=specification_table.c.master_id
162                    == products_table.c.product_id,
163                    lazy="select",
164                    backref=backref("specification"),
165                    uselist=False,
166                ),
167                slave=relationship(
168                    Product,
169                    foreign_keys=[specification_table.c.slave_id],
170                    primaryjoin=specification_table.c.slave_id
171                    == products_table.c.product_id,
172                    lazy="select",
173                    uselist=False,
174                ),
175                quantity=specification_table.c.quantity,
176            ),
177        )
178
179        session = create_session()
180
181        a1 = Assembly(name="a1")
182
183        p1 = Product(name="p1")
184        a1.specification.append(SpecLine(slave=p1))
185
186        d1 = Detail(name="d1")
187        a1.specification.append(SpecLine(slave=d1))
188
189        session.add(a1)
190        orig = repr(a1)
191        session.flush()
192        session.expunge_all()
193
194        a1 = session.query(Product).filter_by(name="a1").one()
195        new = repr(a1)
196        print(orig)
197        print(new)
198        assert (
199            orig == new == "<Assembly a1> specification=[<SpecLine 1.0 "
200            "<Product p1>>, <SpecLine 1.0 <Detail d1>>] documents=None"
201        )
202
203    def test_two(self):
204        product_mapper = mapper(
205            Product,
206            products_table,
207            polymorphic_on=products_table.c.product_type,
208            polymorphic_identity="product",
209        )
210
211        detail_mapper = mapper(
212            Detail, inherits=product_mapper, polymorphic_identity="detail"
213        )
214
215        specification_mapper = mapper(
216            SpecLine,
217            specification_table,
218            properties=dict(
219                slave=relationship(
220                    Product,
221                    foreign_keys=[specification_table.c.slave_id],
222                    primaryjoin=specification_table.c.slave_id
223                    == products_table.c.product_id,
224                    lazy="select",
225                    uselist=False,
226                )
227            ),
228        )
229
230        session = create_session()
231
232        s = SpecLine(slave=Product(name="p1"))
233        s2 = SpecLine(slave=Detail(name="d1"))
234        session.add(s)
235        session.add(s2)
236        orig = repr([s, s2])
237        session.flush()
238        session.expunge_all()
239        new = repr(session.query(SpecLine).all())
240        print(orig)
241        print(new)
242        assert (
243            orig == new == "[<SpecLine 1.0 <Product p1>>, "
244            "<SpecLine 1.0 <Detail d1>>]"
245        )
246
247    def test_three(self):
248        product_mapper = mapper(
249            Product,
250            products_table,
251            polymorphic_on=products_table.c.product_type,
252            polymorphic_identity="product",
253        )
254        detail_mapper = mapper(
255            Detail, inherits=product_mapper, polymorphic_identity="detail"
256        )
257        assembly_mapper = mapper(
258            Assembly, inherits=product_mapper, polymorphic_identity="assembly"
259        )
260
261        specification_mapper = mapper(
262            SpecLine,
263            specification_table,
264            properties=dict(
265                master=relationship(
266                    Assembly,
267                    lazy="joined",
268                    uselist=False,
269                    foreign_keys=[specification_table.c.master_id],
270                    primaryjoin=specification_table.c.master_id
271                    == products_table.c.product_id,
272                    backref=backref(
273                        "specification", cascade="all, delete-orphan"
274                    ),
275                ),
276                slave=relationship(
277                    Product,
278                    lazy="joined",
279                    uselist=False,
280                    foreign_keys=[specification_table.c.slave_id],
281                    primaryjoin=specification_table.c.slave_id
282                    == products_table.c.product_id,
283                ),
284                quantity=specification_table.c.quantity,
285            ),
286        )
287
288        document_mapper = mapper(
289            Document,
290            documents_table,
291            polymorphic_on=documents_table.c.document_type,
292            polymorphic_identity="document",
293            properties=dict(
294                name=documents_table.c.name,
295                data=deferred(documents_table.c.data),
296                product=relationship(
297                    Product,
298                    lazy="select",
299                    backref=backref("documents", cascade="all, delete-orphan"),
300                ),
301            ),
302        )
303        raster_document_mapper = mapper(
304            RasterDocument,
305            inherits=document_mapper,
306            polymorphic_identity="raster_document",
307        )
308
309        session = create_session()
310
311        a1 = Assembly(name="a1")
312        a1.specification.append(SpecLine(slave=Detail(name="d1")))
313        a1.documents.append(Document("doc1"))
314        a1.documents.append(RasterDocument("doc2"))
315        session.add(a1)
316        orig = repr(a1)
317        session.flush()
318        session.expunge_all()
319
320        a1 = session.query(Product).filter_by(name="a1").one()
321        new = repr(a1)
322        print(orig)
323        print(new)
324        assert (
325            orig == new == "<Assembly a1> specification="
326            "[<SpecLine 1.0 <Detail d1>>] "
327            "documents=[<Document doc1>, <RasterDocument doc2>]"
328        )
329
330    def test_four(self):
331        """this tests the RasterDocument being attached to the Assembly, but
332        *not* the Document.  this means only a "sub-class" task, i.e.
333        corresponding to an inheriting mapper but not the base mapper,
334        is created. """
335
336        product_mapper = mapper(
337            Product,
338            products_table,
339            polymorphic_on=products_table.c.product_type,
340            polymorphic_identity="product",
341        )
342        detail_mapper = mapper(
343            Detail, inherits=product_mapper, polymorphic_identity="detail"
344        )
345        assembly_mapper = mapper(
346            Assembly, inherits=product_mapper, polymorphic_identity="assembly"
347        )
348
349        document_mapper = mapper(
350            Document,
351            documents_table,
352            polymorphic_on=documents_table.c.document_type,
353            polymorphic_identity="document",
354            properties=dict(
355                name=documents_table.c.name,
356                data=deferred(documents_table.c.data),
357                product=relationship(
358                    Product,
359                    lazy="select",
360                    backref=backref("documents", cascade="all, delete-orphan"),
361                ),
362            ),
363        )
364        raster_document_mapper = mapper(
365            RasterDocument,
366            inherits=document_mapper,
367            polymorphic_identity="raster_document",
368        )
369
370        session = create_session()
371
372        a1 = Assembly(name="a1")
373        a1.documents.append(RasterDocument("doc2"))
374        session.add(a1)
375        orig = repr(a1)
376        session.flush()
377        session.expunge_all()
378
379        a1 = session.query(Product).filter_by(name="a1").one()
380        new = repr(a1)
381        print(orig)
382        print(new)
383        assert (
384            orig == new == "<Assembly a1> specification=None documents="
385            "[<RasterDocument doc2>]"
386        )
387
388        del a1.documents[0]
389        session.flush()
390        session.expunge_all()
391
392        a1 = session.query(Product).filter_by(name="a1").one()
393        assert len(session.query(Document).all()) == 0
394
395    def test_five(self):
396        """tests the late compilation of mappers"""
397
398        specification_mapper = mapper(
399            SpecLine,
400            specification_table,
401            properties=dict(
402                master=relationship(
403                    Assembly,
404                    lazy="joined",
405                    uselist=False,
406                    foreign_keys=[specification_table.c.master_id],
407                    primaryjoin=specification_table.c.master_id
408                    == products_table.c.product_id,
409                    backref=backref("specification"),
410                ),
411                slave=relationship(
412                    Product,
413                    lazy="joined",
414                    uselist=False,
415                    foreign_keys=[specification_table.c.slave_id],
416                    primaryjoin=specification_table.c.slave_id
417                    == products_table.c.product_id,
418                ),
419                quantity=specification_table.c.quantity,
420            ),
421        )
422
423        product_mapper = mapper(
424            Product,
425            products_table,
426            polymorphic_on=products_table.c.product_type,
427            polymorphic_identity="product",
428            properties={
429                "documents": relationship(
430                    Document,
431                    lazy="select",
432                    backref="product",
433                    cascade="all, delete-orphan",
434                )
435            },
436        )
437
438        detail_mapper = mapper(
439            Detail, inherits=Product, polymorphic_identity="detail"
440        )
441
442        document_mapper = mapper(
443            Document,
444            documents_table,
445            polymorphic_on=documents_table.c.document_type,
446            polymorphic_identity="document",
447            properties=dict(
448                name=documents_table.c.name,
449                data=deferred(documents_table.c.data),
450            ),
451        )
452
453        raster_document_mapper = mapper(
454            RasterDocument,
455            inherits=Document,
456            polymorphic_identity="raster_document",
457        )
458
459        assembly_mapper = mapper(
460            Assembly, inherits=Product, polymorphic_identity="assembly"
461        )
462
463        session = create_session()
464
465        a1 = Assembly(name="a1")
466        a1.specification.append(SpecLine(slave=Detail(name="d1")))
467        a1.documents.append(Document("doc1"))
468        a1.documents.append(RasterDocument("doc2"))
469        session.add(a1)
470        orig = repr(a1)
471        session.flush()
472        session.expunge_all()
473
474        a1 = session.query(Product).filter_by(name="a1").one()
475        new = repr(a1)
476        print(orig)
477        print(new)
478        assert (
479            orig == new == "<Assembly a1> specification="
480            "[<SpecLine 1.0 <Detail d1>>] documents=[<Document doc1>, "
481            "<RasterDocument doc2>]"
482        )
483