1 /*
2  * Copyright (C) 2020 Microsoft Corporation
3  *
4  * This file is part of ocserv.
5  *
6  * ocserv is free software: you can redistribute it and/or modify it
7  * under the terms of the GNU General Public License as published by
8  * the Free Software Foundation, either version 2 of the License, or
9  * (at your option) any later version.
10  *
11  * ocserv is distributed in the hope that it will be useful, but
12  * WITHOUT ANY WARRANTY; without even the implied warranty of
13  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
14  * General Public License for more details.
15  *
16  * You should have received a copy of the GNU General Public License
17  * along with this program.  If not, see <http://www.gnu.org/licenses/>.
18  */
19 
20 #include <config.h>
21 #include <errno.h>
22 #include <stdio.h>
23 #include <string.h>
24 #include <unistd.h>
25 
26 #if defined(ENABLE_ADAPTIVE_RATE_LIMIT_SUPPORT)
27 
28 #include <sys/socket.h>
29 #include <sys/un.h>
30 #include <linux/netlink.h>
31 #include <linux/rtnetlink.h>
32 #include <linux/sock_diag.h>
33 #include <linux/unix_diag.h>
34 #include <netinet/tcp.h>
35 #include <sys/syslog.h>
36 
send_query(int fd,int inode,int states,int show)37 static int send_query(int fd, int inode, int states, int show)
38 {
39 	int err;
40 	struct sockaddr_nl nladdr = {
41 		.nl_family = AF_NETLINK
42 	};
43 	struct {
44 		struct nlmsghdr nlh;
45 		struct unix_diag_req udr;
46 	} req = {
47 		.nlh = {
48 			.nlmsg_len = sizeof(req),.nlmsg_type =
49 			SOCK_DIAG_BY_FAMILY,.nlmsg_flags =
50 			NLM_F_REQUEST | (inode ? 0 : NLM_F_DUMP)
51 			}
52 		,.udr = {
53 			 .sdiag_family = AF_UNIX,.udiag_states =
54 			 states,.udiag_show = show,.udiag_ino = inode}
55 	};
56 	struct iovec iov = {
57 		.iov_base = &req,
58 		.iov_len = sizeof(req)
59 	};
60 	struct msghdr msg = {
61 		.msg_name = (void *)&nladdr,
62 		.msg_namelen = sizeof(nladdr),
63 		.msg_iov = &iov,
64 		.msg_iovlen = 1
65 	};
66 
67 	for (;;) {
68 		if (sendmsg(fd, &msg, 0) < 0) {
69 			if (errno == EINTR)
70 				continue;
71 			err = errno;
72 
73 			syslog(LOG_ERR, "sendmsg failed %s", strerror(err));
74 			return -1;
75 		}
76 
77 		return 0;
78 	}
79 }
80 
81 typedef int (*process_response)(const struct unix_diag_msg * diag,
82 				unsigned int len, void *context);
83 
84 struct match_name_context {
85 	const char *name;
86 	int inode;
87 	struct unix_diag_rqlen rqlen;
88 };
89 
match_name(const struct unix_diag_msg * diag,unsigned int len,void * context)90 static int match_name(const struct unix_diag_msg *diag, unsigned int len,
91 		      void *context)
92 {
93 	struct match_name_context *ctx = (struct match_name_context *)context;
94 
95 	struct rtattr *attr;
96 	unsigned int rta_len = len - NLMSG_LENGTH(sizeof(*diag));
97 	size_t path_len = 0;
98 	char path[sizeof(((struct sockaddr_un *) 0)->sun_path) + 1];
99 	struct unix_diag_rqlen rqlen;
100 	int rqlen_valid = 0;
101 
102 	for (attr = (struct rtattr *)(diag + 1);
103 	     RTA_OK(attr, rta_len); attr = RTA_NEXT(attr, rta_len)) {
104 		switch (attr->rta_type) {
105 		case UNIX_DIAG_NAME:
106 			if (!path_len) {
107 				path_len = RTA_PAYLOAD(attr);
108 				if (path_len > sizeof(path) - 1)
109 					path_len = sizeof(path) - 1;
110 				memcpy(path, RTA_DATA(attr), path_len);
111 				path[path_len] = '\0';
112 			}
113 			break;
114 		case UNIX_DIAG_RQLEN:
115 			if (RTA_PAYLOAD(attr) != sizeof(rqlen))
116 				return -1;
117 			memcpy(&rqlen, RTA_DATA(attr), sizeof(rqlen));
118 			rqlen_valid = 1;
119 			break;
120 		}
121 	}
122 
123 	if (path_len == 0) {
124 		syslog(LOG_ERR, "UNIX_DIAG_NAME not present in response");
125 		return -1;
126 	}
127 
128 	if (rqlen_valid == 0) {
129 		syslog(LOG_ERR, "UNIX_DIAG_RQLEN not present in response");
130 		return -1;
131 	}
132 
133 	if (strcmp(path, ctx->name) == 0) {
134 		ctx->inode = diag->udiag_ino;
135 		ctx->rqlen = rqlen;
136 	}
137 
138 	return 0;
139 }
140 
receive_responses(int fd,process_response process,void * context)141 static int receive_responses(int fd, process_response process, void *context)
142 {
143 	int err;
144 	long buf[8192 / sizeof(long)];
145 	struct sockaddr_nl nladdr = {
146 		.nl_family = AF_NETLINK
147 	};
148 	struct iovec iov = {
149 		.iov_base = buf,
150 		.iov_len = sizeof(buf)
151 	};
152 	int flags = 0;
153 
154 	for (;;) {
155 		struct msghdr msg = {
156 			.msg_name = (void *)&nladdr,
157 			.msg_namelen = sizeof(nladdr),
158 			.msg_iov = &iov,
159 			.msg_iovlen = 1
160 		};
161 
162 		ssize_t ret = recvmsg(fd, &msg, flags);
163 
164 		if (ret < 0) {
165 			if (errno == EINTR)
166 				continue;
167 			err = errno;
168 			syslog(LOG_ERR, "recvmsg failed %s", strerror(err));
169 			return -1;
170 		}
171 
172 		if (ret == 0) {
173 			syslog(LOG_ERR, "recvmsg returned empty response");
174 			return -1;
175 		}
176 
177 		const struct nlmsghdr *h = (struct nlmsghdr *)buf;
178 
179 		if (!NLMSG_OK(h, ret)) {
180 			syslog(LOG_ERR, "!NLMSG_OK");
181 			return -1;
182 		}
183 
184 		for (; NLMSG_OK(h, ret); h = NLMSG_NEXT(h, ret)) {
185 			const struct unix_diag_msg *diag;
186 
187 			if (h->nlmsg_type == NLMSG_DONE)
188 				return 0;
189 
190 			if (h->nlmsg_type == NLMSG_ERROR) {
191 				const struct nlmsgerr *err = NLMSG_DATA(h);
192 
193 				if (h->nlmsg_len < NLMSG_LENGTH(sizeof(*err))) {
194 					syslog(LOG_ERR,
195 					       "nlmsg_type NLMSG_ERROR has short nlmsg_len %d",
196 					       h->nlmsg_len);
197 				} else {
198 					syslog(LOG_ERR, "NLM query failed %s",
199 					       strerror(-err->error));
200 				}
201 
202 				return -1;
203 			}
204 
205 			if (h->nlmsg_type != SOCK_DIAG_BY_FAMILY) {
206 				syslog(LOG_ERR, "unexpected nlmsg_type %u\n",
207 				       (unsigned)h->nlmsg_type);
208 				return -1;
209 			}
210 
211 			diag = (const struct unix_diag_msg *)NLMSG_DATA(h);
212 
213 			if (h->nlmsg_len < NLMSG_LENGTH(sizeof(*diag))) {
214 				syslog(LOG_ERR,
215 				       "nlmsg_type SOCK_DIAG_BY_FAMILY has short nlmsg_len %d",
216 				       h->nlmsg_len);
217 				return -1;
218 			}
219 
220 			if (diag->udiag_family != AF_UNIX) {
221 				syslog(LOG_ERR, "unexpected family %u\n",
222 				       diag->udiag_family);
223 				return -1;
224 			}
225 
226 			if (process(diag, h->nlmsg_len, context))
227 				return -1;
228 		}
229 	}
230 }
231 
sockdiag_query_unix_domain_socket_queue_length(const char * socket_name,int * sock_rqueue,int * sock_wqueue)232 int sockdiag_query_unix_domain_socket_queue_length(const char *socket_name,
233 						   int *sock_rqueue,
234 						   int *sock_wqueue)
235 {
236 	int err;
237 	int ret = -1;
238 	struct match_name_context ctx = {
239 		.name = socket_name,
240 		.inode = 0
241 	};
242 
243 	int fd = socket(AF_NETLINK, SOCK_RAW, NETLINK_SOCK_DIAG);
244 
245 	if (fd < 0) {
246 		err = errno;
247 		syslog(LOG_ERR, "socket failed %s", strerror(err));
248 		goto cleanup;
249 	}
250 
251 	if (send_query
252 	    (fd, 0, 1 << TCP_LISTEN, UDIAG_SHOW_NAME | UDIAG_SHOW_RQLEN))
253 		goto cleanup;
254 
255 	if (receive_responses(fd, match_name, &ctx))
256 		goto cleanup;
257 
258 	*sock_rqueue = ctx.rqlen.udiag_rqueue;
259 	*sock_wqueue = ctx.rqlen.udiag_wqueue;
260 
261 	ret = 0;
262 
263  cleanup:
264 	if (fd >= 0) {
265 		close(fd);
266 	}
267 	return ret;
268 }
269 #else
sockdiag_query_unix_domain_socket_queue_length(const char * socket_name,int * sock_rqueue,int * sock_wqueue)270 int sockdiag_query_unix_domain_socket_queue_length(const char *socket_name,
271 						   int *sock_rqueue,
272 						   int *sock_wqueue)
273 {
274 	return -1;
275 }
276 #endif
277