xref: /freebsd/sbin/hastd/proto.c (revision 069ac184)
1 /*-
2  * SPDX-License-Identifier: BSD-2-Clause
3  *
4  * Copyright (c) 2009-2010 The FreeBSD Foundation
5  * All rights reserved.
6  *
7  * This software was developed by Pawel Jakub Dawidek under sponsorship from
8  * the FreeBSD Foundation.
9  *
10  * Redistribution and use in source and binary forms, with or without
11  * modification, are permitted provided that the following conditions
12  * are met:
13  * 1. Redistributions of source code must retain the above copyright
14  *    notice, this list of conditions and the following disclaimer.
15  * 2. Redistributions in binary form must reproduce the above copyright
16  *    notice, this list of conditions and the following disclaimer in the
17  *    documentation and/or other materials provided with the distribution.
18  *
19  * THIS SOFTWARE IS PROVIDED BY THE AUTHORS AND CONTRIBUTORS ``AS IS'' AND
20  * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21  * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
22  * ARE DISCLAIMED.  IN NO EVENT SHALL THE AUTHORS OR CONTRIBUTORS BE LIABLE
23  * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24  * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS
25  * OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
26  * HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
27  * LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY
28  * OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF
29  * SUCH DAMAGE.
30  */
31 
32 #include <sys/types.h>
33 #include <sys/queue.h>
34 #include <sys/socket.h>
35 
36 #include <errno.h>
37 #include <stdint.h>
38 #include <string.h>
39 #include <strings.h>
40 
41 #include "pjdlog.h"
42 #include "proto.h"
43 #include "proto_impl.h"
44 
45 #define	PROTO_CONN_MAGIC	0x907041c
46 struct proto_conn {
47 	int		 pc_magic;
48 	struct proto	*pc_proto;
49 	void		*pc_ctx;
50 	int		 pc_side;
51 #define	PROTO_SIDE_CLIENT		0
52 #define	PROTO_SIDE_SERVER_LISTEN	1
53 #define	PROTO_SIDE_SERVER_WORK		2
54 };
55 
56 static TAILQ_HEAD(, proto) protos = TAILQ_HEAD_INITIALIZER(protos);
57 
58 void
59 proto_register(struct proto *proto, bool isdefault)
60 {
61 	static bool seen_default = false;
62 
63 	if (!isdefault)
64 		TAILQ_INSERT_HEAD(&protos, proto, prt_next);
65 	else {
66 		PJDLOG_ASSERT(!seen_default);
67 		seen_default = true;
68 		TAILQ_INSERT_TAIL(&protos, proto, prt_next);
69 	}
70 }
71 
72 static struct proto_conn *
73 proto_alloc(struct proto *proto, int side)
74 {
75 	struct proto_conn *conn;
76 
77 	PJDLOG_ASSERT(proto != NULL);
78 	PJDLOG_ASSERT(side == PROTO_SIDE_CLIENT ||
79 	    side == PROTO_SIDE_SERVER_LISTEN ||
80 	    side == PROTO_SIDE_SERVER_WORK);
81 
82 	conn = malloc(sizeof(*conn));
83 	if (conn != NULL) {
84 		conn->pc_proto = proto;
85 		conn->pc_side = side;
86 		conn->pc_magic = PROTO_CONN_MAGIC;
87 	}
88 	return (conn);
89 }
90 
91 static void
92 proto_free(struct proto_conn *conn)
93 {
94 
95 	PJDLOG_ASSERT(conn != NULL);
96 	PJDLOG_ASSERT(conn->pc_magic == PROTO_CONN_MAGIC);
97 	PJDLOG_ASSERT(conn->pc_side == PROTO_SIDE_CLIENT ||
98 	    conn->pc_side == PROTO_SIDE_SERVER_LISTEN ||
99 	    conn->pc_side == PROTO_SIDE_SERVER_WORK);
100 	PJDLOG_ASSERT(conn->pc_proto != NULL);
101 
102 	bzero(conn, sizeof(*conn));
103 	free(conn);
104 }
105 
106 static int
107 proto_common_setup(const char *srcaddr, const char *dstaddr,
108     struct proto_conn **connp, int side)
109 {
110 	struct proto *proto;
111 	struct proto_conn *conn;
112 	void *ctx;
113 	int ret;
114 
115 	PJDLOG_ASSERT(side == PROTO_SIDE_CLIENT ||
116 	    side == PROTO_SIDE_SERVER_LISTEN);
117 
118 	TAILQ_FOREACH(proto, &protos, prt_next) {
119 		if (side == PROTO_SIDE_CLIENT) {
120 			if (proto->prt_client == NULL)
121 				ret = -1;
122 			else
123 				ret = proto->prt_client(srcaddr, dstaddr, &ctx);
124 		} else /* if (side == PROTO_SIDE_SERVER_LISTEN) */ {
125 			if (proto->prt_server == NULL)
126 				ret = -1;
127 			else
128 				ret = proto->prt_server(dstaddr, &ctx);
129 		}
130 		/*
131 		 * ret == 0  - success
132 		 * ret == -1 - dstaddr is not for this protocol
133 		 * ret > 0   - right protocol, but an error occurred
134 		 */
135 		if (ret >= 0)
136 			break;
137 	}
138 	if (proto == NULL) {
139 		/* Unrecognized address. */
140 		errno = EINVAL;
141 		return (-1);
142 	}
143 	if (ret > 0) {
144 		/* An error occurred. */
145 		errno = ret;
146 		return (-1);
147 	}
148 	conn = proto_alloc(proto, side);
149 	if (conn == NULL) {
150 		if (proto->prt_close != NULL)
151 			proto->prt_close(ctx);
152 		errno = ENOMEM;
153 		return (-1);
154 	}
155 	conn->pc_ctx = ctx;
156 	*connp = conn;
157 
158 	return (0);
159 }
160 
161 int
162 proto_client(const char *srcaddr, const char *dstaddr,
163     struct proto_conn **connp)
164 {
165 
166 	return (proto_common_setup(srcaddr, dstaddr, connp, PROTO_SIDE_CLIENT));
167 }
168 
169 int
170 proto_connect(struct proto_conn *conn, int timeout)
171 {
172 	int ret;
173 
174 	PJDLOG_ASSERT(conn != NULL);
175 	PJDLOG_ASSERT(conn->pc_magic == PROTO_CONN_MAGIC);
176 	PJDLOG_ASSERT(conn->pc_side == PROTO_SIDE_CLIENT);
177 	PJDLOG_ASSERT(conn->pc_proto != NULL);
178 	PJDLOG_ASSERT(conn->pc_proto->prt_connect != NULL);
179 	PJDLOG_ASSERT(timeout >= -1);
180 
181 	ret = conn->pc_proto->prt_connect(conn->pc_ctx, timeout);
182 	if (ret != 0) {
183 		errno = ret;
184 		return (-1);
185 	}
186 
187 	return (0);
188 }
189 
190 int
191 proto_connect_wait(struct proto_conn *conn, int timeout)
192 {
193 	int ret;
194 
195 	PJDLOG_ASSERT(conn != NULL);
196 	PJDLOG_ASSERT(conn->pc_magic == PROTO_CONN_MAGIC);
197 	PJDLOG_ASSERT(conn->pc_side == PROTO_SIDE_CLIENT);
198 	PJDLOG_ASSERT(conn->pc_proto != NULL);
199 	PJDLOG_ASSERT(conn->pc_proto->prt_connect_wait != NULL);
200 	PJDLOG_ASSERT(timeout >= 0);
201 
202 	ret = conn->pc_proto->prt_connect_wait(conn->pc_ctx, timeout);
203 	if (ret != 0) {
204 		errno = ret;
205 		return (-1);
206 	}
207 
208 	return (0);
209 }
210 
211 int
212 proto_server(const char *addr, struct proto_conn **connp)
213 {
214 
215 	return (proto_common_setup(NULL, addr, connp, PROTO_SIDE_SERVER_LISTEN));
216 }
217 
218 int
219 proto_accept(struct proto_conn *conn, struct proto_conn **newconnp)
220 {
221 	struct proto_conn *newconn;
222 	int ret;
223 
224 	PJDLOG_ASSERT(conn != NULL);
225 	PJDLOG_ASSERT(conn->pc_magic == PROTO_CONN_MAGIC);
226 	PJDLOG_ASSERT(conn->pc_side == PROTO_SIDE_SERVER_LISTEN);
227 	PJDLOG_ASSERT(conn->pc_proto != NULL);
228 	PJDLOG_ASSERT(conn->pc_proto->prt_accept != NULL);
229 
230 	newconn = proto_alloc(conn->pc_proto, PROTO_SIDE_SERVER_WORK);
231 	if (newconn == NULL)
232 		return (-1);
233 
234 	ret = conn->pc_proto->prt_accept(conn->pc_ctx, &newconn->pc_ctx);
235 	if (ret != 0) {
236 		proto_free(newconn);
237 		errno = ret;
238 		return (-1);
239 	}
240 
241 	*newconnp = newconn;
242 
243 	return (0);
244 }
245 
246 int
247 proto_send(const struct proto_conn *conn, const void *data, size_t size)
248 {
249 	int ret;
250 
251 	PJDLOG_ASSERT(conn != NULL);
252 	PJDLOG_ASSERT(conn->pc_magic == PROTO_CONN_MAGIC);
253 	PJDLOG_ASSERT(conn->pc_proto != NULL);
254 	PJDLOG_ASSERT(conn->pc_proto->prt_send != NULL);
255 
256 	ret = conn->pc_proto->prt_send(conn->pc_ctx, data, size, -1);
257 	if (ret != 0) {
258 		errno = ret;
259 		return (-1);
260 	}
261 	return (0);
262 }
263 
264 int
265 proto_recv(const struct proto_conn *conn, void *data, size_t size)
266 {
267 	int ret;
268 
269 	PJDLOG_ASSERT(conn != NULL);
270 	PJDLOG_ASSERT(conn->pc_magic == PROTO_CONN_MAGIC);
271 	PJDLOG_ASSERT(conn->pc_proto != NULL);
272 	PJDLOG_ASSERT(conn->pc_proto->prt_recv != NULL);
273 
274 	ret = conn->pc_proto->prt_recv(conn->pc_ctx, data, size, NULL);
275 	if (ret != 0) {
276 		errno = ret;
277 		return (-1);
278 	}
279 	return (0);
280 }
281 
282 int
283 proto_connection_send(const struct proto_conn *conn, struct proto_conn *mconn)
284 {
285 	const char *protoname;
286 	int ret, fd;
287 
288 	PJDLOG_ASSERT(conn != NULL);
289 	PJDLOG_ASSERT(conn->pc_magic == PROTO_CONN_MAGIC);
290 	PJDLOG_ASSERT(conn->pc_proto != NULL);
291 	PJDLOG_ASSERT(conn->pc_proto->prt_send != NULL);
292 	PJDLOG_ASSERT(mconn != NULL);
293 	PJDLOG_ASSERT(mconn->pc_magic == PROTO_CONN_MAGIC);
294 	PJDLOG_ASSERT(mconn->pc_proto != NULL);
295 	fd = proto_descriptor(mconn);
296 	PJDLOG_ASSERT(fd >= 0);
297 	protoname = mconn->pc_proto->prt_name;
298 	PJDLOG_ASSERT(protoname != NULL);
299 
300 	ret = conn->pc_proto->prt_send(conn->pc_ctx,
301 	    (const unsigned char *)protoname, strlen(protoname) + 1, fd);
302 	proto_close(mconn);
303 	if (ret != 0) {
304 		errno = ret;
305 		return (-1);
306 	}
307 	return (0);
308 }
309 
310 int
311 proto_connection_recv(const struct proto_conn *conn, bool client,
312     struct proto_conn **newconnp)
313 {
314 	char protoname[128];
315 	struct proto *proto;
316 	struct proto_conn *newconn;
317 	int ret, fd;
318 
319 	PJDLOG_ASSERT(conn != NULL);
320 	PJDLOG_ASSERT(conn->pc_magic == PROTO_CONN_MAGIC);
321 	PJDLOG_ASSERT(conn->pc_proto != NULL);
322 	PJDLOG_ASSERT(conn->pc_proto->prt_recv != NULL);
323 	PJDLOG_ASSERT(newconnp != NULL);
324 
325 	bzero(protoname, sizeof(protoname));
326 
327 	ret = conn->pc_proto->prt_recv(conn->pc_ctx, (unsigned char *)protoname,
328 	    sizeof(protoname) - 1, &fd);
329 	if (ret != 0) {
330 		errno = ret;
331 		return (-1);
332 	}
333 
334 	PJDLOG_ASSERT(fd >= 0);
335 
336 	TAILQ_FOREACH(proto, &protos, prt_next) {
337 		if (strcmp(proto->prt_name, protoname) == 0)
338 			break;
339 	}
340 	if (proto == NULL) {
341 		errno = EINVAL;
342 		return (-1);
343 	}
344 
345 	newconn = proto_alloc(proto,
346 	    client ? PROTO_SIDE_CLIENT : PROTO_SIDE_SERVER_WORK);
347 	if (newconn == NULL)
348 		return (-1);
349 	PJDLOG_ASSERT(newconn->pc_proto->prt_wrap != NULL);
350 	ret = newconn->pc_proto->prt_wrap(fd, client, &newconn->pc_ctx);
351 	if (ret != 0) {
352 		proto_free(newconn);
353 		errno = ret;
354 		return (-1);
355 	}
356 
357 	*newconnp = newconn;
358 
359 	return (0);
360 }
361 
362 int
363 proto_descriptor(const struct proto_conn *conn)
364 {
365 
366 	PJDLOG_ASSERT(conn != NULL);
367 	PJDLOG_ASSERT(conn->pc_magic == PROTO_CONN_MAGIC);
368 	PJDLOG_ASSERT(conn->pc_proto != NULL);
369 	PJDLOG_ASSERT(conn->pc_proto->prt_descriptor != NULL);
370 
371 	return (conn->pc_proto->prt_descriptor(conn->pc_ctx));
372 }
373 
374 bool
375 proto_address_match(const struct proto_conn *conn, const char *addr)
376 {
377 
378 	PJDLOG_ASSERT(conn != NULL);
379 	PJDLOG_ASSERT(conn->pc_magic == PROTO_CONN_MAGIC);
380 	PJDLOG_ASSERT(conn->pc_proto != NULL);
381 	PJDLOG_ASSERT(conn->pc_proto->prt_address_match != NULL);
382 
383 	return (conn->pc_proto->prt_address_match(conn->pc_ctx, addr));
384 }
385 
386 void
387 proto_local_address(const struct proto_conn *conn, char *addr, size_t size)
388 {
389 
390 	PJDLOG_ASSERT(conn != NULL);
391 	PJDLOG_ASSERT(conn->pc_magic == PROTO_CONN_MAGIC);
392 	PJDLOG_ASSERT(conn->pc_proto != NULL);
393 	PJDLOG_ASSERT(conn->pc_proto->prt_local_address != NULL);
394 
395 	conn->pc_proto->prt_local_address(conn->pc_ctx, addr, size);
396 }
397 
398 void
399 proto_remote_address(const struct proto_conn *conn, char *addr, size_t size)
400 {
401 
402 	PJDLOG_ASSERT(conn != NULL);
403 	PJDLOG_ASSERT(conn->pc_magic == PROTO_CONN_MAGIC);
404 	PJDLOG_ASSERT(conn->pc_proto != NULL);
405 	PJDLOG_ASSERT(conn->pc_proto->prt_remote_address != NULL);
406 
407 	conn->pc_proto->prt_remote_address(conn->pc_ctx, addr, size);
408 }
409 
410 int
411 proto_timeout(const struct proto_conn *conn, int timeout)
412 {
413 	struct timeval tv;
414 	int fd;
415 
416 	PJDLOG_ASSERT(conn != NULL);
417 	PJDLOG_ASSERT(conn->pc_magic == PROTO_CONN_MAGIC);
418 	PJDLOG_ASSERT(conn->pc_proto != NULL);
419 
420 	fd = proto_descriptor(conn);
421 	if (fd == -1)
422 		return (-1);
423 
424 	tv.tv_sec = timeout;
425 	tv.tv_usec = 0;
426 	if (setsockopt(fd, SOL_SOCKET, SO_SNDTIMEO, &tv, sizeof(tv)) == -1)
427 		return (-1);
428 	if (setsockopt(fd, SOL_SOCKET, SO_RCVTIMEO, &tv, sizeof(tv)) == -1)
429 		return (-1);
430 
431 	return (0);
432 }
433 
434 void
435 proto_close(struct proto_conn *conn)
436 {
437 
438 	PJDLOG_ASSERT(conn != NULL);
439 	PJDLOG_ASSERT(conn->pc_magic == PROTO_CONN_MAGIC);
440 	PJDLOG_ASSERT(conn->pc_proto != NULL);
441 	PJDLOG_ASSERT(conn->pc_proto->prt_close != NULL);
442 
443 	conn->pc_proto->prt_close(conn->pc_ctx);
444 	proto_free(conn);
445 }
446