1import sys
2import asyncio
3import itertools
4import socket
5
6import pytest
7
8from postfix_mta_sts_resolver import netstring
9from postfix_mta_sts_resolver.responder import STSSocketmapResponder
10import postfix_mta_sts_resolver.utils as utils
11
12@pytest.fixture
13async def responder(event_loop):
14    import postfix_mta_sts_resolver.utils as utils
15    cfg = utils.populate_cfg_defaults(None)
16    cfg["port"] = 38461
17    cfg["shutdown_timeout"] = 1
18    cfg["cache_grace"] = 0
19    cfg["zones"]["test2"] = cfg["default_zone"]
20    cache = utils.create_cache(cfg['cache']['type'],
21                               cfg['cache']['options'])
22    await cache.setup()
23    resp = STSSocketmapResponder(cfg, event_loop, cache)
24    await resp.start()
25    result = resp, cfg['host'], cfg['port']
26    yield result
27    await resp.stop()
28    await cache.teardown()
29
30@pytest.mark.asyncio
31@pytest.mark.timeout(5)
32async def test_hanging_stop(responder):
33    resp, host, port = responder
34    reader, writer = await asyncio.open_connection(host, port)
35    await resp.stop()
36    assert await reader.read() == b''
37    writer.close()
38
39@pytest.mark.asyncio
40@pytest.mark.timeout(5)
41async def test_inprogress_stop(responder):
42    resp, host, port = responder
43    reader, writer = await asyncio.open_connection(host, port)
44    writer.write(netstring.encode(b'test blackhole.loc'))
45    await writer.drain()
46    await asyncio.sleep(0.2)
47    await resp.stop()
48    assert await reader.read() == b''
49    writer.close()
50
51@pytest.mark.asyncio
52@pytest.mark.timeout(5)
53async def test_extended_stop(responder):
54    resp, host, port = responder
55    reader, writer = await asyncio.open_connection(host, port)
56    writer.write(netstring.encode(b'test blackhole.loc'))
57    writer.write(netstring.encode(b'test blackhole.loc'))
58    writer.write(netstring.encode(b'test blackhole.loc'))
59    await writer.drain()
60    await asyncio.sleep(0.2)
61    await resp.stop()
62    assert await reader.read() == b''
63    writer.close()
64
65@pytest.mark.asyncio
66@pytest.mark.timeout(7)
67async def test_grace_expired(responder):
68    resp, host, port = responder
69    reader, writer = await asyncio.open_connection(host, port)
70    stream_reader = netstring.StreamReader()
71    async def answer():
72        string_reader = stream_reader.next_string()
73        res = b''
74        while True:
75            try:
76                part = string_reader.read()
77            except netstring.WantRead:
78                data = await reader.read(4096)
79                assert data
80                stream_reader.feed(data)
81            else:
82                if not part:
83                    break
84                res += part
85        return res
86    try:
87        writer.write(netstring.encode(b'test good.loc'))
88        answer_a = await answer()
89        await asyncio.sleep(2)
90        writer.write(netstring.encode(b'test good.loc'))
91        answer_b = await answer()
92        assert answer_a == answer_b
93    finally:
94        writer.close()
95
96@pytest.mark.asyncio
97@pytest.mark.timeout(7)
98async def test_fast_expire(responder):
99    resp, host, port = responder
100    reader, writer = await asyncio.open_connection(host, port)
101    stream_reader = netstring.StreamReader()
102    async def answer():
103        string_reader = stream_reader.next_string()
104        res = b''
105        while True:
106            try:
107                part = string_reader.read()
108            except netstring.WantRead:
109                data = await reader.read(4096)
110                assert data
111                stream_reader.feed(data)
112            else:
113                if not part:
114                    break
115                res += part
116        return res
117    try:
118        writer.write(netstring.encode(b'test fast-expire.loc'))
119        answer_a = await answer()
120        await asyncio.sleep(2)
121        writer.write(netstring.encode(b'test fast-expire.loc'))
122        answer_b = await answer()
123        assert answer_a == answer_b == b'OK secure match=mail.loc servername=hostname'
124    finally:
125        writer.close()
126