1# 2# Licensed to the Apache Software Foundation (ASF) under one 3# or more contributor license agreements. See the NOTICE file 4# distributed with this work for additional information 5# regarding copyright ownership. The ASF licenses this file 6# to you under the Apache License, Version 2.0 (the 7# "License"); you may not use this file except in compliance 8# with the License. You may obtain a copy of the License at 9# 10# http://www.apache.org/licenses/LICENSE-2.0 11# 12# Unless required by applicable law or agreed to in writing, 13# software distributed under the License is distributed on an 14# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 15# KIND, either express or implied. See the License for the 16# specific language governing permissions and limitations 17# under the License. 18# 19 20import inspect 21import logging 22import os 23import platform 24import ssl 25import sys 26import tempfile 27import threading 28import unittest 29import warnings 30from contextlib import contextmanager 31 32import _import_local_thrift # noqa 33 34SCRIPT_DIR = os.path.realpath(os.path.dirname(__file__)) 35ROOT_DIR = os.path.dirname(os.path.dirname(os.path.dirname(SCRIPT_DIR))) 36SERVER_PEM = os.path.join(ROOT_DIR, 'test', 'keys', 'server.pem') 37SERVER_CERT = os.path.join(ROOT_DIR, 'test', 'keys', 'server.crt') 38SERVER_KEY = os.path.join(ROOT_DIR, 'test', 'keys', 'server.key') 39CLIENT_CERT_NO_IP = os.path.join(ROOT_DIR, 'test', 'keys', 'client.crt') 40CLIENT_KEY_NO_IP = os.path.join(ROOT_DIR, 'test', 'keys', 'client.key') 41CLIENT_CERT = os.path.join(ROOT_DIR, 'test', 'keys', 'client_v3.crt') 42CLIENT_KEY = os.path.join(ROOT_DIR, 'test', 'keys', 'client_v3.key') 43CLIENT_CA = os.path.join(ROOT_DIR, 'test', 'keys', 'CA.pem') 44 45TEST_CIPHERS = 'DES-CBC3-SHA:ECDHE-RSA-AES128-GCM-SHA256' 46 47 48class ServerAcceptor(threading.Thread): 49 def __init__(self, server, expect_failure=False): 50 super(ServerAcceptor, self).__init__() 51 self.daemon = True 52 self._server = server 53 self._listening = threading.Event() 54 self._port = None 55 self._port_bound = threading.Event() 56 self._client = None 57 self._client_accepted = threading.Event() 58 self._expect_failure = expect_failure 59 frame = inspect.stack(3)[2] 60 self.name = frame[3] 61 del frame 62 63 def run(self): 64 self._server.listen() 65 self._listening.set() 66 67 try: 68 address = self._server.handle.getsockname() 69 if len(address) > 1: 70 # AF_INET addresses are 2-tuples (host, port) and AF_INET6 are 71 # 4-tuples (host, port, ...), but in each case port is in the second slot. 72 self._port = address[1] 73 finally: 74 self._port_bound.set() 75 76 try: 77 self._client = self._server.accept() 78 if self._client: 79 self._client.read(5) # hello 80 self._client.write(b"there") 81 except Exception: 82 logging.exception('error on server side (%s):' % self.name) 83 if not self._expect_failure: 84 raise 85 finally: 86 self._client_accepted.set() 87 88 def await_listening(self): 89 self._listening.wait() 90 91 @property 92 def port(self): 93 self._port_bound.wait() 94 return self._port 95 96 @property 97 def client(self): 98 self._client_accepted.wait() 99 return self._client 100 101 def close(self): 102 if self._client: 103 self._client.close() 104 self._server.close() 105 106 107# Python 2.6 compat 108class AssertRaises(object): 109 def __init__(self, expected): 110 self._expected = expected 111 112 def __enter__(self): 113 pass 114 115 def __exit__(self, exc_type, exc_value, traceback): 116 if not exc_type or not issubclass(exc_type, self._expected): 117 raise Exception('fail') 118 return True 119 120 121class TSSLSocketTest(unittest.TestCase): 122 def _server_socket(self, **kwargs): 123 return TSSLServerSocket(port=0, **kwargs) 124 125 @contextmanager 126 def _connectable_client(self, server, expect_failure=False, path=None, **client_kwargs): 127 acc = ServerAcceptor(server, expect_failure) 128 try: 129 acc.start() 130 acc.await_listening() 131 132 host, port = ('localhost', acc.port) if path is None else (None, None) 133 client = TSSLSocket(host, port, unix_socket=path, **client_kwargs) 134 yield acc, client 135 finally: 136 acc.close() 137 138 def _assert_connection_failure(self, server, path=None, **client_args): 139 logging.disable(logging.CRITICAL) 140 try: 141 with self._connectable_client(server, True, path=path, **client_args) as (acc, client): 142 # We need to wait for a connection failure, but not too long. 20ms is a tunable 143 # compromise between test speed and stability 144 client.setTimeout(20) 145 with self._assert_raises(TTransportException): 146 client.open() 147 client.write(b"hello") 148 client.read(5) # b"there" 149 finally: 150 logging.disable(logging.NOTSET) 151 152 def _assert_raises(self, exc): 153 if sys.hexversion >= 0x020700F0: 154 return self.assertRaises(exc) 155 else: 156 return AssertRaises(exc) 157 158 def _assert_connection_success(self, server, path=None, **client_args): 159 with self._connectable_client(server, path=path, **client_args) as (acc, client): 160 try: 161 client.open() 162 client.write(b"hello") 163 self.assertEqual(client.read(5), b"there") 164 self.assertTrue(acc.client is not None) 165 finally: 166 client.close() 167 168 # deprecated feature 169 def test_deprecation(self): 170 with warnings.catch_warnings(record=True) as w: 171 warnings.filterwarnings('always', category=DeprecationWarning, module=self.__module__) 172 TSSLSocket('localhost', 0, validate=True, ca_certs=SERVER_CERT) 173 self.assertEqual(len(w), 1) 174 175 with warnings.catch_warnings(record=True) as w: 176 warnings.filterwarnings('always', category=DeprecationWarning, module=self.__module__) 177 # Deprecated signature 178 # def __init__(self, host='localhost', port=9090, validate=True, ca_certs=None, keyfile=None, certfile=None, unix_socket=None, ciphers=None): 179 TSSLSocket('localhost', 0, True, SERVER_CERT, CLIENT_KEY, CLIENT_CERT, None, TEST_CIPHERS) 180 self.assertEqual(len(w), 7) 181 182 with warnings.catch_warnings(record=True) as w: 183 warnings.filterwarnings('always', category=DeprecationWarning, module=self.__module__) 184 # Deprecated signature 185 # def __init__(self, host=None, port=9090, certfile='cert.pem', unix_socket=None, ciphers=None): 186 TSSLServerSocket(None, 0, SERVER_PEM, None, TEST_CIPHERS) 187 self.assertEqual(len(w), 3) 188 189 # deprecated feature 190 def test_set_cert_reqs_by_validate(self): 191 with warnings.catch_warnings(record=True) as w: 192 warnings.filterwarnings('always', category=DeprecationWarning, module=self.__module__) 193 c1 = TSSLSocket('localhost', 0, validate=True, ca_certs=SERVER_CERT) 194 self.assertEqual(c1.cert_reqs, ssl.CERT_REQUIRED) 195 196 c1 = TSSLSocket('localhost', 0, validate=False) 197 self.assertEqual(c1.cert_reqs, ssl.CERT_NONE) 198 199 self.assertEqual(len(w), 2) 200 201 # deprecated feature 202 def test_set_validate_by_cert_reqs(self): 203 with warnings.catch_warnings(record=True) as w: 204 warnings.filterwarnings('always', category=DeprecationWarning, module=self.__module__) 205 c1 = TSSLSocket('localhost', 0, cert_reqs=ssl.CERT_NONE) 206 self.assertFalse(c1.validate) 207 208 c2 = TSSLSocket('localhost', 0, cert_reqs=ssl.CERT_REQUIRED, ca_certs=SERVER_CERT) 209 self.assertTrue(c2.validate) 210 211 c3 = TSSLSocket('localhost', 0, cert_reqs=ssl.CERT_OPTIONAL, ca_certs=SERVER_CERT) 212 self.assertTrue(c3.validate) 213 214 self.assertEqual(len(w), 3) 215 216 def test_unix_domain_socket(self): 217 if platform.system() == 'Windows': 218 print('skipping test_unix_domain_socket') 219 return 220 fd, path = tempfile.mkstemp() 221 os.close(fd) 222 os.unlink(path) 223 try: 224 server = self._server_socket(unix_socket=path, keyfile=SERVER_KEY, certfile=SERVER_CERT) 225 self._assert_connection_success(server, path=path, cert_reqs=ssl.CERT_NONE) 226 finally: 227 os.unlink(path) 228 229 def test_server_cert(self): 230 server = self._server_socket(keyfile=SERVER_KEY, certfile=SERVER_CERT) 231 self._assert_connection_success(server, cert_reqs=ssl.CERT_REQUIRED, ca_certs=SERVER_CERT) 232 233 server = self._server_socket(keyfile=SERVER_KEY, certfile=SERVER_CERT) 234 # server cert not in ca_certs 235 self._assert_connection_failure(server, cert_reqs=ssl.CERT_REQUIRED, ca_certs=CLIENT_CERT) 236 237 server = self._server_socket(keyfile=SERVER_KEY, certfile=SERVER_CERT) 238 self._assert_connection_success(server, cert_reqs=ssl.CERT_NONE) 239 240 def test_set_server_cert(self): 241 server = self._server_socket(keyfile=SERVER_KEY, certfile=CLIENT_CERT) 242 with self._assert_raises(Exception): 243 server.certfile = 'foo' 244 with self._assert_raises(Exception): 245 server.certfile = None 246 server.certfile = SERVER_CERT 247 self._assert_connection_success(server, cert_reqs=ssl.CERT_REQUIRED, ca_certs=SERVER_CERT) 248 249 def test_client_cert(self): 250 if not _match_has_ipaddress: 251 print('skipping test_client_cert') 252 return 253 server = self._server_socket( 254 cert_reqs=ssl.CERT_REQUIRED, keyfile=SERVER_KEY, 255 certfile=SERVER_CERT, ca_certs=CLIENT_CERT) 256 self._assert_connection_failure(server, cert_reqs=ssl.CERT_NONE, certfile=SERVER_CERT, keyfile=SERVER_KEY) 257 258 server = self._server_socket( 259 cert_reqs=ssl.CERT_REQUIRED, keyfile=SERVER_KEY, 260 certfile=SERVER_CERT, ca_certs=CLIENT_CA) 261 self._assert_connection_failure(server, cert_reqs=ssl.CERT_NONE, certfile=CLIENT_CERT_NO_IP, keyfile=CLIENT_KEY_NO_IP) 262 263 server = self._server_socket( 264 cert_reqs=ssl.CERT_REQUIRED, keyfile=SERVER_KEY, 265 certfile=SERVER_CERT, ca_certs=CLIENT_CA) 266 self._assert_connection_success(server, cert_reqs=ssl.CERT_NONE, certfile=CLIENT_CERT, keyfile=CLIENT_KEY) 267 268 server = self._server_socket( 269 cert_reqs=ssl.CERT_OPTIONAL, keyfile=SERVER_KEY, 270 certfile=SERVER_CERT, ca_certs=CLIENT_CA) 271 self._assert_connection_success(server, cert_reqs=ssl.CERT_NONE, certfile=CLIENT_CERT, keyfile=CLIENT_KEY) 272 273 def test_ciphers(self): 274 server = self._server_socket(keyfile=SERVER_KEY, certfile=SERVER_CERT, ciphers=TEST_CIPHERS) 275 self._assert_connection_success(server, ca_certs=SERVER_CERT, ciphers=TEST_CIPHERS) 276 277 if not TSSLSocket._has_ciphers: 278 # unittest.skip is not available for Python 2.6 279 print('skipping test_ciphers') 280 return 281 server = self._server_socket(keyfile=SERVER_KEY, certfile=SERVER_CERT) 282 self._assert_connection_failure(server, ca_certs=SERVER_CERT, ciphers='NULL') 283 284 server = self._server_socket(keyfile=SERVER_KEY, certfile=SERVER_CERT, ciphers=TEST_CIPHERS) 285 self._assert_connection_failure(server, ca_certs=SERVER_CERT, ciphers='NULL') 286 287 def test_ssl2_and_ssl3_disabled(self): 288 if not hasattr(ssl, 'PROTOCOL_SSLv3'): 289 print('PROTOCOL_SSLv3 is not available') 290 else: 291 server = self._server_socket(keyfile=SERVER_KEY, certfile=SERVER_CERT) 292 self._assert_connection_failure(server, ca_certs=SERVER_CERT, ssl_version=ssl.PROTOCOL_SSLv3) 293 294 server = self._server_socket(keyfile=SERVER_KEY, certfile=SERVER_CERT, ssl_version=ssl.PROTOCOL_SSLv3) 295 self._assert_connection_failure(server, ca_certs=SERVER_CERT) 296 297 if not hasattr(ssl, 'PROTOCOL_SSLv2'): 298 print('PROTOCOL_SSLv2 is not available') 299 else: 300 server = self._server_socket(keyfile=SERVER_KEY, certfile=SERVER_CERT) 301 self._assert_connection_failure(server, ca_certs=SERVER_CERT, ssl_version=ssl.PROTOCOL_SSLv2) 302 303 server = self._server_socket(keyfile=SERVER_KEY, certfile=SERVER_CERT, ssl_version=ssl.PROTOCOL_SSLv2) 304 self._assert_connection_failure(server, ca_certs=SERVER_CERT) 305 306 def test_newer_tls(self): 307 if not TSSLSocket._has_ssl_context: 308 # unittest.skip is not available for Python 2.6 309 print('skipping test_newer_tls') 310 return 311 if not hasattr(ssl, 'PROTOCOL_TLSv1_2'): 312 print('PROTOCOL_TLSv1_2 is not available') 313 else: 314 server = self._server_socket(keyfile=SERVER_KEY, certfile=SERVER_CERT, ssl_version=ssl.PROTOCOL_TLSv1_2) 315 self._assert_connection_success(server, ca_certs=SERVER_CERT, ssl_version=ssl.PROTOCOL_TLSv1_2) 316 317 if not hasattr(ssl, 'PROTOCOL_TLSv1_1'): 318 print('PROTOCOL_TLSv1_1 is not available') 319 else: 320 server = self._server_socket(keyfile=SERVER_KEY, certfile=SERVER_CERT, ssl_version=ssl.PROTOCOL_TLSv1_1) 321 self._assert_connection_success(server, ca_certs=SERVER_CERT, ssl_version=ssl.PROTOCOL_TLSv1_1) 322 323 if not hasattr(ssl, 'PROTOCOL_TLSv1_1') or not hasattr(ssl, 'PROTOCOL_TLSv1_2'): 324 print('PROTOCOL_TLSv1_1 and/or PROTOCOL_TLSv1_2 is not available') 325 else: 326 server = self._server_socket(keyfile=SERVER_KEY, certfile=SERVER_CERT, ssl_version=ssl.PROTOCOL_TLSv1_2) 327 self._assert_connection_failure(server, ca_certs=SERVER_CERT, ssl_version=ssl.PROTOCOL_TLSv1_1) 328 329 def test_ssl_context(self): 330 if not TSSLSocket._has_ssl_context: 331 # unittest.skip is not available for Python 2.6 332 print('skipping test_ssl_context') 333 return 334 server_context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH) 335 server_context.load_cert_chain(SERVER_CERT, SERVER_KEY) 336 server_context.load_verify_locations(CLIENT_CA) 337 server_context.verify_mode = ssl.CERT_REQUIRED 338 server = self._server_socket(ssl_context=server_context) 339 340 client_context = ssl.create_default_context(ssl.Purpose.SERVER_AUTH) 341 client_context.load_cert_chain(CLIENT_CERT, CLIENT_KEY) 342 client_context.load_verify_locations(SERVER_CERT) 343 client_context.verify_mode = ssl.CERT_REQUIRED 344 345 self._assert_connection_success(server, ssl_context=client_context) 346 347 348if __name__ == '__main__': 349 logging.basicConfig(level=logging.WARN) 350 from thrift.transport.TSSLSocket import TSSLSocket, TSSLServerSocket, _match_has_ipaddress 351 from thrift.transport.TTransport import TTransportException 352 353 unittest.main() 354