1 /*
2 * Copyright (C) 2018 Red Hat, Inc.
3 *
4 * Author: Nikos Mavrogiannopoulos
5 *
6 * This file is part of GnuTLS.
7 *
8 * GnuTLS is free software; you can redistribute it and/or modify it
9 * under the terms of the GNU General Public License as published by
10 * the Free Software Foundation; either version 3 of the License, or
11 * (at your option) any later version.
12 *
13 * GnuTLS is distributed in the hope that it will be useful, but
14 * WITHOUT ANY WARRANTY; without even the implied warranty of
15 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
16 * General Public License for more details.
17 *
18 * You should have received a copy of the GNU Lesser General Public License
19 * along with this program. If not, see <https://www.gnu.org/licenses/>
20 */
21
22 #ifdef HAVE_CONFIG_H
23 #include <config.h>
24 #endif
25
26 #include <gnutls/gnutls.h>
27 #include <gnutls/gnutlsxx.h>
28 #include <iostream>
29
30 extern "C" {
31 #include <stdio.h>
32 #include <stdlib.h>
33 #include <string.h>
34 #include <errno.h>
35 #include <assert.h>
36 #include "cert-common.h"
37 #include <setjmp.h>
38 #include <cmocka.h>
39 #include <minmax.h>
40 }
41
42 /* This is a basic test for C++ API */
tls_log_func(int level,const char * str)43 static void tls_log_func(int level, const char *str)
44 {
45 std::cerr << level << "| " << str << "\n";
46 }
47
48 static char to_server[64 * 1024];
49 static size_t to_server_len = 0;
50
51 static char to_client[64 * 1024];
52 static size_t to_client_len = 0;
53
54 static ssize_t
client_push(gnutls_transport_ptr_t tr,const void * data,size_t len)55 client_push(gnutls_transport_ptr_t tr, const void *data, size_t len)
56 {
57 size_t newlen;
58
59 len = MIN(len, sizeof(to_server) - to_server_len);
60
61 newlen = to_server_len + len;
62 memcpy(to_server + to_server_len, data, len);
63 to_server_len = newlen;
64
65 return len;
66 }
67
68 static ssize_t
client_pull(gnutls_transport_ptr_t tr,void * data,size_t len)69 client_pull(gnutls_transport_ptr_t tr, void *data, size_t len)
70 {
71 if (to_client_len == 0) {
72 gnutls_transport_set_errno ((gnutls_session_t)tr, EAGAIN);
73 return -1;
74 }
75
76 len = MIN(len, to_client_len);
77
78 memcpy(data, to_client, len);
79
80 memmove(to_client, to_client + len, to_client_len - len);
81 to_client_len -= len;
82 return len;
83 }
84
85 static ssize_t
server_pull(gnutls_transport_ptr_t tr,void * data,size_t len)86 server_pull(gnutls_transport_ptr_t tr, void *data, size_t len)
87 {
88 if (to_server_len == 0) {
89 gnutls_transport_set_errno ((gnutls_session_t)tr, EAGAIN);
90 return -1;
91 }
92
93 len = MIN(len, to_server_len);
94 memcpy(data, to_server, len);
95
96 memmove(to_server, to_server + len, to_server_len - len);
97 to_server_len -= len;
98
99 return len;
100 }
101
102 static ssize_t
server_push(gnutls_transport_ptr_t tr,const void * data,size_t len)103 server_push(gnutls_transport_ptr_t tr, const void *data, size_t len)
104 {
105 size_t newlen;
106
107 len = MIN(len, sizeof(to_client) - to_client_len);
108
109 newlen = to_client_len + len;
110 memcpy(to_client + to_client_len, data, len);
111 to_client_len = newlen;
112
113 return len;
114 }
115
reset_buffers(void)116 inline static void reset_buffers(void)
117 {
118 to_server_len = 0;
119 to_client_len = 0;
120 }
121
122 #define MSG "test message"
test_handshake(void ** glob_state,const char * prio,gnutls::server_session & server,gnutls::client_session & client)123 static void test_handshake(void **glob_state, const char *prio,
124 gnutls::server_session& server, gnutls::client_session& client)
125 {
126 gnutls::certificate_credentials serverx509cred;
127 int sret, cret;
128 gnutls::certificate_credentials clientx509cred;
129 char buffer[64];
130 int ret;
131
132 /* General init. */
133 reset_buffers();
134 gnutls_global_set_log_function(tls_log_func);
135
136 try {
137 serverx509cred.set_x509_key(server_cert, server_key, GNUTLS_X509_FMT_PEM);
138 server.set_credentials(serverx509cred);
139
140 server.set_priority(prio, NULL);
141
142 server.set_transport_push_function(server_push);
143 server.set_transport_pull_function(server_pull);
144 server.set_transport_ptr(server.ptr());
145
146 client.set_priority(prio, NULL);
147 client.set_credentials(clientx509cred);
148
149 client.set_transport_push_function(client_push);
150 client.set_transport_pull_function(client_pull);
151 client.set_transport_ptr(client.ptr());
152 }
153 catch (std::exception &ex) {
154 std::cerr << "Exception caught: " << ex.what() << std::endl;
155 fail();
156 }
157
158 sret = cret = GNUTLS_E_AGAIN;
159
160 do {
161 if (cret == GNUTLS_E_AGAIN) {
162 try {
163 cret = client.handshake();
164 } catch(gnutls::exception &ex) {
165 cret = ex.get_code();
166 if (cret == GNUTLS_E_INTERRUPTED || cret == GNUTLS_E_AGAIN)
167 cret = GNUTLS_E_AGAIN;
168 }
169 }
170 if (sret == GNUTLS_E_AGAIN) {
171 try {
172 sret = server.handshake();
173 } catch(gnutls::exception &ex) {
174 sret = ex.get_code();
175 if (sret == GNUTLS_E_INTERRUPTED || sret == GNUTLS_E_AGAIN)
176 sret = GNUTLS_E_AGAIN;
177 }
178 }
179 }
180 while ((cret == GNUTLS_E_AGAIN || (cret == 0 && sret == GNUTLS_E_AGAIN)) &&
181 (sret == GNUTLS_E_AGAIN || (sret == 0 && cret == GNUTLS_E_AGAIN)));
182
183 if (sret < 0 || cret < 0) {
184 fail();
185 }
186
187 try {
188 client.send(MSG, sizeof(MSG)-1);
189 ret = server.recv(buffer, sizeof(buffer));
190
191 assert(ret == sizeof(MSG)-1);
192 assert(memcmp(buffer, MSG, sizeof(MSG)-1) == 0);
193
194 client.bye(GNUTLS_SHUT_WR);
195 server.bye(GNUTLS_SHUT_WR);
196 }
197 catch (std::exception &ex) {
198 std::cerr << "Exception caught: " << ex.what() << std::endl;
199 fail();
200 }
201
202 return;
203 }
204
tls_handshake(void ** glob_state)205 static void tls_handshake(void **glob_state)
206 {
207 gnutls::server_session server;
208 gnutls::client_session client;
209
210 test_handshake(glob_state, "NORMAL", server, client);
211 }
212
tls_handshake_alt(void ** glob_state)213 static void tls_handshake_alt(void **glob_state)
214 {
215 gnutls::server_session server(0);
216 gnutls::client_session client(0);
217
218 test_handshake(glob_state, "NORMAL", server, client);
219 }
220
tls12_handshake(void ** glob_state)221 static void tls12_handshake(void **glob_state)
222 {
223 gnutls::server_session server;
224 gnutls::client_session client;
225
226 test_handshake(glob_state, "NORMAL:-VERS-TLS-ALL:+VERS-TLS1.2", server, client);
227 }
228
tls13_handshake(void ** glob_state)229 static void tls13_handshake(void **glob_state)
230 {
231 gnutls::server_session server;
232 gnutls::client_session client;
233
234 test_handshake(glob_state, "NORMAL:-VERS-TLS-ALL:+VERS-TLS1.3", server, client);
235 }
236
main(void)237 int main(void)
238 {
239 const struct CMUnitTest tests[] = {
240 cmocka_unit_test(tls_handshake),
241 cmocka_unit_test(tls_handshake_alt),
242 cmocka_unit_test(tls13_handshake),
243 cmocka_unit_test(tls12_handshake)
244 };
245 return cmocka_run_group_tests(tests, NULL, NULL);
246 }
247