1 #include <u.h>
2 #include <libc.h>
3 #include <ip.h>
4 #include <bio.h>
5 #include <ndb.h>
6 #include "dns.h"
7 
8 /*
9  *  a dictionary of domain names for packing messages
10  */
11 enum
12 {
13 	Ndict=	64
14 };
15 typedef struct Dict	Dict;
16 struct Dict
17 {
18 	struct {
19 		ushort	offset;		/* pointer to packed name in message */
20 		char	*name;		/* pointer to unpacked name in buf */
21 	} x[Ndict];
22 	int n;			/* size of dictionary */
23 	uchar *start;		/* start of packed message */
24 	char buf[4*1024];	/* buffer for unpacked names */
25 	char *ep;		/* first free char in buf */
26 };
27 
28 #define NAME(x)		p = pname(p, ep, x, dp)
29 #define SYMBOL(x)	p = psym(p, ep, x)
30 #define STRING(x)	p = pstr(p, ep, x)
31 #define BYTES(x, n)	p = pbytes(p, ep, x, n)
32 #define USHORT(x)	p = pushort(p, ep, x)
33 #define UCHAR(x)	p = puchar(p, ep, x)
34 #define ULONG(x)	p = pulong(p, ep, x)
35 #define V4ADDR(x)	p = pv4addr(p, ep, x)
36 #define V6ADDR(x)	p = pv6addr(p, ep, x)
37 
38 static uchar*
psym(uchar * p,uchar * ep,char * np)39 psym(uchar *p, uchar *ep, char *np)
40 {
41 	int n;
42 
43 	n = strlen(np);
44 	if(n >= Strlen)			/* DNS maximum length string */
45 		n = Strlen - 1;
46 	if(ep - p < n+1)		/* see if it fits in the buffer */
47 		return ep+1;
48 	*p++ = n;
49 	memcpy(p, np, n);
50 	return p + n;
51 }
52 
53 static uchar*
pstr(uchar * p,uchar * ep,char * np)54 pstr(uchar *p, uchar *ep, char *np)
55 {
56 	int n;
57 
58 	n = strlen(np);
59 	if(n >= Strlen)			/* DNS maximum length string */
60 		n = Strlen - 1;
61 	if(ep - p < n+1)		/* see if it fits in the buffer */
62 		return ep+1;
63 	*p++ = n;
64 	memcpy(p, np, n);
65 	return p + n;
66 }
67 
68 static uchar*
pbytes(uchar * p,uchar * ep,uchar * np,int n)69 pbytes(uchar *p, uchar *ep, uchar *np, int n)
70 {
71 	if(ep - p < n)
72 		return ep+1;
73 	memcpy(p, np, n);
74 	return p + n;
75 }
76 
77 static uchar*
puchar(uchar * p,uchar * ep,int val)78 puchar(uchar *p, uchar *ep, int val)
79 {
80 	if(ep - p < 1)
81 		return ep+1;
82 	*p++ = val;
83 	return p;
84 }
85 
86 static uchar*
pushort(uchar * p,uchar * ep,int val)87 pushort(uchar *p, uchar *ep, int val)
88 {
89 	if(ep - p < 2)
90 		return ep+1;
91 	*p++ = val>>8;
92 	*p++ = val;
93 	return p;
94 }
95 
96 static uchar*
pulong(uchar * p,uchar * ep,int val)97 pulong(uchar *p, uchar *ep, int val)
98 {
99 	if(ep - p < 4)
100 		return ep+1;
101 	*p++ = val>>24;
102 	*p++ = val>>16;
103 	*p++ = val>>8;
104 	*p++ = val;
105 	return p;
106 }
107 
108 static uchar*
pv4addr(uchar * p,uchar * ep,char * name)109 pv4addr(uchar *p, uchar *ep, char *name)
110 {
111 	uchar ip[IPaddrlen];
112 
113 	if(ep - p < 4)
114 		return ep+1;
115 	parseip(ip, name);
116 	v6tov4(p, ip);
117 	return p + 4;
118 
119 }
120 
121 static uchar*
pv6addr(uchar * p,uchar * ep,char * name)122 pv6addr(uchar *p, uchar *ep, char *name)
123 {
124 	if(ep - p < IPaddrlen)
125 		return ep+1;
126 	parseip(p, name);
127 	return p + IPaddrlen;
128 
129 }
130 
131 static uchar*
pname(uchar * p,uchar * ep,char * np,Dict * dp)132 pname(uchar *p, uchar *ep, char *np, Dict *dp)
133 {
134 	char *cp;
135 	int i;
136 	char *last;		/* last component packed */
137 
138 	if(strlen(np) >= Domlen)	/* make sure we don't exceed DNS limits */
139 		return ep+1;
140 
141 	last = 0;
142 	while(*np){
143 		/* look through every component in the dictionary for a match */
144 		for(i = 0; i < dp->n; i++){
145 			if(strcmp(np, dp->x[i].name) == 0){
146 				if(ep - p < 2)
147 					return ep+1;
148 				*p++ = (dp->x[i].offset>>8) | 0xc0;
149 				*p++ = dp->x[i].offset;
150 				return p;
151 			}
152 		}
153 
154 		/* if there's room, enter this name in dictionary */
155 		if(dp->n < Ndict){
156 			if(last){
157 				/* the whole name is already in dp->buf */
158 				last = strchr(last, '.') + 1;
159 				dp->x[dp->n].name = last;
160 				dp->x[dp->n].offset = p - dp->start;
161 				dp->n++;
162 			} else {
163 				/* add to dp->buf */
164 				i = strlen(np);
165 				if(dp->ep + i + 1 < &dp->buf[sizeof(dp->buf)]){
166 					strcpy(dp->ep, np);
167 					dp->x[dp->n].name = dp->ep;
168 					last = dp->ep;
169 					dp->x[dp->n].offset = p - dp->start;
170 					dp->ep += i + 1;
171 					dp->n++;
172 				}
173 			}
174 		}
175 
176 		/* put next component into message */
177 		cp = strchr(np, '.');
178 		if(cp == 0){
179 			i = strlen(np);
180 			cp = np + i;	/* point to null terminator */
181 		} else {
182 			i = cp - np;
183 			cp++;		/* point past '.' */
184 		}
185 		if(ep-p < i+1)
186 			return ep+1;
187 		*p++ = i;		/* count of chars in label */
188 		memcpy(p, np, i);
189 		np = cp;
190 		p += i;
191 	}
192 
193 	if(p >= ep)
194 		return ep+1;
195 	*p++ = 0;	/* add top level domain */
196 
197 	return p;
198 }
199 
200 static uchar*
convRR2M(RR * rp,uchar * p,uchar * ep,Dict * dp)201 convRR2M(RR *rp, uchar *p, uchar *ep, Dict *dp)
202 {
203 	uchar *lp, *data;
204 	int len, ttl;
205 	Txt *t;
206 
207 	NAME(rp->owner->name);
208 	USHORT(rp->type);
209 	USHORT(rp->owner->class);
210 
211 	/* egregious overuse of ttl (it's absolute time in the cache) */
212 	if(rp->db)
213 		ttl = rp->ttl;
214 	else
215 		ttl = rp->ttl - now;
216 	if(ttl < 0)
217 		ttl = 0;
218 	ULONG(ttl);
219 
220 	lp = p;			/* leave room for the rdata length */
221 	p += 2;
222 	data = p;
223 
224 	if(data >= ep)
225 		return p+1;
226 
227 	switch(rp->type){
228 	case Thinfo:
229 		SYMBOL(rp->cpu->name);
230 		SYMBOL(rp->os->name);
231 		break;
232 	case Tcname:
233 	case Tmb:
234 	case Tmd:
235 	case Tmf:
236 	case Tns:
237 		NAME(rp->host->name);
238 		break;
239 	case Tmg:
240 	case Tmr:
241 		NAME(rp->mb->name);
242 		break;
243 	case Tminfo:
244 		NAME(rp->rmb->name);
245 		NAME(rp->mb->name);
246 		break;
247 	case Tmx:
248 		USHORT(rp->pref);
249 		NAME(rp->host->name);
250 		break;
251 	case Ta:
252 		V4ADDR(rp->ip->name);
253 		break;
254 	case Taaaa:
255 		V6ADDR(rp->ip->name);
256 		break;
257 	case Tptr:
258 		NAME(rp->ptr->name);
259 		break;
260 	case Tsoa:
261 		NAME(rp->host->name);
262 		NAME(rp->rmb->name);
263 		ULONG(rp->soa->serial);
264 		ULONG(rp->soa->refresh);
265 		ULONG(rp->soa->retry);
266 		ULONG(rp->soa->expire);
267 		ULONG(rp->soa->minttl);
268 		break;
269 	case Ttxt:
270 		for(t = rp->txt; t != nil; t = t->next)
271 			STRING(t->p);
272 		break;
273 	case Tnull:
274 		BYTES(rp->null->data, rp->null->dlen);
275 		break;
276 	case Trp:
277 		NAME(rp->rmb->name);
278 		NAME(rp->rp->name);
279 		break;
280 	case Tkey:
281 		USHORT(rp->key->flags);
282 		UCHAR(rp->key->proto);
283 		UCHAR(rp->key->alg);
284 		BYTES(rp->key->data, rp->key->dlen);
285 		break;
286 	case Tsig:
287 		USHORT(rp->sig->type);
288 		UCHAR(rp->sig->alg);
289 		UCHAR(rp->sig->labels);
290 		ULONG(rp->sig->ttl);
291 		ULONG(rp->sig->exp);
292 		ULONG(rp->sig->incep);
293 		USHORT(rp->sig->tag);
294 		NAME(rp->sig->signer->name);
295 		BYTES(rp->sig->data, rp->sig->dlen);
296 		break;
297 	case Tcert:
298 		USHORT(rp->cert->type);
299 		USHORT(rp->cert->tag);
300 		UCHAR(rp->cert->alg);
301 		BYTES(rp->cert->data, rp->cert->dlen);
302 		break;
303 	}
304 
305 	/* stuff in the rdata section length */
306 	len = p - data;
307 	*lp++ = len >> 8;
308 	*lp = len;
309 
310 	return p;
311 }
312 
313 static uchar*
convQ2M(RR * rp,uchar * p,uchar * ep,Dict * dp)314 convQ2M(RR *rp, uchar *p, uchar *ep, Dict *dp)
315 {
316 	NAME(rp->owner->name);
317 	USHORT(rp->type);
318 	USHORT(rp->owner->class);
319 	return p;
320 }
321 
322 static uchar*
rrloop(RR * rp,int * countp,uchar * p,uchar * ep,Dict * dp,int quest)323 rrloop(RR *rp, int *countp, uchar *p, uchar *ep, Dict *dp, int quest)
324 {
325 	uchar *np;
326 
327 	*countp = 0;
328 	for(; rp && p < ep; rp = rp->next){
329 		if(quest)
330 			np = convQ2M(rp, p, ep, dp);
331 		else
332 			np = convRR2M(rp, p, ep, dp);
333 		if(np > ep)
334 			break;
335 		p = np;
336 		(*countp)++;
337 	}
338 	return p;
339 }
340 
341 /*
342  *  convert into a message
343  */
344 int
convDNS2M(DNSmsg * m,uchar * buf,int len)345 convDNS2M(DNSmsg *m, uchar *buf, int len)
346 {
347 	uchar *p, *ep, *np;
348 	Dict d;
349 
350 	d.n = 0;
351 	d.start = buf;
352 	d.ep = d.buf;
353 	memset(buf, 0, len);
354 	m->qdcount = m->ancount = m->nscount = m->arcount = 0;
355 
356 	/* first pack in the RR's so we can get real counts */
357 	p = buf + 12;
358 	ep = buf + len;
359 	p = rrloop(m->qd, &m->qdcount, p, ep, &d, 1);
360 	p = rrloop(m->an, &m->ancount, p, ep, &d, 0);
361 	p = rrloop(m->ns, &m->nscount, p, ep, &d, 0);
362 	p = rrloop(m->ar, &m->arcount, p, ep, &d, 0);
363 	if(p > ep)
364 		return -1;
365 
366 	/* now pack the rest */
367 	np = p;
368 	p = buf;
369 	ep = buf + len;
370 	USHORT(m->id);
371 	USHORT(m->flags);
372 	USHORT(m->qdcount);
373 	USHORT(m->ancount);
374 	USHORT(m->nscount);
375 	USHORT(m->arcount);
376 	if(p > ep)
377 		return -1;
378 
379 	return np - buf;
380 }
381