1from datetime import datetime, timedelta
2
3import pytest
4import sqlalchemy as sa
5
6from sqlalchemy_utils import DateRangeType
7
8intervals = None
9inf = 0
10try:
11    import intervals
12    from infinity import inf
13except ImportError:
14    pass
15
16
17@pytest.fixture
18def Booking(Base):
19    class Booking(Base):
20        __tablename__ = 'booking'
21        id = sa.Column(sa.Integer, primary_key=True)
22        during = sa.Column(DateRangeType)
23
24    return Booking
25
26
27@pytest.fixture
28def create_booking(session, Booking):
29    def create_booking(date_range):
30        booking = Booking(
31            during=date_range
32        )
33        session.add(booking)
34        session.commit()
35        return session.query(Booking).first()
36    return create_booking
37
38
39@pytest.fixture
40def init_models(Booking):
41    pass
42
43
44@pytest.mark.skipif('intervals is None')
45class DateRangeTestCase:
46    def test_nullify_range(self, create_booking):
47        booking = create_booking(None)
48        assert booking.during is None
49
50    @pytest.mark.parametrize(
51        ('date_range'),
52        (
53            [datetime(2015, 1, 1).date(), datetime(2015, 1, 3).date()],
54            [datetime(2015, 1, 1).date(), inf],
55            [-inf, datetime(2015, 1, 1).date()]
56        )
57    )
58    def test_save_date_range(self, create_booking, date_range):
59        booking = create_booking(date_range)
60        assert booking.during.lower == date_range[0]
61        assert booking.during.upper == date_range[1]
62
63    def test_nullify_date_range(self, session, Booking):
64        booking = Booking(
65            during=intervals.DateInterval(
66                [datetime(2015, 1, 1).date(), datetime(2015, 1, 3).date()]
67            )
68        )
69
70        session.add(booking)
71        session.commit()
72
73        booking = session.query(Booking).first()
74        booking.during = None
75        session.commit()
76
77        booking = session.query(Booking).first()
78        assert booking.during is None
79
80    def test_integer_coercion(self, Booking):
81        booking = Booking(during=datetime(2015, 1, 1).date())
82        assert booking.during.lower == datetime(2015, 1, 1).date()
83        assert booking.during.upper == datetime(2015, 1, 1).date()
84
85
86@pytest.mark.usefixtures('postgresql_dsn')
87class TestDateRangeOnPostgres(DateRangeTestCase):
88    @pytest.mark.parametrize(
89        ('date_range', 'length'),
90        (
91            (
92                [datetime(2015, 1, 1).date(), datetime(2015, 1, 3).date()],
93                timedelta(days=2)
94            ),
95            (
96                [datetime(2015, 1, 1).date(), datetime(2015, 1, 1).date()],
97                timedelta(days=0)
98            ),
99            ([-inf, datetime(2015, 1, 1).date()], None),
100            ([datetime(2015, 1, 1).date(), inf], None),
101        )
102    )
103    def test_length(
104        self,
105        session,
106        Booking,
107        create_booking,
108        date_range,
109        length
110    ):
111        create_booking(date_range)
112        query = (
113            session.query(Booking.during.length)
114        )
115        assert query.scalar() == length
116
117    def test_literal_param(self, session, Booking):
118        clause = Booking.during == [
119            datetime(2015, 1, 1).date(),
120            datetime(2015, 1, 3).date()
121        ]
122        compiled = str(clause.compile(compile_kwargs={'literal_binds': True}))
123        assert compiled == "booking.during = '[2015-01-01, 2015-01-03]'"
124