1import asyncio
2import json
3import logging
4import sys
5from typing import Any, Callable, Coroutine, List, Optional, Mapping
6
7import websockets
8
9logger = logging.getLogger("webdriver.bidi")
10
11
12def get_running_loop() -> asyncio.AbstractEventLoop:
13    if sys.version_info >= (3, 7):
14        return asyncio.get_running_loop()
15    # Unlike the above, this will actually create an event loop
16    # if there isn't one; hopefully running tests in Python >= 3.7
17    # will allow us to catch any behaviour difference
18    return asyncio.get_event_loop()
19
20
21class Transport:
22    """Low level message handler for the WebSockets connection"""
23    def __init__(self, url: str,
24                 msg_handler: Callable[[Mapping[str, Any]], Coroutine[Any, Any, None]],
25                 loop: Optional[asyncio.AbstractEventLoop] = None):
26        self.url = url
27        self.connection: Optional[websockets.WebSocketClientProtocol] = None
28        self.msg_handler = msg_handler
29        self.send_buf: List[Mapping[str, Any]] = []
30
31        if loop is None:
32            loop = get_running_loop()
33        self.loop = loop
34
35        self.read_message_task: Optional[asyncio.Task[Any]] = None
36
37    async def start(self) -> None:
38        self.connection = await websockets.client.connect(self.url)
39        self.read_message_task = self.loop.create_task(self.read_messages())
40
41        for msg in self.send_buf:
42            await self._send(self.connection, msg)
43
44    async def send(self, data: Mapping[str, Any]) -> None:
45        if self.connection is not None:
46            await self._send(self.connection, data)
47        else:
48            self.send_buf.append(data)
49
50    @staticmethod
51    async def _send(
52        connection: websockets.WebSocketClientProtocol,
53        data: Mapping[str, Any]
54    ) -> None:
55        msg = json.dumps(data)
56        logger.debug("→ %s", msg)
57        await connection.send(msg)
58
59    async def handle(self, msg: str) -> None:
60        logger.debug("← %s", msg)
61        data = json.loads(msg)
62        await self.msg_handler(data)
63
64    async def end(self) -> None:
65        if self.connection:
66            await self.connection.close()
67            self.connection = None
68
69    async def read_messages(self) -> None:
70        assert self.connection is not None
71        async for msg in self.connection:
72            if not isinstance(msg, str):
73                raise ValueError("Got a binary message")
74            await self.handle(msg)
75