1 /*
2 * tinysvcmdns - a tiny MDNS implementation for publishing services
3 * Copyright (C) 2011 Darell Tan
4 * All rights reserved.
5 *
6 * Redistribution and use in source and binary forms, with or without
7 * modification, are permitted provided that the following conditions
8 * are met:
9 * 1. Redistributions of source code must retain the above copyright
10 * notice, this list of conditions and the following disclaimer.
11 * 2. Redistributions in binary form must reproduce the above copyright
12 * notice, this list of conditions and the following disclaimer in the
13 * documentation and/or other materials provided with the distribution.
14 * 3. The name of the author may not be used to endorse or promote products
15 * derived from this software without specific prior written permission.
16 *
17 * THIS SOFTWARE IS PROVIDED BY THE AUTHOR ``AS IS'' AND ANY EXPRESS OR
18 * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES
19 * OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED.
20 * IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY DIRECT, INDIRECT,
21 * INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT
22 * NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
23 * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
24 * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
25 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF
26 * THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
27 */
28
29 #include "mdns.h"
30 #include <stdint.h>
31 #include <stdlib.h>
32 #include <stdio.h>
33 #include <string.h>
34 #include <assert.h>
35
36 #ifdef _WIN32
37 #include <winsock.h>
38 #include <in6addr.h>
39 #else
40 #include <netinet/in.h>
41 #endif
42
43
44 #define DEFAULT_TTL 120
45
46
47 struct name_comp {
48 uint8_t *label; // label
49 size_t pos; // position in msg
50
51 struct name_comp *next;
52 };
53
54 // ----- label functions -----
55
56 // duplicates a name
dup_nlabel(const uint8_t * n)57 inline uint8_t *dup_nlabel(const uint8_t *n) {
58 assert(n[0] <= 63); // prevent mis-use
59 return (uint8_t *) strdup((char *) n);
60 }
61
62 // duplicates a label
dup_label(const uint8_t * label)63 uint8_t *dup_label(const uint8_t *label) {
64 int len = *label + 1;
65 if (len > 63)
66 return NULL;
67 uint8_t *newlabel = malloc(len + 1);
68 strncpy((char *) newlabel, (char *) label, len);
69 newlabel[len] = '\0';
70 return newlabel;
71 }
72
join_nlabel(const uint8_t * n1,const uint8_t * n2)73 uint8_t *join_nlabel(const uint8_t *n1, const uint8_t *n2) {
74 int len1, len2;
75 uint8_t *s;
76
77 assert(n1[0] <= 63 && n2[0] <= 63); // detect misuse
78
79 len1 = strlen((char *) n1);
80 len2 = strlen((char *) n2);
81
82 s = malloc(len1 + len2 + 1);
83 strncpy((char *) s, (char *) n1, len1);
84 strncpy((char *) s+len1, (char *) n2, len2);
85 s[len1 + len2] = '\0';
86 return s;
87 }
88
89 // returns a human-readable name label in dotted form
nlabel_to_str(const uint8_t * name)90 char *nlabel_to_str(const uint8_t *name) {
91 char *label, *labelp;
92 const uint8_t *p;
93
94 assert(name != NULL);
95
96 label = labelp = malloc(256);
97
98 for (p = name; *p; p++) {
99 strncpy(labelp, (char *) p + 1, *p);
100 labelp += *p;
101 *labelp = '.';
102 labelp++;
103
104 p += *p;
105 }
106
107 *labelp = '\0';
108
109 return label;
110 }
111
112 // returns the length of a label field
113 // does NOT uncompress the field, so it could be as small as 2 bytes
114 // or 1 for the root
label_len(uint8_t * pkt_buf,size_t pkt_len,size_t off)115 static size_t label_len(uint8_t *pkt_buf, size_t pkt_len, size_t off) {
116 uint8_t *p;
117 uint8_t *e = pkt_buf + pkt_len;
118 size_t len = 0;
119
120 for (p = pkt_buf + off; p < e; p++) {
121 if (*p == 0) {
122 return len + 1;
123 } else if ((*p & 0xC0) == 0xC0) {
124 return len + 2;
125 } else {
126 len += *p + 1;
127 p += *p;
128 }
129 }
130
131 return len;
132 }
133
134 // creates a label
135 // free() after use
create_label(const char * txt)136 uint8_t *create_label(const char *txt) {
137 int len;
138 uint8_t *s;
139
140 assert(txt != NULL);
141 len = strlen(txt);
142 if (len > 63)
143 return NULL;
144
145 s = malloc(len + 2);
146 s[0] = len;
147 strncpy((char *) s + 1, txt, len);
148 s[len + 1] = '\0';
149
150 return s;
151 }
152
153 // creates a uncompressed name label given a DNS name like "apple.b.com"
154 // free() after use
create_nlabel(const char * name)155 uint8_t *create_nlabel(const char *name) {
156 char *label;
157 char *p, *e, *lenpos;
158 int len = 0;
159
160 assert(name != NULL);
161
162 len = strlen(name);
163 label = malloc(len + 1 + 1);
164 if (label == NULL)
165 return NULL;
166
167 strncpy((char *) label + 1, name, len);
168 label[len + 1] = '\0';
169
170 p = label;
171 e = p + len;
172 lenpos = p;
173
174 while (p < e) {
175 *lenpos = 0;
176 char *dot = memchr(p + 1, '.', e - p - 1);
177 if (dot == NULL)
178 dot = e + 1;
179 *lenpos = dot - p - 1;
180
181 p = dot;
182 lenpos = dot;
183 }
184
185 return (uint8_t *) label;
186 }
187
188 // copies a label from the buffer into a newly-allocated string
189 // free() after use
copy_label(uint8_t * pkt_buf,size_t pkt_len,size_t off)190 static uint8_t *copy_label(uint8_t *pkt_buf, size_t pkt_len, size_t off) {
191 int len;
192
193 if (off > pkt_len)
194 return NULL;
195
196 len = pkt_buf[off] + 1;
197 if (off + len > pkt_len) {
198 DEBUG_PRINTF("label length exceeds packet buffer\n");
199 return NULL;
200 }
201
202 return dup_label(pkt_buf + off);
203 }
204
205 // uncompresses a name
206 // free() after use
uncompress_nlabel(uint8_t * pkt_buf,size_t pkt_len,size_t off)207 static uint8_t *uncompress_nlabel(uint8_t *pkt_buf, size_t pkt_len, size_t off) {
208 uint8_t *p;
209 uint8_t *e = pkt_buf + pkt_len;
210 size_t len = 0;
211 char *str, *sp;
212 if (off >= pkt_len)
213 return NULL;
214
215 // calculate length of uncompressed label
216 for (p = pkt_buf + off; *p && p < e; p++) {
217 size_t llen = 0;
218 if ((*p & 0xC0) == 0xC0) {
219 uint8_t *p2 = pkt_buf + (((p[0] & ~0xC0) << 8) | p[1]);
220 llen = *p2 + 1;
221 p = p2 + llen - 1;
222 } else {
223 llen = *p + 1;
224 p += llen - 1;
225 }
226 len += llen;
227 }
228
229 str = sp = malloc(len + 1);
230 if (str == NULL)
231 return NULL;
232
233 // FIXME: must merge this with above code
234 for (p = pkt_buf + off; *p && p < e; p++) {
235 size_t llen = 0;
236 if ((*p & 0xC0) == 0xC0) {
237 uint8_t *p2 = pkt_buf + (((p[0] & ~0xC0) << 8) | p[1]);
238 llen = *p2 + 1;
239 strncpy(sp, (char *) p2, llen);
240 p = p2 + llen - 1;
241 } else {
242 llen = *p + 1;
243 strncpy(sp, (char *) p, llen);
244 p += llen - 1;
245 }
246 sp += llen;
247 }
248 *sp = '\0';
249
250 return (uint8_t *) str;
251 }
252
253 // ----- RR list & group functions -----
254
rr_get_type_name(enum rr_type type)255 const char *rr_get_type_name(enum rr_type type) {
256 switch (type) {
257 case RR_A: return "A";
258 case RR_PTR: return "PTR";
259 case RR_TXT: return "TXT";
260 case RR_AAAA: return "AAAA";
261 case RR_SRV: return "SRV";
262 case RR_NSEC: return "NSEC";
263 case RR_ANY: return "ANY";
264 }
265 return NULL;
266 }
267
rr_entry_destroy(struct rr_entry * rr)268 void rr_entry_destroy(struct rr_entry *rr) {
269 struct rr_data_txt *txt_rec;
270 assert(rr);
271
272 // check rr_type and free data elements
273 switch (rr->type) {
274 case RR_PTR:
275 if (rr->data.PTR.name)
276 free(rr->data.PTR.name);
277 // don't free entry
278 break;
279
280 case RR_TXT:
281 txt_rec = &rr->data.TXT;
282 while (txt_rec) {
283 struct rr_data_txt *next = txt_rec->next;
284 if (txt_rec->txt)
285 free(txt_rec->txt);
286
287 // only free() if it wasn't part of the struct
288 if (txt_rec != &rr->data.TXT)
289 free(txt_rec);
290
291 txt_rec = next;
292 }
293 break;
294
295 case RR_SRV:
296 if (rr->data.SRV.target)
297 free(rr->data.SRV.target);
298 break;
299
300 default:
301 // nothing to free
302 break;
303 }
304
305 free(rr->name);
306 free(rr);
307 }
308
309 // destroys an RR list (and optionally, items)
rr_list_destroy(struct rr_list * rr,char destroy_items)310 void rr_list_destroy(struct rr_list *rr, char destroy_items) {
311 struct rr_list *rr_next;
312
313 for (; rr; rr = rr_next) {
314 rr_next = rr->next;
315 if (destroy_items)
316 rr_entry_destroy(rr->e);
317 free(rr);
318 }
319 }
320
rr_list_count(struct rr_list * rr)321 int rr_list_count(struct rr_list *rr) {
322 int i = 0;
323 for (; rr; i++, rr = rr->next);
324 return i;
325 }
326
rr_list_remove(struct rr_list ** rr_head,struct rr_entry * rr)327 struct rr_entry *rr_list_remove(struct rr_list **rr_head, struct rr_entry *rr) {
328 struct rr_list *le = *rr_head, *pe = NULL;
329 for (; le; le = le->next) {
330 if (le->e == rr) {
331 if (pe == NULL) {
332 *rr_head = le->next;
333 free(le);
334 return rr;
335 } else {
336 pe->next = le->next;
337 free(le);
338 return rr;
339 }
340 }
341 pe = le;
342 }
343 return NULL;
344 }
345
346 // appends an rr_entry to an RR list
347 // if the RR is already in the list, it will not be added
348 // RRs are compared by memory location - not its contents
349 // return value of 0 means item not added
rr_list_append(struct rr_list ** rr_head,struct rr_entry * rr)350 int rr_list_append(struct rr_list **rr_head, struct rr_entry *rr) {
351 struct rr_list *node = malloc(sizeof(struct rr_list));
352 node->e = rr;
353 node->next = NULL;
354
355 if (*rr_head == NULL) {
356 *rr_head = node;
357 } else {
358 struct rr_list *e = *rr_head, *taile;
359 for (; e; e = e->next) {
360 // already in list - don't add
361 if (e->e == rr) {
362 free(node);
363 return 0;
364 }
365 if (e->next == NULL)
366 taile = e;
367 }
368 taile->next = node;
369 }
370 return 1;
371 }
372
373 #define FILL_RR_ENTRY(rr, _name, _type) \
374 rr->name = _name; \
375 rr->type = _type; \
376 rr->ttl = DEFAULT_TTL; \
377 rr->cache_flush = 1; \
378 rr->rr_class = 1;
379
rr_create_a(uint8_t * name,uint32_t addr)380 struct rr_entry *rr_create_a(uint8_t *name, uint32_t addr) {
381 DECL_MALLOC_ZERO_STRUCT(rr, rr_entry);
382 FILL_RR_ENTRY(rr, name, RR_A);
383 rr->data.A.addr = addr;
384 return rr;
385 }
386
rr_create_aaaa(uint8_t * name,struct in6_addr * addr)387 struct rr_entry *rr_create_aaaa(uint8_t *name, struct in6_addr *addr) {
388 DECL_MALLOC_ZERO_STRUCT(rr, rr_entry);
389 FILL_RR_ENTRY(rr, name, RR_AAAA);
390 rr->data.AAAA.addr = addr;
391 return rr;
392 }
393
rr_create_srv(uint8_t * name,uint16_t port,uint8_t * target)394 struct rr_entry *rr_create_srv(uint8_t *name, uint16_t port, uint8_t *target) {
395 DECL_MALLOC_ZERO_STRUCT(rr, rr_entry);
396 FILL_RR_ENTRY(rr, name, RR_SRV);
397 rr->data.SRV.port = port;
398 rr->data.SRV.target = target;
399 return rr;
400 }
401
rr_create_ptr(uint8_t * name,struct rr_entry * d_rr)402 struct rr_entry *rr_create_ptr(uint8_t *name, struct rr_entry *d_rr) {
403 DECL_MALLOC_ZERO_STRUCT(rr, rr_entry);
404 FILL_RR_ENTRY(rr, name, RR_PTR);
405 rr->cache_flush = 0; // PTRs shouldn't have their cache flush bit set
406 rr->data.PTR.entry = d_rr;
407 return rr;
408 }
409
rr_create(uint8_t * name,enum rr_type type)410 struct rr_entry *rr_create(uint8_t *name, enum rr_type type) {
411 DECL_MALLOC_ZERO_STRUCT(rr, rr_entry);
412 FILL_RR_ENTRY(rr, name, type);
413 return rr;
414 }
415
rr_set_nsec(struct rr_entry * rr_nsec,enum rr_type type)416 void rr_set_nsec(struct rr_entry *rr_nsec, enum rr_type type) {
417 assert(rr_nsec->type = RR_NSEC);
418 assert((type / 8) < sizeof(rr_nsec->data.NSEC.bitmap));
419
420 rr_nsec->data.NSEC.bitmap[ type / 8 ] = 1 << (7 - (type % 8));
421 }
422
rr_add_txt(struct rr_entry * rr_txt,const char * txt)423 void rr_add_txt(struct rr_entry *rr_txt, const char *txt) {
424 struct rr_data_txt *txt_rec;
425 assert(rr_txt->type == RR_TXT);
426
427 txt_rec = &rr_txt->data.TXT;
428
429 // is current data filled?
430 if (txt_rec->txt == NULL) {
431 txt_rec->txt = create_label(txt);
432 return;
433 }
434
435 // find the last node
436 for (; txt_rec->next; txt_rec = txt_rec->next);
437
438 // create a new empty node
439 txt_rec->next = malloc(sizeof(struct rr_data_txt));
440
441 txt_rec = txt_rec->next;
442 txt_rec->txt = create_label(txt);
443 txt_rec->next = NULL;
444 }
445
446 // adds a record to an rr_group
rr_group_add(struct rr_group ** group,struct rr_entry * rr)447 void rr_group_add(struct rr_group **group, struct rr_entry *rr) {
448 struct rr_group *g;
449
450 assert(rr != NULL);
451
452 if (*group) {
453 g = rr_group_find(*group, rr->name);
454 if (g) {
455 rr_list_append(&g->rr, rr);
456 return;
457 }
458 }
459
460 MALLOC_ZERO_STRUCT(g, rr_group);
461 g->name = dup_nlabel(rr->name);
462 rr_list_append(&g->rr, rr);
463
464 // prepend to list
465 g->next = *group;
466 *group = g;
467 }
468
469 // finds a rr_group matching the given name
rr_group_find(struct rr_group * g,uint8_t * name)470 struct rr_group *rr_group_find(struct rr_group* g, uint8_t *name) {
471 for (; g; g = g->next) {
472 if (cmp_nlabel(g->name, name) == 0)
473 return g;
474 }
475 return NULL;
476 }
477
rr_entry_find(struct rr_list * rr_list,uint8_t * name,uint16_t type)478 struct rr_entry *rr_entry_find(struct rr_list *rr_list, uint8_t *name, uint16_t type) {
479 struct rr_list *rr = rr_list;
480 for (; rr; rr = rr->next) {
481 if (rr->e->type == type && cmp_nlabel(rr->e->name, name) == 0)
482 return rr->e;
483 }
484 return NULL;
485 }
486
487 // looks for a matching entry in rr_list
488 // if entry is a PTR, we need to check if the PTR target also matches
rr_entry_match(struct rr_list * rr_list,struct rr_entry * entry)489 struct rr_entry *rr_entry_match(struct rr_list *rr_list, struct rr_entry *entry) {
490 struct rr_list *rr = rr_list;
491 for (; rr; rr = rr->next) {
492 if (rr->e->type == entry->type && cmp_nlabel(rr->e->name, entry->name) == 0) {
493 if (entry->type != RR_PTR) {
494 return rr->e;
495 } else if (cmp_nlabel(MDNS_RR_GET_PTR_NAME(entry), MDNS_RR_GET_PTR_NAME(rr->e)) == 0) {
496 // if it's a PTR, we need to make sure PTR target also matches
497 return rr->e;
498 }
499 }
500 }
501 return NULL;
502 }
503
rr_group_destroy(struct rr_group * group)504 void rr_group_destroy(struct rr_group *group) {
505 struct rr_group *g = group;
506
507 while (g) {
508 struct rr_group *nextg = g->next;
509 free(g->name);
510 rr_list_destroy(g->rr, 1);
511 free(g);
512 g = nextg;
513 }
514 }
515
mdns_write_u16(uint8_t * ptr,const uint16_t v)516 uint8_t *mdns_write_u16(uint8_t *ptr, const uint16_t v) {
517 *ptr++ = (uint8_t) (v >> 8) & 0xFF;
518 *ptr++ = (uint8_t) (v >> 0) & 0xFF;
519 return ptr;
520 }
521
mdns_write_u32(uint8_t * ptr,const uint32_t v)522 uint8_t *mdns_write_u32(uint8_t *ptr, const uint32_t v) {
523 *ptr++ = (uint8_t) (v >> 24) & 0xFF;
524 *ptr++ = (uint8_t) (v >> 16) & 0xFF;
525 *ptr++ = (uint8_t) (v >> 8) & 0xFF;
526 *ptr++ = (uint8_t) (v >> 0) & 0xFF;
527 return ptr;
528 }
529
mdns_read_u16(const uint8_t * ptr)530 uint16_t mdns_read_u16(const uint8_t *ptr) {
531 return ((ptr[0] & 0xFF) << 8) |
532 ((ptr[1] & 0xFF) << 0);
533 }
534
mdns_read_u32(const uint8_t * ptr)535 uint32_t mdns_read_u32(const uint8_t *ptr) {
536 return ((ptr[0] & 0xFF) << 24) |
537 ((ptr[1] & 0xFF) << 16) |
538 ((ptr[2] & 0xFF) << 8) |
539 ((ptr[3] & 0xFF) << 0);
540 }
541
542 // initialize the packet for reply
543 // clears the packet of list structures but not its list items
mdns_init_reply(struct mdns_pkt * pkt,uint16_t id)544 void mdns_init_reply(struct mdns_pkt *pkt, uint16_t id) {
545 // copy transaction ID
546 pkt->id = id;
547
548 // response flags
549 pkt->flags = MDNS_FLAG_RESP | MDNS_FLAG_AA;
550
551 rr_list_destroy(pkt->rr_qn, 0);
552 rr_list_destroy(pkt->rr_ans, 0);
553 rr_list_destroy(pkt->rr_auth, 0);
554 rr_list_destroy(pkt->rr_add, 0);
555
556 pkt->rr_qn = NULL;
557 pkt->rr_ans = NULL;
558 pkt->rr_auth = NULL;
559 pkt->rr_add = NULL;
560
561 pkt->num_qn = 0;
562 pkt->num_ans_rr = 0;
563 pkt->num_auth_rr = 0;
564 pkt->num_add_rr = 0;
565 }
566
567 // destroys an mdns_pkt struct, including its contents
mdns_pkt_destroy(struct mdns_pkt * p)568 void mdns_pkt_destroy(struct mdns_pkt *p) {
569 rr_list_destroy(p->rr_qn, 1);
570 rr_list_destroy(p->rr_ans, 1);
571 rr_list_destroy(p->rr_auth, 1);
572 rr_list_destroy(p->rr_add, 1);
573
574 free(p);
575 }
576
577
578 // parse the MDNS questions section
579 // stores the parsed data in the given mdns_pkt struct
mdns_parse_qn(uint8_t * pkt_buf,size_t pkt_len,size_t off,struct mdns_pkt * pkt)580 static size_t mdns_parse_qn(uint8_t *pkt_buf, size_t pkt_len, size_t off,
581 struct mdns_pkt *pkt) {
582 const uint8_t *p = pkt_buf + off;
583 struct rr_entry *rr;
584 uint8_t *name;
585
586 assert(pkt != NULL);
587
588 rr = malloc(sizeof(struct rr_entry));
589 memset(rr, 0, sizeof(struct rr_entry));
590
591 name = uncompress_nlabel(pkt_buf, pkt_len, off);
592 p += label_len(pkt_buf, pkt_len, off);
593 rr->name = name;
594
595 rr->type = mdns_read_u16(p);
596 p += sizeof(uint16_t);
597
598 rr->unicast_query = (*p & 0x80) == 0x80;
599 rr->rr_class = mdns_read_u16(p) & ~0x80;
600 p += sizeof(uint16_t);
601
602 rr_list_append(&pkt->rr_qn, rr);
603
604 return p - (pkt_buf + off);
605 }
606
607 // parse the MDNS RR section
608 // stores the parsed data in the given mdns_pkt struct
mdns_parse_rr(uint8_t * pkt_buf,size_t pkt_len,size_t off,struct mdns_pkt * pkt)609 static size_t mdns_parse_rr(uint8_t *pkt_buf, size_t pkt_len, size_t off,
610 struct mdns_pkt *pkt) {
611 const uint8_t *p = pkt_buf + off;
612 const uint8_t *e = pkt_buf + pkt_len;
613 struct rr_entry *rr;
614 uint8_t *name;
615 size_t rr_data_len = 0;
616 struct rr_data_txt *txt_rec;
617 int parse_error = 0;
618
619 assert(pkt != NULL);
620
621 if (off > pkt_len)
622 return 0;
623
624 rr = malloc(sizeof(struct rr_entry));
625 memset(rr, 0, sizeof(struct rr_entry));
626
627 name = uncompress_nlabel(pkt_buf, pkt_len, off);
628 p += label_len(pkt_buf, pkt_len, off);
629 rr->name = name;
630
631 rr->type = mdns_read_u16(p);
632 p += sizeof(uint16_t);
633
634 rr->cache_flush = (*p & 0x80) == 0x80;
635 rr->rr_class = mdns_read_u16(p) & ~0x80;
636 p += sizeof(uint16_t);
637
638 rr->ttl = mdns_read_u32(p);
639 p += sizeof(uint32_t);
640
641 // RR data
642 rr_data_len = mdns_read_u16(p);
643 p += sizeof(uint16_t);
644
645 if (p + rr_data_len > e) {
646 DEBUG_PRINTF("rr_data_len goes beyond packet buffer: %lu > %lu\n", rr_data_len, e - p);
647 rr_entry_destroy(rr);
648 return 0;
649 }
650
651 e = p + rr_data_len;
652
653 // see if we can parse the RR data
654 switch (rr->type) {
655 case RR_A:
656 if (rr_data_len < sizeof(uint32_t)) {
657 DEBUG_PRINTF("invalid rr_data_len=%lu for A record\n", rr_data_len);
658 parse_error = 1;
659 break;
660 }
661 rr->data.A.addr = ntohl(mdns_read_u32(p)); /* addr already in net order */
662 p += sizeof(uint32_t);
663 break;
664
665 case RR_AAAA:
666 if (rr_data_len < sizeof(struct in6_addr)) {
667 DEBUG_PRINTF("invalid rr_data_len=%lu for AAAA record\n", rr_data_len);
668 parse_error = 1;
669 break;
670 }
671 rr->data.AAAA.addr = malloc(sizeof(struct in6_addr));
672 for (int i = 0; i < sizeof(struct in6_addr); i++)
673 rr->data.AAAA.addr->s6_addr[i] = p[i];
674 p += sizeof(struct in6_addr);
675 break;
676
677 case RR_PTR:
678 rr->data.PTR.name = uncompress_nlabel(pkt_buf, pkt_len, p - pkt_buf);
679 if (rr->data.PTR.name == NULL) {
680 DEBUG_PRINTF("unable to parse/uncompress label for PTR name\n");
681 parse_error = 1;
682 break;
683 }
684 p += rr_data_len;
685 break;
686
687 case RR_TXT:
688 txt_rec = &rr->data.TXT;
689
690 // not supposed to happen, but we should handle it
691 if (rr_data_len == 0) {
692 DEBUG_PRINTF("WARN: rr_data_len for TXT is 0\n");
693 txt_rec->txt = create_label("");
694 break;
695 }
696
697 while (1) {
698 txt_rec->txt = copy_label(pkt_buf, pkt_len, p - pkt_buf);
699 if (txt_rec->txt == NULL) {
700 DEBUG_PRINTF("unable to copy label for TXT record\n");
701 parse_error = 1;
702 break;
703 }
704 p += txt_rec->txt[0] + 1;
705
706 if (p >= e)
707 break;
708
709 // allocate another record
710 txt_rec->next = malloc(sizeof(struct rr_data_txt));
711 txt_rec = txt_rec->next;
712 txt_rec->next = NULL;
713 }
714 break;
715
716 default:
717 // skip to end of RR data
718 p = e;
719 }
720
721 // if there was a parse error, destroy partial rr_entry
722 if (parse_error) {
723 rr_entry_destroy(rr);
724 return 0;
725 }
726
727 rr_list_append(&pkt->rr_ans, rr);
728
729 return p - (pkt_buf + off);
730 }
731
732 // parse a MDNS packet into an mdns_pkt struct
mdns_parse_pkt(uint8_t * pkt_buf,size_t pkt_len)733 struct mdns_pkt *mdns_parse_pkt(uint8_t *pkt_buf, size_t pkt_len) {
734 uint8_t *p = pkt_buf;
735 size_t off;
736 struct mdns_pkt *pkt;
737 int i;
738
739 if (pkt_len < 12)
740 return NULL;
741
742 MALLOC_ZERO_STRUCT(pkt, mdns_pkt);
743
744 // parse header
745 pkt->id = mdns_read_u16(p); p += sizeof(uint16_t);
746 pkt->flags = mdns_read_u16(p); p += sizeof(uint16_t);
747 pkt->num_qn = mdns_read_u16(p); p += sizeof(uint16_t);
748 pkt->num_ans_rr = mdns_read_u16(p); p += sizeof(uint16_t);
749 pkt->num_auth_rr = mdns_read_u16(p); p += sizeof(uint16_t);
750 pkt->num_add_rr = mdns_read_u16(p); p += sizeof(uint16_t);
751
752 off = p - pkt_buf;
753
754 // parse questions
755 for (i = 0; i < pkt->num_qn; i++) {
756 size_t l = mdns_parse_qn(pkt_buf, pkt_len, off, pkt);
757 if (! l) {
758 DEBUG_PRINTF("error parsing question #%d\n", i);
759 mdns_pkt_destroy(pkt);
760 return NULL;
761 }
762
763 off += l;
764 }
765
766 // parse answer RRs
767 for (i = 0; i < pkt->num_ans_rr; i++) {
768 size_t l = mdns_parse_rr(pkt_buf, pkt_len, off, pkt);
769 if (! l) {
770 DEBUG_PRINTF("error parsing answer #%d\n", i);
771 mdns_pkt_destroy(pkt);
772 return NULL;
773 }
774
775 off += l;
776 }
777
778 // TODO: parse the authority and additional RR sections
779
780 return pkt;
781 }
782
783 // encodes a name (label) into a packet using the name compression scheme
784 // encoded names will be added to the compression list for subsequent use
mdns_encode_name(uint8_t * pkt_buf,size_t pkt_len,size_t off,const uint8_t * name,struct name_comp * comp)785 static size_t mdns_encode_name(uint8_t *pkt_buf, size_t pkt_len, size_t off,
786 const uint8_t *name, struct name_comp *comp) {
787 struct name_comp *c, *c_tail = NULL;
788 uint8_t *p = pkt_buf + off;
789 size_t len = 0;
790
791 if (name) {
792 while (*name) {
793 // find match for compression
794 for (c = comp; c; c = c->next) {
795 if (cmp_nlabel(name, c->label) == 0) {
796 mdns_write_u16(p, 0xC000 | (c->pos & ~0xC000));
797 return len + sizeof(uint16_t);
798 }
799
800 if (c->next == NULL)
801 c_tail = c;
802 }
803
804 // copy this segment
805 int segment_len = *name + 1;
806 strncpy((char *) p, (char *) name, segment_len);
807
808 // cache the name for subsequent compression
809 DECL_MALLOC_ZERO_STRUCT(new_c, name_comp);
810
811 new_c->label = (uint8_t *) name;
812 new_c->pos = p - pkt_buf;
813 c_tail->next = new_c;
814
815 // advance to next name segment
816 p += segment_len;
817 len += segment_len;
818 name += segment_len;
819 }
820 }
821
822 *p = '\0'; // root "label"
823 len += 1;
824
825 return len;
826 }
827
828 // encodes an RR entry at the given offset
829 // returns the size of the entire RR entry
mdns_encode_rr(uint8_t * pkt_buf,size_t pkt_len,size_t off,struct rr_entry * rr,struct name_comp * comp)830 static size_t mdns_encode_rr(uint8_t *pkt_buf, size_t pkt_len, size_t off,
831 struct rr_entry *rr, struct name_comp *comp) {
832 uint8_t *p = pkt_buf + off, *p_data;
833 size_t l;
834 struct rr_data_txt *txt_rec;
835 uint8_t *label;
836 int i;
837
838 assert(off < pkt_len);
839
840 // name
841 l = mdns_encode_name(pkt_buf, pkt_len, off, rr->name, comp);
842 assert(l != 0);
843 p += l;
844
845 // type
846 p = mdns_write_u16(p, rr->type);
847
848 // class & cache flush
849 p = mdns_write_u16(p, (rr->rr_class & ~0x8000) | (rr->cache_flush << 15));
850
851 // TTL
852 p = mdns_write_u32(p, rr->ttl);
853
854 // data length (filled in later)
855 p += sizeof(uint16_t);
856
857 // start of data marker
858 p_data = p;
859
860 switch (rr->type) {
861 case RR_A:
862 /* htonl() needed coz addr already in net order */
863 p = mdns_write_u32(p, htonl(rr->data.A.addr));
864 break;
865
866 case RR_AAAA:
867 for (i = 0; i < sizeof(struct in6_addr); i++)
868 *p++ = rr->data.AAAA.addr->s6_addr[i];
869 break;
870
871 case RR_PTR:
872 label = rr->data.PTR.name ?
873 rr->data.PTR.name :
874 rr->data.PTR.entry->name;
875 p += mdns_encode_name(pkt_buf, pkt_len, p - pkt_buf, label, comp);
876 break;
877
878 case RR_TXT:
879 txt_rec = &rr->data.TXT;
880 for (; txt_rec; txt_rec = txt_rec->next) {
881 int len = txt_rec->txt[0] + 1;
882 strncpy((char *) p, (char *) txt_rec->txt, len);
883 p += len;
884 }
885 break;
886
887 case RR_SRV:
888 p = mdns_write_u16(p, rr->data.SRV.priority);
889
890 p = mdns_write_u16(p, rr->data.SRV.weight);
891
892 p = mdns_write_u16(p, rr->data.SRV.port);
893
894 p += mdns_encode_name(pkt_buf, pkt_len, p - pkt_buf,
895 rr->data.SRV.target, comp);
896 break;
897
898 case RR_NSEC:
899 p += mdns_encode_name(pkt_buf, pkt_len, p - pkt_buf,
900 rr->name, comp);
901
902 *p++ = 0; // bitmap window/block number
903
904 *p++ = sizeof(rr->data.NSEC.bitmap); // bitmap length
905
906 for (i = 0; i < sizeof(rr->data.NSEC.bitmap); i++)
907 *p++ = rr->data.NSEC.bitmap[i];
908
909 break;
910
911 default:
912 DEBUG_PRINTF("unhandled rr type 0x%02x\n", rr->type);
913 }
914
915 // calculate data length based on p
916 l = p - p_data;
917
918 // fill in the length
919 mdns_write_u16(p - l - sizeof(uint16_t), l);
920
921 return p - pkt_buf - off;
922 }
923
924 // encodes a MDNS packet from the given mdns_pkt struct into a buffer
925 // returns the size of the entire MDNS packet
mdns_encode_pkt(struct mdns_pkt * answer,uint8_t * pkt_buf,size_t pkt_len)926 size_t mdns_encode_pkt(struct mdns_pkt *answer, uint8_t *pkt_buf, size_t pkt_len) {
927 struct name_comp *comp;
928 uint8_t *p = pkt_buf;
929 //uint8_t *e = pkt_buf + pkt_len;
930 size_t off;
931 int i;
932
933 assert(answer != NULL);
934 assert(pkt_len >= 12);
935
936 if (p == NULL)
937 return -1;
938
939 // this is an Answer - number of qns should be zero
940 assert(answer->num_qn == 0);
941
942 p = mdns_write_u16(p, answer->id);
943 p = mdns_write_u16(p, answer->flags);
944 p = mdns_write_u16(p, answer->num_qn);
945 p = mdns_write_u16(p, answer->num_ans_rr);
946 p = mdns_write_u16(p, answer->num_auth_rr);
947 p = mdns_write_u16(p, answer->num_add_rr);
948
949 off = p - pkt_buf;
950
951 // allocate list for name compression
952 comp = malloc(sizeof(struct name_comp));
953 if (comp == NULL)
954 return -1;
955 memset(comp, 0, sizeof(struct name_comp));
956
957 // dummy entry
958 comp->label = (uint8_t *) "";
959 comp->pos = 0;
960
961 // skip encoding of qn
962
963 struct rr_list *rr_set[] = {
964 answer->rr_ans,
965 answer->rr_auth,
966 answer->rr_add
967 };
968
969 // encode answer, authority and additional RRs
970 for (i = 0; i < sizeof(rr_set) / sizeof(rr_set[0]); i++) {
971 struct rr_list *rr = rr_set[i];
972 for (; rr; rr = rr->next) {
973 size_t l = mdns_encode_rr(pkt_buf, pkt_len, off, rr->e, comp);
974 off += l;
975
976 if (off >= pkt_len) {
977 DEBUG_PRINTF("packet buffer too small\n");
978 return -1;
979 }
980 }
981
982 }
983
984 // free name compression list
985 while (comp) {
986 struct name_comp *c = comp->next;
987 free(comp);
988 comp = c;
989 }
990
991 return off;
992 }
993
994