1 // SPDX-License-Identifier: LGPL-2.1+
2 /*
3  * Copyright (C) 2015-2020 Jason A. Donenfeld <Jason@zx2c4.com>. All Rights Reserved.
4  * Copyright (C) 2008-2012 Pablo Neira Ayuso <pablo@netfilter.org>.
5  */
6 
7 #define _GNU_SOURCE
8 
9 #include <errno.h>
10 #include <linux/genetlink.h>
11 #include <linux/if_link.h>
12 #include <linux/netlink.h>
13 #include <linux/rtnetlink.h>
14 #include <netinet/in.h>
15 #include <stdbool.h>
16 #include <stdio.h>
17 #include <stdlib.h>
18 #include <string.h>
19 #include <sys/socket.h>
20 #include <time.h>
21 #include <unistd.h>
22 #include <fcntl.h>
23 #include <assert.h>
24 
25 #include "wireguard.h"
26 
27 /* wireguard.h netlink uapi: */
28 
29 #define WG_GENL_NAME "wireguard"
30 #define WG_GENL_VERSION 1
31 
32 enum wg_cmd {
33 	WG_CMD_GET_DEVICE,
34 	WG_CMD_SET_DEVICE,
35 	__WG_CMD_MAX
36 };
37 
38 enum wgdevice_flag {
39 	WGDEVICE_F_REPLACE_PEERS = 1U << 0
40 };
41 enum wgdevice_attribute {
42 	WGDEVICE_A_UNSPEC,
43 	WGDEVICE_A_IFINDEX,
44 	WGDEVICE_A_IFNAME,
45 	WGDEVICE_A_PRIVATE_KEY,
46 	WGDEVICE_A_PUBLIC_KEY,
47 	WGDEVICE_A_FLAGS,
48 	WGDEVICE_A_LISTEN_PORT,
49 	WGDEVICE_A_FWMARK,
50 	WGDEVICE_A_PEERS,
51 	__WGDEVICE_A_LAST
52 };
53 
54 enum wgpeer_flag {
55 	WGPEER_F_REMOVE_ME = 1U << 0,
56 	WGPEER_F_REPLACE_ALLOWEDIPS = 1U << 1
57 };
58 enum wgpeer_attribute {
59 	WGPEER_A_UNSPEC,
60 	WGPEER_A_PUBLIC_KEY,
61 	WGPEER_A_PRESHARED_KEY,
62 	WGPEER_A_FLAGS,
63 	WGPEER_A_ENDPOINT,
64 	WGPEER_A_PERSISTENT_KEEPALIVE_INTERVAL,
65 	WGPEER_A_LAST_HANDSHAKE_TIME,
66 	WGPEER_A_RX_BYTES,
67 	WGPEER_A_TX_BYTES,
68 	WGPEER_A_ALLOWEDIPS,
69 	WGPEER_A_PROTOCOL_VERSION,
70 	__WGPEER_A_LAST
71 };
72 
73 enum wgallowedip_attribute {
74 	WGALLOWEDIP_A_UNSPEC,
75 	WGALLOWEDIP_A_FAMILY,
76 	WGALLOWEDIP_A_IPADDR,
77 	WGALLOWEDIP_A_CIDR_MASK,
78 	__WGALLOWEDIP_A_LAST
79 };
80 
81 /* libmnl mini library: */
82 
83 #define MNL_SOCKET_AUTOPID 0
84 #define MNL_ALIGNTO 4
85 #define MNL_ALIGN(len) (((len)+MNL_ALIGNTO-1) & ~(MNL_ALIGNTO-1))
86 #define MNL_NLMSG_HDRLEN MNL_ALIGN(sizeof(struct nlmsghdr))
87 #define MNL_ATTR_HDRLEN MNL_ALIGN(sizeof(struct nlattr))
88 
89 enum mnl_attr_data_type {
90 	MNL_TYPE_UNSPEC,
91 	MNL_TYPE_U8,
92 	MNL_TYPE_U16,
93 	MNL_TYPE_U32,
94 	MNL_TYPE_U64,
95 	MNL_TYPE_STRING,
96 	MNL_TYPE_FLAG,
97 	MNL_TYPE_MSECS,
98 	MNL_TYPE_NESTED,
99 	MNL_TYPE_NESTED_COMPAT,
100 	MNL_TYPE_NUL_STRING,
101 	MNL_TYPE_BINARY,
102 	MNL_TYPE_MAX,
103 };
104 
105 #define mnl_attr_for_each(attr, nlh, offset) \
106 	for ((attr) = mnl_nlmsg_get_payload_offset((nlh), (offset)); \
107 	     mnl_attr_ok((attr), (char *)mnl_nlmsg_get_payload_tail(nlh) - (char *)(attr)); \
108 	     (attr) = mnl_attr_next(attr))
109 
110 #define mnl_attr_for_each_nested(attr, nest) \
111 	for ((attr) = mnl_attr_get_payload(nest); \
112 	     mnl_attr_ok((attr), (char *)mnl_attr_get_payload(nest) + mnl_attr_get_payload_len(nest) - (char *)(attr)); \
113 	     (attr) = mnl_attr_next(attr))
114 
115 #define mnl_attr_for_each_payload(payload, payload_size) \
116 	for ((attr) = (payload); \
117 	     mnl_attr_ok((attr), (char *)(payload) + payload_size - (char *)(attr)); \
118 	     (attr) = mnl_attr_next(attr))
119 
120 #define MNL_CB_ERROR	-1
121 #define MNL_CB_STOP	0
122 #define MNL_CB_OK	1
123 
124 typedef int (*mnl_attr_cb_t)(const struct nlattr *attr, void *data);
125 typedef int (*mnl_cb_t)(const struct nlmsghdr *nlh, void *data);
126 
127 #ifndef MNL_ARRAY_SIZE
128 #define MNL_ARRAY_SIZE(a) (sizeof(a)/sizeof((a)[0]))
129 #endif
130 
mnl_ideal_socket_buffer_size(void)131 static size_t mnl_ideal_socket_buffer_size(void)
132 {
133 	static size_t size = 0;
134 
135 	if (size)
136 		return size;
137 	size = (size_t)sysconf(_SC_PAGESIZE);
138 	if (size > 8192)
139 		size = 8192;
140 	return size;
141 }
142 
mnl_nlmsg_size(size_t len)143 static size_t mnl_nlmsg_size(size_t len)
144 {
145 	return len + MNL_NLMSG_HDRLEN;
146 }
147 
mnl_nlmsg_put_header(void * buf)148 static struct nlmsghdr *mnl_nlmsg_put_header(void *buf)
149 {
150 	int len = MNL_ALIGN(sizeof(struct nlmsghdr));
151 	struct nlmsghdr *nlh = buf;
152 
153 	memset(buf, 0, len);
154 	nlh->nlmsg_len = len;
155 	return nlh;
156 }
157 
mnl_nlmsg_put_extra_header(struct nlmsghdr * nlh,size_t size)158 static void *mnl_nlmsg_put_extra_header(struct nlmsghdr *nlh, size_t size)
159 {
160 	char *ptr = (char *)nlh + nlh->nlmsg_len;
161 	size_t len = MNL_ALIGN(size);
162 	nlh->nlmsg_len += len;
163 	memset(ptr, 0, len);
164 	return ptr;
165 }
166 
mnl_nlmsg_get_payload(const struct nlmsghdr * nlh)167 static void *mnl_nlmsg_get_payload(const struct nlmsghdr *nlh)
168 {
169 	return (void *)nlh + MNL_NLMSG_HDRLEN;
170 }
171 
mnl_nlmsg_get_payload_offset(const struct nlmsghdr * nlh,size_t offset)172 static void *mnl_nlmsg_get_payload_offset(const struct nlmsghdr *nlh, size_t offset)
173 {
174 	return (void *)nlh + MNL_NLMSG_HDRLEN + MNL_ALIGN(offset);
175 }
176 
mnl_nlmsg_ok(const struct nlmsghdr * nlh,int len)177 static bool mnl_nlmsg_ok(const struct nlmsghdr *nlh, int len)
178 {
179 	return len >= (int)sizeof(struct nlmsghdr) &&
180 	       nlh->nlmsg_len >= sizeof(struct nlmsghdr) &&
181 	       (int)nlh->nlmsg_len <= len;
182 }
183 
mnl_nlmsg_next(const struct nlmsghdr * nlh,int * len)184 static struct nlmsghdr *mnl_nlmsg_next(const struct nlmsghdr *nlh, int *len)
185 {
186 	*len -= MNL_ALIGN(nlh->nlmsg_len);
187 	return (struct nlmsghdr *)((void *)nlh + MNL_ALIGN(nlh->nlmsg_len));
188 }
189 
mnl_nlmsg_get_payload_tail(const struct nlmsghdr * nlh)190 static void *mnl_nlmsg_get_payload_tail(const struct nlmsghdr *nlh)
191 {
192 	return (void *)nlh + MNL_ALIGN(nlh->nlmsg_len);
193 }
194 
mnl_nlmsg_seq_ok(const struct nlmsghdr * nlh,unsigned int seq)195 static bool mnl_nlmsg_seq_ok(const struct nlmsghdr *nlh, unsigned int seq)
196 {
197 	return nlh->nlmsg_seq && seq ? nlh->nlmsg_seq == seq : true;
198 }
199 
mnl_nlmsg_portid_ok(const struct nlmsghdr * nlh,unsigned int portid)200 static bool mnl_nlmsg_portid_ok(const struct nlmsghdr *nlh, unsigned int portid)
201 {
202 	return nlh->nlmsg_pid && portid ? nlh->nlmsg_pid == portid : true;
203 }
204 
mnl_attr_get_type(const struct nlattr * attr)205 static uint16_t mnl_attr_get_type(const struct nlattr *attr)
206 {
207 	return attr->nla_type & NLA_TYPE_MASK;
208 }
209 
mnl_attr_get_payload_len(const struct nlattr * attr)210 static uint16_t mnl_attr_get_payload_len(const struct nlattr *attr)
211 {
212 	return attr->nla_len - MNL_ATTR_HDRLEN;
213 }
214 
mnl_attr_get_payload(const struct nlattr * attr)215 static void *mnl_attr_get_payload(const struct nlattr *attr)
216 {
217 	return (void *)attr + MNL_ATTR_HDRLEN;
218 }
219 
mnl_attr_ok(const struct nlattr * attr,int len)220 static bool mnl_attr_ok(const struct nlattr *attr, int len)
221 {
222 	return len >= (int)sizeof(struct nlattr) &&
223 	       attr->nla_len >= sizeof(struct nlattr) &&
224 	       (int)attr->nla_len <= len;
225 }
226 
mnl_attr_next(const struct nlattr * attr)227 static struct nlattr *mnl_attr_next(const struct nlattr *attr)
228 {
229 	return (struct nlattr *)((void *)attr + MNL_ALIGN(attr->nla_len));
230 }
231 
mnl_attr_type_valid(const struct nlattr * attr,uint16_t max)232 static int mnl_attr_type_valid(const struct nlattr *attr, uint16_t max)
233 {
234 	if (mnl_attr_get_type(attr) > max) {
235 		errno = EOPNOTSUPP;
236 		return -1;
237 	}
238 	return 1;
239 }
240 
__mnl_attr_validate(const struct nlattr * attr,enum mnl_attr_data_type type,size_t exp_len)241 static int __mnl_attr_validate(const struct nlattr *attr,
242 			       enum mnl_attr_data_type type, size_t exp_len)
243 {
244 	uint16_t attr_len = mnl_attr_get_payload_len(attr);
245 	const char *attr_data = mnl_attr_get_payload(attr);
246 
247 	if (attr_len < exp_len) {
248 		errno = ERANGE;
249 		return -1;
250 	}
251 	switch(type) {
252 	case MNL_TYPE_FLAG:
253 		if (attr_len > 0) {
254 			errno = ERANGE;
255 			return -1;
256 		}
257 		break;
258 	case MNL_TYPE_NUL_STRING:
259 		if (attr_len == 0) {
260 			errno = ERANGE;
261 			return -1;
262 		}
263 		if (attr_data[attr_len-1] != '\0') {
264 			errno = EINVAL;
265 			return -1;
266 		}
267 		break;
268 	case MNL_TYPE_STRING:
269 		if (attr_len == 0) {
270 			errno = ERANGE;
271 			return -1;
272 		}
273 		break;
274 	case MNL_TYPE_NESTED:
275 
276 		if (attr_len == 0)
277 			break;
278 
279 		if (attr_len < MNL_ATTR_HDRLEN) {
280 			errno = ERANGE;
281 			return -1;
282 		}
283 		break;
284 	default:
285 
286 		break;
287 	}
288 	if (exp_len && attr_len > exp_len) {
289 		errno = ERANGE;
290 		return -1;
291 	}
292 	return 0;
293 }
294 
295 static const size_t mnl_attr_data_type_len[MNL_TYPE_MAX] = {
296 	[MNL_TYPE_U8]		= sizeof(uint8_t),
297 	[MNL_TYPE_U16]		= sizeof(uint16_t),
298 	[MNL_TYPE_U32]		= sizeof(uint32_t),
299 	[MNL_TYPE_U64]		= sizeof(uint64_t),
300 	[MNL_TYPE_MSECS]	= sizeof(uint64_t),
301 };
302 
mnl_attr_validate(const struct nlattr * attr,enum mnl_attr_data_type type)303 static int mnl_attr_validate(const struct nlattr *attr, enum mnl_attr_data_type type)
304 {
305 	int exp_len;
306 
307 	if (type >= MNL_TYPE_MAX) {
308 		errno = EINVAL;
309 		return -1;
310 	}
311 	exp_len = mnl_attr_data_type_len[type];
312 	return __mnl_attr_validate(attr, type, exp_len);
313 }
314 
mnl_attr_parse(const struct nlmsghdr * nlh,unsigned int offset,mnl_attr_cb_t cb,void * data)315 static int mnl_attr_parse(const struct nlmsghdr *nlh, unsigned int offset,
316 			  mnl_attr_cb_t cb, void *data)
317 {
318 	int ret = MNL_CB_OK;
319 	const struct nlattr *attr;
320 
321 	mnl_attr_for_each(attr, nlh, offset)
322 		if ((ret = cb(attr, data)) <= MNL_CB_STOP)
323 			return ret;
324 	return ret;
325 }
326 
mnl_attr_parse_nested(const struct nlattr * nested,mnl_attr_cb_t cb,void * data)327 static int mnl_attr_parse_nested(const struct nlattr *nested, mnl_attr_cb_t cb,
328 				 void *data)
329 {
330 	int ret = MNL_CB_OK;
331 	const struct nlattr *attr;
332 
333 	mnl_attr_for_each_nested(attr, nested)
334 		if ((ret = cb(attr, data)) <= MNL_CB_STOP)
335 			return ret;
336 	return ret;
337 }
338 
mnl_attr_get_u8(const struct nlattr * attr)339 static uint8_t mnl_attr_get_u8(const struct nlattr *attr)
340 {
341 	return *((uint8_t *)mnl_attr_get_payload(attr));
342 }
343 
mnl_attr_get_u16(const struct nlattr * attr)344 static uint16_t mnl_attr_get_u16(const struct nlattr *attr)
345 {
346 	return *((uint16_t *)mnl_attr_get_payload(attr));
347 }
348 
mnl_attr_get_u32(const struct nlattr * attr)349 static uint32_t mnl_attr_get_u32(const struct nlattr *attr)
350 {
351 	return *((uint32_t *)mnl_attr_get_payload(attr));
352 }
353 
mnl_attr_get_u64(const struct nlattr * attr)354 static uint64_t mnl_attr_get_u64(const struct nlattr *attr)
355 {
356 	uint64_t tmp;
357 	memcpy(&tmp, mnl_attr_get_payload(attr), sizeof(tmp));
358 	return tmp;
359 }
360 
mnl_attr_get_str(const struct nlattr * attr)361 static const char *mnl_attr_get_str(const struct nlattr *attr)
362 {
363 	return mnl_attr_get_payload(attr);
364 }
365 
mnl_attr_put(struct nlmsghdr * nlh,uint16_t type,size_t len,const void * data)366 static void mnl_attr_put(struct nlmsghdr *nlh, uint16_t type, size_t len,
367 			 const void *data)
368 {
369 	struct nlattr *attr = mnl_nlmsg_get_payload_tail(nlh);
370 	uint16_t payload_len = MNL_ALIGN(sizeof(struct nlattr)) + len;
371 	int pad;
372 
373 	attr->nla_type = type;
374 	attr->nla_len = payload_len;
375 	memcpy(mnl_attr_get_payload(attr), data, len);
376 	nlh->nlmsg_len += MNL_ALIGN(payload_len);
377 	pad = MNL_ALIGN(len) - len;
378 	if (pad > 0)
379 		memset(mnl_attr_get_payload(attr) + len, 0, pad);
380 }
381 
mnl_attr_put_u16(struct nlmsghdr * nlh,uint16_t type,uint16_t data)382 static void mnl_attr_put_u16(struct nlmsghdr *nlh, uint16_t type, uint16_t data)
383 {
384 	mnl_attr_put(nlh, type, sizeof(uint16_t), &data);
385 }
386 
mnl_attr_put_u32(struct nlmsghdr * nlh,uint16_t type,uint32_t data)387 static void mnl_attr_put_u32(struct nlmsghdr *nlh, uint16_t type, uint32_t data)
388 {
389 	mnl_attr_put(nlh, type, sizeof(uint32_t), &data);
390 }
391 
mnl_attr_put_strz(struct nlmsghdr * nlh,uint16_t type,const char * data)392 static void mnl_attr_put_strz(struct nlmsghdr *nlh, uint16_t type, const char *data)
393 {
394 	mnl_attr_put(nlh, type, strlen(data)+1, data);
395 }
396 
mnl_attr_nest_start(struct nlmsghdr * nlh,uint16_t type)397 static struct nlattr *mnl_attr_nest_start(struct nlmsghdr *nlh, uint16_t type)
398 {
399 	struct nlattr *start = mnl_nlmsg_get_payload_tail(nlh);
400 
401 	start->nla_type = NLA_F_NESTED | type;
402 	nlh->nlmsg_len += MNL_ALIGN(sizeof(struct nlattr));
403 	return start;
404 }
405 
mnl_attr_put_check(struct nlmsghdr * nlh,size_t buflen,uint16_t type,size_t len,const void * data)406 static bool mnl_attr_put_check(struct nlmsghdr *nlh, size_t buflen,
407 			       uint16_t type, size_t len, const void *data)
408 {
409 	if (nlh->nlmsg_len + MNL_ATTR_HDRLEN + MNL_ALIGN(len) > buflen)
410 		return false;
411 	mnl_attr_put(nlh, type, len, data);
412 	return true;
413 }
414 
mnl_attr_put_u8_check(struct nlmsghdr * nlh,size_t buflen,uint16_t type,uint8_t data)415 static bool mnl_attr_put_u8_check(struct nlmsghdr *nlh, size_t buflen,
416 				  uint16_t type, uint8_t data)
417 {
418 	return mnl_attr_put_check(nlh, buflen, type, sizeof(uint8_t), &data);
419 }
420 
mnl_attr_put_u16_check(struct nlmsghdr * nlh,size_t buflen,uint16_t type,uint16_t data)421 static bool mnl_attr_put_u16_check(struct nlmsghdr *nlh, size_t buflen,
422 				   uint16_t type, uint16_t data)
423 {
424 	return mnl_attr_put_check(nlh, buflen, type, sizeof(uint16_t), &data);
425 }
426 
mnl_attr_put_u32_check(struct nlmsghdr * nlh,size_t buflen,uint16_t type,uint32_t data)427 static bool mnl_attr_put_u32_check(struct nlmsghdr *nlh, size_t buflen,
428 				   uint16_t type, uint32_t data)
429 {
430 	return mnl_attr_put_check(nlh, buflen, type, sizeof(uint32_t), &data);
431 }
432 
mnl_attr_nest_start_check(struct nlmsghdr * nlh,size_t buflen,uint16_t type)433 static struct nlattr *mnl_attr_nest_start_check(struct nlmsghdr *nlh, size_t buflen,
434 						uint16_t type)
435 {
436 	if (nlh->nlmsg_len + MNL_ATTR_HDRLEN > buflen)
437 		return NULL;
438 	return mnl_attr_nest_start(nlh, type);
439 }
440 
mnl_attr_nest_end(struct nlmsghdr * nlh,struct nlattr * start)441 static void mnl_attr_nest_end(struct nlmsghdr *nlh, struct nlattr *start)
442 {
443 	start->nla_len = mnl_nlmsg_get_payload_tail(nlh) - (void *)start;
444 }
445 
mnl_attr_nest_cancel(struct nlmsghdr * nlh,struct nlattr * start)446 static void mnl_attr_nest_cancel(struct nlmsghdr *nlh, struct nlattr *start)
447 {
448 	nlh->nlmsg_len -= mnl_nlmsg_get_payload_tail(nlh) - (void *)start;
449 }
450 
mnl_cb_noop(const struct nlmsghdr * nlh,void * data)451 static int mnl_cb_noop(__attribute__((unused)) const struct nlmsghdr *nlh, __attribute__((unused)) void *data)
452 {
453 	return MNL_CB_OK;
454 }
455 
mnl_cb_error(const struct nlmsghdr * nlh,void * data)456 static int mnl_cb_error(const struct nlmsghdr *nlh, __attribute__((unused)) void *data)
457 {
458 	const struct nlmsgerr *err = mnl_nlmsg_get_payload(nlh);
459 
460 	if (nlh->nlmsg_len < mnl_nlmsg_size(sizeof(struct nlmsgerr))) {
461 		errno = EBADMSG;
462 		return MNL_CB_ERROR;
463 	}
464 
465 	if (err->error < 0)
466 		errno = -err->error;
467 	else
468 		errno = err->error;
469 
470 	return err->error == 0 ? MNL_CB_STOP : MNL_CB_ERROR;
471 }
472 
mnl_cb_stop(const struct nlmsghdr * nlh,void * data)473 static int mnl_cb_stop(__attribute__((unused)) const struct nlmsghdr *nlh, __attribute__((unused)) void *data)
474 {
475 	return MNL_CB_STOP;
476 }
477 
478 static const mnl_cb_t default_cb_array[NLMSG_MIN_TYPE] = {
479 	[NLMSG_NOOP]	= mnl_cb_noop,
480 	[NLMSG_ERROR]	= mnl_cb_error,
481 	[NLMSG_DONE]	= mnl_cb_stop,
482 	[NLMSG_OVERRUN]	= mnl_cb_noop,
483 };
484 
__mnl_cb_run(const void * buf,size_t numbytes,unsigned int seq,unsigned int portid,mnl_cb_t cb_data,void * data,const mnl_cb_t * cb_ctl_array,unsigned int cb_ctl_array_len)485 static int __mnl_cb_run(const void *buf, size_t numbytes,
486 			unsigned int seq, unsigned int portid,
487 			mnl_cb_t cb_data, void *data,
488 			const mnl_cb_t *cb_ctl_array,
489 			unsigned int cb_ctl_array_len)
490 {
491 	int ret = MNL_CB_OK, len = numbytes;
492 	const struct nlmsghdr *nlh = buf;
493 
494 	while (mnl_nlmsg_ok(nlh, len)) {
495 
496 		if (!mnl_nlmsg_portid_ok(nlh, portid)) {
497 			errno = ESRCH;
498 			return -1;
499 		}
500 
501 		if (!mnl_nlmsg_seq_ok(nlh, seq)) {
502 			errno = EPROTO;
503 			return -1;
504 		}
505 
506 		if (nlh->nlmsg_flags & NLM_F_DUMP_INTR) {
507 			errno = EINTR;
508 			return -1;
509 		}
510 
511 		if (nlh->nlmsg_type >= NLMSG_MIN_TYPE) {
512 			if (cb_data){
513 				ret = cb_data(nlh, data);
514 				if (ret <= MNL_CB_STOP)
515 					goto out;
516 			}
517 		} else if (nlh->nlmsg_type < cb_ctl_array_len) {
518 			if (cb_ctl_array && cb_ctl_array[nlh->nlmsg_type]) {
519 				ret = cb_ctl_array[nlh->nlmsg_type](nlh, data);
520 				if (ret <= MNL_CB_STOP)
521 					goto out;
522 			}
523 		} else if (default_cb_array[nlh->nlmsg_type]) {
524 			ret = default_cb_array[nlh->nlmsg_type](nlh, data);
525 			if (ret <= MNL_CB_STOP)
526 				goto out;
527 		}
528 		nlh = mnl_nlmsg_next(nlh, &len);
529 	}
530 out:
531 	return ret;
532 }
533 
mnl_cb_run2(const void * buf,size_t numbytes,unsigned int seq,unsigned int portid,mnl_cb_t cb_data,void * data,const mnl_cb_t * cb_ctl_array,unsigned int cb_ctl_array_len)534 static int mnl_cb_run2(const void *buf, size_t numbytes, unsigned int seq,
535 		       unsigned int portid, mnl_cb_t cb_data, void *data,
536 		       const mnl_cb_t *cb_ctl_array, unsigned int cb_ctl_array_len)
537 {
538 	return __mnl_cb_run(buf, numbytes, seq, portid, cb_data, data,
539 			    cb_ctl_array, cb_ctl_array_len);
540 }
541 
mnl_cb_run(const void * buf,size_t numbytes,unsigned int seq,unsigned int portid,mnl_cb_t cb_data,void * data)542 static int mnl_cb_run(const void *buf, size_t numbytes, unsigned int seq,
543 		      unsigned int portid, mnl_cb_t cb_data, void *data)
544 {
545 	return __mnl_cb_run(buf, numbytes, seq, portid, cb_data, data, NULL, 0);
546 }
547 
548 struct mnl_socket {
549 	int 			fd;
550 	struct sockaddr_nl	addr;
551 };
552 
mnl_socket_get_portid(const struct mnl_socket * nl)553 static unsigned int mnl_socket_get_portid(const struct mnl_socket *nl)
554 {
555 	return nl->addr.nl_pid;
556 }
557 
__mnl_socket_open(int bus,int flags)558 static struct mnl_socket *__mnl_socket_open(int bus, int flags)
559 {
560 	struct mnl_socket *nl;
561 
562 	nl = calloc(1, sizeof(struct mnl_socket));
563 	if (nl == NULL)
564 		return NULL;
565 
566 	nl->fd = socket(AF_NETLINK, SOCK_RAW | flags, bus);
567 	if (nl->fd == -1) {
568 		free(nl);
569 		return NULL;
570 	}
571 
572 	return nl;
573 }
574 
mnl_socket_open(int bus)575 static struct mnl_socket *mnl_socket_open(int bus)
576 {
577 	return __mnl_socket_open(bus, 0);
578 }
579 
mnl_socket_bind(struct mnl_socket * nl,unsigned int groups,pid_t pid)580 static int mnl_socket_bind(struct mnl_socket *nl, unsigned int groups, pid_t pid)
581 {
582 	int ret;
583 	socklen_t addr_len;
584 
585 	nl->addr.nl_family = AF_NETLINK;
586 	nl->addr.nl_groups = groups;
587 	nl->addr.nl_pid = pid;
588 
589 	ret = bind(nl->fd, (struct sockaddr *) &nl->addr, sizeof (nl->addr));
590 	if (ret < 0)
591 		return ret;
592 
593 	addr_len = sizeof(nl->addr);
594 	ret = getsockname(nl->fd, (struct sockaddr *) &nl->addr, &addr_len);
595 	if (ret < 0)
596 		return ret;
597 
598 	if (addr_len != sizeof(nl->addr)) {
599 		errno = EINVAL;
600 		return -1;
601 	}
602 	if (nl->addr.nl_family != AF_NETLINK) {
603 		errno = EINVAL;
604 		return -1;
605 	}
606 	return 0;
607 }
608 
mnl_socket_sendto(const struct mnl_socket * nl,const void * buf,size_t len)609 static ssize_t mnl_socket_sendto(const struct mnl_socket *nl, const void *buf,
610 				 size_t len)
611 {
612 	static const struct sockaddr_nl snl = {
613 		.nl_family = AF_NETLINK
614 	};
615 	return sendto(nl->fd, buf, len, 0,
616 		      (struct sockaddr *) &snl, sizeof(snl));
617 }
618 
mnl_socket_recvfrom(const struct mnl_socket * nl,void * buf,size_t bufsiz)619 static ssize_t mnl_socket_recvfrom(const struct mnl_socket *nl, void *buf,
620 				   size_t bufsiz)
621 {
622 	ssize_t ret;
623 	struct sockaddr_nl addr;
624 	struct iovec iov = {
625 		.iov_base	= buf,
626 		.iov_len	= bufsiz,
627 	};
628 	struct msghdr msg = {
629 		.msg_name	= &addr,
630 		.msg_namelen	= sizeof(struct sockaddr_nl),
631 		.msg_iov	= &iov,
632 		.msg_iovlen	= 1,
633 		.msg_control	= NULL,
634 		.msg_controllen	= 0,
635 		.msg_flags	= 0,
636 	};
637 	ret = recvmsg(nl->fd, &msg, 0);
638 	if (ret == -1)
639 		return ret;
640 
641 	if (msg.msg_flags & MSG_TRUNC) {
642 		errno = ENOSPC;
643 		return -1;
644 	}
645 	if (msg.msg_namelen != sizeof(struct sockaddr_nl)) {
646 		errno = EINVAL;
647 		return -1;
648 	}
649 	return ret;
650 }
651 
mnl_socket_close(struct mnl_socket * nl)652 static int mnl_socket_close(struct mnl_socket *nl)
653 {
654 	int ret = close(nl->fd);
655 	free(nl);
656 	return ret;
657 }
658 
659 /* mnlg mini library: */
660 
661 struct mnlg_socket {
662 	struct mnl_socket *nl;
663 	char *buf;
664 	uint16_t id;
665 	uint8_t version;
666 	unsigned int seq;
667 	unsigned int portid;
668 };
669 
__mnlg_msg_prepare(struct mnlg_socket * nlg,uint8_t cmd,uint16_t flags,uint16_t id,uint8_t version)670 static struct nlmsghdr *__mnlg_msg_prepare(struct mnlg_socket *nlg, uint8_t cmd,
671 					   uint16_t flags, uint16_t id,
672 					   uint8_t version)
673 {
674 	struct nlmsghdr *nlh;
675 	struct genlmsghdr *genl;
676 
677 	nlh = mnl_nlmsg_put_header(nlg->buf);
678 	nlh->nlmsg_type	= id;
679 	nlh->nlmsg_flags = flags;
680 	nlg->seq = time(NULL);
681 	nlh->nlmsg_seq = nlg->seq;
682 
683 	genl = mnl_nlmsg_put_extra_header(nlh, sizeof(struct genlmsghdr));
684 	genl->cmd = cmd;
685 	genl->version = version;
686 
687 	return nlh;
688 }
689 
mnlg_msg_prepare(struct mnlg_socket * nlg,uint8_t cmd,uint16_t flags)690 static struct nlmsghdr *mnlg_msg_prepare(struct mnlg_socket *nlg, uint8_t cmd,
691 					 uint16_t flags)
692 {
693 	return __mnlg_msg_prepare(nlg, cmd, flags, nlg->id, nlg->version);
694 }
695 
mnlg_socket_send(struct mnlg_socket * nlg,const struct nlmsghdr * nlh)696 static int mnlg_socket_send(struct mnlg_socket *nlg, const struct nlmsghdr *nlh)
697 {
698 	return mnl_socket_sendto(nlg->nl, nlh, nlh->nlmsg_len);
699 }
700 
mnlg_cb_noop(const struct nlmsghdr * nlh,void * data)701 static int mnlg_cb_noop(const struct nlmsghdr *nlh, void *data)
702 {
703 	(void)nlh;
704 	(void)data;
705 	return MNL_CB_OK;
706 }
707 
mnlg_cb_error(const struct nlmsghdr * nlh,void * data)708 static int mnlg_cb_error(const struct nlmsghdr *nlh, void *data)
709 {
710 	const struct nlmsgerr *err = mnl_nlmsg_get_payload(nlh);
711 	(void)data;
712 
713 	if (nlh->nlmsg_len < mnl_nlmsg_size(sizeof(struct nlmsgerr))) {
714 		errno = EBADMSG;
715 		return MNL_CB_ERROR;
716 	}
717 	/* Netlink subsystems returns the errno value with different signess */
718 	if (err->error < 0)
719 		errno = -err->error;
720 	else
721 		errno = err->error;
722 
723 	return err->error == 0 ? MNL_CB_STOP : MNL_CB_ERROR;
724 }
725 
mnlg_cb_stop(const struct nlmsghdr * nlh,void * data)726 static int mnlg_cb_stop(const struct nlmsghdr *nlh, void *data)
727 {
728 	(void)data;
729 	if (nlh->nlmsg_flags & NLM_F_MULTI && nlh->nlmsg_len == mnl_nlmsg_size(sizeof(int))) {
730 		int error = *(int *)mnl_nlmsg_get_payload(nlh);
731 		/* Netlink subsystems returns the errno value with different signess */
732 		if (error < 0)
733 			errno = -error;
734 		else
735 			errno = error;
736 
737 		return error == 0 ? MNL_CB_STOP : MNL_CB_ERROR;
738 	}
739 	return MNL_CB_STOP;
740 }
741 
742 static const mnl_cb_t mnlg_cb_array[] = {
743 	[NLMSG_NOOP]	= mnlg_cb_noop,
744 	[NLMSG_ERROR]	= mnlg_cb_error,
745 	[NLMSG_DONE]	= mnlg_cb_stop,
746 	[NLMSG_OVERRUN]	= mnlg_cb_noop,
747 };
748 
mnlg_socket_recv_run(struct mnlg_socket * nlg,mnl_cb_t data_cb,void * data)749 static int mnlg_socket_recv_run(struct mnlg_socket *nlg, mnl_cb_t data_cb, void *data)
750 {
751 	int err;
752 
753 	do {
754 		err = mnl_socket_recvfrom(nlg->nl, nlg->buf,
755 					  mnl_ideal_socket_buffer_size());
756 		if (err <= 0)
757 			break;
758 		err = mnl_cb_run2(nlg->buf, err, nlg->seq, nlg->portid,
759 				  data_cb, data, mnlg_cb_array, MNL_ARRAY_SIZE(mnlg_cb_array));
760 	} while (err > 0);
761 
762 	return err;
763 }
764 
get_family_id_attr_cb(const struct nlattr * attr,void * data)765 static int get_family_id_attr_cb(const struct nlattr *attr, void *data)
766 {
767 	const struct nlattr **tb = data;
768 	int type = mnl_attr_get_type(attr);
769 
770 	if (mnl_attr_type_valid(attr, CTRL_ATTR_MAX) < 0)
771 		return MNL_CB_ERROR;
772 
773 	if (type == CTRL_ATTR_FAMILY_ID &&
774 	    mnl_attr_validate(attr, MNL_TYPE_U16) < 0)
775 		return MNL_CB_ERROR;
776 	tb[type] = attr;
777 	return MNL_CB_OK;
778 }
779 
get_family_id_cb(const struct nlmsghdr * nlh,void * data)780 static int get_family_id_cb(const struct nlmsghdr *nlh, void *data)
781 {
782 	uint16_t *p_id = data;
783 	struct nlattr *tb[CTRL_ATTR_MAX + 1] = { 0 };
784 
785 	mnl_attr_parse(nlh, sizeof(struct genlmsghdr), get_family_id_attr_cb, tb);
786 	if (!tb[CTRL_ATTR_FAMILY_ID])
787 		return MNL_CB_ERROR;
788 	*p_id = mnl_attr_get_u16(tb[CTRL_ATTR_FAMILY_ID]);
789 	return MNL_CB_OK;
790 }
791 
mnlg_socket_open(const char * family_name,uint8_t version)792 static struct mnlg_socket *mnlg_socket_open(const char *family_name, uint8_t version)
793 {
794 	struct mnlg_socket *nlg;
795 	struct nlmsghdr *nlh;
796 	int err;
797 
798 	nlg = malloc(sizeof(*nlg));
799 	if (!nlg)
800 		return NULL;
801 	nlg->id = 0;
802 
803 	err = -ENOMEM;
804 	nlg->buf = malloc(mnl_ideal_socket_buffer_size());
805 	if (!nlg->buf)
806 		goto err_buf_alloc;
807 
808 	nlg->nl = mnl_socket_open(NETLINK_GENERIC);
809 	if (!nlg->nl) {
810 		err = -errno;
811 		goto err_mnl_socket_open;
812 	}
813 
814 	if (mnl_socket_bind(nlg->nl, 0, MNL_SOCKET_AUTOPID) < 0) {
815 		err = -errno;
816 		goto err_mnl_socket_bind;
817 	}
818 
819 	nlg->portid = mnl_socket_get_portid(nlg->nl);
820 
821 	nlh = __mnlg_msg_prepare(nlg, CTRL_CMD_GETFAMILY,
822 				 NLM_F_REQUEST | NLM_F_ACK, GENL_ID_CTRL, 1);
823 	mnl_attr_put_strz(nlh, CTRL_ATTR_FAMILY_NAME, family_name);
824 
825 	if (mnlg_socket_send(nlg, nlh) < 0) {
826 		err = -errno;
827 		goto err_mnlg_socket_send;
828 	}
829 
830 	errno = 0;
831 	if (mnlg_socket_recv_run(nlg, get_family_id_cb, &nlg->id) < 0) {
832 		errno = errno == ENOENT ? EPROTONOSUPPORT : errno;
833 		err = errno ? -errno : -ENOSYS;
834 		goto err_mnlg_socket_recv_run;
835 	}
836 
837 	nlg->version = version;
838 	errno = 0;
839 	return nlg;
840 
841 err_mnlg_socket_recv_run:
842 err_mnlg_socket_send:
843 err_mnl_socket_bind:
844 	mnl_socket_close(nlg->nl);
845 err_mnl_socket_open:
846 	free(nlg->buf);
847 err_buf_alloc:
848 	free(nlg);
849 	errno = -err;
850 	return NULL;
851 }
852 
mnlg_socket_close(struct mnlg_socket * nlg)853 static void mnlg_socket_close(struct mnlg_socket *nlg)
854 {
855 	mnl_socket_close(nlg->nl);
856 	free(nlg->buf);
857 	free(nlg);
858 }
859 
860 /* wireguard-specific parts: */
861 
862 struct string_list {
863 	char *buffer;
864 	size_t len;
865 	size_t cap;
866 };
867 
string_list_add(struct string_list * list,const char * str)868 static int string_list_add(struct string_list *list, const char *str)
869 {
870 	size_t len = strlen(str) + 1;
871 
872 	if (len == 1)
873 		return 0;
874 
875 	if (len >= list->cap - list->len) {
876 		char *new_buffer;
877 		size_t new_cap = list->cap * 2;
878 
879 		if (new_cap <  list->len +len + 1)
880 			new_cap = list->len + len + 1;
881 		new_buffer = realloc(list->buffer, new_cap);
882 		if (!new_buffer)
883 			return -errno;
884 		list->buffer = new_buffer;
885 		list->cap = new_cap;
886 	}
887 	memcpy(list->buffer + list->len, str, len);
888 	list->len += len;
889 	list->buffer[list->len] = '\0';
890 	return 0;
891 }
892 
893 struct interface {
894 	const char *name;
895 	bool is_wireguard;
896 };
897 
parse_linkinfo(const struct nlattr * attr,void * data)898 static int parse_linkinfo(const struct nlattr *attr, void *data)
899 {
900 	struct interface *interface = data;
901 
902 	if (mnl_attr_get_type(attr) == IFLA_INFO_KIND && !strcmp(WG_GENL_NAME, mnl_attr_get_str(attr)))
903 		interface->is_wireguard = true;
904 	return MNL_CB_OK;
905 }
906 
parse_infomsg(const struct nlattr * attr,void * data)907 static int parse_infomsg(const struct nlattr *attr, void *data)
908 {
909 	struct interface *interface = data;
910 
911 	if (mnl_attr_get_type(attr) == IFLA_LINKINFO)
912 		return mnl_attr_parse_nested(attr, parse_linkinfo, data);
913 	else if (mnl_attr_get_type(attr) == IFLA_IFNAME)
914 		interface->name = mnl_attr_get_str(attr);
915 	return MNL_CB_OK;
916 }
917 
read_devices_cb(const struct nlmsghdr * nlh,void * data)918 static int read_devices_cb(const struct nlmsghdr *nlh, void *data)
919 {
920 	struct string_list *list = data;
921 	struct interface interface = { 0 };
922 	int ret;
923 
924 	ret = mnl_attr_parse(nlh, sizeof(struct ifinfomsg), parse_infomsg, &interface);
925 	if (ret != MNL_CB_OK)
926 		return ret;
927 	if (interface.name && interface.is_wireguard)
928 		ret = string_list_add(list, interface.name);
929 	if (ret < 0)
930 		return ret;
931 	if (nlh->nlmsg_type != NLMSG_DONE)
932 		return MNL_CB_OK + 1;
933 	return MNL_CB_OK;
934 }
935 
fetch_device_names(struct string_list * list)936 static int fetch_device_names(struct string_list *list)
937 {
938 	struct mnl_socket *nl = NULL;
939 	char *rtnl_buffer = NULL;
940 	size_t message_len;
941 	unsigned int portid, seq;
942 	ssize_t len;
943 	int ret = 0;
944 	struct nlmsghdr *nlh;
945 	struct ifinfomsg *ifm;
946 
947 	ret = -ENOMEM;
948 	rtnl_buffer = calloc(mnl_ideal_socket_buffer_size(), 1);
949 	if (!rtnl_buffer)
950 		goto cleanup;
951 
952 	nl = mnl_socket_open(NETLINK_ROUTE);
953 	if (!nl) {
954 		ret = -errno;
955 		goto cleanup;
956 	}
957 
958 	if (mnl_socket_bind(nl, 0, MNL_SOCKET_AUTOPID) < 0) {
959 		ret = -errno;
960 		goto cleanup;
961 	}
962 
963 	seq = time(NULL);
964 	portid = mnl_socket_get_portid(nl);
965 	nlh = mnl_nlmsg_put_header(rtnl_buffer);
966 	nlh->nlmsg_type = RTM_GETLINK;
967 	nlh->nlmsg_flags = NLM_F_REQUEST | NLM_F_ACK | NLM_F_DUMP;
968 	nlh->nlmsg_seq = seq;
969 	ifm = mnl_nlmsg_put_extra_header(nlh, sizeof(*ifm));
970 	ifm->ifi_family = AF_UNSPEC;
971 	message_len = nlh->nlmsg_len;
972 
973 	if (mnl_socket_sendto(nl, rtnl_buffer, message_len) < 0) {
974 		ret = -errno;
975 		goto cleanup;
976 	}
977 
978 another:
979 	if ((len = mnl_socket_recvfrom(nl, rtnl_buffer, mnl_ideal_socket_buffer_size())) < 0) {
980 		ret = -errno;
981 		goto cleanup;
982 	}
983 	if ((len = mnl_cb_run(rtnl_buffer, len, seq, portid, read_devices_cb, list)) < 0) {
984 		/* Netlink returns NLM_F_DUMP_INTR if the set of all tunnels changed
985 		 * during the dump. That's unfortunate, but is pretty common on busy
986 		 * systems that are adding and removing tunnels all the time. Rather
987 		 * than retrying, potentially indefinitely, we just work with the
988 		 * partial results. */
989 		if (errno != EINTR) {
990 			ret = -errno;
991 			goto cleanup;
992 		}
993 	}
994 	if (len == MNL_CB_OK + 1)
995 		goto another;
996 	ret = 0;
997 
998 cleanup:
999 	free(rtnl_buffer);
1000 	if (nl)
1001 		mnl_socket_close(nl);
1002 	return ret;
1003 }
1004 
add_del_iface(const char * ifname,bool add)1005 static int add_del_iface(const char *ifname, bool add)
1006 {
1007 	struct mnl_socket *nl = NULL;
1008 	char *rtnl_buffer;
1009 	ssize_t len;
1010 	int ret;
1011 	struct nlmsghdr *nlh;
1012 	struct ifinfomsg *ifm;
1013 	struct nlattr *nest;
1014 
1015 	rtnl_buffer = calloc(mnl_ideal_socket_buffer_size(), 1);
1016 	if (!rtnl_buffer) {
1017 		ret = -ENOMEM;
1018 		goto cleanup;
1019 	}
1020 
1021 	nl = mnl_socket_open(NETLINK_ROUTE);
1022 	if (!nl) {
1023 		ret = -errno;
1024 		goto cleanup;
1025 	}
1026 
1027 	if (mnl_socket_bind(nl, 0, MNL_SOCKET_AUTOPID) < 0) {
1028 		ret = -errno;
1029 		goto cleanup;
1030 	}
1031 
1032 	nlh = mnl_nlmsg_put_header(rtnl_buffer);
1033 	nlh->nlmsg_type = add ? RTM_NEWLINK : RTM_DELLINK;
1034 	nlh->nlmsg_flags = NLM_F_REQUEST | NLM_F_ACK | (add ? NLM_F_CREATE | NLM_F_EXCL : 0);
1035 	nlh->nlmsg_seq = time(NULL);
1036 	ifm = mnl_nlmsg_put_extra_header(nlh, sizeof(*ifm));
1037 	ifm->ifi_family = AF_UNSPEC;
1038 	mnl_attr_put_strz(nlh, IFLA_IFNAME, ifname);
1039 	nest = mnl_attr_nest_start(nlh, IFLA_LINKINFO);
1040 	mnl_attr_put_strz(nlh, IFLA_INFO_KIND, WG_GENL_NAME);
1041 	mnl_attr_nest_end(nlh, nest);
1042 
1043 	if (mnl_socket_sendto(nl, rtnl_buffer, nlh->nlmsg_len) < 0) {
1044 		ret = -errno;
1045 		goto cleanup;
1046 	}
1047 	if ((len = mnl_socket_recvfrom(nl, rtnl_buffer, mnl_ideal_socket_buffer_size())) < 0) {
1048 		ret = -errno;
1049 		goto cleanup;
1050 	}
1051 	if (mnl_cb_run(rtnl_buffer, len, nlh->nlmsg_seq, mnl_socket_get_portid(nl), NULL, NULL) < 0) {
1052 		ret = -errno;
1053 		goto cleanup;
1054 	}
1055 	ret = 0;
1056 
1057 cleanup:
1058 	free(rtnl_buffer);
1059 	if (nl)
1060 		mnl_socket_close(nl);
1061 	return ret;
1062 }
1063 
wg_set_device(wg_device * dev)1064 int wg_set_device(wg_device *dev)
1065 {
1066 	int ret = 0;
1067 	wg_peer *peer = NULL;
1068 	wg_allowedip *allowedip = NULL;
1069 	struct nlattr *peers_nest, *peer_nest, *allowedips_nest, *allowedip_nest;
1070 	struct nlmsghdr *nlh;
1071 	struct mnlg_socket *nlg;
1072 
1073 	nlg = mnlg_socket_open(WG_GENL_NAME, WG_GENL_VERSION);
1074 	if (!nlg)
1075 		return -errno;
1076 
1077 again:
1078 	nlh = mnlg_msg_prepare(nlg, WG_CMD_SET_DEVICE, NLM_F_REQUEST | NLM_F_ACK);
1079 	mnl_attr_put_strz(nlh, WGDEVICE_A_IFNAME, dev->name);
1080 
1081 	if (!peer) {
1082 		uint32_t flags = 0;
1083 
1084 		if (dev->flags & WGDEVICE_HAS_PRIVATE_KEY)
1085 			mnl_attr_put(nlh, WGDEVICE_A_PRIVATE_KEY, sizeof(dev->private_key), dev->private_key);
1086 		if (dev->flags & WGDEVICE_HAS_LISTEN_PORT)
1087 			mnl_attr_put_u16(nlh, WGDEVICE_A_LISTEN_PORT, dev->listen_port);
1088 		if (dev->flags & WGDEVICE_HAS_FWMARK)
1089 			mnl_attr_put_u32(nlh, WGDEVICE_A_FWMARK, dev->fwmark);
1090 		if (dev->flags & WGDEVICE_REPLACE_PEERS)
1091 			flags |= WGDEVICE_F_REPLACE_PEERS;
1092 		if (flags)
1093 			mnl_attr_put_u32(nlh, WGDEVICE_A_FLAGS, flags);
1094 	}
1095 	if (!dev->first_peer)
1096 		goto send;
1097 	peers_nest = peer_nest = allowedips_nest = allowedip_nest = NULL;
1098 	peers_nest = mnl_attr_nest_start(nlh, WGDEVICE_A_PEERS);
1099 	for (peer = peer ? peer : dev->first_peer; peer; peer = peer->next_peer) {
1100 		uint32_t flags = 0;
1101 
1102 		peer_nest = mnl_attr_nest_start_check(nlh, mnl_ideal_socket_buffer_size(), 0);
1103 		if (!peer_nest)
1104 			goto toobig_peers;
1105 		if (!mnl_attr_put_check(nlh, mnl_ideal_socket_buffer_size(), WGPEER_A_PUBLIC_KEY, sizeof(peer->public_key), peer->public_key))
1106 			goto toobig_peers;
1107 		if (peer->flags & WGPEER_REMOVE_ME)
1108 			flags |= WGPEER_F_REMOVE_ME;
1109 		if (!allowedip) {
1110 			if (peer->flags & WGPEER_REPLACE_ALLOWEDIPS)
1111 				flags |= WGPEER_F_REPLACE_ALLOWEDIPS;
1112 			if (peer->flags & WGPEER_HAS_PRESHARED_KEY) {
1113 				if (!mnl_attr_put_check(nlh, mnl_ideal_socket_buffer_size(), WGPEER_A_PRESHARED_KEY, sizeof(peer->preshared_key), peer->preshared_key))
1114 					goto toobig_peers;
1115 			}
1116 			if (peer->endpoint.addr.sa_family == AF_INET) {
1117 				if (!mnl_attr_put_check(nlh, mnl_ideal_socket_buffer_size(), WGPEER_A_ENDPOINT, sizeof(peer->endpoint.addr4), &peer->endpoint.addr4))
1118 					goto toobig_peers;
1119 			} else if (peer->endpoint.addr.sa_family == AF_INET6) {
1120 				if (!mnl_attr_put_check(nlh, mnl_ideal_socket_buffer_size(), WGPEER_A_ENDPOINT, sizeof(peer->endpoint.addr6), &peer->endpoint.addr6))
1121 					goto toobig_peers;
1122 			}
1123 			if (peer->flags & WGPEER_HAS_PERSISTENT_KEEPALIVE_INTERVAL) {
1124 				if (!mnl_attr_put_u16_check(nlh, mnl_ideal_socket_buffer_size(), WGPEER_A_PERSISTENT_KEEPALIVE_INTERVAL, peer->persistent_keepalive_interval))
1125 					goto toobig_peers;
1126 			}
1127 		}
1128 		if (flags) {
1129 			if (!mnl_attr_put_u32_check(nlh, mnl_ideal_socket_buffer_size(), WGPEER_A_FLAGS, flags))
1130 				goto toobig_peers;
1131 		}
1132 		if (peer->first_allowedip) {
1133 			if (!allowedip)
1134 				allowedip = peer->first_allowedip;
1135 			allowedips_nest = mnl_attr_nest_start_check(nlh, mnl_ideal_socket_buffer_size(), WGPEER_A_ALLOWEDIPS);
1136 			if (!allowedips_nest)
1137 				goto toobig_allowedips;
1138 			for (; allowedip; allowedip = allowedip->next_allowedip) {
1139 				allowedip_nest = mnl_attr_nest_start_check(nlh, mnl_ideal_socket_buffer_size(), 0);
1140 				if (!allowedip_nest)
1141 					goto toobig_allowedips;
1142 				if (!mnl_attr_put_u16_check(nlh, mnl_ideal_socket_buffer_size(), WGALLOWEDIP_A_FAMILY, allowedip->family))
1143 					goto toobig_allowedips;
1144 				if (allowedip->family == AF_INET) {
1145 					if (!mnl_attr_put_check(nlh, mnl_ideal_socket_buffer_size(), WGALLOWEDIP_A_IPADDR, sizeof(allowedip->ip4), &allowedip->ip4))
1146 						goto toobig_allowedips;
1147 				} else if (allowedip->family == AF_INET6) {
1148 					if (!mnl_attr_put_check(nlh, mnl_ideal_socket_buffer_size(), WGALLOWEDIP_A_IPADDR, sizeof(allowedip->ip6), &allowedip->ip6))
1149 						goto toobig_allowedips;
1150 				}
1151 				if (!mnl_attr_put_u8_check(nlh, mnl_ideal_socket_buffer_size(), WGALLOWEDIP_A_CIDR_MASK, allowedip->cidr))
1152 					goto toobig_allowedips;
1153 				mnl_attr_nest_end(nlh, allowedip_nest);
1154 				allowedip_nest = NULL;
1155 			}
1156 			mnl_attr_nest_end(nlh, allowedips_nest);
1157 			allowedips_nest = NULL;
1158 		}
1159 
1160 		mnl_attr_nest_end(nlh, peer_nest);
1161 		peer_nest = NULL;
1162 	}
1163 	mnl_attr_nest_end(nlh, peers_nest);
1164 	peers_nest = NULL;
1165 	goto send;
1166 toobig_allowedips:
1167 	if (allowedip_nest)
1168 		mnl_attr_nest_cancel(nlh, allowedip_nest);
1169 	if (allowedips_nest)
1170 		mnl_attr_nest_end(nlh, allowedips_nest);
1171 	mnl_attr_nest_end(nlh, peer_nest);
1172 	mnl_attr_nest_end(nlh, peers_nest);
1173 	goto send;
1174 toobig_peers:
1175 	if (peer_nest)
1176 		mnl_attr_nest_cancel(nlh, peer_nest);
1177 	mnl_attr_nest_end(nlh, peers_nest);
1178 	goto send;
1179 send:
1180 	if (mnlg_socket_send(nlg, nlh) < 0) {
1181 		ret = -errno;
1182 		goto out;
1183 	}
1184 	errno = 0;
1185 	if (mnlg_socket_recv_run(nlg, NULL, NULL) < 0) {
1186 		ret = errno ? -errno : -EINVAL;
1187 		goto out;
1188 	}
1189 	if (peer)
1190 		goto again;
1191 
1192 out:
1193 	mnlg_socket_close(nlg);
1194 	errno = -ret;
1195 	return ret;
1196 }
1197 
parse_allowedip(const struct nlattr * attr,void * data)1198 static int parse_allowedip(const struct nlattr *attr, void *data)
1199 {
1200 	wg_allowedip *allowedip = data;
1201 
1202 	switch (mnl_attr_get_type(attr)) {
1203 	case WGALLOWEDIP_A_UNSPEC:
1204 		break;
1205 	case WGALLOWEDIP_A_FAMILY:
1206 		if (!mnl_attr_validate(attr, MNL_TYPE_U16))
1207 			allowedip->family = mnl_attr_get_u16(attr);
1208 		break;
1209 	case WGALLOWEDIP_A_IPADDR:
1210 		if (mnl_attr_get_payload_len(attr) == sizeof(allowedip->ip4))
1211 			memcpy(&allowedip->ip4, mnl_attr_get_payload(attr), sizeof(allowedip->ip4));
1212 		else if (mnl_attr_get_payload_len(attr) == sizeof(allowedip->ip6))
1213 			memcpy(&allowedip->ip6, mnl_attr_get_payload(attr), sizeof(allowedip->ip6));
1214 		break;
1215 	case WGALLOWEDIP_A_CIDR_MASK:
1216 		if (!mnl_attr_validate(attr, MNL_TYPE_U8))
1217 			allowedip->cidr = mnl_attr_get_u8(attr);
1218 		break;
1219 	}
1220 
1221 	return MNL_CB_OK;
1222 }
1223 
parse_allowedips(const struct nlattr * attr,void * data)1224 static int parse_allowedips(const struct nlattr *attr, void *data)
1225 {
1226 	wg_peer *peer = data;
1227 	wg_allowedip *new_allowedip = calloc(1, sizeof(wg_allowedip));
1228 	int ret;
1229 
1230 	if (!new_allowedip)
1231 		return MNL_CB_ERROR;
1232 	if (!peer->first_allowedip)
1233 		peer->first_allowedip = peer->last_allowedip = new_allowedip;
1234 	else {
1235 		peer->last_allowedip->next_allowedip = new_allowedip;
1236 		peer->last_allowedip = new_allowedip;
1237 	}
1238 	ret = mnl_attr_parse_nested(attr, parse_allowedip, new_allowedip);
1239 	if (!ret)
1240 		return ret;
1241 	if (!((new_allowedip->family == AF_INET && new_allowedip->cidr <= 32) || (new_allowedip->family == AF_INET6 && new_allowedip->cidr <= 128))) {
1242 		errno = EAFNOSUPPORT;
1243 		return MNL_CB_ERROR;
1244 	}
1245 	return MNL_CB_OK;
1246 }
1247 
wg_key_is_zero(const wg_key key)1248 bool wg_key_is_zero(const wg_key key)
1249 {
1250 	volatile uint8_t acc = 0;
1251 	unsigned int i;
1252 
1253 	for (i = 0; i < sizeof(wg_key); ++i) {
1254 		acc |= key[i];
1255 		__asm__ ("" : "=r" (acc) : "0" (acc));
1256 	}
1257 	return 1 & ((acc - 1) >> 8);
1258 }
1259 
parse_peer(const struct nlattr * attr,void * data)1260 static int parse_peer(const struct nlattr *attr, void *data)
1261 {
1262 	wg_peer *peer = data;
1263 
1264 	switch (mnl_attr_get_type(attr)) {
1265 	case WGPEER_A_UNSPEC:
1266 		break;
1267 	case WGPEER_A_PUBLIC_KEY:
1268 		if (mnl_attr_get_payload_len(attr) == sizeof(peer->public_key)) {
1269 			memcpy(peer->public_key, mnl_attr_get_payload(attr), sizeof(peer->public_key));
1270 			peer->flags |= WGPEER_HAS_PUBLIC_KEY;
1271 		}
1272 		break;
1273 	case WGPEER_A_PRESHARED_KEY:
1274 		if (mnl_attr_get_payload_len(attr) == sizeof(peer->preshared_key)) {
1275 			memcpy(peer->preshared_key, mnl_attr_get_payload(attr), sizeof(peer->preshared_key));
1276 			if (!wg_key_is_zero(peer->preshared_key))
1277 				peer->flags |= WGPEER_HAS_PRESHARED_KEY;
1278 		}
1279 		break;
1280 	case WGPEER_A_ENDPOINT: {
1281 		struct sockaddr *addr;
1282 
1283 		if (mnl_attr_get_payload_len(attr) < sizeof(*addr))
1284 			break;
1285 		addr = mnl_attr_get_payload(attr);
1286 		if (addr->sa_family == AF_INET && mnl_attr_get_payload_len(attr) == sizeof(peer->endpoint.addr4))
1287 			memcpy(&peer->endpoint.addr4, addr, sizeof(peer->endpoint.addr4));
1288 		else if (addr->sa_family == AF_INET6 && mnl_attr_get_payload_len(attr) == sizeof(peer->endpoint.addr6))
1289 			memcpy(&peer->endpoint.addr6, addr, sizeof(peer->endpoint.addr6));
1290 		break;
1291 	}
1292 	case WGPEER_A_PERSISTENT_KEEPALIVE_INTERVAL:
1293 		if (!mnl_attr_validate(attr, MNL_TYPE_U16))
1294 			peer->persistent_keepalive_interval = mnl_attr_get_u16(attr);
1295 		break;
1296 	case WGPEER_A_LAST_HANDSHAKE_TIME:
1297 		if (mnl_attr_get_payload_len(attr) == sizeof(peer->last_handshake_time))
1298 			memcpy(&peer->last_handshake_time, mnl_attr_get_payload(attr), sizeof(peer->last_handshake_time));
1299 		break;
1300 	case WGPEER_A_RX_BYTES:
1301 		if (!mnl_attr_validate(attr, MNL_TYPE_U64))
1302 			peer->rx_bytes = mnl_attr_get_u64(attr);
1303 		break;
1304 	case WGPEER_A_TX_BYTES:
1305 		if (!mnl_attr_validate(attr, MNL_TYPE_U64))
1306 			peer->tx_bytes = mnl_attr_get_u64(attr);
1307 		break;
1308 	case WGPEER_A_ALLOWEDIPS:
1309 		return mnl_attr_parse_nested(attr, parse_allowedips, peer);
1310 	}
1311 
1312 	return MNL_CB_OK;
1313 }
1314 
parse_peers(const struct nlattr * attr,void * data)1315 static int parse_peers(const struct nlattr *attr, void *data)
1316 {
1317 	wg_device *device = data;
1318 	wg_peer *new_peer = calloc(1, sizeof(wg_peer));
1319 	int ret;
1320 
1321 	if (!new_peer)
1322 		return MNL_CB_ERROR;
1323 	if (!device->first_peer)
1324 		device->first_peer = device->last_peer = new_peer;
1325 	else {
1326 		device->last_peer->next_peer = new_peer;
1327 		device->last_peer = new_peer;
1328 	}
1329 	ret = mnl_attr_parse_nested(attr, parse_peer, new_peer);
1330 	if (!ret)
1331 		return ret;
1332 	if (!(new_peer->flags & WGPEER_HAS_PUBLIC_KEY)) {
1333 		errno = ENXIO;
1334 		return MNL_CB_ERROR;
1335 	}
1336 	return MNL_CB_OK;
1337 }
1338 
parse_device(const struct nlattr * attr,void * data)1339 static int parse_device(const struct nlattr *attr, void *data)
1340 {
1341 	wg_device *device = data;
1342 
1343 	switch (mnl_attr_get_type(attr)) {
1344 	case WGDEVICE_A_UNSPEC:
1345 		break;
1346 	case WGDEVICE_A_IFINDEX:
1347 		if (!mnl_attr_validate(attr, MNL_TYPE_U32))
1348 			device->ifindex = mnl_attr_get_u32(attr);
1349 		break;
1350 	case WGDEVICE_A_IFNAME:
1351 		if (!mnl_attr_validate(attr, MNL_TYPE_STRING)) {
1352 			strncpy(device->name, mnl_attr_get_str(attr), sizeof(device->name) - 1);
1353 			device->name[sizeof(device->name) - 1] = '\0';
1354 		}
1355 		break;
1356 	case WGDEVICE_A_PRIVATE_KEY:
1357 		if (mnl_attr_get_payload_len(attr) == sizeof(device->private_key)) {
1358 			memcpy(device->private_key, mnl_attr_get_payload(attr), sizeof(device->private_key));
1359 			device->flags |= WGDEVICE_HAS_PRIVATE_KEY;
1360 		}
1361 		break;
1362 	case WGDEVICE_A_PUBLIC_KEY:
1363 		if (mnl_attr_get_payload_len(attr) == sizeof(device->public_key)) {
1364 			memcpy(device->public_key, mnl_attr_get_payload(attr), sizeof(device->public_key));
1365 			device->flags |= WGDEVICE_HAS_PUBLIC_KEY;
1366 		}
1367 		break;
1368 	case WGDEVICE_A_LISTEN_PORT:
1369 		if (!mnl_attr_validate(attr, MNL_TYPE_U16))
1370 			device->listen_port = mnl_attr_get_u16(attr);
1371 		break;
1372 	case WGDEVICE_A_FWMARK:
1373 		if (!mnl_attr_validate(attr, MNL_TYPE_U32))
1374 			device->fwmark = mnl_attr_get_u32(attr);
1375 		break;
1376 	case WGDEVICE_A_PEERS:
1377 		return mnl_attr_parse_nested(attr, parse_peers, device);
1378 	}
1379 
1380 	return MNL_CB_OK;
1381 }
1382 
read_device_cb(const struct nlmsghdr * nlh,void * data)1383 static int read_device_cb(const struct nlmsghdr *nlh, void *data)
1384 {
1385 	return mnl_attr_parse(nlh, sizeof(struct genlmsghdr), parse_device, data);
1386 }
1387 
coalesce_peers(wg_device * device)1388 static void coalesce_peers(wg_device *device)
1389 {
1390 	wg_peer *old_next_peer, *peer = device->first_peer;
1391 
1392 	while (peer && peer->next_peer) {
1393 		if (memcmp(peer->public_key, peer->next_peer->public_key, sizeof(wg_key))) {
1394 			peer = peer->next_peer;
1395 			continue;
1396 		}
1397 		if (!peer->first_allowedip) {
1398 			peer->first_allowedip = peer->next_peer->first_allowedip;
1399 			peer->last_allowedip = peer->next_peer->last_allowedip;
1400 		} else {
1401 			peer->last_allowedip->next_allowedip = peer->next_peer->first_allowedip;
1402 			peer->last_allowedip = peer->next_peer->last_allowedip;
1403 		}
1404 		old_next_peer = peer->next_peer;
1405 		peer->next_peer = old_next_peer->next_peer;
1406 		free(old_next_peer);
1407 	}
1408 }
1409 
wg_get_device(wg_device ** device,const char * device_name)1410 int wg_get_device(wg_device **device, const char *device_name)
1411 {
1412 	int ret = 0;
1413 	struct nlmsghdr *nlh;
1414 	struct mnlg_socket *nlg;
1415 
1416 try_again:
1417 	*device = calloc(1, sizeof(wg_device));
1418 	if (!*device)
1419 		return -errno;
1420 
1421 	nlg = mnlg_socket_open(WG_GENL_NAME, WG_GENL_VERSION);
1422 	if (!nlg) {
1423 		wg_free_device(*device);
1424 		*device = NULL;
1425 		return -errno;
1426 	}
1427 
1428 	nlh = mnlg_msg_prepare(nlg, WG_CMD_GET_DEVICE, NLM_F_REQUEST | NLM_F_ACK | NLM_F_DUMP);
1429 	mnl_attr_put_strz(nlh, WGDEVICE_A_IFNAME, device_name);
1430 	if (mnlg_socket_send(nlg, nlh) < 0) {
1431 		ret = -errno;
1432 		goto out;
1433 	}
1434 	errno = 0;
1435 	if (mnlg_socket_recv_run(nlg, read_device_cb, *device) < 0) {
1436 		ret = errno ? -errno : -EINVAL;
1437 		goto out;
1438 	}
1439 	coalesce_peers(*device);
1440 
1441 out:
1442 	if (nlg)
1443 		mnlg_socket_close(nlg);
1444 	if (ret) {
1445 		wg_free_device(*device);
1446 		if (ret == -EINTR)
1447 			goto try_again;
1448 		*device = NULL;
1449 	}
1450 	errno = -ret;
1451 	return ret;
1452 }
1453 
1454 /* first\0second\0third\0forth\0last\0\0 */
wg_list_device_names(void)1455 char *wg_list_device_names(void)
1456 {
1457 	struct string_list list = { 0 };
1458 	int ret = fetch_device_names(&list);
1459 
1460 	errno = -ret;
1461 	if (errno) {
1462 		free(list.buffer);
1463 		return NULL;
1464 	}
1465 	return list.buffer ?: strdup("\0");
1466 }
1467 
wg_add_device(const char * device_name)1468 int wg_add_device(const char *device_name)
1469 {
1470 	return add_del_iface(device_name, true);
1471 }
1472 
wg_del_device(const char * device_name)1473 int wg_del_device(const char *device_name)
1474 {
1475 	return add_del_iface(device_name, false);
1476 }
1477 
wg_free_device(wg_device * dev)1478 void wg_free_device(wg_device *dev)
1479 {
1480 	wg_peer *peer, *np;
1481 	wg_allowedip *allowedip, *na;
1482 
1483 	if (!dev)
1484 		return;
1485 	for (peer = dev->first_peer, np = peer ? peer->next_peer : NULL; peer; peer = np, np = peer ? peer->next_peer : NULL) {
1486 		for (allowedip = peer->first_allowedip, na = allowedip ? allowedip->next_allowedip : NULL; allowedip; allowedip = na, na = allowedip ? allowedip->next_allowedip : NULL)
1487 			free(allowedip);
1488 		free(peer);
1489 	}
1490 	free(dev);
1491 }
1492 
encode_base64(char dest[static4],const uint8_t src[static3])1493 static void encode_base64(char dest[static 4], const uint8_t src[static 3])
1494 {
1495 	const uint8_t input[] = { (src[0] >> 2) & 63, ((src[0] << 4) | (src[1] >> 4)) & 63, ((src[1] << 2) | (src[2] >> 6)) & 63, src[2] & 63 };
1496 	unsigned int i;
1497 
1498 	for (i = 0; i < 4; ++i)
1499 		dest[i] = input[i] + 'A'
1500 			  + (((25 - input[i]) >> 8) & 6)
1501 			  - (((51 - input[i]) >> 8) & 75)
1502 			  - (((61 - input[i]) >> 8) & 15)
1503 			  + (((62 - input[i]) >> 8) & 3);
1504 
1505 }
1506 
wg_key_to_base64(wg_key_b64_string base64,const wg_key key)1507 void wg_key_to_base64(wg_key_b64_string base64, const wg_key key)
1508 {
1509 	unsigned int i;
1510 
1511 	for (i = 0; i < 32 / 3; ++i)
1512 		encode_base64(&base64[i * 4], &key[i * 3]);
1513 	encode_base64(&base64[i * 4], (const uint8_t[]){ key[i * 3 + 0], key[i * 3 + 1], 0 });
1514 	base64[sizeof(wg_key_b64_string) - 2] = '=';
1515 	base64[sizeof(wg_key_b64_string) - 1] = '\0';
1516 }
1517 
decode_base64(const char src[static4])1518 static int decode_base64(const char src[static 4])
1519 {
1520 	int val = 0;
1521 	unsigned int i;
1522 
1523 	for (i = 0; i < 4; ++i)
1524 		val |= (-1
1525 			    + ((((('A' - 1) - src[i]) & (src[i] - ('Z' + 1))) >> 8) & (src[i] - 64))
1526 			    + ((((('a' - 1) - src[i]) & (src[i] - ('z' + 1))) >> 8) & (src[i] - 70))
1527 			    + ((((('0' - 1) - src[i]) & (src[i] - ('9' + 1))) >> 8) & (src[i] + 5))
1528 			    + ((((('+' - 1) - src[i]) & (src[i] - ('+' + 1))) >> 8) & 63)
1529 			    + ((((('/' - 1) - src[i]) & (src[i] - ('/' + 1))) >> 8) & 64)
1530 			) << (18 - 6 * i);
1531 	return val;
1532 }
1533 
wg_key_from_base64(wg_key key,const wg_key_b64_string base64)1534 int wg_key_from_base64(wg_key key, const wg_key_b64_string base64)
1535 {
1536 	unsigned int i;
1537 	int val;
1538 	volatile uint8_t ret = 0;
1539 
1540 	if (strlen(base64) != sizeof(wg_key_b64_string) - 1 || base64[sizeof(wg_key_b64_string) - 2] != '=') {
1541 		errno = EINVAL;
1542 		goto out;
1543 	}
1544 
1545 	for (i = 0; i < 32 / 3; ++i) {
1546 		val = decode_base64(&base64[i * 4]);
1547 		ret |= (uint32_t)val >> 31;
1548 		key[i * 3 + 0] = (val >> 16) & 0xff;
1549 		key[i * 3 + 1] = (val >> 8) & 0xff;
1550 		key[i * 3 + 2] = val & 0xff;
1551 	}
1552 	val = decode_base64((const char[]){ base64[i * 4 + 0], base64[i * 4 + 1], base64[i * 4 + 2], 'A' });
1553 	ret |= ((uint32_t)val >> 31) | (val & 0xff);
1554 	key[i * 3 + 0] = (val >> 16) & 0xff;
1555 	key[i * 3 + 1] = (val >> 8) & 0xff;
1556 	errno = EINVAL & ~((ret - 1) >> 8);
1557 out:
1558 	return -errno;
1559 }
1560 
1561 typedef int64_t fe[16];
1562 
memzero_explicit(void * s,size_t count)1563 static __attribute__((noinline)) void memzero_explicit(void *s, size_t count)
1564 {
1565 	memset(s, 0, count);
1566 	__asm__ __volatile__("": :"r"(s) :"memory");
1567 }
1568 
carry(fe o)1569 static void carry(fe o)
1570 {
1571 	int i;
1572 
1573 	for (i = 0; i < 16; ++i) {
1574 		o[(i + 1) % 16] += (i == 15 ? 38 : 1) * (o[i] >> 16);
1575 		o[i] &= 0xffff;
1576 	}
1577 }
1578 
cswap(fe p,fe q,int b)1579 static void cswap(fe p, fe q, int b)
1580 {
1581 	int i;
1582 	int64_t t, c = ~(b - 1);
1583 
1584 	for (i = 0; i < 16; ++i) {
1585 		t = c & (p[i] ^ q[i]);
1586 		p[i] ^= t;
1587 		q[i] ^= t;
1588 	}
1589 
1590 	memzero_explicit(&t, sizeof(t));
1591 	memzero_explicit(&c, sizeof(c));
1592 	memzero_explicit(&b, sizeof(b));
1593 }
1594 
pack(uint8_t * o,const fe n)1595 static void pack(uint8_t *o, const fe n)
1596 {
1597 	int i, j, b;
1598 	fe m, t;
1599 
1600 	memcpy(t, n, sizeof(t));
1601 	carry(t);
1602 	carry(t);
1603 	carry(t);
1604 	for (j = 0; j < 2; ++j) {
1605 		m[0] = t[0] - 0xffed;
1606 		for (i = 1; i < 15; ++i) {
1607 			m[i] = t[i] - 0xffff - ((m[i - 1] >> 16) & 1);
1608 			m[i - 1] &= 0xffff;
1609 		}
1610 		m[15] = t[15] - 0x7fff - ((m[14] >> 16) & 1);
1611 		b = (m[15] >> 16) & 1;
1612 		m[14] &= 0xffff;
1613 		cswap(t, m, 1 - b);
1614 	}
1615 	for (i = 0; i < 16; ++i) {
1616 		o[2 * i] = t[i] & 0xff;
1617 		o[2 * i + 1] = t[i] >> 8;
1618 	}
1619 
1620 	memzero_explicit(m, sizeof(m));
1621 	memzero_explicit(t, sizeof(t));
1622 	memzero_explicit(&b, sizeof(b));
1623 }
1624 
add(fe o,const fe a,const fe b)1625 static void add(fe o, const fe a, const fe b)
1626 {
1627 	int i;
1628 
1629 	for (i = 0; i < 16; ++i)
1630 		o[i] = a[i] + b[i];
1631 }
1632 
subtract(fe o,const fe a,const fe b)1633 static void subtract(fe o, const fe a, const fe b)
1634 {
1635 	int i;
1636 
1637 	for (i = 0; i < 16; ++i)
1638 		o[i] = a[i] - b[i];
1639 }
1640 
multmod(fe o,const fe a,const fe b)1641 static void multmod(fe o, const fe a, const fe b)
1642 {
1643 	int i, j;
1644 	int64_t t[31] = { 0 };
1645 
1646 	for (i = 0; i < 16; ++i) {
1647 		for (j = 0; j < 16; ++j)
1648 			t[i + j] += a[i] * b[j];
1649 	}
1650 	for (i = 0; i < 15; ++i)
1651 		t[i] += 38 * t[i + 16];
1652 	memcpy(o, t, sizeof(fe));
1653 	carry(o);
1654 	carry(o);
1655 
1656 	memzero_explicit(t, sizeof(t));
1657 }
1658 
invert(fe o,const fe i)1659 static void invert(fe o, const fe i)
1660 {
1661 	fe c;
1662 	int a;
1663 
1664 	memcpy(c, i, sizeof(c));
1665 	for (a = 253; a >= 0; --a) {
1666 		multmod(c, c, c);
1667 		if (a != 2 && a != 4)
1668 			multmod(c, c, i);
1669 	}
1670 	memcpy(o, c, sizeof(fe));
1671 
1672 	memzero_explicit(c, sizeof(c));
1673 }
1674 
clamp_key(uint8_t * z)1675 static void clamp_key(uint8_t *z)
1676 {
1677 	z[31] = (z[31] & 127) | 64;
1678 	z[0] &= 248;
1679 }
1680 
wg_generate_public_key(wg_key public_key,const wg_key private_key)1681 void wg_generate_public_key(wg_key public_key, const wg_key private_key)
1682 {
1683 	int i, r;
1684 	uint8_t z[32];
1685 	fe a = { 1 }, b = { 9 }, c = { 0 }, d = { 1 }, e, f;
1686 
1687 	memcpy(z, private_key, sizeof(z));
1688 	clamp_key(z);
1689 
1690 	for (i = 254; i >= 0; --i) {
1691 		r = (z[i >> 3] >> (i & 7)) & 1;
1692 		cswap(a, b, r);
1693 		cswap(c, d, r);
1694 		add(e, a, c);
1695 		subtract(a, a, c);
1696 		add(c, b, d);
1697 		subtract(b, b, d);
1698 		multmod(d, e, e);
1699 		multmod(f, a, a);
1700 		multmod(a, c, a);
1701 		multmod(c, b, e);
1702 		add(e, a, c);
1703 		subtract(a, a, c);
1704 		multmod(b, a, a);
1705 		subtract(c, d, f);
1706 		multmod(a, c, (const fe){ 0xdb41, 1 });
1707 		add(a, a, d);
1708 		multmod(c, c, a);
1709 		multmod(a, d, f);
1710 		multmod(d, b, (const fe){ 9 });
1711 		multmod(b, e, e);
1712 		cswap(a, b, r);
1713 		cswap(c, d, r);
1714 	}
1715 	invert(c, c);
1716 	multmod(a, a, c);
1717 	pack(public_key, a);
1718 
1719 	memzero_explicit(&r, sizeof(r));
1720 	memzero_explicit(z, sizeof(z));
1721 	memzero_explicit(a, sizeof(a));
1722 	memzero_explicit(b, sizeof(b));
1723 	memzero_explicit(c, sizeof(c));
1724 	memzero_explicit(d, sizeof(d));
1725 	memzero_explicit(e, sizeof(e));
1726 	memzero_explicit(f, sizeof(f));
1727 }
1728 
wg_generate_private_key(wg_key private_key)1729 void wg_generate_private_key(wg_key private_key)
1730 {
1731 	wg_generate_preshared_key(private_key);
1732 	clamp_key(private_key);
1733 }
1734 
wg_generate_preshared_key(wg_key preshared_key)1735 void wg_generate_preshared_key(wg_key preshared_key)
1736 {
1737 	ssize_t ret;
1738 	size_t i;
1739 	int fd;
1740 #if defined(__OpenBSD__) || (defined(__APPLE__) && MAC_OS_X_VERSION_MIN_REQUIRED >= MAC_OS_X_VERSION_10_12) || (defined(__GLIBC__) && (__GLIBC__ > 2 || (__GLIBC__ == 2 && __GLIBC_MINOR__ >= 25)))
1741 	if (!getentropy(preshared_key, sizeof(wg_key)))
1742 		return;
1743 #endif
1744 #if defined(__NR_getrandom) && defined(__linux__)
1745 	if (syscall(__NR_getrandom, preshared_key, sizeof(wg_key), 0) == sizeof(wg_key))
1746 		return;
1747 #endif
1748 	fd = open("/dev/urandom", O_RDONLY);
1749 	assert(fd >= 0);
1750 	for (i = 0; i < sizeof(wg_key); i += ret) {
1751 		ret = read(fd, preshared_key + i, sizeof(wg_key) - i);
1752 		assert(ret > 0);
1753 	}
1754 	close(fd);
1755 }
1756