1# This file is part of Scapy
2# See http://www.secdev.org/projects/scapy for more information
3# Copyright (C) Philippe Biondi <phil@secdev.org>
4# This program is published under a GPLv2 license
5
6"""
7SuperSocket.
8"""
9
10from __future__ import absolute_import
11from select import select, error as select_error
12import ctypes
13import errno
14import socket
15import struct
16import time
17
18from scapy.config import conf
19from scapy.consts import DARWIN, WINDOWS
20from scapy.data import MTU, ETH_P_IP, SOL_PACKET, SO_TIMESTAMPNS
21from scapy.compat import raw
22from scapy.error import warning, log_runtime
23from scapy.interfaces import network_name
24import scapy.modules.six as six
25from scapy.packet import Packet
26import scapy.packet
27from scapy.plist import (
28    PacketList,
29    SndRcvList,
30    _PacketIterable,
31)
32from scapy.utils import PcapReader, tcpdump
33
34# Typing imports
35from scapy.interfaces import _GlobInterfaceType
36from scapy.compat import (
37    Any,
38    Iterator,
39    List,
40    Optional,
41    Tuple,
42    Type,
43    cast,
44)
45
46# Utils
47
48
49class _SuperSocket_metaclass(type):
50    desc = None   # type: Optional[str]
51
52    def __repr__(self):
53        # type: () -> str
54        if self.desc is not None:
55            return "<%s: %s>" % (self.__name__, self.desc)
56        else:
57            return "<%s>" % self.__name__
58
59
60# Used to get ancillary data
61PACKET_AUXDATA = 8  # type: int
62ETH_P_8021Q = 0x8100  # type: int
63TP_STATUS_VLAN_VALID = 1 << 4  # type: int
64
65
66class tpacket_auxdata(ctypes.Structure):
67    _fields_ = [
68        ("tp_status", ctypes.c_uint),
69        ("tp_len", ctypes.c_uint),
70        ("tp_snaplen", ctypes.c_uint),
71        ("tp_mac", ctypes.c_ushort),
72        ("tp_net", ctypes.c_ushort),
73        ("tp_vlan_tci", ctypes.c_ushort),
74        ("tp_padding", ctypes.c_ushort),
75    ]  # type: List[Tuple[str, Any]]
76
77
78# SuperSocket
79
80@six.add_metaclass(_SuperSocket_metaclass)
81class SuperSocket:
82    closed = 0    # type: int
83    nonblocking_socket = False  # type: bool
84    auxdata_available = False   # type: bool
85
86    def __init__(self,
87                 family=socket.AF_INET,  # type: int
88                 type=socket.SOCK_STREAM,  # type: int
89                 proto=0,  # type: int
90                 iface=None,  # type: Optional[_GlobInterfaceType]
91                 **kwargs  # type: Any
92                 ):
93        # type: (...) -> None
94        self.ins = socket.socket(family, type, proto)  # type: socket.socket
95        self.outs = self.ins  # type: Optional[socket.socket]
96        self.promisc = None
97        self.iface = iface
98
99    def send(self, x):
100        # type: (Packet) -> int
101        sx = raw(x)
102        try:
103            x.sent_time = time.time()
104        except AttributeError:
105            pass
106
107        if self.outs:
108            return self.outs.send(sx)
109        else:
110            return 0
111
112    if six.PY2:
113        def _recv_raw(self, sock, x):
114            # type: (socket.socket, int) -> Tuple[bytes, Any, Optional[float]]
115            """Internal function to receive a Packet"""
116            pkt, sa_ll = sock.recvfrom(x)
117            return pkt, sa_ll, None
118    else:
119        def _recv_raw(self, sock, x):
120            # type: (socket.socket, int) -> Tuple[bytes, Any, Optional[float]]
121            """Internal function to receive a Packet,
122            and process ancillary data.
123            """
124            timestamp = None
125            if not self.auxdata_available:
126                pkt, _, _, sa_ll = sock.recvmsg(x)
127                return pkt, sa_ll, timestamp
128            flags_len = socket.CMSG_LEN(4096)
129            pkt, ancdata, flags, sa_ll = sock.recvmsg(x, flags_len)
130            if not pkt:
131                return pkt, sa_ll, timestamp
132            for cmsg_lvl, cmsg_type, cmsg_data in ancdata:
133                # Check available ancillary data
134                if (cmsg_lvl == SOL_PACKET and cmsg_type == PACKET_AUXDATA):
135                    # Parse AUXDATA
136                    try:
137                        auxdata = tpacket_auxdata.from_buffer_copy(cmsg_data)
138                    except ValueError:
139                        # Note: according to Python documentation, recvmsg()
140                        #       can return a truncated message. A ValueError
141                        #       exception likely indicates that Auxiliary
142                        #       Data is not supported by the Linux kernel.
143                        return pkt, sa_ll, timestamp
144                    if auxdata.tp_vlan_tci != 0 or \
145                            auxdata.tp_status & TP_STATUS_VLAN_VALID:
146                        # Insert VLAN tag
147                        tag = struct.pack(
148                            "!HH",
149                            ETH_P_8021Q,
150                            auxdata.tp_vlan_tci
151                        )
152                        pkt = pkt[:12] + tag + pkt[12:]
153                elif cmsg_lvl == socket.SOL_SOCKET and \
154                        cmsg_type == SO_TIMESTAMPNS:
155                    length = len(cmsg_data)
156                    if length == 16:  # __kernel_timespec
157                        tmp = struct.unpack("ll", cmsg_data)
158                    elif length == 8:  # timespec
159                        tmp = struct.unpack("ii", cmsg_data)
160                    else:
161                        log_runtime.warning("Unknown timespec format.. ?!")
162                        continue
163                    timestamp = tmp[0] + tmp[1] * 1e-9
164            return pkt, sa_ll, timestamp
165
166    def recv_raw(self, x=MTU):
167        # type: (int) -> Tuple[Optional[Type[Packet]], Optional[bytes], Optional[float]]  # noqa: E501
168        """Returns a tuple containing (cls, pkt_data, time)"""
169        return conf.raw_layer, self.ins.recv(x), None
170
171    def recv(self, x=MTU):
172        # type: (int) -> Optional[Packet]
173        cls, val, ts = self.recv_raw(x)
174        if not val or not cls:
175            return None
176        try:
177            pkt = cls(val)  # type: Packet
178        except KeyboardInterrupt:
179            raise
180        except Exception:
181            if conf.debug_dissector:
182                from scapy.sendrecv import debug
183                debug.crashed_on = (cls, val)
184                raise
185            pkt = conf.raw_layer(val)
186        if ts:
187            pkt.time = ts
188        return pkt
189
190    def fileno(self):
191        # type: () -> int
192        return self.ins.fileno()
193
194    def close(self):
195        # type: () -> None
196        if self.closed:
197            return
198        self.closed = True
199        if getattr(self, "outs", None):
200            if getattr(self, "ins", None) != self.outs:
201                if self.outs and (WINDOWS or self.outs.fileno() != -1):
202                    self.outs.close()
203        if getattr(self, "ins", None):
204            if WINDOWS or self.ins.fileno() != -1:
205                self.ins.close()
206
207    def sr(self, *args, **kargs):
208        # type: (Any, Any) -> Tuple[SndRcvList, PacketList]
209        from scapy import sendrecv
210        ans, unans = sendrecv.sndrcv(self, *args, **kargs)  # type: SndRcvList, PacketList  # noqa: E501
211        return ans, unans
212
213    def sr1(self, *args, **kargs):
214        # type: (Any, Any) -> Optional[Packet]
215        from scapy import sendrecv
216        ans = sendrecv.sndrcv(self, *args, **kargs)[0]  # type: SndRcvList
217        if len(ans) > 0:
218            pkt = ans[0][1]  # type: Packet
219            return pkt
220        else:
221            return None
222
223    def sniff(self, *args, **kargs):
224        # type: (Any, Any) -> PacketList
225        from scapy import sendrecv
226        pkts = sendrecv.sniff(opened_socket=self, *args, **kargs)  # type: PacketList  # noqa: E501
227        return pkts
228
229    def tshark(self, *args, **kargs):
230        # type: (Any, Any) -> None
231        from scapy import sendrecv
232        sendrecv.tshark(opened_socket=self, *args, **kargs)
233
234    # TODO: use 'scapy.ansmachine.AnsweringMachine' when typed
235    def am(self,
236           cls,  # type: Type[Any]
237           *args,  # type: Any
238           **kwargs  # type: Any
239           ):
240        # type: (...) -> Any
241        """
242        Creates an AnsweringMachine associated with this socket.
243
244        :param cls: A subclass of AnsweringMachine to instantiate
245        """
246        return cls(*args, opened_socket=self, socket=self, **kwargs)
247
248    @staticmethod
249    def select(sockets, remain=conf.recv_poll_rate):
250        # type: (List[SuperSocket], Optional[float]) -> List[SuperSocket]
251        """This function is called during sendrecv() routine to select
252        the available sockets.
253
254        :param sockets: an array of sockets that need to be selected
255        :returns: an array of sockets that were selected and
256            the function to be called next to get the packets (i.g. recv)
257        """
258        try:
259            inp, _, _ = select(sockets, [], [], remain)
260        except (IOError, select_error) as exc:
261            # select.error has no .errno attribute
262            if not exc.args or exc.args[0] != errno.EINTR:
263                raise
264        return inp
265
266    def __del__(self):
267        # type: () -> None
268        """Close the socket"""
269        self.close()
270
271    def __enter__(self):
272        # type: () -> SuperSocket
273        return self
274
275    def __exit__(self, exc_type, exc_value, traceback):
276        # type: (Optional[Type[BaseException]], Optional[BaseException], Optional[Any]) -> None  # noqa: E501
277        """Close the socket"""
278        self.close()
279
280
281class L3RawSocket(SuperSocket):
282    desc = "Layer 3 using Raw sockets (PF_INET/SOCK_RAW)"
283
284    def __init__(self,
285                 type=ETH_P_IP,  # type: int
286                 filter=None,  # type: Optional[str]
287                 iface=None,  # type: Optional[_GlobInterfaceType]
288                 promisc=None,  # type: Optional[bool]
289                 nofilter=0  # type: int
290                 ):
291        # type: (...) -> None
292        self.outs = socket.socket(socket.AF_INET, socket.SOCK_RAW, socket.IPPROTO_RAW)  # noqa: E501
293        self.outs.setsockopt(socket.SOL_IP, socket.IP_HDRINCL, 1)
294        self.ins = socket.socket(socket.AF_PACKET, socket.SOCK_RAW, socket.htons(type))  # noqa: E501
295        self.iface = iface
296        if iface is not None:
297            iface = network_name(iface)
298            self.ins.bind((iface, type))
299        if not six.PY2:
300            try:
301                # Receive Auxiliary Data (VLAN tags)
302                self.ins.setsockopt(SOL_PACKET, PACKET_AUXDATA, 1)
303                self.ins.setsockopt(
304                    socket.SOL_SOCKET,
305                    SO_TIMESTAMPNS,
306                    1
307                )
308                self.auxdata_available = True
309            except OSError:
310                # Note: Auxiliary Data is only supported since
311                #       Linux 2.6.21
312                msg = "Your Linux Kernel does not support Auxiliary Data!"
313                log_runtime.info(msg)
314
315    def recv(self, x=MTU):
316        # type: (int) -> Optional[Packet]
317        data, sa_ll, ts = self._recv_raw(self.ins, x)
318        if sa_ll[2] == socket.PACKET_OUTGOING:
319            return None
320        if sa_ll[3] in conf.l2types:
321            cls = conf.l2types.num2layer[sa_ll[3]]  # type: Type[Packet]
322            lvl = 2
323        elif sa_ll[1] in conf.l3types:
324            cls = conf.l3types.num2layer[sa_ll[1]]
325            lvl = 3
326        else:
327            cls = conf.default_l2
328            warning("Unable to guess type (interface=%s protocol=%#x family=%i). Using %s", sa_ll[0], sa_ll[1], sa_ll[3], cls.name)  # noqa: E501
329            lvl = 3
330
331        try:
332            pkt = cls(data)
333        except KeyboardInterrupt:
334            raise
335        except Exception:
336            if conf.debug_dissector:
337                raise
338            pkt = conf.raw_layer(data)
339
340        if lvl == 2:
341            pkt = pkt.payload
342
343        if pkt is not None:
344            if ts is None:
345                from scapy.arch.linux import get_last_packet_timestamp
346                ts = get_last_packet_timestamp(self.ins)
347            pkt.time = ts
348        return pkt
349
350    def send(self, x):
351        # type: (Packet) -> int
352        try:
353            sx = raw(x)
354            if self.outs:
355                x.sent_time = time.time()
356                return self.outs.sendto(
357                    sx,
358                    (x.dst, 0)
359                )
360        except AttributeError:
361            raise ValueError(
362                "Missing 'dst' attribute in the first layer to be "
363                "sent using a native L3 socket ! (make sure you passed the "
364                "IP layer)"
365            )
366        except socket.error as msg:
367            log_runtime.error(msg)
368        return 0
369
370
371class SimpleSocket(SuperSocket):
372    desc = "wrapper around a classic socket"
373    nonblocking_socket = True
374
375    def __init__(self, sock):
376        # type: (socket.socket) -> None
377        self.ins = sock
378        self.outs = sock
379
380
381class StreamSocket(SimpleSocket):
382    desc = "transforms a stream socket into a layer 2"
383
384    def __init__(self, sock, basecls=None):
385        # type: (socket.socket, Optional[Type[Packet]]) -> None
386        if basecls is None:
387            basecls = conf.raw_layer
388        SimpleSocket.__init__(self, sock)
389        self.basecls = basecls
390
391    def recv(self, x=MTU):
392        # type: (int) -> Optional[Packet]
393        data = self.ins.recv(x, socket.MSG_PEEK)
394        x = len(data)
395        if x == 0:
396            return None
397        pkt = self.basecls(data)  # type: Packet
398        pad = pkt.getlayer(conf.padding_layer)
399        if pad is not None and pad.underlayer is not None:
400            del(pad.underlayer.payload)
401        from scapy.packet import NoPayload
402        while pad is not None and not isinstance(pad, NoPayload):
403            x -= len(pad.load)
404            pad = pad.payload
405        self.ins.recv(x)
406        return pkt
407
408
409class SSLStreamSocket(StreamSocket):
410    desc = "similar usage than StreamSocket but specialized for handling SSL-wrapped sockets"  # noqa: E501
411
412    def __init__(self, sock, basecls=None):
413        # type: (socket.socket, Optional[Type[Packet]]) -> None
414        self._buf = b""
415        super(SSLStreamSocket, self).__init__(sock, basecls)
416
417    # 65535, the default value of x is the maximum length of a TLS record
418    def recv(self, x=65535):
419        # type: (int) -> Optional[Packet]
420        pkt = None  # type: Optional[Packet]
421        if self._buf != b"":
422            try:
423                pkt = self.basecls(self._buf)
424            except Exception:
425                # We assume that the exception is generated by a buffer underflow  # noqa: E501
426                pass
427
428        if not pkt:
429            buf = self.ins.recv(x)
430            if len(buf) == 0:
431                raise socket.error((100, "Underlying stream socket tore down"))
432            self._buf += buf
433
434        x = len(self._buf)
435        pkt = self.basecls(self._buf)
436        if pkt is not None:
437            pad = pkt.getlayer(conf.padding_layer)
438
439            if pad is not None and pad.underlayer is not None:
440                del(pad.underlayer.payload)
441            while pad is not None and not isinstance(pad, scapy.packet.NoPayload):   # noqa: E501
442                x -= len(pad.load)
443                pad = pad.payload
444            self._buf = self._buf[x:]
445        return pkt
446
447
448class L2ListenTcpdump(SuperSocket):
449    desc = "read packets at layer 2 using tcpdump"
450
451    def __init__(self,
452                 iface=None,  # type: Optional[_GlobInterfaceType]
453                 promisc=False,  # type: bool
454                 filter=None,  # type: Optional[str]
455                 nofilter=False,  # type: bool
456                 prog=None,  # type: Optional[str]
457                 *arg,  # type: Any
458                 **karg  # type: Any
459                 ):
460        # type: (...) -> None
461        self.outs = None
462        args = ['-w', '-', '-s', '65535']
463        if iface is None and (WINDOWS or DARWIN):
464            iface = conf.iface
465        self.iface = iface
466        if iface is not None:
467            args.extend(['-i', network_name(iface)])
468        if not promisc:
469            args.append('-p')
470        if not nofilter:
471            if conf.except_filter:
472                if filter:
473                    filter = "(%s) and not (%s)" % (filter, conf.except_filter)
474                else:
475                    filter = "not (%s)" % conf.except_filter
476        if filter is not None:
477            args.append(filter)
478        self.tcpdump_proc = tcpdump(None, prog=prog, args=args, getproc=True)
479        self.reader = PcapReader(self.tcpdump_proc.stdout)
480        self.ins = self.reader  # type: ignore
481
482    def recv(self, x=MTU):
483        # type: (int) -> Optional[Packet]
484        return self.reader.recv(x)
485
486    def close(self):
487        # type: () -> None
488        SuperSocket.close(self)
489        self.tcpdump_proc.kill()
490
491    @staticmethod
492    def select(sockets, remain=None):
493        # type: (List[SuperSocket], Optional[float]) -> List[SuperSocket]
494        if (WINDOWS or DARWIN):
495            return sockets
496        return SuperSocket.select(sockets, remain=remain)
497
498
499# More abstract objects
500
501class IterSocket(SuperSocket):
502    desc = "wrapper around an iterable"
503    nonblocking_socket = True
504
505    def __init__(self, obj):
506        # type: (_PacketIterable) -> None
507        if not obj:
508            self.iter = iter([])  # type: Iterator[Packet]
509        elif isinstance(obj, IterSocket):
510            self.iter = obj.iter
511        elif isinstance(obj, SndRcvList):
512            def _iter(obj=cast(SndRcvList, obj)):
513                # type: (SndRcvList) -> Iterator[Packet]
514                for s, r in obj:
515                    if s.sent_time:
516                        s.time = s.sent_time
517                    yield s
518                    yield r
519            self.iter = _iter()
520        elif isinstance(obj, (list, PacketList)):
521            if isinstance(obj[0], bytes):  # type: ignore
522                self.iter = iter(obj)
523            else:
524                self.iter = (y for x in obj for y in x)
525        else:
526            self.iter = obj.__iter__()
527
528    @staticmethod
529    def select(sockets, remain=None):
530        # type: (List[SuperSocket], Any) -> List[SuperSocket]
531        return sockets
532
533    def recv(self, *args):
534        # type: (*Any) -> Optional[Packet]
535        try:
536            pkt = next(self.iter)
537            return pkt.__class__(bytes(pkt))
538        except StopIteration:
539            raise EOFError
540
541    def close(self):
542        # type: () -> None
543        pass
544