1#!/usr/bin/env python3
2"""Protocol parser example."""
3import argparse
4import asyncio
5import collections
6
7import aiohttp
8
9try:
10    import signal
11except ImportError:
12    signal = None
13
14
15MSG_TEXT = b"text:"
16MSG_PING = b"ping:"
17MSG_PONG = b"pong:"
18MSG_STOP = b"stop:"
19
20Message = collections.namedtuple("Message", ("tp", "data"))
21
22
23def my_protocol_parser(out, buf):
24    """Parser is used with StreamParser for incremental protocol parsing.
25    Parser is a generator function, but it is not a coroutine. Usually
26    parsers are implemented as a state machine.
27
28    more details in asyncio/parsers.py
29    existing parsers:
30      * HTTP protocol parsers asyncio/http/protocol.py
31      * websocket parser asyncio/http/websocket.py
32    """
33    while True:
34        tp = yield from buf.read(5)
35        if tp in (MSG_PING, MSG_PONG):
36            # skip line
37            yield from buf.skipuntil(b"\r\n")
38            out.feed_data(Message(tp, None))
39        elif tp == MSG_STOP:
40            out.feed_data(Message(tp, None))
41        elif tp == MSG_TEXT:
42            # read text
43            text = yield from buf.readuntil(b"\r\n")
44            out.feed_data(Message(tp, text.strip().decode("utf-8")))
45        else:
46            raise ValueError("Unknown protocol prefix.")
47
48
49class MyProtocolWriter:
50    def __init__(self, transport):
51        self.transport = transport
52
53    def ping(self):
54        self.transport.write(b"ping:\r\n")
55
56    def pong(self):
57        self.transport.write(b"pong:\r\n")
58
59    def stop(self):
60        self.transport.write(b"stop:\r\n")
61
62    def send_text(self, text):
63        self.transport.write(f"text:{text.strip()}\r\n".encode("utf-8"))
64
65
66class EchoServer(asyncio.Protocol):
67    def connection_made(self, transport):
68        print("Connection made")
69        self.transport = transport
70        self.stream = aiohttp.StreamParser()
71        asyncio.Task(self.dispatch())
72
73    def data_received(self, data):
74        self.stream.feed_data(data)
75
76    def eof_received(self):
77        self.stream.feed_eof()
78
79    def connection_lost(self, exc):
80        print("Connection lost")
81
82    async def dispatch(self):
83        reader = self.stream.set_parser(my_protocol_parser)
84        writer = MyProtocolWriter(self.transport)
85
86        while True:
87            try:
88                msg = await reader.read()
89            except aiohttp.ConnectionError:
90                # client has been disconnected
91                break
92
93            print(f"Message received: {msg}")
94
95            if msg.type == MSG_PING:
96                writer.pong()
97            elif msg.type == MSG_TEXT:
98                writer.send_text("Re: " + msg.data)
99            elif msg.type == MSG_STOP:
100                self.transport.close()
101                break
102
103
104async def start_client(loop, host, port):
105    transport, stream = await loop.create_connection(aiohttp.StreamProtocol, host, port)
106    reader = stream.reader.set_parser(my_protocol_parser)
107    writer = MyProtocolWriter(transport)
108    writer.ping()
109
110    message = "This is the message. It will be echoed."
111
112    while True:
113        try:
114            msg = await reader.read()
115        except aiohttp.ConnectionError:
116            print("Server has been disconnected.")
117            break
118
119        print(f"Message received: {msg}")
120        if msg.type == MSG_PONG:
121            writer.send_text(message)
122            print("data sent:", message)
123        elif msg.type == MSG_TEXT:
124            writer.stop()
125            print("stop sent")
126            break
127
128    transport.close()
129
130
131def start_server(loop, host, port):
132    f = loop.create_server(EchoServer, host, port)
133    srv = loop.run_until_complete(f)
134    x = srv.sockets[0]
135    print("serving on", x.getsockname())
136    loop.run_forever()
137
138
139ARGS = argparse.ArgumentParser(description="Protocol parser example.")
140ARGS.add_argument(
141    "--server", action="store_true", dest="server", default=False, help="Run tcp server"
142)
143ARGS.add_argument(
144    "--client", action="store_true", dest="client", default=False, help="Run tcp client"
145)
146ARGS.add_argument(
147    "--host", action="store", dest="host", default="127.0.0.1", help="Host name"
148)
149ARGS.add_argument(
150    "--port", action="store", dest="port", default=9999, type=int, help="Port number"
151)
152
153
154if __name__ == "__main__":
155    args = ARGS.parse_args()
156
157    if ":" in args.host:
158        args.host, port = args.host.split(":", 1)
159        args.port = int(port)
160
161    if (not (args.server or args.client)) or (args.server and args.client):
162        print("Please specify --server or --client\n")
163        ARGS.print_help()
164    else:
165        loop = asyncio.get_event_loop()
166        if signal is not None:
167            loop.add_signal_handler(signal.SIGINT, loop.stop)
168
169        if args.server:
170            start_server(loop, args.host, args.port)
171        else:
172            loop.run_until_complete(start_client(loop, args.host, args.port))
173