1 /*
2  * Copyright (C) 2017 Free Software Foundation, Inc.
3  *
4  * Author: Ander Juaristi
5  *
6  * This file is part of GnuTLS.
7  *
8  * The GnuTLS is free software; you can redistribute it and/or
9  * modify it under the terms of the GNU Lesser General Public License
10  * as published by the Free Software Foundation; either version 2.1 of
11  * the License, or (at your option) any later version.
12  *
13  * This library is distributed in the hope that it will be useful, but
14  * WITHOUT ANY WARRANTY; without even the implied warranty of
15  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
16  * Lesser General Public License for more details.
17  *
18  * You should have received a copy of the GNU Lesser General Public License
19  * along with this program.  If not, see <https://www.gnu.org/licenses/>
20  *
21  */
22 
23 #include "gnutls_int.h"
24 #include "ext/psk_ke_modes.h"
25 #include "ext/pre_shared_key.h"
26 #include <assert.h>
27 
28 #define PSK_KE 0
29 #define PSK_DHE_KE 1
30 
31 static int
psk_ke_modes_send_params(gnutls_session_t session,gnutls_buffer_t extdata)32 psk_ke_modes_send_params(gnutls_session_t session,
33 			 gnutls_buffer_t extdata)
34 {
35 	int ret;
36 	const version_entry_st *vers;
37 	uint8_t data[2];
38 	unsigned pos, i;
39 	unsigned have_dhpsk = 0;
40 	unsigned have_psk = 0;
41 
42 	/* Server doesn't send psk_key_exchange_modes */
43 	if (session->security_parameters.entity == GNUTLS_SERVER)
44 		return 0;
45 
46 	/* If session ticket is disabled and no PSK key exchange is
47 	 * enabled, don't send the extension */
48 	if ((session->internals.flags & GNUTLS_NO_TICKETS) &&
49 	    !session->internals.priorities->have_psk)
50 		return 0;
51 
52 	vers = _gnutls_version_max(session);
53 	if (!vers || !vers->tls13_sem)
54 		return 0;
55 
56 	/* We send the list prioritized according to our preferences as a convention
57 	 * (used throughout the protocol), even if the protocol doesn't mandate that
58 	 * for this particular message. That way we can keep the TLS 1.2 semantics/
59 	 * prioritization when negotiating PSK or DHE-PSK. Receiving servers would
60 	 * very likely respect our prioritization if they parse the message serially. */
61 	pos = 0;
62 	for (i=0;i<session->internals.priorities->_kx.num_priorities;i++) {
63 		if (session->internals.priorities->_kx.priorities[i] == GNUTLS_KX_PSK && !have_psk) {
64 			assert(pos <= 1);
65 			data[pos++] = PSK_KE;
66 			session->internals.hsk_flags |= HSK_PSK_KE_MODE_PSK;
67 			have_psk = 1;
68 		} else if ((session->internals.priorities->_kx.priorities[i] == GNUTLS_KX_DHE_PSK ||
69 			    session->internals.priorities->_kx.priorities[i] == GNUTLS_KX_ECDHE_PSK) && !have_dhpsk) {
70 			assert(pos <= 1);
71 			data[pos++] = PSK_DHE_KE;
72 			session->internals.hsk_flags |= HSK_PSK_KE_MODE_DHE_PSK;
73 			have_dhpsk = 1;
74 		}
75 
76 		if (have_psk && have_dhpsk)
77 			break;
78 	}
79 
80 	/* For session resumption we need to send at least one */
81 	if (pos == 0) {
82 		if (session->internals.flags & GNUTLS_NO_TICKETS)
83 			return 0;
84 
85 		data[pos++] = PSK_DHE_KE;
86 		data[pos++] = PSK_KE;
87 		session->internals.hsk_flags |= HSK_PSK_KE_MODE_DHE_PSK;
88 		session->internals.hsk_flags |= HSK_PSK_KE_MODE_PSK;
89 	}
90 
91 	ret = _gnutls_buffer_append_data_prefix(extdata, 8, data, pos);
92 	if (ret < 0)
93 		return gnutls_assert_val(ret);
94 
95 	session->internals.hsk_flags |= HSK_PSK_KE_MODES_SENT;
96 
97 	return 0;
98 }
99 
100 #define MAX_POS INT_MAX
101 
102 /*
103  * Since we only support ECDHE-authenticated PSKs, the server
104  * just verifies that a "psk_key_exchange_modes" extension was received,
105  * and that it contains the value one.
106  */
107 static int
psk_ke_modes_recv_params(gnutls_session_t session,const unsigned char * data,size_t len)108 psk_ke_modes_recv_params(gnutls_session_t session,
109 			 const unsigned char *data, size_t len)
110 {
111 	uint8_t ke_modes_len;
112 	const version_entry_st *vers = get_version(session);
113 	gnutls_psk_server_credentials_t cred;
114 	int dhpsk_pos = MAX_POS;
115 	int psk_pos = MAX_POS;
116 	int cli_psk_pos = MAX_POS;
117 	int cli_dhpsk_pos = MAX_POS;
118 	unsigned i;
119 
120 	/* Client doesn't receive psk_key_exchange_modes */
121 	if (session->security_parameters.entity == GNUTLS_CLIENT)
122 		return gnutls_assert_val(GNUTLS_E_RECEIVED_ILLEGAL_EXTENSION);
123 
124 	/* we set hsk_flags to HSK_PSK_KE_MODE_INVALID on failure to ensure that
125 	 * when we parse the pre-shared key extension we detect PSK_KE_MODES as
126 	 * received. */
127 	if (!vers || !vers->tls13_sem) {
128 		session->internals.hsk_flags |= HSK_PSK_KE_MODE_INVALID;
129 		return gnutls_assert_val(0);
130 	}
131 
132 	cred = (gnutls_psk_server_credentials_t)_gnutls_get_cred(session, GNUTLS_CRD_PSK);
133 	if (cred == NULL && (session->internals.flags & GNUTLS_NO_TICKETS)) {
134 		session->internals.hsk_flags |= HSK_PSK_KE_MODE_INVALID;
135 		return gnutls_assert_val(0);
136 	}
137 
138 	DECR_LEN(len, 1);
139 	ke_modes_len = *(data++);
140 
141 	for (i=0;i<session->internals.priorities->_kx.num_priorities;i++) {
142 		if (session->internals.priorities->_kx.priorities[i] == GNUTLS_KX_PSK && psk_pos == MAX_POS) {
143 			psk_pos = i;
144 		} else if ((session->internals.priorities->_kx.priorities[i] == GNUTLS_KX_DHE_PSK ||
145 			    session->internals.priorities->_kx.priorities[i] == GNUTLS_KX_ECDHE_PSK) &&
146 			    dhpsk_pos == MAX_POS) {
147 			dhpsk_pos = i;
148 		}
149 
150 		if (dhpsk_pos != MAX_POS && psk_pos != MAX_POS)
151 			break;
152 	}
153 
154 	if (psk_pos == MAX_POS && dhpsk_pos == MAX_POS) {
155 		if (!(session->internals.flags & GNUTLS_NO_TICKETS))
156 			dhpsk_pos = 0;
157 		else if (session->internals.priorities->groups.size == 0)
158 			return gnutls_assert_val(0);
159 	}
160 
161 	for (i=0;i<ke_modes_len;i++) {
162 		DECR_LEN(len, 1);
163 		if (data[i] == PSK_DHE_KE)
164 			cli_dhpsk_pos = i;
165 		else if (data[i] == PSK_KE)
166 			cli_psk_pos = i;
167 
168 		_gnutls_handshake_log("EXT[%p]: PSK KE mode %.2x received\n",
169 				      session, (unsigned)data[i]);
170 		if (cli_psk_pos != MAX_POS && cli_dhpsk_pos != MAX_POS)
171 			break;
172 	}
173 
174 	if (session->internals.priorities->server_precedence) {
175 		if (dhpsk_pos != MAX_POS && cli_dhpsk_pos != MAX_POS && dhpsk_pos < psk_pos)
176 			session->internals.hsk_flags |= HSK_PSK_KE_MODE_DHE_PSK;
177 		else if (psk_pos != MAX_POS && cli_psk_pos != MAX_POS && psk_pos < dhpsk_pos)
178 			session->internals.hsk_flags |= HSK_PSK_KE_MODE_PSK;
179 	} else {
180 		if (dhpsk_pos != MAX_POS && cli_dhpsk_pos != MAX_POS && cli_dhpsk_pos < cli_psk_pos)
181 			session->internals.hsk_flags |= HSK_PSK_KE_MODE_DHE_PSK;
182 		else if (psk_pos != MAX_POS && cli_psk_pos != MAX_POS && cli_psk_pos < cli_dhpsk_pos)
183 			session->internals.hsk_flags |= HSK_PSK_KE_MODE_PSK;
184 	}
185 
186 	if ((session->internals.hsk_flags & HSK_PSK_KE_MODE_PSK) ||
187 	    (session->internals.hsk_flags & HSK_PSK_KE_MODE_DHE_PSK)) {
188 
189 		return 0;
190 	} else {
191 		session->internals.hsk_flags |= HSK_PSK_KE_MODE_INVALID;
192 		return gnutls_assert_val(0);
193 	}
194 }
195 
196 const hello_ext_entry_st ext_mod_psk_ke_modes = {
197 	.name = "PSK Key Exchange Modes",
198 	.tls_id = 45,
199 	.gid = GNUTLS_EXTENSION_PSK_KE_MODES,
200 	.client_parse_point = GNUTLS_EXT_TLS,
201 	.server_parse_point = GNUTLS_EXT_TLS,
202 	.validity = GNUTLS_EXT_FLAG_TLS | GNUTLS_EXT_FLAG_CLIENT_HELLO | GNUTLS_EXT_FLAG_TLS13_SERVER_HELLO,
203 	.send_func = psk_ke_modes_send_params,
204 	.recv_func = psk_ke_modes_recv_params
205 };
206