1#
2# !!! WARNING !!!
3#
4# This example uses some private APIs.
5#
6
7import argparse
8import asyncio
9import json
10import logging
11import ssl
12import time
13from dataclasses import dataclass, field
14from enum import Flag
15from typing import Optional, cast
16
17import requests
18import urllib3
19from http3_client import HttpClient
20
21from aioquic.asyncio import connect
22from aioquic.h0.connection import H0_ALPN
23from aioquic.h3.connection import H3_ALPN, H3Connection
24from aioquic.h3.events import DataReceived, HeadersReceived, PushPromiseReceived
25from aioquic.quic.configuration import QuicConfiguration
26from aioquic.quic.logger import QuicLogger
27
28
29class Result(Flag):
30    V = 0x000001
31    H = 0x000002
32    D = 0x000004
33    C = 0x000008
34    R = 0x000010
35    Z = 0x000020
36    S = 0x000040
37    Q = 0x000080
38
39    M = 0x000100
40    B = 0x000200
41    A = 0x000400
42    U = 0x000800
43    P = 0x001000
44    E = 0x002000
45    L = 0x004000
46    T = 0x008000
47
48    three = 0x010000
49    d = 0x020000
50    p = 0x040000
51
52    def __str__(self):
53        flags = sorted(
54            map(
55                lambda x: getattr(Result, x),
56                filter(lambda x: not x.startswith("_"), dir(Result)),
57            ),
58            key=lambda x: x.value,
59        )
60        result_str = ""
61        for flag in flags:
62            if self & flag:
63                result_str += flag.name
64            else:
65                result_str += "-"
66        return result_str
67
68
69@dataclass
70class Server:
71    name: str
72    host: str
73    port: int = 4433
74    http3: bool = True
75    retry_port: Optional[int] = 4434
76    path: str = "/"
77    push_path: Optional[str] = None
78    result: Result = field(default_factory=lambda: Result(0))
79    session_resumption_port: Optional[int] = None
80    structured_logging: bool = False
81    throughput_file_suffix: str = ""
82    verify_mode: Optional[int] = None
83
84
85SERVERS = [
86    Server("akamaiquic", "ietf.akaquic.com", port=443, verify_mode=ssl.CERT_NONE),
87    Server(
88        "aioquic", "quic.aiortc.org", port=443, push_path="/", structured_logging=True
89    ),
90    Server("ats", "quic.ogre.com"),
91    Server("f5", "f5quic.com", retry_port=4433),
92    Server("haskell", "mew.org", retry_port=4433),
93    Server("gquic", "quic.rocks", retry_port=None),
94    Server("lsquic", "http3-test.litespeedtech.com", push_path="/200?push=/100"),
95    Server(
96        "msquic",
97        "quic.westus.cloudapp.azure.com",
98        port=4433,
99        session_resumption_port=4433,
100        structured_logging=True,
101        throughput_file_suffix=".txt",
102        verify_mode=ssl.CERT_NONE,
103    ),
104    Server(
105        "mvfst", "fb.mvfst.net", port=443, push_path="/push", structured_logging=True
106    ),
107    Server("ngtcp2", "nghttp2.org", push_path="/?push=/100"),
108    Server("ngx_quic", "cloudflare-quic.com", port=443, retry_port=443),
109    Server("pandora", "pandora.cm.in.tum.de", verify_mode=ssl.CERT_NONE),
110    Server("picoquic", "test.privateoctopus.com", structured_logging=True),
111    Server("quant", "quant.eggert.org", http3=False),
112    Server("quic-go", "quic.seemann.io", port=443, retry_port=443),
113    Server("quiche", "quic.tech", port=8443, retry_port=8444),
114    Server("quicly", "quic.examp1e.net"),
115    Server("quinn", "ralith.com"),
116]
117
118
119async def test_version_negotiation(server: Server, configuration: QuicConfiguration):
120    # force version negotiation
121    configuration.supported_versions.insert(0, 0x1A2A3A4A)
122
123    async with connect(
124        server.host, server.port, configuration=configuration
125    ) as protocol:
126        await protocol.ping()
127
128        # check log
129        for stamp, category, event, data in configuration.quic_logger.to_dict()[
130            "traces"
131        ][0]["events"]:
132            if (
133                category == "transport"
134                and event == "packet_received"
135                and data["packet_type"] == "version_negotiation"
136            ):
137                server.result |= Result.V
138
139
140async def test_handshake_and_close(server: Server, configuration: QuicConfiguration):
141    async with connect(
142        server.host, server.port, configuration=configuration
143    ) as protocol:
144        await protocol.ping()
145        server.result |= Result.H
146    server.result |= Result.C
147
148
149async def test_stateless_retry(server: Server, configuration: QuicConfiguration):
150    async with connect(
151        server.host, server.retry_port, configuration=configuration
152    ) as protocol:
153        await protocol.ping()
154
155        # check log
156        for stamp, category, event, data in configuration.quic_logger.to_dict()[
157            "traces"
158        ][0]["events"]:
159            if (
160                category == "transport"
161                and event == "packet_received"
162                and data["packet_type"] == "retry"
163            ):
164                server.result |= Result.S
165
166
167async def test_quantum_readiness(server: Server, configuration: QuicConfiguration):
168    configuration.quantum_readiness_test = True
169    async with connect(
170        server.host, server.port, configuration=configuration
171    ) as protocol:
172        await protocol.ping()
173        server.result |= Result.Q
174
175
176async def test_http_0(server: Server, configuration: QuicConfiguration):
177    if server.path is None:
178        return
179
180    configuration.alpn_protocols = H0_ALPN
181    async with connect(
182        server.host,
183        server.port,
184        configuration=configuration,
185        create_protocol=HttpClient,
186    ) as protocol:
187        protocol = cast(HttpClient, protocol)
188
189        # perform HTTP request
190        events = await protocol.get(
191            "https://{}:{}{}".format(server.host, server.port, server.path)
192        )
193        if events and isinstance(events[0], HeadersReceived):
194            server.result |= Result.D
195
196
197async def test_http_3(server: Server, configuration: QuicConfiguration):
198    if server.path is None:
199        return
200
201    configuration.alpn_protocols = H3_ALPN
202    async with connect(
203        server.host,
204        server.port,
205        configuration=configuration,
206        create_protocol=HttpClient,
207    ) as protocol:
208        protocol = cast(HttpClient, protocol)
209
210        # perform HTTP request
211        events = await protocol.get(
212            "https://{}:{}{}".format(server.host, server.port, server.path)
213        )
214        if events and isinstance(events[0], HeadersReceived):
215            server.result |= Result.D
216            server.result |= Result.three
217
218        # perform more HTTP requests to use QPACK dynamic tables
219        for i in range(2):
220            events = await protocol.get(
221                "https://{}:{}{}".format(server.host, server.port, server.path)
222            )
223        if events and isinstance(events[0], HeadersReceived):
224            http = cast(H3Connection, protocol._http)
225            protocol._quic._logger.info(
226                "QPACK decoder bytes RX %d TX %d",
227                http._decoder_bytes_received,
228                http._decoder_bytes_sent,
229            )
230            protocol._quic._logger.info(
231                "QPACK encoder bytes RX %d TX %d",
232                http._encoder_bytes_received,
233                http._encoder_bytes_sent,
234            )
235            if (
236                http._decoder_bytes_received
237                and http._decoder_bytes_sent
238                and http._encoder_bytes_received
239                and http._encoder_bytes_sent
240            ):
241                server.result |= Result.d
242
243        # check push support
244        if server.push_path is not None:
245            protocol.pushes.clear()
246            await protocol.get(
247                "https://{}:{}{}".format(server.host, server.port, server.push_path)
248            )
249            await asyncio.sleep(0.5)
250            for push_id, events in protocol.pushes.items():
251                if (
252                    len(events) >= 3
253                    and isinstance(events[0], PushPromiseReceived)
254                    and isinstance(events[1], HeadersReceived)
255                    and isinstance(events[2], DataReceived)
256                ):
257                    protocol._quic._logger.info(
258                        "Push promise %d for %s received (status %s)",
259                        push_id,
260                        dict(events[0].headers)[b":path"].decode("ascii"),
261                        int(dict(events[1].headers)[b":status"]),
262                    )
263
264                    server.result |= Result.p
265
266
267async def test_session_resumption(server: Server, configuration: QuicConfiguration):
268    port = server.session_resumption_port or server.port
269    saved_ticket = None
270
271    def session_ticket_handler(ticket):
272        nonlocal saved_ticket
273        saved_ticket = ticket
274
275    # connect a first time, receive a ticket
276    async with connect(
277        server.host,
278        port,
279        configuration=configuration,
280        session_ticket_handler=session_ticket_handler,
281    ) as protocol:
282        await protocol.ping()
283
284        # some servers don't send the ticket immediately
285        await asyncio.sleep(1)
286
287    # connect a second time, with the ticket
288    if saved_ticket is not None:
289        configuration.session_ticket = saved_ticket
290        async with connect(server.host, port, configuration=configuration) as protocol:
291            await protocol.ping()
292
293            # check session was resumed
294            if protocol._quic.tls.session_resumed:
295                server.result |= Result.R
296
297            # check early data was accepted
298            if protocol._quic.tls.early_data_accepted:
299                server.result |= Result.Z
300
301
302async def test_key_update(server: Server, configuration: QuicConfiguration):
303    async with connect(
304        server.host, server.port, configuration=configuration
305    ) as protocol:
306        # cause some traffic
307        await protocol.ping()
308
309        # request key update
310        protocol.request_key_update()
311
312        # cause more traffic
313        await protocol.ping()
314
315        server.result |= Result.U
316
317
318async def test_migration(server: Server, configuration: QuicConfiguration):
319    async with connect(
320        server.host, server.port, configuration=configuration
321    ) as protocol:
322        # cause some traffic
323        await protocol.ping()
324
325        # change connection ID and replace transport
326        protocol.change_connection_id()
327        protocol._transport.close()
328        await loop.create_datagram_endpoint(lambda: protocol, local_addr=("::", 0))
329
330        # cause more traffic
331        await protocol.ping()
332
333        # check log
334        dcids = set()
335        for stamp, category, event, data in configuration.quic_logger.to_dict()[
336            "traces"
337        ][0]["events"]:
338            if (
339                category == "transport"
340                and event == "packet_received"
341                and data["packet_type"] == "1RTT"
342            ):
343                dcids.add(data["header"]["dcid"])
344        if len(dcids) == 2:
345            server.result |= Result.M
346
347
348async def test_rebinding(server: Server, configuration: QuicConfiguration):
349    async with connect(
350        server.host, server.port, configuration=configuration
351    ) as protocol:
352        # cause some traffic
353        await protocol.ping()
354
355        # replace transport
356        protocol._transport.close()
357        await loop.create_datagram_endpoint(lambda: protocol, local_addr=("::", 0))
358
359        # cause more traffic
360        await protocol.ping()
361
362        server.result |= Result.B
363
364
365async def test_spin_bit(server: Server, configuration: QuicConfiguration):
366    async with connect(
367        server.host, server.port, configuration=configuration
368    ) as protocol:
369        for i in range(5):
370            await protocol.ping()
371
372        # check log
373        spin_bits = set()
374        for stamp, category, event, data in configuration.quic_logger.to_dict()[
375            "traces"
376        ][0]["events"]:
377            if category == "connectivity" and event == "spin_bit_updated":
378                spin_bits.add(data["state"])
379        if len(spin_bits) == 2:
380            server.result |= Result.P
381
382
383async def test_throughput(server: Server, configuration: QuicConfiguration):
384    failures = 0
385
386    for size in [5000000, 10000000]:
387        print("Testing %d bytes download" % size)
388        path = "/%d%s" % (size, server.throughput_file_suffix)
389
390        # perform HTTP request over TCP
391        start = time.time()
392        response = requests.get("https://" + server.host + path, verify=False)
393        tcp_octets = len(response.content)
394        tcp_elapsed = time.time() - start
395        assert tcp_octets == size, "HTTP/TCP response size mismatch"
396
397        # perform HTTP request over QUIC
398        if server.http3:
399            configuration.alpn_protocols = H3_ALPN
400        else:
401            configuration.alpn_protocols = H0_ALPN
402        start = time.time()
403        async with connect(
404            server.host,
405            server.port,
406            configuration=configuration,
407            create_protocol=HttpClient,
408        ) as protocol:
409            protocol = cast(HttpClient, protocol)
410
411            http_events = await protocol.get(
412                "https://{}:{}{}".format(server.host, server.port, path)
413            )
414            quic_elapsed = time.time() - start
415            quic_octets = 0
416            for http_event in http_events:
417                if isinstance(http_event, DataReceived):
418                    quic_octets += len(http_event.data)
419        assert quic_octets == size, "HTTP/QUIC response size mismatch"
420
421        print(" - HTTP/TCP  completed in %.3f s" % tcp_elapsed)
422        print(" - HTTP/QUIC completed in %.3f s" % quic_elapsed)
423
424        if quic_elapsed > 1.1 * tcp_elapsed:
425            failures += 1
426            print(" => FAIL")
427        else:
428            print(" => PASS")
429
430    if failures == 0:
431        server.result |= Result.T
432
433
434def print_result(server: Server) -> None:
435    result = str(server.result).replace("three", "3")
436    result = result[0:8] + " " + result[8:16] + " " + result[16:]
437    print("%s%s%s" % (server.name, " " * (20 - len(server.name)), result))
438
439
440async def run(servers, tests, quic_log=False, secrets_log_file=None) -> None:
441    for server in servers:
442        if server.structured_logging:
443            server.result |= Result.L
444        for test_name, test_func in tests:
445            print("\n=== %s %s ===\n" % (server.name, test_name))
446            configuration = QuicConfiguration(
447                alpn_protocols=H3_ALPN + H0_ALPN,
448                is_client=True,
449                quic_logger=QuicLogger(),
450                secrets_log_file=secrets_log_file,
451                verify_mode=server.verify_mode,
452            )
453            if test_name == "test_throughput":
454                timeout = 60
455            else:
456                timeout = 10
457            try:
458                await asyncio.wait_for(
459                    test_func(server, configuration), timeout=timeout
460                )
461            except Exception as exc:
462                print(exc)
463
464            if quic_log:
465                with open("%s-%s.qlog" % (server.name, test_name), "w") as logger_fp:
466                    json.dump(configuration.quic_logger.to_dict(), logger_fp, indent=4)
467
468        print("")
469        print_result(server)
470
471    # print summary
472    if len(servers) > 1:
473        print("SUMMARY")
474        for server in servers:
475            print_result(server)
476
477
478if __name__ == "__main__":
479    parser = argparse.ArgumentParser(description="QUIC interop client")
480    parser.add_argument(
481        "-q",
482        "--quic-log",
483        action="store_true",
484        help="log QUIC events to a file in QLOG format",
485    )
486    parser.add_argument(
487        "--server", type=str, help="only run against the specified server."
488    )
489    parser.add_argument("--test", type=str, help="only run the specifed test.")
490    parser.add_argument(
491        "-l",
492        "--secrets-log",
493        type=str,
494        help="log secrets to a file, for use with Wireshark",
495    )
496    parser.add_argument(
497        "-v", "--verbose", action="store_true", help="increase logging verbosity"
498    )
499
500    args = parser.parse_args()
501
502    logging.basicConfig(
503        format="%(asctime)s %(levelname)s %(name)s %(message)s",
504        level=logging.DEBUG if args.verbose else logging.INFO,
505    )
506
507    # open SSL log file
508    if args.secrets_log:
509        secrets_log_file = open(args.secrets_log, "a")
510    else:
511        secrets_log_file = None
512
513    # determine what to run
514    servers = SERVERS
515    tests = list(filter(lambda x: x[0].startswith("test_"), globals().items()))
516    if args.server:
517        servers = list(filter(lambda x: x.name == args.server, servers))
518    if args.test:
519        tests = list(filter(lambda x: x[0] == args.test, tests))
520
521    # disable requests SSL warnings
522    urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
523
524    loop = asyncio.get_event_loop()
525    loop.run_until_complete(
526        run(
527            servers=servers,
528            tests=tests,
529            quic_log=args.quic_log,
530            secrets_log_file=secrets_log_file,
531        )
532    )
533