1 
2 #ifdef _WIN32
3 #define _CRT_SECURE_NO_WARNINGS 1
4 #endif
5 
6 #include <stdio.h>
7 
8 #include "mdns.h"
9 
10 #include <errno.h>
11 
12 #ifdef _WIN32
13 #include <iphlpapi.h>
14 #define sleep(x) Sleep(x * 1000)
15 #else
16 #include <netdb.h>
17 #include <ifaddrs.h>
18 #endif
19 
20 static char addrbuffer[64];
21 static char entrybuffer[256];
22 static char namebuffer[256];
23 static char sendbuffer[256];
24 static mdns_record_txt_t txtbuffer[128];
25 
26 static uint32_t service_address_ipv4;
27 static uint8_t service_address_ipv6[16];
28 
29 static int has_ipv4;
30 static int has_ipv6;
31 
32 typedef struct {
33 	const char* service;
34 	const char* hostname;
35 	uint32_t address_ipv4;
36 	uint8_t* address_ipv6;
37 	int port;
38 } service_record_t;
39 
40 static mdns_string_t
ipv4_address_to_string(char * buffer,size_t capacity,const struct sockaddr_in * addr,size_t addrlen)41 ipv4_address_to_string(char* buffer, size_t capacity, const struct sockaddr_in* addr,
42                        size_t addrlen) {
43 	char host[NI_MAXHOST] = {0};
44 	char service[NI_MAXSERV] = {0};
45 	int ret = getnameinfo((const struct sockaddr*)addr, (socklen_t)addrlen, host, NI_MAXHOST,
46 	                      service, NI_MAXSERV, NI_NUMERICSERV | NI_NUMERICHOST);
47 	int len = 0;
48 	if (ret == 0) {
49 		if (addr->sin_port != 0)
50 			len = snprintf(buffer, capacity, "%s:%s", host, service);
51 		else
52 			len = snprintf(buffer, capacity, "%s", host);
53 	}
54 	if (len >= (int)capacity)
55 		len = (int)capacity - 1;
56 	mdns_string_t str;
57 	str.str = buffer;
58 	str.length = len;
59 	return str;
60 }
61 
62 static mdns_string_t
ipv6_address_to_string(char * buffer,size_t capacity,const struct sockaddr_in6 * addr,size_t addrlen)63 ipv6_address_to_string(char* buffer, size_t capacity, const struct sockaddr_in6* addr,
64                        size_t addrlen) {
65 	char host[NI_MAXHOST] = {0};
66 	char service[NI_MAXSERV] = {0};
67 	int ret = getnameinfo((const struct sockaddr*)addr, (socklen_t)addrlen, host, NI_MAXHOST,
68 	                      service, NI_MAXSERV, NI_NUMERICSERV | NI_NUMERICHOST);
69 	int len = 0;
70 	if (ret == 0) {
71 		if (addr->sin6_port != 0)
72 			len = snprintf(buffer, capacity, "[%s]:%s", host, service);
73 		else
74 			len = snprintf(buffer, capacity, "%s", host);
75 	}
76 	if (len >= (int)capacity)
77 		len = (int)capacity - 1;
78 	mdns_string_t str;
79 	str.str = buffer;
80 	str.length = len;
81 	return str;
82 }
83 
84 static mdns_string_t
ip_address_to_string(char * buffer,size_t capacity,const struct sockaddr * addr,size_t addrlen)85 ip_address_to_string(char* buffer, size_t capacity, const struct sockaddr* addr, size_t addrlen) {
86 	if (addr->sa_family == AF_INET6)
87 		return ipv6_address_to_string(buffer, capacity, (const struct sockaddr_in6*)addr, addrlen);
88 	return ipv4_address_to_string(buffer, capacity, (const struct sockaddr_in*)addr, addrlen);
89 }
90 
91 static int
query_callback(int sock,const struct sockaddr * from,size_t addrlen,mdns_entry_type_t entry,uint16_t query_id,uint16_t rtype,uint16_t rclass,uint32_t ttl,const void * data,size_t size,size_t name_offset,size_t name_length,size_t record_offset,size_t record_length,void * user_data)92 query_callback(int sock, const struct sockaddr* from, size_t addrlen, mdns_entry_type_t entry,
93                uint16_t query_id, uint16_t rtype, uint16_t rclass, uint32_t ttl, const void* data,
94                size_t size, size_t name_offset, size_t name_length, size_t record_offset,
95                size_t record_length, void* user_data) {
96 	(void)sizeof(sock);
97 	(void)sizeof(query_id);
98 	(void)sizeof(name_length);
99 	(void)sizeof(user_data);
100 	mdns_string_t fromaddrstr = ip_address_to_string(addrbuffer, sizeof(addrbuffer), from, addrlen);
101 	const char* entrytype = (entry == MDNS_ENTRYTYPE_ANSWER) ?
102 	                            "answer" :
103 	                            ((entry == MDNS_ENTRYTYPE_AUTHORITY) ? "authority" : "additional");
104 	mdns_string_t entrystr =
105 	    mdns_string_extract(data, size, &name_offset, entrybuffer, sizeof(entrybuffer));
106 	if (rtype == MDNS_RECORDTYPE_PTR) {
107 		mdns_string_t namestr = mdns_record_parse_ptr(data, size, record_offset, record_length,
108 		                                              namebuffer, sizeof(namebuffer));
109 		printf("%.*s : %s %.*s PTR %.*s rclass 0x%x ttl %u length %d\n",
110 		       MDNS_STRING_FORMAT(fromaddrstr), entrytype, MDNS_STRING_FORMAT(entrystr),
111 		       MDNS_STRING_FORMAT(namestr), rclass, ttl, (int)record_length);
112 	} else if (rtype == MDNS_RECORDTYPE_SRV) {
113 		mdns_record_srv_t srv = mdns_record_parse_srv(data, size, record_offset, record_length,
114 		                                              namebuffer, sizeof(namebuffer));
115 		printf("%.*s : %s %.*s SRV %.*s priority %d weight %d port %d\n",
116 		       MDNS_STRING_FORMAT(fromaddrstr), entrytype, MDNS_STRING_FORMAT(entrystr),
117 		       MDNS_STRING_FORMAT(srv.name), srv.priority, srv.weight, srv.port);
118 	} else if (rtype == MDNS_RECORDTYPE_A) {
119 		struct sockaddr_in addr;
120 		mdns_record_parse_a(data, size, record_offset, record_length, &addr);
121 		mdns_string_t addrstr =
122 		    ipv4_address_to_string(namebuffer, sizeof(namebuffer), &addr, sizeof(addr));
123 		printf("%.*s : %s %.*s A %.*s\n", MDNS_STRING_FORMAT(fromaddrstr), entrytype,
124 		       MDNS_STRING_FORMAT(entrystr), MDNS_STRING_FORMAT(addrstr));
125 	} else if (rtype == MDNS_RECORDTYPE_AAAA) {
126 		struct sockaddr_in6 addr;
127 		mdns_record_parse_aaaa(data, size, record_offset, record_length, &addr);
128 		mdns_string_t addrstr =
129 		    ipv6_address_to_string(namebuffer, sizeof(namebuffer), &addr, sizeof(addr));
130 		printf("%.*s : %s %.*s AAAA %.*s\n", MDNS_STRING_FORMAT(fromaddrstr), entrytype,
131 		       MDNS_STRING_FORMAT(entrystr), MDNS_STRING_FORMAT(addrstr));
132 	} else if (rtype == MDNS_RECORDTYPE_TXT) {
133 		size_t parsed = mdns_record_parse_txt(data, size, record_offset, record_length, txtbuffer,
134 		                                      sizeof(txtbuffer) / sizeof(mdns_record_txt_t));
135 		for (size_t itxt = 0; itxt < parsed; ++itxt) {
136 			if (txtbuffer[itxt].value.length) {
137 				printf("%.*s : %s %.*s TXT %.*s = %.*s\n", MDNS_STRING_FORMAT(fromaddrstr),
138 				       entrytype, MDNS_STRING_FORMAT(entrystr),
139 				       MDNS_STRING_FORMAT(txtbuffer[itxt].key),
140 				       MDNS_STRING_FORMAT(txtbuffer[itxt].value));
141 			} else {
142 				printf("%.*s : %s %.*s TXT %.*s\n", MDNS_STRING_FORMAT(fromaddrstr), entrytype,
143 				       MDNS_STRING_FORMAT(entrystr), MDNS_STRING_FORMAT(txtbuffer[itxt].key));
144 			}
145 		}
146 	} else {
147 		printf("%.*s : %s %.*s type %u rclass 0x%x ttl %u length %d\n",
148 		       MDNS_STRING_FORMAT(fromaddrstr), entrytype, MDNS_STRING_FORMAT(entrystr), rtype,
149 		       rclass, ttl, (int)record_length);
150 	}
151 	return 0;
152 }
153 
154 static int
service_callback(int sock,const struct sockaddr * from,size_t addrlen,mdns_entry_type_t entry,uint16_t query_id,uint16_t rtype,uint16_t rclass,uint32_t ttl,const void * data,size_t size,size_t name_offset,size_t name_length,size_t record_offset,size_t record_length,void * user_data)155 service_callback(int sock, const struct sockaddr* from, size_t addrlen, mdns_entry_type_t entry,
156                  uint16_t query_id, uint16_t rtype, uint16_t rclass, uint32_t ttl, const void* data,
157                  size_t size, size_t name_offset, size_t name_length, size_t record_offset,
158                  size_t record_length, void* user_data) {
159 	(void)sizeof(name_offset);
160 	(void)sizeof(name_length);
161 	(void)sizeof(ttl);
162 	if (entry != MDNS_ENTRYTYPE_QUESTION)
163 		return 0;
164 	mdns_string_t fromaddrstr = ip_address_to_string(addrbuffer, sizeof(addrbuffer), from, addrlen);
165 	if (rtype == MDNS_RECORDTYPE_PTR) {
166 		mdns_string_t service = mdns_record_parse_ptr(data, size, record_offset, record_length,
167 		                                              namebuffer, sizeof(namebuffer));
168 		printf("%.*s : question PTR %.*s\n", MDNS_STRING_FORMAT(fromaddrstr),
169 		       MDNS_STRING_FORMAT(service));
170 
171 		const char dns_sd[] = "_services._dns-sd._udp.local.";
172 		const service_record_t* service_record = (const service_record_t*)user_data;
173 		size_t service_length = strlen(service_record->service);
174 		if ((service.length == (sizeof(dns_sd) - 1)) &&
175 		    (strncmp(service.str, dns_sd, sizeof(dns_sd) - 1) == 0)) {
176 			printf("  --> answer %s\n", service_record->service);
177 			mdns_discovery_answer(sock, from, addrlen, sendbuffer, sizeof(sendbuffer),
178 			                      service_record->service, service_length);
179 		} else if ((service.length == service_length) &&
180 		           (strncmp(service.str, service_record->service, service_length) == 0)) {
181 			uint16_t unicast = (rclass & MDNS_UNICAST_RESPONSE);
182 			printf("  --> answer %s.%s port %d (%s)\n", service_record->hostname,
183 			       service_record->service, service_record->port,
184 			       (unicast ? "unicast" : "multicast"));
185 			if (!unicast)
186 				addrlen = 0;
187 			char txt_record[] = "test=1";
188 			mdns_query_answer(sock, from, addrlen, sendbuffer, sizeof(sendbuffer), query_id,
189 			                  service_record->service, service_length, service_record->hostname,
190 			                  strlen(service_record->hostname), service_record->address_ipv4,
191 			                  service_record->address_ipv6, (uint16_t)service_record->port,
192 			                  txt_record, sizeof(txt_record));
193 		}
194 	} else if (rtype == MDNS_RECORDTYPE_SRV) {
195 		mdns_record_srv_t service = mdns_record_parse_srv(data, size, record_offset, record_length,
196 		                                                  namebuffer, sizeof(namebuffer));
197 		printf("%.*s : question SRV %.*s\n", MDNS_STRING_FORMAT(fromaddrstr),
198 		       MDNS_STRING_FORMAT(service.name));
199 #if 0
200 		if ((service.length == service_length) &&
201 		    (strncmp(service.str, service_record->service, service_length) == 0)) {
202 			uint16_t unicast = (rclass & MDNS_UNICAST_RESPONSE);
203 			printf("  --> answer %s.%s port %d (%s)\n", service_record->hostname,
204 			       service_record->service, service_record->port,
205 			       (unicast ? "unicast" : "multicast"));
206 			if (!unicast)
207 				addrlen = 0;
208 			char txt_record[] = "test=1";
209 			mdns_query_answer(sock, from, addrlen, sendbuffer, sizeof(sendbuffer), query_id,
210 			                  service_record->service, service_length, service_record->hostname,
211 			                  strlen(service_record->hostname), service_record->address_ipv4,
212 			                  service_record->address_ipv6, (uint16_t)service_record->port,
213 			                  txt_record, sizeof(txt_record));
214 		}
215 #endif
216 	}
217 	return 0;
218 }
219 
220 static int
open_client_sockets(int * sockets,int max_sockets,int port)221 open_client_sockets(int* sockets, int max_sockets, int port) {
222 	// When sending, each socket can only send to one network interface
223 	// Thus we need to open one socket for each interface and address family
224 	int num_sockets = 0;
225 
226 #ifdef _WIN32
227 
228 	IP_ADAPTER_ADDRESSES* adapter_address = 0;
229 	ULONG address_size = 8000;
230 	unsigned int ret;
231 	unsigned int num_retries = 4;
232 	do {
233 		adapter_address = malloc(address_size);
234 		ret = GetAdaptersAddresses(AF_UNSPEC, GAA_FLAG_SKIP_MULTICAST | GAA_FLAG_SKIP_ANYCAST, 0,
235 		                           adapter_address, &address_size);
236 		if (ret == ERROR_BUFFER_OVERFLOW) {
237 			free(adapter_address);
238 			adapter_address = 0;
239 		} else {
240 			break;
241 		}
242 	} while (num_retries-- > 0);
243 
244 	if (!adapter_address || (ret != NO_ERROR)) {
245 		free(adapter_address);
246 		printf("Failed to get network adapter addresses\n");
247 		return num_sockets;
248 	}
249 
250 	int first_ipv4 = 1;
251 	int first_ipv6 = 1;
252 	for (PIP_ADAPTER_ADDRESSES adapter = adapter_address; adapter; adapter = adapter->Next) {
253 		if (adapter->TunnelType == TUNNEL_TYPE_TEREDO)
254 			continue;
255 		if (adapter->OperStatus != IfOperStatusUp)
256 			continue;
257 
258 		for (IP_ADAPTER_UNICAST_ADDRESS* unicast = adapter->FirstUnicastAddress; unicast;
259 		     unicast = unicast->Next) {
260 			if (unicast->Address.lpSockaddr->sa_family == AF_INET) {
261 				struct sockaddr_in* saddr = (struct sockaddr_in*)unicast->Address.lpSockaddr;
262 				if ((saddr->sin_addr.S_un.S_un_b.s_b1 != 127) ||
263 				    (saddr->sin_addr.S_un.S_un_b.s_b2 != 0) ||
264 				    (saddr->sin_addr.S_un.S_un_b.s_b3 != 0) ||
265 				    (saddr->sin_addr.S_un.S_un_b.s_b4 != 1)) {
266 					int log_addr = 0;
267 					if (first_ipv4) {
268 						service_address_ipv4 = saddr->sin_addr.S_un.S_addr;
269 						first_ipv4 = 0;
270 						log_addr = 1;
271 					}
272 					has_ipv4 = 1;
273 					if (num_sockets < max_sockets) {
274 						saddr->sin_port = htons((unsigned short)port);
275 						int sock = mdns_socket_open_ipv4(saddr);
276 						if (sock >= 0) {
277 							sockets[num_sockets++] = sock;
278 							log_addr = 1;
279 						} else {
280 							log_addr = 0;
281 						}
282 					}
283 					if (log_addr) {
284 						char buffer[128];
285 						mdns_string_t addr = ipv4_address_to_string(buffer, sizeof(buffer), saddr,
286 						                                            sizeof(struct sockaddr_in));
287 						printf("Local IPv4 address: %.*s\n", MDNS_STRING_FORMAT(addr));
288 					}
289 				}
290 			} else if (unicast->Address.lpSockaddr->sa_family == AF_INET6) {
291 				struct sockaddr_in6* saddr = (struct sockaddr_in6*)unicast->Address.lpSockaddr;
292 				static const unsigned char localhost[] = {0, 0, 0, 0, 0, 0, 0, 0,
293 				                                          0, 0, 0, 0, 0, 0, 0, 1};
294 				static const unsigned char localhost_mapped[] = {0, 0, 0,    0,    0,    0, 0, 0,
295 				                                                 0, 0, 0xff, 0xff, 0x7f, 0, 0, 1};
296 				if ((unicast->DadState == NldsPreferred) &&
297 				    memcmp(saddr->sin6_addr.s6_addr, localhost, 16) &&
298 				    memcmp(saddr->sin6_addr.s6_addr, localhost_mapped, 16)) {
299 					int log_addr = 0;
300 					if (first_ipv6) {
301 						memcpy(service_address_ipv6, &saddr->sin6_addr, 16);
302 						first_ipv6 = 0;
303 						log_addr = 1;
304 					}
305 					has_ipv6 = 1;
306 					if (num_sockets < max_sockets) {
307 						saddr->sin6_port = htons((unsigned short)port);
308 						int sock = mdns_socket_open_ipv6(saddr);
309 						if (sock >= 0) {
310 							sockets[num_sockets++] = sock;
311 							log_addr = 1;
312 						} else {
313 							log_addr = 0;
314 						}
315 					}
316 					if (log_addr) {
317 						char buffer[128];
318 						mdns_string_t addr = ipv6_address_to_string(buffer, sizeof(buffer), saddr,
319 						                                            sizeof(struct sockaddr_in6));
320 						printf("Local IPv6 address: %.*s\n", MDNS_STRING_FORMAT(addr));
321 					}
322 				}
323 			}
324 		}
325 	}
326 
327 	free(adapter_address);
328 
329 #else
330 
331 	struct ifaddrs* ifaddr = 0;
332 	struct ifaddrs* ifa = 0;
333 
334 	if (getifaddrs(&ifaddr) < 0)
335 		printf("Unable to get interface addresses\n");
336 
337 	int first_ipv4 = 1;
338 	int first_ipv6 = 1;
339 	for (ifa = ifaddr; ifa; ifa = ifa->ifa_next) {
340 		if (!ifa->ifa_addr)
341 			continue;
342 
343 		if (ifa->ifa_addr->sa_family == AF_INET) {
344 			struct sockaddr_in* saddr = (struct sockaddr_in*)ifa->ifa_addr;
345 			if (saddr->sin_addr.s_addr != htonl(INADDR_LOOPBACK)) {
346 				int log_addr = 0;
347 				if (first_ipv4) {
348 					service_address_ipv4 = saddr->sin_addr.s_addr;
349 					first_ipv4 = 0;
350 					log_addr = 1;
351 				}
352 				has_ipv4 = 1;
353 				if (num_sockets < max_sockets) {
354 					saddr->sin_port = htons(port);
355 					int sock = mdns_socket_open_ipv4(saddr);
356 					if (sock >= 0) {
357 						sockets[num_sockets++] = sock;
358 						log_addr = 1;
359 					} else {
360 						log_addr = 0;
361 					}
362 				}
363 				if (log_addr) {
364 					char buffer[128];
365 					mdns_string_t addr = ipv4_address_to_string(buffer, sizeof(buffer), saddr,
366 					                                            sizeof(struct sockaddr_in));
367 					printf("Local IPv4 address: %.*s\n", MDNS_STRING_FORMAT(addr));
368 				}
369 			}
370 		} else if (ifa->ifa_addr->sa_family == AF_INET6) {
371 			struct sockaddr_in6* saddr = (struct sockaddr_in6*)ifa->ifa_addr;
372 			static const unsigned char localhost[] = {0, 0, 0, 0, 0, 0, 0, 0,
373 			                                          0, 0, 0, 0, 0, 0, 0, 1};
374 			static const unsigned char localhost_mapped[] = {0, 0, 0,    0,    0,    0, 0, 0,
375 			                                                 0, 0, 0xff, 0xff, 0x7f, 0, 0, 1};
376 			if (memcmp(saddr->sin6_addr.s6_addr, localhost, 16) &&
377 			    memcmp(saddr->sin6_addr.s6_addr, localhost_mapped, 16)) {
378 				int log_addr = 0;
379 				if (first_ipv6) {
380 					memcpy(service_address_ipv6, &saddr->sin6_addr, 16);
381 					first_ipv6 = 0;
382 					log_addr = 1;
383 				}
384 				has_ipv6 = 1;
385 				if (num_sockets < max_sockets) {
386 					saddr->sin6_port = htons(port);
387 					int sock = mdns_socket_open_ipv6(saddr);
388 					if (sock >= 0) {
389 						sockets[num_sockets++] = sock;
390 						log_addr = 1;
391 					} else {
392 						log_addr = 0;
393 					}
394 				}
395 				if (log_addr) {
396 					char buffer[128];
397 					mdns_string_t addr = ipv6_address_to_string(buffer, sizeof(buffer), saddr,
398 					                                            sizeof(struct sockaddr_in6));
399 					printf("Local IPv6 address: %.*s\n", MDNS_STRING_FORMAT(addr));
400 				}
401 			}
402 		}
403 	}
404 
405 	freeifaddrs(ifaddr);
406 
407 #endif
408 
409 	return num_sockets;
410 }
411 
412 static int
open_service_sockets(int * sockets,int max_sockets)413 open_service_sockets(int* sockets, int max_sockets) {
414 	// When recieving, each socket can recieve data from all network interfaces
415 	// Thus we only need to open one socket for each address family
416 	int num_sockets = 0;
417 
418 	// Call the client socket function to enumerate and get local addresses,
419 	// but not open the actual sockets
420 	open_client_sockets(0, 0, 0);
421 
422 	if (num_sockets < max_sockets) {
423 		struct sockaddr_in sock_addr;
424 		memset(&sock_addr, 0, sizeof(struct sockaddr_in));
425 		sock_addr.sin_family = AF_INET;
426 #ifdef _WIN32
427 		sock_addr.sin_addr = in4addr_any;
428 #else
429 		sock_addr.sin_addr.s_addr = INADDR_ANY;
430 #endif
431 		sock_addr.sin_port = htons(MDNS_PORT);
432 #ifdef __APPLE__
433 		sock_addr.sin_len = sizeof(struct sockaddr_in);
434 #endif
435 		int sock = mdns_socket_open_ipv4(&sock_addr);
436 		if (sock >= 0)
437 			sockets[num_sockets++] = sock;
438 	}
439 
440 	if (num_sockets < max_sockets) {
441 		struct sockaddr_in6 sock_addr;
442 		memset(&sock_addr, 0, sizeof(struct sockaddr_in6));
443 		sock_addr.sin6_family = AF_INET6;
444 		sock_addr.sin6_addr = in6addr_any;
445 		sock_addr.sin6_port = htons(MDNS_PORT);
446 #ifdef __APPLE__
447 		sock_addr.sin6_len = sizeof(struct sockaddr_in6);
448 #endif
449 		int sock = mdns_socket_open_ipv6(&sock_addr);
450 		if (sock >= 0)
451 			sockets[num_sockets++] = sock;
452 	}
453 
454 	return num_sockets;
455 }
456 
457 static int
send_dns_sd(void)458 send_dns_sd(void) {
459 	int sockets[32];
460 	int num_sockets = open_client_sockets(sockets, sizeof(sockets) / sizeof(sockets[0]), 0);
461 	if (num_sockets <= 0) {
462 		printf("Failed to open any client sockets\n");
463 		return -1;
464 	}
465 	printf("Opened %d socket%s for DNS-SD\n", num_sockets, num_sockets ? "s" : "");
466 
467 	printf("Sending DNS-SD discovery\n");
468 	for (int isock = 0; isock < num_sockets; ++isock) {
469 		if (mdns_discovery_send(sockets[isock]))
470 			printf("Failed to send DNS-DS discovery: %s\n", strerror(errno));
471 	}
472 
473 	size_t capacity = 2048;
474 	void* buffer = malloc(capacity);
475 	void* user_data = 0;
476 	size_t records;
477 
478 	// This is a simple implementation that loops for 5 seconds or as long as we get replies
479 	int res;
480 	printf("Reading DNS-SD replies\n");
481 	do {
482 		struct timeval timeout;
483 		timeout.tv_sec = 5;
484 		timeout.tv_usec = 0;
485 
486 		int nfds = 0;
487 		fd_set readfs;
488 		FD_ZERO(&readfs);
489 		for (int isock = 0; isock < num_sockets; ++isock) {
490 			if (sockets[isock] >= nfds)
491 				nfds = sockets[isock] + 1;
492 			FD_SET(sockets[isock], &readfs);
493 		}
494 
495 		records = 0;
496 		res = select(nfds, &readfs, 0, 0, &timeout);
497 		if (res > 0) {
498 			for (int isock = 0; isock < num_sockets; ++isock) {
499 				if (FD_ISSET(sockets[isock], &readfs)) {
500 					records += mdns_discovery_recv(sockets[isock], buffer, capacity, query_callback,
501 					                               user_data);
502 				}
503 			}
504 		}
505 	} while (res > 0);
506 
507 	free(buffer);
508 
509 	for (int isock = 0; isock < num_sockets; ++isock)
510 		mdns_socket_close(sockets[isock]);
511 	printf("Closed socket%s\n", num_sockets ? "s" : "");
512 
513 	return 0;
514 }
515 
516 static int
send_mdns_query(const char * service)517 send_mdns_query(const char* service) {
518 	int sockets[32];
519 	int query_id[32];
520 	int num_sockets = open_client_sockets(sockets, sizeof(sockets) / sizeof(sockets[0]), 0);
521 	if (num_sockets <= 0) {
522 		printf("Failed to open any client sockets\n");
523 		return -1;
524 	}
525 	printf("Opened %d socket%s for mDNS query\n", num_sockets, num_sockets ? "s" : "");
526 
527 	size_t capacity = 2048;
528 	void* buffer = malloc(capacity);
529 	void* user_data = 0;
530 	size_t records;
531 
532 	printf("Sending mDNS query: %s\n", service);
533 	for (int isock = 0; isock < num_sockets; ++isock) {
534 		query_id[isock] = mdns_query_send(sockets[isock], MDNS_RECORDTYPE_PTR, service,
535 		                                  strlen(service), buffer, capacity, 0);
536 		if (query_id[isock] < 0)
537 			printf("Failed to send mDNS query: %s\n", strerror(errno));
538 	}
539 
540 	// This is a simple implementation that loops for 5 seconds or as long as we get replies
541 	int res;
542 	printf("Reading mDNS query replies\n");
543 	do {
544 		struct timeval timeout;
545 		timeout.tv_sec = 5;
546 		timeout.tv_usec = 0;
547 
548 		int nfds = 0;
549 		fd_set readfs;
550 		FD_ZERO(&readfs);
551 		for (int isock = 0; isock < num_sockets; ++isock) {
552 			if (sockets[isock] >= nfds)
553 				nfds = sockets[isock] + 1;
554 			FD_SET(sockets[isock], &readfs);
555 		}
556 
557 		records = 0;
558 		res = select(nfds, &readfs, 0, 0, &timeout);
559 		if (res > 0) {
560 			for (int isock = 0; isock < num_sockets; ++isock) {
561 				if (FD_ISSET(sockets[isock], &readfs)) {
562 					records += mdns_query_recv(sockets[isock], buffer, capacity, query_callback,
563 					                           user_data, query_id[isock]);
564 				}
565 				FD_SET(sockets[isock], &readfs);
566 			}
567 		}
568 	} while (res > 0);
569 
570 	free(buffer);
571 
572 	for (int isock = 0; isock < num_sockets; ++isock)
573 		mdns_socket_close(sockets[isock]);
574 	printf("Closed socket%s\n", num_sockets ? "s" : "");
575 
576 	return 0;
577 }
578 
579 static int
service_mdns(const char * hostname,const char * service,int service_port)580 service_mdns(const char* hostname, const char* service, int service_port) {
581 	int sockets[32];
582 	int num_sockets = open_service_sockets(sockets, sizeof(sockets) / sizeof(sockets[0]));
583 	if (num_sockets <= 0) {
584 		printf("Failed to open any client sockets\n");
585 		return -1;
586 	}
587 	printf("Opened %d socket%s for mDNS service\n", num_sockets, num_sockets ? "s" : "");
588 
589 	printf("Service mDNS: %s:%d\n", service, service_port);
590 	printf("Hostname: %s\n", hostname);
591 
592 	size_t capacity = 2048;
593 	void* buffer = malloc(capacity);
594 
595 	service_record_t service_record;
596 	service_record.service = service;
597 	service_record.hostname = hostname;
598 	service_record.address_ipv4 = has_ipv4 ? service_address_ipv4 : 0;
599 	service_record.address_ipv6 = has_ipv6 ? service_address_ipv6 : 0;
600 	service_record.port = service_port;
601 
602 	// This is a crude implementation that checks for incoming queries
603 	while (1) {
604 		int nfds = 0;
605 		fd_set readfs;
606 		FD_ZERO(&readfs);
607 		for (int isock = 0; isock < num_sockets; ++isock) {
608 			if (sockets[isock] >= nfds)
609 				nfds = sockets[isock] + 1;
610 			FD_SET(sockets[isock], &readfs);
611 		}
612 
613 		if (select(nfds, &readfs, 0, 0, 0) >= 0) {
614 			for (int isock = 0; isock < num_sockets; ++isock) {
615 				if (FD_ISSET(sockets[isock], &readfs)) {
616 					mdns_socket_listen(sockets[isock], buffer, capacity, service_callback,
617 					                   &service_record);
618 				}
619 				FD_SET(sockets[isock], &readfs);
620 			}
621 		} else {
622 			break;
623 		}
624 	}
625 
626 	free(buffer);
627 
628 	for (int isock = 0; isock < num_sockets; ++isock)
629 		mdns_socket_close(sockets[isock]);
630 	printf("Closed socket%s\n", num_sockets ? "s" : "");
631 
632 	return 0;
633 }
634 
635 int
main(int argc,const char * const * argv)636 main(int argc, const char* const* argv) {
637 	int mode = 0;
638 	const char* service = "_test-mdns._tcp.local.";
639 	const char* hostname = "dummy-host";
640 	int service_port = 42424;
641 
642 #ifdef _WIN32
643 
644 	WORD versionWanted = MAKEWORD(1, 1);
645 	WSADATA wsaData;
646 	if (WSAStartup(versionWanted, &wsaData)) {
647 		printf("Failed to initialize WinSock\n");
648 		return -1;
649 	}
650 
651 	char hostname_buffer[256];
652 	DWORD hostname_size = (DWORD)sizeof(hostname_buffer);
653 	if (GetComputerNameA(hostname_buffer, &hostname_size))
654 		hostname = hostname_buffer;
655 
656 #else
657 
658 	char hostname_buffer[256];
659 	size_t hostname_size = sizeof(hostname_buffer);
660 	if (gethostname(hostname_buffer, hostname_size) == 0)
661 		hostname = hostname_buffer;
662 
663 #endif
664 
665 	for (int iarg = 0; iarg < argc; ++iarg) {
666 		if (strcmp(argv[iarg], "--discovery") == 0) {
667 			mode = 0;
668 		} else if (strcmp(argv[iarg], "--query") == 0) {
669 			mode = 1;
670 			++iarg;
671 			if (iarg < argc)
672 				service = argv[iarg];
673 		} else if (strcmp(argv[iarg], "--service") == 0) {
674 			mode = 2;
675 			++iarg;
676 			if (iarg < argc)
677 				service = argv[iarg];
678 		} else if (strcmp(argv[iarg], "--hostname") == 0) {
679 			++iarg;
680 			if (iarg < argc)
681 				hostname = argv[iarg];
682 		} else if (strcmp(argv[iarg], "--port") == 0) {
683 			++iarg;
684 			if (iarg < argc)
685 				service_port = atoi(argv[iarg]);
686 		}
687 	}
688 
689 	int ret;
690 	if (mode == 0)
691 		ret = send_dns_sd();
692 	else if (mode == 1)
693 		ret = send_mdns_query(service);
694 	else if (mode == 2)
695 		ret = service_mdns(hostname, service, service_port);
696 
697 #ifdef _WIN32
698 	WSACleanup();
699 #endif
700 
701 	return 0;
702 }
703