1 /* Copyright (c) 2007 Google Inc.
2  *
3  * Licensed under the Apache License, Version 2.0 (the "License");
4  * you may not use this file except in compliance with the License.
5  * You may obtain a copy of the License at
6  *
7  *     http://www.apache.org/licenses/LICENSE-2.0
8  *
9  * Unless required by applicable law or agreed to in writing, software
10  * distributed under the License is distributed on an "AS IS" BASIS,
11  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12  * See the License for the specific language governing permissions and
13  * limitations under the License.
14  */
15 
16 #include <sys/types.h>
17 #include <sys/socket.h>
18 #include <netinet/in.h>
19 #include <arpa/inet.h>
20 #include <strings.h>
21 #include <string.h>
22 #include <stdio.h>
23 #include <stdlib.h>
24 #include <getopt.h>
25 
26 #include "query_record.h"
27 #include "check_record.h"
28 
29 /* Read in a NAME field, but discard the result (saves having to parse
30  * compressed labels */
skip_name(char ** ptr,char * end)31 void skip_name(char** ptr, char* end) {
32   if (*ptr >= end)
33     return;
34 
35   do {
36     if ((**ptr & 0xc0) == 0)
37       *ptr += (unsigned char)**ptr + 1;
38     else {
39       (*ptr)++;
40       break;
41     }
42   } while (**ptr != 0 && *ptr < end);
43   (*ptr)++;
44 }
45 
46 /* Read in a short in network-byte order */
read_short(char ** ptr,char * end)47 short read_short(char** ptr, char* end) {
48   if (*ptr + 2 > end)
49     return 0;
50 
51   short res = ntohs(*((short*)*ptr));
52   *ptr += 2;
53   return res;
54 }
55 
56 /* Read in an int in network-byte order */
read_int(char ** ptr,char * end)57 int read_int(char** ptr, char* end) {
58   if (*ptr + 4 > end)
59     return 0;
60 
61   int res = ntohl(*((int*)*ptr));
62   *ptr += 4;
63   return res;
64 }
65 
usage()66 void usage() {
67   printf("Usage: dnswall -b (bind ip) -B (bind port) -f (forwarder ip) -F (forwarder port)\n");
68 }
69 
main(int argc,char ** argv)70 int main(int argc, char** argv) {
71   char* bind_addr = NULL;
72   int bind_port = 0;
73   char* forward_addr = NULL;
74   int forward_port = 0;
75 
76   // Get command-line options
77   int c;
78   while ((c = getopt(argc, argv, "hb:f:B:F:")) != -1) {
79     switch (c) {
80       case 'h':
81         usage();
82         return 0;
83       case 'b':
84         bind_addr = optarg;
85         break;
86       case 'f':
87         forward_addr = optarg;
88         break;
89       case 'B':
90         bind_port = atoi(optarg);
91         break;
92       case 'F':
93         forward_port = atoi(optarg);
94         break;
95     }
96   }
97 
98   if (bind_addr == NULL || forward_addr == NULL ||
99       bind_port == 0 || forward_port == 0) {
100     usage();
101     return 1;
102   }
103 
104   // Create the (only) socket
105   int sock;
106   if ((sock = socket(PF_INET, SOCK_DGRAM, 0)) < 0)
107     return 0;
108 
109   // Bind to the appropriate IP and port
110   struct sockaddr_in addr;
111   bzero(&addr, sizeof(addr));
112   addr.sin_family = AF_INET;
113   addr.sin_addr.s_addr = inet_addr(bind_addr);
114   addr.sin_port = htons(bind_port);
115   if ((bind(sock, (struct sockaddr *)&addr, sizeof(addr))) < 0)
116     return 0;
117 
118   // Setup the sockaddr struct for the forwarder (will be using this a lot)
119   struct sockaddr_in dst_addr;
120   bzero(&dst_addr, sizeof(dst_addr));
121   dst_addr.sin_family = AF_INET;
122   dst_addr.sin_addr.s_addr = inet_addr(forward_addr);
123   dst_addr.sin_port = htons(forward_port);
124 
125   InitQueryRecordHeap();
126 
127   char msg[1024];
128   while (1) {
129     int addrlen = sizeof(addr);
130     int len = recvfrom(sock, msg, sizeof(msg), 0,
131                        (struct sockaddr *)&addr, (socklen_t *)&addrlen);
132 
133     // If there was an error or the msg was too big (specified max in RFC),
134     // then just drop it
135     if (len <= 0 || len > 512)
136       continue;
137 
138     // If the packet is a query, then proxy (almost) without change
139     if ((*((short *)&msg[2]) & 0x8000) == 0) {
140       QueryRecord* record = AllocQueryRecord();
141       // Record the old id and source address
142       record->old_id = ntohs(*((short*)&msg[0]));
143       record->src_addr = addr;
144 
145       // Replace the id with our own query identifier
146       *((short *)&msg[0]) = htons(record->id);
147 
148       // Send!
149       sendto(sock, msg, len, 0,
150              (struct sockaddr *)&dst_addr,
151              sizeof(dst_addr));
152     }
153     // If the packet is a response, then check it for invalid records
154     else {
155       // If it didn't come from the real forwarder, ignore it
156       if (addr.sin_addr.s_addr != dst_addr.sin_addr.s_addr)
157         continue;
158 
159       // Extract our query identifier from the packet
160       int id = ntohs(*((short *)&msg[0]));
161       QueryRecord* record = GetQueryRecordById(id);
162 
163       // Validate that this query id is currently in flight.
164       if (!record)
165         continue;
166 
167       int valid = 1;
168 
169       // Start of the real payload
170       char *ptr = &msg[12];
171       // Skip all the query records
172       for (int count = ntohs(*((short*)&msg[4])); count > 0; count--) {
173         skip_name(&ptr, msg + len);
174         int type = read_short(&ptr, msg + len);
175         int class = read_short(&ptr, msg + len);
176 
177         // CNAME queries are not allowed
178         if (type == 5 && class == 1)
179           valid = 0;
180       }
181 
182       // Start of the response section
183       char *end = ptr;
184       for (int count = ntohs(*((short*)&msg[6])); count > 0; count--) {
185         skip_name(&ptr, msg + len);
186 
187         int type = read_short(&ptr, msg + len);
188         int class = read_short(&ptr, msg + len);
189         /* int ttl = */ read_int(&ptr, msg + len);
190 
191         int rlen = read_short(&ptr, msg + len);
192         // If its an A record, check it for private IPs
193         if (class == 1 && type == 1) {
194           if (CheckARecord(ptr, msg + len) == 0)
195             valid = 0;
196         }
197 
198         // If its an AAAA record, check it for private IPs
199         if (class == 1 && type == 28) {
200           if (CheckAAAARecord(ptr, msg + len) == 0)
201             valid = 0;
202         }
203         ptr += rlen;
204       }
205 
206       // If this was an invalid response, substitute in an appropriate
207       // NXDOMAIN response by mangling the original response.
208       if (!valid) {
209         // Set the response code to NXDOMAIN
210         *((short*)&msg[2]) = htons((ntohs(*((short*)&msg[2])) & 0xfff0) | 3);
211 
212         // Set the number of response entries to 0, authority to 1, and
213         // additional to 0
214         *((short*)&msg[6]) = htons(0);
215         *((short*)&msg[8]) = htons(1);
216         *((short*)&msg[10]) = htons(0);
217 
218         *end = 0; end++; // name = '.'
219         *((short*)end) = htons(6); end += 2; // type = 6 (SOA)
220         *((short*)end) = htons(1); end += 2; // class = 1 (IN)
221         *((int*)end) = htonl(3600); end += 4; // ttl = 3600
222 
223         // Skip the length field, we'll fill it in when we know
224         char *soabegin = end;
225         end += 2;
226 
227         *end = 0; end++; // mname = '.'
228         *end = 0; end++; // email = '.'
229         *((int*)end) = htonl(1); end += 4; // serial = 1
230         *((int*)end) = htonl(3600); end += 4; // refresh time = 3600
231         *((int*)end) = htonl(600); end += 4; // retry time = 600
232         *((int*)end) = htonl(86400); end += 4; // expire time = 86400
233         *((int*)end) = htonl(3600); end += 4; // minimum ttl = 3600
234 
235         // Fill in the length now that we know what it is
236         *((short*)soabegin) = htons(end - soabegin - 2);
237 
238         // Set the new length of the packet
239         len = end - msg;
240       }
241 
242       // Put back the original id
243       *((short *)&msg[0]) = htons(record->old_id);
244 
245       // Send!
246       sendto(sock, msg, len, 0,
247              (struct sockaddr *)&record->src_addr,
248              sizeof(record->src_addr));
249 
250       FreeQueryRecord(record);
251     }
252   }
253 }
254