1 /* $Id$ */
2
3 /*
4 * Copyright (c) 2005 Nicholas Marriott <nicholas.marriott@gmail.com>
5 *
6 * Permission to use, copy, modify, and distribute this software for any
7 * purpose with or without fee is hereby granted, provided that the above
8 * copyright notice and this permission notice appear in all copies.
9 *
10 * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
11 * WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
12 * MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
13 * ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
14 * WHATSOEVER RESULTING FROM LOSS OF MIND, USE, DATA OR PROFITS, WHETHER
15 * IN AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING
16 * OUT OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
17 */
18
19 #include <sys/types.h>
20 #include <sys/time.h>
21
22 #include <errno.h>
23 #include <fcntl.h>
24 #include <poll.h>
25 #include <stdarg.h>
26 #include <stdio.h>
27 #include <stdlib.h>
28 #include <string.h>
29 #include <unistd.h>
30
31 #include <openssl/ssl.h>
32 #include <openssl/err.h>
33
34 #include "fdm.h"
35
36 #define IO_DEBUG(io, fmt, ...)
37 #ifndef IO_DEBUG
38 #define IO_DEBUG(io, fmt, ...) \
39 log_debug3("%s: (%d) " fmt, __func__, io->fd, ## __VA_ARGS__)
40 #endif
41
42 int io_before_poll(struct io *, struct pollfd *);
43 int io_after_poll(struct io *, struct pollfd *);
44
45 int io_push(struct io *);
46 int io_fill(struct io *);
47
48 /* Create a struct io for the specified socket and SSL descriptors. */
49 struct io *
io_create(int fd,SSL * ssl,const char * eol)50 io_create(int fd, SSL *ssl, const char *eol)
51 {
52 struct io *io;
53 int mode;
54
55 io = xcalloc(1, sizeof *io);
56 io->fd = fd;
57 io->ssl = ssl;
58 io->dup_fd = -1;
59
60 /* Set non-blocking. */
61 if ((mode = fcntl(fd, F_GETFL)) == -1)
62 fatal("fcntl failed");
63 if (fcntl(fd, F_SETFL, mode|O_NONBLOCK) == -1)
64 fatal("fcntl failed");
65
66 io->flags = 0;
67 io->error = NULL;
68
69 io->rd = buffer_create(IO_BLOCKSIZE);
70 io->wr = buffer_create(IO_BLOCKSIZE);
71
72 io->lbuf = NULL;
73 io->llen = 0;
74
75 io->eol = eol;
76
77 return (io);
78 }
79
80 /* Mark io as read only. */
81 void
io_readonly(struct io * io)82 io_readonly(struct io *io)
83 {
84 buffer_destroy(io->wr);
85 io->wr = NULL;
86 }
87
88 /* Mark io as write only. */
89 void
io_writeonly(struct io * io)90 io_writeonly(struct io *io)
91 {
92 buffer_destroy(io->rd);
93 io->rd = NULL;
94 }
95
96 /* Free a struct io. */
97 void
io_free(struct io * io)98 io_free(struct io *io)
99 {
100 if (io->lbuf != NULL)
101 xfree(io->lbuf);
102 if (io->error != NULL)
103 xfree(io->error);
104 if (io->rd != NULL)
105 buffer_destroy(io->rd);
106 if (io->wr != NULL)
107 buffer_destroy(io->wr);
108 xfree(io);
109 }
110
111 /* Close io sockets. */
112 void
io_close(struct io * io)113 io_close(struct io *io)
114 {
115 if (io->ssl != NULL) {
116 SSL_CTX_free(SSL_get_SSL_CTX(io->ssl));
117 SSL_free(io->ssl);
118 }
119 close(io->fd);
120 }
121
122 /* Poll the io. */
123 int
io_poll(struct io * io,int timeout,char ** cause)124 io_poll(struct io *io, int timeout, char **cause)
125 {
126 return (io_polln(&io, 1, NULL, timeout, cause));
127 }
128
129 /* Poll multiple IOs. */
130 int
io_polln(struct io ** iop,u_int n,struct io ** rio,int timeout,char ** cause)131 io_polln(struct io **iop, u_int n, struct io **rio, int timeout, char **cause)
132 {
133 struct io *io;
134 struct pollfd *pfds;
135 int error;
136 u_int i;
137
138 /* Fill in all the pollfds. */
139 pfds = xcalloc(n, sizeof *pfds);
140 for (i = 0; i < n; i++) {
141 io = iop[i];
142 if (rio != NULL)
143 *rio = io;
144 switch (io_before_poll(io, &pfds[i])) {
145 case 0:
146 /* Found a closed io. */
147 xfree(pfds);
148 return (0);
149 case -1:
150 goto error;
151 }
152 }
153
154 /* Do the poll. */
155 error = poll(pfds, n, timeout);
156 if (error == 0 || error == -1) {
157 IO_DEBUG(io, "poll returned: %d (errno=%d)", error, errno);
158 xfree(pfds);
159
160 if (error == 0) {
161 if (timeout == 0) {
162 errno = EAGAIN;
163 return (-1);
164 }
165 errno = ETIMEDOUT;
166 }
167
168 if (errno == EINTR)
169 return (1);
170
171 if (rio != NULL)
172 *rio = NULL;
173 if (cause != NULL)
174 xasprintf(cause, "io: poll: %s", strerror(errno));
175 return (-1);
176 }
177
178 /* Check all the ios. */
179 for (i = 0; i < n; i++) {
180 io = iop[i];
181 if (rio != NULL)
182 *rio = io;
183 if (io_after_poll(io, &pfds[i]) == -1)
184 goto error;
185 }
186
187 xfree(pfds);
188 return (1);
189
190 error:
191 if (cause != NULL)
192 *cause = xstrdup(io->error);
193 xfree(pfds);
194 errno = 0;
195 return (-1);
196 }
197
198 /* Set up an io for polling. */
199 int
io_before_poll(struct io * io,struct pollfd * pfd)200 io_before_poll(struct io *io, struct pollfd *pfd)
201 {
202 /* If io is NULL, don't let poll do anything with this one. */
203 if (io == NULL) {
204 memset(pfd, 0, sizeof *pfd);
205 pfd->fd = -1;
206 return (1);
207 }
208
209 /* Check for errors or closure. */
210 if (io->error != NULL)
211 return (-1);
212 if (IO_CLOSED(io))
213 return (0);
214
215 /* Fill in pollfd. */
216 memset(pfd, 0, sizeof *pfd);
217 if (io->ssl != NULL)
218 pfd->fd = SSL_get_fd(io->ssl);
219 else
220 pfd->fd = io->fd;
221 if (io->rd != NULL)
222 pfd->events |= POLLIN;
223 if (io->wr != NULL && (BUFFER_USED(io->wr) != 0 ||
224 (io->flags & (IOF_NEEDFILL|IOF_NEEDPUSH|IOF_MUSTWR)) != 0))
225 pfd->events |= POLLOUT;
226
227 IO_DEBUG(io, "poll in: 0x%03x", pfd->events);
228
229 return (1);
230 }
231
232 /* Handle io after polling. */
233 int
io_after_poll(struct io * io,struct pollfd * pfd)234 io_after_poll(struct io *io, struct pollfd *pfd)
235 {
236 /* Ignore NULL ios. */
237 if (io == NULL)
238 return (1);
239
240 IO_DEBUG(io, "poll out: 0x%03x", pfd->revents);
241
242 /* Close on POLLERR or POLLNVAL hard. */
243 if (pfd->revents & (POLLERR|POLLNVAL)) {
244 io->flags |= IOF_CLOSED;
245 return (0);
246 }
247 /* Close on POLLHUP but only if there is nothing to read. */
248 if (pfd->revents & POLLHUP && (pfd->revents & POLLIN) == 0) {
249 io->flags |= IOF_CLOSED;
250 return (0);
251 }
252
253 /* Check for repeated read/write. */
254 if ((io->flags & (IOF_NEEDPUSH|IOF_NEEDFILL)) != 0) {
255 /*
256 * If a repeated read/write is necessary, the socket must be
257 * ready for both reading and writing
258 */
259 if (pfd->revents & (POLLOUT|POLLIN)) {
260 if (io->flags & IOF_NEEDPUSH) {
261 switch (io_push(io)) {
262 case 0:
263 io->flags |= IOF_CLOSED;
264 return (0);
265 case -1:
266 return (-1);
267 }
268 }
269 if (io->flags & IOF_NEEDFILL) {
270 switch (io_fill(io)) {
271 case 0:
272 io->flags |= IOF_CLOSED;
273 return (0);
274 case -1:
275 return (-1);
276 }
277 }
278 }
279 return (1);
280 }
281
282 /* Otherwise try to read and write. */
283 if (io->wr != NULL && pfd->revents & POLLOUT) {
284 switch (io_push(io)) {
285 case 0:
286 io->flags |= IOF_CLOSED;
287 return (0);
288 case -1:
289 return (-1);
290 }
291 }
292 if (io->rd != NULL && pfd->revents & POLLIN) {
293 switch (io_fill(io)) {
294 case 0:
295 io->flags |= IOF_CLOSED;
296 return (0);
297 case -1:
298 return (-1);
299 }
300 }
301
302 return (1);
303 }
304
305 /*
306 * Fill read buffer. Returns 0 for closed, -1 for error, 1 for success,
307 * a la read(2).
308 */
309 int
io_fill(struct io * io)310 io_fill(struct io *io)
311 {
312 ssize_t n;
313 int error;
314
315 again:
316 /* Ensure there is at least some minimum space in the buffer. */
317 buffer_ensure(io->rd, IO_WATERMARK);
318
319 /* Attempt to read as much as the buffer has available. */
320 if (io->ssl == NULL) {
321 n = read(io->fd, BUFFER_IN(io->rd), BUFFER_FREE(io->rd));
322 IO_DEBUG(io, "read returned %zd (errno=%d)", n, errno);
323 if (n == 0 || (n == -1 && errno == EPIPE))
324 return (0);
325 if (n == -1 && errno != EINTR && errno != EAGAIN) {
326 if (io->error != NULL)
327 xfree(io->error);
328 xasprintf(&io->error, "io: read: %s", strerror(errno));
329 return (-1);
330 }
331 } else {
332 n = SSL_read(io->ssl, BUFFER_IN(io->rd), BUFFER_FREE(io->rd));
333 IO_DEBUG(io, "SSL_read returned %zd", n);
334 if (n == 0)
335 return (0);
336 if (n < 0) {
337 switch (error = SSL_get_error(io->ssl, n)) {
338 case SSL_ERROR_WANT_READ:
339 /*
340 * A repeat is certain (poll on the socket will
341 * still return data ready) so this can be
342 * ignored.
343 */
344 break;
345 case SSL_ERROR_WANT_WRITE:
346 io->flags |= IOF_NEEDFILL;
347 break;
348 case SSL_ERROR_SYSCALL:
349 if (errno == EAGAIN || errno == EINTR)
350 break;
351 /* FALLTHROUGH */
352 default:
353 if (io->error != NULL)
354 xfree(io->error);
355 io->error = sslerror2(error, "SSL_read");
356 return (-1);
357 }
358 }
359 }
360
361 /* Test for > 0 since SSL_read can return any -ve on error. */
362 if (n > 0) {
363 IO_DEBUG(io, "read %zd bytes", n);
364
365 /* Copy out the duplicate fd. Errors are just ignored. */
366 if (io->dup_fd != -1) {
367 write(io->dup_fd, "< ", 2);
368 write(io->dup_fd, BUFFER_IN(io->rd), n);
369 }
370
371 /* Adjust the buffer size. */
372 buffer_add(io->rd, n);
373
374 /* Reset the need flags. */
375 io->flags &= ~IOF_NEEDFILL;
376
377 goto again;
378 }
379
380 return (1);
381 }
382
383 /* Empty write buffer. */
384 int
io_push(struct io * io)385 io_push(struct io *io)
386 {
387 ssize_t n;
388 int error;
389
390 /* If nothing to write, return. */
391 if (BUFFER_USED(io->wr) == 0)
392 return (1);
393
394 /* Write as much as possible. */
395 if (io->ssl == NULL) {
396 n = write(io->fd, BUFFER_OUT(io->wr), BUFFER_USED(io->wr));
397 IO_DEBUG(io, "write returned %zd (errno=%d)", n, errno);
398 if (n == 0 || (n == -1 && errno == EPIPE))
399 return (0);
400 if (n == -1 && errno != EINTR && errno != EAGAIN) {
401 if (io->error != NULL)
402 xfree(io->error);
403 xasprintf(&io->error, "io: write: %s", strerror(errno));
404 return (-1);
405 }
406 } else {
407 n = SSL_write(io->ssl, BUFFER_OUT(io->wr), BUFFER_USED(io->wr));
408 IO_DEBUG(io, "SSL_write returned %zd", n);
409 if (n == 0)
410 return (0);
411 if (n < 0) {
412 switch (error = SSL_get_error(io->ssl, n)) {
413 case SSL_ERROR_WANT_READ:
414 io->flags |= IOF_NEEDPUSH;
415 break;
416 case SSL_ERROR_WANT_WRITE:
417 /*
418 * A repeat is certain (buffer still has data)
419 * so this can be ignored
420 */
421 break;
422 case SSL_ERROR_SYSCALL:
423 if (errno == EAGAIN || errno == EINTR)
424 break;
425 /* FALLTHROUGH */
426 default:
427 if (io->error != NULL)
428 xfree(io->error);
429 io->error = sslerror2(error, "SSL_write");
430 return (-1);
431 }
432 }
433 }
434
435 /* Test for > 0 since SSL_write can return any -ve on error. */
436 if (n > 0) {
437 IO_DEBUG(io, "wrote %zd bytes", n);
438
439 /* Copy out the duplicate fd. */
440 if (io->dup_fd != -1) {
441 write(io->dup_fd, "> ", 2);
442 write(io->dup_fd, BUFFER_OUT(io->wr), n);
443 }
444
445 /* Adjust the buffer size. */
446 buffer_remove(io->wr, n);
447
448 /* Reset the need flags. */
449 io->flags &= ~IOF_NEEDPUSH;
450 }
451
452 return (1);
453 }
454
455 /* Return a specific number of bytes from the read buffer, if available. */
456 void *
io_read(struct io * io,size_t len)457 io_read(struct io *io, size_t len)
458 {
459 void *buf;
460
461 IO_DEBUG(io, "in: %zu bytes, rd: used=%zu, free=%zu", len,
462 BUFFER_USED(io->rd), BUFFER_FREE(io->rd));
463
464 if (io->error != NULL)
465 return (NULL);
466
467 if (BUFFER_USED(io->rd) < len)
468 return (NULL);
469
470 buf = xmalloc(len);
471 buffer_read(io->rd, buf, len);
472
473 IO_DEBUG(io, "out: %zu bytes, rd: used=%zu, free=%zu", len,
474 BUFFER_USED(io->rd), BUFFER_FREE(io->rd));
475
476 return (buf);
477 }
478
479 /* Return a specific number of bytes from the read buffer, if available. */
480 int
io_read2(struct io * io,void * buf,size_t len)481 io_read2(struct io *io, void *buf, size_t len)
482 {
483 if (io->error != NULL)
484 return (-1);
485
486 IO_DEBUG(io, "in: %zu bytes, rd: used=%zu, free=%zu", len,
487 BUFFER_USED(io->rd), BUFFER_FREE(io->rd));
488
489 if (BUFFER_USED(io->rd) < len)
490 return (1);
491
492 buffer_read(io->rd, buf, len);
493
494 IO_DEBUG(io, "out: %zu bytes, rd: used=%zu, free=%zu", len,
495 BUFFER_USED(io->rd), BUFFER_FREE(io->rd));
496
497 return (0);
498 }
499
500 /* Write a block to the io write buffer. */
501 void
io_write(struct io * io,const void * buf,size_t len)502 io_write(struct io *io, const void *buf, size_t len)
503 {
504 if (io->error != NULL)
505 return;
506
507 IO_DEBUG(io, "in: %zu bytes, wr: used=%zu, free=%zu", len,
508 BUFFER_USED(io->wr), BUFFER_FREE(io->wr));
509
510 buffer_write(io->wr, buf, len);
511
512 IO_DEBUG(io, "out: %zu bytes, wr: used=%zu, free=%zu", len,
513 BUFFER_USED(io->wr), BUFFER_FREE(io->wr));
514 }
515
516 /*
517 * Return a line from the read buffer. EOL is stripped and the string returned
518 * is zero-terminated.
519 */
520 char *
io_readline2(struct io * io,char ** buf,size_t * len)521 io_readline2(struct io *io, char **buf, size_t *len)
522 {
523 char *ptr, *base;
524 size_t size, maxlen, eollen;
525
526 if (io->error != NULL)
527 return (NULL);
528
529 maxlen = BUFFER_USED(io->rd);
530 if (maxlen > IO_MAXLINELEN)
531 maxlen = IO_MAXLINELEN;
532 eollen = strlen(io->eol);
533 if (BUFFER_USED(io->rd) < eollen)
534 return (NULL);
535
536 IO_DEBUG(io, "in: rd: used=%zu, free=%zu",
537 BUFFER_USED(io->rd), BUFFER_FREE(io->rd));
538
539 base = ptr = BUFFER_OUT(io->rd);
540 for (;;) {
541 /* Find the first character in the EOL string. */
542 ptr = memchr(ptr, *io->eol, maxlen - (ptr - base));
543
544 if (ptr != NULL) {
545 /* Found. Is there enough space for the rest? */
546 if (ptr - base + eollen > maxlen) {
547 /*
548 * No, this isn't it. Set ptr to NULL to handle
549 * as not found.
550 */
551 ptr = NULL;
552 } else if (strncmp(ptr, io->eol, eollen) == 0) {
553 /* This is an EOL. */
554 size = ptr - base;
555 break;
556 }
557 }
558 if (ptr == NULL) {
559 IO_DEBUG(io,
560 "not found (%zu, %d)", maxlen, IO_CLOSED(io));
561
562 /*
563 * Not found within the length searched. If that was
564 * the maximum length, this is an error.
565 */
566 if (maxlen == IO_MAXLINELEN) {
567 if (io->error != NULL)
568 xfree(io->error);
569 io->error =
570 xstrdup("io: maximum line length exceeded");
571 return (NULL);
572 }
573
574 /*
575 * If the socket has closed, just return all the data
576 * (the buffer is known to be at least eollen long).
577 */
578 if (!IO_CLOSED(io))
579 return (NULL);
580 size = BUFFER_USED(io->rd);
581
582 ENSURE_FOR(*buf, *len, size, 1);
583 buffer_read(io->rd, *buf, size);
584 (*buf)[size] = '\0';
585 return (*buf);
586 }
587
588 /* Start again from the next character. */
589 ptr++;
590 }
591
592 /* Copy the line and remove it from the buffer. */
593 ENSURE_FOR(*buf, *len, size, 1);
594 if (size != 0)
595 buffer_read(io->rd, *buf, size);
596 (*buf)[size] = '\0';
597
598 /* Discard the EOL from the buffer. */
599 buffer_remove(io->rd, eollen);
600
601 IO_DEBUG(io, "out: %zu bytes, rd: used=%zu, free=%zu",
602 size, BUFFER_USED(io->rd), BUFFER_FREE(io->rd));
603
604 return (*buf);
605 }
606
607 /* Return a line from the read buffer in a new buffer. */
608 char *
io_readline(struct io * io)609 io_readline(struct io *io)
610 {
611 char *line;
612
613 if (io->error != NULL)
614 return (NULL);
615
616 if (io->lbuf == NULL) {
617 io->llen = IO_LINESIZE;
618 io->lbuf = xmalloc(io->llen);
619 }
620
621 if ((line = io_readline2(io, &io->lbuf, &io->llen)) != NULL)
622 io->lbuf = NULL;
623 return (line);
624 }
625
626 /* Write a line to the io write buffer. */
627 void printflike2
io_writeline(struct io * io,const char * fmt,...)628 io_writeline(struct io *io, const char *fmt, ...)
629 {
630 va_list ap;
631
632 if (io->error != NULL)
633 return;
634
635 va_start(ap, fmt);
636 io_vwriteline(io, fmt, ap);
637 va_end(ap);
638 }
639
640 /* Write a line to the io write buffer from a va_list. */
641 void
io_vwriteline(struct io * io,const char * fmt,va_list ap)642 io_vwriteline(struct io *io, const char *fmt, va_list ap)
643 {
644 int n;
645 va_list aq;
646
647 if (io->error != NULL)
648 return;
649
650 IO_DEBUG(io, "in: wr: used=%zu, free=%zu",
651 BUFFER_USED(io->wr), BUFFER_FREE(io->wr));
652
653 if (fmt != NULL) {
654 va_copy(aq, ap);
655 n = xvsnprintf(NULL, 0, fmt, aq);
656 va_end(aq);
657
658 buffer_ensure(io->wr, n + 1);
659 xvsnprintf(BUFFER_IN(io->wr), n + 1, fmt, ap);
660 buffer_add(io->wr, n);
661 } else
662 n = 0;
663 io_write(io, io->eol, strlen(io->eol));
664
665 IO_DEBUG(io, "out: %zu bytes, wr: used=%zu, free=%zu",
666 n + strlen(io->eol), BUFFER_USED(io->wr), BUFFER_FREE(io->wr));
667 }
668
669 /* Poll until a line is received. */
670 int
io_pollline(struct io * io,char ** line,int timeout,char ** cause)671 io_pollline(struct io *io, char **line, int timeout, char **cause)
672 {
673 int res;
674
675 if (io->lbuf == NULL) {
676 io->llen = IO_LINESIZE;
677 io->lbuf = xmalloc(io->llen);
678 }
679
680 res = io_pollline2(io, line, &io->lbuf, &io->llen, timeout, cause);
681 if (res == 1)
682 io->lbuf = NULL;
683 return (res);
684 }
685
686 /* Poll until a line is received, using a user buffer. */
687 int
io_pollline2(struct io * io,char ** line,char ** buf,size_t * len,int timeout,char ** cause)688 io_pollline2(struct io *io, char **line, char **buf, size_t *len, int timeout,
689 char **cause)
690 {
691 int res;
692
693 for (;;) {
694 *line = io_readline2(io, buf, len);
695 if (*line != NULL)
696 return (1);
697
698 if ((res = io_poll(io, timeout, cause)) != 1)
699 return (res);
700 }
701 }
702
703 /* Poll until all data in the write buffer has been written to the socket. */
704 int
io_flush(struct io * io,int timeout,char ** cause)705 io_flush(struct io *io, int timeout, char **cause)
706 {
707 while (BUFFER_USED(io->wr) != 0) {
708 if (io_poll(io, timeout, cause) != 1)
709 return (-1);
710 }
711
712 return (0);
713 }
714
715 /* Poll until len bytes have been read into the read buffer. */
716 int
io_wait(struct io * io,size_t len,int timeout,char ** cause)717 io_wait(struct io *io, size_t len, int timeout, char **cause)
718 {
719 while (BUFFER_USED(io->rd) < len) {
720 if (io_poll(io, timeout, cause) != 1)
721 return (-1);
722 }
723
724 return (0);
725 }
726
727 /* Poll if there is lots of data to write. */
728 int
io_update(struct io * io,int timeout,char ** cause)729 io_update(struct io *io, int timeout, char **cause)
730 {
731 if (BUFFER_USED(io->wr) < IO_FLUSHSIZE)
732 return (1);
733
734 return (io_poll(io, timeout, cause));
735 }
736