1#!/usr/local/bin/python3
2import os
3import socket
4import sys
5from ctypes import c_int
6from ctypes import c_ubyte
7from ctypes import c_uint
8from ctypes import c_ushort
9from ctypes import sizeof
10from ctypes import Structure
11from enum import auto
12from enum import Enum
13
14from atf_python.sys.netlink.attrs import NlAttr
15from atf_python.sys.netlink.attrs import NlAttrStr
16from atf_python.sys.netlink.attrs import NlAttrU32
17from atf_python.sys.netlink.base_headers import GenlMsgHdr
18from atf_python.sys.netlink.base_headers import NlmBaseFlags
19from atf_python.sys.netlink.base_headers import Nlmsghdr
20from atf_python.sys.netlink.base_headers import NlMsgType
21from atf_python.sys.netlink.message import BaseNetlinkMessage
22from atf_python.sys.netlink.message import NlMsgCategory
23from atf_python.sys.netlink.message import NlMsgProps
24from atf_python.sys.netlink.message import StdNetlinkMessage
25from atf_python.sys.netlink.netlink_generic import GenlCtrlAttrType
26from atf_python.sys.netlink.netlink_generic import GenlCtrlMsgType
27from atf_python.sys.netlink.netlink_generic import handler_classes as genl_classes
28from atf_python.sys.netlink.netlink_route import handler_classes as rt_classes
29from atf_python.sys.netlink.utils import align4
30from atf_python.sys.netlink.utils import AttrDescr
31from atf_python.sys.netlink.utils import build_propmap
32from atf_python.sys.netlink.utils import enum_or_int
33from atf_python.sys.netlink.utils import get_bitmask_map
34from atf_python.sys.netlink.utils import NlConst
35from atf_python.sys.netlink.utils import prepare_attrs_map
36
37
38class SockaddrNl(Structure):
39    _fields_ = [
40        ("nl_len", c_ubyte),
41        ("nl_family", c_ubyte),
42        ("nl_pad", c_ushort),
43        ("nl_pid", c_uint),
44        ("nl_groups", c_uint),
45    ]
46
47
48class Nlmsgdone(Structure):
49    _fields_ = [
50        ("error", c_int),
51    ]
52
53
54class Nlmsgerr(Structure):
55    _fields_ = [
56        ("error", c_int),
57        ("msg", Nlmsghdr),
58    ]
59
60
61class NlErrattrType(Enum):
62    NLMSGERR_ATTR_UNUSED = 0
63    NLMSGERR_ATTR_MSG = auto()
64    NLMSGERR_ATTR_OFFS = auto()
65    NLMSGERR_ATTR_COOKIE = auto()
66    NLMSGERR_ATTR_POLICY = auto()
67
68
69class AddressFamilyLinux(Enum):
70    AF_INET = socket.AF_INET
71    AF_INET6 = socket.AF_INET6
72    AF_NETLINK = 16
73
74
75class AddressFamilyBsd(Enum):
76    AF_INET = socket.AF_INET
77    AF_INET6 = socket.AF_INET6
78    AF_NETLINK = 38
79
80
81class NlHelper:
82    def __init__(self):
83        self._pmap = {}
84        self._af_cls = self.get_af_cls()
85        self._seq_counter = 1
86        self.pid = os.getpid()
87
88    def get_seq(self):
89        ret = self._seq_counter
90        self._seq_counter += 1
91        return ret
92
93    def get_af_cls(self):
94        if sys.platform.startswith("freebsd"):
95            cls = AddressFamilyBsd
96        else:
97            cls = AddressFamilyLinux
98        return cls
99
100    def get_propmap(self, cls):
101        if cls not in self._pmap:
102            self._pmap[cls] = build_propmap(cls)
103        return self._pmap[cls]
104
105    def get_name_propmap(self, cls):
106        ret = {}
107        for prop in dir(cls):
108            if not prop.startswith("_"):
109                ret[prop] = getattr(cls, prop).value
110        return ret
111
112    def get_attr_byval(self, cls, attr_val):
113        propmap = self.get_propmap(cls)
114        return propmap.get(attr_val)
115
116    def get_af_name(self, family):
117        v = self.get_attr_byval(self._af_cls, family)
118        if v is not None:
119            return v
120        return "af#{}".format(family)
121
122    def get_af_value(self, family_str: str) -> int:
123        propmap = self.get_name_propmap(self._af_cls)
124        return propmap.get(family_str)
125
126    def get_bitmask_str(self, cls, val):
127        bmap = get_bitmask_map(self.get_propmap(cls), val)
128        return ",".join([v for k, v in bmap.items()])
129
130    @staticmethod
131    def get_bitmask_str_uncached(cls, val):
132        pmap = NlHelper.build_propmap(cls)
133        bmap = NlHelper.get_bitmask_map(pmap, val)
134        return ",".join([v for k, v in bmap.items()])
135
136
137nldone_attrs = prepare_attrs_map([])
138
139nlerr_attrs = prepare_attrs_map(
140    [
141        AttrDescr(NlErrattrType.NLMSGERR_ATTR_MSG, NlAttrStr),
142        AttrDescr(NlErrattrType.NLMSGERR_ATTR_OFFS, NlAttrU32),
143        AttrDescr(NlErrattrType.NLMSGERR_ATTR_COOKIE, NlAttr),
144    ]
145)
146
147
148class NetlinkDoneMessage(StdNetlinkMessage):
149    messages = [NlMsgProps(NlMsgType.NLMSG_DONE, NlMsgCategory.ACK)]
150    nl_attrs_map = nldone_attrs
151
152    @property
153    def error_code(self):
154        return self.base_hdr.error
155
156    def parse_base_header(self, data):
157        if len(data) < sizeof(Nlmsgdone):
158            raise ValueError("length less than nlmsgdone header")
159        done_hdr = Nlmsgdone.from_buffer_copy(data)
160        sz = sizeof(Nlmsgdone)
161        return (done_hdr, sz)
162
163    def print_base_header(self, hdr, prepend=""):
164        print("{}error={}".format(prepend, hdr.error))
165
166
167class NetlinkErrorMessage(StdNetlinkMessage):
168    messages = [NlMsgProps(NlMsgType.NLMSG_ERROR, NlMsgCategory.ACK)]
169    nl_attrs_map = nlerr_attrs
170
171    @property
172    def error_code(self):
173        return self.base_hdr.error
174
175    @property
176    def error_str(self):
177        nla = self.get_nla(NlErrattrType.NLMSGERR_ATTR_MSG)
178        if nla:
179            return nla.text
180        return None
181
182    @property
183    def error_offset(self):
184        nla = self.get_nla(NlErrattrType.NLMSGERR_ATTR_OFFS)
185        if nla:
186            return nla.u32
187        return None
188
189    @property
190    def cookie(self):
191        return self.get_nla(NlErrattrType.NLMSGERR_ATTR_COOKIE)
192
193    def parse_base_header(self, data):
194        if len(data) < sizeof(Nlmsgerr):
195            raise ValueError("length less than nlmsgerr header")
196        err_hdr = Nlmsgerr.from_buffer_copy(data)
197        sz = sizeof(Nlmsgerr)
198        if (self.nl_hdr.nlmsg_flags & 0x100) == 0:
199            sz += align4(err_hdr.msg.nlmsg_len - sizeof(Nlmsghdr))
200        return (err_hdr, sz)
201
202    def print_base_header(self, errhdr, prepend=""):
203        print("{}error={}, ".format(prepend, errhdr.error), end="")
204        hdr = errhdr.msg
205        print(
206            "{}len={}, type={}, flags={}(0x{:X}), seq={}, pid={}".format(
207                prepend,
208                hdr.nlmsg_len,
209                "msg#{}".format(hdr.nlmsg_type),
210                self.helper.get_bitmask_str(NlmBaseFlags, hdr.nlmsg_flags),
211                hdr.nlmsg_flags,
212                hdr.nlmsg_seq,
213                hdr.nlmsg_pid,
214            )
215        )
216
217
218core_classes = {
219    "netlink_core": [
220        NetlinkDoneMessage,
221        NetlinkErrorMessage,
222    ],
223}
224
225
226class Nlsock:
227    HANDLER_CLASSES = [core_classes, rt_classes, genl_classes]
228
229    def __init__(self, family, helper):
230        self.helper = helper
231        self.sock_fd = self._setup_netlink(family)
232        self._sock_family = family
233        self._data = bytes()
234        self.msgmap = self.build_msgmap()
235        self._family_map = {
236            NlConst.GENL_ID_CTRL: "nlctrl",
237        }
238
239    def build_msgmap(self):
240        handler_classes = {}
241        for d in self.HANDLER_CLASSES:
242            handler_classes.update(d)
243        xmap = {}
244        # 'family_name': [class.messages[MsgProps.msg],  ]
245        for family_id, family_classes in handler_classes.items():
246            xmap[family_id] = {}
247            for cls in family_classes:
248                for msg_props in cls.messages:
249                    xmap[family_id][enum_or_int(msg_props.msg)] = cls
250        return xmap
251
252    def _setup_netlink(self, netlink_family) -> int:
253        family = self.helper.get_af_value("AF_NETLINK")
254        s = socket.socket(family, socket.SOCK_RAW, netlink_family)
255        s.setsockopt(270, 10, 1)  # NETLINK_CAP_ACK
256        s.setsockopt(270, 11, 1)  # NETLINK_EXT_ACK
257        return s
258
259    def set_groups(self, mask: int):
260        self.sock_fd.setsockopt(socket.SOL_SOCKET, 1, mask)
261        # snl = SockaddrNl(nl_len = sizeof(SockaddrNl), nl_family=38,
262        #                  nl_pid=self.pid, nl_groups=mask)
263        # xbuffer = create_string_buffer(sizeof(SockaddrNl))
264        # memmove(xbuffer, addressof(snl), sizeof(SockaddrNl))
265        # k = struct.pack("@BBHII", 12, 38, 0, self.pid, mask)
266        # self.sock_fd.bind(k)
267
268    def join_group(self, group_id: int):
269        self.sock_fd.setsockopt(270, 1, group_id)
270
271    def write_message(self, msg, verbose=True):
272        if verbose:
273            print("vvvvvvvv OUT vvvvvvvv")
274            msg.print_message()
275        msg_bytes = bytes(msg)
276        try:
277            ret = os.write(self.sock_fd.fileno(), msg_bytes)
278            assert ret == len(msg_bytes)
279        except Exception as e:
280            print("write({}) -> {}".format(len(msg_bytes), e))
281
282    def parse_message(self, data: bytes):
283        if len(data) < sizeof(Nlmsghdr):
284            raise Exception("Short read from nl: {} bytes".format(len(data)))
285        hdr = Nlmsghdr.from_buffer_copy(data)
286        if hdr.nlmsg_type < 16:
287            family_name = "netlink_core"
288            nlmsg_type = hdr.nlmsg_type
289        elif self._sock_family == NlConst.NETLINK_ROUTE:
290            family_name = "netlink_route"
291            nlmsg_type = hdr.nlmsg_type
292        else:
293            # Genetlink
294            if len(data) < sizeof(Nlmsghdr) + sizeof(GenlMsgHdr):
295                raise Exception("Short read from genl: {} bytes".format(len(data)))
296            family_name = self._family_map.get(hdr.nlmsg_type, "")
297            ghdr = GenlMsgHdr.from_buffer_copy(data[sizeof(Nlmsghdr):])
298            nlmsg_type = ghdr.cmd
299        cls = self.msgmap.get(family_name, {}).get(nlmsg_type)
300        if not cls:
301            cls = BaseNetlinkMessage
302        return cls.from_bytes(self.helper, data)
303
304    def get_genl_family_id(self, family_name):
305        hdr = Nlmsghdr(
306            nlmsg_type=NlConst.GENL_ID_CTRL,
307            nlmsg_flags=NlmBaseFlags.NLM_F_REQUEST.value,
308            nlmsg_seq=self.helper.get_seq(),
309        )
310        ghdr = GenlMsgHdr(cmd=GenlCtrlMsgType.CTRL_CMD_GETFAMILY.value)
311        nla = NlAttrStr(GenlCtrlAttrType.CTRL_ATTR_FAMILY_NAME, family_name)
312        hdr.nlmsg_len = sizeof(Nlmsghdr) + sizeof(GenlMsgHdr) + len(bytes(nla))
313
314        msg_bytes = bytes(hdr) + bytes(ghdr) + bytes(nla)
315        self.write_data(msg_bytes)
316        while True:
317            rx_msg = self.read_message()
318            if hdr.nlmsg_seq == rx_msg.nl_hdr.nlmsg_seq:
319                if rx_msg.is_type(NlMsgType.NLMSG_ERROR):
320                    if rx_msg.error_code != 0:
321                        raise ValueError("unable to get family {}".format(family_name))
322                else:
323                    family_id = rx_msg.get_nla(GenlCtrlAttrType.CTRL_ATTR_FAMILY_ID).u16
324                    self._family_map[family_id] = family_name
325                    return family_id
326        raise ValueError("unable to get family {}".format(family_name))
327
328    def write_data(self, data: bytes):
329        self.sock_fd.send(data)
330
331    def read_data(self):
332        while True:
333            data = self.sock_fd.recv(65535)
334            self._data += data
335            if len(self._data) >= sizeof(Nlmsghdr):
336                break
337
338    def read_message(self) -> bytes:
339        if len(self._data) < sizeof(Nlmsghdr):
340            self.read_data()
341        hdr = Nlmsghdr.from_buffer_copy(self._data)
342        while hdr.nlmsg_len > len(self._data):
343            self.read_data()
344        raw_msg = self._data[: hdr.nlmsg_len]
345        self._data = self._data[hdr.nlmsg_len:]
346        return self.parse_message(raw_msg)
347
348    def get_reply(self, tx_msg):
349        self.write_message(tx_msg)
350        while True:
351            rx_msg = self.read_message()
352            if tx_msg.nl_hdr.nlmsg_seq == rx_msg.nl_hdr.nlmsg_seq:
353                return rx_msg
354
355
356class NetlinkMultipartIterator(object):
357    def __init__(self, obj, seq_number: int, msg_type):
358        self._obj = obj
359        self._seq = seq_number
360        self._msg_type = msg_type
361
362    def __iter__(self):
363        return self
364
365    def __next__(self):
366        msg = self._obj.read_message()
367        if self._seq != msg.nl_hdr.nlmsg_seq:
368            raise ValueError("bad sequence number")
369        if msg.is_type(NlMsgType.NLMSG_ERROR):
370            raise ValueError(
371                "error while handling multipart msg: {}".format(msg.error_code)
372            )
373        elif msg.is_type(NlMsgType.NLMSG_DONE):
374            if msg.error_code == 0:
375                raise StopIteration
376            raise ValueError(
377                "error listing some parts of the multipart msg: {}".format(
378                    msg.error_code
379                )
380            )
381        elif not msg.is_type(self._msg_type):
382            raise ValueError("bad message type: {}".format(msg))
383        return msg
384
385
386class NetlinkTestTemplate(object):
387    REQUIRED_MODULES = ["netlink"]
388
389    def setup_netlink(self, netlink_family: NlConst):
390        self.helper = NlHelper()
391        self.nlsock = Nlsock(netlink_family, self.helper)
392
393    def write_message(self, msg, silent=False):
394        if not silent:
395            print("")
396            print("============= >> TX MESSAGE =============")
397            msg.print_message()
398            msg.print_as_bytes(bytes(msg), "-- DATA --")
399        self.nlsock.write_data(bytes(msg))
400
401    def read_message(self, silent=False):
402        msg = self.nlsock.read_message()
403        if not silent:
404            print("")
405            print("============= << RX MESSAGE =============")
406            msg.print_message()
407        return msg
408
409    def get_reply(self, tx_msg):
410        self.write_message(tx_msg)
411        while True:
412            rx_msg = self.read_message()
413            if tx_msg.nl_hdr.nlmsg_seq == rx_msg.nl_hdr.nlmsg_seq:
414                return rx_msg
415
416    def read_msg_list(self, seq, msg_type):
417        return list(NetlinkMultipartIterator(self, seq, msg_type))
418