1#!/usr/local/bin/python3
2import struct
3from ctypes import sizeof
4from typing import List
5
6from atf_python.sys.netlink.attrs import NlAttr
7from atf_python.sys.netlink.attrs import NlAttrNested
8from atf_python.sys.netlink.base_headers import NlmBaseFlags
9from atf_python.sys.netlink.base_headers import Nlmsghdr
10from atf_python.sys.netlink.base_headers import NlMsgType
11from atf_python.sys.netlink.utils import align4
12from atf_python.sys.netlink.utils import enum_or_int
13
14
15class BaseNetlinkMessage(object):
16    def __init__(self, helper, nlmsg_type):
17        self.nlmsg_type = enum_or_int(nlmsg_type)
18        self.nla_list = []
19        self._orig_data = None
20        self.helper = helper
21        self.nl_hdr = Nlmsghdr(
22            nlmsg_type=self.nlmsg_type, nlmsg_seq=helper.get_seq(), nlmsg_pid=helper.pid
23        )
24        self.base_hdr = None
25
26    def set_request(self, need_ack=True):
27        self.add_nlflags([NlmBaseFlags.NLM_F_REQUEST])
28        if need_ack:
29            self.add_nlflags([NlmBaseFlags.NLM_F_ACK])
30
31    def add_nlflags(self, flags: List):
32        int_flags = 0
33        for flag in flags:
34            int_flags |= enum_or_int(flag)
35        self.nl_hdr.nlmsg_flags |= int_flags
36
37    def add_nla(self, nla):
38        self.nla_list.append(nla)
39
40    def _get_nla(self, nla_list, nla_type):
41        nla_type_raw = enum_or_int(nla_type)
42        for nla in nla_list:
43            if nla.nla_type == nla_type_raw:
44                return nla
45        return None
46
47    def get_nla(self, nla_type):
48        return self._get_nla(self.nla_list, nla_type)
49
50    @staticmethod
51    def parse_nl_header(data: bytes):
52        if len(data) < sizeof(Nlmsghdr):
53            raise ValueError("length less than netlink message header")
54        return Nlmsghdr.from_buffer_copy(data), sizeof(Nlmsghdr)
55
56    def is_type(self, nlmsg_type):
57        nlmsg_type_raw = enum_or_int(nlmsg_type)
58        return nlmsg_type_raw == self.nl_hdr.nlmsg_type
59
60    def is_reply(self, hdr):
61        return hdr.nlmsg_type == NlMsgType.NLMSG_ERROR.value
62
63    def print_nl_header(self, hdr, prepend=""):
64        # len=44, type=RTM_DELROUTE, flags=NLM_F_REQUEST|NLM_F_ACK, seq=1641163704, pid=0  # noqa: E501
65        is_reply = self.is_reply(hdr)
66        msg_name = self.helper.get_nlmsg_name(hdr.nlmsg_type)
67        print(
68            "{}len={}, type={}, flags={}(0x{:X}), seq={}, pid={}".format(
69                prepend,
70                hdr.nlmsg_len,
71                msg_name,
72                self.helper.get_nlm_flags_str(
73                    msg_name, is_reply, hdr.nlmsg_flags
74                ),  # noqa: E501
75                hdr.nlmsg_flags,
76                hdr.nlmsg_seq,
77                hdr.nlmsg_pid,
78            )
79        )
80
81    @classmethod
82    def from_bytes(cls, helper, data):
83        try:
84            hdr, hdrlen = BaseNetlinkMessage.parse_nl_header(data)
85            self = cls(helper, hdr.nlmsg_type)
86            self._orig_data = data
87            self.nl_hdr = hdr
88        except ValueError as e:
89            print("Failed to parse nl header: {}".format(e))
90            cls.print_as_bytes(data)
91            raise
92        return self
93
94    def print_message(self):
95        self.print_nl_header(self.nl_hdr)
96
97    @staticmethod
98    def print_as_bytes(data: bytes, descr: str):
99        print("===vv {} (len:{:3d}) vv===".format(descr, len(data)))
100        off = 0
101        step = 16
102        while off < len(data):
103            for i in range(step):
104                if off + i < len(data):
105                    print(" {:02X}".format(data[off + i]), end="")
106            print("")
107            off += step
108        print("--------------------")
109
110
111class StdNetlinkMessage(BaseNetlinkMessage):
112    nl_attrs_map = {}
113
114    @classmethod
115    def from_bytes(cls, helper, data):
116        try:
117            hdr, hdrlen = BaseNetlinkMessage.parse_nl_header(data)
118            self = cls(helper, hdr.nlmsg_type)
119            self._orig_data = data
120            self.nl_hdr = hdr
121        except ValueError as e:
122            print("Failed to parse nl header: {}".format(e))
123            cls.print_as_bytes(data)
124            raise
125
126        offset = align4(hdrlen)
127        try:
128            base_hdr, hdrlen = self.parse_base_header(data[offset:])
129            self.base_hdr = base_hdr
130            offset += align4(hdrlen)
131            # XXX: CAP_ACK
132        except ValueError as e:
133            print("Failed to parse nl rt header: {}".format(e))
134            cls.print_as_bytes(data)
135            raise
136
137        orig_offset = offset
138        try:
139            nla_list, nla_len = self.parse_nla_list(data[offset:])
140            offset += nla_len
141            if offset != len(data):
142                raise ValueError(
143                    "{} bytes left at the end of the packet".format(len(data) - offset)
144                )  # noqa: E501
145            self.nla_list = nla_list
146        except ValueError as e:
147            print(
148                "Failed to parse nla attributes at offset {}: {}".format(orig_offset, e)
149            )  # noqa: E501
150            cls.print_as_bytes(data, "msg dump")
151            cls.print_as_bytes(data[orig_offset:], "failed block")
152            raise
153        return self
154
155    def parse_attrs(self, data: bytes, attr_map):
156        ret = []
157        off = 0
158        while len(data) - off >= 4:
159            nla_len, raw_nla_type = struct.unpack("@HH", data[off:off + 4])
160            if nla_len + off > len(data):
161                raise ValueError(
162                    "attr length {} > than the remaining length {}".format(
163                        nla_len, len(data) - off
164                    )
165                )
166            nla_type = raw_nla_type & 0x3F
167            if nla_type in attr_map:
168                v = attr_map[nla_type]
169                val = v["ad"].cls.from_bytes(data[off:off + nla_len], v["ad"].val)
170                if "child" in v:
171                    # nested
172                    attrs, _ = self.parse_attrs(
173                        data[off + 4:off + nla_len], v["child"]
174                    )
175                    val = NlAttrNested(v["ad"].val, attrs)
176            else:
177                # unknown attribute
178                val = NlAttr(raw_nla_type, data[off + 4:off + nla_len])
179            ret.append(val)
180            off += align4(nla_len)
181        return ret, off
182
183    def parse_nla_list(self, data: bytes) -> List[NlAttr]:
184        return self.parse_attrs(data, self.nl_attrs_map)
185
186    def __bytes__(self):
187        ret = bytes()
188        for nla in self.nla_list:
189            ret += bytes(nla)
190        ret = bytes(self.base_hdr) + ret
191        self.nl_hdr.nlmsg_len = len(ret) + sizeof(Nlmsghdr)
192        return bytes(self.nl_hdr) + ret
193
194    def print_base_header(self, hdr, prepend=""):
195        pass
196
197    def print_message(self):
198        self.print_nl_header(self.nl_hdr)
199        self.print_base_header(self.base_hdr, " ")
200        for nla in self.nla_list:
201            nla.print_attr("  ")
202