1 /*
2 * Copyright (C) 2016 Red Hat, Inc.
3 * Copyright (C) 2013-2016 Nikos Mavrogiannopoulos
4 *
5 * This file is part of GnuTLS.
6 *
7 * GnuTLS is free software; you can redistribute it and/or modify it
8 * under the terms of the GNU General Public License as published by
9 * the Free Software Foundation; either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * GnuTLS is distributed in the hope that it will be useful, but
13 * WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
15 * General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with GnuTLS; if not, write to the Free Software Foundation,
19 * Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA
20 */
21
22 #ifdef HAVE_CONFIG_H
23 #include <config.h>
24 #endif
25
26 #include <stdio.h>
27 #include <stdlib.h>
28
29 #if defined(_WIN32) || !defined(ENABLE_ALPN)
30
main(int argc,char ** argv)31 int main(int argc, char **argv)
32 {
33 exit(77);
34 }
35
36 #else
37
38 #include <string.h>
39 #include <sys/types.h>
40 #include <netinet/in.h>
41 #include <sys/socket.h>
42 #include <sys/wait.h>
43 #include <arpa/inet.h>
44 #include <unistd.h>
45 #include <gnutls/gnutls.h>
46 #include <gnutls/dtls.h>
47
48 #include "utils.h"
49
50 static void terminate(void);
51
52 /* This program tests whether the gnutls_record_get_state() works as
53 * expected.
54 */
55
server_log_func(int level,const char * str)56 static void server_log_func(int level, const char *str)
57 {
58 fprintf(stderr, "server|<%d>| %s", level, str);
59 }
60
client_log_func(int level,const char * str)61 static void client_log_func(int level, const char *str)
62 {
63 fprintf(stderr, "client|<%d>| %s", level, str);
64 }
65
66 /* These are global */
67 static pid_t child;
68
69 /* A very basic DTLS client, with anonymous authentication, that negotiates SRTP
70 */
71
dump(const char * name,uint8_t * data,unsigned data_size)72 static void dump(const char *name, uint8_t *data, unsigned data_size)
73 {
74 unsigned i;
75
76 fprintf(stderr, "%s", name);
77 for (i=0;i<data_size;i++)
78 fprintf(stderr, "%.2x", (unsigned)data[i]);
79 fprintf(stderr, "\n");
80 }
81
terminate(void)82 static void terminate(void)
83 {
84 int status = 0;
85
86 kill(child, SIGTERM);
87 wait(&status);
88 exit(1);
89 }
90
client(int fd)91 static void client(int fd)
92 {
93 gnutls_session_t session;
94 int ret;
95 gnutls_anon_client_credentials_t anoncred;
96 gnutls_datum_t mac_key, iv, cipher_key;
97 gnutls_datum_t read_mac_key, read_iv, read_cipher_key;
98 unsigned char rseq_number[8];
99 unsigned char wseq_number[8];
100 unsigned char key_material[512], *p;
101 unsigned i;
102 unsigned block_size, hash_size, key_size, iv_size;
103 const char *err;
104 /* Need to enable anonymous KX specifically. */
105
106 global_init();
107
108 if (debug) {
109 gnutls_global_set_log_function(client_log_func);
110 gnutls_global_set_log_level(4711);
111 }
112
113 gnutls_anon_allocate_client_credentials(&anoncred);
114
115 /* Initialize TLS session
116 */
117 gnutls_init(&session, GNUTLS_CLIENT|GNUTLS_DATAGRAM);
118
119 /* Use default priorities */
120 ret = gnutls_priority_set_direct(session,
121 "NONE:+VERS-DTLS1.0:+AES-128-CBC:+SHA1:+SIGN-ALL:+COMP-NULL:+ANON-DH:+ANON-ECDH:+CURVE-ALL",
122 &err);
123 if (ret < 0) {
124 fail("client: priority set failed (%s): %s\n",
125 gnutls_strerror(ret), err);
126 exit(1);
127 }
128
129 /* put the anonymous credentials to the current session
130 */
131 gnutls_credentials_set(session, GNUTLS_CRD_ANON, anoncred);
132
133 gnutls_transport_set_int(session, fd);
134
135 /* Perform the TLS handshake
136 */
137 do {
138 ret = gnutls_handshake(session);
139 }
140 while (ret < 0 && gnutls_error_is_fatal(ret) == 0);
141
142 if (ret < 0) {
143 fail("client: Handshake failed: %s\n", strerror(ret));
144 terminate();
145 } else {
146 if (debug)
147 success("client: Handshake was completed\n");
148 }
149
150 if (debug)
151 success("client: TLS version is: %s\n",
152 gnutls_protocol_get_name
153 (gnutls_protocol_get_version(session)));
154
155 ret = gnutls_cipher_get(session);
156 if (ret != GNUTLS_CIPHER_AES_128_CBC) {
157 fprintf(stderr, "negotiated unexpected cipher: %s\n", gnutls_cipher_get_name(ret));
158 terminate();
159 }
160
161 ret = gnutls_mac_get(session);
162 if (ret != GNUTLS_MAC_SHA1) {
163 fprintf(stderr, "negotiated unexpected mac: %s\n", gnutls_mac_get_name(ret));
164 terminate();
165 }
166
167 iv_size = 16;
168 hash_size = 20;
169 key_size = 16;
170 block_size = 2*hash_size + 2*key_size + 2 *iv_size;
171
172 ret = gnutls_prf(session, 13, "key expansion", 1, 0, NULL, block_size,
173 (void*)key_material);
174 if (ret < 0) {
175 fprintf(stderr, "error in %d\n", __LINE__);
176 gnutls_perror(ret);
177 terminate();
178 }
179 p = key_material;
180
181 /* check whether the key material matches our calculations */
182 ret = gnutls_record_get_state(session, 0, &mac_key, &iv, &cipher_key, wseq_number);
183 if (ret < 0) {
184 fprintf(stderr, "error in %d\n", __LINE__);
185 gnutls_perror(ret);
186 terminate();
187 }
188
189 if (memcmp(wseq_number, "\x00\x01\x00\x00\x00\x00\x00\x01", 8) != 0) {
190 dump("wseq:", wseq_number, 8);
191 fprintf(stderr, "error in %d\n", __LINE__);
192 terminate();
193 }
194
195 ret = gnutls_record_get_state(session, 1, &read_mac_key, &read_iv, &read_cipher_key, rseq_number);
196 if (ret < 0) {
197 fprintf(stderr, "error in %d\n", __LINE__);
198 gnutls_perror(ret);
199 terminate();
200 }
201
202 if (memcmp(rseq_number, "\x00\x01\x00\x00\x00\x00\x00\x01", 8) != 0) {
203 dump("rseq:", rseq_number, 8);
204 fprintf(stderr, "error in %d\n", __LINE__);
205 terminate();
206 }
207
208 if (hash_size != mac_key.size || memcmp(p, mac_key.data, hash_size) != 0) {
209 dump("MAC:", mac_key.data, mac_key.size);
210 dump("Block:", key_material, block_size);
211 fprintf(stderr, "error in %d\n", __LINE__);
212 terminate();
213 }
214 p+= hash_size;
215
216 if (hash_size != read_mac_key.size || memcmp(p, read_mac_key.data, hash_size) != 0) {
217 dump("MAC:", read_mac_key.data, read_mac_key.size);
218 dump("Block:", key_material, block_size);
219 fprintf(stderr, "error in %d\n", __LINE__);
220 terminate();
221 }
222 p+= hash_size;
223
224 if (key_size != cipher_key.size || memcmp(p, cipher_key.data, key_size) != 0) {
225 fprintf(stderr, "error in %d\n", __LINE__);
226 terminate();
227 }
228 p+= key_size;
229
230 if (key_size != read_cipher_key.size || memcmp(p, read_cipher_key.data, key_size) != 0) {
231 fprintf(stderr, "error in %d\n", __LINE__);
232 terminate();
233 }
234 p+= key_size;
235
236 /* check sequence numbers */
237 for (i=0;i<5;i++) {
238 ret = gnutls_record_send(session, "hello", 5);
239 if (ret < 0) {
240 fail("gnutls_record_send: %s\n", gnutls_strerror(ret));
241 }
242 }
243
244 memset(wseq_number, 0xAA, sizeof(wseq_number));
245 ret = gnutls_record_get_state(session, 0, NULL, NULL, NULL, wseq_number);
246 if (ret < 0) {
247 fprintf(stderr, "error in %d\n", __LINE__);
248 gnutls_perror(ret);
249 terminate();
250 }
251
252 if (memcmp(wseq_number, "\x00\x01\x00\x00\x00\x00\x00\x06", 8) != 0) {
253 dump("wseq:", wseq_number, 8);
254 fprintf(stderr, "error in %d\n", __LINE__);
255 terminate();
256 }
257
258 memset(rseq_number, 0xAA, sizeof(rseq_number));
259 ret = gnutls_record_get_state(session, 1, NULL, NULL, NULL, rseq_number);
260 if (ret < 0) {
261 fprintf(stderr, "error in %d\n", __LINE__);
262 gnutls_perror(ret);
263 terminate();
264 }
265
266 if (memcmp(rseq_number, "\x00\x01\x00\x00\x00\x00\x00\x01", 8) != 0) {
267 dump("rseq:", rseq_number, 8);
268 fprintf(stderr, "error in %d\n", __LINE__);
269 terminate();
270 }
271 gnutls_bye(session, GNUTLS_SHUT_WR);
272
273 close(fd);
274
275 gnutls_deinit(session);
276
277 gnutls_anon_free_client_credentials(anoncred);
278
279 gnutls_global_deinit();
280 }
281
server(int fd)282 static void server(int fd)
283 {
284 int ret;
285 gnutls_session_t session;
286 gnutls_anon_server_credentials_t anoncred;
287 gnutls_dh_params_t dh_params;
288 char buf[128];
289 const gnutls_datum_t p3 =
290 { (unsigned char *) pkcs3, strlen(pkcs3) };
291
292 /* this must be called once in the program
293 */
294 global_init();
295
296 if (debug) {
297 gnutls_global_set_log_function(server_log_func);
298 gnutls_global_set_log_level(4711);
299 }
300
301 gnutls_anon_allocate_server_credentials(&anoncred);
302 gnutls_dh_params_init(&dh_params);
303 gnutls_dh_params_import_pkcs3(dh_params, &p3, GNUTLS_X509_FMT_PEM);
304 gnutls_anon_set_server_dh_params(anoncred, dh_params);
305
306 gnutls_init(&session, GNUTLS_SERVER|GNUTLS_DATAGRAM);
307
308 /* avoid calling all the priority functions, since the defaults
309 * are adequate.
310 */
311 ret = gnutls_priority_set_direct(session,
312 "NORMAL:+VERS-DTLS1.0:+ANON-DH:+ANON-ECDH", NULL);
313 if (ret < 0) {
314 fail("server: priority set failed (%s)\n\n",
315 gnutls_strerror(ret));
316 terminate();
317 }
318
319 gnutls_credentials_set(session, GNUTLS_CRD_ANON, anoncred);
320
321 gnutls_transport_set_int(session, fd);
322
323 do {
324 ret = gnutls_handshake(session);
325 }
326 while (ret < 0 && gnutls_error_is_fatal(ret) == 0);
327 if (ret < 0) {
328 close(fd);
329 gnutls_deinit(session);
330 fail("server: Handshake has failed (%s)\n\n",
331 gnutls_strerror(ret));
332 terminate();
333 }
334 if (debug)
335 success("server: Handshake was completed\n");
336
337 if (debug)
338 success("server: TLS version is: %s\n",
339 gnutls_protocol_get_name
340 (gnutls_protocol_get_version(session)));
341
342 do {
343 ret = gnutls_record_recv(session, buf, sizeof(buf));
344 } while(ret > 0 || ret == GNUTLS_E_AGAIN || ret == GNUTLS_E_INTERRUPTED);
345
346 if (ret < 0) {
347 fail("error: %s\n", gnutls_strerror(ret));
348 }
349
350 /* do not wait for the peer to close the connection.
351 */
352 gnutls_bye(session, GNUTLS_SHUT_WR);
353
354 close(fd);
355 gnutls_deinit(session);
356
357 gnutls_anon_free_server_credentials(anoncred);
358 gnutls_dh_params_deinit(dh_params);
359
360 gnutls_global_deinit();
361
362 if (debug)
363 success("server: finished\n");
364 }
365
start(void)366 static void start(void)
367 {
368 int fd[2];
369 int ret;
370
371 ret = socketpair(AF_UNIX, SOCK_STREAM, 0, fd);
372 if (ret < 0) {
373 perror("socketpair");
374 exit(1);
375 }
376
377 child = fork();
378 if (child < 0) {
379 perror("fork");
380 fail("fork");
381 exit(1);
382 }
383
384 if (child) {
385 int status;
386 /* parent */
387
388 server(fd[0]);
389 wait(&status);
390 check_wait_status(status);
391 } else {
392 close(fd[0]);
393 client(fd[1]);
394 exit(0);
395 }
396 }
397
doit(void)398 void doit(void)
399 {
400 start();
401 }
402
403 #endif /* _WIN32 */
404