1import pytest
2import sqlalchemy as sa
3
4from sqlalchemy_utils import IntRangeType
5from sqlalchemy_utils.compat import get_scalar_subquery
6
7intervals = None
8inf = -1
9try:
10    import intervals
11    from infinity import inf
12except ImportError:
13    pass
14
15
16@pytest.fixture
17def Building(Base):
18    class Building(Base):
19        __tablename__ = 'building'
20        id = sa.Column(sa.Integer, primary_key=True)
21        persons_at_night = sa.Column(IntRangeType)
22
23        def __repr__(self):
24            return 'Building(%r)' % self.id
25    return Building
26
27
28@pytest.fixture
29def init_models(Building):
30    pass
31
32
33@pytest.fixture
34def create_building(session, Building):
35    def create_building(number_range):
36        building = Building(
37            persons_at_night=number_range
38        )
39
40        session.add(building)
41        session.commit()
42        return session.query(Building).first()
43    return create_building
44
45
46@pytest.mark.skipif('intervals is None')
47class NumberRangeTestCase(object):
48
49    def test_nullify_range(self, create_building):
50        building = create_building(None)
51        assert building.persons_at_night is None
52
53    def test_update_with_none(self, session, create_building):
54        interval = intervals.IntInterval([None, None])
55        building = create_building(interval)
56        building.persons_at_night = None
57        assert building.persons_at_night is None
58        session.commit()
59        assert building.persons_at_night is None
60
61    @pytest.mark.parametrize(
62        'number_range',
63        (
64            [1, 3],
65            (0, 4),
66        )
67    )
68    def test_save_number_range(self, create_building, number_range):
69        building = create_building(number_range)
70        assert building.persons_at_night.lower == 1
71        assert building.persons_at_night.upper == 3
72
73    def test_infinite_upper_bound(self, create_building):
74        building = create_building([1, inf])
75        assert building.persons_at_night.lower == 1
76        assert building.persons_at_night.upper == inf
77
78    def test_infinite_lower_bound(self, create_building):
79        building = create_building([-inf, 1])
80        assert building.persons_at_night.lower == -inf
81        assert building.persons_at_night.upper == 1
82
83    def test_nullify_number_range(self, session, Building):
84        building = Building(
85            persons_at_night=intervals.IntInterval([1, 3])
86        )
87
88        session.add(building)
89        session.commit()
90
91        building = session.query(Building).first()
92        building.persons_at_night = None
93        session.commit()
94
95        building = session.query(Building).first()
96        assert building.persons_at_night is None
97
98    def test_integer_coercion(self, Building):
99        building = Building(persons_at_night=15)
100        assert building.persons_at_night.lower == 15
101        assert building.persons_at_night.upper == 15
102
103
104@pytest.mark.usefixtures('postgresql_dsn')
105class TestIntRangeTypeOnPostgres(NumberRangeTestCase):
106    @pytest.mark.parametrize(
107        'number_range',
108        (
109            [1, 3],
110            (0, 4)
111        )
112    )
113    def test_eq_operator(
114        self,
115        session,
116        Building,
117        create_building,
118        number_range
119    ):
120        create_building([1, 3])
121        query = (
122            session.query(Building)
123            .filter(Building.persons_at_night == number_range)
124        )
125        assert query.count()
126
127    @pytest.mark.parametrize(
128        ('number_range', 'length'),
129        (
130            ([1, 3], 2),
131            ([1, 1], 0),
132            ([-1, 1], 2),
133            ([-inf, 1], None),
134            ([0, inf], None),
135            ([0, 0], 0),
136            ([-3, -1], 2)
137        )
138    )
139    def test_length(
140        self,
141        session,
142        Building,
143        create_building,
144        number_range,
145        length
146    ):
147        create_building(number_range)
148        query = (
149            session.query(Building.persons_at_night.length)
150        )
151        assert query.scalar() == length
152
153    @pytest.mark.parametrize(
154        'number_range',
155        (
156            [[1, 3]],
157            [(0, 4)],
158        )
159    )
160    def test_in_operator(
161        self,
162        session,
163        Building,
164        create_building,
165        number_range
166    ):
167        create_building([1, 3])
168        query = (
169            session.query(Building)
170            .filter(Building.persons_at_night.in_(number_range))
171        )
172        assert query.count()
173
174    @pytest.mark.parametrize(
175        'number_range',
176        (
177            [1, 3],
178            (0, 4),
179        )
180    )
181    def test_rshift_operator(
182        self,
183        session,
184        Building,
185        create_building,
186        number_range
187    ):
188        create_building([5, 6])
189        query = (
190            session.query(Building)
191            .filter(Building.persons_at_night >> number_range)
192        )
193        assert query.count()
194
195    @pytest.mark.parametrize(
196        'number_range',
197        (
198            [1, 3],
199            (0, 4),
200        )
201    )
202    def test_lshift_operator(
203        self,
204        session,
205        Building,
206        create_building,
207        number_range
208    ):
209        create_building([-1, 0])
210        query = (
211            session.query(Building)
212            .filter(Building.persons_at_night << number_range)
213        )
214        assert query.count()
215
216    @pytest.mark.parametrize(
217        'number_range',
218        (
219            [1, 3],
220            (1, 3),
221            2
222        )
223    )
224    def test_contains_operator(
225        self,
226        session,
227        Building,
228        create_building,
229        number_range
230    ):
231        create_building([1, 3])
232        query = (
233            session.query(Building)
234            .filter(Building.persons_at_night.contains(number_range))
235        )
236        assert query.count()
237
238    @pytest.mark.parametrize(
239        'number_range',
240        (
241            [1, 3],
242            (0, 8),
243            (-inf, inf)
244        )
245    )
246    def test_contained_by_operator(
247        self,
248        session,
249        Building,
250        create_building,
251        number_range
252    ):
253        create_building([1, 3])
254        query = (
255            session.query(Building)
256            .filter(Building.persons_at_night.contained_by(number_range))
257        )
258        assert query.count()
259
260    @pytest.mark.parametrize(
261        'number_range',
262        (
263            [2, 5],
264            0
265        )
266    )
267    def test_not_in_operator(
268        self,
269        session,
270        Building,
271        create_building,
272        number_range
273    ):
274        create_building([1, 3])
275        query = (
276            session.query(Building)
277            .filter(~ Building.persons_at_night.in_([number_range]))
278        )
279        assert query.count()
280
281    def test_eq_with_query_arg(self, session, Building, create_building):
282        create_building([1, 3])
283        query = (
284            session.query(Building)
285            .filter(
286                Building.persons_at_night ==
287                get_scalar_subquery(session.query(Building.persons_at_night))
288            ).order_by(Building.persons_at_night).limit(1)
289        )
290        assert query.count()
291
292    @pytest.mark.parametrize(
293        'number_range',
294        (
295            [1, 2],
296            (0, 4),
297            [0, 3],
298            0,
299            1,
300        )
301    )
302    def test_ge_operator(
303        self,
304        session,
305        Building,
306        create_building,
307        number_range
308    ):
309        create_building([1, 3])
310        query = (
311            session.query(Building)
312            .filter(Building.persons_at_night >= number_range)
313        )
314        assert query.count()
315
316    @pytest.mark.parametrize(
317        'number_range',
318        (
319            [0, 2],
320            0,
321            [-inf, 2]
322        )
323    )
324    def test_gt_operator(
325        self,
326        session,
327        Building,
328        create_building,
329        number_range
330    ):
331        create_building([1, 3])
332        query = (
333            session.query(Building)
334            .filter(Building.persons_at_night > number_range)
335        )
336        assert query.count()
337
338    @pytest.mark.parametrize(
339        'number_range',
340        (
341            [1, 4],
342            4,
343            [2, inf]
344        )
345    )
346    def test_le_operator(
347        self,
348        session,
349        Building,
350        create_building,
351        number_range
352    ):
353        create_building([1, 3])
354        query = (
355            session.query(Building)
356            .filter(Building.persons_at_night <= number_range)
357        )
358        assert query.count()
359
360    @pytest.mark.parametrize(
361        'number_range',
362        (
363            [2, 4],
364            4,
365            [1, inf]
366        )
367    )
368    def test_lt_operator(
369        self,
370        session,
371        Building,
372        create_building,
373        number_range
374    ):
375        create_building([1, 3])
376        query = (
377            session.query(Building)
378            .filter(Building.persons_at_night < number_range)
379        )
380        assert query.count()
381
382    def test_literal_param(self, session, Building):
383        clause = Building.persons_at_night == [1, 3]
384        compiled = str(clause.compile(compile_kwargs={'literal_binds': True}))
385        assert compiled == "building.persons_at_night = '[1, 3]'"
386
387
388class TestNumberRangeTypeOnSqlite(NumberRangeTestCase):
389    pass
390