1 /*  Copyright (C) 2018-2020 CZ.NIC, z.s.p.o. <knot-dns@labs.nic.cz>
2  *  SPDX-License-Identifier: GPL-3.0-or-later
3  */
4 
5 #include <libknot/packet/pkt.h>
6 
7 #include "lib/defines.h"
8 #include "daemon/session.h"
9 #include "daemon/engine.h"
10 #include "daemon/tls.h"
11 #include "daemon/http.h"
12 #include "daemon/worker.h"
13 #include "daemon/io.h"
14 #include "lib/generic/queue.h"
15 
16 #define TLS_CHUNK_SIZE (16 * 1024)
17 
18 /* Initial max frame size: https://tools.ietf.org/html/rfc7540#section-6.5.2 */
19 #define HTTP_MAX_FRAME_SIZE 16384
20 
21 /* Per-socket (TCP or UDP) persistent structure.
22  *
23  * In particular note that for UDP clients it's just one session (per socket)
24  * shared for all clients.  For TCP/TLS it's also for the connection-specific socket,
25  * i.e one session per connection.
26  *
27  * LATER(optim.): the memory here is used a bit wastefully.
28  */
29 struct session {
30 	struct session_flags sflags;  /**< miscellaneous flags. */
31 	union inaddr peer;            /**< address of peer; not for UDP clients (downstream) */
32 	union inaddr sockname;        /**< our local address; for UDP it may be a wildcard */
33 	uv_handle_t *handle;          /**< libuv handle for IO operations. */
34 	uv_timer_t timeout;           /**< libuv handle for timer. */
35 
36 	struct tls_ctx *tls_ctx;      /**< server side tls-related data. */
37 	struct tls_client_ctx *tls_client_ctx;  /**< client side tls-related data. */
38 
39 #if ENABLE_DOH2
40 	struct http_ctx *http_ctx;  /**< server side http-related data. */
41 #endif
42 
43 	trie_t *tasks;                /**< list of tasks associated with given session. */
44 	queue_t(struct qr_task *) waiting;  /**< list of tasks waiting for sending to upstream. */
45 
46 	uint8_t *wire_buf;            /**< Buffer for DNS message, except for XDP. */
47 	ssize_t wire_buf_size;        /**< Buffer size. */
48 	ssize_t wire_buf_start_idx;   /**< Data start offset in wire_buf. */
49 	ssize_t wire_buf_end_idx;     /**< Data end offset in wire_buf. */
50 	uint64_t last_activity;       /**< Time of last IO activity (if any occurs).
51 				       *   Otherwise session creation time. */
52 };
53 
on_session_close(uv_handle_t * handle)54 static void on_session_close(uv_handle_t *handle)
55 {
56 	struct session *session = handle->data;
57 	kr_require(session->handle == handle);
58 	io_free(handle);
59 }
60 
on_session_timer_close(uv_handle_t * timer)61 static void on_session_timer_close(uv_handle_t *timer)
62 {
63 	struct session *session = timer->data;
64 	uv_handle_t *handle = session->handle;
65 	kr_require(handle && handle->data == session);
66 	kr_require(session->sflags.outgoing || handle->type == UV_TCP);
67 	if (!uv_is_closing(handle)) {
68 		uv_close(handle, on_session_close);
69 	}
70 }
71 
session_free(struct session * session)72 void session_free(struct session *session)
73 {
74 	if (session) {
75 		session_clear(session);
76 		free(session);
77 	}
78 }
79 
session_clear(struct session * session)80 void session_clear(struct session *session)
81 {
82 	kr_require(session_is_empty(session));
83 	if (session->handle && session->handle->type == UV_TCP) {
84 		free(session->wire_buf);
85 	}
86 #if ENABLE_DOH2
87 	http_free(session->http_ctx);
88 #endif
89 	trie_clear(session->tasks);
90 	trie_free(session->tasks);
91 	queue_deinit(session->waiting);
92 	tls_free(session->tls_ctx);
93 	tls_client_ctx_free(session->tls_client_ctx);
94 	memset(session, 0, sizeof(*session));
95 }
96 
session_close(struct session * session)97 void session_close(struct session *session)
98 {
99 	kr_require(session_is_empty(session));
100 	if (session->sflags.closing) {
101 		return;
102 	}
103 
104 	uv_handle_t *handle = session->handle;
105 	io_stop_read(handle);
106 	session->sflags.closing = true;
107 
108 	if (!uv_is_closing((uv_handle_t *)&session->timeout)) {
109 		uv_timer_stop(&session->timeout);
110 		if (session->tls_client_ctx) {
111 			tls_close(&session->tls_client_ctx->c);
112 		}
113 		if (session->tls_ctx) {
114 			tls_close(&session->tls_ctx->c);
115 		}
116 
117 		session->timeout.data = session;
118 		uv_close((uv_handle_t *)&session->timeout, on_session_timer_close);
119 	}
120 }
121 
session_start_read(struct session * session)122 int session_start_read(struct session *session)
123 {
124 	return io_start_read(session->handle);
125 }
126 
session_stop_read(struct session * session)127 int session_stop_read(struct session *session)
128 {
129 	return io_stop_read(session->handle);
130 }
131 
session_waitinglist_push(struct session * session,struct qr_task * task)132 int session_waitinglist_push(struct session *session, struct qr_task *task)
133 {
134 	queue_push(session->waiting, task);
135 	worker_task_ref(task);
136 	return kr_ok();
137 }
138 
session_waitinglist_get(const struct session * session)139 struct qr_task *session_waitinglist_get(const struct session *session)
140 {
141 	return (queue_len(session->waiting) > 0) ? (queue_head(session->waiting)) : NULL;
142 }
143 
session_waitinglist_pop(struct session * session,bool deref)144 struct qr_task *session_waitinglist_pop(struct session *session, bool deref)
145 {
146 	struct qr_task *t = session_waitinglist_get(session);
147 	queue_pop(session->waiting);
148 	if (deref) {
149 		worker_task_unref(t);
150 	}
151 	return t;
152 }
153 
session_tasklist_add(struct session * session,struct qr_task * task)154 int session_tasklist_add(struct session *session, struct qr_task *task)
155 {
156 	trie_t *t = session->tasks;
157 	uint16_t task_msg_id = 0;
158 	const char *key = NULL;
159 	size_t key_len = 0;
160 	if (session->sflags.outgoing) {
161 		knot_pkt_t *pktbuf = worker_task_get_pktbuf(task);
162 		task_msg_id = knot_wire_get_id(pktbuf->wire);
163 		key = (const char *)&task_msg_id;
164 		key_len = sizeof(task_msg_id);
165 	} else {
166 		key = (const char *)&task;
167 		key_len = sizeof(char *);
168 	}
169 	trie_val_t *v = trie_get_ins(t, key, key_len);
170 	if (kr_fails_assert(v))
171 		return kr_error(ENOMEM);
172 	if (*v == NULL) {
173 		*v = task;
174 		worker_task_ref(task);
175 	} else if (kr_fails_assert(*v == task)) {
176 		return kr_error(EINVAL);
177 	}
178 	return kr_ok();
179 }
180 
session_tasklist_del(struct session * session,struct qr_task * task)181 int session_tasklist_del(struct session *session, struct qr_task *task)
182 {
183 	trie_t *t = session->tasks;
184 	uint16_t task_msg_id = 0;
185 	const char *key = NULL;
186 	size_t key_len = 0;
187 	trie_val_t val;
188 	if (session->sflags.outgoing) {
189 		knot_pkt_t *pktbuf = worker_task_get_pktbuf(task);
190 		task_msg_id = knot_wire_get_id(pktbuf->wire);
191 		key = (const char *)&task_msg_id;
192 		key_len = sizeof(task_msg_id);
193 	} else {
194 		key = (const char *)&task;
195 		key_len = sizeof(char *);
196 	}
197 	int ret = trie_del(t, key, key_len, &val);
198 	if (ret == KNOT_EOK) {
199 		kr_require(val == task);
200 		worker_task_unref(val);
201 	}
202 	return ret;
203 }
204 
session_tasklist_get_first(struct session * session)205 struct qr_task *session_tasklist_get_first(struct session *session)
206 {
207 	trie_val_t *val = trie_get_first(session->tasks, NULL, NULL);
208 	return val ? (struct qr_task *) *val : NULL;
209 }
210 
session_tasklist_del_first(struct session * session,bool deref)211 struct qr_task *session_tasklist_del_first(struct session *session, bool deref)
212 {
213 	trie_val_t val = NULL;
214 	int res = trie_del_first(session->tasks, NULL, NULL, &val);
215 	if (res != KNOT_EOK) {
216 		val = NULL;
217 	} else if (deref) {
218 		worker_task_unref(val);
219 	}
220 	return (struct qr_task *)val;
221 }
session_tasklist_del_msgid(const struct session * session,uint16_t msg_id)222 struct qr_task* session_tasklist_del_msgid(const struct session *session, uint16_t msg_id)
223 {
224 	if (kr_fails_assert(session->sflags.outgoing))
225 		return NULL;
226 	trie_t *t = session->tasks;
227 	struct qr_task *ret = NULL;
228 	const char *key = (const char *)&msg_id;
229 	size_t key_len = sizeof(msg_id);
230 	trie_val_t val;
231 	int res = trie_del(t, key, key_len, &val);
232 	if (res == KNOT_EOK) {
233 		if (worker_task_numrefs(val) > 1) {
234 			ret = val;
235 		}
236 		worker_task_unref(val);
237 	}
238 	return ret;
239 }
240 
session_tasklist_find_msgid(const struct session * session,uint16_t msg_id)241 struct qr_task* session_tasklist_find_msgid(const struct session *session, uint16_t msg_id)
242 {
243 	if (kr_fails_assert(session->sflags.outgoing))
244 		return NULL;
245 	trie_t *t = session->tasks;
246 	struct qr_task *ret = NULL;
247 	trie_val_t *val = trie_get_try(t, (char *)&msg_id, sizeof(msg_id));
248 	if (val) {
249 		ret = *val;
250 	}
251 	return ret;
252 }
253 
session_flags(struct session * session)254 struct session_flags *session_flags(struct session *session)
255 {
256 	return &session->sflags;
257 }
258 
session_get_peer(struct session * session)259 struct sockaddr *session_get_peer(struct session *session)
260 {
261 	return &session->peer.ip;
262 }
263 
session_get_sockname(struct session * session)264 struct sockaddr *session_get_sockname(struct session *session)
265 {
266 	return &session->sockname.ip;
267 }
268 
session_tls_get_server_ctx(const struct session * session)269 struct tls_ctx *session_tls_get_server_ctx(const struct session *session)
270 {
271 	return session->tls_ctx;
272 }
273 
session_tls_set_server_ctx(struct session * session,struct tls_ctx * ctx)274 void session_tls_set_server_ctx(struct session *session, struct tls_ctx *ctx)
275 {
276 	session->tls_ctx = ctx;
277 }
278 
session_tls_get_client_ctx(const struct session * session)279 struct tls_client_ctx *session_tls_get_client_ctx(const struct session *session)
280 {
281 	return session->tls_client_ctx;
282 }
283 
session_tls_set_client_ctx(struct session * session,struct tls_client_ctx * ctx)284 void session_tls_set_client_ctx(struct session *session, struct tls_client_ctx *ctx)
285 {
286 	session->tls_client_ctx = ctx;
287 }
288 
session_tls_get_common_ctx(const struct session * session)289 struct tls_common_ctx *session_tls_get_common_ctx(const struct session *session)
290 {
291 	struct tls_common_ctx *tls_ctx = session->sflags.outgoing ? &session->tls_client_ctx->c :
292 								    &session->tls_ctx->c;
293 	return tls_ctx;
294 }
295 
296 #if ENABLE_DOH2
session_http_get_server_ctx(const struct session * session)297 struct http_ctx *session_http_get_server_ctx(const struct session *session)
298 {
299 	return session->http_ctx;
300 }
301 
session_http_set_server_ctx(struct session * session,struct http_ctx * ctx)302 void session_http_set_server_ctx(struct session *session, struct http_ctx *ctx)
303 {
304 	session->http_ctx = ctx;
305 }
306 #endif
307 
session_get_handle(struct session * session)308 uv_handle_t *session_get_handle(struct session *session)
309 {
310 	return session->handle;
311 }
312 
session_get(uv_handle_t * h)313 struct session *session_get(uv_handle_t *h)
314 {
315 	return h->data;
316 }
317 
session_new(uv_handle_t * handle,bool has_tls,bool has_http)318 struct session *session_new(uv_handle_t *handle, bool has_tls, bool has_http)
319 {
320 	if (!handle) {
321 		return NULL;
322 	}
323 	struct session *session = calloc(1, sizeof(struct session));
324 	if (!session) {
325 		return NULL;
326 	}
327 
328 	queue_init(session->waiting);
329 	session->tasks = trie_create(NULL);
330 	if (handle->type == UV_TCP) {
331 		size_t wire_buffer_size = KNOT_WIRE_MAX_PKTSIZE;
332 		if (has_tls) {
333 			/* When decoding large packets,
334 			 * gnutls gives the application chunks of size 16 kb each. */
335 			wire_buffer_size += TLS_CHUNK_SIZE;
336 			session->sflags.has_tls = true;
337 		}
338 #if ENABLE_DOH2
339 		if (has_http) {
340 			/* When decoding large packets,
341 			 * HTTP/2 frames can be up to 16 KB by default. */
342 			wire_buffer_size += HTTP_MAX_FRAME_SIZE;
343 			session->sflags.has_http = true;
344 		}
345 #endif
346 		uint8_t *wire_buf = malloc(wire_buffer_size);
347 		if (!wire_buf) {
348 			free(session);
349 			return NULL;
350 		}
351 		session->wire_buf = wire_buf;
352 		session->wire_buf_size = wire_buffer_size;
353 	} else if (handle->type == UV_UDP) {
354 		/* We use the singleton buffer from worker for all UDP (!)
355 		 * libuv documentation doesn't really guarantee this is OK,
356 		 * but the implementation for unix systems does not hold
357 		 * the buffer (both UDP and TCP) - always makes a NON-blocking
358 		 * syscall that fills the buffer and immediately calls
359 		 * the callback, whatever the result of the operation.
360 		 * We still need to keep in mind to only touch the buffer
361 		 * in this callback... */
362 		kr_require(the_worker);
363 		session->wire_buf = the_worker->wire_buf;
364 		session->wire_buf_size = sizeof(the_worker->wire_buf);
365 	} else {
366 		kr_assert(handle->type == UV_POLL/*XDP*/);
367 		/* - wire_buf* are left zeroed, as they make no sense
368 		 * - timer is unused but OK for simplicity (server-side sessions are few)
369 		 */
370 	}
371 
372 	uv_timer_init(handle->loop, &session->timeout);
373 
374 	session->handle = handle;
375 	handle->data = session;
376 	session->timeout.data = session;
377 	session_touch(session);
378 
379 	return session;
380 }
381 
session_tasklist_get_len(const struct session * session)382 size_t session_tasklist_get_len(const struct session *session)
383 {
384 	return trie_weight(session->tasks);
385 }
386 
session_waitinglist_get_len(const struct session * session)387 size_t session_waitinglist_get_len(const struct session *session)
388 {
389 	return queue_len(session->waiting);
390 }
391 
session_tasklist_is_empty(const struct session * session)392 bool session_tasklist_is_empty(const struct session *session)
393 {
394 	return session_tasklist_get_len(session) == 0;
395 }
396 
session_waitinglist_is_empty(const struct session * session)397 bool session_waitinglist_is_empty(const struct session *session)
398 {
399 	return session_waitinglist_get_len(session) == 0;
400 }
401 
session_is_empty(const struct session * session)402 bool session_is_empty(const struct session *session)
403 {
404 	return session_tasklist_is_empty(session) &&
405 	       session_waitinglist_is_empty(session);
406 }
407 
session_has_tls(const struct session * session)408 bool session_has_tls(const struct session *session)
409 {
410 	return session->sflags.has_tls;
411 }
412 
session_set_has_tls(struct session * session,bool has_tls)413 void session_set_has_tls(struct session *session, bool has_tls)
414 {
415 	session->sflags.has_tls = has_tls;
416 }
417 
session_waitinglist_retry(struct session * session,bool increase_timeout_cnt)418 void session_waitinglist_retry(struct session *session, bool increase_timeout_cnt)
419 {
420 	while (!session_waitinglist_is_empty(session)) {
421 		struct qr_task *task = session_waitinglist_pop(session, false);
422 		if (increase_timeout_cnt) {
423 			worker_task_timeout_inc(task);
424 		}
425 		worker_task_step(task, &session->peer.ip, NULL);
426 		worker_task_unref(task);
427 	}
428 }
429 
session_waitinglist_finalize(struct session * session,int status)430 void session_waitinglist_finalize(struct session *session, int status)
431 {
432 	while (!session_waitinglist_is_empty(session)) {
433 		struct qr_task *t = session_waitinglist_pop(session, false);
434 		worker_task_finalize(t, status);
435 		worker_task_unref(t);
436 	}
437 }
438 
session_tasklist_finalize(struct session * session,int status)439 void session_tasklist_finalize(struct session *session, int status)
440 {
441 	while (session_tasklist_get_len(session) > 0) {
442 		struct qr_task *t = session_tasklist_del_first(session, false);
443 		kr_require(worker_task_numrefs(t) > 0);
444 		worker_task_finalize(t, status);
445 		worker_task_unref(t);
446 	}
447 }
448 
session_tasklist_finalize_expired(struct session * session)449 int session_tasklist_finalize_expired(struct session *session)
450 {
451 	int ret = 0;
452 	queue_t(struct qr_task *) q;
453 	uint64_t now = kr_now();
454 	trie_t *t = session->tasks;
455 	trie_it_t *it;
456 	queue_init(q);
457 	for (it = trie_it_begin(t); !trie_it_finished(it); trie_it_next(it)) {
458 		trie_val_t *v = trie_it_val(it);
459 		struct qr_task *task = (struct qr_task *)*v;
460 		if ((now - worker_task_creation_time(task)) >= KR_RESOLVE_TIME_LIMIT) {
461 			queue_push(q, task);
462 			worker_task_ref(task);
463 		}
464 	}
465 	trie_it_free(it);
466 
467 	struct qr_task *task = NULL;
468 	uint16_t msg_id = 0;
469 	char *key = (char *)&task;
470 	int32_t keylen = sizeof(struct qr_task *);
471 	if (session->sflags.outgoing) {
472 		key = (char *)&msg_id;
473 		keylen = sizeof(msg_id);
474 	}
475 	while (queue_len(q) > 0) {
476 		task = queue_head(q);
477 		if (session->sflags.outgoing) {
478 			knot_pkt_t *pktbuf = worker_task_get_pktbuf(task);
479 			msg_id = knot_wire_get_id(pktbuf->wire);
480 		}
481 		int res = trie_del(t, key, keylen, NULL);
482 		if (!worker_task_finished(task)) {
483 			/* task->pending_count must be zero,
484 			 * but there are can be followers,
485 			 * so run worker_task_subreq_finalize() to ensure retrying
486 			 * for all the followers. */
487 			worker_task_subreq_finalize(task);
488 			worker_task_finalize(task, KR_STATE_FAIL);
489 		}
490 		if (res == KNOT_EOK) {
491 			worker_task_unref(task);
492 		}
493 		queue_pop(q);
494 		worker_task_unref(task);
495 		++ret;
496 	}
497 
498 	queue_deinit(q);
499 	return ret;
500 }
501 
session_timer_start(struct session * session,uv_timer_cb cb,uint64_t timeout,uint64_t repeat)502 int session_timer_start(struct session *session, uv_timer_cb cb,
503 			uint64_t timeout, uint64_t repeat)
504 {
505 	uv_timer_t *timer = &session->timeout;
506 	// Session might be closing and get here e.g. through a late on_send callback.
507 	const bool is_closing = uv_is_closing((uv_handle_t *)timer);
508 	if (is_closing || kr_fails_assert(is_closing == session->sflags.closing))
509 		return kr_error(EINVAL);
510 
511 	if (kr_fails_assert(timer->data == session))
512 		return kr_error(EINVAL);
513 	int ret = uv_timer_start(timer, cb, timeout, repeat);
514 	if (ret != 0) {
515 		uv_timer_stop(timer);
516 		return kr_error(ret);
517 	}
518 	return kr_ok();
519 }
520 
session_timer_restart(struct session * session)521 int session_timer_restart(struct session *session)
522 {
523 	kr_require(!uv_is_closing((uv_handle_t *)&session->timeout));
524 	return uv_timer_again(&session->timeout);
525 }
526 
session_timer_stop(struct session * session)527 int session_timer_stop(struct session *session)
528 {
529 	return uv_timer_stop(&session->timeout);
530 }
531 
session_wirebuf_consume(struct session * session,const uint8_t * data,ssize_t len)532 ssize_t session_wirebuf_consume(struct session *session, const uint8_t *data, ssize_t len)
533 {
534 	if (data != &session->wire_buf[session->wire_buf_end_idx]) {
535 		/* shouldn't happen */
536 		return kr_error(EINVAL);
537 	}
538 
539 	if (len < 0) {
540 		/* shouldn't happen */
541 		return kr_error(EINVAL);
542 	}
543 
544 	if (session->wire_buf_end_idx + len > session->wire_buf_size) {
545 		/* shouldn't happen */
546 		return kr_error(EINVAL);
547 	}
548 
549 	session->wire_buf_end_idx += len;
550 	return len;
551 }
552 
session_produce_packet(struct session * session,knot_mm_t * mm)553 knot_pkt_t *session_produce_packet(struct session *session, knot_mm_t *mm)
554 {
555 	session->sflags.wirebuf_error = false;
556 	if (session->wire_buf_end_idx == 0) {
557 		return NULL;
558 	}
559 
560 	if (session->wire_buf_start_idx == session->wire_buf_end_idx) {
561 		session->wire_buf_start_idx = 0;
562 		session->wire_buf_end_idx = 0;
563 		return NULL;
564 	}
565 
566 	if (session->wire_buf_start_idx > session->wire_buf_end_idx) {
567 		session->sflags.wirebuf_error = true;
568 		session->wire_buf_start_idx = 0;
569 		session->wire_buf_end_idx = 0;
570 		return NULL;
571 	}
572 
573 	const uv_handle_t *handle = session->handle;
574 	uint8_t *msg_start = &session->wire_buf[session->wire_buf_start_idx];
575 	ssize_t wirebuf_msg_data_size = session->wire_buf_end_idx - session->wire_buf_start_idx;
576 	uint16_t msg_size = 0;
577 
578 	if (!handle) {
579 		session->sflags.wirebuf_error = true;
580 		return NULL;
581 	} else if (handle->type == UV_TCP) {
582 		if (wirebuf_msg_data_size < 2) {
583 			return NULL;
584 		}
585 		msg_size = knot_wire_read_u16(msg_start);
586 		if (msg_size >= session->wire_buf_size) {
587 			session->sflags.wirebuf_error = true;
588 			return NULL;
589 		}
590 		if (msg_size + 2 > wirebuf_msg_data_size) {
591 			return NULL;
592 		}
593 		if (msg_size == 0) {
594 			session->sflags.wirebuf_error = true;
595 			return NULL;
596 		}
597 		msg_start += 2;
598 	} else if (wirebuf_msg_data_size < UINT16_MAX) {
599 		msg_size = wirebuf_msg_data_size;
600 	} else {
601 		session->sflags.wirebuf_error = true;
602 		return NULL;
603 	}
604 
605 
606 	knot_pkt_t *pkt = knot_pkt_new(msg_start, msg_size, mm);
607 	session->sflags.wirebuf_error = (pkt == NULL);
608 	return pkt;
609 }
610 
session_discard_packet(struct session * session,const knot_pkt_t * pkt)611 int session_discard_packet(struct session *session, const knot_pkt_t *pkt)
612 {
613 	uv_handle_t *handle = session->handle;
614 	/* Pointer to data start in wire_buf */
615 	uint8_t *wirebuf_data_start = &session->wire_buf[session->wire_buf_start_idx];
616 	/* Number of data bytes in wire_buf */
617 	size_t wirebuf_data_size = session->wire_buf_end_idx - session->wire_buf_start_idx;
618 	/* Pointer to message start in wire_buf */
619 	uint8_t *wirebuf_msg_start = wirebuf_data_start;
620 	/* Number of message bytes in wire_buf.
621 	 * For UDP it is the same number as wirebuf_data_size. */
622 	size_t wirebuf_msg_size = wirebuf_data_size;
623 	/* Wire data from parsed packet. */
624 	uint8_t *pkt_msg_start = pkt->wire;
625 	/* Number of bytes in packet wire buffer. */
626 	size_t pkt_msg_size = pkt->size;
627 	if (knot_pkt_has_tsig(pkt)) {
628 		pkt_msg_size += pkt->tsig_wire.len;
629 	}
630 
631 	session->sflags.wirebuf_error = true;
632 
633 	if (!handle) {
634 		return kr_error(EINVAL);
635 	} else if (handle->type == UV_TCP) {
636 		/* wire_buf contains TCP DNS message. */
637 		if (kr_fails_assert(wirebuf_data_size >= 2)) {
638 			/* TCP message length field isn't in buffer, must not happen. */
639 			session->wire_buf_start_idx = 0;
640 			session->wire_buf_end_idx = 0;
641 			return kr_error(EINVAL);
642 		}
643 		wirebuf_msg_size = knot_wire_read_u16(wirebuf_msg_start);
644 		wirebuf_msg_start += 2;
645 		if (kr_fails_assert(wirebuf_msg_size + 2 <= wirebuf_data_size)) {
646 			/* TCP message length field is greater then
647 			 * number of bytes in buffer, must not happen. */
648 			session->wire_buf_start_idx = 0;
649 			session->wire_buf_end_idx = 0;
650 			return kr_error(EINVAL);
651 		}
652 	}
653 
654 	if (kr_fails_assert(wirebuf_msg_start == pkt_msg_start)) {
655 		/* packet wirebuf must be located at the beginning
656 		 * of the session wirebuf, must not happen. */
657 		session->wire_buf_start_idx = 0;
658 		session->wire_buf_end_idx = 0;
659 		return kr_error(EINVAL);
660 	}
661 
662 	if (kr_fails_assert(wirebuf_msg_size >= pkt_msg_size)) {
663 		/* Message length field is lesser then packet size,
664 		 * must not happen. */
665 		session->wire_buf_start_idx = 0;
666 		session->wire_buf_end_idx = 0;
667 		return kr_error(EINVAL);
668 	}
669 
670 	if (handle->type == UV_TCP) {
671 		session->wire_buf_start_idx += wirebuf_msg_size + 2;
672 	} else {
673 		session->wire_buf_start_idx += pkt_msg_size;
674 	}
675 	session->sflags.wirebuf_error = false;
676 
677 	wirebuf_data_size = session->wire_buf_end_idx - session->wire_buf_start_idx;
678 	if (wirebuf_data_size == 0) {
679 		session_wirebuf_discard(session);
680 	} else if (wirebuf_data_size < KNOT_WIRE_HEADER_SIZE) {
681 		session_wirebuf_compress(session);
682 	}
683 
684 	return kr_ok();
685 }
686 
session_wirebuf_discard(struct session * session)687 void session_wirebuf_discard(struct session *session)
688 {
689 	session->wire_buf_start_idx = 0;
690 	session->wire_buf_end_idx = 0;
691 }
692 
session_wirebuf_compress(struct session * session)693 void session_wirebuf_compress(struct session *session)
694 {
695 	if (session->wire_buf_start_idx == 0) {
696 		return;
697 	}
698 	uint8_t *wirebuf_data_start = &session->wire_buf[session->wire_buf_start_idx];
699 	size_t wirebuf_data_size = session->wire_buf_end_idx - session->wire_buf_start_idx;
700 	if (session->wire_buf_start_idx < wirebuf_data_size) {
701 		memmove(session->wire_buf, wirebuf_data_start, wirebuf_data_size);
702 	} else {
703 		memcpy(session->wire_buf, wirebuf_data_start, wirebuf_data_size);
704 	}
705 	session->wire_buf_start_idx = 0;
706 	session->wire_buf_end_idx = wirebuf_data_size;
707 }
708 
session_wirebuf_error(struct session * session)709 bool session_wirebuf_error(struct session *session)
710 {
711 	return session->sflags.wirebuf_error;
712 }
713 
session_wirebuf_get_start(struct session * session)714 uint8_t *session_wirebuf_get_start(struct session *session)
715 {
716 	return session->wire_buf;
717 }
718 
session_wirebuf_get_size(struct session * session)719 size_t session_wirebuf_get_size(struct session *session)
720 {
721 	return session->wire_buf_size;
722 }
723 
session_wirebuf_get_free_start(struct session * session)724 uint8_t *session_wirebuf_get_free_start(struct session *session)
725 {
726 	return &session->wire_buf[session->wire_buf_end_idx];
727 }
728 
session_wirebuf_get_free_size(struct session * session)729 size_t session_wirebuf_get_free_size(struct session *session)
730 {
731 	return session->wire_buf_size - session->wire_buf_end_idx;
732 }
733 
session_poison(struct session * session)734 void session_poison(struct session *session)
735 {
736 	kr_asan_poison(session, sizeof(*session));
737 }
738 
session_unpoison(struct session * session)739 void session_unpoison(struct session *session)
740 {
741 	kr_asan_unpoison(session, sizeof(*session));
742 }
743 
session_wirebuf_process(struct session * session,const struct sockaddr * peer)744 int session_wirebuf_process(struct session *session, const struct sockaddr *peer)
745 {
746 	int ret = 0;
747 	if (session->wire_buf_start_idx == session->wire_buf_end_idx)
748 		return ret;
749 
750 	size_t wirebuf_data_size = session->wire_buf_end_idx - session->wire_buf_start_idx;
751 	uint32_t max_iterations = (wirebuf_data_size /
752 		(KNOT_WIRE_HEADER_SIZE + KNOT_WIRE_QUESTION_MIN_SIZE)) + 1;
753 	knot_pkt_t *pkt = NULL;
754 
755 	while (((pkt = session_produce_packet(session, &the_worker->pkt_pool)) != NULL) &&
756 	       (ret < max_iterations)) {
757 		if (kr_fails_assert(!session_wirebuf_error(session)))
758 			return -1;
759 		int res = worker_submit(session, peer, NULL, NULL, NULL, pkt);
760 		/* Errors from worker_submit() are intentionally *not* handled in order to
761 		 * ensure the entire wire buffer is processed. */
762 		if (res == kr_ok())
763 			ret += 1;
764 		if (session_discard_packet(session, pkt) < 0) {
765 			/* Packet data isn't stored in memory as expected.
766 			 * something went wrong, normally should not happen. */
767 			break;
768 		}
769 	}
770 
771 	/* worker_submit() may cause the session to close (e.g. due to IO
772 	 * write error when the packet triggers an immediate answer). This is
773 	 * an error state, as well as any wirebuf error. */
774 	if (session->sflags.closing || session_wirebuf_error(session))
775 		ret = -1;
776 
777 	return ret;
778 }
779 
session_kill_ioreq(struct session * session,struct qr_task * task)780 void session_kill_ioreq(struct session *session, struct qr_task *task)
781 {
782 	if (!session || session->sflags.closing)
783 		return;
784 	if (kr_fails_assert(session->sflags.outgoing && session->handle))
785 		return;
786 	session_tasklist_del(session, task);
787 	if (session->handle->type == UV_UDP) {
788 		session_close(session);
789 		return;
790 	}
791 }
792 
793 /** Update timestamp */
session_touch(struct session * session)794 void session_touch(struct session *session)
795 {
796 	session->last_activity = kr_now();
797 }
798 
session_last_activity(struct session * session)799 uint64_t session_last_activity(struct session *session)
800 {
801 	return session->last_activity;
802 }
803