1 /*
2  * Test client to test the NBD server. Doesn't do anything useful, except
3  * checking that the server does, actually, work.
4  *
5  * Note that the only 'real' test is to check the client against a kernel. If
6  * it works here but does not work in the kernel, then that's most likely a bug
7  * in this program and/or in nbd-server.
8  *
9  * Copyright(c) 2006  Wouter Verhelst
10  *
11  * This program is Free Software; you can redistribute it and/or modify it
12  * under the terms of the GNU General Public License as published by the Free
13  * Software Foundation, in version 2.
14  *
15  * This program is distributed in the hope that it will be useful, but WITHOUT
16  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
17  * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for
18  * more details.
19  *
20  * You should have received a copy of the GNU General Public License along with
21  * this program; if not, write to the Free Software Foundation, Inc., 51
22  * Franklin St, Fifth Floor, Boston, MA  02110-1301 USA
23  */
24 #include <stdlib.h>
25 #include <stdio.h>
26 #include <stdbool.h>
27 #include <string.h>
28 #include <sys/time.h>
29 #include <sys/types.h>
30 #include <sys/wait.h>
31 #include <sys/un.h>
32 #include <sys/socket.h>
33 #include <sys/stat.h>
34 #include <sys/mman.h>
35 #include <fcntl.h>
36 #include <syslog.h>
37 #include <unistd.h>
38 #include "config.h"
39 #include "lfs.h"
40 #include <netinet/in.h>
41 #include <glib.h>
42 
43 #define MY_NAME "nbd-tester-client"
44 #include "cliserv.h"
45 
46 #if HAVE_GNUTLS
47 #include "crypto-gnutls.h"
48 #endif
49 
50 static gchar errstr[1024];
51 const static int errstr_len = 1023;
52 
53 static uint64_t size;
54 
55 static int looseordering = 0;
56 
57 static gchar *transactionlog = "nbd-tester-client.tr";
58 static gchar *certfile = NULL;
59 static gchar *keyfile = NULL;
60 static gchar *cacertfile = NULL;
61 static gchar *tlshostname = NULL;
62 
63 typedef enum {
64 	CONNECTION_TYPE_INIT_PASSWD,
65 	CONNECTION_TYPE_CLISERV,
66 	CONNECTION_TYPE_FULL,
67 } CONNECTION_TYPE;
68 
69 typedef enum {
70 	CONNECTION_CLOSE_PROPERLY,
71 	CONNECTION_CLOSE_FAST,
72 } CLOSE_TYPE;
73 
74 struct reqcontext {
75 	uint64_t seq;
76 	char orighandle[8];
77 	struct nbd_request req;
78 	struct reqcontext *next;
79 	struct reqcontext *prev;
80 };
81 
82 struct rclist {
83 	struct reqcontext *head;
84 	struct reqcontext *tail;
85 	int numitems;
86 };
87 
88 struct chunk {
89 	char *buffer;
90 	char *readptr;
91 	char *writeptr;
92 	uint64_t space;
93 	uint64_t length;
94 	struct chunk *next;
95 	struct chunk *prev;
96 };
97 
98 struct chunklist {
99 	struct chunk *head;
100 	struct chunk *tail;
101 	int numitems;
102 };
103 
104 struct blkitem {
105 	uint32_t seq;
106 	int32_t inflightr;
107 	int32_t inflightw;
108 };
109 
rclist_unlink(struct rclist * l,struct reqcontext * p)110 void rclist_unlink(struct rclist *l, struct reqcontext *p)
111 {
112 	if (p && l) {
113 		struct reqcontext *prev = p->prev;
114 		struct reqcontext *next = p->next;
115 
116 		/* Fix link to previous */
117 		if (prev)
118 			prev->next = next;
119 		else
120 			l->head = next;
121 
122 		if (next)
123 			next->prev = prev;
124 		else
125 			l->tail = prev;
126 
127 		p->prev = NULL;
128 		p->next = NULL;
129 		l->numitems--;
130 	}
131 }
132 
133 /* Add a new list item to the tail */
rclist_addtail(struct rclist * l,struct reqcontext * p)134 void rclist_addtail(struct rclist *l, struct reqcontext *p)
135 {
136 	if (!p || !l)
137 		return;
138 	if (l->tail) {
139 		if (l->tail->next)
140 			g_warning("addtail found list tail has a next pointer");
141 		l->tail->next = p;
142 		p->next = NULL;
143 		p->prev = l->tail;
144 		l->tail = p;
145 	} else {
146 		if (l->head)
147 			g_warning("addtail found no list tail but a list head");
148 		l->head = p;
149 		l->tail = p;
150 		p->prev = NULL;
151 		p->next = NULL;
152 	}
153 	l->numitems++;
154 }
155 
chunklist_unlink(struct chunklist * l,struct chunk * p)156 void chunklist_unlink(struct chunklist *l, struct chunk *p)
157 {
158 	if (p && l) {
159 		struct chunk *prev = p->prev;
160 		struct chunk *next = p->next;
161 
162 		/* Fix link to previous */
163 		if (prev)
164 			prev->next = next;
165 		else
166 			l->head = next;
167 
168 		if (next)
169 			next->prev = prev;
170 		else
171 			l->tail = prev;
172 
173 		p->prev = NULL;
174 		p->next = NULL;
175 		l->numitems--;
176 	}
177 }
178 
179 /* Add a new list item to the tail */
chunklist_addtail(struct chunklist * l,struct chunk * p)180 void chunklist_addtail(struct chunklist *l, struct chunk *p)
181 {
182 	if (!p || !l)
183 		return;
184 	if (l->tail) {
185 		if (l->tail->next)
186 			g_warning("addtail found list tail has a next pointer");
187 		l->tail->next = p;
188 		p->next = NULL;
189 		p->prev = l->tail;
190 		l->tail = p;
191 	} else {
192 		if (l->head)
193 			g_warning("addtail found no list tail but a list head");
194 		l->head = p;
195 		l->tail = p;
196 		p->prev = NULL;
197 		p->next = NULL;
198 	}
199 	l->numitems++;
200 }
201 
202 /* Add some new bytes to a chunklist */
addbuffer(struct chunklist * l,void * data,uint64_t len)203 void addbuffer(struct chunklist *l, void *data, uint64_t len)
204 {
205 	void *buf;
206 	uint64_t size = 64 * 1024;
207 	struct chunk *pchunk;
208 
209 	while (len > 0) {
210 		/* First see if there is a current chunk, and if it has space */
211 		if (l->tail && l->tail->space) {
212 			uint64_t towrite = len;
213 			if (towrite > l->tail->space)
214 				towrite = l->tail->space;
215 			memcpy(l->tail->writeptr, data, towrite);
216 			l->tail->length += towrite;
217 			l->tail->space -= towrite;
218 			l->tail->writeptr += towrite;
219 			len -= towrite;
220 			data += towrite;
221 		}
222 
223 		if (len > 0) {
224 			/* We still need to write more, so prepare a new chunk */
225 			if ((NULL == (buf = malloc(size)))
226 			    || (NULL ==
227 				(pchunk = calloc(1, sizeof(struct chunk))))) {
228 				g_critical("Out of memory");
229 				exit(1);
230 			}
231 
232 			pchunk->buffer = buf;
233 			pchunk->readptr = buf;
234 			pchunk->writeptr = buf;
235 			pchunk->space = size;
236 			chunklist_addtail(l, pchunk);
237 		}
238 	}
239 
240 }
241 
242 /* returns 0 on success, -1 on failure */
writebuffer(int fd,struct chunklist * l)243 int writebuffer(int fd, struct chunklist *l)
244 {
245 
246 	struct chunk *pchunk = NULL;
247 	int res;
248 	if (!l)
249 		return 0;
250 
251 	while (!pchunk) {
252 		pchunk = l->head;
253 		if (!pchunk)
254 			return 0;
255 		if (!(pchunk->length) || !(pchunk->readptr)) {
256 			chunklist_unlink(l, pchunk);
257 			free(pchunk->buffer);
258 			free(pchunk);
259 			pchunk = NULL;
260 		}
261 	}
262 
263 	/* OK we have a chunk with some data in */
264 	res = write(fd, pchunk->readptr, pchunk->length);
265 	if (res == 0)
266 		errno = EAGAIN;
267 	if (res <= 0)
268 		return -1;
269 	pchunk->length -= res;
270 	pchunk->readptr += res;
271 	if (!pchunk->length) {
272 		chunklist_unlink(l, pchunk);
273 		free(pchunk->buffer);
274 		free(pchunk);
275 	}
276 	return 0;
277 }
278 
279 #define TEST_WRITE (1<<0)
280 #define TEST_FLUSH (1<<1)
281 #define TEST_EXPECT_ERROR (1<<2)
282 #define TEST_HANDSHAKE (1<<3)
283 
timeval_subtract(struct timeval * result,struct timeval * x,struct timeval * y)284 int timeval_subtract(struct timeval *result, struct timeval *x,
285 		     struct timeval *y)
286 {
287 	if (x->tv_usec < y->tv_usec) {
288 		int nsec = (y->tv_usec - x->tv_usec) / 1000000 + 1;
289 		y->tv_usec -= 1000000 * nsec;
290 		y->tv_sec += nsec;
291 	}
292 
293 	if (x->tv_usec - y->tv_usec > 1000000) {
294 		int nsec = (x->tv_usec - y->tv_usec) / 1000000;
295 		y->tv_usec += 1000000 * nsec;
296 		y->tv_sec -= nsec;
297 	}
298 
299 	result->tv_sec = x->tv_sec - y->tv_sec;
300 	result->tv_usec = x->tv_usec - y->tv_usec;
301 
302 	return x->tv_sec < y->tv_sec;
303 }
304 
timeval_diff_to_double(struct timeval * x,struct timeval * y)305 double timeval_diff_to_double(struct timeval *x, struct timeval *y)
306 {
307 	struct timeval r;
308 	timeval_subtract(&r, x, y);
309 	return r.tv_sec * 1.0 + r.tv_usec / 1000000.0;
310 }
311 
read_all(int f,void * buf,size_t len)312 static inline int read_all(int f, void *buf, size_t len)
313 {
314 	ssize_t res;
315 	size_t retval = 0;
316 
317 	while (len > 0) {
318 		if ((res = read(f, buf, len)) <= 0) {
319 			if (!res)
320 				errno = EAGAIN;
321 			snprintf(errstr, errstr_len, "Read failed: %s",
322 				 strerror(errno));
323 			return -1;
324 		}
325 		len -= res;
326 		buf += res;
327 		retval += res;
328 	}
329 	return retval;
330 }
331 
write_all(int f,void * buf,size_t len)332 static inline int write_all(int f, void *buf, size_t len)
333 {
334 	ssize_t res;
335 	size_t retval = 0;
336 
337 	while (len > 0) {
338 		if ((res = write(f, buf, len)) <= 0) {
339 			if (!res)
340 				errno = EAGAIN;
341 			snprintf(errstr, errstr_len, "Write failed: %s",
342 				 strerror(errno));
343 			return -1;
344 		}
345 		len -= res;
346 		buf += res;
347 		retval += res;
348 	}
349 	return retval;
350 }
351 
tlserrout(void * opaque,const char * format,va_list ap)352 static int tlserrout (void *opaque, const char *format, va_list ap) {
353 	return vfprintf(stderr, format, ap);
354 }
355 
356 #define READ_ALL_ERRCHK(f, buf, len, whereto, errmsg...) if((read_all(f, buf, len))<=0) { snprintf(errstr, errstr_len, ##errmsg); goto whereto; }
357 #define READ_ALL_ERR_RT(f, buf, len, whereto, rval, errmsg...) if((read_all(f, buf, len))<=0) { snprintf(errstr, errstr_len, ##errmsg); retval = rval; goto whereto; }
358 
359 #define WRITE_ALL_ERRCHK(f, buf, len, whereto, errmsg...) if((write_all(f, buf, len))<=0) { snprintf(errstr, errstr_len, ##errmsg); goto whereto; }
360 #define WRITE_ALL_ERR_RT(f, buf, len, whereto, rval, errmsg...) if((write_all(f, buf, len))<=0) { snprintf(errstr, errstr_len, ##errmsg); retval = rval; goto whereto; }
361 
setup_connection_common(int sock,char * name,CONNECTION_TYPE ctype,int * serverflags,int testflags)362 int setup_connection_common(int sock, char *name, CONNECTION_TYPE ctype,
363 			    int *serverflags, int testflags)
364 {
365 	char buf[256];
366 	u64 tmp64;
367 	uint64_t mymagic = (name ? opts_magic : cliserv_magic);
368 	uint32_t tmp32 = 0;
369 	uint16_t handshakeflags = 0;
370 	uint32_t negotiationflags = 0;
371 
372 	if (ctype < CONNECTION_TYPE_INIT_PASSWD)
373 		goto end;
374 	READ_ALL_ERRCHK(sock, buf, strlen(INIT_PASSWD), err,
375 			"Could not read INIT_PASSWD: %s", strerror(errno));
376 	buf[strlen(INIT_PASSWD)] = 0;
377 	if (strlen(buf) == 0) {
378 		snprintf(errstr, errstr_len, "Server closed connection");
379 		goto err;
380 	}
381 	if (strncmp(buf, INIT_PASSWD, strlen(INIT_PASSWD))) {
382 		snprintf(errstr, errstr_len, "INIT_PASSWD does not match");
383 		goto err;
384 	}
385 	if (ctype < CONNECTION_TYPE_CLISERV)
386 		goto end;
387 	READ_ALL_ERRCHK(sock, &tmp64, sizeof(tmp64), err,
388 			"Could not read cliserv_magic: %s", strerror(errno));
389 	tmp64 = ntohll(tmp64);
390 	if (tmp64 != mymagic) {
391 		strncpy(errstr, "mymagic does not match", errstr_len);
392 		goto err;
393 	}
394 	if (ctype < CONNECTION_TYPE_FULL)
395 		goto end;
396 	if (!name) {
397 		READ_ALL_ERRCHK(sock, &size, sizeof(size), err,
398 				"Could not read size: %s", strerror(errno));
399 		size = ntohll(size);
400 		uint32_t flags;
401 		READ_ALL_ERRCHK(sock, &flags, sizeof(uint32_t), err,
402 				"Could not read flags: %s", strerror(errno));
403 		flags = ntohl(flags);
404 		*serverflags = flags;
405 		READ_ALL_ERRCHK(sock, buf, 124, err, "Could not read data: %s",
406 				strerror(errno));
407 		goto end;
408 	}
409 	/* handshake flags */
410 	READ_ALL_ERRCHK(sock, &handshakeflags, sizeof(handshakeflags), err,
411 			"Could not read reserved field: %s", strerror(errno));
412 	handshakeflags = ntohs(handshakeflags);
413 	/* negotiation flags */
414 	if (handshakeflags & NBD_FLAG_FIXED_NEWSTYLE)
415 		negotiationflags |= NBD_FLAG_C_FIXED_NEWSTYLE;
416 	else if (keyfile) {
417 		snprintf(errstr, errstr_len, "Cannot negotiate TLS without NBD_FLAG_FIXED_NEWSTYLE");
418 		goto err;
419 	}
420 	negotiationflags = htonl(negotiationflags);
421 	WRITE_ALL_ERRCHK(sock, &negotiationflags, sizeof(negotiationflags), err,
422 			 "Could not write reserved field: %s", strerror(errno));
423 	if (testflags & TEST_HANDSHAKE) {
424 		/* Server must support newstyle for this test */
425 		if (!(handshakeflags & NBD_FLAG_FIXED_NEWSTYLE)) {
426 			strncpy(errstr, "server does not support handshake", errstr_len);
427 			goto err;
428 		}
429 		goto end;
430 	}
431 #if HAVE_GNUTLS
432 	/* TLS */
433 	if (keyfile) {
434 		int plainfd[2]; // [0] is used by the proxy, [1] is used by NBD
435 		tlssession_t *s = NULL;
436 		int ret;
437 
438 		/* magic */
439 		tmp64 = htonll(opts_magic);
440 		WRITE_ALL_ERRCHK(sock, &tmp64, sizeof(tmp64), err,
441 				 "Could not write magic: %s", strerror(errno));
442 		/* starttls */
443 		tmp32 = htonl(NBD_OPT_STARTTLS);
444 		WRITE_ALL_ERRCHK(sock, &tmp32, sizeof(tmp32), err,
445 			 "Could not write option: %s", strerror(errno));
446 		/* length of data */
447 		tmp32 = htonl(0);
448 		WRITE_ALL_ERRCHK(sock, &tmp32, sizeof(tmp32), err,
449 			 "Could not write option length: %s", strerror(errno));
450 
451 		READ_ALL_ERRCHK(sock, &tmp64, sizeof(tmp64), err,
452 				"Could not read cliserv_magic: %s", strerror(errno));
453 		tmp64 = ntohll(tmp64);
454 		if (tmp64 != NBD_OPT_REPLY_MAGIC) {
455 			strncpy(errstr, "reply magic does not match", errstr_len);
456 			goto err;
457 		}
458 		READ_ALL_ERRCHK(sock, &tmp32, sizeof(tmp32), err,
459 				"Could not read option type: %s", strerror(errno));
460 		tmp32 = ntohl(tmp32);
461 		if (tmp32 != NBD_OPT_STARTTLS) {
462 			strncpy(errstr, "Reply to wrong option", errstr_len);
463 			goto err;
464 		}
465 		READ_ALL_ERRCHK(sock, &tmp32, sizeof(tmp32), err,
466 				"Could not read option reply type: %s", strerror(errno));
467 		tmp32 = ntohl(tmp32);
468 		if (tmp32 != NBD_REP_ACK) {
469 			if(tmp32 & NBD_REP_FLAG_ERROR) {
470 				snprintf(errstr, errstr_len, "Received error %d", tmp32 & ~NBD_REP_FLAG_ERROR);
471 			} else {
472 				snprintf(errstr, errstr_len, "Option reply type %d != NBD_REP_ACK", tmp32);
473 			}
474 			goto err;
475 		}
476 		READ_ALL_ERRCHK(sock, &tmp32, sizeof(tmp32), err,
477 				"Could not read option data length: %s", strerror(errno));
478 		tmp32 = ntohl(tmp32);
479 		if (tmp32 != 0) {
480 			strncpy(errstr, "Option reply data length != 0", errstr_len);
481 			goto err;
482 		}
483 
484 		s = tlssession_new(FALSE,
485 				   keyfile,
486 				   certfile,
487 				   cacertfile,
488 				   tlshostname,
489 				   !cacertfile || !tlshostname, // insecure flag
490 #ifdef DODBG
491 				   1, // debug
492 #else
493 				   0, // debug
494 #endif
495 				   NULL, // quitfn
496 				   tlserrout, // erroutfn
497 				   NULL // opaque
498 			);
499 		if (!s) {
500 			strncpy(errstr, "Cannot establish TLS session", errstr_len);
501 			goto err;
502 		}
503 
504 		if (socketpair(AF_UNIX, SOCK_STREAM, 0, plainfd) < 0) {
505 			strncpy(errstr, "Cannot get socket pair", errstr_len);
506 			goto err;
507 		}
508 
509 		if (set_nonblocking(plainfd[0], 0) <0 ||
510 		    set_nonblocking(plainfd[1], 0) <0 ||
511 		    set_nonblocking(sock, 0) <0) {
512 			close(plainfd[0]);
513 			close(plainfd[1]);
514 			strncpy(errstr, "Cannot set socket options", errstr_len);
515 			goto err;
516 		}
517 
518 		ret = fork();
519 		if (ret < 0)
520 			err("Could not fork");
521 		else if (ret == 0) {
522 			// we are the child
523 			signal (SIGPIPE, SIG_IGN);
524 			close(plainfd[1]);
525 			tlssession_mainloop(sock, plainfd[0], s);
526 			close(sock);
527 			close(plainfd[0]);
528 			exit(0);
529 		}
530 		close(plainfd[0]);
531 		close(sock);
532 		sock = plainfd[1]; /* use the decrypted FD from now on */
533 	}
534 #else
535 	if (keyfile) {
536 		strncpy(errstr, "TLS requested but support not compiled in", errstr_len);
537 		goto err;
538 	}
539 #endif
540 	if(testflags & TEST_EXPECT_ERROR) {
541 		struct sigaction act;
542 		memset(&act, 0, sizeof act);
543 		act.sa_handler = SIG_IGN;
544 		sigaction(SIGPIPE, &act, NULL);
545 	}
546 	/* magic */
547 	tmp64 = htonll(opts_magic);
548 	WRITE_ALL_ERRCHK(sock, &tmp64, sizeof(tmp64), err,
549 			 "Could not write magic: %s", strerror(errno));
550 	/* name */
551 	tmp32 = htonl(NBD_OPT_EXPORT_NAME);
552 	WRITE_ALL_ERRCHK(sock, &tmp32, sizeof(tmp32), err,
553 			 "Could not write option: %s", strerror(errno));
554 	tmp32 = htonl((uint32_t) strlen(name));
555 	WRITE_ALL_ERRCHK(sock, &tmp32, sizeof(tmp32), err,
556 			 "Could not write name length: %s", strerror(errno));
557 	WRITE_ALL_ERRCHK(sock, name, strlen(name), err,
558 			 "Could not write name:: %s", strerror(errno));
559 	READ_ALL_ERRCHK(sock, &size, sizeof(size), err,
560 			"Could not read size: %s", strerror(errno));
561 	size = ntohll(size);
562 	uint16_t flags;
563 	READ_ALL_ERRCHK(sock, &flags, sizeof(uint16_t), err,
564 			"Could not read flags: %s", strerror(errno));
565 	flags = ntohs(flags);
566 	*serverflags = flags;
567 	READ_ALL_ERRCHK(sock, buf, 124, err,
568 			"Could not read reserved zeroes: %s", strerror(errno));
569 	goto end;
570 err:
571 	close(sock);
572 	sock = -1;
573 end:
574 	return sock;
575 }
576 
setup_unix_connection(gchar * unixsock)577 int setup_unix_connection(gchar * unixsock)
578 {
579 	struct sockaddr_un addr;
580 	int sock;
581 
582 	sock = 0;
583 	if ((sock = socket(AF_UNIX, SOCK_STREAM, 0)) < 0) {
584 		strncpy(errstr, strerror(errno), errstr_len);
585 		goto err;
586 	}
587 
588 	setmysockopt(sock);
589 	memset(&addr, 0, sizeof(struct sockaddr_un));
590 	addr.sun_family = AF_UNIX;
591 	strncpy(addr.sun_path, unixsock, sizeof addr.sun_path);
592 	addr.sun_path[sizeof(addr.sun_path)-1] = '\0';
593 	if (connect(sock, (struct sockaddr *)&addr, sizeof(addr)) < 0) {
594 		strncpy(errstr, strerror(errno), errstr_len);
595 		goto err_open;
596 	}
597 	goto end;
598 err_open:
599 	close(sock);
600 err:
601 	sock = -1;
602 end:
603 	return sock;
604 }
605 
setup_inet_connection(gchar * hostname,int port)606 int setup_inet_connection(gchar * hostname, int port)
607 {
608 	int sock;
609 	struct hostent *host;
610 	struct sockaddr_in addr;
611 
612 	sock = 0;
613 	if ((sock = socket(PF_INET, SOCK_STREAM, IPPROTO_TCP)) < 0) {
614 		strncpy(errstr, strerror(errno), errstr_len);
615 		goto err;
616 	}
617 	setmysockopt(sock);
618 	if (!(host = gethostbyname(hostname))) {
619 		strncpy(errstr, hstrerror(h_errno), errstr_len);
620 		goto err_open;
621 	}
622 	addr.sin_family = AF_INET;
623 	addr.sin_port = htons(port);
624 	addr.sin_addr.s_addr = *((int *)host->h_addr);
625 	memset(&addr.sin_zero, 0, sizeof(addr.sin_zero));
626 	if ((connect(sock, (struct sockaddr *)&addr, sizeof(addr)) < 0)) {
627 		strncpy(errstr, strerror(errno), errstr_len);
628 		goto err_open;
629 	}
630 	goto end;
631 err_open:
632 	close(sock);
633 err:
634 	sock = -1;
635 end:
636 	return sock;
637 }
638 
setup_inetd_connection(gchar ** argv)639 int setup_inetd_connection(gchar **argv)
640 {
641 	int sv[2], status;
642 	pid_t child;
643 
644 	if (socketpair(AF_UNIX, SOCK_STREAM, 0, sv) == -1) {
645 		strncpy(errstr, strerror(errno), errstr_len);
646 		return -1;
647 	}
648 
649 	child = vfork();
650 	if (child == 0) {
651 		dup2(sv[0], 0);
652 		close(sv[0]);
653 		close(sv[1]);
654 		execvp(argv[0], argv);
655 		perror("execvp");
656 		_exit(-1);
657 	} else if (child == -1) {
658 		close(sv[0]);
659 		close(sv[1]);
660 		strncpy(errstr, strerror(errno), errstr_len);
661 		return -1;
662 	}
663 
664 	close(sv[0]);
665 	if (waitpid(child, &status, WNOHANG)) {
666 		close(sv[1]);
667 		return -1;
668 	}
669 
670 	setmysockopt(sv[1]);
671 	return sv[1];
672 }
673 
close_connection(int sock,CLOSE_TYPE type)674 int close_connection(int sock, CLOSE_TYPE type)
675 {
676 	struct nbd_request req;
677 	u64 counter = 0;
678 
679 	switch (type) {
680 	case CONNECTION_CLOSE_PROPERLY:
681 		req.magic = htonl(NBD_REQUEST_MAGIC);
682 		req.type = htonl(NBD_CMD_DISC);
683 		memcpy(&(req.handle), &(counter), sizeof(counter));
684 		counter++;
685 		req.from = 0;
686 		req.len = 0;
687 		if (write(sock, &req, sizeof(req)) < 0) {
688 			snprintf(errstr, errstr_len,
689 				 "Could not write to socket: %s",
690 				 strerror(errno));
691 			return -1;
692 		}
693 		/* falls through */
694 	case CONNECTION_CLOSE_FAST:
695 		if (close(sock) < 0) {
696 			snprintf(errstr, errstr_len,
697 				 "Could not close socket: %s", strerror(errno));
698 			return -1;
699 		}
700 		break;
701 	default:
702 		g_critical("Your compiler is on crack!");	/* or I am buggy */
703 		return -1;
704 	}
705 	return 0;
706 }
707 
read_packet_check_header(int sock,size_t datasize,long long int curhandle)708 int read_packet_check_header(int sock, size_t datasize, long long int curhandle)
709 {
710 	struct nbd_reply rep;
711 	int retval = 0;
712 	char buf[datasize];
713 
714 	READ_ALL_ERR_RT(sock, &rep, sizeof(rep), end, -1,
715 			"Could not read reply header: %s", strerror(errno));
716 	rep.magic = ntohl(rep.magic);
717 	rep.error = ntohl(rep.error);
718 	if (rep.magic != NBD_REPLY_MAGIC) {
719 		snprintf(errstr, errstr_len,
720 			 "Received package with incorrect reply_magic. Index of sent packages is %lld (0x%llX), received handle is %lld (0x%llX). Received magic 0x%lX, expected 0x%lX",
721 			 (long long int)curhandle,
722 			 (long long unsigned int)curhandle,
723 			 (long long int)*((u64 *) rep.handle),
724 			 (long long unsigned int)*((u64 *) rep.handle),
725 			 (long unsigned int)rep.magic,
726 			 (long unsigned int)NBD_REPLY_MAGIC);
727 		retval = -1;
728 		goto end;
729 	}
730 	if (rep.error) {
731 		snprintf(errstr, errstr_len,
732 			 "Received error from server: %ld (0x%lX). Handle is %lld (0x%llX).",
733 			 (long int)rep.error, (long unsigned int)rep.error,
734 			 (long long int)(*((u64 *) rep.handle)),
735 			 (long long unsigned int)*((u64 *) rep.handle));
736 		retval = -2;
737 		goto end;
738 	}
739 	if (datasize)
740 		READ_ALL_ERR_RT(sock, &buf, datasize, end, -1,
741 				"Could not read data: %s", strerror(errno));
742 
743 end:
744 	return retval;
745 }
746 
oversize_test(char * name,int sock,char close_sock,int testflags)747 int oversize_test(char *name, int sock, char close_sock, int testflags)
748 {
749 	int retval = 0;
750 	struct nbd_request req;
751 	struct nbd_reply rep;
752 	int i = 0;
753 	int serverflags = 0;
754 	pid_t G_GNUC_UNUSED mypid = getpid();
755 	char buf[((1024 * 1024) + sizeof(struct nbd_request) / 2) << 1];
756 	bool got_err;
757 
758 	/* This should work */
759 	if ((sock =
760 		 setup_connection_common(sock, name,
761 				  CONNECTION_TYPE_FULL,
762 				  &serverflags, testflags)) < 0) {
763 		g_warning("Could not open socket: %s", errstr);
764 		retval = -1;
765 		goto err;
766 	}
767 	req.magic = htonl(NBD_REQUEST_MAGIC);
768 	req.type = htonl(NBD_CMD_READ);
769 	req.len = htonl(1024 * 1024);
770 	memcpy(&(req.handle), &i, sizeof(i));
771 	req.from = htonll(i);
772 	WRITE_ALL_ERR_RT(sock, &req, sizeof(req), err, -1,
773 			 "Could not write request: %s", strerror(errno));
774 	printf("%d: testing oversized request: %d: ", getpid(), ntohl(req.len));
775 	READ_ALL_ERR_RT(sock, &rep, sizeof(struct nbd_reply), err, -1,
776 			"Could not read reply header: %s", strerror(errno));
777 	READ_ALL_ERR_RT(sock, &buf, ntohl(req.len), err, -1,
778 			"Could not read data: %s", strerror(errno));
779 	if (rep.error) {
780 		snprintf(errstr, errstr_len, "Received unexpected error: %d",
781 			 rep.error);
782 		retval = -1;
783 		goto err;
784 	} else {
785 		printf("OK\n");
786 	}
787 	/* This probably should not work */
788 	i++;
789 	req.from = htonll(i);
790 	req.len = htonl(ntohl(req.len) + sizeof(struct nbd_request) / 2);
791 	WRITE_ALL_ERR_RT(sock, &req, sizeof(req), err, -1,
792 			 "Could not write request: %s", strerror(errno));
793 	printf("%d: testing oversized request: %d: ", getpid(), ntohl(req.len));
794 	READ_ALL_ERR_RT(sock, &rep, sizeof(struct nbd_reply), err, -1,
795 			"Could not read reply header: %s", strerror(errno));
796 	READ_ALL_ERR_RT(sock, &buf, ntohl(req.len), err, -1,
797 			"Could not read data: %s", strerror(errno));
798 	if (rep.error) {
799 		printf("Received expected error\n");
800 		got_err = true;
801 	} else {
802 		printf("OK\n");
803 		got_err = false;
804 	}
805 	/* ... unless this works, too */
806 	i++;
807 	req.from = htonll(i);
808 	req.len = htonl(ntohl(req.len) << 1);
809 	WRITE_ALL_ERR_RT(sock, &req, sizeof(req), err, -1,
810 			 "Could not write request: %s", strerror(errno));
811 	printf("%d: testing oversized request: %d: ", getpid(), ntohl(req.len));
812 	READ_ALL_ERR_RT(sock, &rep, sizeof(struct nbd_reply), err, -1,
813 			"Could not read reply header: %s", strerror(errno));
814 	READ_ALL_ERR_RT(sock, &buf, ntohl(req.len), err, -1,
815 			"Could not read data: %s", strerror(errno));
816 	if (rep.error) {
817 		printf("error\n");
818 	} else {
819 		printf("OK\n");
820 	}
821 	if ((rep.error && !got_err) || (!rep.error && got_err)) {
822 		printf("Received unexpected error\n");
823 		retval = -1;
824 	}
825 err:
826 	return retval;
827 }
828 
handshake_test(char * name,int sock,char close_sock,int testflags)829 int handshake_test(char *name, int sock, char close_sock, int testflags)
830 {
831 	int retval = -1;
832 	int serverflags = 0;
833 	u64 tmp64;
834 	uint32_t tmp32 = 0;
835 
836 	/* This should work */
837 	if ((sock =
838 		 setup_connection_common(sock, name,
839 				  CONNECTION_TYPE_FULL,
840 				  &serverflags, testflags)) < 0) {
841 		g_warning("Could not open socket: %s", errstr);
842 		goto err;
843 	}
844 
845 	/* Intentionally throw an unknown option at the server */
846 	tmp64 = htonll(opts_magic);
847 	WRITE_ALL_ERRCHK(sock, &tmp64, sizeof(tmp64), err,
848 			 "Could not write magic: %s", strerror(errno));
849 	tmp32 = htonl(0x7654321);
850 	WRITE_ALL_ERRCHK(sock, &tmp32, sizeof(tmp32), err,
851 			 "Could not write option: %s", strerror(errno));
852 	tmp32 = htonl((uint32_t) sizeof(tmp32));
853 	WRITE_ALL_ERRCHK(sock, &tmp32, sizeof(tmp32), err,
854 			 "Could not write option length: %s", strerror(errno));
855 	WRITE_ALL_ERRCHK(sock, &tmp32, sizeof(tmp32), err,
856 			 "Could not write option payload: %s", strerror(errno));
857 	/* Expect proper error from server */
858 	READ_ALL_ERRCHK(sock, &tmp64, sizeof(tmp64), err,
859 			"Could not read magic: %s", strerror(errno));
860 	tmp64 = ntohll(tmp64);
861 	if (tmp64 != 0x3e889045565a9LL) {
862 		strncpy(errstr, "magic does not match", errstr_len);
863 		goto err;
864 	}
865 	READ_ALL_ERRCHK(sock, &tmp32, sizeof(tmp32), err,
866 			"Could not read option: %s", strerror(errno));
867 	tmp32 = ntohl(tmp32);
868 	if (tmp32 != 0x7654321) {
869 		strncpy(errstr, "option does not match", errstr_len);
870 		goto err;
871 	}
872 	READ_ALL_ERRCHK(sock, &tmp32, sizeof(tmp32), err,
873 			"Could not read status: %s", strerror(errno));
874 	tmp32 = ntohl(tmp32);
875 	if (tmp32 != NBD_REP_ERR_UNSUP) {
876 		strncpy(errstr, "status does not match", errstr_len);
877 		goto err;
878 	}
879 	READ_ALL_ERRCHK(sock, &tmp32, sizeof(tmp32), err,
880 			"Could not read length: %s", strerror(errno));
881 	tmp32 = ntohl(tmp32);
882 	while (tmp32) {
883 		char buf[1024];
884 		size_t len = tmp32 < sizeof(buf) ? tmp32 : sizeof(buf);
885 		READ_ALL_ERRCHK(sock, buf, len, err,
886 				"Could not read payload: %s", strerror(errno));
887 		tmp32 -= len;
888 	}
889 
890 
891 	/* Send NBD_OPT_ABORT to close the connection */
892 	tmp64 = htonll(opts_magic);
893 	WRITE_ALL_ERRCHK(sock, &tmp64, sizeof(tmp64), err,
894 			 "Could not write magic: %s", strerror(errno));
895 	tmp32 = htonl(NBD_OPT_ABORT);
896 	WRITE_ALL_ERRCHK(sock, &tmp32, sizeof(tmp32), err,
897 			 "Could not write option: %s", strerror(errno));
898 	tmp32 = htonl((uint32_t) 0);
899 	WRITE_ALL_ERRCHK(sock, &tmp32, sizeof(tmp32), err,
900 			 "Could not write option length: %s", strerror(errno));
901 
902 	retval = 0;
903 
904 	g_message("Handshake test completed. No errors encountered.");
905 err:
906 	return retval;
907 }
908 
throughput_test(char * name,int sock,char close_sock,int testflags)909 int throughput_test(char *name, int sock, char close_sock, int testflags)
910 {
911 	long long int i;
912 	char writebuf[1024];
913 	struct nbd_request req;
914 	int requests = 0;
915 	fd_set set;
916 	struct timeval tv;
917 	struct timeval start;
918 	struct timeval stop;
919 	double timespan;
920 	double speed;
921 	char speedchar[2] = { '\0', '\0' };
922 	int retval = 0;
923 	int serverflags = 0;
924 	signed int do_write = TRUE;
925 	pid_t mypid = getpid();
926 	char *print = getenv("NBD_TEST_SILENT");
927 
928 	if (!(testflags & TEST_WRITE))
929 		testflags &= ~TEST_FLUSH;
930 
931 	memset(writebuf, 'X', 1024);
932 	size = 0;
933 	if ((sock =
934 		 setup_connection_common(sock, name,
935 				  CONNECTION_TYPE_FULL,
936 				  &serverflags, testflags)) < 0) {
937 		g_warning("Could not open socket: %s", errstr);
938 		if(testflags & TEST_EXPECT_ERROR) {
939 			g_message("Test failed, as expected");
940 			retval = 0;
941 		} else {
942 			retval = -1;
943 		}
944 		goto err;
945 	}
946 	if ((testflags & TEST_FLUSH)
947 	    && ((serverflags & (NBD_FLAG_SEND_FLUSH | NBD_FLAG_SEND_FUA))
948 		!= (NBD_FLAG_SEND_FLUSH | NBD_FLAG_SEND_FUA))) {
949 		snprintf(errstr, errstr_len,
950 			 "Server did not supply flush capability flags");
951 		retval = -1;
952 		goto err_open;
953 	}
954 	req.magic = htonl(NBD_REQUEST_MAGIC);
955 	req.len = htonl(1024);
956 	if (gettimeofday(&start, NULL) < 0) {
957 		retval = -1;
958 		snprintf(errstr, errstr_len, "Could not measure start time: %s",
959 			 strerror(errno));
960 		goto err_open;
961 	}
962 	for (i = 0; i + 1024 <= size; i += 1024) {
963 		if (do_write) {
964 			int sendfua = (testflags & TEST_FLUSH)
965 			    && (((i >> 10) & 15) == 3);
966 			int sendflush = (testflags & TEST_FLUSH)
967 			    && (((i >> 10) & 15) == 11);
968 			req.type =
969 			    htonl((testflags & TEST_WRITE) ? NBD_CMD_WRITE :
970 				  NBD_CMD_READ);
971 			if (sendfua)
972 				req.type =
973 				    htonl(NBD_CMD_WRITE | NBD_CMD_FLAG_FUA);
974 			memcpy(&(req.handle), &i, sizeof(i));
975 			req.from = htonll(i);
976 			if (write_all(sock, &req, sizeof(req)) < 0) {
977 				retval = -1;
978 				goto err_open;
979 			}
980 			if (testflags & TEST_WRITE) {
981 				if (write_all(sock, writebuf, 1024) < 0) {
982 					retval = -1;
983 					goto err_open;
984 				}
985 			}
986 			++requests;
987 			if (sendflush) {
988 				long long int j = i ^ (1LL << 63);
989 				req.type = htonl(NBD_CMD_FLUSH);
990 				memcpy(&(req.handle), &j, sizeof(j));
991 				req.from = 0;
992 				req.len = 0;
993 				if (write_all(sock, &req, sizeof(req)) < 0) {
994 					retval = -1;
995 					goto err_open;
996 				}
997 				req.len = htonl(1024);
998 				++requests;
999 			}
1000 		}
1001 		do {
1002 			FD_ZERO(&set);
1003 			FD_SET(sock, &set);
1004 			tv.tv_sec = 0;
1005 			tv.tv_usec = 0;
1006 			select(sock + 1, &set, NULL, NULL, &tv);
1007 			if (FD_ISSET(sock, &set)) {
1008 				/* Okay, there's something ready for
1009 				 * reading here */
1010 				int rv;
1011 				if ((rv =
1012 				     read_packet_check_header(sock,
1013 							      (testflags &
1014 							       TEST_WRITE) ? 0 :
1015 							      1024, i)) < 0) {
1016 					if (!(testflags & TEST_EXPECT_ERROR)
1017 					    || rv != -2) {
1018 						retval = -1;
1019 					} else {
1020 						printf("\n");
1021 					}
1022 					goto err_open;
1023 				} else {
1024 					if (testflags & TEST_EXPECT_ERROR) {
1025 						retval = -1;
1026 						goto err_open;
1027 					}
1028 				}
1029 				--requests;
1030 			}
1031 		} while (FD_ISSET(sock, &set));
1032 		/* Now wait until we can write again or until a second have
1033 		 * passed, whichever comes first*/
1034 		FD_ZERO(&set);
1035 		FD_SET(sock, &set);
1036 		tv.tv_sec = 1;
1037 		tv.tv_usec = 0;
1038 		do_write = select(sock + 1, NULL, &set, NULL, &tv);
1039 		if (!do_write)
1040 			printf("Select finished\n");
1041 		if (do_write < 0) {
1042 			snprintf(errstr, errstr_len, "select: %s",
1043 				 strerror(errno));
1044 			retval = -1;
1045 			goto err_open;
1046 		}
1047 		if(print == NULL) {
1048 			printf("%d: Requests: %d  \r", (int)mypid, requests);
1049 		}
1050 	}
1051 	/* Now empty the read buffer */
1052 	do {
1053 		FD_ZERO(&set);
1054 		FD_SET(sock, &set);
1055 		tv.tv_sec = 0;
1056 		tv.tv_usec = 0;
1057 		select(sock + 1, &set, NULL, NULL, &tv);
1058 		if (FD_ISSET(sock, &set)) {
1059 			/* Okay, there's something ready for
1060 			 * reading here */
1061 			read_packet_check_header(sock,
1062 						 (testflags & TEST_WRITE) ? 0 :
1063 						 1024, i);
1064 			--requests;
1065 		}
1066 		if(print == NULL) {
1067 			printf("%d: Requests: %d  \r", (int)mypid, requests);
1068 		}
1069 	} while (requests);
1070 	printf("%d: Requests: %d  \n", (int)mypid, requests);
1071 	if (gettimeofday(&stop, NULL) < 0) {
1072 		retval = -1;
1073 		snprintf(errstr, errstr_len, "Could not measure end time: %s",
1074 			 strerror(errno));
1075 		goto err_open;
1076 	}
1077 	timespan = timeval_diff_to_double(&stop, &start);
1078 	speed = size / timespan;
1079 	if (speed > 1024) {
1080 		speed = speed / 1024.0;
1081 		speedchar[0] = 'K';
1082 	}
1083 	if (speed > 1024) {
1084 		speed = speed / 1024.0;
1085 		speedchar[0] = 'M';
1086 	}
1087 	if (speed > 1024) {
1088 		speed = speed / 1024.0;
1089 		speedchar[0] = 'G';
1090 	}
1091 	g_message
1092 	    ("%d: Throughput %s test (%s flushes) complete. Took %.3f seconds to complete, %.3f%sib/s",
1093 	     (int)getpid(), (testflags & TEST_WRITE) ? "write" : "read",
1094 	     (testflags & TEST_FLUSH) ? "with" : "without", timespan, speed,
1095 	     speedchar);
1096 
1097 err_open:
1098 	if (close_sock) {
1099 		close_connection(sock, CONNECTION_CLOSE_PROPERLY);
1100 	}
1101 err:
1102 	return retval;
1103 }
1104 
1105 /*
1106  * fill 512 byte buffer 'buf' with a hashed selection of interesting data based
1107  * only on handle and blknum. The first word is blknum, and the second handle, for ease
1108  * of understanding. Things with handle 0 are blank.
1109  */
makebuf(char * buf,uint64_t seq,uint64_t blknum)1110 static inline void makebuf(char *buf, uint64_t seq, uint64_t blknum)
1111 {
1112 	uint64_t x = ((uint64_t) blknum) ^ (seq << 32) ^ (seq >> 32);
1113 	uint64_t *p = (uint64_t *) buf;
1114 	int i;
1115 	if (!seq) {
1116 		bzero(buf, 512);
1117 		return;
1118 	}
1119 	for (i = 0; i < 512 / sizeof(uint64_t); i++) {
1120 		int s;
1121 		*(p++) = x;
1122 		x += 0xFEEDA1ECDEADBEEFULL + i + (((uint64_t) i) << 56);
1123 		s = x & 63;
1124 		x = x ^ (x << s) ^ (x >> (64 - s)) ^ 0xAA55AA55AA55AA55ULL ^
1125 		    seq;
1126 	}
1127 }
1128 
checkbuf(char * buf,uint64_t seq,uint64_t blknum)1129 static inline int checkbuf(char *buf, uint64_t seq, uint64_t blknum)
1130 {
1131 	uint64_t cmp[64];	// 512/8 = 64
1132 	makebuf((char *)cmp, seq, blknum);
1133 	return memcmp(cmp, buf, 512) ? -1 : 0;
1134 }
1135 
dumpcommand(char * text,uint32_t command)1136 static inline void dumpcommand(char *text, uint32_t command)
1137 {
1138 #ifdef DEBUG_COMMANDS
1139 	command = ntohl(command);
1140 	char *ctext;
1141 	switch (command & NBD_CMD_MASK_COMMAND) {
1142 	case NBD_CMD_READ:
1143 		ctext = "NBD_CMD_READ";
1144 		break;
1145 	case NBD_CMD_WRITE:
1146 		ctext = "NBD_CMD_WRITE";
1147 		break;
1148 	case NBD_CMD_DISC:
1149 		ctext = "NBD_CMD_DISC";
1150 		break;
1151 	case NBD_CMD_FLUSH:
1152 		ctext = "NBD_CMD_FLUSH";
1153 		break;
1154 	default:
1155 		ctext = "UNKNOWN";
1156 		break;
1157 	}
1158 	printf("%s: %s [%s] (0x%08x)\n",
1159 	       text,
1160 	       ctext, (command & NBD_CMD_FLAG_FUA) ? "FUA" : "NONE", command);
1161 #endif
1162 }
1163 
1164 /* return an unused handle */
getrandomhandle(GHashTable * phash)1165 uint64_t getrandomhandle(GHashTable * phash)
1166 {
1167 	uint64_t handle = 0;
1168 	int i;
1169 	do {
1170 		/* RAND_MAX may be as low as 2^15 */
1171 		for (i = 1; i <= 5; i++)
1172 			handle ^= random() ^ (handle << 15);
1173 	} while (g_hash_table_lookup(phash, &handle));
1174 	return handle;
1175 }
1176 
integrity_test(char * name,int sock,char close_sock,int testflags)1177 int integrity_test(char *name, int sock, char close_sock, int testflags)
1178 {
1179 	struct nbd_reply rep;
1180 	fd_set rset;
1181 	fd_set wset;
1182 	struct timeval tv;
1183 	struct timeval start;
1184 	struct timeval stop;
1185 	double timespan;
1186 	double speed;
1187 	char speedchar[2] = { '\0', '\0' };
1188 	int retval = -1;
1189 	int serverflags = 0;
1190 	pid_t G_GNUC_UNUSED mypid = getpid();
1191 	int blkhashfd = -1;
1192 	char *blkhashname = NULL;
1193 	struct blkitem *blkhash = NULL;
1194 	int logfd = -1;
1195 	uint64_t seq = 1;
1196 	uint64_t processed = 0;
1197 	uint64_t printer = 0;
1198 	char *do_print = getenv("NBD_TEST_SILENT");
1199 	uint64_t xfer = 0;
1200 	int readtransactionfile = 1;
1201 	int blocked = 0;
1202 	struct rclist txqueue = { NULL, NULL, 0 };
1203 	struct rclist inflight = { NULL, NULL, 0 };
1204 	struct chunklist txbuf = { NULL, NULL, 0 };
1205 
1206 	GHashTable *handlehash = g_hash_table_new(g_int64_hash, g_int64_equal);
1207 
1208 	size = 0;
1209 	if ((sock =
1210 		 setup_connection_common(sock, name,
1211 				  CONNECTION_TYPE_FULL,
1212 				  &serverflags, testflags)) < 0) {
1213 		g_warning("Could not open socket: %s", errstr);
1214 		goto err;
1215 	}
1216 
1217 	if ((serverflags & (NBD_FLAG_SEND_FLUSH | NBD_FLAG_SEND_FUA))
1218 	    != (NBD_FLAG_SEND_FLUSH | NBD_FLAG_SEND_FUA))
1219 		g_warning
1220 		    ("Server flags do not support FLUSH and FUA - these may error");
1221 
1222 #ifdef HAVE_MKSTEMP
1223 	blkhashname = strdup("/tmp/blkarray-XXXXXX");
1224 	if (!blkhashname || (-1 == (blkhashfd = mkstemp(blkhashname)))) {
1225 		g_warning("Could not open temp file: %s", strerror(errno));
1226 		goto err;
1227 	}
1228 #else
1229 	/* use tmpnam here to avoid further feature test nightmare */
1230 	if (-1 == (blkhashfd = open(blkhashname = strdup(tmpnam(NULL)),
1231 				    O_CREAT | O_RDWR,
1232 				    S_IRUSR | S_IWUSR | S_IRGRP | S_IROTH))) {
1233 		g_warning("Could not open temp file: %s", strerror(errno));
1234 		goto err;
1235 	}
1236 #endif
1237 	/* Ensure space freed if we die */
1238 	if (-1 == unlink(blkhashname)) {
1239 		g_warning("Could not unlink temp file: %s", strerror(errno));
1240 		goto err;
1241 	}
1242 
1243 	if (-1 ==
1244 	    lseek(blkhashfd, (off_t) ((size >> 9) * sizeof(struct blkitem)),
1245 		  SEEK_SET)) {
1246 		g_warning("Could not llseek temp file: %s", strerror(errno));
1247 		goto err;
1248 	}
1249 
1250 	if (-1 == write(blkhashfd, "\0", 1)) {
1251 		g_warning("Could not write temp file: %s", strerror(errno));
1252 		goto err;
1253 	}
1254 
1255 	if (NULL == (blkhash = mmap(NULL,
1256 				    (size >> 9) * sizeof(struct blkitem),
1257 				    PROT_READ | PROT_WRITE,
1258 				    MAP_SHARED, blkhashfd, 0))) {
1259 		g_warning("Could not mmap temp file: %s", strerror(errno));
1260 		goto err;
1261 	}
1262 
1263 	if (-1 == (logfd = open(transactionlog, O_RDONLY))) {
1264 		g_warning("Could open log file: %s", strerror(errno));
1265 		goto err;
1266 	}
1267 
1268 	if (gettimeofday(&start, NULL) < 0) {
1269 		snprintf(errstr, errstr_len, "Could not measure start time: %s",
1270 			 strerror(errno));
1271 		goto err_open;
1272 	}
1273 
1274 	while (readtransactionfile || txqueue.numitems || txbuf.numitems
1275 	       || inflight.numitems) {
1276 		int ret;
1277 
1278 		uint32_t magic;
1279 		uint32_t command;
1280 		uint64_t from;
1281 		uint32_t len;
1282 		struct reqcontext *prc;
1283 
1284 		*errstr = 0;
1285 
1286 		FD_ZERO(&wset);
1287 		FD_ZERO(&rset);
1288 		if (readtransactionfile)
1289 			FD_SET(logfd, &rset);
1290 		if ((!blocked && txqueue.numitems) || txbuf.numitems)
1291 			FD_SET(sock, &wset);
1292 		if (inflight.numitems)
1293 			FD_SET(sock, &rset);
1294 		tv.tv_sec = 5;
1295 		tv.tv_usec = 0;
1296 		ret =
1297 		    select(1 + ((sock > logfd) ? sock : logfd), &rset, &wset,
1298 			   NULL, &tv);
1299 		if (ret == 0) {
1300 			snprintf(errstr, errstr_len,
1301 				 "Timeout reading from socket");
1302 			goto err_open;
1303 		} else if (ret < 0) {
1304 			g_warning("Could not mmap temp file: %s", errstr);
1305 			goto err;
1306 		}
1307 		/* We know we've got at least one thing to do here then */
1308 
1309 		/* Get a command from the transaction log */
1310 		if (FD_ISSET(logfd, &rset)) {
1311 
1312 			/* Read a request or reply from the transaction file */
1313 			READ_ALL_ERRCHK(logfd,
1314 					&magic,
1315 					sizeof(magic),
1316 					err_open,
1317 					"Could not read transaction log: %s",
1318 					strerror(errno));
1319 			magic = ntohl(magic);
1320 			switch (magic) {
1321 			case NBD_REQUEST_MAGIC:
1322 				if (NULL ==
1323 				    (prc =
1324 				     calloc(1, sizeof(struct reqcontext)))) {
1325 					snprintf(errstr, errstr_len,
1326 						 "Could not allocate request");
1327 					goto err_open;
1328 				}
1329 				READ_ALL_ERRCHK(logfd,
1330 						sizeof(magic) +
1331 						(char *)&(prc->req),
1332 						sizeof(struct nbd_request) -
1333 						sizeof(magic), err_open,
1334 						"Could not read transaction log: %s",
1335 						strerror(errno));
1336 				prc->req.magic = htonl(NBD_REQUEST_MAGIC);
1337 				memcpy(prc->orighandle, prc->req.handle, 8);
1338 				prc->seq = seq++;
1339 				if ((ntohl(prc->req.type) &
1340 				     NBD_CMD_MASK_COMMAND) == NBD_CMD_DISC) {
1341 					/* no more to read; don't enqueue as no reply
1342 					 * we will disconnect manually at the end
1343 					 */
1344 					readtransactionfile = 0;
1345 					free(prc);
1346 				} else {
1347 					dumpcommand("Enqueuing command",
1348 						    prc->req.type);
1349 					rclist_addtail(&txqueue, prc);
1350 				}
1351 				prc = NULL;
1352 				break;
1353 			case NBD_REPLY_MAGIC:
1354 				READ_ALL_ERRCHK(logfd,
1355 						sizeof(magic) + (char *)(&rep),
1356 						sizeof(struct nbd_reply) -
1357 						sizeof(magic), err_open,
1358 						"Could not read transaction log: %s",
1359 						strerror(errno));
1360 
1361 				if (rep.error) {
1362 					snprintf(errstr, errstr_len,
1363 						 "Transaction log file contained errored transaction");
1364 					goto err_open;
1365 				}
1366 
1367 				/* We do not need to consume data on a read reply as there is
1368 				 * none in the log */
1369 				break;
1370 			default:
1371 				snprintf(errstr, errstr_len,
1372 					 "Could not measure start time: %08x",
1373 					 magic);
1374 				goto err_open;
1375 			}
1376 		}
1377 
1378 		/* See if we have a write we can do */
1379 		if (FD_ISSET(sock, &wset)) {
1380 			if ((!(txqueue.head) && !(txbuf.head)) || blocked)
1381 				g_warning
1382 				    ("Socket write FD set but we shouldn't have been interested");
1383 
1384 			/* If there is no buffered data, generate some */
1385 			if (!blocked && !(txbuf.head)
1386 			    && (NULL != (prc = txqueue.head))) {
1387 				if (ntohl(prc->req.magic) != NBD_REQUEST_MAGIC) {
1388 					g_warning
1389 					    ("Asked to write a request without a magic number");
1390 					goto err_open;
1391 				}
1392 
1393 				command = ntohl(prc->req.type);
1394 				from = ntohll(prc->req.from);
1395 				len = ntohl(prc->req.len);
1396 
1397 				/* First check whether we can touch this command at all. If this
1398 				 * command is a read, and there is an inflight write, OR if this
1399 				 * command is a write, and there is an inflight read or write, then
1400 				 * we need to leave the command alone and signal that we are blocked
1401 				 */
1402 
1403 				if (!looseordering) {
1404 					uint64_t cfrom;
1405 					uint32_t clen;
1406 					cfrom = from;
1407 					clen = len;
1408 					while (clen > 0) {
1409 						uint64_t blknum = cfrom >> 9;
1410 						if (cfrom >= size) {
1411 							snprintf(errstr,
1412 								 errstr_len,
1413 								 "offset %llx beyond size %llx",
1414 								 (long long int)
1415 								 cfrom,
1416 								 (long long int)
1417 								 size);
1418 							goto err_open;
1419 						}
1420 						if (blkhash[blknum].inflightw ||
1421 						    (blkhash[blknum].inflightr
1422 						     &&
1423 						     ((command &
1424 						       NBD_CMD_MASK_COMMAND) ==
1425 						      NBD_CMD_WRITE))) {
1426 							blocked = 1;
1427 							break;
1428 						}
1429 						cfrom += 512;
1430 						clen -= 512;
1431 					}
1432 				}
1433 
1434 				if (blocked)
1435 					goto skipdequeue;
1436 
1437 				rclist_unlink(&txqueue, prc);
1438 				rclist_addtail(&inflight, prc);
1439 
1440 				dumpcommand("Sending command", prc->req.type);
1441 				/* we rewrite the handle as they otherwise may not be unique */
1442 				*((uint64_t *) (prc->req.handle)) =
1443 				    getrandomhandle(handlehash);
1444 				g_hash_table_insert(handlehash, prc->req.handle,
1445 						    prc);
1446 				addbuffer(&txbuf, &(prc->req),
1447 					  sizeof(struct nbd_request));
1448 				switch (command & NBD_CMD_MASK_COMMAND) {
1449 				case NBD_CMD_WRITE:
1450 					xfer += len;
1451 					while (len > 0) {
1452 						uint64_t blknum = from >> 9;
1453 						char dbuf[512];
1454 						if (from >= size) {
1455 							snprintf(errstr,
1456 								 errstr_len,
1457 								 "offset %llx beyond size %llx",
1458 								 (long long int)
1459 								 from,
1460 								 (long long int)
1461 								 size);
1462 							goto err_open;
1463 						}
1464 						(blkhash[blknum].inflightw)++;
1465 						/* work out what we should be writing */
1466 						makebuf(dbuf, prc->seq, blknum);
1467 						addbuffer(&txbuf, dbuf, 512);
1468 						from += 512;
1469 						len -= 512;
1470 					}
1471 					break;
1472 				case NBD_CMD_READ:
1473 					xfer += len;
1474 					while (len > 0) {
1475 						uint64_t blknum = from >> 9;
1476 						if (from >= size) {
1477 							snprintf(errstr,
1478 								 errstr_len,
1479 								 "offset %llx beyond size %llx",
1480 								 (long long int)
1481 								 from,
1482 								 (long long int)
1483 								 size);
1484 							goto err_open;
1485 						}
1486 						(blkhash[blknum].inflightr)++;
1487 						from += 512;
1488 						len -= 512;
1489 					}
1490 					break;
1491 				case NBD_CMD_DISC:
1492 				case NBD_CMD_FLUSH:
1493 					break;
1494 				default:
1495 					snprintf(errstr, errstr_len,
1496 						 "Incomprehensible command: %08x",
1497 						 command);
1498 					goto err_open;
1499 					break;
1500 				}
1501 
1502 				prc = NULL;
1503 			}
1504 skipdequeue:
1505 
1506 			/* there should be some now */
1507 			if (writebuffer(sock, &txbuf) < 0) {
1508 				snprintf(errstr, errstr_len,
1509 					 "Failed to write to socket buffer: %s",
1510 					 strerror(errno));
1511 				goto err_open;
1512 			}
1513 
1514 		}
1515 
1516 		/* See if there is a reply to be processed from the socket */
1517 		if (FD_ISSET(sock, &rset)) {
1518 			/* Okay, there's something ready for
1519 			 * reading here */
1520 
1521 			READ_ALL_ERRCHK(sock,
1522 					&rep,
1523 					sizeof(struct nbd_reply),
1524 					err_open,
1525 					"Could not read from server socket: %s",
1526 					strerror(errno));
1527 
1528 			if (rep.magic != htonl(NBD_REPLY_MAGIC)) {
1529 				snprintf(errstr, errstr_len,
1530 					 "Bad magic from server");
1531 				goto err_open;
1532 			}
1533 
1534 			if (rep.error) {
1535 				snprintf(errstr, errstr_len,
1536 					 "Server errored a transaction");
1537 				goto err_open;
1538 			}
1539 
1540 			uint64_t handle;
1541 			memcpy(&handle, rep.handle, 8);
1542 			prc = g_hash_table_lookup(handlehash, &handle);
1543 			if (!prc) {
1544 				snprintf(errstr, errstr_len,
1545 					 "Unrecognised handle in reply: 0x%llX",
1546 					 *(long long unsigned int *)(rep.
1547 								     handle));
1548 				goto err_open;
1549 			}
1550 			if (!g_hash_table_remove(handlehash, &handle)) {
1551 				snprintf(errstr, errstr_len,
1552 					 "Could not remove handle from hash: 0x%llX",
1553 					 *(long long unsigned int *)(rep.
1554 								     handle));
1555 				goto err_open;
1556 			}
1557 
1558 			if (prc->req.magic != htonl(NBD_REQUEST_MAGIC)) {
1559 				snprintf(errstr, errstr_len,
1560 					 "Bad magic in inflight data: %08x",
1561 					 prc->req.magic);
1562 				goto err_open;
1563 			}
1564 
1565 			dumpcommand("Processing reply to command",
1566 				    prc->req.type);
1567 			command = ntohl(prc->req.type);
1568 			from = ntohll(prc->req.from);
1569 			len = ntohl(prc->req.len);
1570 
1571 			switch (command & NBD_CMD_MASK_COMMAND) {
1572 			case NBD_CMD_READ:
1573 				while (len > 0) {
1574 					uint64_t blknum = from >> 9;
1575 					char dbuf[512];
1576 					if (from >= size) {
1577 						snprintf(errstr, errstr_len,
1578 							 "offset %llx beyond size %llx",
1579 							 (long long int)from,
1580 							 (long long int)size);
1581 						goto err_open;
1582 					}
1583 					READ_ALL_ERRCHK(sock,
1584 							dbuf,
1585 							512,
1586 							err_open,
1587 							"Could not read data: %s",
1588 							strerror(errno));
1589 					if (--(blkhash[blknum].inflightr) < 0) {
1590 						snprintf(errstr, errstr_len,
1591 							 "Received a read reply for offset %llx when not in flight",
1592 							 (long long int)from);
1593 						goto err_open;
1594 					}
1595 					/* work out what we was written */
1596 					if (checkbuf
1597 					    (dbuf, blkhash[blknum].seq,
1598 					     blknum)) {
1599 						snprintf(errstr, errstr_len,
1600 							 "Bad reply data: I wanted blk %08x, seq %08x but I got (at a guess) blk %08x, seq %08x",
1601 							 (unsigned int)blknum,
1602 							 blkhash[blknum].seq,
1603 							 ((uint32_t
1604 							   *) (dbuf))[0],
1605 							 ((uint32_t
1606 							   *) (dbuf))[1]
1607 						    );
1608 						goto err_open;
1609 
1610 					}
1611 					from += 512;
1612 					len -= 512;
1613 				}
1614 				break;
1615 			case NBD_CMD_WRITE:
1616 				/* subsequent reads should get data with this seq */
1617 				while (len > 0) {
1618 					uint64_t blknum = from >> 9;
1619 					if (--(blkhash[blknum].inflightw) < 0) {
1620 						snprintf(errstr, errstr_len,
1621 							 "Received a write reply for offset %llx when not in flight",
1622 							 (long long int)from);
1623 						goto err_open;
1624 					}
1625 					blkhash[blknum].seq =
1626 					    (uint32_t) (prc->seq);
1627 					from += 512;
1628 					len -= 512;
1629 				}
1630 				break;
1631 			default:
1632 				break;
1633 			}
1634 			blocked = 0;
1635 			processed++;
1636 			rclist_unlink(&inflight, prc);
1637 			prc->req.magic = 0;	/* so a duplicate reply is detected */
1638 			free(prc);
1639 		}
1640 
1641 		if ((do_print == NULL && !(printer++ % 5000))
1642 		    || !(readtransactionfile || txqueue.numitems
1643 			 || inflight.numitems))
1644 			printf
1645 			    ("%d: Seq %08lld Queued: %08d Inflight: %08d Done: %08lld\r",
1646 			     (int)mypid, (long long int)seq, txqueue.numitems,
1647 			     inflight.numitems, (long long int)processed);
1648 
1649 	}
1650 
1651 	printf("\n");
1652 
1653 	if (gettimeofday(&stop, NULL) < 0) {
1654 		snprintf(errstr, errstr_len, "Could not measure end time: %s",
1655 			 strerror(errno));
1656 		goto err_open;
1657 	}
1658 	timespan = timeval_diff_to_double(&stop, &start);
1659 	speed = xfer / timespan;
1660 	if (speed > 1024) {
1661 		speed = speed / 1024.0;
1662 		speedchar[0] = 'K';
1663 	}
1664 	if (speed > 1024) {
1665 		speed = speed / 1024.0;
1666 		speedchar[0] = 'M';
1667 	}
1668 	if (speed > 1024) {
1669 		speed = speed / 1024.0;
1670 		speedchar[0] = 'G';
1671 	}
1672 	g_message
1673 	    ("%d: Integrity %s test complete. Took %.3f seconds to complete, %.3f%sib/s",
1674 	     (int)getpid(), (testflags & TEST_WRITE) ? "write" : "read",
1675 	     timespan, speed, speedchar);
1676 
1677 	retval = 0;
1678 
1679 err_open:
1680 	if (close_sock) {
1681 		close_connection(sock, CONNECTION_CLOSE_PROPERLY);
1682 	}
1683 err:
1684 	if (size && blkhash)
1685 		munmap(blkhash, (size >> 9) * sizeof(struct blkitem));
1686 
1687 	if (blkhashfd != -1)
1688 		close(blkhashfd);
1689 
1690 	if (logfd != -1)
1691 		close(logfd);
1692 
1693 	if (blkhashname)
1694 		free(blkhashname);
1695 
1696 	if (*errstr)
1697 		g_warning("%s", errstr);
1698 
1699 	g_hash_table_destroy(handlehash);
1700 
1701 	return retval;
1702 }
1703 
handle_nonopt(char * opt,gchar ** hostname,long int * p)1704 void handle_nonopt(char *opt, gchar ** hostname, long int *p)
1705 {
1706 	static int nonopt = 0;
1707 
1708 	switch (nonopt) {
1709 	case 0:
1710 		*hostname = g_strdup(opt);
1711 		nonopt++;
1712 		break;
1713 	case 1:
1714 		*p = (strtol(opt, NULL, 0));
1715 		if (*p == LONG_MIN || *p == LONG_MAX) {
1716 			g_critical("Could not parse port number: %s",
1717 				   strerror(errno));
1718 			exit(EXIT_FAILURE);
1719 		}
1720 		break;
1721 	}
1722 }
1723 
1724 typedef int (*testfunc) (char *, int, char, int);
1725 
main(int argc,char ** argv)1726 int main(int argc, char **argv)
1727 {
1728 	gchar *hostname = NULL, *unixsock = NULL;
1729 	long int p = 10809;
1730 	char *name = NULL;
1731 	int sock = -1;
1732 	int c;
1733 	int testflags = 0;
1734 	testfunc test = throughput_test;
1735 
1736 #if HAVE_GNUTLS
1737 	tlssession_init();
1738 #endif
1739 
1740 	/* Ignore SIGPIPE as we want to pick up the error from write() */
1741 	signal(SIGPIPE, SIG_IGN);
1742 
1743 	errstr[errstr_len] = '\0';
1744 
1745 	if (argc < 3) {
1746 		g_message("%d: Not enough arguments", (int)getpid());
1747 		g_message("%d: Usage: %s <hostname> <port>", (int)getpid(),
1748 			  argv[0]);
1749 		g_message("%d: Or: %s <hostname> -N <exportname> [<port>]",
1750 			  (int)getpid(), argv[0]);
1751 		g_message("%d: Or: %s -u <unix socket> -N <exportname>",
1752 			  (int)getpid(), argv[0]);
1753 		exit(EXIT_FAILURE);
1754 	}
1755 	logging(MY_NAME);
1756 	while ((c = getopt(argc, argv, "FN:t:owfilu:hC:K:A:H:I")) >= 0) {
1757 		switch (c) {
1758 		case 1:
1759 			handle_nonopt(optarg, &hostname, &p);
1760 			break;
1761 		case 'N':
1762 			name = g_strdup(optarg);
1763 			break;
1764 		case 'F':
1765 			testflags |= TEST_EXPECT_ERROR;
1766 			break;
1767 		case 't':
1768 			transactionlog = g_strdup(optarg);
1769 			break;
1770 		case 'o':
1771 			test = oversize_test;
1772 			break;
1773 		case 'l':
1774 			looseordering = 1;
1775 			break;
1776 		case 'w':
1777 			testflags |= TEST_WRITE;
1778 			break;
1779 		case 'f':
1780 			testflags |= TEST_FLUSH;
1781 			break;
1782 		case 'I':
1783 #ifndef ISSERVER
1784 			err_nonfatal("inetd mode not supported without syslog support");
1785 			return 77;
1786 #else
1787 			p = -1;
1788 			break;
1789 #endif
1790 		case 'i':
1791 			test = integrity_test;
1792 			break;
1793 		case 'u':
1794 			unixsock = g_strdup(optarg);
1795 			break;
1796 		case 'h':
1797 			test = handshake_test;
1798 			testflags |= TEST_HANDSHAKE;
1799 			break;
1800 #if HAVE_GNUTLS
1801 		case 'C':
1802 			certfile=g_strdup(optarg);
1803 			break;
1804 		case 'K':
1805 			keyfile=g_strdup(optarg);
1806 			break;
1807 		case 'A':
1808 			cacertfile=g_strdup(optarg);
1809 			break;
1810 		case 'H':
1811 			tlshostname=g_strdup(optarg);
1812 			break;
1813 #else
1814 		case 'C':
1815 		case 'K':
1816 		case 'H':
1817 		case 'A':
1818 			g_warning("TLS support not compiled in");
1819 			/* Do not change this - looked for by test suite */
1820 			exit(77);
1821 #endif
1822 		}
1823 	}
1824 
1825 	if (p != -1) {
1826 		while (optind < argc) {
1827 			handle_nonopt(argv[optind++], &hostname, &p);
1828 		}
1829 	}
1830 
1831 	if (keyfile && !certfile)
1832 		certfile = g_strdup(keyfile);
1833 
1834 	if (!tlshostname && hostname)
1835 		tlshostname = g_strdup(hostname);
1836 
1837 	if (hostname != NULL) {
1838 		sock = setup_inet_connection(hostname, p);
1839 	} else if (unixsock != NULL) {
1840 		sock = setup_unix_connection(unixsock);
1841 	} else if (p == -1) {
1842 		sock = setup_inetd_connection(argv + optind);
1843 	} else {
1844 		g_error("need a hostname, a unix domain socket or inetd-mode command line!");
1845 		return -1;
1846 	}
1847 
1848 	if (sock == -1) {
1849 		g_warning("Could not establish a connection: %s", errstr);
1850 		exit(EXIT_FAILURE);
1851 	}
1852 
1853 	if (test(name, sock, TRUE, testflags)
1854 	    < 0) {
1855 		g_warning("Could not run test: %s", errstr);
1856 		exit(EXIT_FAILURE);
1857 	}
1858 
1859 	return 0;
1860 }
1861