1 /*
2  * virnetsocket.c: generic network socket handling
3  *
4  * Copyright (C) 2006-2014 Red Hat, Inc.
5  * Copyright (C) 2006 Daniel P. Berrange
6  *
7  * This library is free software; you can redistribute it and/or
8  * modify it under the terms of the GNU Lesser General Public
9  * License as published by the Free Software Foundation; either
10  * version 2.1 of the License, or (at your option) any later version.
11  *
12  * This library is distributed in the hope that it will be useful,
13  * but WITHOUT ANY WARRANTY; without even the implied warranty of
14  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
15  * Lesser General Public License for more details.
16  *
17  * You should have received a copy of the GNU Lesser General Public
18  * License along with this library.  If not, see
19  * <http://www.gnu.org/licenses/>.
20  */
21 
22 #include <config.h>
23 
24 #include <sys/stat.h>
25 #include <unistd.h>
26 #include <signal.h>
27 #include <fcntl.h>
28 #ifdef WITH_IFADDRS_H
29 # include <ifaddrs.h>
30 #endif
31 
32 #ifdef WITH_SYS_UCRED_H
33 # include <sys/ucred.h>
34 #endif
35 
36 #ifdef WITH_SELINUX
37 # include <selinux/selinux.h>
38 #endif
39 
40 #include "virsocket.h"
41 #include "virnetsocket.h"
42 #include "virutil.h"
43 #include "viralloc.h"
44 #include "virerror.h"
45 #include "virlog.h"
46 #include "virfile.h"
47 #include "virthread.h"
48 #include "virpidfile.h"
49 #include "virprobe.h"
50 #include "virprocess.h"
51 #include "virstring.h"
52 
53 #if WITH_SSH2
54 # include "virnetsshsession.h"
55 #endif
56 
57 #if WITH_LIBSSH
58 # include "virnetlibsshsession.h"
59 #endif
60 
61 #define VIR_FROM_THIS VIR_FROM_RPC
62 
63 VIR_LOG_INIT("rpc.netsocket");
64 
65 struct _virNetSocket {
66     virObjectLockable parent;
67 
68     int fd;
69     int watch;
70     pid_t pid;
71     int errfd;
72     bool isClient;
73     bool ownsFd;
74     bool quietEOF;
75     bool unlinkUNIX;
76 
77     /* Event callback fields */
78     virNetSocketIOFunc func;
79     void *opaque;
80     virFreeCallback ff;
81 
82     virSocketAddr localAddr;
83     virSocketAddr remoteAddr;
84     char *localAddrStrSASL;
85     char *remoteAddrStrSASL;
86     char *remoteAddrStrURI;
87 
88     virNetTLSSession *tlsSession;
89 #if WITH_SASL
90     virNetSASLSession *saslSession;
91 
92     const char *saslDecoded;
93     size_t saslDecodedLength;
94     size_t saslDecodedOffset;
95 
96     const char *saslEncoded;
97     size_t saslEncodedLength;
98     size_t saslEncodedRawLength;
99     size_t saslEncodedOffset;
100 #endif
101 #if WITH_SSH2
102     virNetSSHSession *sshSession;
103 #endif
104 #if WITH_LIBSSH
105     virNetLibsshSession *libsshSession;
106 #endif
107 };
108 
109 
110 static virClass *virNetSocketClass;
111 static void virNetSocketDispose(void *obj);
112 
virNetSocketOnceInit(void)113 static int virNetSocketOnceInit(void)
114 {
115     if (!VIR_CLASS_NEW(virNetSocket, virClassForObjectLockable()))
116         return -1;
117 
118     return 0;
119 }
120 
121 VIR_ONCE_GLOBAL_INIT(virNetSocket);
122 
123 
124 #ifndef WIN32
virNetSocketForkDaemon(const char * binary)125 static int virNetSocketForkDaemon(const char *binary)
126 {
127     g_autoptr(virCommand) cmd = virCommandNewArgList(binary,
128                                                      "--timeout=120",
129                                                      NULL);
130 
131     virCommandAddEnvPassCommon(cmd);
132     virCommandAddEnvPass(cmd, "XDG_CACHE_HOME");
133     virCommandAddEnvPass(cmd, "XDG_CONFIG_HOME");
134     virCommandAddEnvPass(cmd, "XDG_RUNTIME_DIR");
135     virCommandClearCaps(cmd);
136     virCommandDaemonize(cmd);
137     return virCommandRun(cmd, NULL);
138 }
139 #endif
140 
141 
142 static int G_GNUC_UNUSED
virNetSocketCheckProtocolByLookup(const char * address,int family,bool * hasFamily)143 virNetSocketCheckProtocolByLookup(const char *address,
144                                   int family,
145                                   bool *hasFamily)
146 {
147     struct addrinfo hints;
148     struct addrinfo *ai = NULL;
149     int gaierr;
150 
151     memset(&hints, 0, sizeof(hints));
152     hints.ai_family = family;
153     hints.ai_flags = AI_PASSIVE | AI_ADDRCONFIG;
154     hints.ai_socktype = SOCK_STREAM;
155 
156     if ((gaierr = getaddrinfo(address, NULL, &hints, &ai)) != 0) {
157         *hasFamily = false;
158 
159         if (gaierr == EAI_FAMILY ||
160 #ifdef EAI_ADDRFAMILY
161             gaierr == EAI_ADDRFAMILY ||
162 #endif
163             gaierr == EAI_NONAME) {
164         } else {
165             virReportError(VIR_ERR_INTERNAL_ERROR,
166                            _("Cannot resolve %s address: %s"),
167                            address,
168                            gai_strerror(gaierr));
169             return -1;
170         }
171     } else {
172         *hasFamily = true;
173     }
174 
175     freeaddrinfo(ai);
176     return 0;
177 }
178 
virNetSocketCheckProtocols(bool * hasIPv4,bool * hasIPv6)179 int virNetSocketCheckProtocols(bool *hasIPv4,
180                                bool *hasIPv6)
181 {
182 #ifdef WITH_IFADDRS_H
183     struct ifaddrs *ifaddr = NULL, *ifa;
184 
185     *hasIPv4 = *hasIPv6 = false;
186 
187     if (getifaddrs(&ifaddr) < 0) {
188         virReportSystemError(errno, "%s",
189                              _("Cannot get host interface addresses"));
190         return -1;
191     }
192 
193     for (ifa = ifaddr; ifa != NULL; ifa = ifa->ifa_next) {
194         if (!ifa->ifa_addr)
195             continue;
196 
197         if (ifa->ifa_addr->sa_family == AF_INET)
198             *hasIPv4 = true;
199         if (ifa->ifa_addr->sa_family == AF_INET6)
200             *hasIPv6 = true;
201     }
202 
203     freeifaddrs(ifaddr);
204 
205     if (*hasIPv4 &&
206         virNetSocketCheckProtocolByLookup("127.0.0.1", AF_INET, hasIPv4) < 0)
207         return -1;
208 
209     if (*hasIPv6 &&
210         virNetSocketCheckProtocolByLookup("::1", AF_INET6, hasIPv6) < 0)
211         return -1;
212 
213     VIR_DEBUG("Protocols: v4 %d v6 %d", *hasIPv4, *hasIPv6);
214 
215     return 0;
216 #else
217     *hasIPv4 = *hasIPv6 = false;
218     virReportError(VIR_ERR_NO_SUPPORT, "%s",
219                    _("Cannot check address family on this platform"));
220     return -1;
221 #endif
222 }
223 
224 
225 static virNetSocket *
virNetSocketNew(virSocketAddr * localAddr,virSocketAddr * remoteAddr,bool isClient,int fd,int errfd,pid_t pid,bool unlinkUNIX)226 virNetSocketNew(virSocketAddr *localAddr,
227                 virSocketAddr *remoteAddr,
228                 bool isClient,
229                 int fd,
230                 int errfd,
231                 pid_t pid,
232                 bool unlinkUNIX)
233 {
234     g_autoptr(virNetSocket) sock = NULL;
235     int no_slow_start = 1;
236 
237     if (virNetSocketInitialize() < 0)
238         return NULL;
239 
240     VIR_DEBUG("localAddr=%p remoteAddr=%p fd=%d errfd=%d pid=%lld",
241               localAddr, remoteAddr,
242               fd, errfd, (long long)pid);
243 
244     if (virSetCloseExec(fd) < 0) {
245         virReportSystemError(errno, "%s",
246                              _("Unable to set close-on-exec flag"));
247        return NULL;
248     }
249     if (virSetNonBlock(fd) < 0) {
250         virReportSystemError(errno, "%s",
251                              _("Unable to enable non-blocking flag"));
252         return NULL;
253     }
254 
255     if (!(sock = virObjectLockableNew(virNetSocketClass)))
256         return NULL;
257 
258     if (localAddr)
259         sock->localAddr = *localAddr;
260     if (remoteAddr)
261         sock->remoteAddr = *remoteAddr;
262     sock->fd = fd;
263     sock->errfd = errfd;
264     sock->pid = pid;
265     sock->watch = -1;
266     sock->ownsFd = true;
267     sock->isClient = isClient;
268     sock->unlinkUNIX = unlinkUNIX;
269 
270     /* Disable nagle for TCP sockets */
271     if (sock->localAddr.data.sa.sa_family == AF_INET ||
272         sock->localAddr.data.sa.sa_family == AF_INET6) {
273         if (setsockopt(fd, IPPROTO_TCP, TCP_NODELAY,
274                        &no_slow_start,
275                        sizeof(no_slow_start)) < 0) {
276             virReportSystemError(errno, "%s",
277                                  _("Unable to disable nagle algorithm"));
278             goto error;
279         }
280     }
281 
282 
283     if (localAddr &&
284         !(sock->localAddrStrSASL = virSocketAddrFormatFull(localAddr, true, ";")))
285         goto error;
286 
287     if (remoteAddr &&
288         !(sock->remoteAddrStrSASL = virSocketAddrFormatFull(remoteAddr, true, ";")))
289         goto error;
290 
291     if (remoteAddr &&
292         !(sock->remoteAddrStrURI = virSocketAddrFormatFull(remoteAddr, true, NULL)))
293         goto error;
294 
295     PROBE(RPC_SOCKET_NEW,
296           "sock=%p fd=%d errfd=%d pid=%lld localAddr=%s, remoteAddr=%s",
297           sock, fd, errfd, (long long)pid,
298           NULLSTR(sock->localAddrStrSASL), NULLSTR(sock->remoteAddrStrSASL));
299 
300     return g_steal_pointer(&sock);
301 
302  error:
303     sock->fd = sock->errfd = -1; /* Caller owns fd/errfd on failure */
304     return NULL;
305 }
306 
307 
virNetSocketNewListenTCP(const char * nodename,const char * service,int family,virNetSocket *** retsocks,size_t * nretsocks)308 int virNetSocketNewListenTCP(const char *nodename,
309                              const char *service,
310                              int family,
311                              virNetSocket ***retsocks,
312                              size_t *nretsocks)
313 {
314     virNetSocket **socks = NULL;
315     size_t nsocks = 0;
316     struct addrinfo *ai = NULL;
317     struct addrinfo hints;
318     int fd = -1;
319     size_t i;
320     int socketErrno = 0;
321     int bindErrno = 0;
322     virSocketAddr tmp_addr;
323     int port = 0;
324     int e;
325     struct addrinfo *runp;
326 
327     *retsocks = NULL;
328     *nretsocks = 0;
329 
330     memset(&hints, 0, sizeof(hints));
331     hints.ai_family = family;
332     hints.ai_flags = AI_PASSIVE;
333     hints.ai_socktype = SOCK_STREAM;
334 
335     /* Don't use ADDRCONFIG for binding to the wildcard address.
336      * Just catch the error returned by socket() if the system has
337      * no IPv6 support.
338      *
339      * This allows libvirtd to be started in parallel with the network
340      * startup in most cases.
341      */
342     if (nodename &&
343         !(virSocketAddrParseAny(&tmp_addr, nodename, AF_UNSPEC, false) > 0 &&
344           virSocketAddrIsWildcard(&tmp_addr)))
345         hints.ai_flags |= AI_ADDRCONFIG;
346 
347     e = getaddrinfo(nodename, service, &hints, &ai);
348     if (e != 0) {
349         virReportError(VIR_ERR_SYSTEM_ERROR,
350                        _("Unable to resolve address '%s' service '%s': %s"),
351                        nodename, service, gai_strerror(e));
352         return -1;
353     }
354 
355     runp = ai;
356     while (runp) {
357         virSocketAddr addr;
358 
359         memset(&addr, 0, sizeof(addr));
360 
361         if ((fd = socket(runp->ai_family, runp->ai_socktype,
362                          runp->ai_protocol)) < 0) {
363             if (errno == EAFNOSUPPORT) {
364                 socketErrno = errno;
365                 runp = runp->ai_next;
366                 continue;
367             }
368             virReportSystemError(errno, "%s", _("Unable to create socket"));
369             goto error;
370         }
371 
372         if (virSetSockReuseAddr(fd, true) < 0)
373             goto error;
374 
375 #ifdef IPV6_V6ONLY
376         if (runp->ai_family == PF_INET6) {
377             int on = 1;
378             /*
379              * Normally on Linux an INET6 socket will bind to the INET4
380              * address too. If getaddrinfo returns results with INET4
381              * first though, this will result in INET6 binding failing.
382              * We can trivially cope with multiple server sockets, so
383              * we force it to only listen on IPv6
384              */
385             if (setsockopt(fd, IPPROTO_IPV6, IPV6_V6ONLY,
386                            (void*)&on, sizeof(on)) < 0) {
387                 virReportSystemError(errno, "%s",
388                                      _("Unable to force bind to IPv6 only"));
389                 goto error;
390             }
391         }
392 #endif
393 
394         addr.len = runp->ai_addrlen;
395         memcpy(&addr.data.sa, runp->ai_addr, runp->ai_addrlen);
396 
397         /* When service is NULL, we let the kernel auto-select the
398          * port. Once we've selected a port for one IP protocol
399          * though, we want to ensure we pick the same port for the
400          * other IP protocol
401          */
402         if (port != 0 && service == NULL) {
403             if (addr.data.sa.sa_family == AF_INET) {
404                 addr.data.inet4.sin_port = port;
405             } else if (addr.data.sa.sa_family == AF_INET6) {
406                 addr.data.inet6.sin6_port = port;
407             }
408             VIR_DEBUG("Used saved port %d", port);
409         }
410 
411         if (bind(fd, &addr.data.sa, addr.len) < 0) {
412             if (errno != EADDRINUSE && errno != EADDRNOTAVAIL) {
413                 virReportSystemError(errno, "%s", _("Unable to bind to port"));
414                 goto error;
415             }
416             bindErrno = errno;
417             closesocket(fd);
418             fd = -1;
419             runp = runp->ai_next;
420             continue;
421         }
422 
423         addr.len = sizeof(addr.data);
424         if (getsockname(fd, &addr.data.sa, &addr.len) < 0) {
425             virReportSystemError(errno, "%s", _("Unable to get local socket name"));
426             goto error;
427         }
428 
429         if (port == 0 && service == NULL) {
430             if (addr.data.sa.sa_family == AF_INET)
431                 port = addr.data.inet4.sin_port;
432             else if (addr.data.sa.sa_family == AF_INET6)
433                 port = addr.data.inet6.sin6_port;
434             VIR_DEBUG("Saved port %d", port);
435         }
436 
437         VIR_DEBUG("%p f=%d f=%d", &addr, runp->ai_family, addr.data.sa.sa_family);
438 
439         VIR_EXPAND_N(socks, nsocks, 1);
440 
441         if (!(socks[nsocks-1] = virNetSocketNew(&addr, NULL, false, fd, -1, 0, false)))
442             goto error;
443         runp = runp->ai_next;
444         fd = -1;
445     }
446 
447     if (nsocks == 0) {
448         if (bindErrno)
449             virReportSystemError(bindErrno, "%s", _("Unable to bind to port"));
450         else if (socketErrno)
451             virReportSystemError(socketErrno, "%s", _("Unable to create socket"));
452         else
453             virReportError(VIR_ERR_INTERNAL_ERROR, "%s", _("No addresses to bind to"));
454         goto error;
455     }
456 
457     freeaddrinfo(ai);
458 
459     *retsocks = socks;
460     *nretsocks = nsocks;
461     return 0;
462 
463  error:
464     for (i = 0; i < nsocks; i++)
465         virObjectUnref(socks[i]);
466     VIR_FREE(socks);
467     freeaddrinfo(ai);
468     if (fd != -1)
469         closesocket(fd);
470     return -1;
471 }
472 
473 
474 #ifndef WIN32
virNetSocketNewListenUNIX(const char * path,mode_t mask,uid_t user,gid_t grp,virNetSocket ** retsock)475 int virNetSocketNewListenUNIX(const char *path,
476                               mode_t mask,
477                               uid_t user,
478                               gid_t grp,
479                               virNetSocket **retsock)
480 {
481     virSocketAddr addr;
482     mode_t oldmask;
483     int fd;
484 
485     *retsock = NULL;
486 
487     memset(&addr, 0, sizeof(addr));
488 
489     addr.len = sizeof(addr.data.un);
490 
491     if ((fd = socket(PF_UNIX, SOCK_STREAM, 0)) < 0) {
492         virReportSystemError(errno, "%s", _("Failed to create socket"));
493         goto error;
494     }
495 
496     addr.data.un.sun_family = AF_UNIX;
497     if (virStrcpyStatic(addr.data.un.sun_path, path) < 0) {
498         virReportSystemError(ENAMETOOLONG,
499                              _("Path %s too long for unix socket"), path);
500         goto error;
501     }
502     if (addr.data.un.sun_path[0] == '@')
503         addr.data.un.sun_path[0] = '\0';
504     else
505         unlink(addr.data.un.sun_path);
506 
507     oldmask = umask(~mask);
508 
509     if (bind(fd, &addr.data.sa, addr.len) < 0) {
510         umask(oldmask);
511         virReportSystemError(errno,
512                              _("Failed to bind socket to '%s'"),
513                              path);
514         goto error;
515     }
516     umask(oldmask);
517 
518     /* chown() doesn't work for abstract sockets but we use them only
519      * if libvirtd runs unprivileged
520      */
521     if (grp != 0 && chown(path, user, grp)) {
522         virReportSystemError(errno,
523                              _("Failed to change ownership of '%s' to %d:%d"),
524                              path, (int)user, (int)grp);
525         goto error;
526     }
527 
528     if (!(*retsock = virNetSocketNew(&addr, NULL, false, fd, -1, 0, true)))
529         goto error;
530 
531     return 0;
532 
533  error:
534     if (path[0] != '@')
535         unlink(path);
536     if (fd != -1)
537         closesocket(fd);
538     return -1;
539 }
540 #else
virNetSocketNewListenUNIX(const char * path G_GNUC_UNUSED,mode_t mask G_GNUC_UNUSED,uid_t user G_GNUC_UNUSED,gid_t grp G_GNUC_UNUSED,virNetSocket ** retsock G_GNUC_UNUSED)541 int virNetSocketNewListenUNIX(const char *path G_GNUC_UNUSED,
542                               mode_t mask G_GNUC_UNUSED,
543                               uid_t user G_GNUC_UNUSED,
544                               gid_t grp G_GNUC_UNUSED,
545                               virNetSocket **retsock G_GNUC_UNUSED)
546 {
547     virReportSystemError(ENOSYS, "%s",
548                          _("UNIX sockets are not supported on this platform"));
549     return -1;
550 }
551 #endif
552 
virNetSocketNewListenFD(int fd,bool unlinkUNIX,virNetSocket ** retsock)553 int virNetSocketNewListenFD(int fd,
554                             bool unlinkUNIX,
555                             virNetSocket **retsock)
556 {
557     virSocketAddr addr;
558     *retsock = NULL;
559 
560     memset(&addr, 0, sizeof(addr));
561 
562     addr.len = sizeof(addr.data);
563     if (getsockname(fd, &addr.data.sa, &addr.len) < 0) {
564         virReportSystemError(errno, "%s", _("Unable to get local socket name"));
565         return -1;
566     }
567 
568     if (!(*retsock = virNetSocketNew(&addr, NULL, false, fd, -1, 0, unlinkUNIX)))
569         return -1;
570 
571     return 0;
572 }
573 
574 
virNetSocketNewConnectTCP(const char * nodename,const char * service,int family,virNetSocket ** retsock)575 int virNetSocketNewConnectTCP(const char *nodename,
576                               const char *service,
577                               int family,
578                               virNetSocket **retsock)
579 {
580     struct addrinfo *ai = NULL;
581     struct addrinfo hints;
582     int fd = -1;
583     virSocketAddr localAddr;
584     virSocketAddr remoteAddr;
585     struct addrinfo *runp;
586     int savedErrno = ENOENT;
587     int e;
588 
589     *retsock = NULL;
590 
591     memset(&localAddr, 0, sizeof(localAddr));
592     memset(&remoteAddr, 0, sizeof(remoteAddr));
593 
594     memset(&hints, 0, sizeof(hints));
595     hints.ai_family = family;
596     hints.ai_flags = AI_PASSIVE | AI_ADDRCONFIG | AI_V4MAPPED;
597     hints.ai_socktype = SOCK_STREAM;
598 
599     e = getaddrinfo(nodename, service, &hints, &ai);
600     if (e != 0) {
601         virReportError(VIR_ERR_SYSTEM_ERROR,
602                        _("Unable to resolve address '%s' service '%s': %s"),
603                        nodename, service, gai_strerror(e));
604         return -1;
605     }
606 
607     runp = ai;
608     while (runp) {
609         if ((fd = socket(runp->ai_family, runp->ai_socktype,
610                          runp->ai_protocol)) < 0) {
611             virReportSystemError(errno, "%s", _("Unable to create socket"));
612             goto error;
613         }
614 
615         if (virSetSockReuseAddr(fd, false) < 0)
616             VIR_WARN("Unable to enable port reuse");
617 
618         if (connect(fd, runp->ai_addr, runp->ai_addrlen) >= 0)
619             break;
620 
621         savedErrno = errno;
622         closesocket(fd);
623         fd = -1;
624         runp = runp->ai_next;
625     }
626 
627     if (fd == -1) {
628         virReportSystemError(savedErrno,
629                              _("unable to connect to server at '%s:%s'"),
630                              nodename, service);
631         goto error;
632     }
633 
634     localAddr.len = sizeof(localAddr.data);
635     if (getsockname(fd, &localAddr.data.sa, &localAddr.len) < 0) {
636         virReportSystemError(errno, "%s", _("Unable to get local socket name"));
637         goto error;
638     }
639 
640     remoteAddr.len = sizeof(remoteAddr.data);
641     if (getpeername(fd, &remoteAddr.data.sa, &remoteAddr.len) < 0) {
642         virReportSystemError(errno, "%s", _("Unable to get remote socket name"));
643         goto error;
644     }
645 
646     if (!(*retsock = virNetSocketNew(&localAddr, &remoteAddr, true, fd, -1, 0, false)))
647         goto error;
648 
649     freeaddrinfo(ai);
650 
651     return 0;
652 
653  error:
654     freeaddrinfo(ai);
655     if (fd != -1)
656         closesocket(fd);
657     return -1;
658 }
659 
660 
661 #ifndef WIN32
virNetSocketNewConnectUNIX(const char * path,const char * spawnDaemonPath,virNetSocket ** retsock)662 int virNetSocketNewConnectUNIX(const char *path,
663                                const char *spawnDaemonPath,
664                                virNetSocket **retsock)
665 {
666     g_autofree char *lockpath = NULL;
667     VIR_AUTOCLOSE lockfd = -1;
668     int fd = -1;
669     int retries = 500;
670     virSocketAddr localAddr;
671     virSocketAddr remoteAddr;
672     g_autofree char *rundir = NULL;
673     int ret = -1;
674     bool daemonLaunched = false;
675 
676     VIR_DEBUG("path=%s spawnDaemonPath=%s", path, NULLSTR(spawnDaemonPath));
677 
678     memset(&localAddr, 0, sizeof(localAddr));
679     memset(&remoteAddr, 0, sizeof(remoteAddr));
680 
681     remoteAddr.len = sizeof(remoteAddr.data.un);
682 
683     if (spawnDaemonPath) {
684         g_autofree char *binname = g_path_get_basename(spawnDaemonPath);
685         rundir = virGetUserRuntimeDirectory();
686 
687         if (g_mkdir_with_parents(rundir, 0700) < 0) {
688             virReportSystemError(errno,
689                                  _("Cannot create user runtime directory '%s'"),
690                                  rundir);
691             goto cleanup;
692         }
693 
694         lockpath = g_strdup_printf("%s/%s.lock", rundir, binname);
695 
696         if ((lockfd = open(lockpath, O_RDWR | O_CREAT, 0600)) < 0 ||
697             virSetCloseExec(lockfd) < 0) {
698             virReportSystemError(errno, _("Unable to create lock '%s'"), lockpath);
699             goto cleanup;
700         }
701 
702         if (virFileLock(lockfd, false, 0, 1, true) < 0) {
703             virReportSystemError(errno, _("Unable to lock '%s'"), lockpath);
704             goto cleanup;
705         }
706     }
707 
708     if ((fd = socket(PF_UNIX, SOCK_STREAM, 0)) < 0) {
709         virReportSystemError(errno, "%s", _("Failed to create socket"));
710         goto cleanup;
711     }
712 
713     remoteAddr.data.un.sun_family = AF_UNIX;
714     if (virStrcpyStatic(remoteAddr.data.un.sun_path, path) < 0) {
715         virReportSystemError(ENOMEM, _("Path %s too long for unix socket"), path);
716         goto cleanup;
717     }
718     if (remoteAddr.data.un.sun_path[0] == '@')
719         remoteAddr.data.un.sun_path[0] = '\0';
720 
721     while (retries) {
722         if (connect(fd, &remoteAddr.data.sa, remoteAddr.len) == 0) {
723             VIR_DEBUG("connect() succeeded");
724             break;
725         }
726         VIR_DEBUG("connect() failed: retries=%d errno=%d", retries, errno);
727 
728         retries--;
729         if (!spawnDaemonPath ||
730             retries == 0 ||
731             (errno != ENOENT && errno != ECONNREFUSED)) {
732             virReportSystemError(errno, _("Failed to connect socket to '%s'"),
733                                  path);
734             goto cleanup;
735         }
736 
737         if (!daemonLaunched) {
738             if (virNetSocketForkDaemon(spawnDaemonPath) < 0)
739                 goto cleanup;
740 
741             daemonLaunched = true;
742         }
743 
744         g_usleep(10000);
745     }
746 
747     localAddr.len = sizeof(localAddr.data);
748     if (getsockname(fd, &localAddr.data.sa, &localAddr.len) < 0) {
749         virReportSystemError(errno, "%s", _("Unable to get local socket name"));
750         goto cleanup;
751     }
752 
753     if (!(*retsock = virNetSocketNew(&localAddr, &remoteAddr, true, fd, -1, 0, false)))
754         goto cleanup;
755 
756     ret = 0;
757 
758  cleanup:
759     if (lockfd != -1) {
760         unlink(lockpath);
761     }
762 
763     if (ret < 0 && fd != -1)
764         closesocket(fd);
765 
766     return ret;
767 }
768 #else
virNetSocketNewConnectUNIX(const char * path G_GNUC_UNUSED,const char * spawnDaemonPath G_GNUC_UNUSED,virNetSocket ** retsock G_GNUC_UNUSED)769 int virNetSocketNewConnectUNIX(const char *path G_GNUC_UNUSED,
770                                const char *spawnDaemonPath G_GNUC_UNUSED,
771                                virNetSocket **retsock G_GNUC_UNUSED)
772 {
773     virReportSystemError(ENOSYS, "%s",
774                          _("UNIX sockets are not supported on this platform"));
775     return -1;
776 }
777 #endif
778 
779 
780 #ifndef WIN32
virNetSocketNewConnectCommand(virCommand * cmd,virNetSocket ** retsock)781 int virNetSocketNewConnectCommand(virCommand *cmd,
782                                   virNetSocket **retsock)
783 {
784     pid_t pid = 0;
785     int sv[2] = { -1, -1 };
786     int errfd[2] = { -1, -1 };
787 
788     *retsock = NULL;
789 
790     /* Fork off the external process.  Use socketpair to create a private
791      * (unnamed) Unix domain socket to the child process so we don't have
792      * to faff around with two file descriptors (a la 'pipe(2)').
793      */
794     if (socketpair(PF_UNIX, SOCK_STREAM, 0, sv) < 0) {
795         virReportSystemError(errno, "%s",
796                              _("unable to create socket pair"));
797         goto error;
798     }
799 
800     if (virPipe(errfd) < 0)
801         goto error;
802 
803     virCommandSetInputFD(cmd, sv[1]);
804     virCommandSetOutputFD(cmd, &sv[1]);
805     virCommandSetErrorFD(cmd, &errfd[1]);
806 
807     if (virCommandRunAsync(cmd, &pid) < 0)
808         goto error;
809 
810     /* Parent continues here. */
811     VIR_FORCE_CLOSE(sv[1]);
812     VIR_FORCE_CLOSE(errfd[1]);
813 
814     if (!(*retsock = virNetSocketNew(NULL, NULL, true, sv[0], errfd[0], pid, false)))
815         goto error;
816 
817     return 0;
818 
819  error:
820     VIR_FORCE_CLOSE(sv[0]);
821     VIR_FORCE_CLOSE(sv[1]);
822     VIR_FORCE_CLOSE(errfd[0]);
823     VIR_FORCE_CLOSE(errfd[1]);
824 
825     virCommandAbort(cmd);
826 
827     return -1;
828 }
829 #else
virNetSocketNewConnectCommand(virCommand * cmd G_GNUC_UNUSED,virNetSocket ** retsock G_GNUC_UNUSED)830 int virNetSocketNewConnectCommand(virCommand *cmd G_GNUC_UNUSED,
831                                   virNetSocket **retsock G_GNUC_UNUSED)
832 {
833     virReportSystemError(errno, "%s",
834                          _("Tunnelling sockets not supported on this platform"));
835     return -1;
836 }
837 #endif
838 
virNetSocketNewConnectSSH(const char * nodename,const char * service,const char * binary,const char * username,bool noTTY,bool noVerify,const char * keyfile,const char * command,virNetSocket ** retsock)839 int virNetSocketNewConnectSSH(const char *nodename,
840                               const char *service,
841                               const char *binary,
842                               const char *username,
843                               bool noTTY,
844                               bool noVerify,
845                               const char *keyfile,
846                               const char *command,
847                               virNetSocket **retsock)
848 {
849     g_autoptr(virCommand) cmd = NULL;
850 
851     *retsock = NULL;
852 
853     cmd = virCommandNew(binary ? binary : "ssh");
854     virCommandAddEnvPassCommon(cmd);
855     virCommandAddEnvPass(cmd, "XDG_RUNTIME_DIR");
856     virCommandAddEnvPass(cmd, "KRB5CCNAME");
857     virCommandAddEnvPass(cmd, "SSH_AUTH_SOCK");
858     virCommandAddEnvPass(cmd, "SSH_ASKPASS");
859     virCommandAddEnvPass(cmd, "DISPLAY");
860     virCommandAddEnvPass(cmd, "XAUTHORITY");
861     virCommandClearCaps(cmd);
862 
863     if (service)
864         virCommandAddArgList(cmd, "-p", service, NULL);
865     if (username)
866         virCommandAddArgList(cmd, "-l", username, NULL);
867     if (keyfile)
868         virCommandAddArgList(cmd, "-i", keyfile, NULL);
869     virCommandAddArgList(cmd, "-T", "-e", "none", NULL);
870     if (noTTY)
871         virCommandAddArgList(cmd, "-o", "BatchMode=yes", NULL);
872     if (noVerify)
873         virCommandAddArgList(cmd, "-o", "StrictHostKeyChecking=no", NULL);
874 
875     virCommandAddArgList(cmd, "--", nodename, command, NULL);
876 
877     return virNetSocketNewConnectCommand(cmd, retsock);
878 }
879 
880 #if WITH_SSH2
881 int
virNetSocketNewConnectLibSSH2(const char * host,const char * port,int family,const char * username,const char * privkey,const char * knownHosts,const char * knownHostsVerify,const char * authMethods,const char * command,virConnectAuthPtr auth,virURI * uri,virNetSocket ** retsock)882 virNetSocketNewConnectLibSSH2(const char *host,
883                               const char *port,
884                               int family,
885                               const char *username,
886                               const char *privkey,
887                               const char *knownHosts,
888                               const char *knownHostsVerify,
889                               const char *authMethods,
890                               const char *command,
891                               virConnectAuthPtr auth,
892                               virURI *uri,
893                               virNetSocket **retsock)
894 {
895     g_autoptr(virNetSocket) sock = NULL;
896     virNetSSHSession *sess = NULL;
897     unsigned int verify;
898     int ret = -1;
899     int portN;
900 
901     g_auto(GStrv) authMethodList = NULL;
902     char **authMethodNext;
903 
904     /* port number will be verified while opening the socket */
905     if (virStrToLong_i(port, NULL, 10, &portN) < 0) {
906         virReportError(VIR_ERR_SSH, "%s",
907                        _("Failed to parse port number"));
908         goto error;
909     }
910 
911     /* create ssh session context */
912     if (!(sess = virNetSSHSessionNew()))
913         goto error;
914 
915     /* set ssh session parameters */
916     if (virNetSSHSessionAuthSetCallback(sess, auth) != 0)
917         goto error;
918 
919     if (STRCASEEQ("auto", knownHostsVerify)) {
920         verify = VIR_NET_SSH_HOSTKEY_VERIFY_AUTO_ADD;
921     } else if (STRCASEEQ("ignore", knownHostsVerify)) {
922         verify = VIR_NET_SSH_HOSTKEY_VERIFY_IGNORE;
923     } else if (STRCASEEQ("normal", knownHostsVerify)) {
924         verify = VIR_NET_SSH_HOSTKEY_VERIFY_NORMAL;
925     } else {
926         virReportError(VIR_ERR_INVALID_ARG,
927                        _("Invalid host key verification method: '%s'"),
928                        knownHostsVerify);
929         goto error;
930     }
931 
932     if (virNetSSHSessionSetHostKeyVerification(sess,
933                                                host,
934                                                portN,
935                                                knownHosts,
936                                                verify,
937                                                VIR_NET_SSH_HOSTKEY_FILE_CREATE) != 0)
938         goto error;
939 
940     virNetSSHSessionSetChannelCommand(sess, command);
941 
942     if (!(authMethodList = g_strsplit(authMethods, ",", 0)))
943         goto error;
944 
945     for (authMethodNext = authMethodList; *authMethodNext; authMethodNext++) {
946         const char *authMethod = *authMethodNext;
947 
948         if (STRCASEEQ(authMethod, "keyboard-interactive")) {
949             ret = virNetSSHSessionAuthAddKeyboardAuth(sess, username, -1);
950         } else if (STRCASEEQ(authMethod, "password")) {
951             ret = virNetSSHSessionAuthAddPasswordAuth(sess,
952                                                       uri,
953                                                       username);
954         } else if (STRCASEEQ(authMethod, "privkey")) {
955             ret = virNetSSHSessionAuthAddPrivKeyAuth(sess,
956                                                      username,
957                                                      privkey,
958                                                      NULL);
959         } else if (STRCASEEQ(authMethod, "agent")) {
960             ret = virNetSSHSessionAuthAddAgentAuth(sess, username);
961         } else {
962             virReportError(VIR_ERR_INVALID_ARG,
963                            _("Invalid authentication method: '%s'"),
964                            authMethod);
965             ret = -1;
966             goto error;
967         }
968 
969         if (ret != 0)
970             goto error;
971     }
972 
973     /* connect to remote server */
974     if ((ret = virNetSocketNewConnectTCP(host, port, family, &sock)) < 0)
975         goto error;
976 
977     /* connect to the host using ssh */
978     if ((ret = virNetSSHSessionConnect(sess, virNetSocketGetFD(sock))) != 0)
979         goto error;
980 
981     sock->sshSession = sess;
982     *retsock = g_steal_pointer(&sock);
983 
984     return 0;
985 
986  error:
987     virObjectUnref(sess);
988     return ret;
989 }
990 #else
991 int
virNetSocketNewConnectLibSSH2(const char * host G_GNUC_UNUSED,const char * port G_GNUC_UNUSED,int family G_GNUC_UNUSED,const char * username G_GNUC_UNUSED,const char * privkey G_GNUC_UNUSED,const char * knownHosts G_GNUC_UNUSED,const char * knownHostsVerify G_GNUC_UNUSED,const char * authMethods G_GNUC_UNUSED,const char * command G_GNUC_UNUSED,virConnectAuthPtr auth G_GNUC_UNUSED,virURI * uri G_GNUC_UNUSED,virNetSocket ** retsock G_GNUC_UNUSED)992 virNetSocketNewConnectLibSSH2(const char *host G_GNUC_UNUSED,
993                               const char *port G_GNUC_UNUSED,
994                               int family G_GNUC_UNUSED,
995                               const char *username G_GNUC_UNUSED,
996                               const char *privkey G_GNUC_UNUSED,
997                               const char *knownHosts G_GNUC_UNUSED,
998                               const char *knownHostsVerify G_GNUC_UNUSED,
999                               const char *authMethods G_GNUC_UNUSED,
1000                               const char *command G_GNUC_UNUSED,
1001                               virConnectAuthPtr auth G_GNUC_UNUSED,
1002                               virURI *uri G_GNUC_UNUSED,
1003                               virNetSocket **retsock G_GNUC_UNUSED)
1004 {
1005     virReportSystemError(ENOSYS, "%s",
1006                          _("libssh2 transport support was not enabled"));
1007     return -1;
1008 }
1009 #endif /* WITH_SSH2 */
1010 
1011 #if WITH_LIBSSH
1012 int
virNetSocketNewConnectLibssh(const char * host,const char * port,int family,const char * username,const char * privkey,const char * knownHosts,const char * knownHostsVerify,const char * authMethods,const char * command,virConnectAuthPtr auth,virURI * uri,virNetSocket ** retsock)1013 virNetSocketNewConnectLibssh(const char *host,
1014                              const char *port,
1015                              int family,
1016                              const char *username,
1017                              const char *privkey,
1018                              const char *knownHosts,
1019                              const char *knownHostsVerify,
1020                              const char *authMethods,
1021                              const char *command,
1022                              virConnectAuthPtr auth,
1023                              virURI *uri,
1024                              virNetSocket **retsock)
1025 {
1026     g_autoptr(virNetSocket) sock = NULL;
1027     virNetLibsshSession *sess = NULL;
1028     unsigned int verify;
1029     int ret = -1;
1030     int portN;
1031 
1032     g_auto(GStrv) authMethodList = NULL;
1033     char **authMethodNext;
1034 
1035     /* port number will be verified while opening the socket */
1036     if (virStrToLong_i(port, NULL, 10, &portN) < 0) {
1037         virReportError(VIR_ERR_LIBSSH, "%s",
1038                        _("Failed to parse port number"));
1039         goto error;
1040     }
1041 
1042     /* create ssh session context */
1043     if (!(sess = virNetLibsshSessionNew(username)))
1044         goto error;
1045 
1046     /* set ssh session parameters */
1047     if (virNetLibsshSessionAuthSetCallback(sess, auth) != 0)
1048         goto error;
1049 
1050     if (STRCASEEQ("auto", knownHostsVerify)) {
1051         verify = VIR_NET_LIBSSH_HOSTKEY_VERIFY_AUTO_ADD;
1052     } else if (STRCASEEQ("ignore", knownHostsVerify)) {
1053         verify = VIR_NET_LIBSSH_HOSTKEY_VERIFY_IGNORE;
1054     } else if (STRCASEEQ("normal", knownHostsVerify)) {
1055         verify = VIR_NET_LIBSSH_HOSTKEY_VERIFY_NORMAL;
1056     } else {
1057         virReportError(VIR_ERR_INVALID_ARG,
1058                        _("Invalid host key verification method: '%s'"),
1059                        knownHostsVerify);
1060         goto error;
1061     }
1062 
1063     if (virNetLibsshSessionSetHostKeyVerification(sess,
1064                                                   host,
1065                                                   portN,
1066                                                   knownHosts,
1067                                                   verify) != 0)
1068         goto error;
1069 
1070     virNetLibsshSessionSetChannelCommand(sess, command);
1071 
1072     if (!(authMethodList = g_strsplit(authMethods, ",", 0)))
1073         goto error;
1074 
1075     for (authMethodNext = authMethodList; *authMethodNext; authMethodNext++) {
1076         const char *authMethod = *authMethodNext;
1077 
1078         if (STRCASEEQ(authMethod, "keyboard-interactive")) {
1079             ret = virNetLibsshSessionAuthAddKeyboardAuth(sess, -1);
1080         } else if (STRCASEEQ(authMethod, "password")) {
1081             ret = virNetLibsshSessionAuthAddPasswordAuth(sess, uri);
1082         } else if (STRCASEEQ(authMethod, "privkey")) {
1083             ret = virNetLibsshSessionAuthAddPrivKeyAuth(sess,
1084                                                         privkey,
1085                                                         NULL);
1086         } else if (STRCASEEQ(authMethod, "agent")) {
1087             ret = virNetLibsshSessionAuthAddAgentAuth(sess);
1088         } else {
1089             virReportError(VIR_ERR_INVALID_ARG,
1090                            _("Invalid authentication method: '%s'"),
1091                            authMethod);
1092             ret = -1;
1093             goto error;
1094         }
1095 
1096         if (ret != 0)
1097             goto error;
1098     }
1099 
1100     /* connect to remote server */
1101     if ((ret = virNetSocketNewConnectTCP(host, port, family, &sock)) < 0)
1102         goto error;
1103 
1104     /* connect to the host using ssh */
1105     if ((ret = virNetLibsshSessionConnect(sess, virNetSocketGetFD(sock))) != 0)
1106         goto error;
1107 
1108     sock->libsshSession = sess;
1109     /* libssh owns the FD and closes it on its own, and thus
1110      * we must not close it (otherwise there are warnings about
1111      * trying to close an invalid FD).
1112      */
1113     sock->ownsFd = false;
1114     *retsock = g_steal_pointer(&sock);
1115 
1116     return 0;
1117 
1118  error:
1119     virObjectUnref(sess);
1120     return ret;
1121 }
1122 #else
1123 int
virNetSocketNewConnectLibssh(const char * host G_GNUC_UNUSED,const char * port G_GNUC_UNUSED,int family G_GNUC_UNUSED,const char * username G_GNUC_UNUSED,const char * privkey G_GNUC_UNUSED,const char * knownHosts G_GNUC_UNUSED,const char * knownHostsVerify G_GNUC_UNUSED,const char * authMethods G_GNUC_UNUSED,const char * command G_GNUC_UNUSED,virConnectAuthPtr auth G_GNUC_UNUSED,virURI * uri G_GNUC_UNUSED,virNetSocket ** retsock G_GNUC_UNUSED)1124 virNetSocketNewConnectLibssh(const char *host G_GNUC_UNUSED,
1125                              const char *port G_GNUC_UNUSED,
1126                              int family G_GNUC_UNUSED,
1127                              const char *username G_GNUC_UNUSED,
1128                              const char *privkey G_GNUC_UNUSED,
1129                              const char *knownHosts G_GNUC_UNUSED,
1130                              const char *knownHostsVerify G_GNUC_UNUSED,
1131                              const char *authMethods G_GNUC_UNUSED,
1132                              const char *command G_GNUC_UNUSED,
1133                              virConnectAuthPtr auth G_GNUC_UNUSED,
1134                              virURI *uri G_GNUC_UNUSED,
1135                              virNetSocket **retsock G_GNUC_UNUSED)
1136 {
1137     virReportSystemError(ENOSYS, "%s",
1138                          _("libssh transport support was not enabled"));
1139     return -1;
1140 }
1141 #endif /* WITH_LIBSSH */
1142 
virNetSocketNewConnectExternal(const char ** cmdargv,virNetSocket ** retsock)1143 int virNetSocketNewConnectExternal(const char **cmdargv,
1144                                    virNetSocket **retsock)
1145 {
1146     g_autoptr(virCommand) cmd = NULL;
1147 
1148     *retsock = NULL;
1149 
1150     cmd = virCommandNewArgs(cmdargv);
1151     virCommandAddEnvPassCommon(cmd);
1152     virCommandClearCaps(cmd);
1153 
1154     return virNetSocketNewConnectCommand(cmd, retsock);
1155 }
1156 
1157 
virNetSocketNewConnectSockFD(int sockfd,virNetSocket ** retsock)1158 int virNetSocketNewConnectSockFD(int sockfd,
1159                                  virNetSocket **retsock)
1160 {
1161     virSocketAddr localAddr;
1162 
1163     localAddr.len = sizeof(localAddr.data);
1164     if (getsockname(sockfd, &localAddr.data.sa, &localAddr.len) < 0) {
1165         virReportSystemError(errno, "%s", _("Unable to get local socket name"));
1166         return -1;
1167     }
1168 
1169     if (!(*retsock = virNetSocketNew(&localAddr, NULL, true, sockfd, -1, -1, false)))
1170         return -1;
1171 
1172     return 0;
1173 }
1174 
1175 
virNetSocketNewPostExecRestart(virJSONValue * object)1176 virNetSocket *virNetSocketNewPostExecRestart(virJSONValue *object)
1177 {
1178     virSocketAddr localAddr;
1179     virSocketAddr remoteAddr;
1180     int fd, thepid, errfd;
1181     bool isClient;
1182     bool unlinkUNIX;
1183 
1184     if (virJSONValueObjectGetNumberInt(object, "fd", &fd) < 0) {
1185         virReportError(VIR_ERR_INTERNAL_ERROR, "%s",
1186                        _("Missing fd data in JSON document"));
1187         return NULL;
1188     }
1189 
1190     if (virJSONValueObjectGetNumberInt(object, "pid", &thepid) < 0) {
1191         virReportError(VIR_ERR_INTERNAL_ERROR, "%s",
1192                        _("Missing pid data in JSON document"));
1193         return NULL;
1194     }
1195 
1196     if (virJSONValueObjectGetNumberInt(object, "errfd", &errfd) < 0) {
1197         virReportError(VIR_ERR_INTERNAL_ERROR, "%s",
1198                        _("Missing errfd data in JSON document"));
1199         return NULL;
1200     }
1201 
1202     if (virJSONValueObjectGetBoolean(object, "isClient", &isClient) < 0) {
1203         virReportError(VIR_ERR_INTERNAL_ERROR, "%s",
1204                        _("Missing isClient data in JSON document"));
1205         return NULL;
1206     }
1207 
1208     if (virJSONValueObjectGetBoolean(object, "unlinkUNIX", &unlinkUNIX) < 0)
1209         unlinkUNIX = !isClient;
1210 
1211     memset(&localAddr, 0, sizeof(localAddr));
1212     memset(&remoteAddr, 0, sizeof(remoteAddr));
1213 
1214     remoteAddr.len = sizeof(remoteAddr.data.stor);
1215     if (getsockname(fd, &remoteAddr.data.sa, &remoteAddr.len) < 0) {
1216         virReportSystemError(errno, "%s", _("Unable to get peer socket name"));
1217         return NULL;
1218     }
1219 
1220     localAddr.len = sizeof(localAddr.data.stor);
1221     if (getsockname(fd, &localAddr.data.sa, &localAddr.len) < 0) {
1222         virReportSystemError(errno, "%s", _("Unable to get local socket name"));
1223         return NULL;
1224     }
1225 
1226     return virNetSocketNew(&localAddr, &remoteAddr, isClient,
1227                            fd, errfd, thepid, unlinkUNIX);
1228 }
1229 
1230 
virNetSocketPreExecRestart(virNetSocket * sock)1231 virJSONValue *virNetSocketPreExecRestart(virNetSocket *sock)
1232 {
1233     g_autoptr(virJSONValue) object = NULL;
1234 
1235     virObjectLock(sock);
1236 
1237 #if WITH_SASL
1238     if (sock->saslSession) {
1239         virReportError(VIR_ERR_OPERATION_INVALID, "%s",
1240                        _("Unable to save socket state when SASL session is active"));
1241         goto error;
1242     }
1243 #endif
1244     if (sock->tlsSession) {
1245         virReportError(VIR_ERR_OPERATION_INVALID, "%s",
1246                        _("Unable to save socket state when TLS session is active"));
1247         goto error;
1248     }
1249 
1250     object = virJSONValueNewObject();
1251 
1252     if (virJSONValueObjectAppendNumberInt(object, "fd", sock->fd) < 0)
1253         goto error;
1254 
1255     if (virJSONValueObjectAppendNumberInt(object, "errfd", sock->errfd) < 0)
1256         goto error;
1257 
1258     if (virJSONValueObjectAppendNumberInt(object, "pid", sock->pid) < 0)
1259         goto error;
1260 
1261     if (virJSONValueObjectAppendBoolean(object, "isClient", sock->isClient) < 0)
1262         goto error;
1263 
1264     if (virJSONValueObjectAppendBoolean(object, "unlinkUNIX", sock->unlinkUNIX) < 0)
1265         goto error;
1266 
1267     if (virSetInherit(sock->fd, true) < 0) {
1268         virReportSystemError(errno,
1269                              _("Cannot disable close-on-exec flag on socket %d"),
1270                              sock->fd);
1271         goto error;
1272     }
1273     if (sock->errfd != -1 &&
1274         virSetInherit(sock->errfd, true) < 0) {
1275         virReportSystemError(errno,
1276                              _("Cannot disable close-on-exec flag on pipe %d"),
1277                              sock->errfd);
1278         goto error;
1279     }
1280 
1281     virObjectUnlock(sock);
1282     return g_steal_pointer(&object);
1283 
1284  error:
1285     virObjectUnlock(sock);
1286     return NULL;
1287 }
1288 
1289 
virNetSocketDispose(void * obj)1290 void virNetSocketDispose(void *obj)
1291 {
1292     virNetSocket *sock = obj;
1293 
1294     PROBE(RPC_SOCKET_DISPOSE,
1295           "sock=%p", sock);
1296 
1297     if (sock->watch >= 0) {
1298         virEventRemoveHandle(sock->watch);
1299         sock->watch = -1;
1300     }
1301 
1302 #ifndef WIN32
1303     /* If a server socket, then unlink UNIX path */
1304     if (sock->unlinkUNIX &&
1305         sock->localAddr.data.sa.sa_family == AF_UNIX &&
1306         sock->localAddr.data.un.sun_path[0] != '\0')
1307         unlink(sock->localAddr.data.un.sun_path);
1308 #endif
1309 
1310     /* Make sure it can't send any more I/O during shutdown */
1311     if (sock->tlsSession)
1312         virNetTLSSessionSetIOCallbacks(sock->tlsSession, NULL, NULL, NULL);
1313     virObjectUnref(sock->tlsSession);
1314 #if WITH_SASL
1315     virObjectUnref(sock->saslSession);
1316 #endif
1317 
1318 #if WITH_SSH2
1319     virObjectUnref(sock->sshSession);
1320 #endif
1321 
1322 #if WITH_LIBSSH
1323     virObjectUnref(sock->libsshSession);
1324 #endif
1325 
1326     if (sock->ownsFd && sock->fd != -1) {
1327         closesocket(sock->fd);
1328         sock->fd = -1;
1329     }
1330     VIR_FORCE_CLOSE(sock->errfd);
1331 
1332     virProcessAbort(sock->pid);
1333 
1334     g_free(sock->localAddrStrSASL);
1335     g_free(sock->remoteAddrStrSASL);
1336     g_free(sock->remoteAddrStrURI);
1337 }
1338 
1339 
virNetSocketGetFD(virNetSocket * sock)1340 int virNetSocketGetFD(virNetSocket *sock)
1341 {
1342     int fd;
1343     virObjectLock(sock);
1344     fd = sock->fd;
1345     virObjectUnlock(sock);
1346     return fd;
1347 }
1348 
virNetSocketDupFD(virNetSocket * sock,bool cloexec)1349 int virNetSocketDupFD(virNetSocket *sock, bool cloexec)
1350 {
1351     int fd;
1352 
1353 #ifdef F_DUPFD_CLOEXEC
1354     if (cloexec)
1355         fd = fcntl(sock->fd, F_DUPFD_CLOEXEC, 0);
1356     else
1357 #endif /* F_DUPFD_CLOEXEC */
1358         fd = dup(sock->fd);
1359     if (fd < 0) {
1360         virReportSystemError(errno, "%s",
1361                              _("Unable to copy socket file handle"));
1362         return -1;
1363     }
1364 #ifndef F_DUPFD_CLOEXEC
1365     if (cloexec &&
1366         virSetCloseExec(fd) < 0) {
1367         int saveerr = errno;
1368         closesocket(fd);
1369         errno = saveerr;
1370         return -1;
1371     }
1372 #endif /* F_DUPFD_CLOEXEC */
1373 
1374     return fd;
1375 }
1376 
1377 
virNetSocketIsLocal(virNetSocket * sock)1378 bool virNetSocketIsLocal(virNetSocket *sock)
1379 {
1380     bool isLocal = false;
1381     virObjectLock(sock);
1382     if (sock->localAddr.data.sa.sa_family == AF_UNIX)
1383         isLocal = true;
1384     virObjectUnlock(sock);
1385     return isLocal;
1386 }
1387 
1388 
virNetSocketHasPassFD(virNetSocket * sock)1389 bool virNetSocketHasPassFD(virNetSocket *sock)
1390 {
1391     bool hasPassFD = false;
1392     virObjectLock(sock);
1393     if (sock->localAddr.data.sa.sa_family == AF_UNIX)
1394         hasPassFD = true;
1395     virObjectUnlock(sock);
1396     return hasPassFD;
1397 }
1398 
virNetSocketGetPath(virNetSocket * sock)1399 char *virNetSocketGetPath(virNetSocket *sock)
1400 {
1401     char *path = NULL;
1402     virObjectLock(sock);
1403     path = virSocketAddrGetPath(&sock->localAddr);
1404     virObjectUnlock(sock);
1405     return path;
1406 }
1407 
virNetSocketGetPort(virNetSocket * sock)1408 int virNetSocketGetPort(virNetSocket *sock)
1409 {
1410     int port;
1411     virObjectLock(sock);
1412     port = virSocketAddrGetPort(&sock->localAddr);
1413     virObjectUnlock(sock);
1414     return port;
1415 }
1416 
1417 
1418 #if defined(SO_PEERCRED)
virNetSocketGetUNIXIdentity(virNetSocket * sock,uid_t * uid,gid_t * gid,pid_t * pid,unsigned long long * timestamp)1419 int virNetSocketGetUNIXIdentity(virNetSocket *sock,
1420                                 uid_t *uid,
1421                                 gid_t *gid,
1422                                 pid_t *pid,
1423                                 unsigned long long *timestamp)
1424 {
1425 # if defined(WITH_STRUCT_SOCKPEERCRED)
1426     struct sockpeercred cr;
1427 # else
1428     struct ucred cr;
1429 # endif
1430     socklen_t cr_len = sizeof(cr);
1431     int ret = -1;
1432 
1433     virObjectLock(sock);
1434 
1435     if (getsockopt(sock->fd, SOL_SOCKET, SO_PEERCRED, &cr, &cr_len) < 0) {
1436         virReportSystemError(errno, "%s",
1437                              _("Failed to get client socket identity"));
1438         goto cleanup;
1439     }
1440 
1441     *timestamp = -1;
1442     if (cr.pid && virProcessGetStartTime(cr.pid, timestamp) < 0)
1443         goto cleanup;
1444 
1445     if (cr.pid)
1446         *pid = cr.pid;
1447     else
1448         *pid = -1;
1449     *uid = cr.uid;
1450     *gid = cr.gid;
1451 
1452     ret = 0;
1453 
1454  cleanup:
1455     virObjectUnlock(sock);
1456     return ret;
1457 }
1458 #elif defined(LOCAL_PEERCRED)
1459 
1460 /* VIR_SOL_PEERCRED - the value needed to let getsockopt() work with
1461  * LOCAL_PEERCRED
1462  */
1463 
1464 /* Mac OS X 10.8 provides SOL_LOCAL for LOCAL_PEERCRED */
1465 # ifdef SOL_LOCAL
1466 #  define VIR_SOL_PEERCRED SOL_LOCAL
1467 # else
1468 /* FreeBSD and Mac OS X prior to 10.7, SOL_LOCAL is not defined and
1469  * users are expected to supply 0 as the second value for getsockopt()
1470  * when using LOCAL_PEERCRED. NB SOL_SOCKET cannot be used instead
1471  * of SOL_LOCAL
1472  */
1473 #  define VIR_SOL_PEERCRED 0
1474 # endif
1475 
virNetSocketGetUNIXIdentity(virNetSocket * sock,uid_t * uid,gid_t * gid,pid_t * pid,unsigned long long * timestamp)1476 int virNetSocketGetUNIXIdentity(virNetSocket *sock,
1477                                 uid_t *uid,
1478                                 gid_t *gid,
1479                                 pid_t *pid,
1480                                 unsigned long long *timestamp)
1481 {
1482     struct xucred cr;
1483     socklen_t cr_len = sizeof(cr);
1484     int ret = -1;
1485 
1486     virObjectLock(sock);
1487 
1488     cr.cr_ngroups = -1;
1489     if (getsockopt(sock->fd, VIR_SOL_PEERCRED, LOCAL_PEERCRED, &cr, &cr_len) < 0) {
1490         virReportSystemError(errno, "%s",
1491                              _("Failed to get client socket identity"));
1492         goto cleanup;
1493     }
1494 
1495     if (cr.cr_version != XUCRED_VERSION) {
1496         virReportError(VIR_ERR_SYSTEM_ERROR, "%s",
1497                        _("Failed to get valid client socket identity"));
1498         goto cleanup;
1499     }
1500 
1501     if (cr.cr_ngroups <= 0 || cr.cr_ngroups > NGROUPS) {
1502         virReportError(VIR_ERR_SYSTEM_ERROR, "%s",
1503                        _("Failed to get valid client socket identity groups"));
1504         goto cleanup;
1505     }
1506 
1507     /* PID and process creation time are not supported on BSDs by
1508      * LOCAL_PEERCRED.
1509      */
1510     *pid = -1;
1511     *timestamp = -1;
1512     *uid = cr.cr_uid;
1513     *gid = cr.cr_gid;
1514 
1515 # ifdef LOCAL_PEERPID
1516     /* Exists on Mac OS X 10.8 for retrieving the peer's PID */
1517     cr_len = sizeof(*pid);
1518 
1519     if (getsockopt(sock->fd, VIR_SOL_PEERCRED, LOCAL_PEERPID, pid, &cr_len) < 0) {
1520         /* Ensure this is set to something sane as there are no guarantees
1521          * as to what its set to now.
1522          */
1523         *pid = -1;
1524 
1525         /* If this was built on a system with LOCAL_PEERPID defined but
1526          * the kernel doesn't support it we'll get back EOPNOTSUPP so
1527          * treat all errors but EOPNOTSUPP as fatal
1528          */
1529         if (errno != EOPNOTSUPP) {
1530             virReportSystemError(errno, "%s",
1531                     _("Failed to get client socket PID"));
1532             goto cleanup;
1533         }
1534     }
1535 # endif
1536 
1537     ret = 0;
1538 
1539  cleanup:
1540     virObjectUnlock(sock);
1541     return ret;
1542 }
1543 #else
virNetSocketGetUNIXIdentity(virNetSocket * sock G_GNUC_UNUSED,uid_t * uid G_GNUC_UNUSED,gid_t * gid G_GNUC_UNUSED,pid_t * pid G_GNUC_UNUSED,unsigned long long * timestamp G_GNUC_UNUSED)1544 int virNetSocketGetUNIXIdentity(virNetSocket *sock G_GNUC_UNUSED,
1545                                 uid_t *uid G_GNUC_UNUSED,
1546                                 gid_t *gid G_GNUC_UNUSED,
1547                                 pid_t *pid G_GNUC_UNUSED,
1548                                 unsigned long long *timestamp G_GNUC_UNUSED)
1549 {
1550     /* XXX Many more OS support UNIX socket credentials we could port to. See dbus ....*/
1551     virReportSystemError(ENOSYS, "%s",
1552                          _("Client socket identity not available"));
1553     return -1;
1554 }
1555 #endif
1556 
1557 #ifdef WITH_SELINUX
virNetSocketGetSELinuxContext(virNetSocket * sock,char ** context)1558 int virNetSocketGetSELinuxContext(virNetSocket *sock,
1559                                   char **context)
1560 {
1561     char *seccon = NULL;
1562     int ret = -1;
1563 
1564     *context = NULL;
1565 
1566     virObjectLock(sock);
1567     if (getpeercon(sock->fd, &seccon) < 0) {
1568         if (errno == ENOSYS || errno == ENOPROTOOPT) {
1569             ret = 0;
1570             goto cleanup;
1571         }
1572         virReportSystemError(errno, "%s",
1573                              _("Unable to query peer security context"));
1574         goto cleanup;
1575     }
1576 
1577     *context = g_strdup(seccon);
1578 
1579     ret = 0;
1580  cleanup:
1581     freecon(seccon);
1582     virObjectUnlock(sock);
1583     return ret;
1584 }
1585 #else
virNetSocketGetSELinuxContext(virNetSocket * sock G_GNUC_UNUSED,char ** context)1586 int virNetSocketGetSELinuxContext(virNetSocket *sock G_GNUC_UNUSED,
1587                                   char **context)
1588 {
1589     *context = NULL;
1590     return 0;
1591 }
1592 #endif
1593 
1594 
virNetSocketSetBlocking(virNetSocket * sock,bool blocking)1595 int virNetSocketSetBlocking(virNetSocket *sock,
1596                             bool blocking)
1597 {
1598     int ret;
1599     virObjectLock(sock);
1600     ret = virSetBlocking(sock->fd, blocking);
1601     virObjectUnlock(sock);
1602     return ret;
1603 }
1604 
1605 
virNetSocketLocalAddrStringSASL(virNetSocket * sock)1606 const char *virNetSocketLocalAddrStringSASL(virNetSocket *sock)
1607 {
1608     return sock->localAddrStrSASL;
1609 }
1610 
virNetSocketRemoteAddrStringSASL(virNetSocket * sock)1611 const char *virNetSocketRemoteAddrStringSASL(virNetSocket *sock)
1612 {
1613     return sock->remoteAddrStrSASL;
1614 }
1615 
virNetSocketRemoteAddrStringURI(virNetSocket * sock)1616 const char *virNetSocketRemoteAddrStringURI(virNetSocket *sock)
1617 {
1618     return sock->remoteAddrStrURI;
1619 }
1620 
virNetSocketTLSSessionWrite(const char * buf,size_t len,void * opaque)1621 static ssize_t virNetSocketTLSSessionWrite(const char *buf,
1622                                            size_t len,
1623                                            void *opaque)
1624 {
1625     virNetSocket *sock = opaque;
1626     return write(sock->fd, buf, len);
1627 }
1628 
1629 
virNetSocketTLSSessionRead(char * buf,size_t len,void * opaque)1630 static ssize_t virNetSocketTLSSessionRead(char *buf,
1631                                           size_t len,
1632                                           void *opaque)
1633 {
1634     virNetSocket *sock = opaque;
1635     return read(sock->fd, buf, len);
1636 }
1637 
1638 
virNetSocketSetTLSSession(virNetSocket * sock,virNetTLSSession * sess)1639 void virNetSocketSetTLSSession(virNetSocket *sock,
1640                                virNetTLSSession *sess)
1641 {
1642     virObjectLock(sock);
1643     virObjectUnref(sock->tlsSession);
1644     sock->tlsSession = virObjectRef(sess);
1645     virNetTLSSessionSetIOCallbacks(sess,
1646                                    virNetSocketTLSSessionWrite,
1647                                    virNetSocketTLSSessionRead,
1648                                    sock);
1649     virObjectUnlock(sock);
1650 }
1651 
1652 #if WITH_SASL
virNetSocketSetSASLSession(virNetSocket * sock,virNetSASLSession * sess)1653 void virNetSocketSetSASLSession(virNetSocket *sock,
1654                                 virNetSASLSession *sess)
1655 {
1656     virObjectLock(sock);
1657     virObjectUnref(sock->saslSession);
1658     sock->saslSession = virObjectRef(sess);
1659     virObjectUnlock(sock);
1660 }
1661 #endif
1662 
1663 
virNetSocketHasCachedData(virNetSocket * sock G_GNUC_UNUSED)1664 bool virNetSocketHasCachedData(virNetSocket *sock G_GNUC_UNUSED)
1665 {
1666     bool hasCached = false;
1667     virObjectLock(sock);
1668 
1669 #if WITH_SSH2
1670     if (virNetSSHSessionHasCachedData(sock->sshSession))
1671         hasCached = true;
1672 #endif
1673 
1674 #if WITH_LIBSSH
1675     if (virNetLibsshSessionHasCachedData(sock->libsshSession))
1676         hasCached = true;
1677 #endif
1678 
1679 #if WITH_SASL
1680     if (sock->saslDecoded)
1681         hasCached = true;
1682 #endif
1683     virObjectUnlock(sock);
1684     return hasCached;
1685 }
1686 
1687 #if WITH_SSH2
virNetSocketLibSSH2Read(virNetSocket * sock,char * buf,size_t len)1688 static ssize_t virNetSocketLibSSH2Read(virNetSocket *sock,
1689                                        char *buf,
1690                                        size_t len)
1691 {
1692     return virNetSSHChannelRead(sock->sshSession, buf, len);
1693 }
1694 
virNetSocketLibSSH2Write(virNetSocket * sock,const char * buf,size_t len)1695 static ssize_t virNetSocketLibSSH2Write(virNetSocket *sock,
1696                                         const char *buf,
1697                                         size_t len)
1698 {
1699     return virNetSSHChannelWrite(sock->sshSession, buf, len);
1700 }
1701 #endif
1702 
1703 #if WITH_LIBSSH
virNetSocketLibsshRead(virNetSocket * sock,char * buf,size_t len)1704 static ssize_t virNetSocketLibsshRead(virNetSocket *sock,
1705                                       char *buf,
1706                                       size_t len)
1707 {
1708     return virNetLibsshChannelRead(sock->libsshSession, buf, len);
1709 }
1710 
virNetSocketLibsshWrite(virNetSocket * sock,const char * buf,size_t len)1711 static ssize_t virNetSocketLibsshWrite(virNetSocket *sock,
1712                                        const char *buf,
1713                                        size_t len)
1714 {
1715     return virNetLibsshChannelWrite(sock->libsshSession, buf, len);
1716 }
1717 #endif
1718 
virNetSocketHasPendingData(virNetSocket * sock G_GNUC_UNUSED)1719 bool virNetSocketHasPendingData(virNetSocket *sock G_GNUC_UNUSED)
1720 {
1721     bool hasPending = false;
1722     virObjectLock(sock);
1723 #if WITH_SASL
1724     if (sock->saslEncoded)
1725         hasPending = true;
1726 #endif
1727     virObjectUnlock(sock);
1728     return hasPending;
1729 }
1730 
1731 
virNetSocketReadWire(virNetSocket * sock,char * buf,size_t len)1732 static ssize_t virNetSocketReadWire(virNetSocket *sock, char *buf, size_t len)
1733 {
1734     g_autofree char *errout = NULL;
1735     ssize_t ret;
1736 
1737 #if WITH_SSH2
1738     if (sock->sshSession)
1739         return virNetSocketLibSSH2Read(sock, buf, len);
1740 #endif
1741 
1742 #if WITH_LIBSSH
1743     if (sock->libsshSession)
1744         return virNetSocketLibsshRead(sock, buf, len);
1745 #endif
1746 
1747  reread:
1748     if (sock->tlsSession &&
1749         virNetTLSSessionGetHandshakeStatus(sock->tlsSession) ==
1750         VIR_NET_TLS_HANDSHAKE_COMPLETE) {
1751         ret = virNetTLSSessionRead(sock->tlsSession, buf, len);
1752     } else {
1753         ret = read(sock->fd, buf, len);
1754     }
1755 
1756     if ((ret < 0) && (errno == EINTR))
1757         goto reread;
1758     if ((ret < 0) && (errno == EAGAIN))
1759         return 0;
1760 
1761     if (ret <= 0 &&
1762         sock->errfd != -1 &&
1763         virFileReadLimFD(sock->errfd, 1024, &errout) >= 0 &&
1764         errout != NULL) {
1765         size_t elen = strlen(errout);
1766         /* remove trailing whitespace */
1767         while (elen && g_ascii_isspace(errout[elen - 1]))
1768             errout[--elen] = '\0';
1769     }
1770 
1771     if (ret < 0) {
1772         if (errout)
1773             virReportSystemError(errno,
1774                                  _("Cannot recv data: %s"), errout);
1775         else
1776             virReportSystemError(errno, "%s",
1777                                  _("Cannot recv data"));
1778         ret = -1;
1779     } else if (ret == 0) {
1780         if (sock->quietEOF) {
1781             VIR_DEBUG("socket='%p' EOF while reading: errout='%s'",
1782                       socket, NULLSTR(errout));
1783 
1784             ret = -2;
1785         } else {
1786             if (errout)
1787                 virReportSystemError(EIO,
1788                                      _("End of file while reading data: %s"),
1789                                      errout);
1790             else
1791                 virReportSystemError(EIO, "%s",
1792                                      _("End of file while reading data"));
1793 
1794             ret = -1;
1795         }
1796     }
1797 
1798     return ret;
1799 }
1800 
virNetSocketWriteWire(virNetSocket * sock,const char * buf,size_t len)1801 static ssize_t virNetSocketWriteWire(virNetSocket *sock, const char *buf, size_t len)
1802 {
1803     ssize_t ret;
1804 
1805 #if WITH_SSH2
1806     if (sock->sshSession)
1807         return virNetSocketLibSSH2Write(sock, buf, len);
1808 #endif
1809 
1810 #if WITH_LIBSSH
1811     if (sock->libsshSession)
1812         return virNetSocketLibsshWrite(sock, buf, len);
1813 #endif
1814 
1815  rewrite:
1816     if (sock->tlsSession &&
1817         virNetTLSSessionGetHandshakeStatus(sock->tlsSession) ==
1818         VIR_NET_TLS_HANDSHAKE_COMPLETE) {
1819         ret = virNetTLSSessionWrite(sock->tlsSession, buf, len);
1820     } else {
1821         ret = write(sock->fd, buf, len);
1822     }
1823 
1824     if (ret < 0) {
1825         if (errno == EINTR)
1826             goto rewrite;
1827         if (errno == EAGAIN)
1828             return 0;
1829 
1830         virReportSystemError(errno, "%s",
1831                              _("Cannot write data"));
1832         return -1;
1833     }
1834     if (ret == 0) {
1835         virReportSystemError(EIO, "%s",
1836                              _("End of file while writing data"));
1837         return -1;
1838     }
1839 
1840     return ret;
1841 }
1842 
1843 
1844 #if WITH_SASL
virNetSocketReadSASL(virNetSocket * sock,char * buf,size_t len)1845 static ssize_t virNetSocketReadSASL(virNetSocket *sock, char *buf, size_t len)
1846 {
1847     ssize_t got;
1848 
1849     /* Need to read some more data off the wire */
1850     if (sock->saslDecoded == NULL) {
1851         ssize_t encodedLen = virNetSASLSessionGetMaxBufSize(sock->saslSession);
1852         g_autofree char *encoded = g_new0(char, encodedLen);
1853 
1854         encodedLen = virNetSocketReadWire(sock, encoded, encodedLen);
1855 
1856         if (encodedLen <= 0)
1857             return encodedLen;
1858 
1859         if (virNetSASLSessionDecode(sock->saslSession,
1860                                     encoded, encodedLen,
1861                                     &sock->saslDecoded, &sock->saslDecodedLength) < 0) {
1862             return -1;
1863         }
1864 
1865         sock->saslDecodedOffset = 0;
1866     }
1867 
1868     /* Some buffered decoded data to return now */
1869     got = sock->saslDecodedLength - sock->saslDecodedOffset;
1870 
1871     if (len > got)
1872         len = got;
1873 
1874     memcpy(buf, sock->saslDecoded + sock->saslDecodedOffset, len);
1875     sock->saslDecodedOffset += len;
1876 
1877     if (sock->saslDecodedOffset == sock->saslDecodedLength) {
1878         sock->saslDecoded = NULL;
1879         sock->saslDecodedOffset = sock->saslDecodedLength = 0;
1880     }
1881 
1882     return len;
1883 }
1884 
1885 
virNetSocketWriteSASL(virNetSocket * sock,const char * buf,size_t len)1886 static ssize_t virNetSocketWriteSASL(virNetSocket *sock, const char *buf, size_t len)
1887 {
1888     int ret;
1889     size_t tosend = virNetSASLSessionGetMaxBufSize(sock->saslSession);
1890 
1891     /* SASL doesn't necessarily let us send the whole
1892        buffer at once */
1893     if (tosend > len)
1894         tosend = len;
1895 
1896     /* Not got any pending encoded data, so we need to encode raw stuff */
1897     if (sock->saslEncoded == NULL) {
1898         if (virNetSASLSessionEncode(sock->saslSession,
1899                                     buf, tosend,
1900                                     &sock->saslEncoded,
1901                                     &sock->saslEncodedLength) < 0)
1902             return -1;
1903 
1904         sock->saslEncodedRawLength = tosend;
1905         sock->saslEncodedOffset = 0;
1906     }
1907 
1908     /* Send some of the encoded stuff out on the wire */
1909     ret = virNetSocketWriteWire(sock,
1910                                 sock->saslEncoded + sock->saslEncodedOffset,
1911                                 sock->saslEncodedLength - sock->saslEncodedOffset);
1912 
1913     if (ret <= 0)
1914         return ret; /* -1 error, 0 == egain */
1915 
1916     /* Note how much we sent */
1917     sock->saslEncodedOffset += ret;
1918 
1919     /* Sent all encoded, so update raw buffer to indicate completion */
1920     if (sock->saslEncodedOffset == sock->saslEncodedLength) {
1921         ssize_t done = sock->saslEncodedRawLength;
1922         sock->saslEncoded = NULL;
1923         sock->saslEncodedOffset = sock->saslEncodedLength = sock->saslEncodedRawLength = 0;
1924 
1925         /* Mark as complete, so caller detects completion.
1926          *
1927          * Note that 'done' is possibly less than our current
1928          * 'tosend' value, since if virNetSocketWriteWire
1929          * only partially sent the data, we might have been
1930          * called a 2nd time to write remaining cached
1931          * encoded data. This means that the caller might
1932          * also have further raw data pending that's included
1933          * in 'tosend' */
1934         return done;
1935     } else {
1936         /* Still have stuff pending in saslEncoded buffer.
1937          * Pretend to caller that we didn't send any yet.
1938          * The caller will then retry with same buffer
1939          * shortly, which lets us finish saslEncoded.
1940          */
1941         return 0;
1942     }
1943 }
1944 #endif
1945 
virNetSocketRead(virNetSocket * sock,char * buf,size_t len)1946 ssize_t virNetSocketRead(virNetSocket *sock, char *buf, size_t len)
1947 {
1948     ssize_t ret;
1949     virObjectLock(sock);
1950 #if WITH_SASL
1951     if (sock->saslSession)
1952         ret = virNetSocketReadSASL(sock, buf, len);
1953     else
1954 #endif
1955         ret = virNetSocketReadWire(sock, buf, len);
1956     virObjectUnlock(sock);
1957     return ret;
1958 }
1959 
virNetSocketWrite(virNetSocket * sock,const char * buf,size_t len)1960 ssize_t virNetSocketWrite(virNetSocket *sock, const char *buf, size_t len)
1961 {
1962     ssize_t ret;
1963 
1964     virObjectLock(sock);
1965 #if WITH_SASL
1966     if (sock->saslSession)
1967         ret = virNetSocketWriteSASL(sock, buf, len);
1968     else
1969 #endif
1970         ret = virNetSocketWriteWire(sock, buf, len);
1971     virObjectUnlock(sock);
1972     return ret;
1973 }
1974 
1975 
1976 /*
1977  * Returns 1 if an FD was sent, 0 if it would block, -1 on error
1978  */
virNetSocketSendFD(virNetSocket * sock,int fd)1979 int virNetSocketSendFD(virNetSocket *sock, int fd)
1980 {
1981     int ret = -1;
1982     if (!virNetSocketHasPassFD(sock)) {
1983         virReportError(VIR_ERR_INTERNAL_ERROR, "%s",
1984                        _("Sending file descriptors is not supported on this socket"));
1985         return -1;
1986     }
1987     virObjectLock(sock);
1988     PROBE(RPC_SOCKET_SEND_FD,
1989           "sock=%p fd=%d", sock, fd);
1990     if (virSocketSendFD(sock->fd, fd) < 0) {
1991         if (errno == EAGAIN)
1992             ret = 0;
1993         else
1994             virReportSystemError(errno,
1995                                  _("Failed to send file descriptor %d"),
1996                                  fd);
1997         goto cleanup;
1998     }
1999     ret = 1;
2000 
2001  cleanup:
2002     virObjectUnlock(sock);
2003     return ret;
2004 }
2005 
2006 
2007 /*
2008  * Returns 1 if an FD was read, 0 if it would block, -1 on error
2009  */
virNetSocketRecvFD(virNetSocket * sock,int * fd)2010 int virNetSocketRecvFD(virNetSocket *sock, int *fd)
2011 {
2012     int ret = -1;
2013 
2014     *fd = -1;
2015 
2016     if (!virNetSocketHasPassFD(sock)) {
2017         virReportError(VIR_ERR_INTERNAL_ERROR, "%s",
2018                        _("Receiving file descriptors is not supported on this socket"));
2019         return -1;
2020     }
2021     virObjectLock(sock);
2022 
2023     if ((*fd = virSocketRecvFD(sock->fd, O_CLOEXEC)) < 0) {
2024         if (errno == EAGAIN)
2025             ret = 0;
2026         else
2027             virReportSystemError(errno, "%s",
2028                                  _("Failed to recv file descriptor"));
2029         goto cleanup;
2030     }
2031     PROBE(RPC_SOCKET_RECV_FD,
2032           "sock=%p fd=%d", sock, *fd);
2033     ret = 1;
2034 
2035  cleanup:
2036     virObjectUnlock(sock);
2037     return ret;
2038 }
2039 
2040 
virNetSocketListen(virNetSocket * sock,int backlog)2041 int virNetSocketListen(virNetSocket *sock, int backlog)
2042 {
2043     virObjectLock(sock);
2044     if (listen(sock->fd, backlog > 0 ? backlog : 30) < 0) {
2045         virReportSystemError(errno, "%s", _("Unable to listen on socket"));
2046         virObjectUnlock(sock);
2047         return -1;
2048     }
2049     virObjectUnlock(sock);
2050     return 0;
2051 }
2052 
2053 
2054 /**
2055  * virNetSocketAccept:
2056  * @sock: socket to accept connection on
2057  * @clientsock: returned client socket
2058  *
2059  * For given socket @sock accept incoming connection and create
2060  * @clientsock representation of the new accepted connection.
2061  *
2062  * Returns: 0 on success,
2063  *         -2 if accepting failed due to EMFILE error,
2064  *         -1 otherwise.
2065  */
virNetSocketAccept(virNetSocket * sock,virNetSocket ** clientsock)2066 int virNetSocketAccept(virNetSocket *sock, virNetSocket **clientsock)
2067 {
2068     int fd = -1;
2069     virSocketAddr localAddr;
2070     virSocketAddr remoteAddr;
2071     int ret = -1;
2072 
2073     virObjectLock(sock);
2074 
2075     *clientsock = NULL;
2076 
2077     memset(&localAddr, 0, sizeof(localAddr));
2078     memset(&remoteAddr, 0, sizeof(remoteAddr));
2079 
2080     remoteAddr.len = sizeof(remoteAddr.data.stor);
2081     if ((fd = accept(sock->fd, &remoteAddr.data.sa, &remoteAddr.len)) < 0) {
2082         if (errno == ECONNABORTED ||
2083             errno == EAGAIN) {
2084             ret = 0;
2085             goto cleanup;
2086         } else if (errno == EMFILE) {
2087             ret = -2;
2088         }
2089 
2090         virReportSystemError(errno, "%s",
2091                              _("Unable to accept client"));
2092         goto cleanup;
2093     }
2094 
2095     localAddr.len = sizeof(localAddr.data);
2096     if (getsockname(fd, &localAddr.data.sa, &localAddr.len) < 0) {
2097         virReportSystemError(errno, "%s", _("Unable to get local socket name"));
2098         goto cleanup;
2099     }
2100 
2101     if (!(*clientsock = virNetSocketNew(&localAddr,
2102                                         &remoteAddr,
2103                                         true,
2104                                         fd, -1, 0,
2105                                         false)))
2106         goto cleanup;
2107 
2108     fd = -1;
2109     ret = 0;
2110 
2111  cleanup:
2112     if (fd != -1)
2113         closesocket(fd);
2114     virObjectUnlock(sock);
2115     return ret;
2116 }
2117 
2118 
virNetSocketEventHandle(int watch G_GNUC_UNUSED,int fd G_GNUC_UNUSED,int events,void * opaque)2119 static void virNetSocketEventHandle(int watch G_GNUC_UNUSED,
2120                                     int fd G_GNUC_UNUSED,
2121                                     int events,
2122                                     void *opaque)
2123 {
2124     virNetSocket *sock = opaque;
2125     virNetSocketIOFunc func;
2126     void *eopaque;
2127 
2128     virObjectLock(sock);
2129     func = sock->func;
2130     eopaque = sock->opaque;
2131     virObjectUnlock(sock);
2132 
2133     if (func)
2134         func(sock, events, eopaque);
2135 }
2136 
2137 
virNetSocketEventFree(void * opaque)2138 static void virNetSocketEventFree(void *opaque)
2139 {
2140     g_autoptr(virNetSocket) sock = opaque;
2141     virFreeCallback ff;
2142     void *eopaque;
2143 
2144     virObjectLock(sock);
2145     ff = sock->ff;
2146     eopaque = g_steal_pointer(&sock->opaque);
2147     sock->func = NULL;
2148     sock->ff = NULL;
2149     virObjectUnlock(sock);
2150 
2151     if (ff)
2152         ff(eopaque);
2153 }
2154 
virNetSocketAddIOCallback(virNetSocket * sock,int events,virNetSocketIOFunc func,void * opaque,virFreeCallback ff)2155 int virNetSocketAddIOCallback(virNetSocket *sock,
2156                               int events,
2157                               virNetSocketIOFunc func,
2158                               void *opaque,
2159                               virFreeCallback ff)
2160 {
2161     int ret = -1;
2162 
2163     virObjectRef(sock);
2164     virObjectLock(sock);
2165     if (sock->watch >= 0) {
2166         VIR_DEBUG("Watch already registered on socket %p", sock);
2167         goto cleanup;
2168     }
2169 
2170     if ((sock->watch = virEventAddHandle(sock->fd,
2171                                          events,
2172                                          virNetSocketEventHandle,
2173                                          sock,
2174                                          virNetSocketEventFree)) < 0) {
2175         VIR_DEBUG("Failed to register watch on socket %p", sock);
2176         goto cleanup;
2177     }
2178     sock->func = func;
2179     sock->opaque = opaque;
2180     sock->ff = ff;
2181 
2182     ret = 0;
2183 
2184  cleanup:
2185     virObjectUnlock(sock);
2186     if (ret != 0)
2187         virObjectUnref(sock);
2188     return ret;
2189 }
2190 
virNetSocketUpdateIOCallback(virNetSocket * sock,int events)2191 void virNetSocketUpdateIOCallback(virNetSocket *sock,
2192                                   int events)
2193 {
2194     virObjectLock(sock);
2195     if (sock->watch < 0) {
2196         VIR_DEBUG("Watch not registered on socket %p", sock);
2197         virObjectUnlock(sock);
2198         return;
2199     }
2200 
2201     virEventUpdateHandle(sock->watch, events);
2202 
2203     virObjectUnlock(sock);
2204 }
2205 
virNetSocketRemoveIOCallback(virNetSocket * sock)2206 void virNetSocketRemoveIOCallback(virNetSocket *sock)
2207 {
2208     virObjectLock(sock);
2209 
2210     if (sock->watch < 0) {
2211         VIR_DEBUG("Watch not registered on socket %p", sock);
2212         virObjectUnlock(sock);
2213         return;
2214     }
2215 
2216     virEventRemoveHandle(sock->watch);
2217     /* Don't unref @sock, it's done via eventloop callback. */
2218     sock->watch = -1;
2219 
2220     virObjectUnlock(sock);
2221 }
2222 
virNetSocketClose(virNetSocket * sock)2223 void virNetSocketClose(virNetSocket *sock)
2224 {
2225     if (!sock)
2226         return;
2227 
2228     virObjectLock(sock);
2229 
2230     if (sock->fd != -1) {
2231         closesocket(sock->fd);
2232         sock->fd = -1;
2233     }
2234 
2235 #ifndef WIN32
2236     /* If a server socket, then unlink UNIX path */
2237     if (sock->unlinkUNIX &&
2238         sock->localAddr.data.sa.sa_family == AF_UNIX &&
2239         sock->localAddr.data.un.sun_path[0] != '\0') {
2240         if (unlink(sock->localAddr.data.un.sun_path) == 0)
2241             sock->localAddr.data.un.sun_path[0] = '\0';
2242     }
2243 #endif
2244 
2245     virObjectUnlock(sock);
2246 }
2247 
2248 
2249 /**
2250  * virNetSocketSetQuietEOF:
2251  * @sock: socket object pointer
2252  *
2253  * Disables reporting I/O errors as a virError when @socket is closed while
2254  * reading data.
2255  */
2256 void
virNetSocketSetQuietEOF(virNetSocket * sock)2257 virNetSocketSetQuietEOF(virNetSocket *sock)
2258 {
2259     sock->quietEOF = true;
2260 }
2261