1"""
2Copied from urllib3.util.ssltransport
3"""
4import io
5import socket
6import ssl
7
8
9SSL_BLOCKSIZE = 16384
10
11
12class SSLTransport:
13    """
14    The SSLTransport wraps an existing socket and establishes an SSL connection.
15
16    Contrary to Python's implementation of SSLSocket, it allows you to chain
17    multiple TLS connections together. It's particularly useful if you need to
18    implement TLS within TLS.
19
20    The class supports most of the socket API operations.
21    """
22
23    def __init__(
24        self, socket, ssl_context, server_hostname=None, suppress_ragged_eofs=True
25    ):
26        """
27        Create an SSLTransport around socket using the provided ssl_context.
28        """
29        self.incoming = ssl.MemoryBIO()
30        self.outgoing = ssl.MemoryBIO()
31
32        self.suppress_ragged_eofs = suppress_ragged_eofs
33        self.socket = socket
34
35        self.sslobj = ssl_context.wrap_bio(
36            self.incoming, self.outgoing, server_hostname=server_hostname
37        )
38
39        # Perform initial handshake.
40        self._ssl_io_loop(self.sslobj.do_handshake)
41
42    def __enter__(self):
43        return self
44
45    def __exit__(self, *_):
46        self.close()
47
48    def fileno(self):
49        return self.socket.fileno()
50
51    def read(self, len=1024, buffer=None):
52        return self._wrap_ssl_read(len, buffer)
53
54    def recv(self, len=1024, flags=0):
55        if flags != 0:
56            raise ValueError("non-zero flags not allowed in calls to recv")
57        return self._wrap_ssl_read(len)
58
59    def recv_into(self, buffer, nbytes=None, flags=0):
60        if flags != 0:
61            raise ValueError("non-zero flags not allowed in calls to recv_into")
62        if buffer and (nbytes is None):
63            nbytes = len(buffer)
64        elif nbytes is None:
65            nbytes = 1024
66        return self.read(nbytes, buffer)
67
68    def sendall(self, data, flags=0):
69        if flags != 0:
70            raise ValueError("non-zero flags not allowed in calls to sendall")
71        count = 0
72        with memoryview(data) as view, view.cast("B") as byte_view:
73            amount = len(byte_view)
74            while count < amount:
75                v = self.send(byte_view[count:])
76                count += v
77
78    def send(self, data, flags=0):
79        if flags != 0:
80            raise ValueError("non-zero flags not allowed in calls to send")
81        response = self._ssl_io_loop(self.sslobj.write, data)
82        return response
83
84    def makefile(
85        self, mode="r", buffering=None, encoding=None, errors=None, newline=None
86    ):
87        """
88        Python's httpclient uses makefile and buffered io when reading HTTP
89        messages and we need to support it.
90
91        This is unfortunately a copy and paste of socket.py makefile with small
92        changes to point to the socket directly.
93        """
94        if not set(mode) <= {"r", "w", "b"}:
95            raise ValueError("invalid mode %r (only r, w, b allowed)" % (mode,))
96
97        writing = "w" in mode
98        reading = "r" in mode or not writing
99        assert reading or writing
100        binary = "b" in mode
101        rawmode = ""
102        if reading:
103            rawmode += "r"
104        if writing:
105            rawmode += "w"
106        raw = socket.SocketIO(self, rawmode)
107        self.socket._io_refs += 1
108        if buffering is None:
109            buffering = -1
110        if buffering < 0:
111            buffering = io.DEFAULT_BUFFER_SIZE
112        if buffering == 0:
113            if not binary:
114                raise ValueError("unbuffered streams must be binary")
115            return raw
116        if reading and writing:
117            buffer = io.BufferedRWPair(raw, raw, buffering)
118        elif reading:
119            buffer = io.BufferedReader(raw, buffering)
120        else:
121            assert writing
122            buffer = io.BufferedWriter(raw, buffering)
123        if binary:
124            return buffer
125        text = io.TextIOWrapper(buffer, encoding, errors, newline)
126        text.mode = mode
127        return text
128
129    def unwrap(self):
130        self._ssl_io_loop(self.sslobj.unwrap)
131
132    def close(self):
133        self.socket.close()
134
135    def getpeercert(self, binary_form=False):
136        return self.sslobj.getpeercert(binary_form)
137
138    def version(self):
139        return self.sslobj.version()
140
141    def cipher(self):
142        return self.sslobj.cipher()
143
144    def selected_alpn_protocol(self):
145        return self.sslobj.selected_alpn_protocol()
146
147    def selected_npn_protocol(self):
148        return self.sslobj.selected_npn_protocol()
149
150    def shared_ciphers(self):
151        return self.sslobj.shared_ciphers()
152
153    def compression(self):
154        return self.sslobj.compression()
155
156    def settimeout(self, value):
157        self.socket.settimeout(value)
158
159    def gettimeout(self):
160        return self.socket.gettimeout()
161
162    def _decref_socketios(self):
163        self.socket._decref_socketios()
164
165    def _wrap_ssl_read(self, len, buffer=None):
166        try:
167            return self._ssl_io_loop(self.sslobj.read, len, buffer)
168        except ssl.SSLError as e:
169            if e.errno == ssl.SSL_ERROR_EOF and self.suppress_ragged_eofs:
170                return 0  # eof, return 0.
171            else:
172                raise
173
174    def _ssl_io_loop(self, func, *args):
175        """Performs an I/O loop between incoming/outgoing and the socket."""
176        should_loop = True
177        ret = None
178
179        while should_loop:
180            errno = None
181            try:
182                ret = func(*args)
183            except ssl.SSLError as e:
184                if e.errno not in (ssl.SSL_ERROR_WANT_READ, ssl.SSL_ERROR_WANT_WRITE):
185                    # WANT_READ, and WANT_WRITE are expected, others are not.
186                    raise e
187                errno = e.errno
188
189            buf = self.outgoing.read()
190            self.socket.sendall(buf)
191
192            if errno is None:
193                should_loop = False
194            elif errno == ssl.SSL_ERROR_WANT_READ:
195                buf = self.socket.recv(SSL_BLOCKSIZE)
196                if buf:
197                    self.incoming.write(buf)
198                else:
199                    self.incoming.write_eof()
200        return ret
201