1#!/usr/bin/env python3
2import os
3import socket
4import struct
5import subprocess
6import sys
7from ctypes import c_byte
8from ctypes import c_char
9from ctypes import c_int
10from ctypes import c_long
11from ctypes import c_uint32
12from ctypes import c_uint8
13from ctypes import c_ulong
14from ctypes import c_ushort
15from ctypes import sizeof
16from ctypes import Structure
17from enum import Enum
18from typing import Any
19from typing import Dict
20from typing import List
21from typing import NamedTuple
22from typing import Optional
23from typing import Union
24
25from atf_python.sys.netpfil.ipfw.insn_headers import IpFwOpcode
26from atf_python.sys.netpfil.ipfw.insn_headers import IcmpRejectCode
27from atf_python.sys.netpfil.ipfw.insn_headers import Icmp6RejectCode
28from atf_python.sys.netpfil.ipfw.utils import AttrDescr
29from atf_python.sys.netpfil.ipfw.utils import enum_or_int
30from atf_python.sys.netpfil.ipfw.utils import enum_from_int
31from atf_python.sys.netpfil.ipfw.utils import prepare_attrs_map
32
33
34insn_actions = (
35    IpFwOpcode.O_CHECK_STATE.value,
36    IpFwOpcode.O_REJECT.value,
37    IpFwOpcode.O_UNREACH6.value,
38    IpFwOpcode.O_ACCEPT.value,
39    IpFwOpcode.O_DENY.value,
40    IpFwOpcode.O_COUNT.value,
41    IpFwOpcode.O_NAT.value,
42    IpFwOpcode.O_QUEUE.value,
43    IpFwOpcode.O_PIPE.value,
44    IpFwOpcode.O_SKIPTO.value,
45    IpFwOpcode.O_NETGRAPH.value,
46    IpFwOpcode.O_NGTEE.value,
47    IpFwOpcode.O_DIVERT.value,
48    IpFwOpcode.O_TEE.value,
49    IpFwOpcode.O_CALLRETURN.value,
50    IpFwOpcode.O_FORWARD_IP.value,
51    IpFwOpcode.O_FORWARD_IP6.value,
52    IpFwOpcode.O_SETFIB.value,
53    IpFwOpcode.O_SETDSCP.value,
54    IpFwOpcode.O_REASS.value,
55    IpFwOpcode.O_SETMARK.value,
56    IpFwOpcode.O_EXTERNAL_ACTION.value,
57)
58
59
60class IpFwInsn(Structure):
61    _fields_ = [
62        ("opcode", c_uint8),
63        ("length", c_uint8),
64        ("arg1", c_ushort),
65    ]
66
67
68class BaseInsn(object):
69    obj_enum_class = IpFwOpcode
70
71    def __init__(self, opcode, is_or, is_not, arg1):
72        if isinstance(opcode, Enum):
73            self.obj_type = opcode.value
74            self._enum = opcode
75        else:
76            self.obj_type = opcode
77            self._enum = enum_from_int(self.obj_enum_class, self.obj_type)
78        self.is_or = is_or
79        self.is_not = is_not
80        self.arg1 = arg1
81        self.is_action = self.obj_type in insn_actions
82        self.ilen = 1
83        self.obj_list = []
84
85    @property
86    def obj_name(self):
87        if self._enum is not None:
88            return self._enum.name
89        else:
90            return "opcode#{}".format(self.obj_type)
91
92    @staticmethod
93    def get_insn_len(data: bytes) -> int:
94        (opcode_len,) = struct.unpack("@B", data[1:2])
95        return opcode_len & 0x3F
96
97    @classmethod
98    def _validate_len(cls, data, valid_options=None):
99        if len(data) < 4:
100            raise ValueError("opcode too short")
101        opcode_type, opcode_len = struct.unpack("@BB", data[:2])
102        if len(data) != ((opcode_len & 0x3F) * 4):
103            raise ValueError("wrong length")
104        if valid_options and len(data) not in valid_options:
105            raise ValueError(
106                "len {} not in {} for {}".format(
107                    len(data), valid_options,
108                    enum_from_int(cls.obj_enum_class, data[0])
109                )
110            )
111
112    @classmethod
113    def _validate(cls, data):
114        cls._validate_len(data)
115
116    @classmethod
117    def _parse(cls, data):
118        insn = IpFwInsn.from_buffer_copy(data[:4])
119        is_or = (insn.length & 0x40) != 0
120        is_not = (insn.length & 0x80) != 0
121        return cls(opcode=insn.opcode, is_or=is_or, is_not=is_not, arg1=insn.arg1)
122
123    @classmethod
124    def from_bytes(cls, data, attr_type_enum):
125        cls._validate(data)
126        opcode = cls._parse(data)
127        opcode._enum = attr_type_enum
128        return opcode
129
130    def __bytes__(self):
131        raise NotImplementedError()
132
133    def print_obj(self, prepend=""):
134        is_or = ""
135        if self.is_or:
136            is_or = " [OR]\\"
137        is_not = ""
138        if self.is_not:
139            is_not = "[!] "
140        print(
141            "{}{}len={} type={}({}){}{}".format(
142                prepend,
143                is_not,
144                len(bytes(self)),
145                self.obj_name,
146                self.obj_type,
147                self._print_obj_value(),
148                is_or,
149            )
150        )
151
152    def _print_obj_value(self):
153        raise NotImplementedError()
154
155    def print_obj_hex(self, prepend=""):
156        print(prepend)
157        print()
158        print(" ".join(["x{:02X}".format(b) for b in bytes(self)]))
159
160    @staticmethod
161    def parse_insns(data, attr_map):
162        ret = []
163        off = 0
164        while off + sizeof(IpFwInsn) <= len(data):
165            hdr = IpFwInsn.from_buffer_copy(data[off : off + sizeof(IpFwInsn)])
166            insn_len = (hdr.length & 0x3F) * 4
167            if off + insn_len > len(data):
168                raise ValueError("wrng length")
169            # print("GET insn type {} len {}".format(hdr.opcode, insn_len))
170            attr = attr_map.get(hdr.opcode, None)
171            if attr is None:
172                cls = InsnUnknown
173                type_enum = enum_from_int(BaseInsn.obj_enum_class, hdr.opcode)
174            else:
175                cls = attr["ad"].cls
176                type_enum = attr["ad"].val
177            insn = cls.from_bytes(data[off : off + insn_len], type_enum)
178            ret.append(insn)
179            off += insn_len
180
181        if off != len(data):
182            raise ValueError("empty space")
183        return ret
184
185
186class Insn(BaseInsn):
187    def __init__(self, opcode, is_or=False, is_not=False, arg1=0):
188        super().__init__(opcode, is_or=is_or, is_not=is_not, arg1=arg1)
189
190    @classmethod
191    def _validate(cls, data):
192        cls._validate_len(data, [4])
193
194    def __bytes__(self):
195        length = self.ilen
196        if self.is_or:
197            length |= 0x40
198        if self.is_not:
199            length | 0x80
200        insn = IpFwInsn(opcode=self.obj_type, length=length, arg1=enum_or_int(self.arg1))
201        return bytes(insn)
202
203    def _print_obj_value(self):
204        return " arg1={}".format(self.arg1)
205
206
207class InsnUnknown(Insn):
208    @classmethod
209    def _validate(cls, data):
210        cls._validate_len(data)
211
212    @classmethod
213    def _parse(cls, data):
214        self = super()._parse(data)
215        self._data = data
216        return self
217
218    def __bytes__(self):
219        return self._data
220
221    def _print_obj_value(self):
222        return " " + " ".join(["x{:02X}".format(b) for b in self._data])
223
224
225class InsnEmpty(Insn):
226    @classmethod
227    def _validate(cls, data):
228        cls._validate_len(data, [4])
229        insn = IpFwInsn.from_buffer_copy(data[:4])
230        if insn.arg1 != 0:
231            raise ValueError("arg1 should be empty")
232
233    def _print_obj_value(self):
234        return ""
235
236
237class InsnComment(Insn):
238    def __init__(self, opcode=IpFwOpcode.O_NOP, is_or=False, is_not=False, arg1=0, comment=""):
239        super().__init__(opcode, is_or=is_or, is_not=is_not, arg1=arg1)
240        if comment:
241            self.comment = comment
242        else:
243            self.comment = ""
244
245    @classmethod
246    def _validate(cls, data):
247        cls._validate_len(data)
248        if len(data) > 88:
249            raise ValueError("comment too long")
250
251    @classmethod
252    def _parse(cls, data):
253        self = super()._parse(data)
254        # Comment encoding can be anything,
255        # use utf-8 to ease debugging
256        max_len = 0
257        for b in range(4, len(data)):
258            if data[b] == b"\0":
259                break
260            max_len += 1
261        self.comment = data[4:max_len].decode("utf-8")
262        return self
263
264    def __bytes__(self):
265        ret = super().__bytes__()
266        comment_bytes = self.comment.encode("utf-8") + b"\0"
267        if len(comment_bytes) % 4 > 0:
268            comment_bytes += b"\0" * (4 - (len(comment_bytes) % 4))
269        ret += comment_bytes
270        return ret
271
272    def _print_obj_value(self):
273        return " comment='{}'".format(self.comment)
274
275
276class InsnProto(Insn):
277    def __init__(self, opcode=IpFwOpcode.O_PROTO, is_or=False, is_not=False, arg1=0):
278        super().__init__(opcode, is_or=is_or, is_not=is_not, arg1=arg1)
279
280    def _print_obj_value(self):
281        known_map = {6: "TCP", 17: "UDP", 41: "IPV6"}
282        proto = self.arg1
283        if proto in known_map:
284            return " proto={}".format(known_map[proto])
285        else:
286            return " proto=#{}".format(proto)
287
288
289class InsnU32(Insn):
290    def __init__(self, opcode, is_or=False, is_not=False, arg1=0, u32=0):
291        super().__init__(opcode, is_or=is_or, is_not=is_not, arg1=arg1)
292        self.u32 = u32
293        self.ilen = 2
294
295    @classmethod
296    def _validate(cls, data):
297        cls._validate_len(data, [8])
298
299    @classmethod
300    def _parse(cls, data):
301        self = super()._parse(data[:4])
302        self.u32 = struct.unpack("@I", data[4:8])[0]
303        return self
304
305    def __bytes__(self):
306        return super().__bytes__() + struct.pack("@I", self.u32)
307
308    def _print_obj_value(self):
309        return " arg1={} u32={}".format(self.arg1, self.u32)
310
311
312class InsnProb(InsnU32):
313    def __init__(
314        self,
315        opcode=IpFwOpcode.O_PROB,
316        is_or=False,
317        is_not=False,
318        arg1=0,
319        u32=0,
320        prob=0.0,
321    ):
322        super().__init__(opcode, is_or=is_or, is_not=is_not)
323        self.prob = prob
324
325    @property
326    def prob(self):
327        return 1.0 * self.u32 / 0x7FFFFFFF
328
329    @prob.setter
330    def prob(self, prob: float):
331        self.u32 = int(prob * 0x7FFFFFFF)
332
333    def _print_obj_value(self):
334        return " prob={}".format(round(self.prob, 5))
335
336
337class InsnIp(InsnU32):
338    def __init__(self, opcode, is_or=False, is_not=False, arg1=0, u32=0, ip=None):
339        super().__init__(opcode, is_or=is_or, is_not=is_not, arg1=arg1, u32=u32)
340        if ip:
341            self.ip = ip
342
343    @property
344    def ip(self):
345        return socket.inet_ntop(socket.AF_INET, struct.pack("@I", self.u32))
346
347    @ip.setter
348    def ip(self, ip: str):
349        ip_bin = socket.inet_pton(socket.AF_INET, ip)
350        self.u32 = struct.unpack("@I", ip_bin)[0]
351
352    def _print_opcode_value(self):
353        return " ip={}".format(self.ip)
354
355
356class InsnTable(Insn):
357    @classmethod
358    def _validate(cls, data):
359        cls._validate_len(data, [4, 8])
360
361    @classmethod
362    def _parse(cls, data):
363        self = super()._parse(data)
364
365        if len(data) == 8:
366            (self.val,) = struct.unpack("@I", data[4:8])
367            self.ilen = 2
368        else:
369            self.val = None
370        return self
371
372    def __bytes__(self):
373        ret = super().__bytes__()
374        if getattr(self, "val", None) is not None:
375            ret += struct.pack("@I", self.val)
376        return ret
377
378    def _print_obj_value(self):
379        if getattr(self, "val", None) is not None:
380            return " table={} value={}".format(self.arg1, self.val)
381        else:
382            return " table={}".format(self.arg1)
383
384
385class InsnReject(Insn):
386    def __init__(self, opcode, is_or=False, is_not=False, arg1=0, mtu=None):
387        super().__init__(opcode, is_or=is_or, is_not=is_not, arg1=arg1)
388        self.mtu = mtu
389        if self.mtu is not None:
390            self.ilen = 2
391
392    @classmethod
393    def _validate(cls, data):
394        cls._validate_len(data, [4, 8])
395
396    @classmethod
397    def _parse(cls, data):
398        self = super()._parse(data)
399
400        if len(data) == 8:
401            (self.mtu,) = struct.unpack("@I", data[4:8])
402            self.ilen = 2
403        else:
404            self.mtu = None
405        return self
406
407    def __bytes__(self):
408        ret = super().__bytes__()
409        if getattr(self, "mtu", None) is not None:
410            ret += struct.pack("@I", self.mtu)
411        return ret
412
413    def _print_obj_value(self):
414        code = enum_from_int(IcmpRejectCode, self.arg1)
415        if getattr(self, "mtu", None) is not None:
416            return " code={} mtu={}".format(code, self.mtu)
417        else:
418            return " code={}".format(code)
419
420
421class InsnPorts(Insn):
422    def __init__(self, opcode, is_or=False, is_not=False, arg1=0, port_pairs=[]):
423        super().__init__(opcode, is_or=is_or, is_not=is_not)
424        self.port_pairs = []
425        if port_pairs:
426            self.port_pairs = port_pairs
427
428    @classmethod
429    def _validate(cls, data):
430        if len(data) < 8:
431            raise ValueError("no ports specified")
432        cls._validate_len(data)
433
434    @classmethod
435    def _parse(cls, data):
436        self = super()._parse(data)
437
438        off = 4
439        port_pairs = []
440        while off + 4 <= len(data):
441            low, high = struct.unpack("@HH", data[off : off + 4])
442            port_pairs.append((low, high))
443            off += 4
444        self.port_pairs = port_pairs
445        return self
446
447    def __bytes__(self):
448        ret = super().__bytes__()
449        if getattr(self, "val", None) is not None:
450            ret += struct.pack("@I", self.val)
451        return ret
452
453    def _print_obj_value(self):
454        ret = []
455        for p in self.port_pairs:
456            if p[0] == p[1]:
457                ret.append(str(p[0]))
458            else:
459                ret.append("{}-{}".format(p[0], p[1]))
460        return " ports={}".format(",".join(ret))
461
462
463class IpFwInsnIp6(Structure):
464    _fields_ = [
465        ("o", IpFwInsn),
466        ("addr6", c_byte * 16),
467        ("mask6", c_byte * 16),
468    ]
469
470
471class InsnIp6(Insn):
472    def __init__(self, opcode, is_or=False, is_not=False, arg1=0, ip6=None, mask6=None):
473        super().__init__(opcode, is_or=is_or, is_not=is_not, arg1=arg1)
474        self.ip6 = ip6
475        self.mask6 = mask6
476        if mask6 is not None:
477            self.ilen = 9
478        else:
479            self.ilen = 5
480
481    @classmethod
482    def _validate(cls, data):
483        cls._validate_len(data, [4 + 16, 4 + 16 * 2])
484
485    @classmethod
486    def _parse(cls, data):
487        self = super()._parse(data)
488        self.ip6 = socket.inet_ntop(socket.AF_INET6, data[4:20])
489
490        if len(data) == 4 + 16 * 2:
491            self.mask6 = socket.inet_ntop(socket.AF_INET6, data[20:36])
492            self.ilen = 9
493        else:
494            self.mask6 = None
495            self.ilen = 5
496        return self
497
498    def __bytes__(self):
499        ret = super().__bytes__() + socket.inet_pton(socket.AF_INET6, self.ip6)
500        if self.mask6 is not None:
501            ret += socket.inet_pton(socket.AF_INET6, self.mask6)
502        return ret
503
504    def _print_obj_value(self):
505        if self.mask6:
506            return " ip6={}/{}".format(self.ip6, self.mask6)
507        else:
508            return " ip6={}".format(self.ip6)
509
510
511insn_attrs = prepare_attrs_map(
512    [
513        AttrDescr(IpFwOpcode.O_CHECK_STATE, Insn),
514        AttrDescr(IpFwOpcode.O_ACCEPT, InsnEmpty),
515        AttrDescr(IpFwOpcode.O_COUNT, InsnEmpty),
516
517        AttrDescr(IpFwOpcode.O_REJECT, InsnReject),
518        AttrDescr(IpFwOpcode.O_UNREACH6, Insn),
519        AttrDescr(IpFwOpcode.O_DENY, InsnEmpty),
520        AttrDescr(IpFwOpcode.O_DIVERT, Insn),
521        AttrDescr(IpFwOpcode.O_COUNT, InsnEmpty),
522        AttrDescr(IpFwOpcode.O_QUEUE, Insn),
523        AttrDescr(IpFwOpcode.O_PIPE, Insn),
524        AttrDescr(IpFwOpcode.O_SKIPTO, Insn),
525        AttrDescr(IpFwOpcode.O_NETGRAPH, Insn),
526        AttrDescr(IpFwOpcode.O_NGTEE, Insn),
527        AttrDescr(IpFwOpcode.O_DIVERT, Insn),
528        AttrDescr(IpFwOpcode.O_TEE, Insn),
529        AttrDescr(IpFwOpcode.O_CALLRETURN, Insn),
530        AttrDescr(IpFwOpcode.O_SETFIB, Insn),
531        AttrDescr(IpFwOpcode.O_SETDSCP, Insn),
532        AttrDescr(IpFwOpcode.O_REASS, InsnEmpty),
533        AttrDescr(IpFwOpcode.O_SETMARK, Insn),
534
535
536
537        AttrDescr(IpFwOpcode.O_NOP, InsnComment),
538        AttrDescr(IpFwOpcode.O_PROTO, InsnProto),
539        AttrDescr(IpFwOpcode.O_PROB, InsnProb),
540        AttrDescr(IpFwOpcode.O_IP_DST_ME, InsnEmpty),
541        AttrDescr(IpFwOpcode.O_IP_SRC_ME, InsnEmpty),
542        AttrDescr(IpFwOpcode.O_IP6_DST_ME, InsnEmpty),
543        AttrDescr(IpFwOpcode.O_IP6_SRC_ME, InsnEmpty),
544        AttrDescr(IpFwOpcode.O_IP_SRC, InsnIp),
545        AttrDescr(IpFwOpcode.O_IP_DST, InsnIp),
546        AttrDescr(IpFwOpcode.O_IP6_DST, InsnIp6),
547        AttrDescr(IpFwOpcode.O_IP6_SRC, InsnIp6),
548        AttrDescr(IpFwOpcode.O_IP_SRC_LOOKUP, InsnTable),
549        AttrDescr(IpFwOpcode.O_IP_DST_LOOKUP, InsnTable),
550        AttrDescr(IpFwOpcode.O_IP_SRCPORT, InsnPorts),
551        AttrDescr(IpFwOpcode.O_IP_DSTPORT, InsnPorts),
552        AttrDescr(IpFwOpcode.O_PROBE_STATE, Insn),
553        AttrDescr(IpFwOpcode.O_KEEP_STATE, Insn),
554    ]
555)
556