1# -*- coding: utf-8 -*-
2
3from __future__ import absolute_import, print_function, unicode_literals
4
5__copyright__ = """
6Copyright (C) 2009-2017 Andreas Kloeckner
7Copyright (C) 2014-2017 Aaron Meurer
8Copyright (C) 2020-2020 Son Geon
9"""
10
11__license__ = """
12Permission is hereby granted, free of charge, to any person obtaining a copy
13of this software and associated documentation files (the "Software"), to deal
14in the Software without restriction, including without limitation the rights
15to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
16copies of the Software, and to permit persons to whom the Software is
17furnished to do so, subject to the following conditions:
18
19The above copyright notice and this permission notice shall be included in
20all copies or substantial portions of the Software.
21
22THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
23IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
24FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
25AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
26LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
27OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
28THE SOFTWARE.
29"""
30
31
32# mostly stolen from celery.contrib.rdb
33
34
35import errno
36import os
37import socket
38import sys
39import fcntl
40import termios
41import struct
42
43from pudb.debugger import Debugger
44
45__all__ = ["PUDB_RDB_HOST", "PUDB_RDB_PORT", "default_port", "debugger", "set_trace"]
46
47default_port = 6899
48
49PUDB_RDB_HOST = os.environ.get("PUDB_RDB_HOST") or "127.0.0.1"
50PUDB_RDB_PORT = int(os.environ.get("PUDB_RDB_PORT") or default_port)
51
52#: Holds the currently active debugger.
53_current = [None]
54
55_frame = getattr(sys, "_getframe")
56
57NO_AVAILABLE_PORT = """\
58{self.ident}: Couldn't find an available port.
59
60Please specify one using the PUDB_RDB_PORT environment variable.
61"""
62
63BANNER = """\
64{self.ident}: Please telnet into {self.host} {self.port}.
65{self.ident}: Waiting for client...
66"""
67
68SESSION_STARTED = "{self.ident}: Now in session with {self.remote_addr}."
69SESSION_ENDED = "{self.ident}: Session with {self.remote_addr} ended."
70
71CONN_REFUSED = """\
72Cannot connect to the reverse telnet client {self.host} {self.port}.
73
74Try to open reverse client by running
75stty -echo -icanon && nc -l -p 6899  # Linux
76stty -echo -icanon && nc -l 6899  # BSD/MacOS
77
78Please specify one using the PUDB_RDB_PORT environment variable.
79"""
80
81
82class RemoteDebugger(Debugger):
83    me = "pudb"
84    _prev_outs = None
85    _sock = None
86
87    def __init__(
88        self,
89        host=PUDB_RDB_HOST,
90        port=PUDB_RDB_PORT,
91        port_search_limit=100,
92        out=sys.stdout,
93        term_size=None,
94        reverse=False,
95    ):
96        self.active = True
97        self.out = out
98
99        self._prev_handles = sys.stdin, sys.stdout
100        self._client, (address, port) = self.get_client(
101            host=host, port=port, search_limit=port_search_limit, reverse=reverse
102        )
103        self.remote_addr = ":".join(str(v) for v in address)
104
105        self.say(SESSION_STARTED.format(self=self))
106
107        # makefile ignores encoding if there's no buffering.
108        raw_sock_file = self._client.makefile("rwb", 0)
109        import codecs
110
111        if sys.version_info[0] < 3:
112            sock_file = codecs.StreamRecoder(
113                raw_sock_file,
114                codecs.getencoder("utf-8"),
115                codecs.getdecoder("utf-8"),
116                codecs.getreader("utf-8"),
117                codecs.getwriter("utf-8"),
118            )
119        else:
120            sock_file = codecs.StreamReaderWriter(
121                raw_sock_file, codecs.getreader("utf-8"), codecs.getwriter("utf-8")
122            )
123
124        self._handle = sys.stdin = sys.stdout = sock_file
125
126        # nc negotiation doesn't support telnet options
127        if not reverse:
128            import telnetlib as tn
129
130            raw_sock_file.write(tn.IAC + tn.WILL + tn.SGA)
131            resp = raw_sock_file.read(3)
132            assert resp == tn.IAC + tn.DO + tn.SGA
133
134            raw_sock_file.write(tn.IAC + tn.WILL + tn.ECHO)
135            resp = raw_sock_file.read(3)
136            assert resp == tn.IAC + tn.DO + tn.ECHO
137
138        Debugger.__init__(
139            self, stdin=self._handle, stdout=self._handle, term_size=term_size
140        )
141
142    def get_client(self, host, port, search_limit=100, reverse=False):
143        if reverse:
144            self.host, self.port = host, port
145            client, address = self.get_reverse_socket_client(host, port)
146            self.ident = "{0}:{1}".format(self.me, self.port)
147        else:
148            self._sock, conn_info = self.get_socket_client(
149                host, port, search_limit=search_limit,
150            )
151            self.host, self.port = conn_info
152            self.ident = "{0}:{1}".format(self.me, self.port)
153            self.say(BANNER.format(self=self))
154            client, address = self._sock.accept()
155        client.setblocking(1)
156        return client, (address, self.port)
157
158    def get_reverse_socket_client(self, host, port):
159        _sock = socket.socket()
160        try:
161            _sock.connect((host, port))
162            _sock.setblocking(1)
163        except socket.error as exc:
164            if exc.errno == errno.ECONNREFUSED:
165                raise ValueError(CONN_REFUSED.format(self=self))
166            raise exc
167        return _sock, _sock.getpeername()
168
169    def get_socket_client(self, host, port, search_limit):
170        _sock, this_port = self.get_avail_port(host, port, search_limit)
171        _sock.setblocking(1)
172        _sock.listen(1)
173        return _sock, (host, this_port)
174
175    def get_avail_port(self, host, port, search_limit=100, skew=+0):
176        this_port = None
177        for i in range(search_limit):
178            _sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
179            _sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
180            this_port = port + i
181            try:
182                _sock.bind((host, this_port))
183            except socket.error as exc:
184                if exc.errno in [errno.EADDRINUSE, errno.EINVAL]:
185                    continue
186                raise
187            else:
188                return _sock, this_port
189        else:
190            raise Exception(NO_AVAILABLE_PORT.format(self=self))
191
192    def say(self, m):
193        print(m, file=self.out)
194
195    def _close_session(self):
196        self.stdin, self.stdout = sys.stdin, sys.stdout = self._prev_handles
197        self._handle.close()
198        self._client.close()
199        self._sock.close()
200        self.active = False
201        self.say(SESSION_ENDED.format(self=self))
202
203    def do_continue(self, arg):
204        self._close_session()
205        self.set_continue()
206        return 1
207
208    do_c = do_cont = do_continue
209
210    def do_quit(self, arg):
211        self._close_session()
212        self.set_quit()
213        return 1
214
215    def set_quit(self):
216        # this raises a BdbQuit exception that we are unable to catch.
217        sys.settrace(None)
218
219
220def debugger(term_size=None, host=PUDB_RDB_HOST, port=PUDB_RDB_PORT, reverse=False):
221    """Return the current debugger instance (if any),
222    or creates a new one."""
223    rdb = _current[0]
224    if rdb is None or not rdb.active:
225        rdb = _current[0] = RemoteDebugger(
226            host=host, port=port, term_size=term_size, reverse=reverse
227        )
228    return rdb
229
230
231def set_trace(
232    frame=None, term_size=None, host=PUDB_RDB_HOST, port=PUDB_RDB_PORT, reverse=False
233):
234    """Set breakpoint at current location, or a specified frame"""
235    if frame is None:
236        frame = _frame().f_back
237    if term_size is None:
238        try:
239            # Getting terminal size
240            s = struct.unpack("hh", fcntl.ioctl(1, termios.TIOCGWINSZ, "1234"))
241            term_size = (s[1], s[0])
242        except Exception:
243            term_size = (80, 24)
244
245    return debugger(
246        term_size=term_size, host=host, port=port, reverse=reverse
247    ).set_trace(frame)
248