1#
2# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
3#
4# Licensed under the Apache License, Version 2.0 (the "License").
5# You may not use this file except in compliance with the License.
6# A copy of the License is located at
7#
8#  http://aws.amazon.com/apache2.0
9#
10# or in the "license" file accompanying this file. This file is distributed
11# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
12# express or implied. See the License for the specific language governing
13# permissions and limitations under the License.
14#
15
16"""
17Simple handshake tests using gnutls-cli
18"""
19
20import argparse
21import collections
22import os
23import sys
24import ssl
25import socket
26import subprocess
27import itertools
28import multiprocessing
29from os import environ
30from multiprocessing.pool import ThreadPool
31from s2n_test_constants import *
32
33# A container to make passing the return values from an attempted handshake more convenient
34HANDSHAKE_RC = collections.namedtuple('HANDSHAKE_RC', 'handshake_success gnutls_stdout')
35
36LIBCRYPTO_SUPPORT_X25519 = ['openssl-1.1.1']
37
38# Helper to print just the SHA256 portion of SIGN-RSA-SHA256
39def sigalg_str_from_list(sigalgs):
40    # strip the first nine bytes from each name for "SIGN-RSA", 11 for "SIGN-ECDSA"
41    return ":".join(x[9:] if x.startswith("SIGN-RSA") else x[11:] for x in sigalgs)
42
43def try_gnutls_handshake(endpoint, port, priority_str, mfl_extension_test, ssl_version, enter_fips_mode=False):
44    # Fire up s2nd
45    s2nd_cmd = ["../../bin/s2nd", str(endpoint), str(port)]
46    s2nd_ciphers = "test_all_tls12"
47
48    if enter_fips_mode == True:
49        s2nd_ciphers = "test_all_fips"
50        s2nd_cmd.append("--enter-fips-mode")
51    s2nd_cmd.append("-c")
52    s2nd_cmd.append(s2nd_ciphers)
53    if "ECDSA" in priority_str:
54        s2nd_ciphers = "test_all_ecdsa"
55        s2nd_cmd.extend(["--cert", TEST_ECDSA_CERT])
56        s2nd_cmd.extend(["--key", TEST_ECDSA_KEY])
57    if mfl_extension_test:
58        s2nd_cmd.append("--enable-mfl")
59
60    s2nd = subprocess.Popen(s2nd_cmd, stdin=subprocess.PIPE, stdout=subprocess.PIPE)
61
62    # Make sure it's running
63    s2nd.stdout.readline()
64
65    gnutls_cmd = ["gnutls-cli", "--priority=" + priority_str,"--insecure", "-p " + str(port), str(endpoint)]
66
67    if mfl_extension_test:
68        gnutls_cmd.append("--recordsize=" + str(mfl_extension_test))
69
70    # Fire up gnutls-cli, use insecure since s2nd is using a dummy cert
71    gnutls_cli = subprocess.Popen(gnutls_cmd, stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
72
73    # Save the initial output of gnutls-cli to parse the negotiated handshake parameters later
74    gnutls_initial_stdout_str = ""
75    for line in range(0 , 100):
76        output = gnutls_cli.stdout.readline().decode("utf-8")
77        gnutls_initial_stdout_str += output + "\n"
78        # Once we see this string, we have read enough output to determine which signature algorithm was used
79        if "Simple Client Mode" in output:
80            break
81
82    # Write the priority str towards s2nd. Prepend with the 's2n' string to make sure we don't accidently match something
83    # in the gnutls-cli handshake output
84    written_str = "s2n" + priority_str
85    gnutls_cli.stdin.write((written_str + "\n").encode("utf-8"))
86    gnutls_cli.stdin.flush()
87
88    # Read it
89    found = 0
90    right_version = 0
91    for line in range(0, 50):
92        output = s2nd.stdout.readline().decode("utf-8")
93        if output.strip() == written_str:
94            found = 1
95            break
96        if ACTUAL_VERSION_STR.format(ssl_version or S2N_TLS12) in output:
97            right_version = 1
98
99    if found == 0 or right_version == 0:
100        return HANDSHAKE_RC(False, gnutls_initial_stdout_str)
101
102    # Write the cipher name from s2n
103    s2nd.stdin.write((written_str + "\n").encode("utf-8"))
104    s2nd.stdin.flush()
105    found = 0
106    for line in range(0, 50):
107        output = gnutls_cli.stdout.readline().decode("utf-8")
108        if output.strip() == written_str:
109            found = 1
110            break
111
112    if found == 0:
113        return HANDSHAKE_RC(False, gnutls_initial_stdout_str)
114
115    gnutls_cli.kill()
116    gnutls_cli.wait()
117    s2nd.kill()
118    s2nd.wait()
119    return HANDSHAKE_RC(True, gnutls_initial_stdout_str)
120
121def handshake(endpoint, port, cipher_name, ssl_version, priority_str, digests, curves, mfl_extension_test, fips_mode,
122        other_prefix=None):
123    ret = try_gnutls_handshake(endpoint, port, priority_str, mfl_extension_test, ssl_version, fips_mode)
124
125    prefix = other_prefix or ""
126    if mfl_extension_test:
127        prefix += "MFL: %-10s Cipher: %-10s Vers: %-10s ... " % (mfl_extension_test, cipher_name, S2N_PROTO_VERS_TO_STR[ssl_version])
128    elif digests:
129        # strip the first nine bytes from each name for "SIGN-RSA", 11 for "SIGN-ECDSA"
130        prefix += "Digests: %-40s Vers: %-10s ... " % (sigalg_str_from_list(digests), S2N_PROTO_VERS_TO_STR[ssl_version])
131    elif curves:
132         # strip the first 6 bytes of each curve name ("CURVE-")
133         curve_string = ":".join([x[6:] for x in curves])
134         prefix += "Curves: %-40s Vers: %-10s ... " % (curve_string, S2N_PROTO_VERS_TO_STR[ssl_version])
135    else:
136        prefix += "Cipher: %-30s Vers: %-10s ... " % (cipher_name, S2N_PROTO_VERS_TO_STR[ssl_version])
137
138    suffix = ""
139    if ret.handshake_success == True:
140        if sys.stdout.isatty():
141            suffix = "\033[32;1mPASSED\033[0m"
142        else:
143            suffix = "PASSED"
144    else:
145        if sys.stdout.isatty():
146            suffix = "\033[31;1mFAILED\033[0m"
147        else:
148            suffix = "FAILED"
149    print(prefix + suffix)
150    return ret
151
152def create_thread_pool():
153    threadpool_size = multiprocessing.cpu_count() * 2  #Multiply by 2 since performance improves slightly if CPU has hyperthreading
154    print("\n\tCreating ThreadPool of size: " + str(threadpool_size))
155    threadpool = ThreadPool(processes=threadpool_size)
156    return threadpool
157
158def main():
159    parser = argparse.ArgumentParser(description='Runs TLS server integration tests against s2nd using gnutls-cli')
160    parser.add_argument('host', help='The host for s2nd to bind to')
161    parser.add_argument('port', type=int, help='The port for s2nd to bind to')
162    parser.add_argument('--libcrypto', default='openssl-1.1.1', choices=S2N_LIBCRYPTO_CHOICES,
163            help="""The Libcrypto that s2n was built with. s2n supports different cipher suites depending on
164                    libcrypto version. Defaults to openssl-1.1.1.""")
165    args = parser.parse_args()
166
167    # Retrieve the test ciphers to use based on the libcrypto version s2n was built with
168    test_ciphers = S2N_LIBCRYPTO_TO_TEST_CIPHERS[args.libcrypto]
169    host = args.host
170    port = args.port
171
172    fips_mode = False
173    if environ.get("S2N_TEST_IN_FIPS_MODE") is not None:
174        fips_mode = True
175        print("\nRunning s2nd in FIPS mode.")
176
177    print("\nRunning GnuTLS handshake tests with: " + os.popen('gnutls-cli --version | grep -w gnutls-cli').read())
178    for ssl_version in [S2N_SSLv3, S2N_TLS10, S2N_TLS11, S2N_TLS12]:
179
180        if ssl_version == S2N_SSLv3 and fips_mode == True:
181            # FIPS does not permit the use of SSLv3
182            continue
183
184        print("\n\tTesting ciphers using client version: " + S2N_PROTO_VERS_TO_STR[ssl_version])
185        threadpool = create_thread_pool()
186        port_offset = 0
187        results = []
188
189        for cipher in test_ciphers:
190            # Use the Openssl name for printing
191            cipher_name = cipher.openssl_name
192            cipher_priority_str = cipher.gnutls_priority_str
193            cipher_vers = cipher.min_tls_vers
194
195            if ssl_version < cipher_vers:
196                continue
197
198            # gnutls-cli always adds tls extensions to client hello, add NO_EXTENSIONS flag for SSLv3 to avoid that
199            if ssl_version == S2N_SSLv3:
200                cipher_priority_str = cipher_priority_str + ":%NO_EXTENSIONS"
201
202            # Add the SSL version to make the cipher priority string fully qualified
203            complete_priority_str = cipher_priority_str + ":+" + S2N_PROTO_VERS_TO_GNUTLS[ssl_version] + ":+SIGN-ALL" + ":+CURVE-ALL"
204
205            async_result = threadpool.apply_async(handshake, (host, port + port_offset, cipher_name, ssl_version, complete_priority_str, [], [], 0, fips_mode))
206            port_offset += 1
207            results.append(async_result)
208
209        threadpool.close()
210        threadpool.join()
211        for async_result in results:
212            if async_result.get().handshake_success == False:
213                return -1
214
215    # Produce permutations of every accepted signature algorithm in every possible order
216    for size in range(1, min(MAX_ITERATION_DEPTH, len(EXPECTED_RSA_SIGNATURE_ALGORITHM_PREFS)) + 1):
217        print("\n\tTesting ciphers using RSA signature preferences of size: " + str(size))
218        threadpool = create_thread_pool()
219        port_offset = 0
220        results = []
221        for permutation in itertools.permutations(EXPECTED_RSA_SIGNATURE_ALGORITHM_PREFS, size):
222            # Try an ECDHE cipher suite and a DHE one
223            for cipher in filter(lambda x: x.openssl_name == "ECDHE-RSA-AES128-GCM-SHA256" or x.openssl_name == "DHE-RSA-AES128-GCM-SHA256", ALL_TEST_CIPHERS):
224                if fips_mode and cipher.openssl_fips_compatible == False:
225                    continue
226                complete_priority_str = cipher.gnutls_priority_str + ":+CURVE-ALL" + ":+VERS-TLS1.2:+" + ":+".join(permutation)
227                async_result = threadpool.apply_async(handshake,(host, port + port_offset, cipher.openssl_name, S2N_TLS12, complete_priority_str, permutation, [], 0, fips_mode))
228                port_offset += 1
229                results.append(async_result)
230
231        threadpool.close()
232        threadpool.join()
233        for async_result in results:
234            if async_result.get().handshake_success == False:
235                return -1
236
237    # Try ECDSA signature algorithm permutations. When we support multiple certificates, we can combine the RSA and ECDSA tests
238    for size in range(1, min(MAX_ITERATION_DEPTH, len(EXPECTED_ECDSA_SIGNATURE_ALGORITHM_PREFS)) + 1):
239        print("\n\tTesting ciphers using ECDSA signature preferences of size: " + str(size))
240        threadpool = create_thread_pool()
241        port_offset = 0
242        results = []
243        for permutation in itertools.permutations(EXPECTED_ECDSA_SIGNATURE_ALGORITHM_PREFS, size):
244            for cipher in filter(lambda x: x.openssl_name == "ECDHE-ECDSA-AES128-SHA", ALL_TEST_CIPHERS):
245                if fips_mode and cipher.openssl_fips_compatible == False:
246                    continue
247                complete_priority_str = cipher.gnutls_priority_str + ":+CURVE-ALL" + ":+VERS-TLS1.2:+" + ":+".join(permutation)
248                async_result = threadpool.apply_async(handshake,(host, port + port_offset, cipher.openssl_name, S2N_TLS12, complete_priority_str, permutation, [], 0, fips_mode))
249                port_offset += 1
250                results.append(async_result)
251
252        threadpool.close()
253        threadpool.join()
254        for async_result in results:
255            if async_result.get().handshake_success == False:
256                return -1
257
258    # Test that s2n's server Signature Algorithm preferences are as expected.
259    # This is a brittle test that must be kept in sync with the signature algorithm preference lists in the core code,
260    # but made manageable by rarity of signature algorithm preference updates.
261    print("\n\tTesting RSA Signature Algorithm preferences")
262    print("\n\tExpected preference order: " + ",".join(EXPECTED_RSA_SIGNATURE_ALGORITHM_PREFS))
263    for i in range(0, len(EXPECTED_RSA_SIGNATURE_ALGORITHM_PREFS)):
264        # To find the Nth preferred signature algorithm, generate a priority string with ALL sigalgs then subtract any
265        # higher preference sigalgs we've already found.
266        current_preferences_found = EXPECTED_RSA_SIGNATURE_ALGORITHM_PREFS[:i]
267        # We expect to negotiate sigalg at preference i if previous i - 1 sigalgs are removed.
268        expected_sigalg = EXPECTED_RSA_SIGNATURE_ALGORITHM_PREFS[i]
269        for cipher in filter(lambda x: x.openssl_name == "ECDHE-RSA-AES128-SHA", ALL_TEST_CIPHERS):
270            if fips_mode and cipher.openssl_fips_compatible == False:
271                continue
272            sig_algs_to_remove = ":!".join(current_preferences_found)
273            sig_algs = "SIGN-ALL"
274            if len(sig_algs_to_remove) > 0:
275                sig_algs += ":!" + sig_algs_to_remove
276            priority_str = cipher.gnutls_priority_str + ":+CURVE-ALL" + ":+VERS-TLS1.2:+" + sig_algs
277            rc = handshake(host, port, cipher.openssl_name, S2N_TLS12, priority_str, [], [], 0, fips_mode, "Preferences found: %-40s "
278                    % (sigalg_str_from_list(current_preferences_found)))
279            if rc.handshake_success == False:
280                print("Failed to negotiate " + expected_sigalg + " as expected! Priority string: "
281                        + priority_str)
282                return -1
283            negotiated_sigalg_line = [line for line in rc.gnutls_stdout.split('\n') if "Server Signature" in line]
284            if len(negotiated_sigalg_line) == 0:
285                print("Failed to find negotiated sig alg in gnutls-cli output! Priority string: " + priority_str)
286                return -1
287
288            # The gnutls-cli output is for sigalgs is of the format "Server Signature : $SIGALG"
289            # Confusingly, $SIGALG is in GnuTLS priority string format with the "SIGN" part of the string removed.
290            # Restore it to this string for comparison with existing list.
291            negotiated_sigalg = "SIGN-" + negotiated_sigalg_line[0].split(":")[1].strip()
292            if negotiated_sigalg != expected_sigalg:
293                print("Failed to negotiate the expected sigalg! Expected " + expected_sigalg
294                        + " Got: " + negotiated_sigalg + " at position " + str(i) + " in the preference list" +
295                        " Priority string: " + priority_str)
296                return -1
297
298    print("\n\tTesting ECDSA Signature Algorithm preferences")
299    print("\n\tExpected preference order: " + ",".join(EXPECTED_ECDSA_SIGNATURE_ALGORITHM_PREFS))
300    for i in range(0, len(EXPECTED_ECDSA_SIGNATURE_ALGORITHM_PREFS)):
301        # To find the Nth preferred signature algorithm, generate a priority string with ALL sigalgs then subtract any
302        # higher preference sigalgs we've already found.
303        current_preferences_found = EXPECTED_ECDSA_SIGNATURE_ALGORITHM_PREFS[:i]
304        # We expect to negotiate sigalg at preference i if previous i - 1 sigalgs are removed.
305        expected_sigalg = EXPECTED_ECDSA_SIGNATURE_ALGORITHM_PREFS[i]
306        for cipher in filter(lambda x: x.openssl_name == "ECDHE-ECDSA-AES128-SHA", ALL_TEST_CIPHERS):
307            if fips_mode and cipher.openssl_fips_compatible == False:
308                continue
309            sig_algs_to_remove = ":!".join(current_preferences_found)
310            sig_algs = "SIGN-ALL"
311            if len(sig_algs_to_remove) > 0:
312                sig_algs += ":!" + sig_algs_to_remove
313            priority_str = cipher.gnutls_priority_str + ":+CURVE-ALL" + ":+VERS-TLS1.2:+" + sig_algs
314            rc = handshake(host, port, cipher.openssl_name, S2N_TLS12, priority_str, [], [], 0, fips_mode, "Preferences found: %-40s "
315                    % (sigalg_str_from_list(current_preferences_found)))
316            if rc.handshake_success == False:
317                print("Failed to negotiate " + expected_sigalg + " as expected! Priority string: " +
318                        priority_str)
319                return -1
320            negotiated_sigalg_line = [line for line in rc.gnutls_stdout.split('\n') if "Server Signature" in line]
321            if len(negotiated_sigalg_line) == 0:
322                print("Failed to find negotiated sig alg in gnutls-cli output! Priority string: " + priority_str)
323                return -1
324
325            # The gnutls-cli output is for sigalgs is of the format "Server Signature : $SIGALG"
326            # Confusingly, $SIGALG is in GnuTLS priority string format with the "SIGN" part of the string removed.
327            # Restore it to this string for comparison with existing list.
328            negotiated_sigalg = "SIGN-" + negotiated_sigalg_line[0].split(":")[1].strip()
329            if negotiated_sigalg != expected_sigalg:
330                print("Failed to negotiate the expected sigalg! Expected " + expected_sigalg
331                        + " Got: " + negotiated_sigalg + " at position " + str(i) + " in the preference list" +
332                        " Priority string: " + priority_str)
333                return -1
334
335    # Produce permutations of every curve s2n supports in every possible order
336    curves = ["CURVE-SECP256R1", "CURVE-SECP384R1", "CURVE-SECP521R1"]
337    for size in range(1, len(curves) + 1):
338        print("\n\tTesting named curve preferences of size: " + str(size))
339        threadpool = create_thread_pool()
340        port_offset = 0
341        results = []
342        for permutation in itertools.permutations(curves,size):
343            # Use an arbitrary ECDHE kx cipher
344            cipher = [x for x in ALL_TEST_CIPHERS if x.openssl_name == "ECDHE-RSA-AES128-GCM-SHA256"][0]
345            complete_priority_str = cipher.gnutls_priority_str + ":+SIGN-ALL" + ":+VERS-TLS1.2:+" + ":+".join(permutation)
346            async_result = threadpool.apply_async(handshake, (host, port + port_offset, cipher.openssl_name, S2N_TLS12, complete_priority_str, [], permutation, 0, fips_mode))
347            port_offset += 1
348            results.append(async_result)
349        threadpool.close()
350        threadpool.join()
351        for async_result in results:
352            if async_result.get().handshake_success == False:
353                return -1
354
355    print("\n\tTesting handshakes with Max Fragment Length Extension")
356    for ssl_version in [S2N_TLS10, S2N_TLS11, S2N_TLS12]:
357        print("\n\tTesting Max Fragment Length Extension using client version: " + S2N_PROTO_VERS_TO_STR[ssl_version])
358        threadpool = create_thread_pool()
359        port_offset = 0
360        results = []
361        for mfl_extension_test in [512, 1024, 2048, 4096]:
362            cipher = test_ciphers[0]
363            complete_priority_str = cipher.gnutls_priority_str + ":+CURVE-ALL" + ":+" + S2N_PROTO_VERS_TO_GNUTLS[ssl_version] + ":+SIGN-ALL"
364            async_result = threadpool.apply_async(handshake,(host, port + port_offset, cipher.openssl_name, ssl_version, complete_priority_str, [], [], mfl_extension_test, fips_mode))
365            port_offset += 1
366            results.append(async_result)
367
368        threadpool.close()
369        threadpool.join()
370        for async_result in results:
371            if async_result.get().handshake_success == False:
372                return -1
373
374if __name__ == "__main__":
375    sys.exit(main())
376