1 /*
2  * DNS Reply Tool (drool)
3  *
4  * Copyright (c) 2017-2018, OARC, Inc.
5  * Copyright (c) 2017, Comcast Corporation
6  * All rights reserved.
7  *
8  * Redistribution and use in source and binary forms, with or without
9  * modification, are permitted provided that the following conditions
10  * are met:
11  *
12  * 1. Redistributions of source code must retain the above copyright
13  *    notice, this list of conditions and the following disclaimer.
14  *
15  * 2. Redistributions in binary form must reproduce the above copyright
16  *    notice, this list of conditions and the following disclaimer in
17  *    the documentation and/or other materials provided with the
18  *    distribution.
19  *
20  * 3. Neither the name of the copyright holder nor the names of its
21  *    contributors may be used to endorse or promote products derived
22  *    from this software without specific prior written permission.
23  *
24  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
25  * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
26  * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
27  * FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE
28  * COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
29  * INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
30  * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
31  * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
32  * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
33  * LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN
34  * ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
35  * POSSIBILITY OF SUCH DAMAGE.
36  */
37 
38 #include "config.h"
39 
40 #include "client.h"
41 #include "log.h"
42 #include "assert.h"
43 
44 #include <stdlib.h>
45 #include <unistd.h>
46 #include <fcntl.h>
47 #include <errno.h>
48 #include <string.h>
49 
50 /*
51  * EV callbacks
52  */
53 
client_shutdown(struct ev_loop * loop,ev_io * w,int revents)54 static void client_shutdown(struct ev_loop* loop, ev_io* w, int revents)
55 {
56     drool_client_t* client;
57     char            buf[512];
58 
59     /* TODO: Check revents for EV_ERROR */
60 
61     drool_assert(loop);
62     drool_assert(w);
63     client = (drool_client_t*)(w->data);
64     drool_assert(client);
65 
66     if (recv(client->fd, buf, sizeof(buf), 0) > 0)
67         return;
68 
69     ev_io_stop(loop, w);
70     client->state = CLIENT_CLOSED;
71     client->callback(client, loop);
72 }
73 
client_read(struct ev_loop * loop,ev_io * w,int revents)74 static void client_read(struct ev_loop* loop, ev_io* w, int revents)
75 {
76     drool_client_t* client;
77     ssize_t         nrecv;
78     char            buf[64 * 1024];
79 
80     /* TODO: Check revents for EV_ERROR */
81 
82     drool_assert(loop);
83     drool_assert(w);
84     client = (drool_client_t*)(w->data);
85     drool_assert(client);
86 
87     /* TODO: How much should we read? */
88 
89     /* TODO:
90     if (client->have_from_addr)
91         memset(&(client->from_addr), 0, sizeof(struct sockaddr_storage));
92     client->from_addrlen = sizeof(struct sockaddr_storage);
93     nrecv = recvfrom(client->fd, buf, sizeof(buf), 0, &(client->from_addr), &(client->from_addrlen));
94     */
95     nrecv = recvfrom(client->fd, buf, sizeof(buf), 0, 0, 0);
96     if (nrecv < 0) {
97         switch (errno) {
98         case EAGAIN:
99 #if EAGAIN != EWOULDBLOCK
100         case EWOULDBLOCK:
101 #endif
102             return;
103 
104         case ECONNREFUSED:
105         case ENETUNREACH:
106             client->state = CLIENT_FAILED;
107             break;
108 
109         default:
110             client->errnum = errno;
111             client->state  = CLIENT_ERRNO;
112             break;
113         }
114         ev_io_stop(loop, w);
115         client->callback(client, loop);
116         return;
117     }
118     /* TODO:
119     else if (nrecv > 0) {
120     }
121 */
122 
123     ev_io_stop(loop, w);
124     client->state = CLIENT_SUCCESS;
125     client->callback(client, loop);
126 }
127 
client_write(struct ev_loop * loop,ev_io * w,int revents)128 static void client_write(struct ev_loop* loop, ev_io* w, int revents)
129 {
130     drool_client_t* client;
131     ssize_t         nsent;
132 
133     /* TODO: Check revents for EV_ERROR */
134 
135     drool_assert(loop);
136     drool_assert(w);
137     client = (drool_client_t*)(w->data);
138     drool_assert(client);
139 
140     if (client->state == CLIENT_CONNECTING) {
141         int       err = 0;
142         socklen_t len = sizeof(err);
143 
144         ev_io_stop(loop, w);
145 
146         if (getsockopt(client->fd, SOL_SOCKET, SO_ERROR, (void*)&err, &len) < 0) {
147             client->errnum = errno;
148             client->state  = CLIENT_ERRNO;
149         } else if (err) {
150             switch (err) {
151             case ECONNREFUSED:
152             case ENETUNREACH:
153                 client->state = CLIENT_FAILED;
154                 break;
155 
156             default:
157                 client->errnum = err;
158                 client->state  = CLIENT_ERRNO;
159                 break;
160             }
161         } else {
162             client->state        = CLIENT_CONNECTED;
163             client->is_connected = 1;
164         }
165 
166         client->callback(client, loop);
167         return;
168     }
169 
170     if (client->is_stream && !client->sent_length) {
171         uint16_t length = htons(query_length(client->query));
172 
173         if (client->have_to_addr)
174             nsent = sendto(client->fd, &length, 2, 0, (struct sockaddr*)&(client->to_addr), client->to_addrlen);
175         else
176             nsent = sendto(client->fd, &length, 2, 0, 0, 0);
177         if (nsent < 0) {
178             switch (errno) {
179             case EAGAIN:
180 #if EAGAIN != EWOULDBLOCK
181             case EWOULDBLOCK:
182 #endif
183                 return;
184 
185             default:
186                 break;
187             }
188 
189             ev_io_stop(loop, w);
190             client->errnum = errno;
191             client->state  = errno == ECONNRESET ? CLIENT_FAILED : CLIENT_ERRNO;
192             client->callback(client, loop);
193             return;
194         } else if (nsent != 2) {
195             ev_io_stop(loop, w);
196             client->errnum = ENOBUFS;
197             client->state  = CLIENT_FAILED;
198             client->callback(client, loop);
199             return;
200         }
201 
202         client->sent_length = 1;
203     }
204 
205     if (client->have_to_addr)
206         nsent = sendto(client->fd, query_raw(client->query) + client->sent, query_length(client->query) - client->sent, 0, (struct sockaddr*)&(client->to_addr), client->to_addrlen);
207     else
208         nsent = sendto(client->fd, query_raw(client->query) + client->sent, query_length(client->query) - client->sent, 0, 0, 0);
209     if (nsent < 0) {
210         switch (errno) {
211         case EAGAIN:
212 #if EAGAIN != EWOULDBLOCK
213         case EWOULDBLOCK:
214 #endif
215             return;
216 
217         default:
218             break;
219         }
220 
221         ev_io_stop(loop, w);
222         client->errnum = errno;
223         client->state  = errno == ECONNRESET ? CLIENT_FAILED : CLIENT_ERRNO;
224         client->callback(client, loop);
225         return;
226     }
227 
228     client->sent += nsent;
229     if (client->sent < query_length(client->query))
230         return;
231 
232     ev_io_stop(loop, w);
233     if (client->skip_reply) {
234         client->state = CLIENT_SUCCESS;
235         client->callback(client, loop);
236         return;
237     }
238     ev_io_start(loop, &(client->read_watcher));
239     client->state = CLIENT_RECIVING;
240 }
241 
242 /*
243  * New/free functions
244  */
245 
client_new(drool_query_t * query,drool_client_callback_t callback)246 drool_client_t* client_new(drool_query_t* query, drool_client_callback_t callback)
247 {
248     drool_client_t* client;
249 
250     drool_assert(query);
251     if (!query) {
252         return 0;
253     }
254     drool_assert(callback);
255     if (!callback) {
256         return 0;
257     }
258 
259     if (!query_have_raw(query)) {
260         return 0;
261     }
262 
263     if ((client = calloc(1, sizeof(drool_client_t)))) {
264         client->query              = query;
265         client->callback           = callback;
266         client->write_watcher.data = (void*)client;
267         ev_init(&(client->write_watcher), &client_write);
268         client->read_watcher.data = (void*)client;
269         ev_init(&(client->read_watcher), &client_read);
270         client->shutdown_watcher.data = (void*)client;
271         ev_init(&(client->shutdown_watcher), &client_shutdown);
272     }
273 
274     return client;
275 }
276 
client_free(drool_client_t * client)277 void client_free(drool_client_t* client)
278 {
279     if (client) {
280         if (client->have_fd) {
281             if (client->is_connected) {
282                 shutdown(client->fd, SHUT_RDWR);
283             }
284             close(client->fd);
285         }
286         if (client->query) {
287             query_free(client->query);
288         }
289         free(client);
290     }
291 }
292 
293 /*
294  * Get/set functions
295  */
296 
client_next(drool_client_t * client)297 inline drool_client_t* client_next(drool_client_t* client)
298 {
299     drool_assert(client);
300     return client->next;
301 }
302 
client_prev(drool_client_t * client)303 inline drool_client_t* client_prev(drool_client_t* client)
304 {
305     drool_assert(client);
306     return client->prev;
307 }
308 
client_fd(const drool_client_t * client)309 inline int client_fd(const drool_client_t* client)
310 {
311     drool_assert(client);
312     return client->fd;
313 }
314 
client_query(const drool_client_t * client)315 inline const drool_query_t* client_query(const drool_client_t* client)
316 {
317     drool_assert(client);
318     return client->query;
319 }
320 
client_state(const drool_client_t * client)321 inline drool_client_state_t client_state(const drool_client_t* client)
322 {
323     drool_assert(client);
324     return client->state;
325 }
326 
client_is_connected(const drool_client_t * client)327 inline int client_is_connected(const drool_client_t* client)
328 {
329     drool_assert(client);
330     return client->is_connected;
331 }
332 
client_errno(const drool_client_t * client)333 inline int client_errno(const drool_client_t* client)
334 {
335     drool_assert(client);
336     return client->errnum;
337 }
338 
client_start(const drool_client_t * client)339 inline ev_tstamp client_start(const drool_client_t* client)
340 {
341     drool_assert(client);
342     return client->start;
343 }
344 
client_is_dgram(const drool_client_t * client)345 inline int client_is_dgram(const drool_client_t* client)
346 {
347     drool_assert(client);
348     return client->is_dgram;
349 }
350 
client_is_stream(const drool_client_t * client)351 inline int client_is_stream(const drool_client_t* client)
352 {
353     drool_assert(client);
354     return client->is_stream;
355 }
356 
client_set_next(drool_client_t * client,drool_client_t * next)357 int client_set_next(drool_client_t* client, drool_client_t* next)
358 {
359     drool_assert(client);
360     if (!client) {
361         return 1;
362     }
363 
364     client->next = next;
365 
366     return 0;
367 }
368 
client_set_prev(drool_client_t * client,drool_client_t * prev)369 int client_set_prev(drool_client_t* client, drool_client_t* prev)
370 {
371     drool_assert(client);
372     if (!client) {
373         return 1;
374     }
375 
376     client->prev = prev;
377 
378     return 0;
379 }
380 
client_set_start(drool_client_t * client,ev_tstamp start)381 int client_set_start(drool_client_t* client, ev_tstamp start)
382 {
383     drool_assert(client);
384     if (!client) {
385         return 1;
386     }
387 
388     client->start = start;
389 
390     return 0;
391 }
392 
client_set_skip_reply(drool_client_t * client)393 int client_set_skip_reply(drool_client_t* client)
394 {
395     drool_assert(client);
396     if (!client) {
397         return 1;
398     }
399 
400     client->skip_reply = 1;
401 
402     return 0;
403 }
404 
client_release_query(drool_client_t * client)405 drool_query_t* client_release_query(drool_client_t* client)
406 {
407     drool_query_t* query;
408 
409     drool_assert(client);
410     if (!client) {
411         return 0;
412     }
413 
414     query         = client->query;
415     client->query = 0;
416 
417     return query;
418 }
419 
420 /*
421  * Control functions
422  */
423 
client_connect(drool_client_t * client,int ipproto,const struct sockaddr * addr,socklen_t addrlen,struct ev_loop * loop)424 int client_connect(drool_client_t* client, int ipproto, const struct sockaddr* addr, socklen_t addrlen, struct ev_loop* loop)
425 {
426     int socket_type, flags;
427 
428     drool_assert(client);
429     if (!client) {
430         return 1;
431     }
432     drool_assert(addr);
433     if (!addr) {
434         return 1;
435     }
436     drool_assert(addrlen);
437     if (!addrlen) {
438         return 1;
439     }
440     if (addrlen > sizeof(struct sockaddr_storage)) {
441         return 1;
442     }
443     drool_assert(loop);
444     if (!loop) {
445         return 1;
446     }
447     if (client->state != CLIENT_NEW) {
448         return 1;
449     }
450 
451     switch (ipproto) {
452     case IPPROTO_UDP:
453         socket_type = SOCK_DGRAM;
454         memcpy(&(client->to_addr), addr, addrlen);
455         client->to_addrlen   = addrlen;
456         client->have_to_addr = 1;
457         client->is_dgram     = 1;
458         break;
459 
460     case IPPROTO_TCP:
461         socket_type       = SOCK_STREAM;
462         client->is_stream = 1;
463         break;
464 
465     default:
466         return 1;
467     }
468 
469     if ((client->fd = socket(addr->sa_family, socket_type, ipproto)) < 0) {
470         client->errnum = errno;
471         client->state  = CLIENT_ERRNO;
472         return 1;
473     }
474     client->have_fd = 1;
475 
476     if ((flags = fcntl(client->fd, F_GETFL)) == -1
477         || fcntl(client->fd, F_SETFL, flags | O_NONBLOCK)) {
478         client->errnum = errno;
479         client->state  = CLIENT_ERRNO;
480         return 1;
481     }
482 
483     ev_io_set(&(client->write_watcher), client->fd, EV_WRITE);
484     ev_io_set(&(client->read_watcher), client->fd, EV_READ);
485     ev_io_set(&(client->shutdown_watcher), client->fd, EV_READ);
486 
487     if (socket_type == SOCK_STREAM && connect(client->fd, addr, addrlen) < 0) {
488         switch (errno) {
489         case EINPROGRESS:
490             ev_io_start(loop, &(client->write_watcher));
491             client->state = CLIENT_CONNECTING;
492             return 0;
493 
494         case ECONNREFUSED:
495         case ENETUNREACH:
496             client->state = CLIENT_FAILED;
497             break;
498 
499         default:
500             client->errnum = errno;
501             client->state  = CLIENT_ERRNO;
502             break;
503         }
504         return 1;
505     }
506 
507     client->state        = CLIENT_CONNECTED;
508     client->is_connected = 1;
509     return 0;
510 }
511 
client_send(drool_client_t * client,struct ev_loop * loop)512 int client_send(drool_client_t* client, struct ev_loop* loop)
513 {
514     ssize_t nsent;
515 
516     drool_assert(client);
517     if (!client) {
518         return 1;
519     }
520     drool_assert(loop);
521     if (!loop) {
522         return 1;
523     }
524     if (client->state != CLIENT_CONNECTED) {
525         return 1;
526     }
527 
528     if (client->is_stream && !client->sent_length) {
529         uint16_t length = htons(query_length(client->query));
530 
531         if (client->have_to_addr)
532             nsent = sendto(client->fd, &length, 2, 0, (struct sockaddr*)&(client->to_addr), client->to_addrlen);
533         else
534             nsent = sendto(client->fd, &length, 2, 0, 0, 0);
535         if (nsent < 0) {
536             switch (errno) {
537             case EAGAIN:
538 #if EAGAIN != EWOULDBLOCK
539             case EWOULDBLOCK:
540 #endif
541                 ev_io_start(loop, &(client->write_watcher));
542                 client->state = CLIENT_SENDING;
543                 return 0;
544 
545             default:
546                 break;
547             }
548 
549             client->errnum = errno;
550             client->state  = errno == ECONNRESET ? CLIENT_FAILED : CLIENT_ERRNO;
551             return 1;
552         } else if (nsent != 2) {
553             client->errnum = ENOBUFS;
554             client->state  = CLIENT_FAILED;
555             return 1;
556         }
557 
558         client->sent_length = 1;
559     }
560 
561     if (client->have_to_addr)
562         nsent = sendto(client->fd, query_raw(client->query), query_length(client->query), 0, (struct sockaddr*)&(client->to_addr), client->to_addrlen);
563     else
564         nsent = sendto(client->fd, query_raw(client->query), query_length(client->query), 0, 0, 0);
565     if (nsent < 0) {
566         switch (errno) {
567         case EAGAIN:
568 #if EAGAIN != EWOULDBLOCK
569         case EWOULDBLOCK:
570 #endif
571             ev_io_start(loop, &(client->write_watcher));
572             client->state = CLIENT_SENDING;
573             return 0;
574 
575         default:
576             break;
577         }
578 
579         client->errnum = errno;
580         client->state  = errno == ECONNRESET ? CLIENT_FAILED : CLIENT_ERRNO;
581         return 1;
582     }
583 
584     if (nsent < query_length(client->query)) {
585         client->sent = nsent;
586         ev_io_start(loop, &(client->write_watcher));
587         client->state = CLIENT_SENDING;
588         return 0;
589     }
590 
591     if (client->skip_reply) {
592         client->state = CLIENT_SUCCESS;
593         return 0;
594     }
595 
596     ev_io_start(loop, &(client->read_watcher));
597     client->state = CLIENT_RECIVING;
598     return 0;
599 }
600 
client_reuse(drool_client_t * client,drool_query_t * query)601 int client_reuse(drool_client_t* client, drool_query_t* query)
602 {
603     drool_assert(client);
604     if (!client) {
605         return 1;
606     }
607     drool_assert(query);
608     if (!query) {
609         return 1;
610     }
611     if (client->state != CLIENT_SUCCESS) {
612         return 1;
613     }
614 
615     if (client->query)
616         query_free(client->query);
617     client->query       = query;
618     client->sent        = 0;
619     client->recv        = 0;
620     client->state       = CLIENT_CONNECTED;
621     client->sent_length = 0;
622 
623     return 0;
624 }
625 
client_close(drool_client_t * client,struct ev_loop * loop)626 int client_close(drool_client_t* client, struct ev_loop* loop)
627 {
628     drool_assert(client);
629     if (!client) {
630         return 1;
631     }
632     drool_assert(loop);
633     if (!loop) {
634         return 1;
635     }
636 
637     switch (client->state) {
638     case CLIENT_CONNECTING:
639     case CLIENT_SENDING:
640     case CLIENT_RECIVING:
641         ev_io_stop(loop, &(client->write_watcher));
642         ev_io_stop(loop, &(client->read_watcher));
643         break;
644 
645     case CLIENT_CLOSING:
646         return 0;
647 
648     default:
649         break;
650     }
651 
652     if (client->have_fd) {
653         if (client->is_connected) {
654             client->is_connected = 0;
655             if (!shutdown(client->fd, SHUT_RDWR)) {
656                 ev_io_start(loop, &(client->shutdown_watcher));
657                 client->state = CLIENT_CLOSING;
658                 return 0;
659             }
660         }
661         close(client->fd);
662         client->have_fd = 0;
663     }
664     client->state = CLIENT_CLOSED;
665 
666     return 0;
667 }
668