1import ipaddress
2import socket
3
4import pytest
5from atf_python.sys.net.tools import ToolsHelper
6from atf_python.sys.net.vnet import SingleVnetTestTemplate
7from atf_python.sys.netlink.attrs import NlAttrIp
8from atf_python.sys.netlink.attrs import NlAttrU32
9from atf_python.sys.netlink.base_headers import NlmBaseFlags
10from atf_python.sys.netlink.base_headers import NlmGetFlags
11from atf_python.sys.netlink.base_headers import NlmNewFlags
12from atf_python.sys.netlink.base_headers import NlMsgType
13from atf_python.sys.netlink.netlink import NetlinkTestTemplate
14from atf_python.sys.netlink.netlink_route import NetlinkRtMessage
15from atf_python.sys.netlink.netlink_route import NlRtMsgType
16from atf_python.sys.netlink.netlink_route import RtattrType
17from atf_python.sys.netlink.utils import NlConst
18
19
20class TestRtNlRoute(NetlinkTestTemplate, SingleVnetTestTemplate):
21    IPV6_PREFIXES = ["2001:db8::1/64"]
22
23    def setup_method(self, method):
24        super().setup_method(method)
25        self.setup_netlink(NlConst.NETLINK_ROUTE)
26
27    @pytest.mark.timeout(5)
28    def test_add_route6_ll_gw(self):
29        epair_ifname = self.vnet.iface_alias_map["if1"].name
30        epair_ifindex = socket.if_nametoindex(epair_ifname)
31
32        msg = NetlinkRtMessage(self.helper, NlRtMsgType.RTM_NEWROUTE)
33        msg.set_request()
34        msg.add_nlflags([NlmNewFlags.NLM_F_CREATE])
35        msg.base_hdr.rtm_family = socket.AF_INET6
36        msg.base_hdr.rtm_dst_len = 64
37        msg.add_nla(NlAttrIp(RtattrType.RTA_DST, "2001:db8:2::"))
38        msg.add_nla(NlAttrIp(RtattrType.RTA_GATEWAY, "fe80::1"))
39        msg.add_nla(NlAttrU32(RtattrType.RTA_OIF, epair_ifindex))
40
41        rx_msg = self.get_reply(msg)
42        assert rx_msg.is_type(NlMsgType.NLMSG_ERROR)
43        assert rx_msg.error_code == 0
44
45        ToolsHelper.print_net_debug()
46        ToolsHelper.print_output("netstat -6onW")
47
48    @pytest.mark.timeout(20)
49    def test_buffer_override(self):
50        msg_flags = (
51            NlmBaseFlags.NLM_F_ACK.value
52            | NlmBaseFlags.NLM_F_REQUEST.value
53            | NlmNewFlags.NLM_F_CREATE.value
54        )
55
56        num_routes = 1000
57        base_address = bytearray(ipaddress.ip_address("2001:db8:ffff::").packed)
58        for i in range(num_routes):
59            base_address[7] = i % 256
60            base_address[6] = i // 256
61            prefix_address = ipaddress.IPv6Address(bytes(base_address))
62
63            msg = NetlinkRtMessage(self.helper, NlRtMsgType.RTM_NEWROUTE.value)
64            msg.nl_hdr.nlmsg_flags = msg_flags
65            msg.base_hdr.rtm_family = socket.AF_INET6
66            msg.base_hdr.rtm_dst_len = 65
67            msg.add_nla(NlAttrIp(RtattrType.RTA_DST, str(prefix_address)))
68            msg.add_nla(NlAttrIp(RtattrType.RTA_GATEWAY, "2001:db8::2"))
69
70            self.write_message(msg, silent=True)
71            rx_msg = self.read_message(silent=True)
72            assert rx_msg.is_type(NlMsgType.NLMSG_ERROR)
73            assert msg.nl_hdr.nlmsg_seq == rx_msg.nl_hdr.nlmsg_seq
74            assert rx_msg.error_code == 0
75        # Now, dump
76        msg = NetlinkRtMessage(self.helper, NlRtMsgType.RTM_GETROUTE.value)
77        msg.nl_hdr.nlmsg_flags = (
78            NlmBaseFlags.NLM_F_ACK.value
79            | NlmBaseFlags.NLM_F_REQUEST.value
80            | NlmGetFlags.NLM_F_ROOT.value
81            | NlmGetFlags.NLM_F_MATCH.value
82        )
83        msg.base_hdr.rtm_family = socket.AF_INET6
84        self.write_message(msg)
85        num_received = 0
86        while True:
87            rx_msg = self.read_message(silent=True)
88            if msg.nl_hdr.nlmsg_seq == rx_msg.nl_hdr.nlmsg_seq:
89                if rx_msg.is_type(NlMsgType.NLMSG_ERROR):
90                    if rx_msg.error_code != 0:
91                        raise ValueError(
92                            "unable to dump routes: error {}".format(rx_msg.error_code)
93                        )
94                if rx_msg.is_type(NlMsgType.NLMSG_DONE):
95                    break
96                if rx_msg.is_type(NlRtMsgType.RTM_NEWROUTE):
97                    if rx_msg.base_hdr.rtm_dst_len == 65:
98                        num_received += 1
99        assert num_routes == num_received
100