1 /*
2  * Copyright (C) Internet Systems Consortium, Inc. ("ISC")
3  *
4  * This Source Code Form is subject to the terms of the Mozilla Public
5  * License, v. 2.0. If a copy of the MPL was not distributed with this
6  * file, you can obtain one at https://mozilla.org/MPL/2.0/.
7  *
8  * See the COPYRIGHT file distributed with this work for additional
9  * information regarding copyright ownership.
10  */
11 
12 #include <inttypes.h>
13 
14 #include <isc/list.h>
15 #include <isc/mem.h>
16 #include <isc/refcount.h>
17 #include <isc/result.h>
18 #include <isc/rwlock.h>
19 #include <isc/util.h>
20 
21 #include <dns/name.h>
22 #include <dns/rbt.h>
23 #include <dns/transport.h>
24 
25 #define TRANSPORT_MAGIC	     ISC_MAGIC('T', 'r', 'n', 's')
26 #define VALID_TRANSPORT(ptr) ISC_MAGIC_VALID(ptr, TRANSPORT_MAGIC)
27 
28 #define TRANSPORT_LIST_MAGIC	  ISC_MAGIC('T', 'r', 'L', 's')
29 #define VALID_TRANSPORT_LIST(ptr) ISC_MAGIC_VALID(ptr, TRANSPORT_LIST_MAGIC)
30 
31 struct dns_transport_list {
32 	unsigned int magic;
33 	isc_refcount_t references;
34 	isc_mem_t *mctx;
35 	isc_rwlock_t lock;
36 	dns_rbt_t *transports[DNS_TRANSPORT_COUNT];
37 };
38 
39 struct dns_transport {
40 	unsigned int magic;
41 	isc_refcount_t references;
42 	isc_mem_t *mctx;
43 	dns_transport_type_t type;
44 	struct {
45 		char *certfile;
46 		char *keyfile;
47 		char *cafile;
48 		char *hostname;
49 	} tls;
50 	struct {
51 		char *endpoint;
52 		dns_http_mode_t mode;
53 	} doh;
54 };
55 
56 static void
free_dns_transport(void * node,void * arg)57 free_dns_transport(void *node, void *arg) {
58 	dns_transport_t *transport = node;
59 
60 	REQUIRE(node != NULL);
61 
62 	UNUSED(arg);
63 
64 	dns_transport_detach(&transport);
65 }
66 
67 static isc_result_t
list_add(dns_transport_list_t * list,const dns_name_t * name,const dns_transport_type_t type,dns_transport_t * transport)68 list_add(dns_transport_list_t *list, const dns_name_t *name,
69 	 const dns_transport_type_t type, dns_transport_t *transport) {
70 	isc_result_t result;
71 	dns_rbt_t *rbt = NULL;
72 
73 	RWLOCK(&list->lock, isc_rwlocktype_write);
74 	rbt = list->transports[type];
75 	INSIST(rbt != NULL);
76 
77 	result = dns_rbt_addname(rbt, name, transport);
78 
79 	RWUNLOCK(&list->lock, isc_rwlocktype_write);
80 
81 	return (result);
82 }
83 
84 dns_transport_type_t
dns_transport_get_type(dns_transport_t * transport)85 dns_transport_get_type(dns_transport_t *transport) {
86 	REQUIRE(VALID_TRANSPORT(transport));
87 
88 	return (transport->type);
89 }
90 
91 char *
dns_transport_get_certfile(dns_transport_t * transport)92 dns_transport_get_certfile(dns_transport_t *transport) {
93 	REQUIRE(VALID_TRANSPORT(transport));
94 
95 	return (transport->tls.certfile);
96 }
97 
98 char *
dns_transport_get_keyfile(dns_transport_t * transport)99 dns_transport_get_keyfile(dns_transport_t *transport) {
100 	REQUIRE(VALID_TRANSPORT(transport));
101 
102 	return (transport->tls.keyfile);
103 }
104 
105 char *
dns_transport_get_cafile(dns_transport_t * transport)106 dns_transport_get_cafile(dns_transport_t *transport) {
107 	REQUIRE(VALID_TRANSPORT(transport));
108 
109 	return (transport->tls.cafile);
110 }
111 
112 char *
dns_transport_get_hostname(dns_transport_t * transport)113 dns_transport_get_hostname(dns_transport_t *transport) {
114 	REQUIRE(VALID_TRANSPORT(transport));
115 
116 	return (transport->tls.hostname);
117 }
118 
119 char *
dns_transport_get_endpoint(dns_transport_t * transport)120 dns_transport_get_endpoint(dns_transport_t *transport) {
121 	REQUIRE(VALID_TRANSPORT(transport));
122 
123 	return (transport->doh.endpoint);
124 }
125 
126 dns_http_mode_t
dns_transport_get_mode(dns_transport_t * transport)127 dns_transport_get_mode(dns_transport_t *transport) {
128 	REQUIRE(VALID_TRANSPORT(transport));
129 
130 	return (transport->doh.mode);
131 }
132 
133 dns_transport_t *
dns_transport_new(const dns_name_t * name,dns_transport_type_t type,dns_transport_list_t * list)134 dns_transport_new(const dns_name_t *name, dns_transport_type_t type,
135 		  dns_transport_list_t *list) {
136 	dns_transport_t *transport = isc_mem_get(list->mctx,
137 						 sizeof(*transport));
138 	*transport = (dns_transport_t){ .type = type };
139 	isc_refcount_init(&transport->references, 1);
140 	isc_mem_attach(list->mctx, &transport->mctx);
141 	transport->magic = TRANSPORT_MAGIC;
142 
143 	list_add(list, name, type, transport);
144 
145 	return (transport);
146 }
147 
148 void
dns_transport_set_certfile(dns_transport_t * transport,const char * certfile)149 dns_transport_set_certfile(dns_transport_t *transport, const char *certfile) {
150 	REQUIRE(VALID_TRANSPORT(transport));
151 	REQUIRE(transport->type == DNS_TRANSPORT_TLS ||
152 		transport->type == DNS_TRANSPORT_HTTP);
153 
154 	if (certfile != NULL) {
155 		transport->tls.certfile = isc_mem_strdup(transport->mctx,
156 							 certfile);
157 	}
158 }
159 
160 void
dns_transport_set_keyfile(dns_transport_t * transport,const char * keyfile)161 dns_transport_set_keyfile(dns_transport_t *transport, const char *keyfile) {
162 	REQUIRE(VALID_TRANSPORT(transport));
163 	REQUIRE(transport->type == DNS_TRANSPORT_TLS ||
164 		transport->type == DNS_TRANSPORT_HTTP);
165 
166 	if (keyfile != NULL) {
167 		transport->tls.keyfile = isc_mem_strdup(transport->mctx,
168 							keyfile);
169 	}
170 }
171 
172 void
dns_transport_set_cafile(dns_transport_t * transport,const char * cafile)173 dns_transport_set_cafile(dns_transport_t *transport, const char *cafile) {
174 	REQUIRE(VALID_TRANSPORT(transport));
175 	REQUIRE(transport->type == DNS_TRANSPORT_TLS ||
176 		transport->type == DNS_TRANSPORT_HTTP);
177 
178 	if (cafile != NULL) {
179 		transport->tls.cafile = isc_mem_strdup(transport->mctx, cafile);
180 	}
181 }
182 
183 void
dns_transport_set_hostname(dns_transport_t * transport,const char * hostname)184 dns_transport_set_hostname(dns_transport_t *transport, const char *hostname) {
185 	REQUIRE(VALID_TRANSPORT(transport));
186 	REQUIRE(transport->type == DNS_TRANSPORT_TLS ||
187 		transport->type == DNS_TRANSPORT_HTTP);
188 
189 	if (hostname != NULL) {
190 		transport->tls.hostname = isc_mem_strdup(transport->mctx,
191 							 hostname);
192 	}
193 }
194 
195 void
dns_transport_set_endpoint(dns_transport_t * transport,const char * endpoint)196 dns_transport_set_endpoint(dns_transport_t *transport, const char *endpoint) {
197 	REQUIRE(VALID_TRANSPORT(transport));
198 	REQUIRE(transport->type == DNS_TRANSPORT_HTTP);
199 
200 	if (endpoint != NULL) {
201 		transport->doh.endpoint = isc_mem_strdup(transport->mctx,
202 							 endpoint);
203 	}
204 }
205 
206 void
dns_transport_set_mode(dns_transport_t * transport,dns_http_mode_t mode)207 dns_transport_set_mode(dns_transport_t *transport, dns_http_mode_t mode) {
208 	REQUIRE(VALID_TRANSPORT(transport));
209 	REQUIRE(transport->type == DNS_TRANSPORT_HTTP);
210 
211 	transport->doh.mode = mode;
212 }
213 
214 static void
transport_destroy(dns_transport_t * transport)215 transport_destroy(dns_transport_t *transport) {
216 	isc_refcount_destroy(&transport->references);
217 	transport->magic = 0;
218 
219 	if (transport->doh.endpoint != NULL) {
220 		isc_mem_free(transport->mctx, transport->doh.endpoint);
221 	}
222 	if (transport->tls.hostname != NULL) {
223 		isc_mem_free(transport->mctx, transport->tls.hostname);
224 	}
225 	if (transport->tls.cafile != NULL) {
226 		isc_mem_free(transport->mctx, transport->tls.cafile);
227 	}
228 	if (transport->tls.keyfile != NULL) {
229 		isc_mem_free(transport->mctx, transport->tls.keyfile);
230 	}
231 	if (transport->tls.certfile != NULL) {
232 		isc_mem_free(transport->mctx, transport->tls.certfile);
233 	}
234 
235 	isc_mem_putanddetach(&transport->mctx, transport, sizeof(*transport));
236 }
237 
238 void
dns_transport_attach(dns_transport_t * source,dns_transport_t ** targetp)239 dns_transport_attach(dns_transport_t *source, dns_transport_t **targetp) {
240 	REQUIRE(source != NULL);
241 	REQUIRE(targetp != NULL && *targetp == NULL);
242 
243 	isc_refcount_increment(&source->references);
244 
245 	*targetp = source;
246 }
247 
248 void
dns_transport_detach(dns_transport_t ** transportp)249 dns_transport_detach(dns_transport_t **transportp) {
250 	dns_transport_t *transport = NULL;
251 
252 	REQUIRE(transportp != NULL);
253 	REQUIRE(VALID_TRANSPORT(*transportp));
254 
255 	transport = *transportp;
256 	*transportp = NULL;
257 
258 	if (isc_refcount_decrement(&transport->references) == 1) {
259 		transport_destroy(transport);
260 	}
261 }
262 
263 dns_transport_t *
dns_transport_find(const dns_transport_type_t type,const dns_name_t * name,dns_transport_list_t * list)264 dns_transport_find(const dns_transport_type_t type, const dns_name_t *name,
265 		   dns_transport_list_t *list) {
266 	isc_result_t result;
267 	dns_transport_t *transport = NULL;
268 	dns_rbt_t *rbt = NULL;
269 
270 	REQUIRE(VALID_TRANSPORT_LIST(list));
271 	REQUIRE(list->transports[type] != NULL);
272 
273 	rbt = list->transports[type];
274 
275 	RWLOCK(&list->lock, isc_rwlocktype_read);
276 	result = dns_rbt_findname(rbt, name, 0, NULL, (void *)&transport);
277 	if (result == ISC_R_SUCCESS) {
278 		isc_refcount_increment(&transport->references);
279 	}
280 	RWUNLOCK(&list->lock, isc_rwlocktype_read);
281 
282 	return (transport);
283 }
284 
285 dns_transport_list_t *
dns_transport_list_new(isc_mem_t * mctx)286 dns_transport_list_new(isc_mem_t *mctx) {
287 	dns_transport_list_t *list = isc_mem_get(mctx, sizeof(*list));
288 
289 	*list = (dns_transport_list_t){ 0 };
290 
291 	isc_rwlock_init(&list->lock, 0, 0);
292 
293 	isc_mem_attach(mctx, &list->mctx);
294 	isc_refcount_init(&list->references, 1);
295 
296 	list->magic = TRANSPORT_LIST_MAGIC;
297 
298 	for (size_t type = 0; type < DNS_TRANSPORT_COUNT; type++) {
299 		isc_result_t result;
300 		result = dns_rbt_create(list->mctx, free_dns_transport, NULL,
301 					&list->transports[type]);
302 		RUNTIME_CHECK(result == ISC_R_SUCCESS);
303 	}
304 
305 	return (list);
306 }
307 
308 void
dns_transport_list_attach(dns_transport_list_t * source,dns_transport_list_t ** targetp)309 dns_transport_list_attach(dns_transport_list_t *source,
310 			  dns_transport_list_t **targetp) {
311 	REQUIRE(VALID_TRANSPORT_LIST(source));
312 	REQUIRE(targetp != NULL && *targetp == NULL);
313 
314 	isc_refcount_increment(&source->references);
315 
316 	*targetp = source;
317 }
318 
319 static void
transport_list_destroy(dns_transport_list_t * list)320 transport_list_destroy(dns_transport_list_t *list) {
321 	isc_refcount_destroy(&list->references);
322 	list->magic = 0;
323 
324 	for (size_t type = 0; type < DNS_TRANSPORT_COUNT; type++) {
325 		if (list->transports[type] != NULL) {
326 			dns_rbt_destroy(&list->transports[type]);
327 		}
328 	}
329 	isc_rwlock_destroy(&list->lock);
330 	isc_mem_putanddetach(&list->mctx, list, sizeof(*list));
331 }
332 
333 void
dns_transport_list_detach(dns_transport_list_t ** listp)334 dns_transport_list_detach(dns_transport_list_t **listp) {
335 	dns_transport_list_t *list = NULL;
336 
337 	REQUIRE(listp != NULL);
338 	REQUIRE(VALID_TRANSPORT_LIST(*listp));
339 
340 	list = *listp;
341 	*listp = NULL;
342 
343 	if (isc_refcount_decrement(&list->references) == 1) {
344 		transport_list_destroy(list);
345 	}
346 }
347