1 /*-
2  * Copyright (c) 2009-2010 The FreeBSD Foundation
3  * All rights reserved.
4  *
5  * This software was developed by Pawel Jakub Dawidek under sponsorship from
6  * the FreeBSD Foundation.
7  *
8  * Redistribution and use in source and binary forms, with or without
9  * modification, are permitted provided that the following conditions
10  * are met:
11  * 1. Redistributions of source code must retain the above copyright
12  *    notice, this list of conditions and the following disclaimer.
13  * 2. Redistributions in binary form must reproduce the above copyright
14  *    notice, this list of conditions and the following disclaimer in the
15  *    documentation and/or other materials provided with the distribution.
16  *
17  * THIS SOFTWARE IS PROVIDED BY THE AUTHORS AND CONTRIBUTORS ``AS IS'' AND
18  * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
19  * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
20  * ARE DISCLAIMED.  IN NO EVENT SHALL THE AUTHORS OR CONTRIBUTORS BE LIABLE
21  * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
22  * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS
23  * OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
24  * HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
25  * LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY
26  * OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF
27  * SUCH DAMAGE.
28  */
29 
30 #include <sys/types.h>
31 #include <sys/socket.h>
32 
33 #include <errno.h>
34 #include <stdbool.h>
35 #include <stdint.h>
36 #include <stdio.h>
37 #include <string.h>
38 #include <unistd.h>
39 
40 #include "pjdlog.h"
41 #include "proto_impl.h"
42 
43 #define	SP_CTX_MAGIC	0x50c3741
44 struct sp_ctx {
45 	int			sp_magic;
46 	int			sp_fd[2];
47 	int			sp_side;
48 #define	SP_SIDE_UNDEF		0
49 #define	SP_SIDE_CLIENT		1
50 #define	SP_SIDE_SERVER		2
51 };
52 
53 static void sp_close(void *ctx);
54 
55 static int
56 sp_connect(const char *srcaddr, const char *dstaddr, int timeout, void **ctxp)
57 {
58 	struct sp_ctx *spctx;
59 	int error;
60 
61 	PJDLOG_ASSERT(dstaddr != NULL);
62 	PJDLOG_ASSERT(timeout >= -1);
63 
64 	if (strcmp(dstaddr, "socketpair://") != 0)
65 		return (-1);
66 
67 	PJDLOG_ASSERT(srcaddr == NULL);
68 
69 	spctx = malloc(sizeof(*spctx));
70 	if (spctx == NULL)
71 		return (errno);
72 
73 	if (socketpair(PF_UNIX, SOCK_STREAM, 0, spctx->sp_fd) == -1) {
74 		error = errno;
75 		free(spctx);
76 		return (error);
77 	}
78 
79 	spctx->sp_side = SP_SIDE_UNDEF;
80 	spctx->sp_magic = SP_CTX_MAGIC;
81 	*ctxp = spctx;
82 
83 	return (0);
84 }
85 
86 static int
87 sp_wrap(int fd, bool client, void **ctxp)
88 {
89 	struct sp_ctx *spctx;
90 
91 	PJDLOG_ASSERT(fd >= 0);
92 
93 	spctx = malloc(sizeof(*spctx));
94 	if (spctx == NULL)
95 		return (errno);
96 
97 	if (client) {
98 		spctx->sp_side = SP_SIDE_CLIENT;
99 		spctx->sp_fd[0] = fd;
100 		spctx->sp_fd[1] = -1;
101 	} else {
102 		spctx->sp_side = SP_SIDE_SERVER;
103 		spctx->sp_fd[0] = -1;
104 		spctx->sp_fd[1] = fd;
105 	}
106 	spctx->sp_magic = SP_CTX_MAGIC;
107 	*ctxp = spctx;
108 
109 	return (0);
110 }
111 
112 static int
113 sp_send(void *ctx, const unsigned char *data, size_t size, int fd)
114 {
115 	struct sp_ctx *spctx = ctx;
116 	int sock;
117 
118 	PJDLOG_ASSERT(spctx != NULL);
119 	PJDLOG_ASSERT(spctx->sp_magic == SP_CTX_MAGIC);
120 
121 	switch (spctx->sp_side) {
122 	case SP_SIDE_UNDEF:
123 		/*
124 		 * If the first operation done by the caller is proto_send(),
125 		 * we assume this is the client.
126 		 */
127 		/* FALLTHROUGH */
128 		spctx->sp_side = SP_SIDE_CLIENT;
129 		/* Close other end. */
130 		close(spctx->sp_fd[1]);
131 		spctx->sp_fd[1] = -1;
132 	case SP_SIDE_CLIENT:
133 		PJDLOG_ASSERT(spctx->sp_fd[0] >= 0);
134 		sock = spctx->sp_fd[0];
135 		break;
136 	case SP_SIDE_SERVER:
137 		PJDLOG_ASSERT(spctx->sp_fd[1] >= 0);
138 		sock = spctx->sp_fd[1];
139 		break;
140 	default:
141 		PJDLOG_ABORT("Invalid socket side (%d).", spctx->sp_side);
142 	}
143 
144 	/* Someone is just trying to decide about side. */
145 	if (data == NULL)
146 		return (0);
147 
148 	return (proto_common_send(sock, data, size, fd));
149 }
150 
151 static int
152 sp_recv(void *ctx, unsigned char *data, size_t size, int *fdp)
153 {
154 	struct sp_ctx *spctx = ctx;
155 	int sock;
156 
157 	PJDLOG_ASSERT(spctx != NULL);
158 	PJDLOG_ASSERT(spctx->sp_magic == SP_CTX_MAGIC);
159 
160 	switch (spctx->sp_side) {
161 	case SP_SIDE_UNDEF:
162 		/*
163 		 * If the first operation done by the caller is proto_recv(),
164 		 * we assume this is the server.
165 		 */
166 		/* FALLTHROUGH */
167 		spctx->sp_side = SP_SIDE_SERVER;
168 		/* Close other end. */
169 		close(spctx->sp_fd[0]);
170 		spctx->sp_fd[0] = -1;
171 	case SP_SIDE_SERVER:
172 		PJDLOG_ASSERT(spctx->sp_fd[1] >= 0);
173 		sock = spctx->sp_fd[1];
174 		break;
175 	case SP_SIDE_CLIENT:
176 		PJDLOG_ASSERT(spctx->sp_fd[0] >= 0);
177 		sock = spctx->sp_fd[0];
178 		break;
179 	default:
180 		PJDLOG_ABORT("Invalid socket side (%d).", spctx->sp_side);
181 	}
182 
183 	/* Someone is just trying to decide about side. */
184 	if (data == NULL)
185 		return (0);
186 
187 	return (proto_common_recv(sock, data, size, fdp));
188 }
189 
190 static int
191 sp_descriptor(const void *ctx)
192 {
193 	const struct sp_ctx *spctx = ctx;
194 
195 	PJDLOG_ASSERT(spctx != NULL);
196 	PJDLOG_ASSERT(spctx->sp_magic == SP_CTX_MAGIC);
197 	PJDLOG_ASSERT(spctx->sp_side == SP_SIDE_CLIENT ||
198 	    spctx->sp_side == SP_SIDE_SERVER);
199 
200 	switch (spctx->sp_side) {
201 	case SP_SIDE_CLIENT:
202 		PJDLOG_ASSERT(spctx->sp_fd[0] >= 0);
203 		return (spctx->sp_fd[0]);
204 	case SP_SIDE_SERVER:
205 		PJDLOG_ASSERT(spctx->sp_fd[1] >= 0);
206 		return (spctx->sp_fd[1]);
207 	}
208 
209 	PJDLOG_ABORT("Invalid socket side (%d).", spctx->sp_side);
210 }
211 
212 static void
213 sp_close(void *ctx)
214 {
215 	struct sp_ctx *spctx = ctx;
216 
217 	PJDLOG_ASSERT(spctx != NULL);
218 	PJDLOG_ASSERT(spctx->sp_magic == SP_CTX_MAGIC);
219 
220 	switch (spctx->sp_side) {
221 	case SP_SIDE_UNDEF:
222 		PJDLOG_ASSERT(spctx->sp_fd[0] >= 0);
223 		close(spctx->sp_fd[0]);
224 		spctx->sp_fd[0] = -1;
225 		PJDLOG_ASSERT(spctx->sp_fd[1] >= 0);
226 		close(spctx->sp_fd[1]);
227 		spctx->sp_fd[1] = -1;
228 		break;
229 	case SP_SIDE_CLIENT:
230 		PJDLOG_ASSERT(spctx->sp_fd[0] >= 0);
231 		close(spctx->sp_fd[0]);
232 		spctx->sp_fd[0] = -1;
233 		PJDLOG_ASSERT(spctx->sp_fd[1] == -1);
234 		break;
235 	case SP_SIDE_SERVER:
236 		PJDLOG_ASSERT(spctx->sp_fd[1] >= 0);
237 		close(spctx->sp_fd[1]);
238 		spctx->sp_fd[1] = -1;
239 		PJDLOG_ASSERT(spctx->sp_fd[0] == -1);
240 		break;
241 	default:
242 		PJDLOG_ABORT("Invalid socket side (%d).", spctx->sp_side);
243 	}
244 
245 	spctx->sp_magic = 0;
246 	free(spctx);
247 }
248 
249 static struct proto sp_proto = {
250 	.prt_name = "socketpair",
251 	.prt_connect = sp_connect,
252 	.prt_wrap = sp_wrap,
253 	.prt_send = sp_send,
254 	.prt_recv = sp_recv,
255 	.prt_descriptor = sp_descriptor,
256 	.prt_close = sp_close
257 };
258 
259 static __constructor void
260 sp_ctor(void)
261 {
262 
263 	proto_register(&sp_proto, false);
264 }
265