1 /* $OpenBSD: tls13_handshake_msg.c,v 1.3 2021/05/16 14:19:04 jsing Exp $ */
2 /*
3  * Copyright (c) 2018, 2019 Joel Sing <jsing@openbsd.org>
4  *
5  * Permission to use, copy, modify, and distribute this software for any
6  * purpose with or without fee is hereby granted, provided that the above
7  * copyright notice and this permission notice appear in all copies.
8  *
9  * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
10  * WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
11  * MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
12  * ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
13  * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
14  * ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
15  * OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
16  */
17 
18 #include "bytestring.h"
19 #include "tls13_internal.h"
20 
21 #define TLS13_HANDSHAKE_MSG_HEADER_LEN	4
22 #define TLS13_HANDSHAKE_MSG_INITIAL_LEN	256
23 #define TLS13_HANDSHAKE_MSG_MAX_LEN	(256 * 1024)
24 
25 struct tls13_handshake_msg {
26 	uint8_t msg_type;
27 	uint32_t msg_len;
28 	uint8_t *data;
29 	size_t data_len;
30 
31 	struct tls13_buffer *buf;
32 	CBS cbs;
33 	CBB cbb;
34 };
35 
36 struct tls13_handshake_msg *
tls13_handshake_msg_new()37 tls13_handshake_msg_new()
38 {
39 	struct tls13_handshake_msg *msg = NULL;
40 
41 	if ((msg = calloc(1, sizeof(struct tls13_handshake_msg))) == NULL)
42 		goto err;
43 	if ((msg->buf = tls13_buffer_new(0)) == NULL)
44 		goto err;
45 
46 	return msg;
47 
48  err:
49 	tls13_handshake_msg_free(msg);
50 
51 	return NULL;
52 }
53 
54 void
tls13_handshake_msg_free(struct tls13_handshake_msg * msg)55 tls13_handshake_msg_free(struct tls13_handshake_msg *msg)
56 {
57 	if (msg == NULL)
58 		return;
59 
60 	tls13_buffer_free(msg->buf);
61 
62 	CBB_cleanup(&msg->cbb);
63 
64 	freezero(msg->data, msg->data_len);
65 	freezero(msg, sizeof(struct tls13_handshake_msg));
66 }
67 
68 void
tls13_handshake_msg_data(struct tls13_handshake_msg * msg,CBS * cbs)69 tls13_handshake_msg_data(struct tls13_handshake_msg *msg, CBS *cbs)
70 {
71 	CBS_init(cbs, msg->data, msg->data_len);
72 }
73 
74 int
tls13_handshake_msg_set_buffer(struct tls13_handshake_msg * msg,CBS * cbs)75 tls13_handshake_msg_set_buffer(struct tls13_handshake_msg *msg, CBS *cbs)
76 {
77 	return tls13_buffer_set_data(msg->buf, cbs);
78 }
79 
80 uint8_t
tls13_handshake_msg_type(struct tls13_handshake_msg * msg)81 tls13_handshake_msg_type(struct tls13_handshake_msg *msg)
82 {
83 	return msg->msg_type;
84 }
85 
86 int
tls13_handshake_msg_content(struct tls13_handshake_msg * msg,CBS * cbs)87 tls13_handshake_msg_content(struct tls13_handshake_msg *msg, CBS *cbs)
88 {
89 	tls13_handshake_msg_data(msg, cbs);
90 
91 	return CBS_skip(cbs, TLS13_HANDSHAKE_MSG_HEADER_LEN);
92 }
93 
94 int
tls13_handshake_msg_start(struct tls13_handshake_msg * msg,CBB * body,uint8_t msg_type)95 tls13_handshake_msg_start(struct tls13_handshake_msg *msg, CBB *body,
96     uint8_t msg_type)
97 {
98 	if (!CBB_init(&msg->cbb, TLS13_HANDSHAKE_MSG_INITIAL_LEN))
99 		return 0;
100 	if (!CBB_add_u8(&msg->cbb, msg_type))
101 		return 0;
102 	if (!CBB_add_u24_length_prefixed(&msg->cbb, body))
103 		return 0;
104 
105 	return 1;
106 }
107 
108 int
tls13_handshake_msg_finish(struct tls13_handshake_msg * msg)109 tls13_handshake_msg_finish(struct tls13_handshake_msg *msg)
110 {
111 	if (!CBB_finish(&msg->cbb, &msg->data, &msg->data_len))
112 		return 0;
113 
114 	CBS_init(&msg->cbs, msg->data, msg->data_len);
115 
116 	return 1;
117 }
118 
119 static ssize_t
tls13_handshake_msg_read_cb(void * buf,size_t n,void * cb_arg)120 tls13_handshake_msg_read_cb(void *buf, size_t n, void *cb_arg)
121 {
122 	struct tls13_record_layer *rl = cb_arg;
123 
124 	return tls13_read_handshake_data(rl, buf, n);
125 }
126 
127 int
tls13_handshake_msg_recv(struct tls13_handshake_msg * msg,struct tls13_record_layer * rl)128 tls13_handshake_msg_recv(struct tls13_handshake_msg *msg,
129     struct tls13_record_layer *rl)
130 {
131 	uint8_t msg_type;
132 	uint32_t msg_len;
133 	CBS cbs;
134 	int ret;
135 
136 	if (msg->data != NULL)
137 		return TLS13_IO_FAILURE;
138 
139 	if (msg->msg_type == 0) {
140 		if ((ret = tls13_buffer_extend(msg->buf,
141 		    TLS13_HANDSHAKE_MSG_HEADER_LEN,
142 		    tls13_handshake_msg_read_cb, rl)) <= 0)
143 			return ret;
144 
145 		tls13_buffer_cbs(msg->buf, &cbs);
146 
147 		if (!CBS_get_u8(&cbs, &msg_type))
148 			return TLS13_IO_FAILURE;
149 		if (!CBS_get_u24(&cbs, &msg_len))
150 			return TLS13_IO_FAILURE;
151 
152 		/* XXX - do we want to make this variable on message type? */
153 		if (msg_len > TLS13_HANDSHAKE_MSG_MAX_LEN)
154 			return TLS13_IO_FAILURE;
155 
156 		msg->msg_type = msg_type;
157 		msg->msg_len = msg_len;
158 	}
159 
160 	if ((ret = tls13_buffer_extend(msg->buf,
161 	    TLS13_HANDSHAKE_MSG_HEADER_LEN + msg->msg_len,
162 	    tls13_handshake_msg_read_cb, rl)) <= 0)
163 		return ret;
164 
165 	if (!tls13_buffer_finish(msg->buf, &msg->data, &msg->data_len))
166 		return TLS13_IO_FAILURE;
167 
168 	return TLS13_IO_SUCCESS;
169 }
170 
171 int
tls13_handshake_msg_send(struct tls13_handshake_msg * msg,struct tls13_record_layer * rl)172 tls13_handshake_msg_send(struct tls13_handshake_msg *msg,
173     struct tls13_record_layer *rl)
174 {
175 	ssize_t ret;
176 
177 	if (msg->data == NULL)
178 		return TLS13_IO_FAILURE;
179 
180 	if (CBS_len(&msg->cbs) == 0)
181 		return TLS13_IO_FAILURE;
182 
183 	while (CBS_len(&msg->cbs) > 0) {
184 		if ((ret = tls13_write_handshake_data(rl, CBS_data(&msg->cbs),
185 		    CBS_len(&msg->cbs))) <= 0)
186 			return ret;
187 
188 		if (!CBS_skip(&msg->cbs, ret))
189 			return TLS13_IO_FAILURE;
190 	}
191 
192 	return TLS13_IO_SUCCESS;
193 }
194