1 /*-
2  * Copyright (c) 2006 Robert N. M. Watson
3  * All rights reserved.
4  *
5  * Redistribution and use in source and binary forms, with or without
6  * modification, are permitted provided that the following conditions
7  * are met:
8  * 1. Redistributions of source code must retain the above copyright
9  *    notice, this list of conditions and the following disclaimer.
10  * 2. Redistributions in binary form must reproduce the above copyright
11  *    notice, this list of conditions and the following disclaimer in the
12  *    documentation and/or other materials provided with the distribution.
13  *
14  * THIS SOFTWARE IS PROVIDED BY THE AUTHOR AND CONTRIBUTORS ``AS IS'' AND
15  * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
16  * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
17  * ARE DISCLAIMED.  IN NO EVENT SHALL THE AUTHOR OR CONTRIBUTORS BE LIABLE
18  * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
19  * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS
20  * OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
21  * HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
22  * LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY
23  * OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF
24  * SUCH DAMAGE.
25  *
26  * $FreeBSD$
27  */
28 
29 #include <sys/types.h>
30 #include <sys/socket.h>
31 
32 #include <netinet/in.h>
33 
34 #include <err.h>
35 #include <limits.h>
36 #include <signal.h>
37 #include <stdio.h>
38 #include <stdlib.h>
39 #include <string.h>
40 #include <unistd.h>
41 
42 /*
43  * Simple regression test for sendfile.  Creates a file sized at three pages
44  * and then proceeds to send it over a series of sockets, exercising a number
45  * of cases and performing limited validation.
46  */
47 
48 #define	TEST_PORT	5678
49 #define	TEST_MAGIC	0x4440f7bb
50 #define	TEST_PAGES	3
51 #define	TEST_SECONDS	30
52 
53 struct test_header {
54 	u_int32_t	th_magic;
55 	u_int32_t	th_header_length;
56 	u_int32_t	th_offset;
57 	u_int32_t	th_length;
58 };
59 
60 pid_t	child_pid, parent_pid;
61 int	listen_socket;
62 int	file_fd;
63 
64 static int
65 test_th(struct test_header *th, u_int32_t *header_length, u_int32_t *offset,
66     u_int32_t *length)
67 {
68 
69 	if (th->th_magic != htonl(TEST_MAGIC))
70 		return (0);
71 	*header_length = ntohl(th->th_header_length);
72 	*offset = ntohl(th->th_offset);
73 	*length = ntohl(th->th_length);
74 	return (1);
75 }
76 
77 static void
78 signal_alarm(int signum)
79 {
80 
81 }
82 
83 static void
84 setup_alarm(int seconds)
85 {
86 
87 	signal(SIGALRM, signal_alarm);
88 	alarm(seconds);
89 }
90 
91 static void
92 cancel_alarm(void)
93 {
94 
95 	alarm(0);
96 	signal(SIGALRM, SIG_DFL);
97 }
98 
99 static void
100 receive_test(int accept_socket)
101 {
102 	u_int32_t header_length, offset, length, counter;
103 	struct test_header th;
104 	ssize_t len;
105 	char ch;
106 
107 	len = read(accept_socket, &th, sizeof(th));
108 	if (len < 0)
109 		err(-1, "read");
110 	if (len < sizeof(th))
111 		errx(-1, "read: %d", len);
112 
113 	if (test_th(&th, &header_length, &offset, &length) == 0)
114 		errx(-1, "test_th: bad");
115 
116 	counter = 0;
117 	while (1) {
118 		len = read(accept_socket, &ch, sizeof(ch));
119 		if (len < 0)
120 			err(-1, "read");
121 		if (len == 0)
122 			break;
123 		counter++;
124 		/* XXXRW: Validate byte here. */
125 	}
126 	if (counter != header_length + length)
127 		errx(-1, "receive_test: expected (%d, %d) received %d",
128 		    header_length, length, counter);
129 }
130 
131 static void
132 run_child(void)
133 {
134 	int accept_socket;
135 
136 	while (1) {
137 		accept_socket = accept(listen_socket, NULL, NULL);
138 		setup_alarm(TEST_SECONDS);
139 		receive_test(accept_socket);
140 		cancel_alarm();
141 		close(accept_socket);
142 	}
143 }
144 
145 static int
146 new_test_socket(void)
147 {
148 	struct sockaddr_in sin;
149 	int connect_socket;
150 
151 	connect_socket = socket(PF_INET, SOCK_STREAM, 0);
152 	if (connect_socket < 0)
153 		err(-1, "socket");
154 
155 	bzero(&sin, sizeof(sin));
156 	sin.sin_len = sizeof(sin);
157 	sin.sin_family = AF_INET;
158 	sin.sin_addr.s_addr = htonl(INADDR_LOOPBACK);
159 	sin.sin_port = htons(TEST_PORT);
160 
161 	if (connect(connect_socket, (struct sockaddr *)&sin, sizeof(sin)) < 0)
162 		err(-1, "connect");
163 
164 	return (connect_socket);
165 }
166 
167 static void
168 init_th(struct test_header *th, u_int32_t header_length, u_int32_t offset,
169     u_int32_t length)
170 {
171 
172 	bzero(th, sizeof(*th));
173 	th->th_magic = htonl(TEST_MAGIC);
174 	th->th_header_length = htonl(header_length);
175 	th->th_offset = htonl(offset);
176 	th->th_length = htonl(length);
177 }
178 
179 static void
180 send_test(int connect_socket, u_int32_t header_length, u_int32_t offset,
181     u_int32_t length)
182 {
183 	struct test_header th;
184 	struct sf_hdtr hdtr, *hdtrp;
185 	struct iovec headers;
186 	char *header;
187 	ssize_t len;
188 	off_t off;
189 
190 	len = lseek(file_fd, 0, SEEK_SET);
191 	if (len < 0)
192 		err(-1, "lseek");
193 	if (len != 0)
194 		errx(-1, "lseek: %d", len);
195 
196 	init_th(&th, header_length, offset, length);
197 
198 	len = write(connect_socket, &th, sizeof(th));
199 	if (len < 0)
200 		err(-1, "send");
201 	if (len != sizeof(th))
202 		err(-1, "send: %d", len);
203 
204 	if (header_length != 0) {
205 		header = malloc(header_length);
206 		if (header == NULL)
207 			err(-1, "malloc");
208 		hdtrp = &hdtr;
209 		bzero(&headers, sizeof(headers));
210 		headers.iov_base = header;
211 		headers.iov_len = header_length;
212 		bzero(&hdtr, sizeof(hdtr));
213 		hdtr.headers = &headers;
214 		hdtr.hdr_cnt = 1;
215 		hdtr.trailers = NULL;
216 		hdtr.trl_cnt = 0;
217 	} else {
218 		hdtrp = NULL;
219 		header = NULL;
220 	}
221 
222 	if (sendfile(file_fd, connect_socket, offset, length, hdtrp, &off,
223 	    0) < 0)
224 		err(-1, "sendfile");
225 
226 	if (off != length)
227 		errx(-1, "sendfile: off %llu", off);
228 
229 	if (header != NULL)
230 		free(header);
231 }
232 
233 static void
234 run_parent(void)
235 {
236 	int connect_socket;
237 
238 	connect_socket = new_test_socket();
239 	send_test(connect_socket, 0, 0, 1);
240 	close(connect_socket);
241 
242 	sleep(1);
243 
244 	connect_socket = new_test_socket();
245 	send_test(connect_socket, 0, 0, getpagesize());
246 	close(connect_socket);
247 
248 	sleep(1);
249 
250 	connect_socket = new_test_socket();
251 	send_test(connect_socket, 0, 1, 1);
252 	close(connect_socket);
253 
254 	sleep(1);
255 
256 	connect_socket = new_test_socket();
257 	send_test(connect_socket, 0, 1, getpagesize());
258 	close(connect_socket);
259 
260 	sleep(1);
261 
262 	connect_socket = new_test_socket();
263 	send_test(connect_socket, 0, getpagesize(), getpagesize());
264 	close(connect_socket);
265 
266 	sleep(1);
267 
268 	(void)kill(child_pid, SIGKILL);
269 }
270 
271 int
272 main(int argc, char *argv[])
273 {
274 	char path[PATH_MAX], *page_buffer;
275 	struct sockaddr_in sin;
276 	int pagesize;
277 	ssize_t len;
278 
279 	pagesize = getpagesize();
280 	page_buffer = malloc(TEST_PAGES * pagesize);
281 	if (page_buffer == NULL)
282 		err(-1, "malloc");
283 	bzero(page_buffer, TEST_PAGES * pagesize);
284 
285 	listen_socket = socket(PF_INET, SOCK_STREAM, 0);
286 	if (listen_socket < 0)
287 		err(-1, "socket");
288 
289 	bzero(&sin, sizeof(sin));
290 	sin.sin_len = sizeof(sin);
291 	sin.sin_family = AF_INET;
292 	sin.sin_addr.s_addr = htonl(INADDR_LOOPBACK);
293 	sin.sin_port = htons(TEST_PORT);
294 
295 	snprintf(path, PATH_MAX, "/tmp/sendfile.XXXXXXXXXXXX");
296 	file_fd = mkstemp(path);
297 	(void)unlink(path);
298 
299 	len = write(file_fd, page_buffer, TEST_PAGES * pagesize);
300 	if (len < 0)
301 		err(-1, "write");
302 
303 	len = lseek(file_fd, 0, SEEK_SET);
304 	if (len < 0)
305 		err(-1, "lseek");
306 	if (len != 0)
307 		errx(-1, "lseek: %d", len);
308 
309 	if (bind(listen_socket, (struct sockaddr *)&sin, sizeof(sin)) < 0)
310 		err(-1, "bind");
311 
312 	if (listen(listen_socket, -1) < 0)
313 		err(-1, "listen");
314 
315 	parent_pid = getpid();
316 	child_pid = fork();
317 	if (child_pid < 0)
318 		err(-1, "fork");
319 	if (child_pid == 0) {
320 		child_pid = getpid();
321 		run_child();
322 	} else
323 		run_parent();
324 
325 	return (0);
326 }
327