1# 2# Licensed to the Apache Software Foundation (ASF) under one 3# or more contributor license agreements. See the NOTICE file 4# distributed with this work for additional information 5# regarding copyright ownership. The ASF licenses this file 6# to you under the Apache License, Version 2.0 (the 7# "License"); you may not use this file except in compliance 8# with the License. You may obtain a copy of the License at 9# 10# http://www.apache.org/licenses/LICENSE-2.0 11# 12# Unless required by applicable law or agreed to in writing, 13# software distributed under the License is distributed on an 14# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 15# KIND, either express or implied. See the License for the 16# specific language governing permissions and limitations 17# under the License. 18# 19 20import inspect 21import logging 22import os 23import platform 24import ssl 25import sys 26import tempfile 27import threading 28import unittest 29import warnings 30from contextlib import contextmanager 31 32import _import_local_thrift # noqa 33 34SCRIPT_DIR = os.path.realpath(os.path.dirname(__file__)) 35ROOT_DIR = os.path.dirname(os.path.dirname(os.path.dirname(SCRIPT_DIR))) 36SERVER_PEM = os.path.join(ROOT_DIR, 'test', 'keys', 'server.pem') 37SERVER_CERT = os.path.join(ROOT_DIR, 'test', 'keys', 'server.crt') 38SERVER_KEY = os.path.join(ROOT_DIR, 'test', 'keys', 'server.key') 39CLIENT_CERT_NO_IP = os.path.join(ROOT_DIR, 'test', 'keys', 'client.crt') 40CLIENT_KEY_NO_IP = os.path.join(ROOT_DIR, 'test', 'keys', 'client.key') 41CLIENT_CERT = os.path.join(ROOT_DIR, 'test', 'keys', 'client_v3.crt') 42CLIENT_KEY = os.path.join(ROOT_DIR, 'test', 'keys', 'client_v3.key') 43CLIENT_CA = os.path.join(ROOT_DIR, 'test', 'keys', 'CA.pem') 44 45TEST_CIPHERS = 'DES-CBC3-SHA:ECDHE-RSA-AES128-GCM-SHA256' 46 47 48class ServerAcceptor(threading.Thread): 49 def __init__(self, server, expect_failure=False): 50 super(ServerAcceptor, self).__init__() 51 self.daemon = True 52 self._server = server 53 self._listening = threading.Event() 54 self._port = None 55 self._port_bound = threading.Event() 56 self._client = None 57 self._client_accepted = threading.Event() 58 self._expect_failure = expect_failure 59 frame = inspect.stack(3)[2] 60 self.name = frame[3] 61 del frame 62 63 def run(self): 64 self._server.listen() 65 self._listening.set() 66 67 try: 68 address = self._server.handle.getsockname() 69 if len(address) > 1: 70 # AF_INET addresses are 2-tuples (host, port) and AF_INET6 are 71 # 4-tuples (host, port, ...), but in each case port is in the second slot. 72 self._port = address[1] 73 finally: 74 self._port_bound.set() 75 76 try: 77 self._client = self._server.accept() 78 except Exception: 79 logging.exception('error on server side (%s):' % self.name) 80 if not self._expect_failure: 81 raise 82 finally: 83 self._client_accepted.set() 84 85 def await_listening(self): 86 self._listening.wait() 87 88 @property 89 def port(self): 90 self._port_bound.wait() 91 return self._port 92 93 @property 94 def client(self): 95 self._client_accepted.wait() 96 return self._client 97 98 def close(self): 99 if self._client: 100 self._client.close() 101 self._server.close() 102 103 104# Python 2.6 compat 105class AssertRaises(object): 106 def __init__(self, expected): 107 self._expected = expected 108 109 def __enter__(self): 110 pass 111 112 def __exit__(self, exc_type, exc_value, traceback): 113 if not exc_type or not issubclass(exc_type, self._expected): 114 raise Exception('fail') 115 return True 116 117 118class TSSLSocketTest(unittest.TestCase): 119 def _server_socket(self, **kwargs): 120 return TSSLServerSocket(port=0, **kwargs) 121 122 @contextmanager 123 def _connectable_client(self, server, expect_failure=False, path=None, **client_kwargs): 124 acc = ServerAcceptor(server, expect_failure) 125 try: 126 acc.start() 127 acc.await_listening() 128 129 host, port = ('localhost', acc.port) if path is None else (None, None) 130 client = TSSLSocket(host, port, unix_socket=path, **client_kwargs) 131 yield acc, client 132 finally: 133 acc.close() 134 135 def _assert_connection_failure(self, server, path=None, **client_args): 136 logging.disable(logging.CRITICAL) 137 try: 138 with self._connectable_client(server, True, path=path, **client_args) as (acc, client): 139 # We need to wait for a connection failure, but not too long. 20ms is a tunable 140 # compromise between test speed and stability 141 client.setTimeout(20) 142 with self._assert_raises(TTransportException): 143 client.open() 144 self.assertTrue(acc.client is None) 145 finally: 146 logging.disable(logging.NOTSET) 147 148 def _assert_raises(self, exc): 149 if sys.hexversion >= 0x020700F0: 150 return self.assertRaises(exc) 151 else: 152 return AssertRaises(exc) 153 154 def _assert_connection_success(self, server, path=None, **client_args): 155 with self._connectable_client(server, path=path, **client_args) as (acc, client): 156 client.open() 157 try: 158 self.assertTrue(acc.client is not None) 159 finally: 160 client.close() 161 162 # deprecated feature 163 def test_deprecation(self): 164 with warnings.catch_warnings(record=True) as w: 165 warnings.filterwarnings('always', category=DeprecationWarning, module=self.__module__) 166 TSSLSocket('localhost', 0, validate=True, ca_certs=SERVER_CERT) 167 self.assertEqual(len(w), 1) 168 169 with warnings.catch_warnings(record=True) as w: 170 warnings.filterwarnings('always', category=DeprecationWarning, module=self.__module__) 171 # Deprecated signature 172 # def __init__(self, host='localhost', port=9090, validate=True, ca_certs=None, keyfile=None, certfile=None, unix_socket=None, ciphers=None): 173 TSSLSocket('localhost', 0, True, SERVER_CERT, CLIENT_KEY, CLIENT_CERT, None, TEST_CIPHERS) 174 self.assertEqual(len(w), 7) 175 176 with warnings.catch_warnings(record=True) as w: 177 warnings.filterwarnings('always', category=DeprecationWarning, module=self.__module__) 178 # Deprecated signature 179 # def __init__(self, host=None, port=9090, certfile='cert.pem', unix_socket=None, ciphers=None): 180 TSSLServerSocket(None, 0, SERVER_PEM, None, TEST_CIPHERS) 181 self.assertEqual(len(w), 3) 182 183 # deprecated feature 184 def test_set_cert_reqs_by_validate(self): 185 with warnings.catch_warnings(record=True) as w: 186 warnings.filterwarnings('always', category=DeprecationWarning, module=self.__module__) 187 c1 = TSSLSocket('localhost', 0, validate=True, ca_certs=SERVER_CERT) 188 self.assertEqual(c1.cert_reqs, ssl.CERT_REQUIRED) 189 190 c1 = TSSLSocket('localhost', 0, validate=False) 191 self.assertEqual(c1.cert_reqs, ssl.CERT_NONE) 192 193 self.assertEqual(len(w), 2) 194 195 # deprecated feature 196 def test_set_validate_by_cert_reqs(self): 197 with warnings.catch_warnings(record=True) as w: 198 warnings.filterwarnings('always', category=DeprecationWarning, module=self.__module__) 199 c1 = TSSLSocket('localhost', 0, cert_reqs=ssl.CERT_NONE) 200 self.assertFalse(c1.validate) 201 202 c2 = TSSLSocket('localhost', 0, cert_reqs=ssl.CERT_REQUIRED, ca_certs=SERVER_CERT) 203 self.assertTrue(c2.validate) 204 205 c3 = TSSLSocket('localhost', 0, cert_reqs=ssl.CERT_OPTIONAL, ca_certs=SERVER_CERT) 206 self.assertTrue(c3.validate) 207 208 self.assertEqual(len(w), 3) 209 210 def test_unix_domain_socket(self): 211 if platform.system() == 'Windows': 212 print('skipping test_unix_domain_socket') 213 return 214 fd, path = tempfile.mkstemp() 215 os.close(fd) 216 try: 217 server = self._server_socket(unix_socket=path, keyfile=SERVER_KEY, certfile=SERVER_CERT) 218 self._assert_connection_success(server, path=path, cert_reqs=ssl.CERT_NONE) 219 finally: 220 os.unlink(path) 221 222 def test_server_cert(self): 223 server = self._server_socket(keyfile=SERVER_KEY, certfile=SERVER_CERT) 224 self._assert_connection_success(server, cert_reqs=ssl.CERT_REQUIRED, ca_certs=SERVER_CERT) 225 226 server = self._server_socket(keyfile=SERVER_KEY, certfile=SERVER_CERT) 227 # server cert not in ca_certs 228 self._assert_connection_failure(server, cert_reqs=ssl.CERT_REQUIRED, ca_certs=CLIENT_CERT) 229 230 server = self._server_socket(keyfile=SERVER_KEY, certfile=SERVER_CERT) 231 self._assert_connection_success(server, cert_reqs=ssl.CERT_NONE) 232 233 def test_set_server_cert(self): 234 server = self._server_socket(keyfile=SERVER_KEY, certfile=CLIENT_CERT) 235 with self._assert_raises(Exception): 236 server.certfile = 'foo' 237 with self._assert_raises(Exception): 238 server.certfile = None 239 server.certfile = SERVER_CERT 240 self._assert_connection_success(server, cert_reqs=ssl.CERT_REQUIRED, ca_certs=SERVER_CERT) 241 242 def test_client_cert(self): 243 if not _match_has_ipaddress: 244 print('skipping test_client_cert') 245 return 246 server = self._server_socket( 247 cert_reqs=ssl.CERT_REQUIRED, keyfile=SERVER_KEY, 248 certfile=SERVER_CERT, ca_certs=CLIENT_CERT) 249 self._assert_connection_failure(server, cert_reqs=ssl.CERT_NONE, certfile=SERVER_CERT, keyfile=SERVER_KEY) 250 251 server = self._server_socket( 252 cert_reqs=ssl.CERT_REQUIRED, keyfile=SERVER_KEY, 253 certfile=SERVER_CERT, ca_certs=CLIENT_CA) 254 self._assert_connection_failure(server, cert_reqs=ssl.CERT_NONE, certfile=CLIENT_CERT_NO_IP, keyfile=CLIENT_KEY_NO_IP) 255 256 server = self._server_socket( 257 cert_reqs=ssl.CERT_REQUIRED, keyfile=SERVER_KEY, 258 certfile=SERVER_CERT, ca_certs=CLIENT_CA) 259 self._assert_connection_success(server, cert_reqs=ssl.CERT_NONE, certfile=CLIENT_CERT, keyfile=CLIENT_KEY) 260 261 server = self._server_socket( 262 cert_reqs=ssl.CERT_OPTIONAL, keyfile=SERVER_KEY, 263 certfile=SERVER_CERT, ca_certs=CLIENT_CA) 264 self._assert_connection_success(server, cert_reqs=ssl.CERT_NONE, certfile=CLIENT_CERT, keyfile=CLIENT_KEY) 265 266 def test_ciphers(self): 267 server = self._server_socket(keyfile=SERVER_KEY, certfile=SERVER_CERT, ciphers=TEST_CIPHERS) 268 self._assert_connection_success(server, ca_certs=SERVER_CERT, ciphers=TEST_CIPHERS) 269 270 if not TSSLSocket._has_ciphers: 271 # unittest.skip is not available for Python 2.6 272 print('skipping test_ciphers') 273 return 274 server = self._server_socket(keyfile=SERVER_KEY, certfile=SERVER_CERT) 275 self._assert_connection_failure(server, ca_certs=SERVER_CERT, ciphers='NULL') 276 277 server = self._server_socket(keyfile=SERVER_KEY, certfile=SERVER_CERT, ciphers=TEST_CIPHERS) 278 self._assert_connection_failure(server, ca_certs=SERVER_CERT, ciphers='NULL') 279 280 def test_ssl2_and_ssl3_disabled(self): 281 if not hasattr(ssl, 'PROTOCOL_SSLv3'): 282 print('PROTOCOL_SSLv3 is not available') 283 else: 284 server = self._server_socket(keyfile=SERVER_KEY, certfile=SERVER_CERT) 285 self._assert_connection_failure(server, ca_certs=SERVER_CERT, ssl_version=ssl.PROTOCOL_SSLv3) 286 287 server = self._server_socket(keyfile=SERVER_KEY, certfile=SERVER_CERT, ssl_version=ssl.PROTOCOL_SSLv3) 288 self._assert_connection_failure(server, ca_certs=SERVER_CERT) 289 290 if not hasattr(ssl, 'PROTOCOL_SSLv2'): 291 print('PROTOCOL_SSLv2 is not available') 292 else: 293 server = self._server_socket(keyfile=SERVER_KEY, certfile=SERVER_CERT) 294 self._assert_connection_failure(server, ca_certs=SERVER_CERT, ssl_version=ssl.PROTOCOL_SSLv2) 295 296 server = self._server_socket(keyfile=SERVER_KEY, certfile=SERVER_CERT, ssl_version=ssl.PROTOCOL_SSLv2) 297 self._assert_connection_failure(server, ca_certs=SERVER_CERT) 298 299 def test_newer_tls(self): 300 if not TSSLSocket._has_ssl_context: 301 # unittest.skip is not available for Python 2.6 302 print('skipping test_newer_tls') 303 return 304 if not hasattr(ssl, 'PROTOCOL_TLSv1_2'): 305 print('PROTOCOL_TLSv1_2 is not available') 306 else: 307 server = self._server_socket(keyfile=SERVER_KEY, certfile=SERVER_CERT, ssl_version=ssl.PROTOCOL_TLSv1_2) 308 self._assert_connection_success(server, ca_certs=SERVER_CERT, ssl_version=ssl.PROTOCOL_TLSv1_2) 309 310 if not hasattr(ssl, 'PROTOCOL_TLSv1_1'): 311 print('PROTOCOL_TLSv1_1 is not available') 312 else: 313 server = self._server_socket(keyfile=SERVER_KEY, certfile=SERVER_CERT, ssl_version=ssl.PROTOCOL_TLSv1_1) 314 self._assert_connection_success(server, ca_certs=SERVER_CERT, ssl_version=ssl.PROTOCOL_TLSv1_1) 315 316 if not hasattr(ssl, 'PROTOCOL_TLSv1_1') or not hasattr(ssl, 'PROTOCOL_TLSv1_2'): 317 print('PROTOCOL_TLSv1_1 and/or PROTOCOL_TLSv1_2 is not available') 318 else: 319 server = self._server_socket(keyfile=SERVER_KEY, certfile=SERVER_CERT, ssl_version=ssl.PROTOCOL_TLSv1_2) 320 self._assert_connection_failure(server, ca_certs=SERVER_CERT, ssl_version=ssl.PROTOCOL_TLSv1_1) 321 322 def test_ssl_context(self): 323 if not TSSLSocket._has_ssl_context: 324 # unittest.skip is not available for Python 2.6 325 print('skipping test_ssl_context') 326 return 327 server_context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH) 328 server_context.load_cert_chain(SERVER_CERT, SERVER_KEY) 329 server_context.load_verify_locations(CLIENT_CA) 330 server_context.verify_mode = ssl.CERT_REQUIRED 331 server = self._server_socket(ssl_context=server_context) 332 333 client_context = ssl.create_default_context(ssl.Purpose.SERVER_AUTH) 334 client_context.load_cert_chain(CLIENT_CERT, CLIENT_KEY) 335 client_context.load_verify_locations(SERVER_CERT) 336 client_context.verify_mode = ssl.CERT_REQUIRED 337 338 self._assert_connection_success(server, ssl_context=client_context) 339 340 341if __name__ == '__main__': 342 logging.basicConfig(level=logging.WARN) 343 from thrift.transport.TSSLSocket import TSSLSocket, TSSLServerSocket, _match_has_ipaddress 344 from thrift.transport.TTransport import TTransportException 345 346 unittest.main() 347