1 /*-
2  * Copyright (c) 2006 Robert N. M. Watson
3  * All rights reserved.
4  *
5  * Redistribution and use in source and binary forms, with or without
6  * modification, are permitted provided that the following conditions
7  * are met:
8  * 1. Redistributions of source code must retain the above copyright
9  *    notice, this list of conditions and the following disclaimer.
10  * 2. Redistributions in binary form must reproduce the above copyright
11  *    notice, this list of conditions and the following disclaimer in the
12  *    documentation and/or other materials provided with the distribution.
13  *
14  * THIS SOFTWARE IS PROVIDED BY THE AUTHOR AND CONTRIBUTORS ``AS IS'' AND
15  * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
16  * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
17  * ARE DISCLAIMED.  IN NO EVENT SHALL THE AUTHOR OR CONTRIBUTORS BE LIABLE
18  * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
19  * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS
20  * OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
21  * HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
22  * LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY
23  * OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF
24  * SUCH DAMAGE.
25  *
26  * $FreeBSD$
27  */
28 
29 #include <sys/types.h>
30 #include <sys/socket.h>
31 #include <sys/stat.h>
32 #include <sys/wait.h>
33 
34 #include <netinet/in.h>
35 
36 #include <err.h>
37 #include <errno.h>
38 #include <fcntl.h>
39 #include <limits.h>
40 #include <md5.h>
41 #include <signal.h>
42 #include <stdint.h>
43 #include <stdio.h>
44 #include <stdlib.h>
45 #include <string.h>
46 #include <unistd.h>
47 
48 /*
49  * Simple regression test for sendfile.  Creates a file sized at four pages
50  * and then proceeds to send it over a series of sockets, exercising a number
51  * of cases and performing limited validation.
52  */
53 
54 #define FAIL(msg)	{printf("# %s\n", msg); \
55 			return (-1);}
56 
57 #define FAIL_ERR(msg)	{printf("# %s: %s\n", msg, strerror(errno)); \
58 			return (-1);}
59 
60 #define	TEST_PORT	5678
61 #define	TEST_MAGIC	0x4440f7bb
62 #define	TEST_PAGES	4
63 #define	TEST_SECONDS	30
64 
65 struct test_header {
66 	uint32_t	th_magic;
67 	uint32_t	th_header_length;
68 	uint32_t	th_offset;
69 	uint32_t	th_length;
70 	char		th_md5[33];
71 };
72 
73 struct sendfile_test {
74 	uint32_t	hdr_length;
75 	uint32_t	offset;
76 	uint32_t	length;
77 	uint32_t	file_size;
78 };
79 
80 static int	file_fd;
81 static char	path[PATH_MAX];
82 static int	listen_socket;
83 static int	accept_socket;
84 
85 static int test_th(struct test_header *th, uint32_t *header_length,
86 		uint32_t *offset, uint32_t *length);
87 static void signal_alarm(int signum);
88 static void setup_alarm(int seconds);
89 static void cancel_alarm(void);
90 static int receive_test(void);
91 static void run_child(void);
92 static int new_test_socket(int *connect_socket);
93 static void init_th(struct test_header *th, uint32_t header_length,
94 		uint32_t offset, uint32_t length);
95 static int send_test(int connect_socket, struct sendfile_test);
96 static int write_test_file(size_t file_size);
97 static void run_parent(void);
98 static void cleanup(void);
99 
100 
101 static int
102 test_th(struct test_header *th, uint32_t *header_length, uint32_t *offset,
103 		uint32_t *length)
104 {
105 
106 	if (th->th_magic != htonl(TEST_MAGIC))
107 		FAIL("magic number not found in header")
108 	*header_length = ntohl(th->th_header_length);
109 	*offset = ntohl(th->th_offset);
110 	*length = ntohl(th->th_length);
111 	return (0);
112 }
113 
114 static void
115 signal_alarm(int signum)
116 {
117 	(void)signum;
118 
119 	printf("# test timeout\n");
120 
121 	if (accept_socket > 0)
122 		close(accept_socket);
123 	if (listen_socket > 0)
124 		close(listen_socket);
125 
126 	_exit(-1);
127 }
128 
129 static void
130 setup_alarm(int seconds)
131 {
132 	struct itimerval itv;
133 	bzero(&itv, sizeof(itv));
134 	(void)seconds;
135 	itv.it_value.tv_sec = seconds;
136 
137 	signal(SIGALRM, signal_alarm);
138 	setitimer(ITIMER_REAL, &itv, NULL);
139 }
140 
141 static void
142 cancel_alarm(void)
143 {
144 	struct itimerval itv;
145 	bzero(&itv, sizeof(itv));
146 	setitimer(ITIMER_REAL, &itv, NULL);
147 }
148 
149 static int
150 receive_test(void)
151 {
152 	uint32_t header_length, offset, length, counter;
153 	struct test_header th;
154 	ssize_t len;
155 	char buf[10240];
156 	MD5_CTX md5ctx;
157 	char *rxmd5;
158 
159 	len = read(accept_socket, &th, sizeof(th));
160 	if (len < 0 || (size_t)len < sizeof(th))
161 		FAIL_ERR("read")
162 
163 	if (test_th(&th, &header_length, &offset, &length) != 0)
164 		return (-1);
165 
166 	MD5Init(&md5ctx);
167 
168 	counter = 0;
169 	while (1) {
170 		len = read(accept_socket, buf, sizeof(buf));
171 		if (len < 0 || len == 0)
172 			break;
173 		counter += len;
174 		MD5Update(&md5ctx, buf, len);
175 	}
176 
177 	rxmd5 = MD5End(&md5ctx, NULL);
178 
179 	if ((counter != header_length+length) ||
180 			memcmp(th.th_md5, rxmd5, 33) != 0)
181 		FAIL("receive length mismatch")
182 
183 	free(rxmd5);
184 	return (0);
185 }
186 
187 static void
188 run_child(void)
189 {
190 	struct sockaddr_in sin;
191 	int rc = 0;
192 
193 	listen_socket = socket(PF_INET, SOCK_STREAM, 0);
194 	if (listen_socket < 0) {
195 		printf("# socket: %s\n", strerror(errno));
196 		rc = -1;
197 	}
198 
199 	if (!rc) {
200 		bzero(&sin, sizeof(sin));
201 		sin.sin_len = sizeof(sin);
202 		sin.sin_family = AF_INET;
203 		sin.sin_addr.s_addr = htonl(INADDR_LOOPBACK);
204 		sin.sin_port = htons(TEST_PORT);
205 
206 		if (bind(listen_socket, (struct sockaddr *)&sin, sizeof(sin)) < 0) {
207 			printf("# bind: %s\n", strerror(errno));
208 			rc = -1;
209 		}
210 	}
211 
212 	if (!rc && listen(listen_socket, -1) < 0) {
213 		printf("# listen: %s\n", strerror(errno));
214 		rc = -1;
215 	}
216 
217 	if (!rc) {
218 		accept_socket = accept(listen_socket, NULL, NULL);
219 		setup_alarm(TEST_SECONDS);
220 		if (receive_test() != 0)
221 			rc = -1;
222 	}
223 
224 	cancel_alarm();
225 	if (accept_socket > 0)
226 		close(accept_socket);
227 	if (listen_socket > 0)
228 		close(listen_socket);
229 
230 	_exit(rc);
231 }
232 
233 static int
234 new_test_socket(int *connect_socket)
235 {
236 	struct sockaddr_in sin;
237 	int rc = 0;
238 
239 	*connect_socket = socket(PF_INET, SOCK_STREAM, 0);
240 	if (*connect_socket < 0)
241 		FAIL_ERR("socket")
242 
243 	bzero(&sin, sizeof(sin));
244 	sin.sin_len = sizeof(sin);
245 	sin.sin_family = AF_INET;
246 	sin.sin_addr.s_addr = htonl(INADDR_LOOPBACK);
247 	sin.sin_port = htons(TEST_PORT);
248 
249 	if (connect(*connect_socket, (struct sockaddr *)&sin, sizeof(sin)) < 0)
250 		FAIL_ERR("connect")
251 
252 	return (rc);
253 }
254 
255 static void
256 init_th(struct test_header *th, uint32_t header_length, uint32_t offset,
257 		uint32_t length)
258 {
259 	bzero(th, sizeof(*th));
260 	th->th_magic = htonl(TEST_MAGIC);
261 	th->th_header_length = htonl(header_length);
262 	th->th_offset = htonl(offset);
263 	th->th_length = htonl(length);
264 
265 	MD5FileChunk(path, th->th_md5, offset, length);
266 }
267 
268 static int
269 send_test(int connect_socket, struct sendfile_test test)
270 {
271 	struct test_header th;
272 	struct sf_hdtr hdtr, *hdtrp;
273 	struct iovec headers;
274 	char *header;
275 	ssize_t len;
276 	int length;
277 	off_t off;
278 
279 	len = lseek(file_fd, 0, SEEK_SET);
280 	if (len != 0)
281 		FAIL_ERR("lseek")
282 
283 	struct stat st;
284 	if (fstat(file_fd, &st) < 0)
285 		FAIL_ERR("fstat")
286 	length = st.st_size - test.offset;
287 	if (test.length > 0 && test.length < (uint32_t)length)
288 		length = test.length;
289 
290 	init_th(&th, test.hdr_length, test.offset, length);
291 
292 	len = write(connect_socket, &th, sizeof(th));
293 	if (len != sizeof(th))
294 		return (-1);
295 
296 	if (test.hdr_length != 0) {
297 		header = malloc(test.hdr_length);
298 		if (header == NULL)
299 			FAIL_ERR("malloc")
300 
301 		hdtrp = &hdtr;
302 		bzero(&headers, sizeof(headers));
303 		headers.iov_base = header;
304 		headers.iov_len = test.hdr_length;
305 		bzero(&hdtr, sizeof(hdtr));
306 		hdtr.headers = &headers;
307 		hdtr.hdr_cnt = 1;
308 		hdtr.trailers = NULL;
309 		hdtr.trl_cnt = 0;
310 	} else {
311 		hdtrp = NULL;
312 		header = NULL;
313 	}
314 
315 	if (sendfile(file_fd, connect_socket, test.offset, test.length,
316 				hdtrp, &off, 0) < 0) {
317 		if (header != NULL)
318 			free(header);
319 		FAIL_ERR("sendfile")
320 	}
321 
322 	if (length == 0) {
323 		struct stat sb;
324 
325 		if (fstat(file_fd, &sb) == 0)
326 			length = sb.st_size - test.offset;
327 	}
328 
329 	if (header != NULL)
330 		free(header);
331 
332 	if (off != length)
333 		FAIL("offset != length")
334 
335 	return (0);
336 }
337 
338 static int
339 write_test_file(size_t file_size)
340 {
341 	char *page_buffer;
342 	ssize_t len;
343 	static size_t current_file_size = 0;
344 
345 	if (file_size == current_file_size)
346 		return (0);
347 	else if (file_size < current_file_size) {
348 		if (ftruncate(file_fd, file_size) != 0)
349 			FAIL_ERR("ftruncate");
350 		current_file_size = file_size;
351 		return (0);
352 	}
353 
354 	page_buffer = malloc(file_size);
355 	if (page_buffer == NULL)
356 		FAIL_ERR("malloc")
357 	bzero(page_buffer, file_size);
358 
359 	len = write(file_fd, page_buffer, file_size);
360 	if (len < 0)
361 		FAIL_ERR("write")
362 
363 	len = lseek(file_fd, 0, SEEK_SET);
364 	if (len < 0)
365 		FAIL_ERR("lseek")
366 	if (len != 0)
367 		FAIL("len != 0")
368 
369 	free(page_buffer);
370 	current_file_size = file_size;
371 	return (0);
372 }
373 
374 static void
375 run_parent(void)
376 {
377 	int connect_socket;
378 	int status;
379 	int test_num;
380 	int test_count;
381 	int pid;
382 	size_t desired_file_size = 0;
383 
384 	const int pagesize = getpagesize();
385 
386 	struct sendfile_test tests[] = {
387  		{ .hdr_length = 0, .offset = 0, .length = 1 },
388 		{ .hdr_length = 0, .offset = 0, .length = pagesize },
389 		{ .hdr_length = 0, .offset = 1, .length = 1 },
390 		{ .hdr_length = 0, .offset = 1, .length = pagesize },
391 		{ .hdr_length = 0, .offset = pagesize, .length = pagesize },
392 		{ .hdr_length = 0, .offset = 0, .length = 2*pagesize },
393 		{ .hdr_length = 0, .offset = 0, .length = 0 },
394 		{ .hdr_length = 0, .offset = pagesize, .length = 0 },
395 		{ .hdr_length = 0, .offset = 2*pagesize, .length = 0 },
396 		{ .hdr_length = 0, .offset = TEST_PAGES*pagesize, .length = 0 },
397 		{ .hdr_length = 0, .offset = 0, .length = pagesize,
398 		    .file_size = 1 }
399 	};
400 
401 	test_count = sizeof(tests) / sizeof(tests[0]);
402 	printf("1..%d\n", test_count);
403 
404 	for (test_num = 1; test_num <= test_count; test_num++) {
405 
406 		desired_file_size = tests[test_num - 1].file_size;
407 		if (desired_file_size == 0)
408 			desired_file_size = TEST_PAGES * pagesize;
409 		if (write_test_file(desired_file_size) != 0) {
410 			printf("not ok %d\n", test_num);
411 			continue;
412 		}
413 
414 		pid = fork();
415 		if (pid == -1) {
416 			printf("not ok %d\n", test_num);
417 			continue;
418 		}
419 
420 		if (pid == 0)
421 			run_child();
422 
423 		usleep(250000);
424 
425 		if (new_test_socket(&connect_socket) != 0) {
426 			printf("not ok %d\n", test_num);
427 			kill(pid, SIGALRM);
428 			close(connect_socket);
429 			continue;
430 		}
431 
432 		if (send_test(connect_socket, tests[test_num-1]) != 0) {
433 			printf("not ok %d\n", test_num);
434 			kill(pid, SIGALRM);
435 			close(connect_socket);
436 			continue;
437 		}
438 
439 		close(connect_socket);
440 		if (waitpid(pid, &status, 0) == pid) {
441 			if (WIFEXITED(status) && WEXITSTATUS(status) == 0)
442 				printf("%s %d\n", "ok", test_num);
443 			else
444 				printf("%s %d\n", "not ok", test_num);
445 		}
446 		else {
447 			printf("not ok %d\n", test_num);
448 		}
449 	}
450 }
451 
452 static void
453 cleanup(void)
454 {
455 
456 	unlink(path);
457 }
458 
459 int
460 main(int argc, char *argv[])
461 {
462 
463 	path[0] = '\0';
464 
465 	if (argc == 1) {
466 		snprintf(path, sizeof(path), "sendfile.XXXXXXXXXXXX");
467 		file_fd = mkstemp(path);
468 		if (file_fd == -1)
469 			FAIL_ERR("mkstemp");
470 	} else if (argc == 2) {
471 		(void)strlcpy(path, argv[1], sizeof(path));
472 		file_fd = open(path, O_CREAT | O_TRUNC | O_RDWR, 0600);
473 		if (file_fd == -1)
474 			FAIL_ERR("open");
475 	} else {
476 		FAIL("usage: sendfile [path]");
477 	}
478 
479 	atexit(cleanup);
480 
481 	run_parent();
482 	return (0);
483 }
484