1import random
2import threading
3import time
4
5from sqlalchemy import create_engine
6from sqlalchemy import ForeignKey
7from sqlalchemy import Integer
8from sqlalchemy import MetaData
9from sqlalchemy import String
10from sqlalchemy import testing
11from sqlalchemy.ext.automap import automap_base
12from sqlalchemy.ext.automap import generate_relationship
13from sqlalchemy.orm import configure_mappers
14from sqlalchemy.orm import interfaces
15from sqlalchemy.orm import relationship
16from sqlalchemy.testing import fixtures
17from sqlalchemy.testing.mock import Mock
18from sqlalchemy.testing.mock import patch
19from sqlalchemy.testing.schema import Column
20from sqlalchemy.testing.schema import Table
21from ..orm._fixtures import FixtureTest
22
23
24class AutomapTest(fixtures.MappedTest):
25    @classmethod
26    def define_tables(cls, metadata):
27        FixtureTest.define_tables(metadata)
28
29    def test_relationship_o2m_default(self):
30        Base = automap_base(metadata=self.metadata)
31        Base.prepare()
32
33        User = Base.classes.users
34        Address = Base.classes.addresses
35
36        a1 = Address(email_address="e1")
37        u1 = User(name="u1", addresses_collection=[a1])
38        assert a1.users is u1
39
40    def test_relationship_explicit_override_o2m(self):
41        Base = automap_base(metadata=self.metadata)
42        prop = relationship("addresses", collection_class=set)
43
44        class User(Base):
45            __tablename__ = "users"
46
47            addresses_collection = prop
48
49        Base.prepare()
50        assert User.addresses_collection.property is prop
51        Address = Base.classes.addresses
52
53        a1 = Address(email_address="e1")
54        u1 = User(name="u1", addresses_collection=set([a1]))
55        assert a1.user is u1
56
57    def test_relationship_explicit_override_m2o(self):
58        Base = automap_base(metadata=self.metadata)
59
60        prop = relationship("users")
61
62        class Address(Base):
63            __tablename__ = "addresses"
64
65            users = prop
66
67        Base.prepare()
68        User = Base.classes.users
69
70        assert Address.users.property is prop
71        a1 = Address(email_address="e1")
72        u1 = User(name="u1", address_collection=[a1])
73        assert a1.users is u1
74
75    def test_relationship_self_referential(self):
76        Base = automap_base(metadata=self.metadata)
77        Base.prepare()
78
79        Node = Base.classes.nodes
80
81        n1 = Node()
82        n2 = Node()
83        n1.nodes_collection.append(n2)
84        assert n2.nodes is n1
85
86    def test_prepare_accepts_optional_schema_arg(self):
87        """
88        The underlying reflect call accepts an optional schema argument.
89        This is for determining which database schema to load.
90        This test verifies that prepare can accept an optiona schema argument
91        and pass it to reflect.
92        """
93        Base = automap_base(metadata=self.metadata)
94        engine_mock = Mock()
95        with patch.object(Base.metadata, "reflect") as reflect_mock:
96            Base.prepare(engine_mock, reflect=True, schema="some_schema")
97            reflect_mock.assert_called_once_with(
98                engine_mock,
99                schema="some_schema",
100                extend_existing=True,
101                autoload_replace=False,
102            )
103
104    def test_prepare_defaults_to_no_schema(self):
105        """
106        The underlying reflect call accepts an optional schema argument.
107        This is for determining which database schema to load.
108        This test verifies that prepare passes a default None if no schema is
109        provided.
110        """
111        Base = automap_base(metadata=self.metadata)
112        engine_mock = Mock()
113        with patch.object(Base.metadata, "reflect") as reflect_mock:
114            Base.prepare(engine_mock, reflect=True)
115            reflect_mock.assert_called_once_with(
116                engine_mock,
117                schema=None,
118                extend_existing=True,
119                autoload_replace=False,
120            )
121
122    def test_naming_schemes(self):
123        Base = automap_base(metadata=self.metadata)
124
125        def classname_for_table(base, tablename, table):
126            return str("cls_" + tablename)
127
128        def name_for_scalar_relationship(
129            base, local_cls, referred_cls, constraint
130        ):
131            return "scalar_" + referred_cls.__name__
132
133        def name_for_collection_relationship(
134            base, local_cls, referred_cls, constraint
135        ):
136            return "coll_" + referred_cls.__name__
137
138        Base.prepare(
139            classname_for_table=classname_for_table,
140            name_for_scalar_relationship=name_for_scalar_relationship,
141            name_for_collection_relationship=name_for_collection_relationship,
142        )
143
144        User = Base.classes.cls_users
145        Address = Base.classes.cls_addresses
146
147        u1 = User()
148        a1 = Address()
149        u1.coll_cls_addresses.append(a1)
150        assert a1.scalar_cls_users is u1
151
152    def test_relationship_m2m(self):
153        Base = automap_base(metadata=self.metadata)
154
155        Base.prepare()
156
157        Order, Item = Base.classes.orders, Base.classes["items"]
158
159        o1 = Order()
160        i1 = Item()
161        o1.items_collection.append(i1)
162        assert o1 in i1.orders_collection
163
164    def test_relationship_explicit_override_forwards_m2m(self):
165        Base = automap_base(metadata=self.metadata)
166
167        class Order(Base):
168            __tablename__ = "orders"
169
170            items_collection = relationship(
171                "items", secondary="order_items", collection_class=set
172            )
173
174        Base.prepare()
175
176        Item = Base.classes["items"]
177
178        o1 = Order()
179        i1 = Item()
180        o1.items_collection.add(i1)
181
182        # it is 'order_collection' because the class name is
183        # "Order" !
184        assert isinstance(i1.order_collection, list)
185        assert o1 in i1.order_collection
186
187    def test_relationship_pass_params(self):
188        Base = automap_base(metadata=self.metadata)
189
190        mock = Mock()
191
192        def _gen_relationship(
193            base, direction, return_fn, attrname, local_cls, referred_cls, **kw
194        ):
195            mock(base, direction, attrname)
196            return generate_relationship(
197                base,
198                direction,
199                return_fn,
200                attrname,
201                local_cls,
202                referred_cls,
203                **kw
204            )
205
206        Base.prepare(generate_relationship=_gen_relationship)
207        assert set(tuple(c[1]) for c in mock.mock_calls).issuperset(
208            [
209                (Base, interfaces.MANYTOONE, "nodes"),
210                (Base, interfaces.MANYTOMANY, "keywords_collection"),
211                (Base, interfaces.MANYTOMANY, "items_collection"),
212                (Base, interfaces.MANYTOONE, "users"),
213                (Base, interfaces.ONETOMANY, "addresses_collection"),
214            ]
215        )
216
217
218class CascadeTest(fixtures.MappedTest):
219    @classmethod
220    def define_tables(cls, metadata):
221        Table("a", metadata, Column("id", Integer, primary_key=True))
222        Table(
223            "b",
224            metadata,
225            Column("id", Integer, primary_key=True),
226            Column("aid", ForeignKey("a.id"), nullable=True),
227        )
228        Table(
229            "c",
230            metadata,
231            Column("id", Integer, primary_key=True),
232            Column("aid", ForeignKey("a.id"), nullable=False),
233        )
234        Table(
235            "d",
236            metadata,
237            Column("id", Integer, primary_key=True),
238            Column(
239                "aid", ForeignKey("a.id", ondelete="cascade"), nullable=False
240            ),
241        )
242        Table(
243            "e",
244            metadata,
245            Column("id", Integer, primary_key=True),
246            Column(
247                "aid", ForeignKey("a.id", ondelete="set null"), nullable=True
248            ),
249        )
250
251    def test_o2m_relationship_cascade(self):
252        Base = automap_base(metadata=self.metadata)
253        Base.prepare()
254
255        configure_mappers()
256
257        b_rel = Base.classes.a.b_collection
258        assert not b_rel.property.cascade.delete
259        assert not b_rel.property.cascade.delete_orphan
260        assert not b_rel.property.passive_deletes
261
262        assert b_rel.property.cascade.save_update
263
264        c_rel = Base.classes.a.c_collection
265        assert c_rel.property.cascade.delete
266        assert c_rel.property.cascade.delete_orphan
267        assert not c_rel.property.passive_deletes
268
269        assert c_rel.property.cascade.save_update
270
271        d_rel = Base.classes.a.d_collection
272        assert d_rel.property.cascade.delete
273        assert d_rel.property.cascade.delete_orphan
274        assert d_rel.property.passive_deletes
275
276        assert d_rel.property.cascade.save_update
277
278        e_rel = Base.classes.a.e_collection
279        assert not e_rel.property.cascade.delete
280        assert not e_rel.property.cascade.delete_orphan
281        assert e_rel.property.passive_deletes
282
283        assert e_rel.property.cascade.save_update
284
285
286class AutomapInhTest(fixtures.MappedTest):
287    @classmethod
288    def define_tables(cls, metadata):
289        Table(
290            "single",
291            metadata,
292            Column("id", Integer, primary_key=True),
293            Column("type", String(10)),
294            test_needs_fk=True,
295        )
296
297        Table(
298            "joined_base",
299            metadata,
300            Column("id", Integer, primary_key=True),
301            Column("type", String(10)),
302            test_needs_fk=True,
303        )
304
305        Table(
306            "joined_inh",
307            metadata,
308            Column(
309                "id", Integer, ForeignKey("joined_base.id"), primary_key=True
310            ),
311            test_needs_fk=True,
312        )
313
314        FixtureTest.define_tables(metadata)
315
316    def test_single_inheritance_reflect(self):
317        Base = automap_base()
318
319        class Single(Base):
320            __tablename__ = "single"
321
322            type = Column(String)
323
324            __mapper_args__ = {
325                "polymorphic_identity": "u0",
326                "polymorphic_on": type,
327            }
328
329        class SubUser1(Single):
330            __mapper_args__ = {"polymorphic_identity": "u1"}
331
332        class SubUser2(Single):
333            __mapper_args__ = {"polymorphic_identity": "u2"}
334
335        Base.prepare(engine=testing.db, reflect=True)
336
337        assert SubUser2.__mapper__.inherits is Single.__mapper__
338
339    def test_joined_inheritance_reflect(self):
340        Base = automap_base()
341
342        class Joined(Base):
343            __tablename__ = "joined_base"
344
345            type = Column(String)
346
347            __mapper_args__ = {
348                "polymorphic_identity": "u0",
349                "polymorphic_on": type,
350            }
351
352        class SubJoined(Joined):
353            __tablename__ = "joined_inh"
354            __mapper_args__ = {"polymorphic_identity": "u1"}
355
356        Base.prepare(engine=testing.db, reflect=True)
357
358        assert SubJoined.__mapper__.inherits is Joined.__mapper__
359
360        assert not Joined.__mapper__.relationships
361        assert not SubJoined.__mapper__.relationships
362
363    def test_conditional_relationship(self):
364        Base = automap_base()
365
366        def _gen_relationship(*arg, **kw):
367            return None
368
369        Base.prepare(
370            engine=testing.db,
371            reflect=True,
372            generate_relationship=_gen_relationship,
373        )
374
375
376class ConcurrentAutomapTest(fixtures.TestBase):
377    __only_on__ = "sqlite"
378
379    def _make_tables(self, e):
380        m = MetaData()
381        for i in range(15):
382            Table(
383                "table_%d" % i,
384                m,
385                Column("id", Integer, primary_key=True),
386                Column("data", String(50)),
387                Column(
388                    "t_%d_id" % (i - 1), ForeignKey("table_%d.id" % (i - 1))
389                )
390                if i > 4
391                else None,
392            )
393        m.drop_all(e)
394        m.create_all(e)
395
396    def _automap(self, e):
397        Base = automap_base()
398
399        Base.prepare(e, reflect=True)
400
401        time.sleep(0.01)
402        configure_mappers()
403
404    def _chaos(self):
405        e = create_engine("sqlite://")
406        try:
407            self._make_tables(e)
408            for i in range(2):
409                try:
410                    self._automap(e)
411                except:
412                    self._success = False
413                    raise
414                time.sleep(random.random())
415        finally:
416            e.dispose()
417
418    def test_concurrent_automaps_w_configure(self):
419        self._success = True
420        threads = [threading.Thread(target=self._chaos) for i in range(30)]
421        for t in threads:
422            t.start()
423
424        for t in threads:
425            t.join()
426
427        assert self._success, "One or more threads failed"
428