1 /*
2  * Copyright (C) 2018 Red Hat, Inc.
3  *
4  * Author: Nikos Mavrogiannopoulos
5  *
6  * This file is part of GnuTLS.
7  *
8  * GnuTLS is free software; you can redistribute it and/or modify it
9  * under the terms of the GNU General Public License as published by
10  * the Free Software Foundation; either version 3 of the License, or
11  * (at your option) any later version.
12  *
13  * GnuTLS is distributed in the hope that it will be useful, but
14  * WITHOUT ANY WARRANTY; without even the implied warranty of
15  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
16  * General Public License for more details.
17  *
18  * You should have received a copy of the GNU Lesser General Public License
19  * along with this program.  If not, see <https://www.gnu.org/licenses/>
20  */
21 
22 #ifdef HAVE_CONFIG_H
23 #include <config.h>
24 #endif
25 
26 #include <gnutls/gnutls.h>
27 #include <gnutls/gnutlsxx.h>
28 #include <iostream>
29 
30 extern "C" {
31 #include <stdio.h>
32 #include <stdlib.h>
33 #include <string.h>
34 #include <errno.h>
35 #include <assert.h>
36 #include "cert-common.h"
37 #include <setjmp.h>
38 #include <cmocka.h>
39 #include <minmax.h>
40 }
41 
42 /* This is a basic test for C++ API */
tls_log_func(int level,const char * str)43 static void tls_log_func(int level, const char *str)
44 {
45 	std::cerr << level << "| " << str << "\n";
46 }
47 
48 static char to_server[64 * 1024];
49 static size_t to_server_len = 0;
50 
51 static char to_client[64 * 1024];
52 static size_t to_client_len = 0;
53 
54 static ssize_t
client_push(gnutls_transport_ptr_t tr,const void * data,size_t len)55 client_push(gnutls_transport_ptr_t tr, const void *data, size_t len)
56 {
57 	size_t newlen;
58 
59 	len = MIN(len, sizeof(to_server) - to_server_len);
60 
61 	newlen = to_server_len + len;
62 	memcpy(to_server + to_server_len, data, len);
63 	to_server_len = newlen;
64 
65 	return len;
66 }
67 
68 static ssize_t
client_pull(gnutls_transport_ptr_t tr,void * data,size_t len)69 client_pull(gnutls_transport_ptr_t tr, void *data, size_t len)
70 {
71 	if (to_client_len == 0) {
72 		gnutls_transport_set_errno ((gnutls_session_t)tr, EAGAIN);
73 		return -1;
74 	}
75 
76 	len = MIN(len, to_client_len);
77 
78 	memcpy(data, to_client, len);
79 
80 	memmove(to_client, to_client + len, to_client_len - len);
81 	to_client_len -= len;
82 	return len;
83 }
84 
85 static ssize_t
server_pull(gnutls_transport_ptr_t tr,void * data,size_t len)86 server_pull(gnutls_transport_ptr_t tr, void *data, size_t len)
87 {
88 	if (to_server_len == 0) {
89 		gnutls_transport_set_errno ((gnutls_session_t)tr, EAGAIN);
90 		return -1;
91 	}
92 
93 	len = MIN(len, to_server_len);
94 	memcpy(data, to_server, len);
95 
96 	memmove(to_server, to_server + len, to_server_len - len);
97 	to_server_len -= len;
98 
99 	return len;
100 }
101 
102 static ssize_t
server_push(gnutls_transport_ptr_t tr,const void * data,size_t len)103 server_push(gnutls_transport_ptr_t tr, const void *data, size_t len)
104 {
105 	size_t newlen;
106 
107 	len = MIN(len, sizeof(to_client) - to_client_len);
108 
109 	newlen = to_client_len + len;
110 	memcpy(to_client + to_client_len, data, len);
111 	to_client_len = newlen;
112 
113 	return len;
114 }
115 
reset_buffers(void)116 inline static void reset_buffers(void)
117 {
118 	to_server_len = 0;
119 	to_client_len = 0;
120 }
121 
122 #define MSG "test message"
test_handshake(void ** glob_state,const char * prio,gnutls::server_session & server,gnutls::client_session & client)123 static void test_handshake(void **glob_state, const char *prio,
124 			   gnutls::server_session& server, gnutls::client_session& client)
125 {
126         gnutls::certificate_credentials serverx509cred;
127 	int sret, cret;
128 	gnutls::certificate_credentials clientx509cred;
129 	char buffer[64];
130 	int ret;
131 
132 	/* General init. */
133 	reset_buffers();
134 	gnutls_global_set_log_function(tls_log_func);
135 
136 	try {
137 		serverx509cred.set_x509_key(server_cert, server_key, GNUTLS_X509_FMT_PEM);
138 		server.set_credentials(serverx509cred);
139 
140 		server.set_priority(prio, NULL);
141 
142 		server.set_transport_push_function(server_push);
143 		server.set_transport_pull_function(server_pull);
144 		server.set_transport_ptr(server.ptr());
145 
146 		client.set_priority(prio, NULL);
147 		client.set_credentials(clientx509cred);
148 
149 		client.set_transport_push_function(client_push);
150 		client.set_transport_pull_function(client_pull);
151 		client.set_transport_ptr(client.ptr());
152 	}
153 	catch (std::exception &ex) {
154 		std::cerr << "Exception caught: " << ex.what() << std::endl;
155 		fail();
156 	}
157 
158 	sret = cret = GNUTLS_E_AGAIN;
159 
160 	do {
161 		if (cret == GNUTLS_E_AGAIN) {
162 			try {
163 				cret = client.handshake();
164 			} catch(gnutls::exception &ex) {
165 				cret = ex.get_code();
166 				if (cret == GNUTLS_E_INTERRUPTED || cret == GNUTLS_E_AGAIN)
167 					cret = GNUTLS_E_AGAIN;
168 			}
169 		}
170 		if (sret == GNUTLS_E_AGAIN) {
171 			try {
172 				sret = server.handshake();
173 			} catch(gnutls::exception &ex) {
174 				sret = ex.get_code();
175 				if (sret == GNUTLS_E_INTERRUPTED || sret == GNUTLS_E_AGAIN)
176 					sret = GNUTLS_E_AGAIN;
177 			}
178 		}
179 	}
180 	while ((cret == GNUTLS_E_AGAIN || (cret == 0 && sret == GNUTLS_E_AGAIN)) &&
181 	       (sret == GNUTLS_E_AGAIN || (sret == 0 && cret == GNUTLS_E_AGAIN)));
182 
183 	if (sret < 0 || cret < 0) {
184 		fail();
185 	}
186 
187 	try {
188 		client.send(MSG, sizeof(MSG)-1);
189 		ret = server.recv(buffer, sizeof(buffer));
190 
191 		assert(ret == sizeof(MSG)-1);
192 		assert(memcmp(buffer, MSG, sizeof(MSG)-1) == 0);
193 
194 		client.bye(GNUTLS_SHUT_WR);
195 		server.bye(GNUTLS_SHUT_WR);
196 	}
197 	catch (std::exception &ex) {
198 		std::cerr << "Exception caught: " << ex.what() << std::endl;
199 		fail();
200 	}
201 
202 	return;
203 }
204 
tls_handshake(void ** glob_state)205 static void tls_handshake(void **glob_state)
206 {
207         gnutls::server_session server;
208 	gnutls::client_session client;
209 
210 	test_handshake(glob_state, "NORMAL", server, client);
211 }
212 
tls_handshake_alt(void ** glob_state)213 static void tls_handshake_alt(void **glob_state)
214 {
215         gnutls::server_session server(0);
216 	gnutls::client_session client(0);
217 
218 	test_handshake(glob_state, "NORMAL", server, client);
219 }
220 
tls12_handshake(void ** glob_state)221 static void tls12_handshake(void **glob_state)
222 {
223         gnutls::server_session server;
224 	gnutls::client_session client;
225 
226 	test_handshake(glob_state, "NORMAL:-VERS-TLS-ALL:+VERS-TLS1.2", server, client);
227 }
228 
tls13_handshake(void ** glob_state)229 static void tls13_handshake(void **glob_state)
230 {
231         gnutls::server_session server;
232 	gnutls::client_session client;
233 
234 	test_handshake(glob_state, "NORMAL:-VERS-TLS-ALL:+VERS-TLS1.3", server, client);
235 }
236 
main(void)237 int main(void)
238 {
239 	const struct CMUnitTest tests[] = {
240 		cmocka_unit_test(tls_handshake),
241 		cmocka_unit_test(tls_handshake_alt),
242 		cmocka_unit_test(tls13_handshake),
243 		cmocka_unit_test(tls12_handshake)
244 	};
245 	return cmocka_run_group_tests(tests, NULL, NULL);
246 }
247