xref: /openbsd/regress/lib/libssl/api/apitest.c (revision 7c0ec4b8)
1 /* $OpenBSD: apitest.c,v 1.3 2024/09/07 16:39:29 tb Exp $ */
2 /*
3  * Copyright (c) 2020, 2021 Joel Sing <jsing@openbsd.org>
4  *
5  * Permission to use, copy, modify, and distribute this software for any
6  * purpose with or without fee is hereby granted, provided that the above
7  * copyright notice and this permission notice appear in all copies.
8  *
9  * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
10  * WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
11  * MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
12  * ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
13  * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
14  * ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
15  * OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
16  */
17 
18 #include <err.h>
19 
20 #include <openssl/bio.h>
21 #include <openssl/err.h>
22 #include <openssl/ssl.h>
23 
24 #ifndef CERTSDIR
25 #define CERTSDIR "."
26 #endif
27 
28 const char *certs_path = CERTSDIR;
29 
30 int debug = 0;
31 
32 static int
ssl_ctx_use_ca_file(SSL_CTX * ssl_ctx,const char * ca_file)33 ssl_ctx_use_ca_file(SSL_CTX *ssl_ctx, const char *ca_file)
34 {
35 	char *ca_path = NULL;
36 	int ret = 0;
37 
38 	if (asprintf(&ca_path, "%s/%s", certs_path, ca_file) == -1)
39 		goto err;
40 	if (!SSL_CTX_load_verify_locations(ssl_ctx, ca_path, NULL)) {
41 		fprintf(stderr, "load_verify_locations(%s) failed\n", ca_path);
42 		goto err;
43 	}
44 
45 	ret = 1;
46 
47  err:
48 	free(ca_path);
49 
50 	return ret;
51 }
52 
53 static int
ssl_ctx_use_keypair(SSL_CTX * ssl_ctx,const char * chain_file,const char * key_file)54 ssl_ctx_use_keypair(SSL_CTX *ssl_ctx, const char *chain_file,
55     const char *key_file)
56 {
57 	char *chain_path = NULL, *key_path = NULL;
58 	int ret = 0;
59 
60 	if (asprintf(&chain_path, "%s/%s", certs_path, chain_file) == -1)
61 		goto err;
62 	if (SSL_CTX_use_certificate_chain_file(ssl_ctx, chain_path) != 1) {
63 		fprintf(stderr, "FAIL: Failed to load certificates\n");
64 		goto err;
65 	}
66 	if (asprintf(&key_path, "%s/%s", certs_path, key_file) == -1)
67 		goto err;
68 	if (SSL_CTX_use_PrivateKey_file(ssl_ctx, key_path,
69 	    SSL_FILETYPE_PEM) != 1) {
70 		fprintf(stderr, "FAIL: Failed to load private key\n");
71 		goto err;
72 	}
73 
74 	ret = 1;
75 
76  err:
77 	free(chain_path);
78 	free(key_path);
79 
80 	return ret;
81 }
82 
83 static SSL *
tls_client(BIO * rbio,BIO * wbio)84 tls_client(BIO *rbio, BIO *wbio)
85 {
86 	SSL_CTX *ssl_ctx = NULL;
87 	SSL *ssl = NULL;
88 
89 	if ((ssl_ctx = SSL_CTX_new(TLS_method())) == NULL)
90 		errx(1, "client context");
91 
92 	SSL_CTX_set_verify(ssl_ctx, SSL_VERIFY_PEER, NULL);
93 
94 	if (!ssl_ctx_use_ca_file(ssl_ctx, "ca-root-rsa.pem"))
95 		goto failure;
96 	if (!ssl_ctx_use_keypair(ssl_ctx, "client1-rsa-chain.pem",
97 	    "client1-rsa.pem"))
98 		goto failure;
99 
100 	if ((ssl = SSL_new(ssl_ctx)) == NULL)
101 		errx(1, "client ssl");
102 
103 	BIO_up_ref(rbio);
104 	BIO_up_ref(wbio);
105 
106 	SSL_set_bio(ssl, rbio, wbio);
107 
108  failure:
109 	SSL_CTX_free(ssl_ctx);
110 
111 	return ssl;
112 }
113 
114 static SSL *
tls_server(BIO * rbio,BIO * wbio)115 tls_server(BIO *rbio, BIO *wbio)
116 {
117 	SSL_CTX *ssl_ctx = NULL;
118 	SSL *ssl = NULL;
119 
120 	if ((ssl_ctx = SSL_CTX_new(TLS_method())) == NULL)
121 		errx(1, "server context");
122 
123 	SSL_CTX_set_dh_auto(ssl_ctx, 2);
124 
125 	SSL_CTX_set_verify(ssl_ctx,
126 	    SSL_VERIFY_PEER|SSL_VERIFY_FAIL_IF_NO_PEER_CERT, NULL);
127 
128 	if (!ssl_ctx_use_ca_file(ssl_ctx, "ca-root-rsa.pem"))
129 		goto failure;
130 	if (!ssl_ctx_use_keypair(ssl_ctx, "server1-rsa-chain.pem",
131 	    "server1-rsa.pem"))
132 		goto failure;
133 
134 	if ((ssl = SSL_new(ssl_ctx)) == NULL)
135 		errx(1, "server ssl");
136 
137 	BIO_up_ref(rbio);
138 	BIO_up_ref(wbio);
139 
140 	SSL_set_bio(ssl, rbio, wbio);
141 
142  failure:
143 	SSL_CTX_free(ssl_ctx);
144 
145 	return ssl;
146 }
147 
148 static int
ssl_error(SSL * ssl,const char * name,const char * desc,int ssl_ret)149 ssl_error(SSL *ssl, const char *name, const char *desc, int ssl_ret)
150 {
151 	int ssl_err;
152 
153 	ssl_err = SSL_get_error(ssl, ssl_ret);
154 
155 	if (ssl_err == SSL_ERROR_WANT_READ) {
156 		return 1;
157 	} else if (ssl_err == SSL_ERROR_WANT_WRITE) {
158 		return 1;
159 	} else if (ssl_err == SSL_ERROR_SYSCALL && errno == 0) {
160 		/* Yup, this is apparently a thing... */
161 	} else {
162 		fprintf(stderr, "FAIL: %s %s failed - ssl err = %d, errno = %d\n",
163 		    name, desc, ssl_err, errno);
164 		ERR_print_errors_fp(stderr);
165 		return 0;
166 	}
167 
168 	return 1;
169 }
170 
171 static int
do_connect(SSL * ssl,const char * name,int * done)172 do_connect(SSL *ssl, const char *name, int *done)
173 {
174 	int ssl_ret;
175 
176 	if ((ssl_ret = SSL_connect(ssl)) == 1) {
177 		fprintf(stderr, "INFO: %s connect done\n", name);
178 		*done = 1;
179 		return 1;
180 	}
181 
182 	return ssl_error(ssl, name, "connect", ssl_ret);
183 }
184 
185 static int
do_accept(SSL * ssl,const char * name,int * done)186 do_accept(SSL *ssl, const char *name, int *done)
187 {
188 	int ssl_ret;
189 
190 	if ((ssl_ret = SSL_accept(ssl)) == 1) {
191 		fprintf(stderr, "INFO: %s accept done\n", name);
192 		*done = 1;
193 		return 1;
194 	}
195 
196 	return ssl_error(ssl, name, "accept", ssl_ret);
197 }
198 
199 typedef int (*ssl_func)(SSL *ssl, const char *name, int *done);
200 
201 static int
do_client_server_loop(SSL * client,ssl_func client_func,SSL * server,ssl_func server_func)202 do_client_server_loop(SSL *client, ssl_func client_func, SSL *server,
203     ssl_func server_func)
204 {
205 	int client_done = 0, server_done = 0;
206 	int i = 0;
207 
208 	do {
209 		if (!client_done) {
210 			if (debug)
211 				fprintf(stderr, "DEBUG: client loop\n");
212 			if (!client_func(client, "client", &client_done))
213 				return 0;
214 		}
215 		if (!server_done) {
216 			if (debug)
217 				fprintf(stderr, "DEBUG: server loop\n");
218 			if (!server_func(server, "server", &server_done))
219 				return 0;
220 		}
221 	} while (i++ < 100 && (!client_done || !server_done));
222 
223 	if (!client_done || !server_done)
224 		fprintf(stderr, "FAIL: gave up\n");
225 
226 	return client_done && server_done;
227 }
228 
229 static int
ssl_get_peer_cert_chain_test(uint16_t tls_version)230 ssl_get_peer_cert_chain_test(uint16_t tls_version)
231 {
232 	STACK_OF(X509) *peer_chain;
233 	X509 *peer_cert;
234 	BIO *client_wbio = NULL, *server_wbio = NULL;
235 	SSL *client = NULL, *server = NULL;
236 	int failed = 1;
237 
238 	if ((client_wbio = BIO_new(BIO_s_mem())) == NULL)
239 		goto failure;
240 	if (BIO_set_mem_eof_return(client_wbio, -1) <= 0)
241 		goto failure;
242 
243 	if ((server_wbio = BIO_new(BIO_s_mem())) == NULL)
244 		goto failure;
245 	if (BIO_set_mem_eof_return(server_wbio, -1) <= 0)
246 		goto failure;
247 
248 	if ((client = tls_client(server_wbio, client_wbio)) == NULL)
249 		goto failure;
250 	if (tls_version != 0) {
251 		if (!SSL_set_min_proto_version(client, tls_version))
252 			goto failure;
253 		if (!SSL_set_max_proto_version(client, tls_version))
254 			goto failure;
255 	}
256 
257 	if ((server = tls_server(client_wbio, server_wbio)) == NULL)
258 		goto failure;
259 	if (tls_version != 0) {
260 		if (!SSL_set_min_proto_version(server, tls_version))
261 			goto failure;
262 		if (!SSL_set_max_proto_version(server, tls_version))
263 			goto failure;
264 	}
265 
266 	if (!do_client_server_loop(client, do_connect, server, do_accept)) {
267 		fprintf(stderr, "FAIL: client and server handshake failed\n");
268 		goto failure;
269 	}
270 
271 	if (tls_version != 0) {
272 		if (SSL_version(client) != tls_version) {
273 			fprintf(stderr, "FAIL: client got TLS version %x, "
274 			    "want %x\n", SSL_version(client), tls_version);
275 			goto failure;
276 		}
277 		if (SSL_version(server) != tls_version) {
278 			fprintf(stderr, "FAIL: server got TLS version %x, "
279 			    "want %x\n", SSL_version(server), tls_version);
280 			goto failure;
281 		}
282 	}
283 
284 	/*
285 	 * Due to the wonders of API inconsistency, SSL_get_peer_cert_chain()
286 	 * includes the peer's leaf certificate when called by the client,
287 	 * however it does not when called by the server. Furthermore, the
288 	 * certificate returned by SSL_get_peer_certificate() has already
289 	 * had its reference count incremented and must be freed, where as
290 	 * the certificates returned from SSL_get_peer_cert_chain() must
291 	 * not be freed... *sigh*
292 	 */
293 	peer_cert = SSL_get_peer_certificate(client);
294 	peer_chain = SSL_get_peer_cert_chain(client);
295 	X509_free(peer_cert);
296 
297 	if (peer_cert == NULL) {
298 		fprintf(stderr, "FAIL: client got no peer cert\n");
299 		goto failure;
300 	}
301 	if (sk_X509_num(peer_chain) != 2) {
302 		fprintf(stderr, "FAIL: client got peer cert chain with %d "
303 		    "certificates, want 2\n", sk_X509_num(peer_chain));
304 		goto failure;
305 	}
306 	if (X509_cmp(peer_cert, sk_X509_value(peer_chain, 0)) != 0) {
307 		fprintf(stderr, "FAIL: client got peer cert chain without peer "
308 		    "certificate\n");
309 		goto failure;
310 	}
311 
312 	peer_cert = SSL_get_peer_certificate(server);
313 	peer_chain = SSL_get_peer_cert_chain(server);
314 	X509_free(peer_cert);
315 
316 	if (peer_cert == NULL) {
317 		fprintf(stderr, "FAIL: server got no peer cert\n");
318 		goto failure;
319 	}
320 	if (sk_X509_num(peer_chain) != 1) {
321 		fprintf(stderr, "FAIL: server got peer cert chain with %d "
322 		    "certificates, want 1\n", sk_X509_num(peer_chain));
323 		goto failure;
324 	}
325 	if (X509_cmp(peer_cert, sk_X509_value(peer_chain, 0)) == 0) {
326 		fprintf(stderr, "FAIL: server got peer cert chain with peer "
327 		    "certificate\n");
328 		goto failure;
329 	}
330 
331 	fprintf(stderr, "INFO: Done!\n");
332 
333 	failed = 0;
334 
335  failure:
336 	BIO_free(client_wbio);
337 	BIO_free(server_wbio);
338 
339 	SSL_free(client);
340 	SSL_free(server);
341 
342 	return failed;
343 }
344 
345 static int
ssl_get_peer_cert_chain_tests(void)346 ssl_get_peer_cert_chain_tests(void)
347 {
348 	int failed = 0;
349 
350 	fprintf(stderr, "\n== Testing SSL_get_peer_cert_chain()... ==\n");
351 
352 	failed |= ssl_get_peer_cert_chain_test(0);
353 	failed |= ssl_get_peer_cert_chain_test(TLS1_3_VERSION);
354 	failed |= ssl_get_peer_cert_chain_test(TLS1_2_VERSION);
355 
356 	return failed;
357 }
358 
359 int
main(int argc,char ** argv)360 main(int argc, char **argv)
361 {
362 	int failed = 0;
363 
364 	if (argc > 2) {
365 		fprintf(stderr, "usage: %s [certspath]\n", argv[0]);
366 		exit(1);
367 	}
368 	if (argc == 2)
369 		certs_path = argv[1];
370 
371 	failed |= ssl_get_peer_cert_chain_tests();
372 
373 	return failed;
374 }
375