1#
2# !!! WARNING !!!
3#
4# This example uses some private APIs.
5#
6
7import argparse
8import asyncio
9import logging
10import ssl
11import time
12from dataclasses import dataclass, field
13from enum import Flag
14from typing import Optional, cast
15
16import httpx
17from http3_client import HttpClient
18
19from aioquic.asyncio import connect
20from aioquic.h0.connection import H0_ALPN
21from aioquic.h3.connection import H3_ALPN, H3Connection
22from aioquic.h3.events import DataReceived, HeadersReceived, PushPromiseReceived
23from aioquic.quic.configuration import QuicConfiguration
24from aioquic.quic.logger import QuicFileLogger, QuicLogger
25
26
27class Result(Flag):
28    V = 0x000001
29    H = 0x000002
30    D = 0x000004
31    C = 0x000008
32    R = 0x000010
33    Z = 0x000020
34    S = 0x000040
35    Q = 0x000080
36
37    M = 0x000100
38    B = 0x000200
39    A = 0x000400
40    U = 0x000800
41    P = 0x001000
42    E = 0x002000
43    L = 0x004000
44    T = 0x008000
45
46    three = 0x010000
47    d = 0x020000
48    p = 0x040000
49
50    def __str__(self):
51        flags = sorted(
52            map(
53                lambda x: getattr(Result, x),
54                filter(lambda x: not x.startswith("_"), dir(Result)),
55            ),
56            key=lambda x: x.value,
57        )
58        result_str = ""
59        for flag in flags:
60            if self & flag:
61                result_str += flag.name
62            else:
63                result_str += "-"
64        return result_str
65
66
67@dataclass
68class Server:
69    name: str
70    host: str
71    port: int = 4433
72    http3: bool = True
73    http3_port: Optional[int] = None
74    retry_port: Optional[int] = 4434
75    path: str = "/"
76    push_path: Optional[str] = None
77    result: Result = field(default_factory=lambda: Result(0))
78    session_resumption_port: Optional[int] = None
79    structured_logging: bool = False
80    throughput_path: Optional[str] = "/%(size)d"
81    verify_mode: Optional[int] = None
82
83
84SERVERS = [
85    Server("akamaiquic", "ietf.akaquic.com", port=443, verify_mode=ssl.CERT_NONE),
86    Server(
87        "aioquic", "quic.aiortc.org", port=443, push_path="/", structured_logging=True
88    ),
89    Server("ats", "quic.ogre.com"),
90    Server("f5", "f5quic.com", retry_port=4433, throughput_path=None),
91    Server(
92        "haskell", "mew.org", structured_logging=True, throughput_path="/num/%(size)s"
93    ),
94    Server("gquic", "quic.rocks", retry_port=None),
95    Server("lsquic", "http3-test.litespeedtech.com", push_path="/200?push=/100"),
96    Server(
97        "msquic",
98        "quic.westus.cloudapp.azure.com",
99        structured_logging=True,
100        throughput_path=None,  # "/%(size)d.txt",
101        verify_mode=ssl.CERT_NONE,
102    ),
103    Server(
104        "mvfst",
105        "fb.mvfst.net",
106        port=443,
107        push_path="/push",
108        retry_port=None,
109        structured_logging=True,
110    ),
111    Server(
112        "ngtcp2",
113        "nghttp2.org",
114        push_path="/?push=/100",
115        structured_logging=True,
116        throughput_path=None,
117    ),
118    Server("ngx_quic", "cloudflare-quic.com", port=443, retry_port=None),
119    Server("pandora", "pandora.cm.in.tum.de", verify_mode=ssl.CERT_NONE),
120    Server("picoquic", "test.privateoctopus.com", structured_logging=True),
121    Server("quant", "quant.eggert.org", http3=False, structured_logging=True),
122    Server("quic-go", "interop.seemann.io", port=443, retry_port=443),
123    Server("quiche", "quic.tech", port=8443, retry_port=8444),
124    Server("quicly", "quic.examp1e.net", http3_port=443),
125    Server("quinn", "h3.stammw.eu", port=443),
126]
127
128
129async def test_version_negotiation(server: Server, configuration: QuicConfiguration):
130    # force version negotiation
131    configuration.supported_versions.insert(0, 0x1A2A3A4A)
132
133    async with connect(
134        server.host, server.port, configuration=configuration
135    ) as protocol:
136        await protocol.ping()
137
138        # check log
139        for event in configuration.quic_logger.to_dict()["traces"][0]["events"]:
140            if (
141                event["name"] == "transport:packet_received"
142                and event["data"]["header"]["packet_type"] == "version_negotiation"
143            ):
144                server.result |= Result.V
145
146
147async def test_handshake_and_close(server: Server, configuration: QuicConfiguration):
148    async with connect(
149        server.host, server.port, configuration=configuration
150    ) as protocol:
151        await protocol.ping()
152        server.result |= Result.H
153    server.result |= Result.C
154
155
156async def test_retry(server: Server, configuration: QuicConfiguration):
157    # skip test if there is not retry port
158    if server.retry_port is None:
159        return
160
161    async with connect(
162        server.host, server.retry_port, configuration=configuration
163    ) as protocol:
164        await protocol.ping()
165
166        # check log
167        for event in configuration.quic_logger.to_dict()["traces"][0]["events"]:
168            if (
169                event["name"] == "transport:packet_received"
170                and event["data"]["header"]["packet_type"] == "retry"
171            ):
172                server.result |= Result.S
173
174
175async def test_quantum_readiness(server: Server, configuration: QuicConfiguration):
176    configuration.quantum_readiness_test = True
177    async with connect(
178        server.host, server.port, configuration=configuration
179    ) as protocol:
180        await protocol.ping()
181        server.result |= Result.Q
182
183
184async def test_http_0(server: Server, configuration: QuicConfiguration):
185    if server.path is None:
186        return
187
188    configuration.alpn_protocols = H0_ALPN
189    async with connect(
190        server.host,
191        server.port,
192        configuration=configuration,
193        create_protocol=HttpClient,
194    ) as protocol:
195        protocol = cast(HttpClient, protocol)
196
197        # perform HTTP request
198        events = await protocol.get(
199            "https://{}:{}{}".format(server.host, server.port, server.path)
200        )
201        if events and isinstance(events[0], HeadersReceived):
202            server.result |= Result.D
203
204
205async def test_http_3(server: Server, configuration: QuicConfiguration):
206    port = server.http3_port or server.port
207    if server.path is None:
208        return
209
210    configuration.alpn_protocols = H3_ALPN
211    async with connect(
212        server.host,
213        port,
214        configuration=configuration,
215        create_protocol=HttpClient,
216    ) as protocol:
217        protocol = cast(HttpClient, protocol)
218
219        # perform HTTP request
220        events = await protocol.get(
221            "https://{}:{}{}".format(server.host, server.port, server.path)
222        )
223        if events and isinstance(events[0], HeadersReceived):
224            server.result |= Result.D
225            server.result |= Result.three
226
227        # perform more HTTP requests to use QPACK dynamic tables
228        for i in range(2):
229            events = await protocol.get(
230                "https://{}:{}{}".format(server.host, server.port, server.path)
231            )
232        if events and isinstance(events[0], HeadersReceived):
233            http = cast(H3Connection, protocol._http)
234            protocol._quic._logger.info(
235                "QPACK decoder bytes RX %d TX %d",
236                http._decoder_bytes_received,
237                http._decoder_bytes_sent,
238            )
239            protocol._quic._logger.info(
240                "QPACK encoder bytes RX %d TX %d",
241                http._encoder_bytes_received,
242                http._encoder_bytes_sent,
243            )
244            if (
245                http._decoder_bytes_received
246                and http._decoder_bytes_sent
247                and http._encoder_bytes_received
248                and http._encoder_bytes_sent
249            ):
250                server.result |= Result.d
251
252        # check push support
253        if server.push_path is not None:
254            protocol.pushes.clear()
255            await protocol.get(
256                "https://{}:{}{}".format(server.host, server.port, server.push_path)
257            )
258            await asyncio.sleep(0.5)
259            for push_id, events in protocol.pushes.items():
260                if (
261                    len(events) >= 3
262                    and isinstance(events[0], PushPromiseReceived)
263                    and isinstance(events[1], HeadersReceived)
264                    and isinstance(events[2], DataReceived)
265                ):
266                    protocol._quic._logger.info(
267                        "Push promise %d for %s received (status %s)",
268                        push_id,
269                        dict(events[0].headers)[b":path"].decode("ascii"),
270                        int(dict(events[1].headers)[b":status"]),
271                    )
272
273                    server.result |= Result.p
274
275
276async def test_session_resumption(server: Server, configuration: QuicConfiguration):
277    port = server.session_resumption_port or server.port
278    saved_ticket = None
279
280    def session_ticket_handler(ticket):
281        nonlocal saved_ticket
282        saved_ticket = ticket
283
284    # connect a first time, receive a ticket
285    async with connect(
286        server.host,
287        port,
288        configuration=configuration,
289        session_ticket_handler=session_ticket_handler,
290    ) as protocol:
291        await protocol.ping()
292
293        # some servers don't send the ticket immediately
294        await asyncio.sleep(1)
295
296    # connect a second time, with the ticket
297    if saved_ticket is not None:
298        configuration.session_ticket = saved_ticket
299        async with connect(server.host, port, configuration=configuration) as protocol:
300            await protocol.ping()
301
302            # check session was resumed
303            if protocol._quic.tls.session_resumed:
304                server.result |= Result.R
305
306            # check early data was accepted
307            if protocol._quic.tls.early_data_accepted:
308                server.result |= Result.Z
309
310
311async def test_key_update(server: Server, configuration: QuicConfiguration):
312    async with connect(
313        server.host, server.port, configuration=configuration
314    ) as protocol:
315        # cause some traffic
316        await protocol.ping()
317
318        # request key update
319        protocol.request_key_update()
320
321        # cause more traffic
322        await protocol.ping()
323
324        server.result |= Result.U
325
326
327async def test_server_cid_change(server: Server, configuration: QuicConfiguration):
328    async with connect(
329        server.host, server.port, configuration=configuration
330    ) as protocol:
331        # cause some traffic
332        await protocol.ping()
333
334        # change connection ID
335        protocol.change_connection_id()
336
337        # cause more traffic
338        await protocol.ping()
339
340        server.result |= Result.M
341
342
343async def test_nat_rebinding(server: Server, configuration: QuicConfiguration):
344    async with connect(
345        server.host, server.port, configuration=configuration
346    ) as protocol:
347        # cause some traffic
348        await protocol.ping()
349
350        # replace transport
351        protocol._transport.close()
352        await loop.create_datagram_endpoint(lambda: protocol, local_addr=("::", 0))
353
354        # cause more traffic
355        await protocol.ping()
356
357        # check log
358        path_challenges = 0
359        for event in configuration.quic_logger.to_dict()["traces"][0]["events"]:
360            if (
361                event["name"] == "transport:packet_received"
362                and event["data"]["header"]["packet_type"] == "1RTT"
363            ):
364                for frame in event["data"]["frames"]:
365                    if frame["frame_type"] == "path_challenge":
366                        path_challenges += 1
367        if not path_challenges:
368            protocol._quic._logger.warning("No PATH_CHALLENGE received")
369        else:
370            server.result |= Result.B
371
372
373async def test_address_mobility(server: Server, configuration: QuicConfiguration):
374    async with connect(
375        server.host, server.port, configuration=configuration
376    ) as protocol:
377        # cause some traffic
378        await protocol.ping()
379
380        # replace transport
381        protocol._transport.close()
382        await loop.create_datagram_endpoint(lambda: protocol, local_addr=("::", 0))
383
384        # change connection ID
385        protocol.change_connection_id()
386
387        # cause more traffic
388        await protocol.ping()
389
390        # check log
391        path_challenges = 0
392        for event in configuration.quic_logger.to_dict()["traces"][0]["events"]:
393            if (
394                event["name"] == "transport:packet_received"
395                and event["data"]["header"]["packet_type"] == "1RTT"
396            ):
397                for frame in event["data"]["frames"]:
398                    if frame["frame_type"] == "path_challenge":
399                        path_challenges += 1
400        if not path_challenges:
401            protocol._quic._logger.warning("No PATH_CHALLENGE received")
402        else:
403            server.result |= Result.A
404
405
406async def test_spin_bit(server: Server, configuration: QuicConfiguration):
407    async with connect(
408        server.host, server.port, configuration=configuration
409    ) as protocol:
410        for i in range(5):
411            await protocol.ping()
412
413        # check log
414        spin_bits = set()
415        for event in configuration.quic_logger.to_dict()["traces"][0]["events"]:
416            if event["name"] == "connectivity:spin_bit_updated":
417                spin_bits.add(event["data"]["state"])
418        if len(spin_bits) == 2:
419            server.result |= Result.P
420
421
422async def test_throughput(server: Server, configuration: QuicConfiguration):
423    failures = 0
424    if server.throughput_path is None:
425        return
426
427    for size in [5000000, 10000000]:
428        path = server.throughput_path % {"size": size}
429        print("Testing %d bytes download: %s" % (size, path))
430
431        # perform HTTP request over TCP
432        start = time.time()
433        response = httpx.get("https://" + server.host + path, verify=False)
434        tcp_octets = len(response.content)
435        tcp_elapsed = time.time() - start
436        assert tcp_octets == size, "HTTP/TCP response size mismatch"
437
438        # perform HTTP request over QUIC
439        if server.http3:
440            configuration.alpn_protocols = H3_ALPN
441            port = server.http3_port or server.port
442        else:
443            configuration.alpn_protocols = H0_ALPN
444            port = server.port
445        start = time.time()
446        async with connect(
447            server.host,
448            port,
449            configuration=configuration,
450            create_protocol=HttpClient,
451        ) as protocol:
452            protocol = cast(HttpClient, protocol)
453
454            http_events = await protocol.get(
455                "https://{}:{}{}".format(server.host, server.port, path)
456            )
457            quic_elapsed = time.time() - start
458            quic_octets = 0
459            for http_event in http_events:
460                if isinstance(http_event, DataReceived):
461                    quic_octets += len(http_event.data)
462        assert quic_octets == size, "HTTP/QUIC response size mismatch"
463
464        print(" - HTTP/TCP  completed in %.3f s" % tcp_elapsed)
465        print(" - HTTP/QUIC completed in %.3f s" % quic_elapsed)
466
467        if quic_elapsed > 1.1 * tcp_elapsed:
468            failures += 1
469            print(" => FAIL")
470        else:
471            print(" => PASS")
472
473    if failures == 0:
474        server.result |= Result.T
475
476
477def print_result(server: Server) -> None:
478    result = str(server.result).replace("three", "3")
479    result = result[0:8] + " " + result[8:16] + " " + result[16:]
480    print("%s%s%s" % (server.name, " " * (20 - len(server.name)), result))
481
482
483async def run(servers, tests, quic_log=False, secrets_log_file=None) -> None:
484    for server in servers:
485        if server.structured_logging:
486            server.result |= Result.L
487        for test_name, test_func in tests:
488            print("\n=== %s %s ===\n" % (server.name, test_name))
489            configuration = QuicConfiguration(
490                alpn_protocols=H3_ALPN + H0_ALPN,
491                is_client=True,
492                quic_logger=QuicFileLogger(quic_log) if quic_log else QuicLogger(),
493                secrets_log_file=secrets_log_file,
494                verify_mode=server.verify_mode,
495            )
496            if test_name == "test_throughput":
497                timeout = 120
498            else:
499                timeout = 10
500            try:
501                await asyncio.wait_for(
502                    test_func(server, configuration), timeout=timeout
503                )
504            except Exception as exc:
505                print(exc)
506
507        print("")
508        print_result(server)
509
510    # print summary
511    if len(servers) > 1:
512        print("SUMMARY")
513        for server in servers:
514            print_result(server)
515
516
517if __name__ == "__main__":
518    parser = argparse.ArgumentParser(description="QUIC interop client")
519    parser.add_argument(
520        "-q",
521        "--quic-log",
522        type=str,
523        help="log QUIC events to QLOG files in the specified directory",
524    )
525    parser.add_argument(
526        "--server", type=str, help="only run against the specified server."
527    )
528    parser.add_argument("--test", type=str, help="only run the specifed test.")
529    parser.add_argument(
530        "-l",
531        "--secrets-log",
532        type=str,
533        help="log secrets to a file, for use with Wireshark",
534    )
535    parser.add_argument(
536        "-v", "--verbose", action="store_true", help="increase logging verbosity"
537    )
538
539    args = parser.parse_args()
540
541    logging.basicConfig(
542        format="%(asctime)s %(levelname)s %(name)s %(message)s",
543        level=logging.DEBUG if args.verbose else logging.INFO,
544    )
545
546    # open SSL log file
547    if args.secrets_log:
548        secrets_log_file = open(args.secrets_log, "a")
549    else:
550        secrets_log_file = None
551
552    # determine what to run
553    servers = SERVERS
554    tests = list(filter(lambda x: x[0].startswith("test_"), globals().items()))
555    if args.server:
556        servers = list(filter(lambda x: x.name == args.server, servers))
557    if args.test:
558        tests = list(filter(lambda x: x[0] == args.test, tests))
559
560    loop = asyncio.get_event_loop()
561    loop.run_until_complete(
562        run(
563            servers=servers,
564            tests=tests,
565            quic_log=args.quic_log,
566            secrets_log_file=secrets_log_file,
567        )
568    )
569