1 /*
2 * Copyright (C) 2020 Microsoft Corporation
3 *
4 * This file is part of ocserv.
5 *
6 * ocserv is free software: you can redistribute it and/or modify it
7 * under the terms of the GNU General Public License as published by
8 * the Free Software Foundation, either version 2 of the License, or
9 * (at your option) any later version.
10 *
11 * ocserv is distributed in the hope that it will be useful, but
12 * WITHOUT ANY WARRANTY; without even the implied warranty of
13 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
14 * General Public License for more details.
15 *
16 * You should have received a copy of the GNU General Public License
17 * along with this program. If not, see <http://www.gnu.org/licenses/>.
18 */
19
20 #include <config.h>
21 #include <errno.h>
22 #include <stdio.h>
23 #include <string.h>
24 #include <unistd.h>
25
26 #if defined(ENABLE_ADAPTIVE_RATE_LIMIT_SUPPORT)
27
28 #include <sys/socket.h>
29 #include <sys/un.h>
30 #include <linux/netlink.h>
31 #include <linux/rtnetlink.h>
32 #include <linux/sock_diag.h>
33 #include <linux/unix_diag.h>
34 #include <netinet/tcp.h>
35 #include <sys/syslog.h>
36
send_query(int fd,int inode,int states,int show)37 static int send_query(int fd, int inode, int states, int show)
38 {
39 int err;
40 struct sockaddr_nl nladdr = {
41 .nl_family = AF_NETLINK
42 };
43 struct {
44 struct nlmsghdr nlh;
45 struct unix_diag_req udr;
46 } req = {
47 .nlh = {
48 .nlmsg_len = sizeof(req),.nlmsg_type =
49 SOCK_DIAG_BY_FAMILY,.nlmsg_flags =
50 NLM_F_REQUEST | (inode ? 0 : NLM_F_DUMP)
51 }
52 ,.udr = {
53 .sdiag_family = AF_UNIX,.udiag_states =
54 states,.udiag_show = show,.udiag_ino = inode}
55 };
56 struct iovec iov = {
57 .iov_base = &req,
58 .iov_len = sizeof(req)
59 };
60 struct msghdr msg = {
61 .msg_name = (void *)&nladdr,
62 .msg_namelen = sizeof(nladdr),
63 .msg_iov = &iov,
64 .msg_iovlen = 1
65 };
66
67 for (;;) {
68 if (sendmsg(fd, &msg, 0) < 0) {
69 if (errno == EINTR)
70 continue;
71 err = errno;
72
73 syslog(LOG_ERR, "sendmsg failed %s", strerror(err));
74 return -1;
75 }
76
77 return 0;
78 }
79 }
80
81 typedef int (*process_response)(const struct unix_diag_msg * diag,
82 unsigned int len, void *context);
83
84 struct match_name_context {
85 const char *name;
86 int inode;
87 struct unix_diag_rqlen rqlen;
88 };
89
match_name(const struct unix_diag_msg * diag,unsigned int len,void * context)90 static int match_name(const struct unix_diag_msg *diag, unsigned int len,
91 void *context)
92 {
93 struct match_name_context *ctx = (struct match_name_context *)context;
94
95 struct rtattr *attr;
96 unsigned int rta_len = len - NLMSG_LENGTH(sizeof(*diag));
97 size_t path_len = 0;
98 char path[sizeof(((struct sockaddr_un *) 0)->sun_path) + 1];
99 struct unix_diag_rqlen rqlen;
100 int rqlen_valid = 0;
101
102 for (attr = (struct rtattr *)(diag + 1);
103 RTA_OK(attr, rta_len); attr = RTA_NEXT(attr, rta_len)) {
104 switch (attr->rta_type) {
105 case UNIX_DIAG_NAME:
106 if (!path_len) {
107 path_len = RTA_PAYLOAD(attr);
108 if (path_len > sizeof(path) - 1)
109 path_len = sizeof(path) - 1;
110 memcpy(path, RTA_DATA(attr), path_len);
111 path[path_len] = '\0';
112 }
113 break;
114 case UNIX_DIAG_RQLEN:
115 if (RTA_PAYLOAD(attr) != sizeof(rqlen))
116 return -1;
117 memcpy(&rqlen, RTA_DATA(attr), sizeof(rqlen));
118 rqlen_valid = 1;
119 break;
120 }
121 }
122
123 if (path_len == 0) {
124 syslog(LOG_ERR, "UNIX_DIAG_NAME not present in response");
125 return -1;
126 }
127
128 if (rqlen_valid == 0) {
129 syslog(LOG_ERR, "UNIX_DIAG_RQLEN not present in response");
130 return -1;
131 }
132
133 if (strcmp(path, ctx->name) == 0) {
134 ctx->inode = diag->udiag_ino;
135 ctx->rqlen = rqlen;
136 }
137
138 return 0;
139 }
140
receive_responses(int fd,process_response process,void * context)141 static int receive_responses(int fd, process_response process, void *context)
142 {
143 int err;
144 long buf[8192 / sizeof(long)];
145 struct sockaddr_nl nladdr = {
146 .nl_family = AF_NETLINK
147 };
148 struct iovec iov = {
149 .iov_base = buf,
150 .iov_len = sizeof(buf)
151 };
152 int flags = 0;
153
154 for (;;) {
155 struct msghdr msg = {
156 .msg_name = (void *)&nladdr,
157 .msg_namelen = sizeof(nladdr),
158 .msg_iov = &iov,
159 .msg_iovlen = 1
160 };
161
162 ssize_t ret = recvmsg(fd, &msg, flags);
163
164 if (ret < 0) {
165 if (errno == EINTR)
166 continue;
167 err = errno;
168 syslog(LOG_ERR, "recvmsg failed %s", strerror(err));
169 return -1;
170 }
171
172 if (ret == 0) {
173 syslog(LOG_ERR, "recvmsg returned empty response");
174 return -1;
175 }
176
177 const struct nlmsghdr *h = (struct nlmsghdr *)buf;
178
179 if (!NLMSG_OK(h, ret)) {
180 syslog(LOG_ERR, "!NLMSG_OK");
181 return -1;
182 }
183
184 for (; NLMSG_OK(h, ret); h = NLMSG_NEXT(h, ret)) {
185 const struct unix_diag_msg *diag;
186
187 if (h->nlmsg_type == NLMSG_DONE)
188 return 0;
189
190 if (h->nlmsg_type == NLMSG_ERROR) {
191 const struct nlmsgerr *err = NLMSG_DATA(h);
192
193 if (h->nlmsg_len < NLMSG_LENGTH(sizeof(*err))) {
194 syslog(LOG_ERR,
195 "nlmsg_type NLMSG_ERROR has short nlmsg_len %d",
196 h->nlmsg_len);
197 } else {
198 syslog(LOG_ERR, "NLM query failed %s",
199 strerror(-err->error));
200 }
201
202 return -1;
203 }
204
205 if (h->nlmsg_type != SOCK_DIAG_BY_FAMILY) {
206 syslog(LOG_ERR, "unexpected nlmsg_type %u\n",
207 (unsigned)h->nlmsg_type);
208 return -1;
209 }
210
211 diag = (const struct unix_diag_msg *)NLMSG_DATA(h);
212
213 if (h->nlmsg_len < NLMSG_LENGTH(sizeof(*diag))) {
214 syslog(LOG_ERR,
215 "nlmsg_type SOCK_DIAG_BY_FAMILY has short nlmsg_len %d",
216 h->nlmsg_len);
217 return -1;
218 }
219
220 if (diag->udiag_family != AF_UNIX) {
221 syslog(LOG_ERR, "unexpected family %u\n",
222 diag->udiag_family);
223 return -1;
224 }
225
226 if (process(diag, h->nlmsg_len, context))
227 return -1;
228 }
229 }
230 }
231
sockdiag_query_unix_domain_socket_queue_length(const char * socket_name,int * sock_rqueue,int * sock_wqueue)232 int sockdiag_query_unix_domain_socket_queue_length(const char *socket_name,
233 int *sock_rqueue,
234 int *sock_wqueue)
235 {
236 int err;
237 int ret = -1;
238 struct match_name_context ctx = {
239 .name = socket_name,
240 .inode = 0
241 };
242
243 int fd = socket(AF_NETLINK, SOCK_RAW, NETLINK_SOCK_DIAG);
244
245 if (fd < 0) {
246 err = errno;
247 syslog(LOG_ERR, "socket failed %s", strerror(err));
248 goto cleanup;
249 }
250
251 if (send_query
252 (fd, 0, 1 << TCP_LISTEN, UDIAG_SHOW_NAME | UDIAG_SHOW_RQLEN))
253 goto cleanup;
254
255 if (receive_responses(fd, match_name, &ctx))
256 goto cleanup;
257
258 *sock_rqueue = ctx.rqlen.udiag_rqueue;
259 *sock_wqueue = ctx.rqlen.udiag_wqueue;
260
261 ret = 0;
262
263 cleanup:
264 if (fd >= 0) {
265 close(fd);
266 }
267 return ret;
268 }
269 #else
sockdiag_query_unix_domain_socket_queue_length(const char * socket_name,int * sock_rqueue,int * sock_wqueue)270 int sockdiag_query_unix_domain_socket_queue_length(const char *socket_name,
271 int *sock_rqueue,
272 int *sock_wqueue)
273 {
274 return -1;
275 }
276 #endif
277