1 /*
2  * Copyright (c) 2018-2021, OARC, Inc.
3  * All rights reserved.
4  *
5  * This file is part of dnsjit.
6  *
7  * dnsjit is free software: you can redistribute it and/or modify
8  * it under the terms of the GNU General Public License as published by
9  * the Free Software Foundation, either version 3 of the License, or
10  * (at your option) any later version.
11  *
12  * dnsjit is distributed in the hope that it will be useful,
13  * but WITHOUT ANY WARRANTY; without even the implied warranty of
14  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15  * GNU General Public License for more details.
16  *
17  * You should have received a copy of the GNU General Public License
18  * along with dnsjit.  If not, see <http://www.gnu.org/licenses/>.
19  */
20 
21 #include "config.h"
22 
23 #include "output/tcpcli.h"
24 #include "core/assert.h"
25 #include "core/object/dns.h"
26 #include "core/object/payload.h"
27 
28 #include <sys/types.h>
29 #include <sys/socket.h>
30 #include <netdb.h>
31 #include <unistd.h>
32 #include <fcntl.h>
33 #include <string.h>
34 #include <arpa/inet.h>
35 #include <poll.h>
36 
37 static core_log_t      _log      = LOG_T_INIT("output.tcpcli");
38 static output_tcpcli_t _defaults = {
39     LOG_T_INIT_OBJ("output.tcpcli"),
40     0, 0, -1,
41     { 0 }, CORE_OBJECT_PAYLOAD_INIT(0),
42     0, 0, 0, 0,
43     { 5, 0 }, 1
44 };
45 
output_tcpcli_log()46 core_log_t* output_tcpcli_log()
47 {
48     return &_log;
49 }
50 
output_tcpcli_init(output_tcpcli_t * self)51 void output_tcpcli_init(output_tcpcli_t* self)
52 {
53     mlassert_self();
54 
55     *self             = _defaults;
56     self->pkt.payload = self->recvbuf;
57 }
58 
output_tcpcli_destroy(output_tcpcli_t * self)59 void output_tcpcli_destroy(output_tcpcli_t* self)
60 {
61     mlassert_self();
62 
63     if (self->fd > -1) {
64         shutdown(self->fd, SHUT_RDWR);
65         close(self->fd);
66     }
67 }
68 
output_tcpcli_connect(output_tcpcli_t * self,const char * host,const char * port)69 int output_tcpcli_connect(output_tcpcli_t* self, const char* host, const char* port)
70 {
71     struct addrinfo* addr;
72     int              err;
73     mlassert_self();
74     lassert(host, "host is nil");
75     lassert(port, "port is nil");
76 
77     if (self->fd > -1) {
78         lfatal("already connected");
79     }
80 
81     if ((err = getaddrinfo(host, port, 0, &addr))) {
82         lcritical("getaddrinfo(%s, %s) error %s", host, port, gai_strerror(err));
83         return -1;
84     }
85     if (!addr) {
86         lcritical("getaddrinfo failed, no address returned");
87         return -1;
88     }
89 
90     if ((self->fd = socket(addr->ai_addr->sa_family, SOCK_STREAM, 0)) < 0) {
91         lcritical("socket() error %s", core_log_errstr(errno));
92         freeaddrinfo(addr);
93         return -2;
94     }
95 
96     if (connect(self->fd, addr->ai_addr, addr->ai_addrlen)) {
97         lcritical("connect() error %s", core_log_errstr(errno));
98         freeaddrinfo(addr);
99         return -2;
100     }
101 
102     freeaddrinfo(addr);
103     return 0;
104 }
105 
output_tcpcli_nonblocking(output_tcpcli_t * self)106 int output_tcpcli_nonblocking(output_tcpcli_t* self)
107 {
108     int flags;
109     mlassert_self();
110 
111     if (self->fd < 0) {
112         lfatal("not connected");
113     }
114 
115     flags = fcntl(self->fd, F_GETFL);
116     if (flags != -1) {
117         flags = flags & O_NONBLOCK ? 1 : 0;
118     }
119 
120     return flags;
121 }
122 
output_tcpcli_set_nonblocking(output_tcpcli_t * self,int nonblocking)123 int output_tcpcli_set_nonblocking(output_tcpcli_t* self, int nonblocking)
124 {
125     int flags;
126     mlassert_self();
127 
128     if (self->fd < 0) {
129         lfatal("not connected");
130     }
131 
132     if ((flags = fcntl(self->fd, F_GETFL)) == -1) {
133         lcritical("fcntl(FL_GETFL) error %s", core_log_errstr(errno));
134         return -1;
135     }
136 
137     if (nonblocking) {
138         flags |= O_NONBLOCK;
139         self->blocking = 0;
140     } else {
141         flags &= ~O_NONBLOCK;
142         self->blocking = 1;
143     }
144 
145     if (fcntl(self->fd, F_SETFL, flags | O_NONBLOCK)) {
146         lcritical("fcntl(FL_SETFL, %x) error %s", flags, core_log_errstr(errno));
147         return -1;
148     }
149 
150     return 0;
151 }
152 
_receive(output_tcpcli_t * self,const core_object_t * obj)153 static void _receive(output_tcpcli_t* self, const core_object_t* obj)
154 {
155     const uint8_t* payload;
156     size_t         len, sent;
157     uint16_t       dnslen;
158     mlassert_self();
159 
160     for (; obj;) {
161         switch (obj->obj_type) {
162         case CORE_OBJECT_DNS:
163             obj = obj->obj_prev;
164             continue;
165         case CORE_OBJECT_PAYLOAD:
166             payload = ((core_object_payload_t*)obj)->payload;
167             len     = ((core_object_payload_t*)obj)->len;
168             break;
169         default:
170             return;
171         }
172 
173         sent   = 0;
174         dnslen = htons(len);
175 
176         for (;;) {
177             ssize_t ret = sendto(self->fd, ((uint8_t*)&dnslen) + sent, sizeof(dnslen) - sent, 0, 0, 0);
178             if (ret > -1) {
179                 sent += ret;
180                 if (sent < sizeof(dnslen))
181                     continue;
182 
183                 sent = 0;
184                 for (;;) {
185                     ssize_t ret = sendto(self->fd, payload + sent, len - sent, 0, 0, 0);
186                     if (ret > -1) {
187                         sent += ret;
188                         if (sent < len)
189                             continue;
190                         self->pkts++;
191                         return;
192                     }
193                     switch (errno) {
194                     case EAGAIN:
195 #if EAGAIN != EWOULDBLOCK
196                     case EWOULDBLOCK:
197 #endif
198                         continue;
199                     default:
200                         break;
201                     }
202                     break;
203                 }
204                 self->errs++;
205                 return;
206             }
207             switch (errno) {
208             case EAGAIN:
209 #if EAGAIN != EWOULDBLOCK
210             case EWOULDBLOCK:
211 #endif
212                 continue;
213             default:
214                 break;
215             }
216             break;
217         }
218         self->errs++;
219         break;
220     }
221 }
222 
output_tcpcli_receiver(output_tcpcli_t * self)223 core_receiver_t output_tcpcli_receiver(output_tcpcli_t* self)
224 {
225     mlassert_self();
226 
227     if (self->fd < 0) {
228         lfatal("not connected");
229     }
230 
231     return (core_receiver_t)_receive;
232 }
233 
_produce(output_tcpcli_t * self)234 static const core_object_t* _produce(output_tcpcli_t* self)
235 {
236     ssize_t       n, recv = 0;
237     uint16_t      dnslen;
238     struct pollfd p;
239     int           to = 0;
240     mlassert_self();
241 
242     // Check if last recvfrom() got more then we needed
243     if (!self->have_dnslen && self->recv > self->dnslen) {
244         recv = self->recv - self->dnslen;
245         if (recv < sizeof(dnslen)) {
246             memcpy(((uint8_t*)&dnslen), self->recvbuf + self->dnslen, recv);
247         } else {
248             memcpy(((uint8_t*)&dnslen), self->recvbuf + self->dnslen, sizeof(dnslen));
249 
250             if (recv > sizeof(dnslen)) {
251                 self->recv = recv - sizeof(dnslen);
252                 memmove(self->recvbuf, self->recvbuf + self->dnslen + sizeof(dnslen), self->recv);
253             } else {
254                 self->recv = 0;
255             }
256 
257             self->dnslen      = ntohs(dnslen);
258             self->have_dnslen = 1;
259 
260             if (self->recv > self->dnslen) {
261                 self->pkts_recv++;
262                 self->pkt.len     = self->dnslen;
263                 self->have_dnslen = 0;
264                 return (core_object_t*)&self->pkt;
265             }
266         }
267     }
268 
269     if (self->blocking) {
270         p.fd      = self->fd;
271         p.events  = POLLIN;
272         p.revents = 0;
273         to        = (self->timeout.sec * 1e3) + (self->timeout.nsec / 1e6); //NOSONAR
274         if (!to) {
275             to = 1;
276         }
277     }
278 
279     if (!self->have_dnslen) {
280         for (;;) {
281             n = poll(&p, 1, to);
282             if (n < 0 || (p.revents & (POLLERR | POLLHUP | POLLNVAL))) {
283                 self->errs++;
284                 return 0;
285             }
286             if (!n || !(p.revents & POLLIN)) {
287                 if (recv) {
288                     self->errs++;
289                     return 0;
290                 }
291                 self->pkt.len = 0;
292                 return (core_object_t*)&self->pkt;
293             }
294 
295             n = recvfrom(self->fd, ((uint8_t*)&dnslen) + recv, sizeof(dnslen) - recv, 0, 0, 0);
296             if (n > 0) {
297                 recv += n;
298                 if (recv < sizeof(dnslen))
299                     continue;
300                 break;
301             }
302             if (!n) {
303                 break;
304             }
305             switch (errno) {
306             case EAGAIN:
307 #if EAGAIN != EWOULDBLOCK
308             case EWOULDBLOCK:
309 #endif
310                 continue;
311             default:
312                 break;
313             }
314             self->errs++;
315             break;
316         }
317 
318         if (n < 1) {
319             return 0;
320         }
321 
322         self->dnslen      = ntohs(dnslen);
323         self->have_dnslen = 1;
324         self->recv        = 0;
325     }
326 
327     for (;;) {
328         n = poll(&p, 1, to);
329         if (n < 0 || (p.revents & (POLLERR | POLLHUP | POLLNVAL))) {
330             self->errs++;
331             return 0;
332         }
333         if (!n || !(p.revents & POLLIN)) {
334             self->pkt.len = 0;
335             return (core_object_t*)&self->pkt;
336         }
337 
338         n = recvfrom(self->fd, self->recvbuf + self->recv, sizeof(self->recvbuf) - self->recv, 0, 0, 0);
339         if (n > 0) {
340             self->recv += n;
341             if (self->recv < self->dnslen)
342                 continue;
343             break;
344         }
345         if (!n) {
346             break;
347         }
348         switch (errno) {
349         case EAGAIN:
350 #if EAGAIN != EWOULDBLOCK
351         case EWOULDBLOCK:
352 #endif
353             self->pkt.len = 0;
354             return (core_object_t*)&self->pkt;
355         default:
356             break;
357         }
358         self->errs++;
359         break;
360     }
361 
362     if (n < 1) {
363         return 0;
364     }
365 
366     self->pkts_recv++;
367     self->pkt.len     = self->dnslen;
368     self->have_dnslen = 0;
369     return (core_object_t*)&self->pkt;
370 }
371 
output_tcpcli_producer(output_tcpcli_t * self)372 core_producer_t output_tcpcli_producer(output_tcpcli_t* self)
373 {
374     mlassert_self();
375 
376     if (self->fd < 0) {
377         lfatal("not connected");
378     }
379 
380     return (core_producer_t)_produce;
381 }
382