1import io
2import socket
3import ssl
4
5from urllib3.exceptions import ProxySchemeUnsupported
6from urllib3.packages import six
7
8SSL_BLOCKSIZE = 16384
9
10
11class SSLTransport:
12    """
13    The SSLTransport wraps an existing socket and establishes an SSL connection.
14
15    Contrary to Python's implementation of SSLSocket, it allows you to chain
16    multiple TLS connections together. It's particularly useful if you need to
17    implement TLS within TLS.
18
19    The class supports most of the socket API operations.
20    """
21
22    @staticmethod
23    def _validate_ssl_context_for_tls_in_tls(ssl_context):
24        """
25        Raises a ProxySchemeUnsupported if the provided ssl_context can't be used
26        for TLS in TLS.
27
28        The only requirement is that the ssl_context provides the 'wrap_bio'
29        methods.
30        """
31
32        if not hasattr(ssl_context, "wrap_bio"):
33            if six.PY2:
34                raise ProxySchemeUnsupported(
35                    "TLS in TLS requires SSLContext.wrap_bio() which isn't "
36                    "supported on Python 2"
37                )
38            else:
39                raise ProxySchemeUnsupported(
40                    "TLS in TLS requires SSLContext.wrap_bio() which isn't "
41                    "available on non-native SSLContext"
42                )
43
44    def __init__(
45        self, socket, ssl_context, server_hostname=None, suppress_ragged_eofs=True
46    ):
47        """
48        Create an SSLTransport around socket using the provided ssl_context.
49        """
50        self.incoming = ssl.MemoryBIO()
51        self.outgoing = ssl.MemoryBIO()
52
53        self.suppress_ragged_eofs = suppress_ragged_eofs
54        self.socket = socket
55
56        self.sslobj = ssl_context.wrap_bio(
57            self.incoming, self.outgoing, server_hostname=server_hostname
58        )
59
60        # Perform initial handshake.
61        self._ssl_io_loop(self.sslobj.do_handshake)
62
63    def __enter__(self):
64        return self
65
66    def __exit__(self, *_):
67        self.close()
68
69    def fileno(self):
70        return self.socket.fileno()
71
72    def read(self, len=1024, buffer=None):
73        return self._wrap_ssl_read(len, buffer)
74
75    def recv(self, len=1024, flags=0):
76        if flags != 0:
77            raise ValueError("non-zero flags not allowed in calls to recv")
78        return self._wrap_ssl_read(len)
79
80    def recv_into(self, buffer, nbytes=None, flags=0):
81        if flags != 0:
82            raise ValueError("non-zero flags not allowed in calls to recv_into")
83        if buffer and (nbytes is None):
84            nbytes = len(buffer)
85        elif nbytes is None:
86            nbytes = 1024
87        return self.read(nbytes, buffer)
88
89    def sendall(self, data, flags=0):
90        if flags != 0:
91            raise ValueError("non-zero flags not allowed in calls to sendall")
92        count = 0
93        with memoryview(data) as view, view.cast("B") as byte_view:
94            amount = len(byte_view)
95            while count < amount:
96                v = self.send(byte_view[count:])
97                count += v
98
99    def send(self, data, flags=0):
100        if flags != 0:
101            raise ValueError("non-zero flags not allowed in calls to send")
102        response = self._ssl_io_loop(self.sslobj.write, data)
103        return response
104
105    def makefile(
106        self, mode="r", buffering=None, encoding=None, errors=None, newline=None
107    ):
108        """
109        Python's httpclient uses makefile and buffered io when reading HTTP
110        messages and we need to support it.
111
112        This is unfortunately a copy and paste of socket.py makefile with small
113        changes to point to the socket directly.
114        """
115        if not set(mode) <= {"r", "w", "b"}:
116            raise ValueError("invalid mode %r (only r, w, b allowed)" % (mode,))
117
118        writing = "w" in mode
119        reading = "r" in mode or not writing
120        assert reading or writing
121        binary = "b" in mode
122        rawmode = ""
123        if reading:
124            rawmode += "r"
125        if writing:
126            rawmode += "w"
127        raw = socket.SocketIO(self, rawmode)
128        self.socket._io_refs += 1
129        if buffering is None:
130            buffering = -1
131        if buffering < 0:
132            buffering = io.DEFAULT_BUFFER_SIZE
133        if buffering == 0:
134            if not binary:
135                raise ValueError("unbuffered streams must be binary")
136            return raw
137        if reading and writing:
138            buffer = io.BufferedRWPair(raw, raw, buffering)
139        elif reading:
140            buffer = io.BufferedReader(raw, buffering)
141        else:
142            assert writing
143            buffer = io.BufferedWriter(raw, buffering)
144        if binary:
145            return buffer
146        text = io.TextIOWrapper(buffer, encoding, errors, newline)
147        text.mode = mode
148        return text
149
150    def unwrap(self):
151        self._ssl_io_loop(self.sslobj.unwrap)
152
153    def close(self):
154        self.socket.close()
155
156    def getpeercert(self, binary_form=False):
157        return self.sslobj.getpeercert(binary_form)
158
159    def version(self):
160        return self.sslobj.version()
161
162    def cipher(self):
163        return self.sslobj.cipher()
164
165    def selected_alpn_protocol(self):
166        return self.sslobj.selected_alpn_protocol()
167
168    def selected_npn_protocol(self):
169        return self.sslobj.selected_npn_protocol()
170
171    def shared_ciphers(self):
172        return self.sslobj.shared_ciphers()
173
174    def compression(self):
175        return self.sslobj.compression()
176
177    def settimeout(self, value):
178        self.socket.settimeout(value)
179
180    def gettimeout(self):
181        return self.socket.gettimeout()
182
183    def _decref_socketios(self):
184        self.socket._decref_socketios()
185
186    def _wrap_ssl_read(self, len, buffer=None):
187        try:
188            return self._ssl_io_loop(self.sslobj.read, len, buffer)
189        except ssl.SSLError as e:
190            if e.errno == ssl.SSL_ERROR_EOF and self.suppress_ragged_eofs:
191                return 0  # eof, return 0.
192            else:
193                raise
194
195    def _ssl_io_loop(self, func, *args):
196        """Performs an I/O loop between incoming/outgoing and the socket."""
197        should_loop = True
198        ret = None
199
200        while should_loop:
201            errno = None
202            try:
203                ret = func(*args)
204            except ssl.SSLError as e:
205                if e.errno not in (ssl.SSL_ERROR_WANT_READ, ssl.SSL_ERROR_WANT_WRITE):
206                    # WANT_READ, and WANT_WRITE are expected, others are not.
207                    raise e
208                errno = e.errno
209
210            buf = self.outgoing.read()
211            self.socket.sendall(buf)
212
213            if errno is None:
214                should_loop = False
215            elif errno == ssl.SSL_ERROR_WANT_READ:
216                buf = self.socket.recv(SSL_BLOCKSIZE)
217                if buf:
218                    self.incoming.write(buf)
219                else:
220                    self.incoming.write_eof()
221        return ret
222