1 /*
2  * tinysvcmdns - a tiny MDNS implementation for publishing services
3  * Copyright (C) 2011 Darell Tan
4  * All rights reserved.
5  *
6  * Redistribution and use in source and binary forms, with or without
7  * modification, are permitted provided that the following conditions
8  * are met:
9  * 1. Redistributions of source code must retain the above copyright
10  *    notice, this list of conditions and the following disclaimer.
11  * 2. Redistributions in binary form must reproduce the above copyright
12  *    notice, this list of conditions and the following disclaimer in the
13  *    documentation and/or other materials provided with the distribution.
14  * 3. The name of the author may not be used to endorse or promote products
15  *    derived from this software without specific prior written permission.
16  *
17  * THIS SOFTWARE IS PROVIDED BY THE AUTHOR ``AS IS'' AND ANY EXPRESS OR
18  * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES
19  * OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED.
20  * IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY DIRECT, INDIRECT,
21  * INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT
22  * NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
23  * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
24  * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
25  * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF
26  * THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
27  */
28 
29 #include "mdns.h"
30 #include <stdint.h>
31 #include <stdlib.h>
32 #include <stdio.h>
33 #include <string.h>
34 #include <assert.h>
35 
36 #ifdef _WIN32
37 #include <winsock.h>
38 #include <in6addr.h>
39 #else
40 #include <netinet/in.h>
41 #endif
42 
43 
44 #define DEFAULT_TTL		120
45 
46 
47 struct name_comp {
48 	uint8_t *label;	// label
49 	size_t pos;		// position in msg
50 
51 	struct name_comp *next;
52 };
53 
54 // ----- label functions -----
55 
56 // duplicates a name
dup_nlabel(const uint8_t * n)57 inline uint8_t *dup_nlabel(const uint8_t *n) {
58 	assert(n[0] <= 63);	// prevent mis-use
59 	return (uint8_t *) strdup((char *) n);
60 }
61 
62 // duplicates a label
dup_label(const uint8_t * label)63 uint8_t *dup_label(const uint8_t *label) {
64 	int len = *label + 1;
65 	if (len > 63)
66 		return NULL;
67 	uint8_t *newlabel = malloc(len + 1);
68 	strncpy((char *) newlabel, (char *) label, len);
69 	newlabel[len] = '\0';
70 	return newlabel;
71 }
72 
join_nlabel(const uint8_t * n1,const uint8_t * n2)73 uint8_t *join_nlabel(const uint8_t *n1, const uint8_t *n2) {
74 	int len1, len2;
75 	uint8_t *s;
76 
77 	assert(n1[0] <= 63 && n2[0] <= 63);	// detect misuse
78 
79 	len1 = strlen((char *) n1);
80 	len2 = strlen((char *) n2);
81 
82 	s = malloc(len1 + len2 + 1);
83 	strncpy((char *) s, (char *) n1, len1);
84 	strncpy((char *) s+len1, (char *) n2, len2);
85 	s[len1 + len2] = '\0';
86 	return s;
87 }
88 
89 // returns a human-readable name label in dotted form
nlabel_to_str(const uint8_t * name)90 char *nlabel_to_str(const uint8_t *name) {
91 	char *label, *labelp;
92 	const uint8_t *p;
93 
94 	assert(name != NULL);
95 
96 	label = labelp = malloc(256);
97 
98 	for (p = name; *p; p++) {
99 		strncpy(labelp, (char *) p + 1, *p);
100 		labelp += *p;
101 		*labelp = '.';
102 		labelp++;
103 
104 		p += *p;
105 	}
106 
107 	*labelp = '\0';
108 
109 	return label;
110 }
111 
112 // returns the length of a label field
113 // does NOT uncompress the field, so it could be as small as 2 bytes
114 // or 1 for the root
label_len(uint8_t * pkt_buf,size_t pkt_len,size_t off)115 static size_t label_len(uint8_t *pkt_buf, size_t pkt_len, size_t off) {
116 	uint8_t *p;
117 	uint8_t *e = pkt_buf + pkt_len;
118 	size_t len = 0;
119 
120 	for (p = pkt_buf + off; p < e; p++) {
121 		if (*p == 0) {
122 			return len + 1;
123 		} else if ((*p & 0xC0) == 0xC0) {
124 			return len + 2;
125 		} else {
126 			len += *p + 1;
127 			p += *p;
128 		}
129 	}
130 
131 	return len;
132 }
133 
134 // creates a label
135 // free() after use
create_label(const char * txt)136 uint8_t *create_label(const char *txt) {
137 	int len;
138 	uint8_t *s;
139 
140 	assert(txt != NULL);
141 	len = strlen(txt);
142 	if (len > 63)
143 		return NULL;
144 
145 	s = malloc(len + 2);
146 	s[0] = len;
147 	strncpy((char *) s + 1, txt, len);
148 	s[len + 1] = '\0';
149 
150 	return s;
151 }
152 
153 // creates a uncompressed name label given a DNS name like "apple.b.com"
154 // free() after use
create_nlabel(const char * name)155 uint8_t *create_nlabel(const char *name) {
156 	char *label;
157 	char *p, *e, *lenpos;
158 	int len = 0;
159 
160 	assert(name != NULL);
161 
162 	len = strlen(name);
163 	label = malloc(len + 1 + 1);
164 	if (label == NULL)
165 		return NULL;
166 
167 	strncpy((char *) label + 1, name, len);
168 	label[len + 1] = '\0';
169 
170 	p = label;
171 	e = p + len;
172 	lenpos = p;
173 
174 	while (p < e) {
175 		*lenpos = 0;
176 		char *dot = memchr(p + 1, '.', e - p - 1);
177 		if (dot == NULL)
178 			dot = e + 1;
179 		*lenpos = dot - p - 1;
180 
181 		p = dot;
182 		lenpos = dot;
183 	}
184 
185 	return (uint8_t *) label;
186 }
187 
188 // copies a label from the buffer into a newly-allocated string
189 // free() after use
copy_label(uint8_t * pkt_buf,size_t pkt_len,size_t off)190 static uint8_t *copy_label(uint8_t *pkt_buf, size_t pkt_len, size_t off) {
191 	int len;
192 
193 	if (off > pkt_len)
194 		return NULL;
195 
196 	len = pkt_buf[off] + 1;
197 	if (off + len > pkt_len) {
198 		DEBUG_PRINTF("label length exceeds packet buffer\n");
199 		return NULL;
200 	}
201 
202 	return dup_label(pkt_buf + off);
203 }
204 
205 // uncompresses a name
206 // free() after use
uncompress_nlabel(uint8_t * pkt_buf,size_t pkt_len,size_t off)207 static uint8_t *uncompress_nlabel(uint8_t *pkt_buf, size_t pkt_len, size_t off) {
208 	uint8_t *p;
209 	uint8_t *e = pkt_buf + pkt_len;
210 	size_t len = 0;
211 	char *str, *sp;
212 	if (off >= pkt_len)
213 		return NULL;
214 
215 	// calculate length of uncompressed label
216 	for (p = pkt_buf + off; *p && p < e; p++) {
217 		size_t llen = 0;
218 		if ((*p & 0xC0) == 0xC0) {
219 			uint8_t *p2 = pkt_buf + (((p[0] & ~0xC0) << 8) | p[1]);
220 			llen = *p2 + 1;
221 			p = p2 + llen - 1;
222 		} else {
223 			llen = *p + 1;
224 			p += llen - 1;
225 		}
226 		len += llen;
227 	}
228 
229 	str = sp = malloc(len + 1);
230 	if (str == NULL)
231 		return NULL;
232 
233 	// FIXME: must merge this with above code
234 	for (p = pkt_buf + off; *p && p < e; p++) {
235 		size_t llen = 0;
236 		if ((*p & 0xC0) == 0xC0) {
237 			uint8_t *p2 = pkt_buf + (((p[0] & ~0xC0) << 8) | p[1]);
238 			llen = *p2 + 1;
239 			strncpy(sp, (char *) p2, llen);
240 			p = p2 + llen - 1;
241 		} else {
242 			llen = *p + 1;
243 			strncpy(sp, (char *) p, llen);
244 			p += llen - 1;
245 		}
246 		sp += llen;
247 	}
248 	*sp = '\0';
249 
250 	return (uint8_t *) str;
251 }
252 
253 // ----- RR list & group functions -----
254 
rr_get_type_name(enum rr_type type)255 const char *rr_get_type_name(enum rr_type type) {
256 	switch (type) {
257 		case RR_A:		return "A";
258 		case RR_PTR:	return "PTR";
259 		case RR_TXT:	return "TXT";
260 		case RR_AAAA:	return "AAAA";
261 		case RR_SRV:	return "SRV";
262 		case RR_NSEC:	return "NSEC";
263 		case RR_ANY:	return "ANY";
264 	}
265 	return NULL;
266 }
267 
rr_entry_destroy(struct rr_entry * rr)268 void rr_entry_destroy(struct rr_entry *rr) {
269 	struct rr_data_txt *txt_rec;
270 	assert(rr);
271 
272 	// check rr_type and free data elements
273 	switch (rr->type) {
274 		case RR_PTR:
275 			if (rr->data.PTR.name)
276 				free(rr->data.PTR.name);
277 			// don't free entry
278 			break;
279 
280 		case RR_TXT:
281 			txt_rec = &rr->data.TXT;
282 			while (txt_rec) {
283 				struct rr_data_txt *next = txt_rec->next;
284 				if (txt_rec->txt)
285 					free(txt_rec->txt);
286 
287 				// only free() if it wasn't part of the struct
288 				if (txt_rec != &rr->data.TXT)
289 					free(txt_rec);
290 
291 				txt_rec = next;
292 			}
293 			break;
294 
295 		case RR_SRV:
296 			if (rr->data.SRV.target)
297 				free(rr->data.SRV.target);
298 			break;
299 
300 		default:
301 			// nothing to free
302 			break;
303 	}
304 
305 	free(rr->name);
306 	free(rr);
307 }
308 
309 // destroys an RR list (and optionally, items)
rr_list_destroy(struct rr_list * rr,char destroy_items)310 void rr_list_destroy(struct rr_list *rr, char destroy_items) {
311 	struct rr_list *rr_next;
312 
313 	for (; rr; rr = rr_next) {
314 		rr_next = rr->next;
315 		if (destroy_items)
316 			rr_entry_destroy(rr->e);
317 		free(rr);
318 	}
319 }
320 
rr_list_count(struct rr_list * rr)321 int rr_list_count(struct rr_list *rr) {
322 	int i = 0;
323 	for (; rr; i++, rr = rr->next);
324 	return i;
325 }
326 
rr_list_remove(struct rr_list ** rr_head,struct rr_entry * rr)327 struct rr_entry *rr_list_remove(struct rr_list **rr_head, struct rr_entry *rr) {
328 	struct rr_list *le = *rr_head, *pe = NULL;
329 	for (; le; le = le->next) {
330 		if (le->e == rr) {
331 			if (pe == NULL) {
332 				*rr_head = le->next;
333 				free(le);
334 				return rr;
335 			} else {
336 				pe->next = le->next;
337 				free(le);
338 				return rr;
339 			}
340 		}
341 		pe = le;
342 	}
343 	return NULL;
344 }
345 
346 // appends an rr_entry to an RR list
347 // if the RR is already in the list, it will not be added
348 // RRs are compared by memory location - not its contents
349 // return value of 0 means item not added
rr_list_append(struct rr_list ** rr_head,struct rr_entry * rr)350 int rr_list_append(struct rr_list **rr_head, struct rr_entry *rr) {
351 	struct rr_list *node = malloc(sizeof(struct rr_list));
352 	node->e = rr;
353 	node->next = NULL;
354 
355 	if (*rr_head == NULL) {
356 		*rr_head = node;
357 	} else {
358 		struct rr_list *e = *rr_head, *taile;
359 		for (; e; e = e->next) {
360 			// already in list - don't add
361 			if (e->e == rr) {
362 				free(node);
363 				return 0;
364 			}
365 			if (e->next == NULL)
366 				taile = e;
367 		}
368 		taile->next = node;
369 	}
370 	return 1;
371 }
372 
373 #define FILL_RR_ENTRY(rr, _name, _type)	\
374 	rr->name = _name;			\
375 	rr->type = _type;			\
376 	rr->ttl  = DEFAULT_TTL;		\
377 	rr->cache_flush = 1;		\
378 	rr->rr_class  = 1;
379 
rr_create_a(uint8_t * name,uint32_t addr)380 struct rr_entry *rr_create_a(uint8_t *name, uint32_t addr) {
381 	DECL_MALLOC_ZERO_STRUCT(rr, rr_entry);
382 	FILL_RR_ENTRY(rr, name, RR_A);
383 	rr->data.A.addr = addr;
384 	return rr;
385 }
386 
rr_create_aaaa(uint8_t * name,struct in6_addr * addr)387 struct rr_entry *rr_create_aaaa(uint8_t *name, struct in6_addr *addr) {
388 	DECL_MALLOC_ZERO_STRUCT(rr, rr_entry);
389 	FILL_RR_ENTRY(rr, name, RR_AAAA);
390 	rr->data.AAAA.addr = addr;
391 	return rr;
392 }
393 
rr_create_srv(uint8_t * name,uint16_t port,uint8_t * target)394 struct rr_entry *rr_create_srv(uint8_t *name, uint16_t port, uint8_t *target) {
395 	DECL_MALLOC_ZERO_STRUCT(rr, rr_entry);
396 	FILL_RR_ENTRY(rr, name, RR_SRV);
397 	rr->data.SRV.port = port;
398 	rr->data.SRV.target = target;
399 	return rr;
400 }
401 
rr_create_ptr(uint8_t * name,struct rr_entry * d_rr)402 struct rr_entry *rr_create_ptr(uint8_t *name, struct rr_entry *d_rr) {
403 	DECL_MALLOC_ZERO_STRUCT(rr, rr_entry);
404 	FILL_RR_ENTRY(rr, name, RR_PTR);
405 	rr->cache_flush = 0;	// PTRs shouldn't have their cache flush bit set
406 	rr->data.PTR.entry = d_rr;
407 	return rr;
408 }
409 
rr_create(uint8_t * name,enum rr_type type)410 struct rr_entry *rr_create(uint8_t *name, enum rr_type type) {
411 	DECL_MALLOC_ZERO_STRUCT(rr, rr_entry);
412 	FILL_RR_ENTRY(rr, name, type);
413 	return rr;
414 }
415 
rr_set_nsec(struct rr_entry * rr_nsec,enum rr_type type)416 void rr_set_nsec(struct rr_entry *rr_nsec, enum rr_type type) {
417 	assert(rr_nsec->type = RR_NSEC);
418 	assert((type / 8) < sizeof(rr_nsec->data.NSEC.bitmap));
419 
420 	rr_nsec->data.NSEC.bitmap[ type / 8 ] = 1 << (7 - (type % 8));
421 }
422 
rr_add_txt(struct rr_entry * rr_txt,const char * txt)423 void rr_add_txt(struct rr_entry *rr_txt, const char *txt) {
424 	struct rr_data_txt *txt_rec;
425 	assert(rr_txt->type == RR_TXT);
426 
427 	txt_rec = &rr_txt->data.TXT;
428 
429 	// is current data filled?
430 	if (txt_rec->txt == NULL) {
431 		txt_rec->txt = create_label(txt);
432 		return;
433 	}
434 
435 	// find the last node
436 	for (; txt_rec->next; txt_rec = txt_rec->next);
437 
438 	// create a new empty node
439 	txt_rec->next = malloc(sizeof(struct rr_data_txt));
440 
441 	txt_rec = txt_rec->next;
442 	txt_rec->txt = create_label(txt);
443 	txt_rec->next = NULL;
444 }
445 
446 // adds a record to an rr_group
rr_group_add(struct rr_group ** group,struct rr_entry * rr)447 void rr_group_add(struct rr_group **group, struct rr_entry *rr) {
448 	struct rr_group *g;
449 
450 	assert(rr != NULL);
451 
452 	if (*group) {
453 		g = rr_group_find(*group, rr->name);
454 		if (g) {
455 			rr_list_append(&g->rr, rr);
456 			return;
457 		}
458 	}
459 
460 	MALLOC_ZERO_STRUCT(g, rr_group);
461 	g->name = dup_nlabel(rr->name);
462 	rr_list_append(&g->rr, rr);
463 
464 	// prepend to list
465 	g->next = *group;
466 	*group = g;
467 }
468 
469 // finds a rr_group matching the given name
rr_group_find(struct rr_group * g,uint8_t * name)470 struct rr_group *rr_group_find(struct rr_group* g, uint8_t *name) {
471 	for (; g; g = g->next) {
472 		if (cmp_nlabel(g->name, name) == 0)
473 			return g;
474 	}
475 	return NULL;
476 }
477 
rr_entry_find(struct rr_list * rr_list,uint8_t * name,uint16_t type)478 struct rr_entry *rr_entry_find(struct rr_list *rr_list, uint8_t *name, uint16_t type) {
479 	struct rr_list *rr = rr_list;
480 	for (; rr; rr = rr->next) {
481 		if (rr->e->type == type && cmp_nlabel(rr->e->name, name) == 0)
482 			return rr->e;
483 	}
484 	return NULL;
485 }
486 
487 // looks for a matching entry in rr_list
488 // if entry is a PTR, we need to check if the PTR target also matches
rr_entry_match(struct rr_list * rr_list,struct rr_entry * entry)489 struct rr_entry *rr_entry_match(struct rr_list *rr_list, struct rr_entry *entry) {
490 	struct rr_list *rr = rr_list;
491 	for (; rr; rr = rr->next) {
492 		if (rr->e->type == entry->type && cmp_nlabel(rr->e->name, entry->name) == 0) {
493 			if (entry->type != RR_PTR) {
494 				return rr->e;
495 			} else if (cmp_nlabel(MDNS_RR_GET_PTR_NAME(entry), MDNS_RR_GET_PTR_NAME(rr->e)) == 0) {
496 				// if it's a PTR, we need to make sure PTR target also matches
497 				return rr->e;
498 			}
499 		}
500 	}
501 	return NULL;
502 }
503 
rr_group_destroy(struct rr_group * group)504 void rr_group_destroy(struct rr_group *group) {
505 	struct rr_group *g = group;
506 
507 	while (g) {
508 		struct rr_group *nextg = g->next;
509 		free(g->name);
510 		rr_list_destroy(g->rr, 1);
511 		free(g);
512 		g = nextg;
513 	}
514 }
515 
mdns_write_u16(uint8_t * ptr,const uint16_t v)516 uint8_t *mdns_write_u16(uint8_t *ptr, const uint16_t v) {
517 	*ptr++ = (uint8_t) (v >> 8) & 0xFF;
518 	*ptr++ = (uint8_t) (v >> 0) & 0xFF;
519 	return ptr;
520 }
521 
mdns_write_u32(uint8_t * ptr,const uint32_t v)522 uint8_t *mdns_write_u32(uint8_t *ptr, const uint32_t v) {
523 	*ptr++ = (uint8_t) (v >> 24) & 0xFF;
524 	*ptr++ = (uint8_t) (v >> 16) & 0xFF;
525 	*ptr++ = (uint8_t) (v >>  8) & 0xFF;
526 	*ptr++ = (uint8_t) (v >>  0) & 0xFF;
527 	return ptr;
528 }
529 
mdns_read_u16(const uint8_t * ptr)530 uint16_t mdns_read_u16(const uint8_t *ptr) {
531 	return  ((ptr[0] & 0xFF) << 8) |
532 			((ptr[1] & 0xFF) << 0);
533 }
534 
mdns_read_u32(const uint8_t * ptr)535 uint32_t mdns_read_u32(const uint8_t *ptr) {
536 	return  ((ptr[0] & 0xFF) << 24) |
537 			((ptr[1] & 0xFF) << 16) |
538 			((ptr[2] & 0xFF) <<  8) |
539 			((ptr[3] & 0xFF) <<  0);
540 }
541 
542 // initialize the packet for reply
543 // clears the packet of list structures but not its list items
mdns_init_reply(struct mdns_pkt * pkt,uint16_t id)544 void mdns_init_reply(struct mdns_pkt *pkt, uint16_t id) {
545 	// copy transaction ID
546 	pkt->id = id;
547 
548 	// response flags
549 	pkt->flags = MDNS_FLAG_RESP | MDNS_FLAG_AA;
550 
551 	rr_list_destroy(pkt->rr_qn,   0);
552 	rr_list_destroy(pkt->rr_ans,  0);
553 	rr_list_destroy(pkt->rr_auth, 0);
554 	rr_list_destroy(pkt->rr_add,  0);
555 
556 	pkt->rr_qn    = NULL;
557 	pkt->rr_ans   = NULL;
558 	pkt->rr_auth  = NULL;
559 	pkt->rr_add   = NULL;
560 
561 	pkt->num_qn = 0;
562 	pkt->num_ans_rr = 0;
563 	pkt->num_auth_rr = 0;
564 	pkt->num_add_rr = 0;
565 }
566 
567 // destroys an mdns_pkt struct, including its contents
mdns_pkt_destroy(struct mdns_pkt * p)568 void mdns_pkt_destroy(struct mdns_pkt *p) {
569 	rr_list_destroy(p->rr_qn, 1);
570 	rr_list_destroy(p->rr_ans, 1);
571 	rr_list_destroy(p->rr_auth, 1);
572 	rr_list_destroy(p->rr_add, 1);
573 
574 	free(p);
575 }
576 
577 
578 // parse the MDNS questions section
579 // stores the parsed data in the given mdns_pkt struct
mdns_parse_qn(uint8_t * pkt_buf,size_t pkt_len,size_t off,struct mdns_pkt * pkt)580 static size_t mdns_parse_qn(uint8_t *pkt_buf, size_t pkt_len, size_t off,
581 		struct mdns_pkt *pkt) {
582 	const uint8_t *p = pkt_buf + off;
583 	struct rr_entry *rr;
584 	uint8_t *name;
585 
586 	assert(pkt != NULL);
587 
588 	rr = malloc(sizeof(struct rr_entry));
589 	memset(rr, 0, sizeof(struct rr_entry));
590 
591 	name = uncompress_nlabel(pkt_buf, pkt_len, off);
592 	p += label_len(pkt_buf, pkt_len, off);
593 	rr->name = name;
594 
595 	rr->type = mdns_read_u16(p);
596 	p += sizeof(uint16_t);
597 
598 	rr->unicast_query = (*p & 0x80) == 0x80;
599 	rr->rr_class = mdns_read_u16(p) & ~0x80;
600 	p += sizeof(uint16_t);
601 
602 	rr_list_append(&pkt->rr_qn, rr);
603 
604 	return p - (pkt_buf + off);
605 }
606 
607 // parse the MDNS RR section
608 // stores the parsed data in the given mdns_pkt struct
mdns_parse_rr(uint8_t * pkt_buf,size_t pkt_len,size_t off,struct mdns_pkt * pkt)609 static size_t mdns_parse_rr(uint8_t *pkt_buf, size_t pkt_len, size_t off,
610 		struct mdns_pkt *pkt) {
611 	const uint8_t *p = pkt_buf + off;
612 	const uint8_t *e = pkt_buf + pkt_len;
613 	struct rr_entry *rr;
614 	uint8_t *name;
615 	size_t rr_data_len = 0;
616 	struct rr_data_txt *txt_rec;
617 	int parse_error = 0;
618 
619 	assert(pkt != NULL);
620 
621 	if (off > pkt_len)
622 		return 0;
623 
624 	rr = malloc(sizeof(struct rr_entry));
625 	memset(rr, 0, sizeof(struct rr_entry));
626 
627 	name = uncompress_nlabel(pkt_buf, pkt_len, off);
628 	p += label_len(pkt_buf, pkt_len, off);
629 	rr->name = name;
630 
631 	rr->type = mdns_read_u16(p);
632 	p += sizeof(uint16_t);
633 
634 	rr->cache_flush = (*p & 0x80) == 0x80;
635 	rr->rr_class = mdns_read_u16(p) & ~0x80;
636 	p += sizeof(uint16_t);
637 
638 	rr->ttl = mdns_read_u32(p);
639 	p += sizeof(uint32_t);
640 
641 	// RR data
642 	rr_data_len = mdns_read_u16(p);
643 	p += sizeof(uint16_t);
644 
645 	if (p + rr_data_len > e) {
646 		DEBUG_PRINTF("rr_data_len goes beyond packet buffer: %lu > %lu\n", rr_data_len, e - p);
647 		rr_entry_destroy(rr);
648 		return 0;
649 	}
650 
651 	e = p + rr_data_len;
652 
653 	// see if we can parse the RR data
654 	switch (rr->type) {
655 		case RR_A:
656 			if (rr_data_len < sizeof(uint32_t)) {
657 				DEBUG_PRINTF("invalid rr_data_len=%lu for A record\n", rr_data_len);
658 				parse_error = 1;
659 				break;
660 			}
661 			rr->data.A.addr = ntohl(mdns_read_u32(p)); /* addr already in net order */
662 			p += sizeof(uint32_t);
663 			break;
664 
665 		case RR_AAAA:
666 			if (rr_data_len < sizeof(struct in6_addr)) {
667 				DEBUG_PRINTF("invalid rr_data_len=%lu for AAAA record\n", rr_data_len);
668 				parse_error = 1;
669 				break;
670 			}
671 			rr->data.AAAA.addr = malloc(sizeof(struct in6_addr));
672 			for (int i = 0; i < sizeof(struct in6_addr); i++)
673 				rr->data.AAAA.addr->s6_addr[i] = p[i];
674 			p += sizeof(struct in6_addr);
675 			break;
676 
677 		case RR_PTR:
678 			rr->data.PTR.name = uncompress_nlabel(pkt_buf, pkt_len, p - pkt_buf);
679 			if (rr->data.PTR.name == NULL) {
680 				DEBUG_PRINTF("unable to parse/uncompress label for PTR name\n");
681 				parse_error = 1;
682 				break;
683 			}
684 			p += rr_data_len;
685 			break;
686 
687 		case RR_TXT:
688 			txt_rec = &rr->data.TXT;
689 
690 			// not supposed to happen, but we should handle it
691 			if (rr_data_len == 0) {
692 				DEBUG_PRINTF("WARN: rr_data_len for TXT is 0\n");
693 				txt_rec->txt = create_label("");
694 				break;
695 			}
696 
697 			while (1) {
698 				txt_rec->txt = copy_label(pkt_buf, pkt_len, p - pkt_buf);
699 				if (txt_rec->txt == NULL) {
700 					DEBUG_PRINTF("unable to copy label for TXT record\n");
701 					parse_error = 1;
702 					break;
703 				}
704 				p += txt_rec->txt[0] + 1;
705 
706 				if (p >= e)
707 					break;
708 
709 				// allocate another record
710 				txt_rec->next = malloc(sizeof(struct rr_data_txt));
711 				txt_rec = txt_rec->next;
712 				txt_rec->next = NULL;
713 			}
714 			break;
715 
716 		default:
717 			// skip to end of RR data
718 			p = e;
719 	}
720 
721 	// if there was a parse error, destroy partial rr_entry
722 	if (parse_error) {
723 		rr_entry_destroy(rr);
724 		return 0;
725 	}
726 
727 	rr_list_append(&pkt->rr_ans, rr);
728 
729 	return p - (pkt_buf + off);
730 }
731 
732 // parse a MDNS packet into an mdns_pkt struct
mdns_parse_pkt(uint8_t * pkt_buf,size_t pkt_len)733 struct mdns_pkt *mdns_parse_pkt(uint8_t *pkt_buf, size_t pkt_len) {
734 	uint8_t *p = pkt_buf;
735 	size_t off;
736 	struct mdns_pkt *pkt;
737 	int i;
738 
739 	if (pkt_len < 12)
740 		return NULL;
741 
742 	MALLOC_ZERO_STRUCT(pkt, mdns_pkt);
743 
744 	// parse header
745 	pkt->id 			= mdns_read_u16(p); p += sizeof(uint16_t);
746 	pkt->flags 			= mdns_read_u16(p); p += sizeof(uint16_t);
747 	pkt->num_qn 		= mdns_read_u16(p); p += sizeof(uint16_t);
748 	pkt->num_ans_rr 	= mdns_read_u16(p); p += sizeof(uint16_t);
749 	pkt->num_auth_rr 	= mdns_read_u16(p); p += sizeof(uint16_t);
750 	pkt->num_add_rr 	= mdns_read_u16(p); p += sizeof(uint16_t);
751 
752 	off = p - pkt_buf;
753 
754 	// parse questions
755 	for (i = 0; i < pkt->num_qn; i++) {
756 		size_t l = mdns_parse_qn(pkt_buf, pkt_len, off, pkt);
757 		if (! l) {
758 			DEBUG_PRINTF("error parsing question #%d\n", i);
759 			mdns_pkt_destroy(pkt);
760 			return NULL;
761 		}
762 
763 		off += l;
764 	}
765 
766 	// parse answer RRs
767 	for (i = 0; i < pkt->num_ans_rr; i++) {
768 		size_t l = mdns_parse_rr(pkt_buf, pkt_len, off, pkt);
769 		if (! l) {
770 			DEBUG_PRINTF("error parsing answer #%d\n", i);
771 			mdns_pkt_destroy(pkt);
772 			return NULL;
773 		}
774 
775 		off += l;
776 	}
777 
778 	// TODO: parse the authority and additional RR sections
779 
780 	return pkt;
781 }
782 
783 // encodes a name (label) into a packet using the name compression scheme
784 // encoded names will be added to the compression list for subsequent use
mdns_encode_name(uint8_t * pkt_buf,size_t pkt_len,size_t off,const uint8_t * name,struct name_comp * comp)785 static size_t mdns_encode_name(uint8_t *pkt_buf, size_t pkt_len, size_t off,
786 		const uint8_t *name, struct name_comp *comp) {
787 	struct name_comp *c, *c_tail = NULL;
788 	uint8_t *p = pkt_buf + off;
789 	size_t len = 0;
790 
791 	if (name) {
792 		while (*name) {
793 			// find match for compression
794 			for (c = comp; c; c = c->next) {
795 				if (cmp_nlabel(name, c->label) == 0) {
796 					mdns_write_u16(p, 0xC000 | (c->pos & ~0xC000));
797 					return len + sizeof(uint16_t);
798 				}
799 
800 				if (c->next == NULL)
801 					c_tail = c;
802 			}
803 
804 			// copy this segment
805 			int segment_len = *name + 1;
806 			strncpy((char *) p, (char *) name, segment_len);
807 
808 			// cache the name for subsequent compression
809 			DECL_MALLOC_ZERO_STRUCT(new_c, name_comp);
810 
811 			new_c->label = (uint8_t *) name;
812 			new_c->pos = p - pkt_buf;
813 			c_tail->next = new_c;
814 
815 			// advance to next name segment
816 			p += segment_len;
817 			len += segment_len;
818 			name += segment_len;
819 		}
820 	}
821 
822 	*p = '\0';	// root "label"
823 	len += 1;
824 
825 	return len;
826 }
827 
828 // encodes an RR entry at the given offset
829 // returns the size of the entire RR entry
mdns_encode_rr(uint8_t * pkt_buf,size_t pkt_len,size_t off,struct rr_entry * rr,struct name_comp * comp)830 static size_t mdns_encode_rr(uint8_t *pkt_buf, size_t pkt_len, size_t off,
831 		struct rr_entry *rr, struct name_comp *comp) {
832 	uint8_t *p = pkt_buf + off, *p_data;
833 	size_t l;
834 	struct rr_data_txt *txt_rec;
835 	uint8_t *label;
836 	int i;
837 
838 	assert(off < pkt_len);
839 
840 	// name
841 	l = mdns_encode_name(pkt_buf, pkt_len, off, rr->name, comp);
842 	assert(l != 0);
843 	p += l;
844 
845 	// type
846 	p = mdns_write_u16(p, rr->type);
847 
848 	// class & cache flush
849 	p = mdns_write_u16(p, (rr->rr_class & ~0x8000) | (rr->cache_flush << 15));
850 
851 	// TTL
852 	p = mdns_write_u32(p, rr->ttl);
853 
854 	// data length (filled in later)
855 	p += sizeof(uint16_t);
856 
857 	// start of data marker
858 	p_data = p;
859 
860 	switch (rr->type) {
861 		case RR_A:
862 			/* htonl() needed coz addr already in net order */
863 			p = mdns_write_u32(p, htonl(rr->data.A.addr));
864 			break;
865 
866 		case RR_AAAA:
867 			for (i = 0; i < sizeof(struct in6_addr); i++)
868 				*p++ = rr->data.AAAA.addr->s6_addr[i];
869 			break;
870 
871 		case RR_PTR:
872 			label = rr->data.PTR.name ?
873 					rr->data.PTR.name :
874 					rr->data.PTR.entry->name;
875 			p += mdns_encode_name(pkt_buf, pkt_len, p - pkt_buf, label, comp);
876 			break;
877 
878 		case RR_TXT:
879 			txt_rec = &rr->data.TXT;
880 			for (; txt_rec; txt_rec = txt_rec->next) {
881 				int len = txt_rec->txt[0] + 1;
882 				strncpy((char *) p, (char *) txt_rec->txt, len);
883 				p += len;
884 			}
885 			break;
886 
887 		case RR_SRV:
888 			p = mdns_write_u16(p, rr->data.SRV.priority);
889 
890 			p = mdns_write_u16(p, rr->data.SRV.weight);
891 
892 			p = mdns_write_u16(p, rr->data.SRV.port);
893 
894 			p += mdns_encode_name(pkt_buf, pkt_len, p - pkt_buf,
895 					rr->data.SRV.target, comp);
896 			break;
897 
898 		case RR_NSEC:
899 			p += mdns_encode_name(pkt_buf, pkt_len, p - pkt_buf,
900 					rr->name, comp);
901 
902 			*p++ = 0;	// bitmap window/block number
903 
904 			*p++ = sizeof(rr->data.NSEC.bitmap);	// bitmap length
905 
906 			for (i = 0; i < sizeof(rr->data.NSEC.bitmap); i++)
907 				*p++ = rr->data.NSEC.bitmap[i];
908 
909 			break;
910 
911 		default:
912 			DEBUG_PRINTF("unhandled rr type 0x%02x\n", rr->type);
913 	}
914 
915 	// calculate data length based on p
916 	l = p - p_data;
917 
918 	// fill in the length
919 	mdns_write_u16(p - l - sizeof(uint16_t), l);
920 
921 	return p - pkt_buf - off;
922 }
923 
924 // encodes a MDNS packet from the given mdns_pkt struct into a buffer
925 // returns the size of the entire MDNS packet
mdns_encode_pkt(struct mdns_pkt * answer,uint8_t * pkt_buf,size_t pkt_len)926 size_t mdns_encode_pkt(struct mdns_pkt *answer, uint8_t *pkt_buf, size_t pkt_len) {
927 	struct name_comp *comp;
928 	uint8_t *p = pkt_buf;
929 	//uint8_t *e = pkt_buf + pkt_len;
930 	size_t off;
931 	int i;
932 
933 	assert(answer != NULL);
934 	assert(pkt_len >= 12);
935 
936 	if (p == NULL)
937 		return -1;
938 
939 	// this is an Answer - number of qns should be zero
940 	assert(answer->num_qn == 0);
941 
942 	p = mdns_write_u16(p, answer->id);
943 	p = mdns_write_u16(p, answer->flags);
944 	p = mdns_write_u16(p, answer->num_qn);
945 	p = mdns_write_u16(p, answer->num_ans_rr);
946 	p = mdns_write_u16(p, answer->num_auth_rr);
947 	p = mdns_write_u16(p, answer->num_add_rr);
948 
949 	off = p - pkt_buf;
950 
951 	// allocate list for name compression
952 	comp = malloc(sizeof(struct name_comp));
953 	if (comp == NULL)
954 		return -1;
955 	memset(comp, 0, sizeof(struct name_comp));
956 
957 	// dummy entry
958 	comp->label = (uint8_t *) "";
959 	comp->pos = 0;
960 
961 	// skip encoding of qn
962 
963 	struct rr_list *rr_set[] = {
964 		answer->rr_ans,
965 		answer->rr_auth,
966 		answer->rr_add
967 	};
968 
969 	// encode answer, authority and additional RRs
970 	for (i = 0; i < sizeof(rr_set) / sizeof(rr_set[0]); i++) {
971 		struct rr_list *rr = rr_set[i];
972 		for (; rr; rr = rr->next) {
973 			size_t l = mdns_encode_rr(pkt_buf, pkt_len, off, rr->e, comp);
974 			off += l;
975 
976 			if (off >= pkt_len) {
977 				DEBUG_PRINTF("packet buffer too small\n");
978 				return -1;
979 			}
980 		}
981 
982 	}
983 
984 	// free name compression list
985 	while (comp) {
986 		struct name_comp *c = comp->next;
987 		free(comp);
988 		comp = c;
989 	}
990 
991 	return off;
992 }
993 
994