1import contextlib
2import ctypes
3import ctypes.wintypes
4import io
5import json
6import os
7import re
8import socket
9import socketserver
10import threading
11import time
12import typing
13
14import click
15import collections
16import collections.abc
17import pydivert
18import pydivert.consts
19
20if typing.TYPE_CHECKING:
21    class WindowsError(OSError):
22        @property
23        def winerror(self) -> int:
24            return 42
25
26REDIRECT_API_HOST = "127.0.0.1"
27REDIRECT_API_PORT = 8085
28
29
30##########################
31# Resolver
32
33def read(rfile: io.BufferedReader) -> typing.Any:
34    x = rfile.readline().strip()
35    return json.loads(x)
36
37
38def write(data, wfile: io.BufferedWriter) -> None:
39    wfile.write(json.dumps(data).encode() + b"\n")
40    wfile.flush()
41
42
43class Resolver:
44    sock: socket.socket
45    lock: threading.RLock
46
47    def __init__(self):
48        self.sock = None
49        self.lock = threading.RLock()
50
51    def setup(self):
52        with self.lock:
53            TransparentProxy.setup()
54            self._connect()
55
56    def _connect(self):
57        if self.sock:
58            self.sock.close()
59        self.sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
60        self.sock.connect((REDIRECT_API_HOST, REDIRECT_API_PORT))
61
62        self.wfile = self.sock.makefile('wb')
63        self.rfile = self.sock.makefile('rb')
64        write(os.getpid(), self.wfile)
65
66    def original_addr(self, csock: socket.socket):
67        ip, port = csock.getpeername()[:2]
68        ip = re.sub(r"^::ffff:(?=\d+.\d+.\d+.\d+$)", "", ip)
69        ip = ip.split("%", 1)[0]
70        with self.lock:
71            try:
72                write((ip, port), self.wfile)
73                addr = read(self.rfile)
74                if addr is None:
75                    raise RuntimeError("Cannot resolve original destination.")
76                return tuple(addr)
77            except (EOFError, OSError):
78                self._connect()
79                return self.original_addr(csock)
80
81
82class APIRequestHandler(socketserver.StreamRequestHandler):
83    """
84    TransparentProxy API: Returns the pickled server address, port tuple
85    for each received pickled client address, port tuple.
86    """
87
88    def handle(self):
89        proxifier: TransparentProxy = self.server.proxifier
90        try:
91            pid: int = read(self.rfile)
92            with proxifier.exempt(pid):
93                while True:
94                    client = tuple(read(self.rfile))
95                    try:
96                        server = proxifier.client_server_map[client]
97                    except KeyError:
98                        server = None
99                    write(server, self.wfile)
100        except (EOFError, OSError):
101            pass
102
103
104class APIServer(socketserver.ThreadingMixIn, socketserver.TCPServer):
105
106    def __init__(self, proxifier, *args, **kwargs):
107        super().__init__(*args, **kwargs)
108        self.proxifier = proxifier
109        self.daemon_threads = True
110
111
112##########################
113# Windows API
114
115# from Windows' error.h
116ERROR_INSUFFICIENT_BUFFER = 0x7A
117
118IN6_ADDR = ctypes.c_ubyte * 16
119IN4_ADDR = ctypes.c_ubyte * 4
120
121
122#
123# IPv6
124#
125
126# https://msdn.microsoft.com/en-us/library/windows/desktop/aa366896(v=vs.85).aspx
127class MIB_TCP6ROW_OWNER_PID(ctypes.Structure):
128    _fields_ = [
129        ('ucLocalAddr', IN6_ADDR),
130        ('dwLocalScopeId', ctypes.wintypes.DWORD),
131        ('dwLocalPort', ctypes.wintypes.DWORD),
132        ('ucRemoteAddr', IN6_ADDR),
133        ('dwRemoteScopeId', ctypes.wintypes.DWORD),
134        ('dwRemotePort', ctypes.wintypes.DWORD),
135        ('dwState', ctypes.wintypes.DWORD),
136        ('dwOwningPid', ctypes.wintypes.DWORD),
137    ]
138
139
140# https://msdn.microsoft.com/en-us/library/windows/desktop/aa366905(v=vs.85).aspx
141def MIB_TCP6TABLE_OWNER_PID(size):
142    class _MIB_TCP6TABLE_OWNER_PID(ctypes.Structure):
143        _fields_ = [
144            ('dwNumEntries', ctypes.wintypes.DWORD),
145            ('table', MIB_TCP6ROW_OWNER_PID * size)
146        ]
147
148    return _MIB_TCP6TABLE_OWNER_PID()
149
150
151#
152# IPv4
153#
154
155# https://msdn.microsoft.com/en-us/library/windows/desktop/aa366913(v=vs.85).aspx
156class MIB_TCPROW_OWNER_PID(ctypes.Structure):
157    _fields_ = [
158        ('dwState', ctypes.wintypes.DWORD),
159        ('ucLocalAddr', IN4_ADDR),
160        ('dwLocalPort', ctypes.wintypes.DWORD),
161        ('ucRemoteAddr', IN4_ADDR),
162        ('dwRemotePort', ctypes.wintypes.DWORD),
163        ('dwOwningPid', ctypes.wintypes.DWORD),
164    ]
165
166
167# https://msdn.microsoft.com/en-us/library/windows/desktop/aa366921(v=vs.85).aspx
168def MIB_TCPTABLE_OWNER_PID(size):
169    class _MIB_TCPTABLE_OWNER_PID(ctypes.Structure):
170        _fields_ = [
171            ('dwNumEntries', ctypes.wintypes.DWORD),
172            ('table', MIB_TCPROW_OWNER_PID * size)
173        ]
174
175    return _MIB_TCPTABLE_OWNER_PID()
176
177
178TCP_TABLE_OWNER_PID_CONNECTIONS = 4
179
180
181class TcpConnectionTable(collections.abc.Mapping):
182    DEFAULT_TABLE_SIZE = 4096
183
184    def __init__(self):
185        self._tcp = MIB_TCPTABLE_OWNER_PID(self.DEFAULT_TABLE_SIZE)
186        self._tcp_size = ctypes.wintypes.DWORD(self.DEFAULT_TABLE_SIZE)
187        self._tcp6 = MIB_TCP6TABLE_OWNER_PID(self.DEFAULT_TABLE_SIZE)
188        self._tcp6_size = ctypes.wintypes.DWORD(self.DEFAULT_TABLE_SIZE)
189        self._map = {}
190
191    def __getitem__(self, item):
192        return self._map[item]
193
194    def __iter__(self):
195        return self._map.__iter__()
196
197    def __len__(self):
198        return self._map.__len__()
199
200    def refresh(self):
201        self._map = {}
202        self._refresh_ipv4()
203        self._refresh_ipv6()
204
205    def _refresh_ipv4(self):
206        ret = ctypes.windll.iphlpapi.GetExtendedTcpTable(
207            ctypes.byref(self._tcp),
208            ctypes.byref(self._tcp_size),
209            False,
210            socket.AF_INET,
211            TCP_TABLE_OWNER_PID_CONNECTIONS,
212            0
213        )
214        if ret == 0:
215            for row in self._tcp.table[:self._tcp.dwNumEntries]:
216                local_ip = socket.inet_ntop(socket.AF_INET, bytes(row.ucLocalAddr))
217                local_port = socket.htons(row.dwLocalPort)
218                self._map[(local_ip, local_port)] = row.dwOwningPid
219        elif ret == ERROR_INSUFFICIENT_BUFFER:
220            self._tcp = MIB_TCPTABLE_OWNER_PID(self._tcp_size.value)
221            # no need to update size, that's already done.
222            self._refresh_ipv4()
223        else:
224            raise RuntimeError("[IPv4] Unknown GetExtendedTcpTable return code: %s" % ret)
225
226    def _refresh_ipv6(self):
227        ret = ctypes.windll.iphlpapi.GetExtendedTcpTable(
228            ctypes.byref(self._tcp6),
229            ctypes.byref(self._tcp6_size),
230            False,
231            socket.AF_INET6,
232            TCP_TABLE_OWNER_PID_CONNECTIONS,
233            0
234        )
235        if ret == 0:
236            for row in self._tcp6.table[:self._tcp6.dwNumEntries]:
237                local_ip = socket.inet_ntop(socket.AF_INET6, bytes(row.ucLocalAddr))
238                local_port = socket.htons(row.dwLocalPort)
239                self._map[(local_ip, local_port)] = row.dwOwningPid
240        elif ret == ERROR_INSUFFICIENT_BUFFER:
241            self._tcp6 = MIB_TCP6TABLE_OWNER_PID(self._tcp6_size.value)
242            # no need to update size, that's already done.
243            self._refresh_ipv6()
244        else:
245            raise RuntimeError("[IPv6] Unknown GetExtendedTcpTable return code: %s" % ret)
246
247
248def get_local_ip() -> typing.Optional[str]:
249    # Auto-Detect local IP. This is required as re-injecting to 127.0.0.1 does not work.
250    # https://stackoverflow.com/questions/166506/finding-local-ip-addresses-using-pythons-stdlib
251    s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
252    try:
253        s.connect(("8.8.8.8", 80))
254        return s.getsockname()[0]
255    except OSError:
256        return None
257    finally:
258        s.close()
259
260
261def get_local_ip6(reachable: str) -> typing.Optional[str]:
262    # The same goes for IPv6, with the added difficulty that .connect() fails if
263    # the target network is not reachable.
264    s = socket.socket(socket.AF_INET6, socket.SOCK_DGRAM)
265    try:
266        s.connect((reachable, 80))
267        return s.getsockname()[0]
268    except OSError:
269        return None
270    finally:
271        s.close()
272
273
274class Redirect(threading.Thread):
275    daemon = True
276    windivert: pydivert.WinDivert
277
278    def __init__(
279        self,
280        handle: typing.Callable[[pydivert.Packet], None],
281        filter: str,
282        layer: pydivert.Layer = pydivert.Layer.NETWORK,
283        flags: pydivert.Flag = 0
284    ) -> None:
285        self.handle = handle
286        self.windivert = pydivert.WinDivert(filter, layer, flags=flags)
287        super().__init__()
288
289    def start(self):
290        self.windivert.open()
291        super().start()
292
293    def run(self):
294        while True:
295            try:
296                packet = self.windivert.recv()
297            except WindowsError as e:
298                if e.winerror == 995:
299                    return
300                else:
301                    raise
302            else:
303                self.handle(packet)
304
305    def shutdown(self):
306        self.windivert.close()
307
308    def recv(self) -> typing.Optional[pydivert.Packet]:
309        """
310        Convenience function that receives a packet from the passed handler and handles error codes.
311        If the process has been shut down, None is returned.
312        """
313        try:
314            return self.windivert.recv()
315        except WindowsError as e:
316            if e.winerror == 995:
317                return None
318            else:
319                raise
320
321
322class RedirectLocal(Redirect):
323    trusted_pids: typing.Set[int]
324
325    def __init__(
326        self,
327        redirect_request: typing.Callable[[pydivert.Packet], None],
328        filter: str
329    ) -> None:
330        self.tcp_connections = TcpConnectionTable()
331        self.trusted_pids = set()
332        self.redirect_request = redirect_request
333        super().__init__(self.handle, filter)
334
335    def handle(self, packet):
336        client = (packet.src_addr, packet.src_port)
337
338        if client not in self.tcp_connections:
339            self.tcp_connections.refresh()
340
341        # If this fails, we most likely have a connection from an external client.
342        # In this, case we always want to proxy the request.
343        pid = self.tcp_connections.get(client, None)
344
345        if pid not in self.trusted_pids:
346            self.redirect_request(packet)
347        else:
348            # It's not really clear why we need to recalculate the checksum here,
349            # but this was identified as necessary in https://github.com/mitmproxy/mitmproxy/pull/3174.
350            self.windivert.send(packet, recalculate_checksum=True)
351
352
353TConnection = typing.Tuple[str, int]
354
355
356class ClientServerMap:
357    """A thread-safe LRU dict."""
358    connection_cache_size: typing.ClassVar[int] = 65536
359
360    def __init__(self):
361        self._lock = threading.Lock()
362        self._map = collections.OrderedDict()
363
364    def __getitem__(self, item: TConnection) -> TConnection:
365        with self._lock:
366            return self._map[item]
367
368    def __setitem__(self, key: TConnection, value: TConnection) -> None:
369        with self._lock:
370            self._map[key] = value
371            self._map.move_to_end(key)
372            while len(self._map) > self.connection_cache_size:
373                self._map.popitem(False)
374
375
376class TransparentProxy:
377    """
378    Transparent Windows Proxy for mitmproxy based on WinDivert/PyDivert. This module can be used to
379    redirect both traffic that is forwarded by the host and traffic originating from the host itself.
380
381    Requires elevated (admin) privileges. Can be started separately by manually running the file.
382
383    How it works:
384
385    (1) First, we intercept all packages that match our filter.
386    We both consider traffic that is forwarded by the OS (WinDivert's NETWORK_FORWARD layer) as well
387    as traffic sent from the local machine (WinDivert's NETWORK layer). In the case of traffic from
388    the local machine, we need to exempt packets sent from the proxy to not create a redirect loop.
389    To accomplish this, we use Windows' GetExtendedTcpTable syscall and determine the source
390    application's PID.
391
392    For each intercepted package, we
393        1. Store the source -> destination mapping (address and port)
394        2. Remove the package from the network (by not reinjecting it).
395        3. Re-inject the package into the local network stack, but with the destination address
396           changed to the proxy.
397
398    (2) Next, the proxy receives the forwarded packet, but does not know the real destination yet
399    (which we overwrote with the proxy's address). On Linux, we would now call
400    getsockopt(SO_ORIGINAL_DST). We now access the redirect module's API (see APIRequestHandler),
401    submit the source information and get the actual destination back (which we stored in 1.1).
402
403    (3) The proxy now establishes the upstream connection as usual.
404
405    (4) Finally, the proxy sends the response back to the client. To make it work, we need to change
406    the packet's source address back to the original destination (using the mapping from 1.1),
407    to which the client believes it is talking to.
408
409    Limitations:
410
411    - We assume that ephemeral TCP ports are not re-used for multiple connections at the same time.
412    The proxy will fail if an application connects to example.com and example.org from
413    192.168.0.42:4242 simultaneously. This could be mitigated by introducing unique "meta-addresses"
414    which mitmproxy sees, but this would remove the correct client info from mitmproxy.
415    """
416    local: typing.Optional[RedirectLocal] = None
417    # really weird linting error here.
418    forward: typing.Optional[Redirect] = None  # noqa
419    response: Redirect
420    icmp: Redirect
421
422    proxy_port: int
423    filter: str
424
425    client_server_map: ClientServerMap
426
427    def __init__(
428        self,
429        local: bool = True,
430        forward: bool = True,
431        proxy_port: int = 8080,
432        filter: typing.Optional[str] = "tcp.DstPort == 80 or tcp.DstPort == 443",
433    ) -> None:
434        self.proxy_port = proxy_port
435        self.filter = (
436            filter
437            or
438            f"tcp.DstPort != {proxy_port} and tcp.DstPort != {REDIRECT_API_PORT} and tcp.DstPort < 49152"
439        )
440
441        self.ipv4_address = get_local_ip()
442        self.ipv6_address = get_local_ip6("2001:4860:4860::8888")
443        # print(f"IPv4: {self.ipv4_address}, IPv6: {self.ipv6_address}")
444        self.client_server_map = ClientServerMap()
445
446        self.api = APIServer(self, (REDIRECT_API_HOST, REDIRECT_API_PORT), APIRequestHandler)
447        self.api_thread = threading.Thread(target=self.api.serve_forever)
448        self.api_thread.daemon = True
449
450        if forward:
451            self.forward = Redirect(
452                self.redirect_request,
453                self.filter,
454                pydivert.Layer.NETWORK_FORWARD
455            )
456        if local:
457            self.local = RedirectLocal(
458                self.redirect_request,
459                self.filter
460            )
461
462        # The proxy server responds to the client. To the client,
463        # this response should look like it has been sent by the real target
464        self.response = Redirect(
465            self.redirect_response,
466            f"outbound and tcp.SrcPort == {proxy_port}",
467        )
468
469        # Block all ICMP requests (which are sent on Windows by default).
470        # If we don't do this, our proxy machine may send an ICMP redirect to the client,
471        # which instructs the client to directly connect to the real gateway
472        # if they are on the same network.
473        self.icmp = Redirect(
474            lambda _: None,
475            "icmp",
476            flags=pydivert.Flag.DROP
477        )
478
479    @classmethod
480    def setup(cls):
481        # TODO: Make sure that server can be killed cleanly. That's a bit difficult as we don't have access to
482        # controller.should_exit when this is called.
483        s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
484        server_unavailable = s.connect_ex((REDIRECT_API_HOST, REDIRECT_API_PORT))
485        if server_unavailable:
486            proxifier = TransparentProxy()
487            proxifier.start()
488
489    def start(self):
490        self.api_thread.start()
491        self.icmp.start()
492        self.response.start()
493        if self.forward:
494            self.forward.start()
495        if self.local:
496            self.local.start()
497
498    def shutdown(self):
499        if self.local:
500            self.local.shutdown()
501        if self.forward:
502            self.forward.shutdown()
503        self.response.shutdown()
504        self.icmp.shutdown()
505        self.api.shutdown()
506
507    def redirect_request(self, packet: pydivert.Packet):
508        # print(" * Redirect client -> server to proxy")
509        # print(f"{packet.src_addr}:{packet.src_port} -> {packet.dst_addr}:{packet.dst_port}")
510        client = (packet.src_addr, packet.src_port)
511
512        self.client_server_map[client] = (packet.dst_addr, packet.dst_port)
513
514        # We do need to inject to an external IP here, 127.0.0.1 does not work.
515        if packet.address_family == socket.AF_INET:
516            assert self.ipv4_address
517            packet.dst_addr = self.ipv4_address
518        elif packet.address_family == socket.AF_INET6:
519            if not self.ipv6_address:
520                self.ipv6_address = get_local_ip6(packet.src_addr)
521            assert self.ipv6_address
522            packet.dst_addr = self.ipv6_address
523        else:
524            raise RuntimeError("Unknown address family")
525        packet.dst_port = self.proxy_port
526        packet.direction = pydivert.consts.Direction.INBOUND
527
528        # We need a handle on the NETWORK layer. the local handle is not guaranteed to exist,
529        # so we use the response handle.
530        self.response.windivert.send(packet)
531
532    def redirect_response(self, packet: pydivert.Packet):
533        """
534        If the proxy responds to the client, let the client believe the target server sent the
535        packets.
536        """
537        # print(" * Adjust proxy -> client")
538        client = (packet.dst_addr, packet.dst_port)
539        try:
540            packet.src_addr, packet.src_port = self.client_server_map[client]
541        except KeyError:
542            print(f"Warning: Previously unseen connection from proxy to {client}")
543        else:
544            packet.recalculate_checksums()
545
546        self.response.windivert.send(packet, recalculate_checksum=False)
547
548    @contextlib.contextmanager
549    def exempt(self, pid: int):
550        if self.local:
551            self.local.trusted_pids.add(pid)
552        try:
553            yield
554        finally:
555            if self.local:
556                self.local.trusted_pids.remove(pid)
557
558
559@click.group()
560def cli():
561    pass
562
563
564@cli.command()
565@click.option("--local/--no-local", default=True,
566              help="Redirect the host's own traffic.")
567@click.option("--forward/--no-forward", default=True,
568              help="Redirect traffic that's forwarded by the host.")
569@click.option("--filter", type=str, metavar="WINDIVERT_FILTER",
570              help="Custom WinDivert interception rule.")
571@click.option("-p", "--proxy-port", type=int, metavar="8080", default=8080,
572              help="The port mitmproxy is listening on.")
573def redirect(**options):
574    """Redirect flows to mitmproxy."""
575    proxy = TransparentProxy(**options)
576    proxy.start()
577    print(f" * Redirection active.")
578    print(f"   Filter: {proxy.request_filter}")
579    try:
580        while True:
581            time.sleep(1)
582    except KeyboardInterrupt:
583        print(" * Shutting down...")
584        proxy.shutdown()
585        print(" * Shut down.")
586
587
588@cli.command()
589def connections():
590    """List all TCP connections and the associated PIDs."""
591    connections = TcpConnectionTable()
592    connections.refresh()
593    for (ip, port), pid in connections.items():
594        print(f"{ip}:{port} -> {pid}")
595
596
597if __name__ == "__main__":
598    cli()
599