1 /* Public domain, from djbdns-1.05. */
2 
3 #include <sys/types.h>
4 #include <sys/socket.h>
5 #include <unistd.h>
6 #include "socket.h"
7 #include "alloc.h"
8 #include "error.h"
9 #include "byte.h"
10 #include "uint16.h"
11 #include "dns.h"
12 
serverwantstcp(const char * buf,unsigned int len)13 static int serverwantstcp(const char *buf,unsigned int len)
14 {
15   char out[12];
16 
17   if (!dns_packet_copy(buf,len,0,out,12)) return 1;
18   if (out[2] & 2) return 1;
19   return 0;
20 }
21 
serverfailed(const char * buf,unsigned int len)22 static int serverfailed(const char *buf,unsigned int len)
23 {
24   char out[12];
25   unsigned int rcode;
26 
27   if (!dns_packet_copy(buf,len,0,out,12)) return 1;
28   rcode = out[3];
29   rcode &= 15;
30   if (rcode && (rcode != 3)) { errno = error_again; return 1; }
31   return 0;
32 }
33 
irrelevant(const struct dns_transmit * d,const char * buf,unsigned int len)34 static int irrelevant(const struct dns_transmit *d,const char *buf,unsigned int len)
35 {
36   char out[12];
37   char *dn;
38   unsigned int pos;
39 
40   pos = dns_packet_copy(buf,len,0,out,12); if (!pos) return 1;
41   if (byte_diff(out,2,d->query + 2)) return 1;
42   if (out[4] != 0) return 1;
43   if (out[5] != 1) return 1;
44 
45   dn = 0;
46   pos = dns_packet_getname(buf,len,pos,&dn); if (!pos) return 1;
47   if (!dns_domain_equal(dn,d->query + 14)) { alloc_free(dn); return 1; }
48   alloc_free(dn);
49 
50   pos = dns_packet_copy(buf,len,pos,out,4); if (!pos) return 1;
51   if (byte_diff(out,2,d->qtype)) return 1;
52   if (byte_diff(out + 2,2,DNS_C_IN)) return 1;
53 
54   return 0;
55 }
56 
packetfree(struct dns_transmit * d)57 static void packetfree(struct dns_transmit *d)
58 {
59   if (!d->packet) return;
60   alloc_free(d->packet);
61   d->packet = 0;
62 }
63 
queryfree(struct dns_transmit * d)64 static void queryfree(struct dns_transmit *d)
65 {
66   if (!d->query) return;
67   alloc_free(d->query);
68   d->query = 0;
69 }
70 
socketfree(struct dns_transmit * d)71 static void socketfree(struct dns_transmit *d)
72 {
73   if (!d->s1) return;
74   close(d->s1 - 1);
75   d->s1 = 0;
76 }
77 
dns_transmit_free(struct dns_transmit * d)78 void dns_transmit_free(struct dns_transmit *d)
79 {
80   queryfree(d);
81   socketfree(d);
82   packetfree(d);
83 }
84 
randombind(struct dns_transmit * d)85 static int randombind(struct dns_transmit *d)
86 {
87   int j;
88 
89   for (j = 0;j < 10;++j)
90     if (socket_bind4(d->s1 - 1,d->localip,1025 + dns_random(64510)) == 0)
91       return 0;
92   if (socket_bind4(d->s1 - 1,d->localip,0) == 0)
93     return 0;
94   return -1;
95 }
96 
97 static const int timeouts[4] = { 1, 3, 11, 45 };
98 
thisudp(struct dns_transmit * d)99 static int thisudp(struct dns_transmit *d)
100 {
101   const char *ip;
102 
103   socketfree(d);
104 
105   while (d->udploop < 4) {
106     for (;d->curserver < 16;++d->curserver) {
107       ip = d->servers + 4 * d->curserver;
108       if (byte_diff(ip,4,"\0\0\0\0")) {
109 	d->query[2] = dns_random(256);
110 	d->query[3] = dns_random(256);
111 
112         d->s1 = 1 + socket_udp();
113         if (!d->s1) { dns_transmit_free(d); return -1; }
114 	if (randombind(d) == -1) { dns_transmit_free(d); return -1; }
115 
116         if (socket_connect4(d->s1 - 1,ip,53) == 0)
117           if (send(d->s1 - 1,d->query + 2,d->querylen - 2,0) == d->querylen - 2) {
118             struct taia now;
119             taia_now(&now);
120             taia_uint(&d->deadline,timeouts[d->udploop]);
121             taia_add(&d->deadline,&d->deadline,&now);
122             d->tcpstate = 0;
123             return 0;
124           }
125 
126         socketfree(d);
127       }
128     }
129 
130     ++d->udploop;
131     d->curserver = 0;
132   }
133 
134   dns_transmit_free(d); return -1;
135 }
136 
firstudp(struct dns_transmit * d)137 static int firstudp(struct dns_transmit *d)
138 {
139   d->curserver = 0;
140   return thisudp(d);
141 }
142 
nextudp(struct dns_transmit * d)143 static int nextudp(struct dns_transmit *d)
144 {
145   ++d->curserver;
146   return thisudp(d);
147 }
148 
thistcp(struct dns_transmit * d)149 static int thistcp(struct dns_transmit *d)
150 {
151   struct taia now;
152   const char *ip;
153 
154   socketfree(d);
155   packetfree(d);
156 
157   for (;d->curserver < 16;++d->curserver) {
158     ip = d->servers + 4 * d->curserver;
159     if (byte_diff(ip,4,"\0\0\0\0")) {
160       d->query[2] = dns_random(256);
161       d->query[3] = dns_random(256);
162 
163       d->s1 = 1 + socket_tcp();
164       if (!d->s1) { dns_transmit_free(d); return -1; }
165       if (randombind(d) == -1) { dns_transmit_free(d); return -1; }
166 
167       taia_now(&now);
168       taia_uint(&d->deadline,10);
169       taia_add(&d->deadline,&d->deadline,&now);
170       if (socket_connect4(d->s1 - 1,ip,53) == 0) {
171         d->tcpstate = 2;
172         return 0;
173       }
174       if ((errno == error_inprogress) || (errno == error_wouldblock)) {
175         d->tcpstate = 1;
176         return 0;
177       }
178 
179       socketfree(d);
180     }
181   }
182 
183   dns_transmit_free(d); return -1;
184 }
185 
firsttcp(struct dns_transmit * d)186 static int firsttcp(struct dns_transmit *d)
187 {
188   d->curserver = 0;
189   return thistcp(d);
190 }
191 
nexttcp(struct dns_transmit * d)192 static int nexttcp(struct dns_transmit *d)
193 {
194   ++d->curserver;
195   return thistcp(d);
196 }
197 
dns_transmit_start(struct dns_transmit * d,const char servers[64],int flagrecursive,const char * q,const char qtype[2],const char localip[4])198 int dns_transmit_start(struct dns_transmit *d,const char servers[64],int flagrecursive,const char *q,const char qtype[2],const char localip[4])
199 {
200   unsigned int len;
201 
202   dns_transmit_free(d);
203   errno = error_io;
204 
205   len = dns_domain_length(q);
206   d->querylen = len + 18;
207   d->query = alloc(d->querylen);
208   if (!d->query) return -1;
209 
210   uint16_pack_big(d->query,len + 16);
211   byte_copy(d->query + 2,12,flagrecursive ? "\0\0\1\0\0\1\0\0\0\0\0\0" : "\0\0\0\0\0\1\0\0\0\0\0\0gcc-bug-workaround");
212   byte_copy(d->query + 14,len,q);
213   byte_copy(d->query + 14 + len,2,qtype);
214   byte_copy(d->query + 16 + len,2,DNS_C_IN);
215 
216   byte_copy(d->qtype,2,qtype);
217   d->servers = servers;
218   byte_copy(d->localip,4,localip);
219 
220   d->udploop = flagrecursive ? 1 : 0;
221 
222   if (len + 16 > 512) return firsttcp(d);
223   return firstudp(d);
224 }
225 
dns_transmit_io(struct dns_transmit * d,iopause_fd * x,struct taia * deadline)226 void dns_transmit_io(struct dns_transmit *d,iopause_fd *x,struct taia *deadline)
227 {
228   x->fd = d->s1 - 1;
229 
230   switch(d->tcpstate) {
231     case 0: case 3: case 4: case 5:
232       x->events = IOPAUSE_READ;
233       break;
234     case 1: case 2:
235       x->events = IOPAUSE_WRITE;
236       break;
237   }
238 
239   if (taia_less(&d->deadline,deadline))
240     *deadline = d->deadline;
241 }
242 
dns_transmit_get(struct dns_transmit * d,const iopause_fd * x,const struct taia * when)243 int dns_transmit_get(struct dns_transmit *d,const iopause_fd *x,const struct taia *when)
244 {
245   char udpbuf[513];
246   unsigned char ch;
247   int r;
248   int fd;
249 
250   errno = error_io;
251   fd = d->s1 - 1;
252 
253   if (!x->revents) {
254     if (taia_less(when,&d->deadline)) return 0;
255     errno = error_timeout;
256     if (d->tcpstate == 0) return nextudp(d);
257     return nexttcp(d);
258   }
259 
260   if (d->tcpstate == 0) {
261 /*
262 have attempted to send UDP query to each server udploop times
263 have sent query to curserver on UDP socket s
264 */
265     r = recv(fd,udpbuf,sizeof udpbuf,0);
266     if (r <= 0) {
267       if (errno == error_connrefused) if (d->udploop == 2) return 0;
268       return nextudp(d);
269     }
270     if ((unsigned)r + 1 > sizeof udpbuf) return 0;
271 
272     if (irrelevant(d,udpbuf,r)) return 0;
273     if (serverwantstcp(udpbuf,r)) return firsttcp(d);
274     if (serverfailed(udpbuf,r)) {
275       if (d->udploop == 2) return 0;
276       return nextudp(d);
277     }
278     socketfree(d);
279 
280     d->packetlen = r;
281     d->packet = alloc(d->packetlen);
282     if (!d->packet) { dns_transmit_free(d); return -1; }
283     byte_copy(d->packet,d->packetlen,udpbuf);
284     queryfree(d);
285     return 1;
286   }
287 
288   if (d->tcpstate == 1) {
289 /*
290 have sent connection attempt to curserver on TCP socket s
291 pos not defined
292 */
293     if (!socket_connected(fd)) return nexttcp(d);
294     d->pos = 0;
295     d->tcpstate = 2;
296     return 0;
297   }
298 
299   if (d->tcpstate == 2) {
300 /*
301 have connection to curserver on TCP socket s
302 have sent pos bytes of query
303 */
304     r = write(fd,d->query + d->pos,d->querylen - d->pos);
305     if (r <= 0) return nexttcp(d);
306     d->pos += r;
307     if (d->pos == d->querylen) {
308       struct taia now;
309       taia_now(&now);
310       taia_uint(&d->deadline,10);
311       taia_add(&d->deadline,&d->deadline,&now);
312       d->tcpstate = 3;
313     }
314     return 0;
315   }
316 
317   if (d->tcpstate == 3) {
318 /*
319 have sent entire query to curserver on TCP socket s
320 pos not defined
321 */
322     r = read(fd,&ch,1);
323     if (r <= 0) return nexttcp(d);
324     d->packetlen = ch;
325     d->tcpstate = 4;
326     return 0;
327   }
328 
329   if (d->tcpstate == 4) {
330 /*
331 have sent entire query to curserver on TCP socket s
332 pos not defined
333 have received one byte of packet length into packetlen
334 */
335     r = read(fd,&ch,1);
336     if (r <= 0) return nexttcp(d);
337     d->packetlen <<= 8;
338     d->packetlen += ch;
339     d->tcpstate = 5;
340     d->pos = 0;
341     d->packet = alloc(d->packetlen);
342     if (!d->packet) { dns_transmit_free(d); return -1; }
343     return 0;
344   }
345 
346   if (d->tcpstate == 5) {
347 /*
348 have sent entire query to curserver on TCP socket s
349 have received entire packet length into packetlen
350 packet is allocated
351 have received pos bytes of packet
352 */
353     r = read(fd,d->packet + d->pos,d->packetlen - d->pos);
354     if (r <= 0) return nexttcp(d);
355     d->pos += r;
356     if (d->pos < d->packetlen) return 0;
357 
358     socketfree(d);
359     if (irrelevant(d,d->packet,d->packetlen)) return nexttcp(d);
360     if (serverwantstcp(d->packet,d->packetlen)) return nexttcp(d);
361     if (serverfailed(d->packet,d->packetlen)) return nexttcp(d);
362 
363     queryfree(d);
364     return 1;
365   }
366 
367   return 0;
368 }
369