1 /* $OpenBSD: tls13_handshake_msg.c,v 1.7 2024/02/04 20:50:23 tb Exp $ */
2 /*
3 * Copyright (c) 2018, 2019 Joel Sing <jsing@openbsd.org>
4 *
5 * Permission to use, copy, modify, and distribute this software for any
6 * purpose with or without fee is hereby granted, provided that the above
7 * copyright notice and this permission notice appear in all copies.
8 *
9 * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
10 * WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
11 * MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
12 * ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
13 * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
14 * ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
15 * OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
16 */
17
18 #include "bytestring.h"
19 #include "tls13_internal.h"
20
21 #define TLS13_HANDSHAKE_MSG_HEADER_LEN 4
22 #define TLS13_HANDSHAKE_MSG_INITIAL_LEN 256
23 #define TLS13_HANDSHAKE_MSG_MAX_LEN (256 * 1024)
24
25 struct tls13_handshake_msg {
26 uint8_t msg_type;
27 uint32_t msg_len;
28 uint8_t *data;
29 size_t data_len;
30
31 struct tls_buffer *buf;
32 CBS cbs;
33 CBB cbb;
34 };
35
36 struct tls13_handshake_msg *
tls13_handshake_msg_new(void)37 tls13_handshake_msg_new(void)
38 {
39 struct tls13_handshake_msg *msg = NULL;
40
41 if ((msg = calloc(1, sizeof(struct tls13_handshake_msg))) == NULL)
42 goto err;
43 if ((msg->buf = tls_buffer_new(0)) == NULL)
44 goto err;
45
46 return msg;
47
48 err:
49 tls13_handshake_msg_free(msg);
50
51 return NULL;
52 }
53
54 void
tls13_handshake_msg_free(struct tls13_handshake_msg * msg)55 tls13_handshake_msg_free(struct tls13_handshake_msg *msg)
56 {
57 if (msg == NULL)
58 return;
59
60 tls_buffer_free(msg->buf);
61
62 CBB_cleanup(&msg->cbb);
63
64 freezero(msg->data, msg->data_len);
65 freezero(msg, sizeof(struct tls13_handshake_msg));
66 }
67
68 void
tls13_handshake_msg_data(struct tls13_handshake_msg * msg,CBS * cbs)69 tls13_handshake_msg_data(struct tls13_handshake_msg *msg, CBS *cbs)
70 {
71 CBS_init(cbs, msg->data, msg->data_len);
72 }
73
74 uint8_t
tls13_handshake_msg_type(struct tls13_handshake_msg * msg)75 tls13_handshake_msg_type(struct tls13_handshake_msg *msg)
76 {
77 return msg->msg_type;
78 }
79
80 int
tls13_handshake_msg_content(struct tls13_handshake_msg * msg,CBS * cbs)81 tls13_handshake_msg_content(struct tls13_handshake_msg *msg, CBS *cbs)
82 {
83 tls13_handshake_msg_data(msg, cbs);
84
85 return CBS_skip(cbs, TLS13_HANDSHAKE_MSG_HEADER_LEN);
86 }
87
88 int
tls13_handshake_msg_start(struct tls13_handshake_msg * msg,CBB * body,uint8_t msg_type)89 tls13_handshake_msg_start(struct tls13_handshake_msg *msg, CBB *body,
90 uint8_t msg_type)
91 {
92 if (!CBB_init(&msg->cbb, TLS13_HANDSHAKE_MSG_INITIAL_LEN))
93 return 0;
94 if (!CBB_add_u8(&msg->cbb, msg_type))
95 return 0;
96 if (!CBB_add_u24_length_prefixed(&msg->cbb, body))
97 return 0;
98
99 return 1;
100 }
101
102 int
tls13_handshake_msg_finish(struct tls13_handshake_msg * msg)103 tls13_handshake_msg_finish(struct tls13_handshake_msg *msg)
104 {
105 if (!CBB_finish(&msg->cbb, &msg->data, &msg->data_len))
106 return 0;
107
108 CBS_init(&msg->cbs, msg->data, msg->data_len);
109
110 return 1;
111 }
112
113 static ssize_t
tls13_handshake_msg_read_cb(void * buf,size_t n,void * cb_arg)114 tls13_handshake_msg_read_cb(void *buf, size_t n, void *cb_arg)
115 {
116 struct tls13_record_layer *rl = cb_arg;
117
118 return tls13_read_handshake_data(rl, buf, n);
119 }
120
121 int
tls13_handshake_msg_recv(struct tls13_handshake_msg * msg,struct tls13_record_layer * rl)122 tls13_handshake_msg_recv(struct tls13_handshake_msg *msg,
123 struct tls13_record_layer *rl)
124 {
125 uint8_t msg_type;
126 uint32_t msg_len;
127 CBS cbs;
128 int ret;
129
130 if (msg->data != NULL)
131 return TLS13_IO_FAILURE;
132
133 if (msg->msg_type == 0) {
134 if ((ret = tls_buffer_extend(msg->buf,
135 TLS13_HANDSHAKE_MSG_HEADER_LEN,
136 tls13_handshake_msg_read_cb, rl)) <= 0)
137 return ret;
138
139 if (!tls_buffer_data(msg->buf, &cbs))
140 return TLS13_IO_FAILURE;
141
142 if (!CBS_get_u8(&cbs, &msg_type))
143 return TLS13_IO_FAILURE;
144 if (!CBS_get_u24(&cbs, &msg_len))
145 return TLS13_IO_FAILURE;
146
147 /* XXX - do we want to make this variable on message type? */
148 if (msg_len > TLS13_HANDSHAKE_MSG_MAX_LEN)
149 return TLS13_IO_FAILURE;
150
151 msg->msg_type = msg_type;
152 msg->msg_len = msg_len;
153 }
154
155 if ((ret = tls_buffer_extend(msg->buf,
156 TLS13_HANDSHAKE_MSG_HEADER_LEN + msg->msg_len,
157 tls13_handshake_msg_read_cb, rl)) <= 0)
158 return ret;
159
160 if (!tls_buffer_finish(msg->buf, &msg->data, &msg->data_len))
161 return TLS13_IO_FAILURE;
162
163 return TLS13_IO_SUCCESS;
164 }
165
166 int
tls13_handshake_msg_send(struct tls13_handshake_msg * msg,struct tls13_record_layer * rl)167 tls13_handshake_msg_send(struct tls13_handshake_msg *msg,
168 struct tls13_record_layer *rl)
169 {
170 ssize_t ret;
171
172 if (msg->data == NULL)
173 return TLS13_IO_FAILURE;
174
175 if (CBS_len(&msg->cbs) == 0)
176 return TLS13_IO_FAILURE;
177
178 while (CBS_len(&msg->cbs) > 0) {
179 if ((ret = tls13_write_handshake_data(rl, CBS_data(&msg->cbs),
180 CBS_len(&msg->cbs))) <= 0)
181 return ret;
182
183 if (!CBS_skip(&msg->cbs, ret))
184 return TLS13_IO_FAILURE;
185 }
186
187 return TLS13_IO_SUCCESS;
188 }
189