xref: /freebsd/sbin/hastd/proto_socketpair.c (revision 3494f7c0)
1 /*-
2  * SPDX-License-Identifier: BSD-2-Clause
3  *
4  * Copyright (c) 2009-2010 The FreeBSD Foundation
5  * All rights reserved.
6  *
7  * This software was developed by Pawel Jakub Dawidek under sponsorship from
8  * the FreeBSD Foundation.
9  *
10  * Redistribution and use in source and binary forms, with or without
11  * modification, are permitted provided that the following conditions
12  * are met:
13  * 1. Redistributions of source code must retain the above copyright
14  *    notice, this list of conditions and the following disclaimer.
15  * 2. Redistributions in binary form must reproduce the above copyright
16  *    notice, this list of conditions and the following disclaimer in the
17  *    documentation and/or other materials provided with the distribution.
18  *
19  * THIS SOFTWARE IS PROVIDED BY THE AUTHORS AND CONTRIBUTORS ``AS IS'' AND
20  * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21  * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
22  * ARE DISCLAIMED.  IN NO EVENT SHALL THE AUTHORS OR CONTRIBUTORS BE LIABLE
23  * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24  * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS
25  * OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
26  * HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
27  * LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY
28  * OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF
29  * SUCH DAMAGE.
30  */
31 
32 #include <sys/types.h>
33 #include <sys/socket.h>
34 
35 #include <errno.h>
36 #include <stdbool.h>
37 #include <stdint.h>
38 #include <stdio.h>
39 #include <string.h>
40 #include <unistd.h>
41 
42 #include "pjdlog.h"
43 #include "proto_impl.h"
44 
45 #define	SP_CTX_MAGIC	0x50c3741
46 struct sp_ctx {
47 	int			sp_magic;
48 	int			sp_fd[2];
49 	int			sp_side;
50 #define	SP_SIDE_UNDEF		0
51 #define	SP_SIDE_CLIENT		1
52 #define	SP_SIDE_SERVER		2
53 };
54 
55 static void sp_close(void *ctx);
56 
57 static int
58 sp_client(const char *srcaddr, const char *dstaddr, void **ctxp)
59 {
60 	struct sp_ctx *spctx;
61 	int ret;
62 
63 	if (strcmp(dstaddr, "socketpair://") != 0)
64 		return (-1);
65 
66 	PJDLOG_ASSERT(srcaddr == NULL);
67 
68 	spctx = malloc(sizeof(*spctx));
69 	if (spctx == NULL)
70 		return (errno);
71 
72 	if (socketpair(PF_UNIX, SOCK_STREAM, 0, spctx->sp_fd) == -1) {
73 		ret = errno;
74 		free(spctx);
75 		return (ret);
76 	}
77 
78 	spctx->sp_side = SP_SIDE_UNDEF;
79 	spctx->sp_magic = SP_CTX_MAGIC;
80 	*ctxp = spctx;
81 
82 	return (0);
83 }
84 
85 static int
86 sp_send(void *ctx, const unsigned char *data, size_t size, int fd)
87 {
88 	struct sp_ctx *spctx = ctx;
89 	int sock;
90 
91 	PJDLOG_ASSERT(spctx != NULL);
92 	PJDLOG_ASSERT(spctx->sp_magic == SP_CTX_MAGIC);
93 
94 	switch (spctx->sp_side) {
95 	case SP_SIDE_UNDEF:
96 		/*
97 		 * If the first operation done by the caller is proto_send(),
98 		 * we assume this is the client.
99 		 */
100 		/* FALLTHROUGH */
101 		spctx->sp_side = SP_SIDE_CLIENT;
102 		/* Close other end. */
103 		close(spctx->sp_fd[1]);
104 		spctx->sp_fd[1] = -1;
105 	case SP_SIDE_CLIENT:
106 		PJDLOG_ASSERT(spctx->sp_fd[0] >= 0);
107 		sock = spctx->sp_fd[0];
108 		break;
109 	case SP_SIDE_SERVER:
110 		PJDLOG_ASSERT(spctx->sp_fd[1] >= 0);
111 		sock = spctx->sp_fd[1];
112 		break;
113 	default:
114 		PJDLOG_ABORT("Invalid socket side (%d).", spctx->sp_side);
115 	}
116 
117 	/* Someone is just trying to decide about side. */
118 	if (data == NULL)
119 		return (0);
120 
121 	return (proto_common_send(sock, data, size, fd));
122 }
123 
124 static int
125 sp_recv(void *ctx, unsigned char *data, size_t size, int *fdp)
126 {
127 	struct sp_ctx *spctx = ctx;
128 	int fd;
129 
130 	PJDLOG_ASSERT(spctx != NULL);
131 	PJDLOG_ASSERT(spctx->sp_magic == SP_CTX_MAGIC);
132 
133 	switch (spctx->sp_side) {
134 	case SP_SIDE_UNDEF:
135 		/*
136 		 * If the first operation done by the caller is proto_recv(),
137 		 * we assume this is the server.
138 		 */
139 		/* FALLTHROUGH */
140 		spctx->sp_side = SP_SIDE_SERVER;
141 		/* Close other end. */
142 		close(spctx->sp_fd[0]);
143 		spctx->sp_fd[0] = -1;
144 	case SP_SIDE_SERVER:
145 		PJDLOG_ASSERT(spctx->sp_fd[1] >= 0);
146 		fd = spctx->sp_fd[1];
147 		break;
148 	case SP_SIDE_CLIENT:
149 		PJDLOG_ASSERT(spctx->sp_fd[0] >= 0);
150 		fd = spctx->sp_fd[0];
151 		break;
152 	default:
153 		PJDLOG_ABORT("Invalid socket side (%d).", spctx->sp_side);
154 	}
155 
156 	/* Someone is just trying to decide about side. */
157 	if (data == NULL)
158 		return (0);
159 
160 	return (proto_common_recv(fd, data, size, fdp));
161 }
162 
163 static int
164 sp_descriptor(const void *ctx)
165 {
166 	const struct sp_ctx *spctx = ctx;
167 
168 	PJDLOG_ASSERT(spctx != NULL);
169 	PJDLOG_ASSERT(spctx->sp_magic == SP_CTX_MAGIC);
170 	PJDLOG_ASSERT(spctx->sp_side == SP_SIDE_CLIENT ||
171 	    spctx->sp_side == SP_SIDE_SERVER);
172 
173 	switch (spctx->sp_side) {
174 	case SP_SIDE_CLIENT:
175 		PJDLOG_ASSERT(spctx->sp_fd[0] >= 0);
176 		return (spctx->sp_fd[0]);
177 	case SP_SIDE_SERVER:
178 		PJDLOG_ASSERT(spctx->sp_fd[1] >= 0);
179 		return (spctx->sp_fd[1]);
180 	}
181 
182 	PJDLOG_ABORT("Invalid socket side (%d).", spctx->sp_side);
183 }
184 
185 static void
186 sp_close(void *ctx)
187 {
188 	struct sp_ctx *spctx = ctx;
189 
190 	PJDLOG_ASSERT(spctx != NULL);
191 	PJDLOG_ASSERT(spctx->sp_magic == SP_CTX_MAGIC);
192 
193 	switch (spctx->sp_side) {
194 	case SP_SIDE_UNDEF:
195 		PJDLOG_ASSERT(spctx->sp_fd[0] >= 0);
196 		close(spctx->sp_fd[0]);
197 		spctx->sp_fd[0] = -1;
198 		PJDLOG_ASSERT(spctx->sp_fd[1] >= 0);
199 		close(spctx->sp_fd[1]);
200 		spctx->sp_fd[1] = -1;
201 		break;
202 	case SP_SIDE_CLIENT:
203 		PJDLOG_ASSERT(spctx->sp_fd[0] >= 0);
204 		close(spctx->sp_fd[0]);
205 		spctx->sp_fd[0] = -1;
206 		PJDLOG_ASSERT(spctx->sp_fd[1] == -1);
207 		break;
208 	case SP_SIDE_SERVER:
209 		PJDLOG_ASSERT(spctx->sp_fd[1] >= 0);
210 		close(spctx->sp_fd[1]);
211 		spctx->sp_fd[1] = -1;
212 		PJDLOG_ASSERT(spctx->sp_fd[0] == -1);
213 		break;
214 	default:
215 		PJDLOG_ABORT("Invalid socket side (%d).", spctx->sp_side);
216 	}
217 
218 	spctx->sp_magic = 0;
219 	free(spctx);
220 }
221 
222 static struct proto sp_proto = {
223 	.prt_name = "socketpair",
224 	.prt_client = sp_client,
225 	.prt_send = sp_send,
226 	.prt_recv = sp_recv,
227 	.prt_descriptor = sp_descriptor,
228 	.prt_close = sp_close
229 };
230 
231 static __constructor void
232 sp_ctor(void)
233 {
234 
235 	proto_register(&sp_proto, false);
236 }
237