1#!/usr/local/bin/python3
2import struct
3from ctypes import sizeof
4from enum import Enum
5from typing import List
6from typing import NamedTuple
7
8from atf_python.sys.netlink.attrs import NlAttr
9from atf_python.sys.netlink.attrs import NlAttrNested
10from atf_python.sys.netlink.base_headers import NlmAckFlags
11from atf_python.sys.netlink.base_headers import NlmNewFlags
12from atf_python.sys.netlink.base_headers import NlmGetFlags
13from atf_python.sys.netlink.base_headers import NlmDeleteFlags
14from atf_python.sys.netlink.base_headers import NlmBaseFlags
15from atf_python.sys.netlink.base_headers import Nlmsghdr
16from atf_python.sys.netlink.base_headers import NlMsgType
17from atf_python.sys.netlink.utils import align4
18from atf_python.sys.netlink.utils import enum_or_int
19from atf_python.sys.netlink.utils import get_bitmask_str
20
21
22class NlMsgCategory(Enum):
23    UNKNOWN = 0
24    GET = 1
25    NEW = 2
26    DELETE = 3
27    ACK = 4
28
29
30class NlMsgProps(NamedTuple):
31    msg: Enum
32    category: NlMsgCategory
33
34
35class BaseNetlinkMessage(object):
36    def __init__(self, helper, nlmsg_type):
37        self.nlmsg_type = enum_or_int(nlmsg_type)
38        self.nla_list = []
39        self._orig_data = None
40        self.helper = helper
41        self.nl_hdr = Nlmsghdr(
42            nlmsg_type=self.nlmsg_type, nlmsg_seq=helper.get_seq(), nlmsg_pid=helper.pid
43        )
44        self.base_hdr = None
45
46    def set_request(self, need_ack=True):
47        self.add_nlflags([NlmBaseFlags.NLM_F_REQUEST])
48        if need_ack:
49            self.add_nlflags([NlmBaseFlags.NLM_F_ACK])
50
51    def add_nlflags(self, flags: List):
52        int_flags = 0
53        for flag in flags:
54            int_flags |= enum_or_int(flag)
55        self.nl_hdr.nlmsg_flags |= int_flags
56
57    def add_nla(self, nla):
58        self.nla_list.append(nla)
59
60    def _get_nla(self, nla_list, nla_type):
61        nla_type_raw = enum_or_int(nla_type)
62        for nla in nla_list:
63            if nla.nla_type == nla_type_raw:
64                return nla
65        return None
66
67    def get_nla(self, nla_type):
68        return self._get_nla(self.nla_list, nla_type)
69
70    @staticmethod
71    def parse_nl_header(data: bytes):
72        if len(data) < sizeof(Nlmsghdr):
73            raise ValueError("length less than netlink message header")
74        return Nlmsghdr.from_buffer_copy(data), sizeof(Nlmsghdr)
75
76    def is_type(self, nlmsg_type):
77        nlmsg_type_raw = enum_or_int(nlmsg_type)
78        return nlmsg_type_raw == self.nl_hdr.nlmsg_type
79
80    def is_reply(self, hdr):
81        return hdr.nlmsg_type == NlMsgType.NLMSG_ERROR.value
82
83    @property
84    def msg_name(self):
85        return "msg#{}".format(self._get_msg_type())
86
87    def _get_nl_category(self):
88        if self.is_reply(self.nl_hdr):
89            return NlMsgCategory.ACK
90        return NlMsgCategory.UNKNOWN
91
92    def get_nlm_flags_str(self):
93        category = self._get_nl_category()
94        flags = self.nl_hdr.nlmsg_flags
95
96        if category == NlMsgCategory.UNKNOWN:
97            return self.helper.get_bitmask_str(NlmBaseFlags, flags)
98        elif category == NlMsgCategory.GET:
99            flags_enum = NlmGetFlags
100        elif category == NlMsgCategory.NEW:
101            flags_enum = NlmNewFlags
102        elif category == NlMsgCategory.DELETE:
103            flags_enum = NlmDeleteFlags
104        elif category == NlMsgCategory.ACK:
105            flags_enum = NlmAckFlags
106        return get_bitmask_str([NlmBaseFlags, flags_enum], flags)
107
108    def print_nl_header(self, prepend=""):
109        # len=44, type=RTM_DELROUTE, flags=NLM_F_REQUEST|NLM_F_ACK, seq=1641163704, pid=0  # noqa: E501
110        hdr = self.nl_hdr
111        print(
112            "{}len={}, type={}, flags={}(0x{:X}), seq={}, pid={}".format(
113                prepend,
114                hdr.nlmsg_len,
115                self.msg_name,
116                self.get_nlm_flags_str(),
117                hdr.nlmsg_flags,
118                hdr.nlmsg_seq,
119                hdr.nlmsg_pid,
120            )
121        )
122
123    @classmethod
124    def from_bytes(cls, helper, data):
125        try:
126            hdr, hdrlen = BaseNetlinkMessage.parse_nl_header(data)
127            self = cls(helper, hdr.nlmsg_type)
128            self._orig_data = data
129            self.nl_hdr = hdr
130        except ValueError as e:
131            print("Failed to parse nl header: {}".format(e))
132            cls.print_as_bytes(data)
133            raise
134        return self
135
136    def print_message(self):
137        self.print_nl_header()
138
139    @staticmethod
140    def print_as_bytes(data: bytes, descr: str):
141        print("===vv {} (len:{:3d}) vv===".format(descr, len(data)))
142        off = 0
143        step = 16
144        while off < len(data):
145            for i in range(step):
146                if off + i < len(data):
147                    print(" {:02X}".format(data[off + i]), end="")
148            print("")
149            off += step
150        print("--------------------")
151
152
153class StdNetlinkMessage(BaseNetlinkMessage):
154    nl_attrs_map = {}
155
156    @classmethod
157    def from_bytes(cls, helper, data):
158        try:
159            hdr, hdrlen = BaseNetlinkMessage.parse_nl_header(data)
160            self = cls(helper, hdr.nlmsg_type)
161            self._orig_data = data
162            self.nl_hdr = hdr
163        except ValueError as e:
164            print("Failed to parse nl header: {}".format(e))
165            cls.print_as_bytes(data)
166            raise
167
168        offset = align4(hdrlen)
169        try:
170            base_hdr, hdrlen = self.parse_base_header(data[offset:])
171            self.base_hdr = base_hdr
172            offset += align4(hdrlen)
173            # XXX: CAP_ACK
174        except ValueError as e:
175            print("Failed to parse nl rt header: {}".format(e))
176            cls.print_as_bytes(data)
177            raise
178
179        orig_offset = offset
180        try:
181            nla_list, nla_len = self.parse_nla_list(data[offset:])
182            offset += nla_len
183            if offset != len(data):
184                raise ValueError(
185                    "{} bytes left at the end of the packet".format(len(data) - offset)
186                )  # noqa: E501
187            self.nla_list = nla_list
188        except ValueError as e:
189            print(
190                "Failed to parse nla attributes at offset {}: {}".format(orig_offset, e)
191            )  # noqa: E501
192            cls.print_as_bytes(data, "msg dump")
193            cls.print_as_bytes(data[orig_offset:], "failed block")
194            raise
195        return self
196
197    def parse_child(self, data: bytes, attr_key, attr_map):
198        attrs, _ = self.parse_attrs(data, attr_map)
199        return NlAttrNested(attr_key, attrs)
200
201    def parse_child_array(self, data: bytes, attr_key, attr_map):
202        ret = []
203        off = 0
204        while len(data) - off >= 4:
205            nla_len, raw_nla_type = struct.unpack("@HH", data[off : off + 4])
206            if nla_len + off > len(data):
207                raise ValueError(
208                    "attr length {} > than the remaining length {}".format(
209                        nla_len, len(data) - off
210                    )
211                )
212            nla_type = raw_nla_type & 0x3FFF
213            val = self.parse_child(data[off + 4 : off + nla_len], nla_type, attr_map)
214            ret.append(val)
215            off += align4(nla_len)
216        return NlAttrNested(attr_key, ret)
217
218    def parse_attrs(self, data: bytes, attr_map):
219        ret = []
220        off = 0
221        while len(data) - off >= 4:
222            nla_len, raw_nla_type = struct.unpack("@HH", data[off : off + 4])
223            if nla_len + off > len(data):
224                raise ValueError(
225                    "attr length {} > than the remaining length {}".format(
226                        nla_len, len(data) - off
227                    )
228                )
229            nla_type = raw_nla_type & 0x3FFF
230            if nla_type in attr_map:
231                v = attr_map[nla_type]
232                val = v["ad"].cls.from_bytes(data[off : off + nla_len], v["ad"].val)
233                if "child" in v:
234                    # nested
235                    child_data = data[off + 4 : off + nla_len]
236                    if v.get("is_array", False):
237                        # Array of nested attributes
238                        val = self.parse_child_array(
239                            child_data, v["ad"].val, v["child"]
240                        )
241                    else:
242                        val = self.parse_child(child_data, v["ad"].val, v["child"])
243            else:
244                # unknown attribute
245                val = NlAttr(raw_nla_type, data[off + 4 : off + nla_len])
246            ret.append(val)
247            off += align4(nla_len)
248        return ret, off
249
250    def parse_nla_list(self, data: bytes) -> List[NlAttr]:
251        return self.parse_attrs(data, self.nl_attrs_map)
252
253    def __bytes__(self):
254        ret = bytes()
255        for nla in self.nla_list:
256            ret += bytes(nla)
257        ret = bytes(self.base_hdr) + ret
258        self.nl_hdr.nlmsg_len = len(ret) + sizeof(Nlmsghdr)
259        return bytes(self.nl_hdr) + ret
260
261    def _get_msg_type(self):
262        return self.nl_hdr.nlmsg_type
263
264    @property
265    def msg_props(self):
266        msg_type = self._get_msg_type()
267        for msg_props in self.messages:
268            if msg_props.msg.value == msg_type:
269                return msg_props
270        return None
271
272    @property
273    def msg_name(self):
274        msg_props = self.msg_props
275        if msg_props is not None:
276            return msg_props.msg.name
277        return super().msg_name
278
279    def print_base_header(self, hdr, prepend=""):
280        pass
281
282    def print_message(self):
283        self.print_nl_header()
284        self.print_base_header(self.base_hdr, " ")
285        for nla in self.nla_list:
286            nla.print_attr("  ")
287