1 /* Copyright (c) 2007 Google Inc.
2 *
3 * Licensed under the Apache License, Version 2.0 (the "License");
4 * you may not use this file except in compliance with the License.
5 * You may obtain a copy of the License at
6 *
7 * http://www.apache.org/licenses/LICENSE-2.0
8 *
9 * Unless required by applicable law or agreed to in writing, software
10 * distributed under the License is distributed on an "AS IS" BASIS,
11 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 * See the License for the specific language governing permissions and
13 * limitations under the License.
14 */
15
16 #include <sys/types.h>
17 #include <sys/socket.h>
18 #include <netinet/in.h>
19 #include <arpa/inet.h>
20 #include <strings.h>
21 #include <string.h>
22 #include <stdio.h>
23 #include <stdlib.h>
24 #include <getopt.h>
25
26 #include "query_record.h"
27 #include "check_record.h"
28
29 /* Read in a NAME field, but discard the result (saves having to parse
30 * compressed labels */
skip_name(char ** ptr,char * end)31 void skip_name(char** ptr, char* end) {
32 if (*ptr >= end)
33 return;
34
35 do {
36 if ((**ptr & 0xc0) == 0)
37 *ptr += (unsigned char)**ptr + 1;
38 else {
39 (*ptr)++;
40 break;
41 }
42 } while (**ptr != 0 && *ptr < end);
43 (*ptr)++;
44 }
45
46 /* Read in a short in network-byte order */
read_short(char ** ptr,char * end)47 short read_short(char** ptr, char* end) {
48 if (*ptr + 2 > end)
49 return 0;
50
51 short res = ntohs(*((short*)*ptr));
52 *ptr += 2;
53 return res;
54 }
55
56 /* Read in an int in network-byte order */
read_int(char ** ptr,char * end)57 int read_int(char** ptr, char* end) {
58 if (*ptr + 4 > end)
59 return 0;
60
61 int res = ntohl(*((int*)*ptr));
62 *ptr += 4;
63 return res;
64 }
65
usage()66 void usage() {
67 printf("Usage: dnswall -b (bind ip) -B (bind port) -f (forwarder ip) -F (forwarder port)\n");
68 }
69
main(int argc,char ** argv)70 int main(int argc, char** argv) {
71 char* bind_addr = NULL;
72 int bind_port = 0;
73 char* forward_addr = NULL;
74 int forward_port = 0;
75
76 // Get command-line options
77 int c;
78 while ((c = getopt(argc, argv, "hb:f:B:F:")) != -1) {
79 switch (c) {
80 case 'h':
81 usage();
82 return 0;
83 case 'b':
84 bind_addr = optarg;
85 break;
86 case 'f':
87 forward_addr = optarg;
88 break;
89 case 'B':
90 bind_port = atoi(optarg);
91 break;
92 case 'F':
93 forward_port = atoi(optarg);
94 break;
95 }
96 }
97
98 if (bind_addr == NULL || forward_addr == NULL ||
99 bind_port == 0 || forward_port == 0) {
100 usage();
101 return 1;
102 }
103
104 // Create the (only) socket
105 int sock;
106 if ((sock = socket(PF_INET, SOCK_DGRAM, 0)) < 0)
107 return 0;
108
109 // Bind to the appropriate IP and port
110 struct sockaddr_in addr;
111 bzero(&addr, sizeof(addr));
112 addr.sin_family = AF_INET;
113 addr.sin_addr.s_addr = inet_addr(bind_addr);
114 addr.sin_port = htons(bind_port);
115 if ((bind(sock, (struct sockaddr *)&addr, sizeof(addr))) < 0)
116 return 0;
117
118 // Setup the sockaddr struct for the forwarder (will be using this a lot)
119 struct sockaddr_in dst_addr;
120 bzero(&dst_addr, sizeof(dst_addr));
121 dst_addr.sin_family = AF_INET;
122 dst_addr.sin_addr.s_addr = inet_addr(forward_addr);
123 dst_addr.sin_port = htons(forward_port);
124
125 InitQueryRecordHeap();
126
127 char msg[1024];
128 while (1) {
129 int addrlen = sizeof(addr);
130 int len = recvfrom(sock, msg, sizeof(msg), 0,
131 (struct sockaddr *)&addr, (socklen_t *)&addrlen);
132
133 // If there was an error or the msg was too big (specified max in RFC),
134 // then just drop it
135 if (len <= 0 || len > 512)
136 continue;
137
138 // If the packet is a query, then proxy (almost) without change
139 if ((*((short *)&msg[2]) & 0x8000) == 0) {
140 QueryRecord* record = AllocQueryRecord();
141 // Record the old id and source address
142 record->old_id = ntohs(*((short*)&msg[0]));
143 record->src_addr = addr;
144
145 // Replace the id with our own query identifier
146 *((short *)&msg[0]) = htons(record->id);
147
148 // Send!
149 sendto(sock, msg, len, 0,
150 (struct sockaddr *)&dst_addr,
151 sizeof(dst_addr));
152 }
153 // If the packet is a response, then check it for invalid records
154 else {
155 // If it didn't come from the real forwarder, ignore it
156 if (addr.sin_addr.s_addr != dst_addr.sin_addr.s_addr)
157 continue;
158
159 // Extract our query identifier from the packet
160 int id = ntohs(*((short *)&msg[0]));
161 QueryRecord* record = GetQueryRecordById(id);
162
163 // Validate that this query id is currently in flight.
164 if (!record)
165 continue;
166
167 int valid = 1;
168
169 // Start of the real payload
170 char *ptr = &msg[12];
171 // Skip all the query records
172 for (int count = ntohs(*((short*)&msg[4])); count > 0; count--) {
173 skip_name(&ptr, msg + len);
174 int type = read_short(&ptr, msg + len);
175 int class = read_short(&ptr, msg + len);
176
177 // CNAME queries are not allowed
178 if (type == 5 && class == 1)
179 valid = 0;
180 }
181
182 // Start of the response section
183 char *end = ptr;
184 for (int count = ntohs(*((short*)&msg[6])); count > 0; count--) {
185 skip_name(&ptr, msg + len);
186
187 int type = read_short(&ptr, msg + len);
188 int class = read_short(&ptr, msg + len);
189 /* int ttl = */ read_int(&ptr, msg + len);
190
191 int rlen = read_short(&ptr, msg + len);
192 // If its an A record, check it for private IPs
193 if (class == 1 && type == 1) {
194 if (CheckARecord(ptr, msg + len) == 0)
195 valid = 0;
196 }
197
198 // If its an AAAA record, check it for private IPs
199 if (class == 1 && type == 28) {
200 if (CheckAAAARecord(ptr, msg + len) == 0)
201 valid = 0;
202 }
203 ptr += rlen;
204 }
205
206 // If this was an invalid response, substitute in an appropriate
207 // NXDOMAIN response by mangling the original response.
208 if (!valid) {
209 // Set the response code to NXDOMAIN
210 *((short*)&msg[2]) = htons((ntohs(*((short*)&msg[2])) & 0xfff0) | 3);
211
212 // Set the number of response entries to 0, authority to 1, and
213 // additional to 0
214 *((short*)&msg[6]) = htons(0);
215 *((short*)&msg[8]) = htons(1);
216 *((short*)&msg[10]) = htons(0);
217
218 *end = 0; end++; // name = '.'
219 *((short*)end) = htons(6); end += 2; // type = 6 (SOA)
220 *((short*)end) = htons(1); end += 2; // class = 1 (IN)
221 *((int*)end) = htonl(3600); end += 4; // ttl = 3600
222
223 // Skip the length field, we'll fill it in when we know
224 char *soabegin = end;
225 end += 2;
226
227 *end = 0; end++; // mname = '.'
228 *end = 0; end++; // email = '.'
229 *((int*)end) = htonl(1); end += 4; // serial = 1
230 *((int*)end) = htonl(3600); end += 4; // refresh time = 3600
231 *((int*)end) = htonl(600); end += 4; // retry time = 600
232 *((int*)end) = htonl(86400); end += 4; // expire time = 86400
233 *((int*)end) = htonl(3600); end += 4; // minimum ttl = 3600
234
235 // Fill in the length now that we know what it is
236 *((short*)soabegin) = htons(end - soabegin - 2);
237
238 // Set the new length of the packet
239 len = end - msg;
240 }
241
242 // Put back the original id
243 *((short *)&msg[0]) = htons(record->old_id);
244
245 // Send!
246 sendto(sock, msg, len, 0,
247 (struct sockaddr *)&record->src_addr,
248 sizeof(record->src_addr));
249
250 FreeQueryRecord(record);
251 }
252 }
253 }
254