1 /*	$OpenBSD: ssl_set_alpn_protos.c,v 1.2 2022/07/21 03:59:04 tb Exp $ */
2 /*
3  * Copyright (c) 2022 Theo Buehler <tb@openbsd.org>
4  *
5  * Permission to use, copy, modify, and distribute this software for any
6  * purpose with or without fee is hereby granted, provided that the above
7  * copyright notice and this permission notice appear in all copies.
8  *
9  * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
10  * WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
11  * MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
12  * ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
13  * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
14  * ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
15  * OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
16  */
17 
18 #include <err.h>
19 #include <stdio.h>
20 
21 #include <openssl/ssl.h>
22 
23 struct alpn_test {
24 	const char *description;
25 	const uint8_t protocols[24];
26 	size_t protocols_len;
27 	int ret;
28 };
29 
30 static const struct alpn_test alpn_tests[] = {
31 	{
32 		.description = "valid protocol list",
33 		.protocols = {
34 			6, 's', 'p', 'd', 'y', '/', '1',
35 			8, 'h', 't', 't', 'p', '/', '1', '.', '1',
36 		},
37 		.protocols_len = 16,
38 		.ret = 0,
39 	},
40 	{
41 		.description = "zero length protocol",
42 		.protocols = {
43 			0,
44 		},
45 		.protocols_len = 1,
46 		.ret = 1,
47 	},
48 	{
49 		.description = "zero length protocol at start",
50 		.protocols = {
51 			0,
52 			8, 'h', 't', 't', 'p', '/', '1', '.', '1',
53 			6, 's', 'p', 'd', 'y', '/', '1',
54 		},
55 		.protocols_len = 17,
56 		.ret = 1,
57 	},
58 	{
59 		.description = "zero length protocol embedded",
60 		.protocols = {
61 			8, 'h', 't', 't', 'p', '/', '1', '.', '1',
62 			0,
63 			6, 's', 'p', 'd', 'y', '/', '1',
64 		},
65 		.protocols_len = 17,
66 		.ret = 1,
67 	},
68 	{
69 		.description = "zero length protocol at end",
70 		.protocols = {
71 			8, 'h', 't', 't', 'p', '/', '1', '.', '1',
72 			6, 's', 'p', 'd', 'y', '/', '1',
73 			0,
74 		},
75 		.protocols_len = 17,
76 		.ret = 1,
77 	},
78 	{
79 		.description = "protocol length too short",
80 		.protocols = {
81 			6, 'h', 't', 't', 'p', '/', '1', '.', '1',
82 		},
83 		.protocols_len = 9,
84 		.ret = 1,
85 	},
86 	{
87 		.description = "protocol length too long",
88 		.protocols = {
89 			8, 's', 'p', 'd', 'y', '/', '1',
90 		},
91 		.protocols_len = 7,
92 		.ret = 1,
93 	},
94 };
95 
96 static const size_t N_ALPN_TESTS = sizeof(alpn_tests) / sizeof(alpn_tests[0]);
97 
98 static int
99 test_ssl_set_alpn_protos(const struct alpn_test *tc)
100 {
101 	SSL_CTX *ctx;
102 	SSL *ssl;
103 	int ret;
104 	int failed = 0;
105 
106 	if ((ctx = SSL_CTX_new(TLS_client_method())) == NULL)
107 		errx(1, "SSL_CTX_new");
108 
109 	ret = SSL_CTX_set_alpn_protos(ctx, tc->protocols, tc->protocols_len);
110 	if (ret != tc->ret) {
111 		warnx("%s: setting on SSL_CTX: want %d, got %d",
112 		    tc->description, tc->ret, ret);
113 		failed = 1;
114 	}
115 
116 	if ((ssl = SSL_new(ctx)) == NULL)
117 		errx(1, "SSL_new");
118 
119 	ret = SSL_set_alpn_protos(ssl, tc->protocols, tc->protocols_len);
120 	if (ret != tc->ret) {
121 		warnx("%s: setting on SSL: want %d, got %d",
122 		    tc->description, tc->ret, ret);
123 		failed = 1;
124 	}
125 
126 	SSL_CTX_free(ctx);
127 	SSL_free(ssl);
128 
129 	return failed;
130 }
131 
132 static int
133 test_ssl_set_alpn_protos_edge_cases(void)
134 {
135 	SSL_CTX *ctx;
136 	SSL *ssl;
137 	const uint8_t valid[] = {
138 		6, 's', 'p', 'd', 'y', '/', '3',
139 		8, 'h', 't', 't', 'p', '/', '1', '.', '1',
140 	};
141 	int failed = 0;
142 
143 	if ((ctx = SSL_CTX_new(TLS_client_method())) == NULL)
144 		errx(1, "SSL_CTX_new");
145 
146 	if (SSL_CTX_set_alpn_protos(ctx, valid, sizeof(valid)) != 0) {
147 		warnx("setting valid protocols on SSL_CTX failed");
148 		failed = 1;
149 	}
150 	if (SSL_CTX_set_alpn_protos(ctx, NULL, 0) != 0) {
151 		warnx("setting 'NULL, 0' on SSL_CTX failed");
152 		failed = 1;
153 	}
154 	if (SSL_CTX_set_alpn_protos(ctx, valid, 0) != 0) {
155 		warnx("setting 'valid, 0' on SSL_CTX failed");
156 		failed = 1;
157 	}
158 	if (SSL_CTX_set_alpn_protos(ctx, NULL, 43) != 0) {
159 		warnx("setting 'NULL, 43' on SSL_CTX failed");
160 		failed = 1;
161 	}
162 
163 	if ((ssl = SSL_new(ctx)) == NULL)
164 		errx(1, "SSL_new");
165 
166 	if (SSL_set_alpn_protos(ssl, valid, sizeof(valid)) != 0) {
167 		warnx("setting valid protocols on SSL failed");
168 		failed = 1;
169 	}
170 	if (SSL_set_alpn_protos(ssl, NULL, 0) != 0) {
171 		warnx("setting 'NULL, 0' on SSL failed");
172 		failed = 1;
173 	}
174 	if (SSL_set_alpn_protos(ssl, valid, 0) != 0) {
175 		warnx("setting 'valid, 0' on SSL failed");
176 		failed = 1;
177 	}
178 	if (SSL_set_alpn_protos(ssl, NULL, 43) != 0) {
179 		warnx("setting 'NULL, 43' on SSL failed");
180 		failed = 1;
181 	}
182 
183 	SSL_CTX_free(ctx);
184 	SSL_free(ssl);
185 
186 	return failed;
187 }
188 
189 int
190 main(void)
191 {
192 	size_t i;
193 	int failed = 0;
194 
195 	for (i = 0; i < N_ALPN_TESTS; i++)
196 		failed |= test_ssl_set_alpn_protos(&alpn_tests[i]);
197 
198 	failed |= test_ssl_set_alpn_protos_edge_cases();
199 
200 	if (!failed)
201 		printf("PASS %s\n", __FILE__);
202 
203 	return failed;
204 }
205