1 /*
2  * Copyright (c) 2020 Apple Inc. All rights reserved.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *     https://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "mdns_message.h"
18 
19 #include "mdns_objects.h"
20 
21 #include "DNSMessage.h"
22 #include <CoreUtils/CoreUtils.h>
23 
24 //======================================================================================================================
25 // MARK: - Message Kind Definition
26 
27 struct mdns_message_s {
28 	struct mdns_object_s	base;				// Object base.
29 	dispatch_data_t			msg_data;			// Underlying object for message data.
30 	const uint8_t *			msg_ptr;			// Pointer to first byte of message data.
31 	size_t					msg_len;			// Length of message.
32 	bool					print_body_only;	// True if only the message body should be printed in description.
33 };
34 
35 MDNS_OBJECT_SUBKIND_DEFINE(message);
36 
37 typedef const struct mdns_message_kind_s *	mdns_message_kind_t;
38 struct mdns_message_kind_s {
39 	struct mdns_kind_s	base;
40 	const char *		name;
41 };
42 
43 #define MDNS_MESSAGE_SUBKIND_DEFINE(NAME)															\
44 	static void																						\
45 	_mdns_ ## NAME ## _message_finalize(mdns_ ## NAME ## _message_t message);						\
46 																									\
47 	static const struct mdns_message_kind_s _mdns_ ## NAME ## _message_kind = {						\
48 		.base = {																					\
49 			.superkind	= &_mdns_message_kind,														\
50 			.name		= "mdns_" # NAME "_message",												\
51 			.finalize	= _mdns_ ## NAME ## _message_finalize										\
52 		},																							\
53 		.name = # NAME "_message"																	\
54 	};																								\
55 																									\
56 	static mdns_ ## NAME ## _message_t																\
57 	_mdns_ ## NAME ## _message_alloc(void)															\
58 	{																								\
59 		mdns_ ## NAME ## _message_t obj;															\
60 		obj = mdns_ ## NAME ## _message_object_alloc(sizeof(struct mdns_ ## NAME ## _message_s));	\
61 		require_return_value(obj, NULL);															\
62 																									\
63 		const mdns_object_t object = (mdns_object_t)obj;											\
64 		object->kind = &_mdns_ ## NAME ## _message_kind.base;										\
65 		return obj;																					\
66 	}																								\
67 	MDNS_BASE_CHECK(NAME ## _message, message)
68 
69 //======================================================================================================================
70 // MARK: - Query Message Kind Definition
71 
72 struct mdns_query_message_s {
73 	struct mdns_message_s	base;				// Message object base.
74 	uint8_t *				qname;				// Question's QNAME.
75 	uint16_t				qtype;				// Question's QTYPE.
76 	uint16_t				qclass;				// Question's QCLASS.
77 	uint16_t				msg_id;				// Message ID.
78 	bool					set_ad_bit;			// True if the AD (authentic data) bit is to be set.
79 	bool					set_cd_bit;			// True if the CD (checking disabled) bit is to be set.
80 	bool					set_do_bit;			// True if the DO (DNSSEC OK) bit is to be set in OPT record.
81 	bool					use_edns0_padding;	// True if the query uses EDNS0 padding.
82 	bool					constructed;		// True if the message has been constructed.
83 };
84 
85 MDNS_MESSAGE_SUBKIND_DEFINE(query);
86 
87 //======================================================================================================================
88 // MARK: - Local Prototypes
89 
90 static OSStatus
91 _mdns_message_init(mdns_any_message_t message, dispatch_data_t msg_data, mdns_message_init_options_t options);
92 
93 static OSStatus
94 _mdns_message_set_msg_data(mdns_any_message_t message, dispatch_data_t msg_data);
95 
96 const uint8_t *
97 _mdns_query_message_get_qname_safe(mdns_query_message_t query_message);
98 
99 //======================================================================================================================
100 // MARK: - Messge Public Methods
101 
102 mdns_message_t
mdns_message_create_with_dispatch_data(const dispatch_data_t data,const mdns_message_init_options_t options)103 mdns_message_create_with_dispatch_data(const dispatch_data_t data, const mdns_message_init_options_t options)
104 {
105 	mdns_message_t message = NULL;
106 	mdns_message_t obj = _mdns_message_alloc();
107 	require_quiet(obj, exit);
108 
109 	const OSStatus err = _mdns_message_init(obj, data, options);
110 	require_noerr_quiet(err, exit);
111 
112 	message = obj;
113 	obj = NULL;
114 
115 exit:
116 	mdns_release_null_safe(obj);
117 	return message;
118 }
119 
120 //======================================================================================================================
121 
122 dispatch_data_t
mdns_message_get_dispatch_data(const mdns_message_t me)123 mdns_message_get_dispatch_data(const mdns_message_t me)
124 {
125 	return me->msg_data;
126 }
127 
128 //======================================================================================================================
129 
130 const uint8_t *
mdns_message_get_byte_ptr(const mdns_message_t me)131 mdns_message_get_byte_ptr(const mdns_message_t me)
132 {
133 	return me->msg_ptr;
134 }
135 
136 //======================================================================================================================
137 
138 size_t
mdns_message_get_length(const mdns_message_t me)139 mdns_message_get_length(const mdns_message_t me)
140 {
141 	return me->msg_len;
142 }
143 
144 //======================================================================================================================
145 // MARK: - Message Private Methods
146 
147 static char *
_mdns_message_copy_description(mdns_message_t me,__unused const bool debug,const bool privacy)148 _mdns_message_copy_description(mdns_message_t me, __unused const bool debug, const bool privacy)
149 {
150 	char *description = NULL;
151 	if (me->msg_ptr) {
152 		DNSMessageToStringFlags flags = kDNSMessageToStringFlag_OneLine;
153 		if (me->print_body_only) {
154 			flags |= kDNSMessageToStringFlag_BodyOnly;
155 		}
156 		if (privacy) {
157 			flags |= kDNSMessageToStringFlag_Privacy;
158 		}
159 		DNSMessageToString(me->msg_ptr, me->msg_len, flags, &description);
160 	}
161 	return description;
162 }
163 
164 //======================================================================================================================
165 
166 static void
_mdns_message_finalize(const mdns_message_t me)167 _mdns_message_finalize(const mdns_message_t me)
168 {
169 	me->msg_ptr = NULL;
170 	dispatch_forget(&me->msg_data);
171 }
172 
173 //======================================================================================================================
174 
175 static OSStatus
_mdns_message_init(const mdns_any_message_t any,const dispatch_data_t msg_data,const mdns_message_init_options_t options)176 _mdns_message_init(const mdns_any_message_t any, const dispatch_data_t msg_data,
177 	const mdns_message_init_options_t options)
178 {
179 	const mdns_message_t me = any.message;
180 	if (options & mdns_message_init_option_disable_header_printing) {
181 		me->print_body_only = true;
182 	}
183 	return _mdns_message_set_msg_data(me, msg_data);
184 }
185 
186 //======================================================================================================================
187 
188 static OSStatus
_mdns_message_set_msg_data(const mdns_any_message_t any,const dispatch_data_t msg_data)189 _mdns_message_set_msg_data(const mdns_any_message_t any, const dispatch_data_t msg_data)
190 {
191 	dispatch_data_t	msg_data_new;
192 	const uint8_t *	msg_ptr;
193 	size_t			msg_len;
194 	if (msg_data) {
195 		msg_data_new = dispatch_data_create_map(msg_data, (const void **)&msg_ptr, &msg_len);
196 		require_return_value(msg_data_new, kNoMemoryErr);
197 	} else {
198 		msg_data_new = dispatch_data_empty;
199 		msg_ptr = NULL;
200 		msg_len = 0;
201 	}
202 	const mdns_message_t me = any.message;
203 	dispatch_release_null_safe(me->msg_data);
204 	me->msg_data = msg_data_new;
205 	me->msg_ptr  = msg_ptr;
206 	me->msg_len  = msg_len;
207 	return kNoErr;
208 }
209 
210 //======================================================================================================================
211 // MARK: - Query Messge Public Methods
212 
213 mdns_query_message_t
mdns_query_message_create(const mdns_message_init_options_t options)214 mdns_query_message_create(const mdns_message_init_options_t options)
215 {
216 	mdns_query_message_t message = NULL;
217 	mdns_query_message_t obj = _mdns_query_message_alloc();
218 	require_quiet(obj, exit);
219 
220 	const OSStatus err = _mdns_message_init(obj, NULL, options);
221 	require_noerr_quiet(err, exit);
222 
223 	message = obj;
224 	obj = NULL;
225 
226 exit:
227 	mdns_release_null_safe(obj);
228 	return message;
229 }
230 
231 //======================================================================================================================
232 
233 OSStatus
mdns_query_message_set_qname(const mdns_query_message_t me,const uint8_t * const qname)234 mdns_query_message_set_qname(const mdns_query_message_t me, const uint8_t * const qname)
235 {
236 	require_return_value(!me->constructed, kNoErr);
237 
238 	uint8_t *qname_dup = NULL;
239 	OSStatus err = DomainNameDup(qname, &qname_dup, NULL);
240 	require_noerr_quiet(err, exit);
241 
242 	FreeNullSafe(me->qname);
243 	me->qname = qname_dup;
244 	qname_dup = NULL;
245 	err = kNoErr;
246 
247 exit:
248 	return err;
249 }
250 
251 //======================================================================================================================
252 
253 void
mdns_query_message_set_qtype(const mdns_query_message_t me,const uint16_t qtype)254 mdns_query_message_set_qtype(const mdns_query_message_t me, const uint16_t qtype)
255 {
256 	require_return(!me->constructed);
257 	me->qtype = qtype;
258 }
259 
260 //======================================================================================================================
261 
262 void
mdns_query_message_set_qclass(const mdns_query_message_t me,const uint16_t qclass)263 mdns_query_message_set_qclass(const mdns_query_message_t me, const uint16_t qclass)
264 {
265 	require_return(!me->constructed);
266 	me->qclass = qclass;
267 }
268 
269 //======================================================================================================================
270 
271 void
mdns_query_message_set_message_id(const mdns_query_message_t me,const uint16_t msg_id)272 mdns_query_message_set_message_id(const mdns_query_message_t me, const uint16_t msg_id)
273 {
274 	require_return(!me->constructed);
275 	me->msg_id = msg_id;
276 }
277 
278 //======================================================================================================================
279 
280 void
mdns_query_message_set_ad_bit(const mdns_query_message_t me,const bool set)281 mdns_query_message_set_ad_bit(const mdns_query_message_t me, const bool set)
282 {
283 	require_return(!me->constructed);
284 	me->set_ad_bit = set;
285 }
286 
287 //======================================================================================================================
288 
289 void
mdns_query_message_set_cd_bit(const mdns_query_message_t me,const bool set)290 mdns_query_message_set_cd_bit(const mdns_query_message_t me, const bool set)
291 {
292 	require_return(!me->constructed);
293 	me->set_cd_bit = set;
294 }
295 
296 //======================================================================================================================
297 
298 void
mdns_query_message_set_do_bit(const mdns_query_message_t me,const bool set)299 mdns_query_message_set_do_bit(const mdns_query_message_t me, const bool set)
300 {
301 	require_return(!me->constructed);
302 	me->set_do_bit = set;
303 }
304 
305 //======================================================================================================================
306 
307 void
mdns_query_message_use_edns0_padding(const mdns_query_message_t me,const bool use)308 mdns_query_message_use_edns0_padding(const mdns_query_message_t me, const bool use)
309 {
310 	require_return(!me->constructed);
311 	me->use_edns0_padding = use;
312 }
313 
314 //======================================================================================================================
315 
316 #define MDNS_EDNS0_PADDING_OVERHEAD		 15	// Size of OPT pseudo-RR with OPTION-CODE and OPTION-LENGTH
317 #define MDNS_EDNS0_PADDING_BLOCK_SIZE	128	// <https://tools.ietf.org/html/rfc8467#section-4.1>
318 
319 #define MDNS_QUERY_MESSAGE_BUFFER_SIZE \
320 	RoundUp(kDNSQueryMessageMaxLen + MDNS_EDNS0_PADDING_OVERHEAD, MDNS_EDNS0_PADDING_BLOCK_SIZE)
321 
322 static OSStatus
323 _mdns_query_message_add_edns0_padding(uint8_t query_buf[static MDNS_QUERY_MESSAGE_BUFFER_SIZE], size_t query_len,
324 	bool set_do_bit, size_t *out_len);
325 
326 static OSStatus
327 _mdns_query_message_add_edns0_dnssec_ok(uint8_t query_buf[static MDNS_QUERY_MESSAGE_BUFFER_SIZE], size_t query_len,
328 	size_t *out_len);
329 
330 OSStatus
mdns_query_message_construct(const mdns_query_message_t me)331 mdns_query_message_construct(const mdns_query_message_t me)
332 {
333 	uint16_t flags = kDNSHeaderFlag_RecursionDesired;
334 	if (me->set_ad_bit) {
335 		flags |= kDNSHeaderFlag_AuthenticData;
336 	}
337 	if (me->set_cd_bit) {
338 		flags |= kDNSHeaderFlag_CheckingDisabled;
339 	}
340 	uint8_t	query_buf[MDNS_QUERY_MESSAGE_BUFFER_SIZE];
341 	size_t	query_len;
342 	const uint8_t * const qname = _mdns_query_message_get_qname_safe(me);
343 	OSStatus err = DNSMessageWriteQuery(me->msg_id, flags, qname, me->qtype, me->qclass, query_buf, &query_len);
344 	require_noerr_quiet(err, exit);
345 
346 	if (me->use_edns0_padding) {
347 		err = _mdns_query_message_add_edns0_padding(query_buf, query_len, me->set_do_bit, &query_len);
348 		require_noerr_quiet(err, exit);
349 	} else if (me->set_do_bit) {
350 		err = _mdns_query_message_add_edns0_dnssec_ok(query_buf, query_len, &query_len);
351 		require_noerr_quiet(err, exit);
352 	}
353 	dispatch_data_t query_data = dispatch_data_create(query_buf, query_len, NULL, DISPATCH_DATA_DESTRUCTOR_DEFAULT);
354 	require_action_quiet(query_data, exit, err = kNoMemoryErr);
355 
356 	err = _mdns_message_set_msg_data(me, query_data);
357 	dispatch_forget(&query_data);
358 	require_noerr_quiet(err, exit);
359 
360 	me->constructed = true;
361 
362 exit:
363 	return err;
364 }
365 
366 static OSStatus
_mdns_query_message_add_edns0_padding(uint8_t query_buf[static MDNS_QUERY_MESSAGE_BUFFER_SIZE],const size_t query_len,const bool set_do_bit,size_t * const out_len)367 _mdns_query_message_add_edns0_padding(uint8_t query_buf[static MDNS_QUERY_MESSAGE_BUFFER_SIZE], const size_t query_len,
368 	const bool set_do_bit, size_t * const out_len)
369 {
370 	const size_t new_len = RoundUp(query_len + MDNS_EDNS0_PADDING_OVERHEAD, MDNS_EDNS0_PADDING_BLOCK_SIZE);
371 	require_return_value(new_len <= MDNS_QUERY_MESSAGE_BUFFER_SIZE, kSizeErr);
372 
373 	uint8_t * const			end		= &query_buf[query_len];
374 	const uint8_t * const	new_end	= &query_buf[new_len];
375 	memset(end, 0, (size_t)(new_end - end));
376 
377 	check_compile_time_code(MDNS_EDNS0_PADDING_OVERHEAD == sizeof(dns_fixed_fields_opt1));
378 
379 	dns_fixed_fields_opt1 * const	pad_opt		= (dns_fixed_fields_opt1 *)end;
380 	const uint8_t * const			pad_start	= (const uint8_t *)&pad_opt[1];
381 	dns_fixed_fields_opt1_set_type(pad_opt, kDNSRecordType_OPT);
382 	dns_fixed_fields_opt1_set_udp_payload_size(pad_opt, 512);
383 	dns_fixed_fields_opt1_set_rdlen(pad_opt, (uint16_t)(new_end - pad_opt->option_code));
384 	dns_fixed_fields_opt1_set_option_code(pad_opt, kDNSEDNS0OptionCode_Padding);
385 	dns_fixed_fields_opt1_set_option_length(pad_opt, (uint16_t)(new_end - pad_start));
386 	if (set_do_bit) {
387 		dns_fixed_fields_opt1_set_extended_flags(pad_opt, kDNSExtendedFlag_DNSSECOK);
388 	}
389 	DNSHeaderSetAdditionalCount((DNSHeader *)&query_buf[0], 1);
390 	if (out_len) {
391 		*out_len = new_len;
392 	}
393 	return kNoErr;
394 }
395 
396 static OSStatus
_mdns_query_message_add_edns0_dnssec_ok(uint8_t query_buf[static MDNS_QUERY_MESSAGE_BUFFER_SIZE],const size_t query_len,size_t * const out_len)397 _mdns_query_message_add_edns0_dnssec_ok(uint8_t query_buf[static MDNS_QUERY_MESSAGE_BUFFER_SIZE],
398 	const size_t query_len, size_t * const out_len)
399 {
400 	dns_fixed_fields_opt *opt;
401 	const size_t new_len = query_len + sizeof(*opt);
402 	require_return_value(new_len <= MDNS_QUERY_MESSAGE_BUFFER_SIZE, kSizeErr);
403 
404 	opt = (dns_fixed_fields_opt *)&query_buf[query_len];
405 	memset(opt, 0, sizeof(*opt));
406 	dns_fixed_fields_opt_set_type(opt, kDNSRecordType_OPT);
407 	dns_fixed_fields_opt_set_udp_payload_size(opt, 512);
408 	dns_fixed_fields_opt_set_extended_flags(opt, kDNSExtendedFlag_DNSSECOK);
409 
410 	DNSHeaderSetAdditionalCount((DNSHeader *)&query_buf[0], 1);
411 	if (out_len) {
412 		*out_len = new_len;
413 	}
414 	return kNoErr;
415 }
416 
417 //======================================================================================================================
418 
419 const uint8_t *
mdns_query_message_get_qname(const mdns_query_message_t me)420 mdns_query_message_get_qname(const mdns_query_message_t me)
421 {
422 	return _mdns_query_message_get_qname_safe(me);
423 }
424 
425 //======================================================================================================================
426 
427 uint16_t
mdns_query_message_get_qtype(const mdns_query_message_t me)428 mdns_query_message_get_qtype(const mdns_query_message_t me)
429 {
430 	return me->qtype;
431 }
432 
433 //======================================================================================================================
434 
435 uint16_t
mdns_query_message_get_qclass(const mdns_query_message_t me)436 mdns_query_message_get_qclass(const mdns_query_message_t me)
437 {
438 	return me->qclass;
439 }
440 
441 //======================================================================================================================
442 
443 uint16_t
mdns_query_message_get_message_id(const mdns_query_message_t me)444 mdns_query_message_get_message_id(const mdns_query_message_t me)
445 {
446 	return me->msg_id;
447 }
448 
449 //======================================================================================================================
450 
451 bool
mdns_query_message_do_bit_is_set(const mdns_query_message_t me)452 mdns_query_message_do_bit_is_set(const mdns_query_message_t me)
453 {
454 	return me->set_do_bit;
455 }
456 
457 //======================================================================================================================
458 // MARK: - Query Message Private Methods
459 
460 static void
_mdns_query_message_finalize(const mdns_query_message_t me)461 _mdns_query_message_finalize(const mdns_query_message_t me)
462 {
463 	ForgetMem(&me->qname);
464 }
465 
466 //======================================================================================================================
467 
468 const uint8_t *
_mdns_query_message_get_qname_safe(const mdns_query_message_t me)469 _mdns_query_message_get_qname_safe(const mdns_query_message_t me)
470 {
471 	return (me->qname ? me->qname : (const uint8_t *)"");
472 }
473