1""" 2Copied from urllib3.util.ssltransport 3""" 4import io 5import socket 6import ssl 7 8 9SSL_BLOCKSIZE = 16384 10 11 12class SSLTransport: 13 """ 14 The SSLTransport wraps an existing socket and establishes an SSL connection. 15 16 Contrary to Python's implementation of SSLSocket, it allows you to chain 17 multiple TLS connections together. It's particularly useful if you need to 18 implement TLS within TLS. 19 20 The class supports most of the socket API operations. 21 """ 22 23 def __init__( 24 self, socket, ssl_context, server_hostname=None, suppress_ragged_eofs=True 25 ): 26 """ 27 Create an SSLTransport around socket using the provided ssl_context. 28 """ 29 self.incoming = ssl.MemoryBIO() 30 self.outgoing = ssl.MemoryBIO() 31 32 self.suppress_ragged_eofs = suppress_ragged_eofs 33 self.socket = socket 34 35 self.sslobj = ssl_context.wrap_bio( 36 self.incoming, self.outgoing, server_hostname=server_hostname 37 ) 38 39 # Perform initial handshake. 40 self._ssl_io_loop(self.sslobj.do_handshake) 41 42 def __enter__(self): 43 return self 44 45 def __exit__(self, *_): 46 self.close() 47 48 def fileno(self): 49 return self.socket.fileno() 50 51 def read(self, len=1024, buffer=None): 52 return self._wrap_ssl_read(len, buffer) 53 54 def recv(self, len=1024, flags=0): 55 if flags != 0: 56 raise ValueError("non-zero flags not allowed in calls to recv") 57 return self._wrap_ssl_read(len) 58 59 def recv_into(self, buffer, nbytes=None, flags=0): 60 if flags != 0: 61 raise ValueError("non-zero flags not allowed in calls to recv_into") 62 if buffer and (nbytes is None): 63 nbytes = len(buffer) 64 elif nbytes is None: 65 nbytes = 1024 66 return self.read(nbytes, buffer) 67 68 def sendall(self, data, flags=0): 69 if flags != 0: 70 raise ValueError("non-zero flags not allowed in calls to sendall") 71 count = 0 72 with memoryview(data) as view, view.cast("B") as byte_view: 73 amount = len(byte_view) 74 while count < amount: 75 v = self.send(byte_view[count:]) 76 count += v 77 78 def send(self, data, flags=0): 79 if flags != 0: 80 raise ValueError("non-zero flags not allowed in calls to send") 81 response = self._ssl_io_loop(self.sslobj.write, data) 82 return response 83 84 def makefile( 85 self, mode="r", buffering=None, encoding=None, errors=None, newline=None 86 ): 87 """ 88 Python's httpclient uses makefile and buffered io when reading HTTP 89 messages and we need to support it. 90 91 This is unfortunately a copy and paste of socket.py makefile with small 92 changes to point to the socket directly. 93 """ 94 if not set(mode) <= {"r", "w", "b"}: 95 raise ValueError("invalid mode %r (only r, w, b allowed)" % (mode,)) 96 97 writing = "w" in mode 98 reading = "r" in mode or not writing 99 assert reading or writing 100 binary = "b" in mode 101 rawmode = "" 102 if reading: 103 rawmode += "r" 104 if writing: 105 rawmode += "w" 106 raw = socket.SocketIO(self, rawmode) 107 self.socket._io_refs += 1 108 if buffering is None: 109 buffering = -1 110 if buffering < 0: 111 buffering = io.DEFAULT_BUFFER_SIZE 112 if buffering == 0: 113 if not binary: 114 raise ValueError("unbuffered streams must be binary") 115 return raw 116 if reading and writing: 117 buffer = io.BufferedRWPair(raw, raw, buffering) 118 elif reading: 119 buffer = io.BufferedReader(raw, buffering) 120 else: 121 assert writing 122 buffer = io.BufferedWriter(raw, buffering) 123 if binary: 124 return buffer 125 text = io.TextIOWrapper(buffer, encoding, errors, newline) 126 text.mode = mode 127 return text 128 129 def unwrap(self): 130 self._ssl_io_loop(self.sslobj.unwrap) 131 132 def close(self): 133 self.socket.close() 134 135 def getpeercert(self, binary_form=False): 136 return self.sslobj.getpeercert(binary_form) 137 138 def version(self): 139 return self.sslobj.version() 140 141 def cipher(self): 142 return self.sslobj.cipher() 143 144 def selected_alpn_protocol(self): 145 return self.sslobj.selected_alpn_protocol() 146 147 def selected_npn_protocol(self): 148 return self.sslobj.selected_npn_protocol() 149 150 def shared_ciphers(self): 151 return self.sslobj.shared_ciphers() 152 153 def compression(self): 154 return self.sslobj.compression() 155 156 def settimeout(self, value): 157 self.socket.settimeout(value) 158 159 def gettimeout(self): 160 return self.socket.gettimeout() 161 162 def _decref_socketios(self): 163 self.socket._decref_socketios() 164 165 def _wrap_ssl_read(self, len, buffer=None): 166 try: 167 return self._ssl_io_loop(self.sslobj.read, len, buffer) 168 except ssl.SSLError as e: 169 if e.errno == ssl.SSL_ERROR_EOF and self.suppress_ragged_eofs: 170 return 0 # eof, return 0. 171 else: 172 raise 173 174 def _ssl_io_loop(self, func, *args): 175 """Performs an I/O loop between incoming/outgoing and the socket.""" 176 should_loop = True 177 ret = None 178 179 while should_loop: 180 errno = None 181 try: 182 ret = func(*args) 183 except ssl.SSLError as e: 184 if e.errno not in (ssl.SSL_ERROR_WANT_READ, ssl.SSL_ERROR_WANT_WRITE): 185 # WANT_READ, and WANT_WRITE are expected, others are not. 186 raise e 187 errno = e.errno 188 189 buf = self.outgoing.read() 190 self.socket.sendall(buf) 191 192 if errno is None: 193 should_loop = False 194 elif errno == ssl.SSL_ERROR_WANT_READ: 195 buf = self.socket.recv(SSL_BLOCKSIZE) 196 if buf: 197 self.incoming.write(buf) 198 else: 199 self.incoming.write_eof() 200 return ret 201