1import io 2import socket 3import ssl 4 5from urllib3.exceptions import ProxySchemeUnsupported 6from urllib3.packages import six 7 8SSL_BLOCKSIZE = 16384 9 10 11class SSLTransport: 12 """ 13 The SSLTransport wraps an existing socket and establishes an SSL connection. 14 15 Contrary to Python's implementation of SSLSocket, it allows you to chain 16 multiple TLS connections together. It's particularly useful if you need to 17 implement TLS within TLS. 18 19 The class supports most of the socket API operations. 20 """ 21 22 @staticmethod 23 def _validate_ssl_context_for_tls_in_tls(ssl_context): 24 """ 25 Raises a ProxySchemeUnsupported if the provided ssl_context can't be used 26 for TLS in TLS. 27 28 The only requirement is that the ssl_context provides the 'wrap_bio' 29 methods. 30 """ 31 32 if not hasattr(ssl_context, "wrap_bio"): 33 if six.PY2: 34 raise ProxySchemeUnsupported( 35 "TLS in TLS requires SSLContext.wrap_bio() which isn't " 36 "supported on Python 2" 37 ) 38 else: 39 raise ProxySchemeUnsupported( 40 "TLS in TLS requires SSLContext.wrap_bio() which isn't " 41 "available on non-native SSLContext" 42 ) 43 44 def __init__( 45 self, socket, ssl_context, server_hostname=None, suppress_ragged_eofs=True 46 ): 47 """ 48 Create an SSLTransport around socket using the provided ssl_context. 49 """ 50 self.incoming = ssl.MemoryBIO() 51 self.outgoing = ssl.MemoryBIO() 52 53 self.suppress_ragged_eofs = suppress_ragged_eofs 54 self.socket = socket 55 56 self.sslobj = ssl_context.wrap_bio( 57 self.incoming, self.outgoing, server_hostname=server_hostname 58 ) 59 60 # Perform initial handshake. 61 self._ssl_io_loop(self.sslobj.do_handshake) 62 63 def __enter__(self): 64 return self 65 66 def __exit__(self, *_): 67 self.close() 68 69 def fileno(self): 70 return self.socket.fileno() 71 72 def read(self, len=1024, buffer=None): 73 return self._wrap_ssl_read(len, buffer) 74 75 def recv(self, len=1024, flags=0): 76 if flags != 0: 77 raise ValueError("non-zero flags not allowed in calls to recv") 78 return self._wrap_ssl_read(len) 79 80 def recv_into(self, buffer, nbytes=None, flags=0): 81 if flags != 0: 82 raise ValueError("non-zero flags not allowed in calls to recv_into") 83 if buffer and (nbytes is None): 84 nbytes = len(buffer) 85 elif nbytes is None: 86 nbytes = 1024 87 return self.read(nbytes, buffer) 88 89 def sendall(self, data, flags=0): 90 if flags != 0: 91 raise ValueError("non-zero flags not allowed in calls to sendall") 92 count = 0 93 with memoryview(data) as view, view.cast("B") as byte_view: 94 amount = len(byte_view) 95 while count < amount: 96 v = self.send(byte_view[count:]) 97 count += v 98 99 def send(self, data, flags=0): 100 if flags != 0: 101 raise ValueError("non-zero flags not allowed in calls to send") 102 response = self._ssl_io_loop(self.sslobj.write, data) 103 return response 104 105 def makefile( 106 self, mode="r", buffering=None, encoding=None, errors=None, newline=None 107 ): 108 """ 109 Python's httpclient uses makefile and buffered io when reading HTTP 110 messages and we need to support it. 111 112 This is unfortunately a copy and paste of socket.py makefile with small 113 changes to point to the socket directly. 114 """ 115 if not set(mode) <= {"r", "w", "b"}: 116 raise ValueError("invalid mode %r (only r, w, b allowed)" % (mode,)) 117 118 writing = "w" in mode 119 reading = "r" in mode or not writing 120 assert reading or writing 121 binary = "b" in mode 122 rawmode = "" 123 if reading: 124 rawmode += "r" 125 if writing: 126 rawmode += "w" 127 raw = socket.SocketIO(self, rawmode) 128 self.socket._io_refs += 1 129 if buffering is None: 130 buffering = -1 131 if buffering < 0: 132 buffering = io.DEFAULT_BUFFER_SIZE 133 if buffering == 0: 134 if not binary: 135 raise ValueError("unbuffered streams must be binary") 136 return raw 137 if reading and writing: 138 buffer = io.BufferedRWPair(raw, raw, buffering) 139 elif reading: 140 buffer = io.BufferedReader(raw, buffering) 141 else: 142 assert writing 143 buffer = io.BufferedWriter(raw, buffering) 144 if binary: 145 return buffer 146 text = io.TextIOWrapper(buffer, encoding, errors, newline) 147 text.mode = mode 148 return text 149 150 def unwrap(self): 151 self._ssl_io_loop(self.sslobj.unwrap) 152 153 def close(self): 154 self.socket.close() 155 156 def getpeercert(self, binary_form=False): 157 return self.sslobj.getpeercert(binary_form) 158 159 def version(self): 160 return self.sslobj.version() 161 162 def cipher(self): 163 return self.sslobj.cipher() 164 165 def selected_alpn_protocol(self): 166 return self.sslobj.selected_alpn_protocol() 167 168 def selected_npn_protocol(self): 169 return self.sslobj.selected_npn_protocol() 170 171 def shared_ciphers(self): 172 return self.sslobj.shared_ciphers() 173 174 def compression(self): 175 return self.sslobj.compression() 176 177 def settimeout(self, value): 178 self.socket.settimeout(value) 179 180 def gettimeout(self): 181 return self.socket.gettimeout() 182 183 def _decref_socketios(self): 184 self.socket._decref_socketios() 185 186 def _wrap_ssl_read(self, len, buffer=None): 187 try: 188 return self._ssl_io_loop(self.sslobj.read, len, buffer) 189 except ssl.SSLError as e: 190 if e.errno == ssl.SSL_ERROR_EOF and self.suppress_ragged_eofs: 191 return 0 # eof, return 0. 192 else: 193 raise 194 195 def _ssl_io_loop(self, func, *args): 196 """Performs an I/O loop between incoming/outgoing and the socket.""" 197 should_loop = True 198 ret = None 199 200 while should_loop: 201 errno = None 202 try: 203 ret = func(*args) 204 except ssl.SSLError as e: 205 if e.errno not in (ssl.SSL_ERROR_WANT_READ, ssl.SSL_ERROR_WANT_WRITE): 206 # WANT_READ, and WANT_WRITE are expected, others are not. 207 raise e 208 errno = e.errno 209 210 buf = self.outgoing.read() 211 self.socket.sendall(buf) 212 213 if errno is None: 214 should_loop = False 215 elif errno == ssl.SSL_ERROR_WANT_READ: 216 buf = self.socket.recv(SSL_BLOCKSIZE) 217 if buf: 218 self.incoming.write(buf) 219 else: 220 self.incoming.write_eof() 221 return ret 222