xref: /freebsd/tests/atf_python/sys/net/rtsock.py (revision 8eb2bee6)
1#!/usr/local/bin/python3
2import os
3import socket
4import struct
5import sys
6from ctypes import c_byte
7from ctypes import c_char
8from ctypes import c_int
9from ctypes import c_long
10from ctypes import c_uint32
11from ctypes import c_ulong
12from ctypes import c_ushort
13from ctypes import sizeof
14from ctypes import Structure
15from typing import Dict
16from typing import List
17from typing import Optional
18from typing import Union
19
20
21def roundup2(val: int, num: int) -> int:
22    if val % num:
23        return (val | (num - 1)) + 1
24    else:
25        return val
26
27
28class RtSockException(OSError):
29    pass
30
31
32class RtConst:
33    RTM_VERSION = 5
34    ALIGN = sizeof(c_long)
35
36    AF_INET = socket.AF_INET
37    AF_INET6 = socket.AF_INET6
38    AF_LINK = socket.AF_LINK
39
40    RTA_DST = 0x1
41    RTA_GATEWAY = 0x2
42    RTA_NETMASK = 0x4
43    RTA_GENMASK = 0x8
44    RTA_IFP = 0x10
45    RTA_IFA = 0x20
46    RTA_AUTHOR = 0x40
47    RTA_BRD = 0x80
48
49    RTM_ADD = 1
50    RTM_DELETE = 2
51    RTM_CHANGE = 3
52    RTM_GET = 4
53
54    RTF_UP = 0x1
55    RTF_GATEWAY = 0x2
56    RTF_HOST = 0x4
57    RTF_REJECT = 0x8
58    RTF_DYNAMIC = 0x10
59    RTF_MODIFIED = 0x20
60    RTF_DONE = 0x40
61    RTF_XRESOLVE = 0x200
62    RTF_LLINFO = 0x400
63    RTF_LLDATA = 0x400
64    RTF_STATIC = 0x800
65    RTF_BLACKHOLE = 0x1000
66    RTF_PROTO2 = 0x4000
67    RTF_PROTO1 = 0x8000
68    RTF_PROTO3 = 0x40000
69    RTF_FIXEDMTU = 0x80000
70    RTF_PINNED = 0x100000
71    RTF_LOCAL = 0x200000
72    RTF_BROADCAST = 0x400000
73    RTF_MULTICAST = 0x800000
74    RTF_STICKY = 0x10000000
75    RTF_RNH_LOCKED = 0x40000000
76    RTF_GWFLAG_COMPAT = 0x80000000
77
78    RTV_MTU = 0x1
79    RTV_HOPCOUNT = 0x2
80    RTV_EXPIRE = 0x4
81    RTV_RPIPE = 0x8
82    RTV_SPIPE = 0x10
83    RTV_SSTHRESH = 0x20
84    RTV_RTT = 0x40
85    RTV_RTTVAR = 0x80
86    RTV_WEIGHT = 0x100
87
88    @staticmethod
89    def get_props(prefix: str) -> List[str]:
90        return [n for n in dir(RtConst) if n.startswith(prefix)]
91
92    @staticmethod
93    def get_name(prefix: str, value: int) -> str:
94        props = RtConst.get_props(prefix)
95        for prop in props:
96            if getattr(RtConst, prop) == value:
97                return prop
98        return "U:{}:{}".format(prefix, value)
99
100    @staticmethod
101    def get_bitmask_map(prefix: str, value: int) -> Dict[int, str]:
102        props = RtConst.get_props(prefix)
103        propmap = {getattr(RtConst, prop): prop for prop in props}
104        v = 1
105        ret = {}
106        while value:
107            if v & value:
108                if v in propmap:
109                    ret[v] = propmap[v]
110                else:
111                    ret[v] = hex(v)
112                value -= v
113            v *= 2
114        return ret
115
116    @staticmethod
117    def get_bitmask_str(prefix: str, value: int) -> str:
118        bmap = RtConst.get_bitmask_map(prefix, value)
119        return ",".join([v for k, v in bmap.items()])
120
121
122class RtMetrics(Structure):
123    _fields_ = [
124        ("rmx_locks", c_ulong),
125        ("rmx_mtu", c_ulong),
126        ("rmx_hopcount", c_ulong),
127        ("rmx_expire", c_ulong),
128        ("rmx_recvpipe", c_ulong),
129        ("rmx_sendpipe", c_ulong),
130        ("rmx_ssthresh", c_ulong),
131        ("rmx_rtt", c_ulong),
132        ("rmx_rttvar", c_ulong),
133        ("rmx_pksent", c_ulong),
134        ("rmx_weight", c_ulong),
135        ("rmx_nhidx", c_ulong),
136        ("rmx_filler", c_ulong * 2),
137    ]
138
139
140class RtMsgHdr(Structure):
141    _fields_ = [
142        ("rtm_msglen", c_ushort),
143        ("rtm_version", c_byte),
144        ("rtm_type", c_byte),
145        ("rtm_index", c_ushort),
146        ("_rtm_spare1", c_ushort),
147        ("rtm_flags", c_int),
148        ("rtm_addrs", c_int),
149        ("rtm_pid", c_int),
150        ("rtm_seq", c_int),
151        ("rtm_errno", c_int),
152        ("rtm_fmask", c_int),
153        ("rtm_inits", c_ulong),
154        ("rtm_rmx", RtMetrics),
155    ]
156
157
158class SockaddrIn(Structure):
159    _fields_ = [
160        ("sin_len", c_byte),
161        ("sin_family", c_byte),
162        ("sin_port", c_ushort),
163        ("sin_addr", c_uint32),
164        ("sin_zero", c_char * 8),
165    ]
166
167
168class SockaddrIn6(Structure):
169    _fields_ = [
170        ("sin6_len", c_byte),
171        ("sin6_family", c_byte),
172        ("sin6_port", c_ushort),
173        ("sin6_flowinfo", c_uint32),
174        ("sin6_addr", c_byte * 16),
175        ("sin6_scope_id", c_uint32),
176    ]
177
178
179class SockaddrDl(Structure):
180    _fields_ = [
181        ("sdl_len", c_byte),
182        ("sdl_family", c_byte),
183        ("sdl_index", c_ushort),
184        ("sdl_type", c_byte),
185        ("sdl_nlen", c_byte),
186        ("sdl_alen", c_byte),
187        ("sdl_slen", c_byte),
188        ("sdl_data", c_byte * 8),
189    ]
190
191
192class SaHelper(object):
193    @staticmethod
194    def is_ipv6(ip: str) -> bool:
195        return ":" in ip
196
197    @staticmethod
198    def ip_sa(ip: str, scopeid: int = 0) -> bytes:
199        if SaHelper.is_ipv6(ip):
200            return SaHelper.ip6_sa(ip, scopeid)
201        else:
202            return SaHelper.ip4_sa(ip)
203
204    @staticmethod
205    def ip4_sa(ip: str) -> bytes:
206        addr_int = int.from_bytes(socket.inet_pton(2, ip), sys.byteorder)
207        sin = SockaddrIn(sizeof(SockaddrIn), socket.AF_INET, 0, addr_int)
208        return bytes(sin)
209
210    @staticmethod
211    def ip6_sa(ip6: str, scopeid: int) -> bytes:
212        addr_bytes = (c_byte * 16)()
213        for i, b in enumerate(socket.inet_pton(socket.AF_INET6, ip6)):
214            addr_bytes[i] = b
215        sin6 = SockaddrIn6(
216            sizeof(SockaddrIn6), socket.AF_INET6, 0, 0, addr_bytes, scopeid
217        )
218        return bytes(sin6)
219
220    @staticmethod
221    def link_sa(ifindex: int = 0, iftype: int = 0) -> bytes:
222        sa = SockaddrDl(sizeof(SockaddrDl), socket.AF_LINK, c_ushort(ifindex), iftype)
223        return bytes(sa)
224
225    @staticmethod
226    def pxlen4_sa(pxlen: int) -> bytes:
227        return SaHelper.ip_sa(SaHelper.pxlen_to_ip4(pxlen))
228
229    @staticmethod
230    def pxlen_to_ip4(pxlen: int) -> str:
231        if pxlen == 32:
232            return "255.255.255.255"
233        else:
234            addr = 0xFFFFFFFF - ((1 << (32 - pxlen)) - 1)
235            addr_bytes = struct.pack("!I", addr)
236            return socket.inet_ntop(socket.AF_INET, addr_bytes)
237
238    @staticmethod
239    def pxlen6_sa(pxlen: int) -> bytes:
240        return SaHelper.ip_sa(SaHelper.pxlen_to_ip6(pxlen))
241
242    @staticmethod
243    def pxlen_to_ip6(pxlen: int) -> str:
244        ip6_b = [0] * 16
245        start = 0
246        while pxlen > 8:
247            ip6_b[start] = 0xFF
248            pxlen -= 8
249            start += 1
250        ip6_b[start] = 0xFF - ((1 << (8 - pxlen)) - 1)
251        return socket.inet_ntop(socket.AF_INET6, bytes(ip6_b))
252
253    @staticmethod
254    def print_sa_inet(sa: bytes):
255        if len(sa) < 8:
256            raise RtSockException("IPv4 sa size too small: {}".format(len(sa)))
257        addr = socket.inet_ntop(socket.AF_INET, sa[4:8])
258        return "{}".format(addr)
259
260    @staticmethod
261    def print_sa_inet6(sa: bytes):
262        if len(sa) < sizeof(SockaddrIn6):
263            raise RtSockException("IPv6 sa size too small: {}".format(len(sa)))
264        addr = socket.inet_ntop(socket.AF_INET6, sa[8:24])
265        scopeid = struct.unpack(">I", sa[24:28])[0]
266        return "{} scopeid {}".format(addr, scopeid)
267
268    @staticmethod
269    def print_sa_link(sa: bytes, hd: Optional[bool] = True):
270        if len(sa) < sizeof(SockaddrDl):
271            raise RtSockException("LINK sa size too small: {}".format(len(sa)))
272        sdl = SockaddrDl.from_buffer_copy(sa)
273        if sdl.sdl_index:
274            ifindex = "link#{} ".format(sdl.sdl_index)
275        else:
276            ifindex = ""
277        if sdl.sdl_nlen:
278            iface_offset = 8
279            if sdl.sdl_nlen + iface_offset > len(sa):
280                raise RtSockException(
281                    "LINK sa sdl_nlen {} > total len {}".format(sdl.sdl_nlen, len(sa))
282                )
283            ifname = "ifname:{} ".format(
284                bytes.decode(sa[iface_offset : iface_offset + sdl.sdl_nlen])
285            )
286        else:
287            ifname = ""
288        return "{}{}".format(ifindex, ifname)
289
290    @staticmethod
291    def print_sa_unknown(sa: bytes):
292        return "unknown_type:{}".format(sa[1])
293
294    @classmethod
295    def print_sa(cls, sa: bytes, hd: Optional[bool] = False):
296        if sa[0] != len(sa):
297            raise Exception("sa size {} != buffer size {}".format(sa[0], len(sa)))
298
299        if len(sa) < 2:
300            raise Exception(
301                "sa type {} too short: {}".format(
302                    RtConst.get_name("AF_", sa[1]), len(sa)
303                )
304            )
305
306        if sa[1] == socket.AF_INET:
307            text = cls.print_sa_inet(sa)
308        elif sa[1] == socket.AF_INET6:
309            text = cls.print_sa_inet6(sa)
310        elif sa[1] == socket.AF_LINK:
311            text = cls.print_sa_link(sa)
312        else:
313            text = cls.print_sa_unknown(sa)
314        if hd:
315            dump = " [{!r}]".format(sa)
316        else:
317            dump = ""
318        return "{}{}".format(text, dump)
319
320
321class BaseRtsockMessage(object):
322    def __init__(self, rtm_type):
323        self.rtm_type = rtm_type
324        self.sa = SaHelper()
325
326    @staticmethod
327    def print_rtm_type(rtm_type):
328        return RtConst.get_name("RTM_", rtm_type)
329
330    @property
331    def rtm_type_str(self):
332        return self.print_rtm_type(self.rtm_type)
333
334
335class RtsockRtMessage(BaseRtsockMessage):
336    messages = [
337        RtConst.RTM_ADD,
338        RtConst.RTM_DELETE,
339        RtConst.RTM_CHANGE,
340        RtConst.RTM_GET,
341    ]
342
343    def __init__(self, rtm_type, rtm_seq=1, dst_sa=None, mask_sa=None):
344        super().__init__(rtm_type)
345        self.rtm_flags = 0
346        self.rtm_seq = rtm_seq
347        self._attrs = {}
348        self.rtm_errno = 0
349        self.rtm_pid = 0
350        self.rtm_inits = 0
351        self.rtm_rmx = RtMetrics()
352        self._orig_data = None
353        if dst_sa:
354            self.add_sa_attr(RtConst.RTA_DST, dst_sa)
355        if mask_sa:
356            self.add_sa_attr(RtConst.RTA_NETMASK, mask_sa)
357
358    def add_sa_attr(self, attr_type, attr_bytes: bytes):
359        self._attrs[attr_type] = attr_bytes
360
361    def add_ip_attr(self, attr_type, ip_addr: str, scopeid: int = 0):
362        if ":" in ip_addr:
363            self.add_ip6_attr(attr_type, ip_addr, scopeid)
364        else:
365            self.add_ip4_attr(attr_type, ip_addr)
366
367    def add_ip4_attr(self, attr_type, ip: str):
368        self.add_sa_attr(attr_type, self.sa.ip_sa(ip))
369
370    def add_ip6_attr(self, attr_type, ip6: str, scopeid: int):
371        self.add_sa_attr(attr_type, self.sa.ip6_sa(ip6, scopeid))
372
373    def add_link_attr(self, attr_type, ifindex: Optional[int] = 0):
374        self.add_sa_attr(attr_type, self.sa.link_sa(ifindex))
375
376    def get_sa(self, attr_type) -> bytes:
377        return self._attrs.get(attr_type)
378
379    def print_message(self):
380        # RTM_GET: Report Metrics: len 272, pid: 87839, seq 1, errno 0, flags:<UP,GATEWAY,DONE,STATIC>
381        if self._orig_data:
382            rtm_len = len(self._orig_data)
383        else:
384            rtm_len = len(bytes(self))
385        print(
386            "{}: len {}, pid: {}, seq {}, errno {}, flags: <{}>".format(
387                self.rtm_type_str,
388                rtm_len,
389                self.rtm_pid,
390                self.rtm_seq,
391                self.rtm_errno,
392                RtConst.get_bitmask_str("RTF_", self.rtm_flags),
393            )
394        )
395        rtm_addrs = sum(list(self._attrs.keys()))
396        print("Addrs: <{}>".format(RtConst.get_bitmask_str("RTA_", rtm_addrs)))
397        for attr in sorted(self._attrs.keys()):
398            sa_data = SaHelper.print_sa(self._attrs[attr])
399            print(" {}: {}".format(RtConst.get_name("RTA_", attr), sa_data))
400
401    def print_in_message(self):
402        print("vvvvvvvv  IN vvvvvvvv")
403        self.print_message()
404        print()
405
406    def verify_sa_inet(self, sa_data):
407        if len(sa_data) < 8:
408            raise Exception("IPv4 sa size too small: {}".format(sa_data))
409        if sa_data[0] > len(sa_data):
410            raise Exception(
411                "IPv4 sin_len too big: {} vs sa size {}: {}".format(
412                    sa_data[0], len(sa_data), sa_data
413                )
414            )
415        sin = SockaddrIn.from_buffer_copy(sa_data)
416        assert sin.sin_port == 0
417        assert sin.sin_zero == [0] * 8
418
419    def compare_sa(self, sa_type, sa_data):
420        if len(sa_data) < 4:
421            sa_type_name = RtConst.get_name("RTA_", sa_type)
422            raise Exception(
423                "sa_len for type {} too short: {}".format(sa_type_name, len(sa_data))
424            )
425        our_sa = self._attrs[sa_type]
426        assert SaHelper.print_sa(sa_data) == SaHelper.print_sa(our_sa)
427        assert len(sa_data) == len(our_sa)
428        assert sa_data == our_sa
429
430    def verify(self, rtm_type: int, rtm_sa):
431        assert self.rtm_type_str == self.print_rtm_type(rtm_type)
432        assert self.rtm_errno == 0
433        hdr = RtMsgHdr.from_buffer_copy(self._orig_data)
434        assert hdr._rtm_spare1 == 0
435        for sa_type, sa_data in rtm_sa.items():
436            if sa_type not in self._attrs:
437                sa_type_name = RtConst.get_name("RTA_", sa_type)
438                raise Exception("SA type {} not present".format(sa_type_name))
439            self.compare_sa(sa_type, sa_data)
440
441    @classmethod
442    def from_bytes(cls, data: bytes):
443        if len(data) < sizeof(RtMsgHdr):
444            raise Exception(
445                "messages size {} is less than expected {}".format(
446                    len(data), sizeof(RtMsgHdr)
447                )
448            )
449        hdr = RtMsgHdr.from_buffer_copy(data)
450
451        self = cls(hdr.rtm_type)
452        self.rtm_flags = hdr.rtm_flags
453        self.rtm_seq = hdr.rtm_seq
454        self.rtm_errno = hdr.rtm_errno
455        self.rtm_pid = hdr.rtm_pid
456        self.rtm_inits = hdr.rtm_inits
457        self.rtm_rmx = hdr.rtm_rmx
458        self._orig_data = data
459
460        off = sizeof(RtMsgHdr)
461        v = 1
462        addrs_mask = hdr.rtm_addrs
463        while addrs_mask:
464            if addrs_mask & v:
465                addrs_mask -= v
466
467                if off + data[off] > len(data):
468                    raise Exception(
469                        "SA sizeof for {} > total message length: {}+{} > {}".format(
470                            RtConst.get_name("RTA_", v), off, data[off], len(data)
471                        )
472                    )
473                self._attrs[v] = data[off : off + data[off]]
474                off += roundup2(data[off], RtConst.ALIGN)
475            v *= 2
476        return self
477
478    def __bytes__(self):
479        sz = sizeof(RtMsgHdr)
480        addrs_mask = 0
481        for k, v in self._attrs.items():
482            sz += roundup2(len(v), RtConst.ALIGN)
483            addrs_mask += k
484        hdr = RtMsgHdr(
485            rtm_msglen=sz,
486            rtm_version=RtConst.RTM_VERSION,
487            rtm_type=self.rtm_type,
488            rtm_flags=self.rtm_flags,
489            rtm_seq=self.rtm_seq,
490            rtm_addrs=addrs_mask,
491            rtm_inits=self.rtm_inits,
492            rtm_rmx=self.rtm_rmx,
493        )
494        buf = bytearray(sz)
495        buf[0 : sizeof(RtMsgHdr)] = hdr
496        off = sizeof(RtMsgHdr)
497        for attr in sorted(self._attrs.keys()):
498            v = self._attrs[attr]
499            sa_len = len(v)
500            buf[off : off + sa_len] = v
501            off += roundup2(len(v), RtConst.ALIGN)
502        return bytes(buf)
503
504
505class Rtsock:
506    def __init__(self):
507        self.socket = self._setup_rtsock()
508        self.rtm_seq = 1
509        self.msgmap = self.build_msgmap()
510
511    def build_msgmap(self):
512        classes = [RtsockRtMessage]
513        xmap = {}
514        for cls in classes:
515            for message in cls.messages:
516                xmap[message] = cls
517        return xmap
518
519    def get_seq(self):
520        ret = self.rtm_seq
521        self.rtm_seq += 1
522        return ret
523
524    def get_weight(self, weight) -> int:
525        if weight:
526            return weight
527        else:
528            return 1  # RT_DEFAULT_WEIGHT
529
530    def new_rtm_any(self, msg_type, prefix: str, gw: Union[str, bytes]):
531        px = prefix.split("/")
532        addr_sa = SaHelper.ip_sa(px[0])
533        if len(px) > 1:
534            pxlen = int(px[1])
535            if SaHelper.is_ipv6(px[0]):
536                mask_sa = SaHelper.pxlen6_sa(pxlen)
537            else:
538                mask_sa = SaHelper.pxlen4_sa(pxlen)
539        else:
540            mask_sa = None
541        msg = RtsockRtMessage(msg_type, self.get_seq(), addr_sa, mask_sa)
542        if isinstance(gw, bytes):
543            msg.add_sa_attr(RtConst.RTA_GATEWAY, gw)
544        else:
545            # String
546            msg.add_ip_attr(RtConst.RTA_GATEWAY, gw)
547        return msg
548
549    def new_rtm_add(self, prefix: str, gw: Union[str, bytes]):
550        return self.new_rtm_any(RtConst.RTM_ADD, prefix, gw)
551
552    def new_rtm_del(self, prefix: str, gw: Union[str, bytes]):
553        return self.new_rtm_any(RtConst.RTM_DELETE, prefix, gw)
554
555    def new_rtm_change(self, prefix: str, gw: Union[str, bytes]):
556        return self.new_rtm_any(RtConst.RTM_CHANGE, prefix, gw)
557
558    def _setup_rtsock(self) -> socket.socket:
559        s = socket.socket(socket.AF_ROUTE, socket.SOCK_RAW, socket.AF_UNSPEC)
560        s.setsockopt(socket.SOL_SOCKET, socket.SO_USELOOPBACK, 1)
561        return s
562
563    def print_hd(self, data: bytes):
564        width = 16
565        print("==========================================")
566        for chunk in [data[i : i + width] for i in range(0, len(data), width)]:
567            for b in chunk:
568                print("0x{:02X} ".format(b), end="")
569            print()
570        print()
571
572    def write_message(self, msg):
573        print("vvvvvvvv OUT vvvvvvvv")
574        msg.print_message()
575        print()
576        msg_bytes = bytes(msg)
577        ret = os.write(self.socket.fileno(), msg_bytes)
578        if ret != -1:
579            assert ret == len(msg_bytes)
580
581    def parse_message(self, data: bytes):
582        if len(data) < 4:
583            raise OSError("Short read from rtsock: {} bytes".format(len(data)))
584        rtm_type = data[4]
585        if rtm_type not in self.msgmap:
586            return None
587
588    def write_data(self, data: bytes):
589        self.socket.send(data)
590
591    def read_data(self, seq: Optional[int] = None) -> bytes:
592        while True:
593            data = self.socket.recv(4096)
594            if seq is None:
595                break
596            if len(data) > sizeof(RtMsgHdr):
597                hdr = RtMsgHdr.from_buffer_copy(data)
598                if hdr.rtm_seq == seq:
599                    break
600        return data
601
602    def read_message(self) -> bytes:
603        data = self.read_data()
604        return self.parse_message(data)
605