1import pytest
2from graphql.backend import GraphQLCachedBackend, GraphQLCoreBackend
3
4import graphene
5from graphene import relay
6
7from ..fields import BatchSQLAlchemyConnectionField
8from ..types import SQLAlchemyObjectType
9from .models import Article, HairKind, Pet, Reporter
10from .utils import is_sqlalchemy_version_less_than
11
12if is_sqlalchemy_version_less_than('1.2'):
13    pytest.skip('SQL batching only works for SQLAlchemy 1.2+', allow_module_level=True)
14
15
16def get_schema():
17    class ReporterType(SQLAlchemyObjectType):
18        class Meta:
19            model = Reporter
20            interfaces = (relay.Node,)
21            connection_field_factory = BatchSQLAlchemyConnectionField.from_relationship
22
23    class ArticleType(SQLAlchemyObjectType):
24        class Meta:
25            model = Article
26            interfaces = (relay.Node,)
27            connection_field_factory = BatchSQLAlchemyConnectionField.from_relationship
28
29    class PetType(SQLAlchemyObjectType):
30        class Meta:
31            model = Pet
32            interfaces = (relay.Node,)
33            connection_field_factory = BatchSQLAlchemyConnectionField.from_relationship
34
35    class Query(graphene.ObjectType):
36        articles = graphene.Field(graphene.List(ArticleType))
37        reporters = graphene.Field(graphene.List(ReporterType))
38
39        def resolve_articles(self, info):
40            return info.context.get('session').query(Article).all()
41
42        def resolve_reporters(self, info):
43            return info.context.get('session').query(Reporter).all()
44
45    return graphene.Schema(query=Query)
46
47
48def benchmark_query(session_factory, benchmark, query):
49    schema = get_schema()
50    cached_backend = GraphQLCachedBackend(GraphQLCoreBackend())
51    cached_backend.document_from_string(schema, query)  # Prime cache
52
53    @benchmark
54    def execute_query():
55        result = schema.execute(
56          query,
57          context_value={"session": session_factory()},
58          backend=cached_backend,
59        )
60        assert not result.errors
61
62
63def test_one_to_one(session_factory, benchmark):
64    session = session_factory()
65
66    reporter_1 = Reporter(
67      first_name='Reporter_1',
68    )
69    session.add(reporter_1)
70    reporter_2 = Reporter(
71      first_name='Reporter_2',
72    )
73    session.add(reporter_2)
74
75    article_1 = Article(headline='Article_1')
76    article_1.reporter = reporter_1
77    session.add(article_1)
78
79    article_2 = Article(headline='Article_2')
80    article_2.reporter = reporter_2
81    session.add(article_2)
82
83    session.commit()
84    session.close()
85
86    benchmark_query(session_factory, benchmark, """
87      query {
88        reporters {
89          firstName
90          favoriteArticle {
91            headline
92          }
93        }
94      }
95    """)
96
97
98def test_many_to_one(session_factory, benchmark):
99    session = session_factory()
100
101    reporter_1 = Reporter(
102      first_name='Reporter_1',
103    )
104    session.add(reporter_1)
105    reporter_2 = Reporter(
106      first_name='Reporter_2',
107    )
108    session.add(reporter_2)
109
110    article_1 = Article(headline='Article_1')
111    article_1.reporter = reporter_1
112    session.add(article_1)
113
114    article_2 = Article(headline='Article_2')
115    article_2.reporter = reporter_2
116    session.add(article_2)
117
118    session.commit()
119    session.close()
120
121    benchmark_query(session_factory, benchmark, """
122      query {
123        articles {
124          headline
125          reporter {
126            firstName
127          }
128        }
129      }
130    """)
131
132
133def test_one_to_many(session_factory, benchmark):
134    session = session_factory()
135
136    reporter_1 = Reporter(
137      first_name='Reporter_1',
138    )
139    session.add(reporter_1)
140    reporter_2 = Reporter(
141      first_name='Reporter_2',
142    )
143    session.add(reporter_2)
144
145    article_1 = Article(headline='Article_1')
146    article_1.reporter = reporter_1
147    session.add(article_1)
148
149    article_2 = Article(headline='Article_2')
150    article_2.reporter = reporter_1
151    session.add(article_2)
152
153    article_3 = Article(headline='Article_3')
154    article_3.reporter = reporter_2
155    session.add(article_3)
156
157    article_4 = Article(headline='Article_4')
158    article_4.reporter = reporter_2
159    session.add(article_4)
160
161    session.commit()
162    session.close()
163
164    benchmark_query(session_factory, benchmark, """
165      query {
166        reporters {
167          firstName
168          articles(first: 2) {
169            edges {
170              node {
171                headline
172              }
173            }
174          }
175        }
176      }
177    """)
178
179
180def test_many_to_many(session_factory, benchmark):
181    session = session_factory()
182
183    reporter_1 = Reporter(
184      first_name='Reporter_1',
185    )
186    session.add(reporter_1)
187    reporter_2 = Reporter(
188      first_name='Reporter_2',
189    )
190    session.add(reporter_2)
191
192    pet_1 = Pet(name='Pet_1', pet_kind='cat', hair_kind=HairKind.LONG)
193    session.add(pet_1)
194
195    pet_2 = Pet(name='Pet_2', pet_kind='cat', hair_kind=HairKind.LONG)
196    session.add(pet_2)
197
198    reporter_1.pets.append(pet_1)
199    reporter_1.pets.append(pet_2)
200
201    pet_3 = Pet(name='Pet_3', pet_kind='cat', hair_kind=HairKind.LONG)
202    session.add(pet_3)
203
204    pet_4 = Pet(name='Pet_4', pet_kind='cat', hair_kind=HairKind.LONG)
205    session.add(pet_4)
206
207    reporter_2.pets.append(pet_3)
208    reporter_2.pets.append(pet_4)
209
210    session.commit()
211    session.close()
212
213    benchmark_query(session_factory, benchmark, """
214      query {
215        reporters {
216          firstName
217          pets(first: 2) {
218            edges {
219              node {
220                name
221              }
222            }
223          }
224        }
225      }
226    """)
227