1 /*
2  * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License").
5  * You may not use this file except in compliance with the License.
6  * A copy of the License is located at
7  *
8  *  http://aws.amazon.com/apache2.0
9  *
10  * or in the "license" file accompanying this file. This file is distributed
11  * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
12  * express or implied. See the License for the specific language governing
13  * permissions and limitations under the License.
14  */
15 
16 #pragma once
17 #include <errno.h>
18 #include <stdio.h>
19 #include <stdlib.h>
20 #include <string.h>
21 #include <unistd.h>
22 
23 #include <openssl/crypto.h>
24 
25 #include "error/s2n_errno.h"
26 #include "utils/s2n_safety.h"
27 #include "utils/s2n_result.h"
28 #include "tls/s2n_tls13.h"
29 
30 int test_count;
31 
32 /* Macro definitions for calls that occur within BEGIN_TEST() and END_TEST() to preserve the SKIPPED test behavior
33  * by ignoring the test_count, keeping it as 0 to indicate that a test was skipped. */
34 #define EXPECT_TRUE_WITHOUT_COUNT( condition )    do { if ( !(condition) ) { FAIL_MSG( #condition " is not true "); } } while(0)
35 #define EXPECT_FALSE_WITHOUT_COUNT( condition )   EXPECT_TRUE_WITHOUT_COUNT( !(condition) )
36 
37 #define EXPECT_NOT_EQUAL_WITHOUT_COUNT( p1, p2 )  EXPECT_FALSE_WITHOUT_COUNT( (p1) == (p2) )
38 
39 #define EXPECT_SUCCESS_WITHOUT_COUNT( function_call )  EXPECT_NOT_EQUAL_WITHOUT_COUNT( (function_call) ,  -1 )
40 
41 /**
42  * This is a very basic, but functional unit testing framework. All testing should
43  * happen in main() and start with a BEGIN_TEST() and end with an END_TEST();
44  */
45 #define BEGIN_TEST()                                           \
46   do {                                                         \
47     test_count = 0;                                            \
48     EXPECT_SUCCESS_WITHOUT_COUNT(s2n_in_unit_test_set(true));  \
49     S2N_TEST_OPTIONALLY_ENABLE_FIPS_MODE();                    \
50     EXPECT_SUCCESS_WITHOUT_COUNT(s2n_init());                  \
51     fprintf(stdout, "Running %-50s ... ", __FILE__);           \
52   } while(0)
53 
54 #define END_TEST()   do { \
55                         EXPECT_SUCCESS_WITHOUT_COUNT(s2n_in_unit_test_set(false));      \
56                         EXPECT_SUCCESS_WITHOUT_COUNT(s2n_cleanup());       \
57                         if (isatty(fileno(stdout))) { \
58                             if (test_count) { \
59                                 fprintf(stdout, "\033[32;1mPASSED\033[0m %10d tests\n", test_count ); \
60                             }\
61                             else {\
62                                 fprintf(stdout, "\033[33;1mSKIPPED\033[0m       ALL tests\n" ); \
63                             }\
64                        } \
65                        else { \
66                             if (test_count) { \
67                                 fprintf(stdout, "PASSED %10d tests\n", test_count ); \
68                             }\
69                             else {\
70                                 fprintf(stdout, "SKIPPED       ALL tests\n" ); \
71                             }\
72                        } \
73                        return 0;\
74                     } while(0)
75 
76 #define FAIL()      FAIL_MSG("")
77 
78 #define FAIL_MSG( msg ) do { \
79                           FAIL_MSG_PRINT(msg); \
80                           exit(1);  \
81                         } while(0)
82 
83 #define FAIL_MSG_PRINT( msg ) do { \
84                           s2n_print_stacktrace(stderr); \
85                           /* isatty will overwrite errno on failure */ \
86                           int real_errno = errno; \
87                           if (isatty(fileno(stderr))) { \
88                             errno = real_errno; \
89                             fprintf(stderr, "\033[31;1mFAILED test %d\033[0m\n%s (%s:%d)\nError Message: '%s'\n Debug String: '%s'\n System Error: %s (%d)\n", test_count, (msg), __FILE__, __LINE__, s2n_strerror(s2n_errno, "EN"), s2n_debug_str, strerror(errno), errno); \
90                           } \
91                           else { \
92                             errno = real_errno; \
93                             fprintf(stderr, "FAILED test %d\n%s (%s:%d)\nError Message: '%s'\n Debug String: '%s'\n System Error: %s (%d)\n", test_count, (msg), __FILE__, __LINE__, s2n_strerror(s2n_errno, "EN"), s2n_debug_str, strerror(errno), errno); \
94                           } \
95                         } while(0)
96 
97 #define RESET_ERRNO() \
98     do { \
99         s2n_errno = 0; \
100         s2n_debug_str = NULL; \
101         errno = 0; \
102     } while(0);
103 
104 #define EXPECT_TRUE( condition )    do { test_count++; if ( !(condition) ) { FAIL_MSG( #condition " is not true "); } } while(0)
105 #define EXPECT_FALSE( condition )   EXPECT_TRUE( !(condition) )
106 
107 #define EXPECT_EQUAL( p1, p2 )      EXPECT_TRUE( (p1) == (p2) )
108 #define EXPECT_NOT_EQUAL( p1, p2 )  EXPECT_FALSE( (p1) == (p2) )
109 
110 #define EXPECT_NULL( ptr )      EXPECT_EQUAL( ptr, NULL )
111 #define EXPECT_NOT_NULL( ptr )  EXPECT_NOT_EQUAL( ptr, NULL )
112 
113 #define EXPECT_FAILURE( function_call ) \
114     do { \
115         EXPECT_EQUAL( (function_call) ,  -1 ); \
116         EXPECT_NOT_EQUAL(s2n_errno, 0); \
117         EXPECT_NOT_NULL(s2n_debug_str); \
118         RESET_ERRNO(); \
119     } while(0)
120 #define EXPECT_ERROR( function_call ) \
121     do { \
122         EXPECT_TRUE( s2n_result_is_error(function_call) ); \
123         EXPECT_NOT_EQUAL(s2n_errno, 0); \
124         EXPECT_NOT_NULL(s2n_debug_str); \
125         RESET_ERRNO(); \
126     } while(0)
127 
128 #define EXPECT_FAILURE_WITH_ERRNO_NO_RESET( function_call, err ) \
129     do { \
130         EXPECT_EQUAL( (function_call), -1 ); \
131         EXPECT_EQUAL(s2n_errno, err); \
132         EXPECT_NOT_NULL(s2n_debug_str); \
133     } while(0)
134 
135 #define EXPECT_FAILURE_WITH_ERRNO( function_call, err ) \
136     do { \
137         EXPECT_FAILURE_WITH_ERRNO_NO_RESET( function_call, err ); \
138         RESET_ERRNO(); \
139     } while(0)
140 
141 /* for use with S2N_RESULT */
142 #define EXPECT_ERROR_WITH_ERRNO_NO_RESET( function_call, err ) \
143     do { \
144         EXPECT_TRUE( s2n_result_is_error(function_call) ); \
145         EXPECT_EQUAL(s2n_errno, err); \
146         EXPECT_NOT_NULL(s2n_debug_str); \
147     } while(0)
148 
149 /* for use with S2N_RESULT */
150 #define EXPECT_ERROR_WITH_ERRNO( function_call, err ) \
151     do { \
152         EXPECT_ERROR_WITH_ERRNO_NO_RESET( function_call, err ); \
153         RESET_ERRNO(); \
154     } while(0)
155 
156 #define EXPECT_NULL_WITH_ERRNO_NO_RESET( function_call, err ) \
157     do { \
158         EXPECT_NULL( (function_call) ); \
159         EXPECT_EQUAL(s2n_errno, err); \
160         EXPECT_NOT_NULL(s2n_debug_str); \
161     } while(0)
162 
163 #define EXPECT_NULL_WITH_ERRNO( function_call, err ) \
164     do { \
165         EXPECT_NULL_WITH_ERRNO_NO_RESET( function_call, err ); \
166         RESET_ERRNO(); \
167     } while(0)
168 
169 #define EXPECT_SUCCESS( function_call )  EXPECT_NOT_EQUAL( (function_call) ,  -1 )
170 /* for use with S2N_RESULT */
171 #define EXPECT_OK( function_call )  EXPECT_TRUE( s2n_result_is_ok(function_call) )
172 
173 #define EXPECT_BYTEARRAY_EQUAL( p1, p2, l ) EXPECT_EQUAL( memcmp( (p1), (p2), (l) ), 0 )
174 #define EXPECT_BYTEARRAY_NOT_EQUAL( p1, p2, l ) EXPECT_NOT_EQUAL( memcmp( (p1), (p2), (l) ), 0 )
175 
176 #define EXPECT_STRING_EQUAL( p1, p2 ) EXPECT_EQUAL( strcmp( (p1), (p2) ), 0 )
177 #define EXPECT_STRING_NOT_EQUAL( p1, p2 ) EXPECT_NOT_EQUAL( strcmp( (p1), (p2) ), 0 )
178 
179 #ifdef S2N_TEST_IN_FIPS_MODE
180 #include <openssl/err.h>
181 
182 #define S2N_TEST_OPTIONALLY_ENABLE_FIPS_MODE() \
183     do { \
184         if (FIPS_mode_set(1) == 0) { \
185             unsigned long fips_rc = ERR_get_error(); \
186             char ssl_error_buf[256]; \
187             fprintf(stderr, "s2nd failed to enter FIPS mode with RC: %lu; String: %s\n", fips_rc, ERR_error_string(fips_rc, ssl_error_buf)); \
188             return 1; \
189         } \
190         printf("s2nd entered FIPS mode\n"); \
191     } while (0)
192 
193 #else
194 #define S2N_TEST_OPTIONALLY_ENABLE_FIPS_MODE()
195 #endif
196 
197 /* Ensures fuzz test input length is greater than or equal to the minimum needed for the test */
198 #define S2N_FUZZ_ENSURE_MIN_LEN( len , min ) do {if ( (len) < (min) ) return S2N_SUCCESS;} while (0)
199 
200 #define EXPECT_MEMCPY_SUCCESS(d, s, n)                                         \
201     do {                                                                       \
202         __typeof(n) __tmp_n = (n);                                             \
203         if (__tmp_n) {                                                         \
204             if (memcpy((d), (s), (__tmp_n)) == NULL) {                         \
205                 FAIL_MSG(#d "is NULL, memcpy() failed");                       \
206             }                                                                  \
207         }                                                                      \
208     } while (0)
209 
210 #if defined(S2N_TEST_DEBUG)
211 #define TEST_DEBUG_PRINT(...)                \
212     do {                                     \
213         (void) fprintf(stderr, __VA_ARGS__); \
214     } while (0)
215 #else
216 #define TEST_DEBUG_PRINT(...)
217 #endif
218 
219 /* Creates a fuzz target */
220 #define S2N_FUZZ_TARGET(fuzz_init, fuzz_entry, fuzz_cleanup) \
221 void s2n_test__fuzz_cleanup() \
222 { \
223     if (fuzz_cleanup) { \
224         ((void (*)()) fuzz_cleanup)(); \
225     } \
226     s2n_cleanup(); \
227 } \
228 int LLVMFuzzerInitialize(int *argc, char **argv[]) \
229 { \
230     S2N_TEST_OPTIONALLY_ENABLE_FIPS_MODE(); \
231     EXPECT_SUCCESS_WITHOUT_COUNT(s2n_init()); \
232     EXPECT_SUCCESS_WITHOUT_COUNT(atexit(s2n_test__fuzz_cleanup)); \
233     if (!fuzz_init) { \
234         return S2N_SUCCESS; \
235     } \
236     int result = ((int (*)(int *argc, char **argv[])) fuzz_init)(argc, argv); \
237     if (result != S2N_SUCCESS) { \
238         FAIL_MSG_PRINT(#fuzz_init " did not return S2N_SUCCESS"); \
239     } \
240     return result; \
241 } \
242 int LLVMFuzzerTestOneInput(const uint8_t *buf, size_t len) \
243 { \
244     int result = fuzz_entry(buf, len); \
245     if (result != S2N_SUCCESS) { \
246         FAIL_MSG_PRINT(#fuzz_entry " did not return S2N_SUCCESS"); \
247     } \
248     return result; \
249 }
250