xref: /netbsd/external/mpl/bind/dist/lib/dns/tcpmsg.c (revision c0b5d9fb)
1 /*	$NetBSD: tcpmsg.c,v 1.6 2022/09/23 12:15:30 christos Exp $	*/
2 
3 /*
4  * Copyright (C) Internet Systems Consortium, Inc. ("ISC")
5  *
6  * SPDX-License-Identifier: MPL-2.0
7  *
8  * This Source Code Form is subject to the terms of the Mozilla Public
9  * License, v. 2.0.  If a copy of the MPL was not distributed with this
10  * file, you can obtain one at https://mozilla.org/MPL/2.0/.
11  *
12  * See the COPYRIGHT file distributed with this work for additional
13  * information regarding copyright ownership.
14  */
15 
16 /*! \file */
17 
18 #include <inttypes.h>
19 
20 #include <isc/mem.h>
21 #include <isc/print.h>
22 #include <isc/task.h>
23 #include <isc/util.h>
24 
25 #include <dns/events.h>
26 #include <dns/result.h>
27 #include <dns/tcpmsg.h>
28 
29 #ifdef TCPMSG_DEBUG
30 #include <stdio.h> /* Required for printf. */
31 #define XDEBUG(x) printf x
32 #else /* ifdef TCPMSG_DEBUG */
33 #define XDEBUG(x)
34 #endif /* ifdef TCPMSG_DEBUG */
35 
36 #define TCPMSG_MAGIC	  ISC_MAGIC('T', 'C', 'P', 'm')
37 #define VALID_TCPMSG(foo) ISC_MAGIC_VALID(foo, TCPMSG_MAGIC)
38 
39 static void
40 recv_length(isc_task_t *, isc_event_t *);
41 static void
42 recv_message(isc_task_t *, isc_event_t *);
43 
44 static void
recv_length(isc_task_t * task,isc_event_t * ev_in)45 recv_length(isc_task_t *task, isc_event_t *ev_in) {
46 	isc_socketevent_t *ev = (isc_socketevent_t *)ev_in;
47 	isc_event_t *dev;
48 	dns_tcpmsg_t *tcpmsg = ev_in->ev_arg;
49 	isc_region_t region;
50 	isc_result_t result;
51 
52 	INSIST(VALID_TCPMSG(tcpmsg));
53 
54 	dev = &tcpmsg->event;
55 	tcpmsg->address = ev->address;
56 
57 	if (ev->result != ISC_R_SUCCESS) {
58 		tcpmsg->result = ev->result;
59 		goto send_and_free;
60 	}
61 
62 	/*
63 	 * Success.
64 	 */
65 	tcpmsg->size = ntohs(tcpmsg->size);
66 	if (tcpmsg->size == 0) {
67 		tcpmsg->result = ISC_R_UNEXPECTEDEND;
68 		goto send_and_free;
69 	}
70 	if (tcpmsg->size > tcpmsg->maxsize) {
71 		tcpmsg->result = ISC_R_RANGE;
72 		goto send_and_free;
73 	}
74 
75 	region.base = isc_mem_get(tcpmsg->mctx, tcpmsg->size);
76 	region.length = tcpmsg->size;
77 	if (region.base == NULL) {
78 		tcpmsg->result = ISC_R_NOMEMORY;
79 		goto send_and_free;
80 	}
81 	XDEBUG(("Allocated %d bytes\n", tcpmsg->size));
82 
83 	isc_buffer_init(&tcpmsg->buffer, region.base, region.length);
84 	result = isc_socket_recv(tcpmsg->sock, &region, 0, task, recv_message,
85 				 tcpmsg);
86 	if (result != ISC_R_SUCCESS) {
87 		tcpmsg->result = result;
88 		goto send_and_free;
89 	}
90 
91 	isc_event_free(&ev_in);
92 	return;
93 
94 send_and_free:
95 	isc_task_send(tcpmsg->task, &dev);
96 	tcpmsg->task = NULL;
97 	isc_event_free(&ev_in);
98 	return;
99 }
100 
101 static void
recv_message(isc_task_t * task,isc_event_t * ev_in)102 recv_message(isc_task_t *task, isc_event_t *ev_in) {
103 	isc_socketevent_t *ev = (isc_socketevent_t *)ev_in;
104 	isc_event_t *dev;
105 	dns_tcpmsg_t *tcpmsg = ev_in->ev_arg;
106 
107 	(void)task;
108 
109 	INSIST(VALID_TCPMSG(tcpmsg));
110 
111 	dev = &tcpmsg->event;
112 	tcpmsg->address = ev->address;
113 
114 	if (ev->result != ISC_R_SUCCESS) {
115 		tcpmsg->result = ev->result;
116 		goto send_and_free;
117 	}
118 
119 	tcpmsg->result = ISC_R_SUCCESS;
120 	isc_buffer_add(&tcpmsg->buffer, ev->n);
121 
122 	XDEBUG(("Received %u bytes (of %d)\n", ev->n, tcpmsg->size));
123 
124 send_and_free:
125 	isc_task_send(tcpmsg->task, &dev);
126 	tcpmsg->task = NULL;
127 	isc_event_free(&ev_in);
128 }
129 
130 void
dns_tcpmsg_init(isc_mem_t * mctx,isc_socket_t * sock,dns_tcpmsg_t * tcpmsg)131 dns_tcpmsg_init(isc_mem_t *mctx, isc_socket_t *sock, dns_tcpmsg_t *tcpmsg) {
132 	REQUIRE(mctx != NULL);
133 	REQUIRE(sock != NULL);
134 	REQUIRE(tcpmsg != NULL);
135 
136 	tcpmsg->magic = TCPMSG_MAGIC;
137 	tcpmsg->size = 0;
138 	tcpmsg->buffer.base = NULL;
139 	tcpmsg->buffer.length = 0;
140 	tcpmsg->maxsize = 65535; /* Largest message possible. */
141 	tcpmsg->mctx = mctx;
142 	tcpmsg->sock = sock;
143 	tcpmsg->task = NULL;		   /* None yet. */
144 	tcpmsg->result = ISC_R_UNEXPECTED; /* None yet. */
145 
146 	/* Should probably initialize the event here, but it can wait. */
147 }
148 
149 void
dns_tcpmsg_setmaxsize(dns_tcpmsg_t * tcpmsg,unsigned int maxsize)150 dns_tcpmsg_setmaxsize(dns_tcpmsg_t *tcpmsg, unsigned int maxsize) {
151 	REQUIRE(VALID_TCPMSG(tcpmsg));
152 	REQUIRE(maxsize < 65536);
153 
154 	tcpmsg->maxsize = maxsize;
155 }
156 
157 isc_result_t
dns_tcpmsg_readmessage(dns_tcpmsg_t * tcpmsg,isc_task_t * task,isc_taskaction_t action,void * arg)158 dns_tcpmsg_readmessage(dns_tcpmsg_t *tcpmsg, isc_task_t *task,
159 		       isc_taskaction_t action, void *arg) {
160 	isc_result_t result;
161 	isc_region_t region;
162 
163 	REQUIRE(VALID_TCPMSG(tcpmsg));
164 	REQUIRE(task != NULL);
165 	REQUIRE(tcpmsg->task == NULL); /* not currently in use */
166 
167 	if (tcpmsg->buffer.base != NULL) {
168 		isc_mem_put(tcpmsg->mctx, tcpmsg->buffer.base,
169 			    tcpmsg->buffer.length);
170 		tcpmsg->buffer.base = NULL;
171 		tcpmsg->buffer.length = 0;
172 	}
173 
174 	tcpmsg->task = task;
175 	tcpmsg->action = action;
176 	tcpmsg->arg = arg;
177 	tcpmsg->result = ISC_R_UNEXPECTED; /* unknown right now */
178 
179 	ISC_EVENT_INIT(&tcpmsg->event, sizeof(isc_event_t), 0, 0,
180 		       DNS_EVENT_TCPMSG, action, arg, tcpmsg, NULL, NULL);
181 
182 	region.base = (unsigned char *)&tcpmsg->size;
183 	region.length = 2; /* uint16_t */
184 	result = isc_socket_recv(tcpmsg->sock, &region, 0, tcpmsg->task,
185 				 recv_length, tcpmsg);
186 
187 	if (result != ISC_R_SUCCESS) {
188 		tcpmsg->task = NULL;
189 	}
190 
191 	return (result);
192 }
193 
194 void
dns_tcpmsg_cancelread(dns_tcpmsg_t * tcpmsg)195 dns_tcpmsg_cancelread(dns_tcpmsg_t *tcpmsg) {
196 	REQUIRE(VALID_TCPMSG(tcpmsg));
197 
198 	isc_socket_cancel(tcpmsg->sock, NULL, ISC_SOCKCANCEL_RECV);
199 }
200 
201 void
dns_tcpmsg_keepbuffer(dns_tcpmsg_t * tcpmsg,isc_buffer_t * buffer)202 dns_tcpmsg_keepbuffer(dns_tcpmsg_t *tcpmsg, isc_buffer_t *buffer) {
203 	REQUIRE(VALID_TCPMSG(tcpmsg));
204 	REQUIRE(buffer != NULL);
205 
206 	*buffer = tcpmsg->buffer;
207 	tcpmsg->buffer.base = NULL;
208 	tcpmsg->buffer.length = 0;
209 }
210 
211 #if 0
212 void
213 dns_tcpmsg_freebuffer(dns_tcpmsg_t *tcpmsg) {
214 	REQUIRE(VALID_TCPMSG(tcpmsg));
215 
216 	if (tcpmsg->buffer.base == NULL) {
217 		return;
218 	}
219 
220 	isc_mem_put(tcpmsg->mctx, tcpmsg->buffer.base, tcpmsg->buffer.length);
221 	tcpmsg->buffer.base = NULL;
222 	tcpmsg->buffer.length = 0;
223 }
224 #endif /* if 0 */
225 
226 void
dns_tcpmsg_invalidate(dns_tcpmsg_t * tcpmsg)227 dns_tcpmsg_invalidate(dns_tcpmsg_t *tcpmsg) {
228 	REQUIRE(VALID_TCPMSG(tcpmsg));
229 
230 	tcpmsg->magic = 0;
231 
232 	if (tcpmsg->buffer.base != NULL) {
233 		isc_mem_put(tcpmsg->mctx, tcpmsg->buffer.base,
234 			    tcpmsg->buffer.length);
235 		tcpmsg->buffer.base = NULL;
236 		tcpmsg->buffer.length = 0;
237 	}
238 }
239