1import pathlib
2import sys
3
4import attr
5import pytest
6import salt.ext.tornado.gen
7import salt.transport.client
8import salt.transport.ipc
9import salt.transport.server
10import salt.utils.platform
11from salt.ext.tornado import locks
12
13pytestmark = [
14    # Windows does not support POSIX IPC
15    pytest.mark.skip_on_windows,
16    pytest.mark.skipif(
17        sys.version_info < (3, 6), reason="The IOLoop blocks under Py3.5 on these tests"
18    ),
19]
20
21
22@attr.s(frozen=True, slots=True)
23class PayloadHandler:
24    payloads = attr.ib(init=False, default=attr.Factory(list))
25
26    async def handle_payload(self, payload, reply_func):
27        self.payloads.append(payload)
28        await reply_func(payload)
29
30    def __enter__(self):
31        return self
32
33    def __exit__(self, *args):
34        self.payloads.clear()
35
36
37@attr.s(frozen=True, slots=True)
38class IPCTester:
39    io_loop = attr.ib()
40    socket_path = attr.ib()
41    publisher = attr.ib()
42    subscriber = attr.ib()
43    payloads = attr.ib(default=attr.Factory(list))
44    payload_ack = attr.ib(default=attr.Factory(locks.Condition))
45
46    @subscriber.default
47    def _subscriber_default(self):
48        return salt.transport.ipc.IPCMessageSubscriber(
49            self.socket_path,
50            io_loop=self.io_loop,
51        )
52
53    @publisher.default
54    def _publisher_default(self):
55        return salt.transport.ipc.IPCMessagePublisher(
56            {"ipc_write_buffer": 0},
57            self.socket_path,
58            io_loop=self.io_loop,
59        )
60
61    async def handle_payload(self, payload, reply_func):
62        self.payloads.append(payload)
63        await reply_func(payload)
64        self.payload_ack.notify()
65
66    def new_client(self):
67        return IPCTester(
68            io_loop=self.io_loop,
69            socket_path=self.socket_path,
70            server=self.server,
71            payloads=self.payloads,
72            payload_ack=self.payload_ack,
73        )
74
75    async def publish(self, payload, timeout=60):
76        self.publisher.publish(payload)
77
78    async def read(self, timeout=60):
79        ret = await self.subscriber.read(timeout)
80        return ret
81
82    def __enter__(self):
83        self.publisher.start()
84        self.io_loop.add_callback(self.subscriber.connect)
85        return self
86
87    def __exit__(self, *args):
88        self.subscriber.close()
89        self.publisher.close()
90
91
92@pytest.fixture
93def ipc_socket_path(tmp_path):
94    if salt.utils.platform.is_darwin():
95        # A shorter path so that we don't hit the AF_UNIX path too long
96        tmp_path = pathlib.Path("/tmp").resolve()
97    _socket_path = tmp_path / "ipc-test.ipc"
98    try:
99        yield _socket_path
100    finally:
101        if _socket_path.exists():
102            _socket_path.unlink()
103
104
105@pytest.fixture
106def channel(io_loop, ipc_socket_path):
107    _ipc_tester = IPCTester(io_loop=io_loop, socket_path=str(ipc_socket_path))
108    with _ipc_tester:
109        yield _ipc_tester
110
111
112async def test_basic_send(channel):
113    msg = {"foo": "bar", "stop": True}
114    # XXX: IPCClient connect and connected methods need to be cleaned up as
115    # this should not be needed.
116    while not channel.subscriber._connecting_future.done():
117        await salt.ext.tornado.gen.sleep(0.01)
118    while not channel.subscriber.connected():
119        await salt.ext.tornado.gen.sleep(0.01)
120    assert channel.subscriber.connected()
121    await channel.publish(msg)
122    ret = await channel.read()
123    assert ret == msg
124