1import pytest
2import sqlalchemy as sa
3
4from sqlalchemy_utils import get_mapper
5from sqlalchemy_utils.functions.orm import _get_query_compile_state
6
7
8class TestGetMapper(object):
9
10    @pytest.fixture
11    def Building(self, Base):
12        class Building(Base):
13            __tablename__ = 'building'
14            id = sa.Column(sa.Integer, primary_key=True)
15        return Building
16
17    def test_table(self, Building):
18        assert get_mapper(Building.__table__) == sa.inspect(Building)
19
20    def test_declarative_class(self, Building):
21        assert (
22            get_mapper(Building) ==
23            sa.inspect(Building)
24        )
25
26    def test_declarative_object(self, Building):
27        assert (
28            get_mapper(Building()) ==
29            sa.inspect(Building)
30        )
31
32    def test_mapper(self, Building):
33        assert (
34            get_mapper(Building.__mapper__) ==
35            sa.inspect(Building)
36        )
37
38    def test_class_alias(self, Building):
39        assert (
40            get_mapper(sa.orm.aliased(Building)) ==
41            sa.inspect(Building)
42        )
43
44    def test_instrumented_attribute(self, Building):
45        assert (
46            get_mapper(Building.id) == sa.inspect(Building)
47        )
48
49    def test_table_alias(self, Building):
50        alias = sa.orm.aliased(Building.__table__)
51        assert (
52            get_mapper(alias) ==
53            sa.inspect(Building)
54        )
55
56    def test_column(self, Building):
57        assert (
58            get_mapper(Building.__table__.c.id) ==
59            sa.inspect(Building)
60        )
61
62    def test_column_of_an_alias(self, Building):
63        assert (
64            get_mapper(sa.orm.aliased(Building.__table__).c.id) ==
65            sa.inspect(Building)
66        )
67
68
69class TestGetMapperWithQueryEntities(object):
70
71    @pytest.fixture
72    def Building(self, Base):
73        class Building(Base):
74            __tablename__ = 'building'
75            id = sa.Column(sa.Integer, primary_key=True)
76        return Building
77
78    @pytest.fixture
79    def init_models(self, Building):
80        pass
81
82    def test_mapper_entity_with_mapper(self, session, Building):
83        query = session.query(Building.__mapper__)
84        entity = _get_query_compile_state(query)._entities[0]
85        assert get_mapper(entity) == sa.inspect(Building)
86
87    def test_mapper_entity_with_class(self, session, Building):
88        query = session.query(Building)
89        entity = _get_query_compile_state(query)._entities[0]
90        assert get_mapper(entity) == sa.inspect(Building)
91
92    def test_column_entity(self, session, Building):
93        query = session.query(Building.id)
94        entity = _get_query_compile_state(query)._entities[0]
95        assert get_mapper(entity) == sa.inspect(Building)
96
97
98class TestGetMapperWithMultipleMappersFound(object):
99
100    @pytest.fixture
101    def Building(self, Base):
102        class Building(Base):
103            __tablename__ = 'building'
104            id = sa.Column(sa.Integer, primary_key=True)
105
106        class BigBuilding(Building):
107            pass
108
109        return Building
110
111    def test_table(self, Building):
112        with pytest.raises(ValueError):
113            get_mapper(Building.__table__)
114
115    def test_table_alias(self, Building):
116        alias = sa.orm.aliased(Building.__table__)
117        with pytest.raises(ValueError):
118            get_mapper(alias)
119
120
121class TestGetMapperForTableWithoutMapper(object):
122
123    @pytest.fixture
124    def building(self):
125        metadata = sa.MetaData()
126        return sa.Table('building', metadata)
127
128    def test_table(self, building):
129        with pytest.raises(ValueError):
130            get_mapper(building)
131
132    def test_table_alias(self, building):
133        alias = sa.orm.aliased(building)
134        with pytest.raises(ValueError):
135            get_mapper(alias)
136