1 /*
2  * virnetserverclient.c: generic network RPC server client
3  *
4  * Copyright (C) 2006-2014 Red Hat, Inc.
5  * Copyright (C) 2006 Daniel P. Berrange
6  *
7  * This library is free software; you can redistribute it and/or
8  * modify it under the terms of the GNU Lesser General Public
9  * License as published by the Free Software Foundation; either
10  * version 2.1 of the License, or (at your option) any later version.
11  *
12  * This library is distributed in the hope that it will be useful,
13  * but WITHOUT ANY WARRANTY; without even the implied warranty of
14  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
15  * Lesser General Public License for more details.
16  *
17  * You should have received a copy of the GNU Lesser General Public
18  * License along with this library.  If not, see
19  * <http://www.gnu.org/licenses/>.
20  */
21 
22 #include <config.h>
23 
24 #include "internal.h"
25 #if WITH_SASL
26 # include <sasl/sasl.h>
27 #endif
28 
29 #include "virnetserver.h"
30 #include "virnetserverclient.h"
31 
32 #include "virlog.h"
33 #include "virerror.h"
34 #include "viralloc.h"
35 #include "virthread.h"
36 #include "virkeepalive.h"
37 #include "virprobe.h"
38 #include "virstring.h"
39 #include "virutil.h"
40 
41 #define VIR_FROM_THIS VIR_FROM_RPC
42 
43 VIR_LOG_INIT("rpc.netserverclient");
44 
45 /* Allow for filtering of incoming messages to a custom
46  * dispatch processing queue, instead of the workers.
47  * This allows for certain types of messages to be handled
48  * strictly "in order"
49  */
50 
51 typedef struct _virNetServerClientFilter virNetServerClientFilter;
52 struct _virNetServerClientFilter {
53     int id;
54     virNetServerClientFilterFunc func;
55     void *opaque;
56 
57     virNetServerClientFilter *next;
58 };
59 
60 
61 struct _virNetServerClient
62 {
63     virObjectLockable parent;
64 
65     unsigned long long id;
66     bool wantClose;
67     bool delayedClose;
68     virNetSocket *sock;
69     int auth;
70     bool auth_pending;
71     bool readonly;
72     virNetTLSContext *tlsCtxt;
73     virNetTLSSession *tls;
74 #if WITH_SASL
75     virNetSASLSession *sasl;
76 #endif
77     int sockTimer; /* Timer to be fired upon cached data,
78                     * so we jump out from poll() immediately */
79 
80 
81     virIdentity *identity;
82 
83     /* Connection timestamp, i.e. when a client connected to the daemon (UTC).
84      * For old clients restored by post-exec-restart, which did not have this
85      * attribute, value of 0 (epoch time) is used to indicate we have no
86      * information about their connection time.
87      */
88     long long conn_time;
89 
90     /* Count of messages in the 'tx' queue,
91      * and the server worker pool queue
92      * ie RPC calls in progress. Does not count
93      * async events which are not used for
94      * throttling calculations */
95     size_t nrequests;
96     size_t nrequests_max;
97     /* Zero or one messages being received. Zero if
98      * nrequests >= max_clients and throttling */
99     virNetMessage *rx;
100     /* Zero or many messages waiting for transmit
101      * back to client, including async events */
102     virNetMessage *tx;
103 
104     /* Filters to capture messages that would otherwise
105      * end up on the 'dx' queue */
106     virNetServerClientFilter *filters;
107     int nextFilterID;
108 
109     virNetServerClientDispatchFunc dispatchFunc;
110     void *dispatchOpaque;
111 
112     void *privateData;
113     virFreeCallback privateDataFreeFunc;
114     virNetServerClientPrivPreExecRestart privateDataPreExecRestart;
115     virNetServerClientCloseFunc privateDataCloseFunc;
116 
117     virKeepAlive *keepalive;
118 };
119 
120 
121 static virClass *virNetServerClientClass;
122 static void virNetServerClientDispose(void *obj);
123 
virNetServerClientOnceInit(void)124 static int virNetServerClientOnceInit(void)
125 {
126     if (!VIR_CLASS_NEW(virNetServerClient, virClassForObjectLockable()))
127         return -1;
128 
129     return 0;
130 }
131 
132 VIR_ONCE_GLOBAL_INIT(virNetServerClient);
133 
134 
135 static void virNetServerClientDispatchEvent(virNetSocket *sock, int events, void *opaque);
136 static void virNetServerClientUpdateEvent(virNetServerClient *client);
137 static virNetMessage *virNetServerClientDispatchRead(virNetServerClient *client);
138 static int virNetServerClientSendMessageLocked(virNetServerClient *client,
139                                                virNetMessage *msg);
140 
141 /*
142  * @client: a locked client object
143  */
144 static int
virNetServerClientCalculateHandleMode(virNetServerClient * client)145 virNetServerClientCalculateHandleMode(virNetServerClient *client)
146 {
147     int mode = 0;
148 
149 
150     VIR_DEBUG("tls=%p hs=%d, rx=%p tx=%p",
151               client->tls,
152               client->tls ? virNetTLSSessionGetHandshakeStatus(client->tls) : -1,
153               client->rx,
154               client->tx);
155     if (!client->sock || client->wantClose)
156         return 0;
157 
158     if (client->tls) {
159         switch (virNetTLSSessionGetHandshakeStatus(client->tls)) {
160         case VIR_NET_TLS_HANDSHAKE_RECVING:
161             mode |= VIR_EVENT_HANDLE_READABLE;
162             break;
163         case VIR_NET_TLS_HANDSHAKE_SENDING:
164             mode |= VIR_EVENT_HANDLE_WRITABLE;
165             break;
166         default:
167         case VIR_NET_TLS_HANDSHAKE_COMPLETE:
168             if (client->rx)
169                 mode |= VIR_EVENT_HANDLE_READABLE;
170             if (client->tx)
171                 mode |= VIR_EVENT_HANDLE_WRITABLE;
172         }
173     } else {
174         /* If there is a message on the rx queue, and
175          * we're not in middle of a delayedClose, then
176          * we're wanting more input */
177         if (client->rx && !client->delayedClose)
178             mode |= VIR_EVENT_HANDLE_READABLE;
179 
180         /* If there are one or more messages to send back to client,
181            then monitor for writability on socket */
182         if (client->tx)
183             mode |= VIR_EVENT_HANDLE_WRITABLE;
184     }
185     VIR_DEBUG("mode=0%o", mode);
186     return mode;
187 }
188 
189 /*
190  * @server: a locked or unlocked server object
191  * @client: a locked client object
192  */
virNetServerClientRegisterEvent(virNetServerClient * client)193 static int virNetServerClientRegisterEvent(virNetServerClient *client)
194 {
195     int mode = virNetServerClientCalculateHandleMode(client);
196 
197     if (!client->sock)
198         return -1;
199 
200     virObjectRef(client);
201     VIR_DEBUG("Registering client event callback %d", mode);
202     if (virNetSocketAddIOCallback(client->sock,
203                                   mode,
204                                   virNetServerClientDispatchEvent,
205                                   client,
206                                   virObjectFreeCallback) < 0) {
207         virObjectUnref(client);
208         return -1;
209     }
210 
211     return 0;
212 }
213 
214 /*
215  * @client: a locked client object
216  */
virNetServerClientUpdateEvent(virNetServerClient * client)217 static void virNetServerClientUpdateEvent(virNetServerClient *client)
218 {
219     int mode;
220 
221     if (!client->sock)
222         return;
223 
224     mode = virNetServerClientCalculateHandleMode(client);
225 
226     virNetSocketUpdateIOCallback(client->sock, mode);
227 
228     if (client->rx && virNetSocketHasCachedData(client->sock))
229         virEventUpdateTimeout(client->sockTimer, 0);
230 }
231 
232 
virNetServerClientAddFilter(virNetServerClient * client,virNetServerClientFilterFunc func,void * opaque)233 int virNetServerClientAddFilter(virNetServerClient *client,
234                                 virNetServerClientFilterFunc func,
235                                 void *opaque)
236 {
237     virNetServerClientFilter *filter;
238     virNetServerClientFilter **place;
239     int ret;
240 
241     filter = g_new0(virNetServerClientFilter, 1);
242 
243     virObjectLock(client);
244 
245     filter->id = client->nextFilterID++;
246     filter->func = func;
247     filter->opaque = opaque;
248 
249     place = &client->filters;
250     while (*place)
251         place = &(*place)->next;
252     *place = filter;
253 
254     ret = filter->id;
255 
256     virObjectUnlock(client);
257 
258     return ret;
259 }
260 
virNetServerClientRemoveFilter(virNetServerClient * client,int filterID)261 void virNetServerClientRemoveFilter(virNetServerClient *client,
262                                     int filterID)
263 {
264     virNetServerClientFilter *tmp;
265     virNetServerClientFilter *prev;
266 
267     virObjectLock(client);
268 
269     prev = NULL;
270     tmp = client->filters;
271     while (tmp) {
272         if (tmp->id == filterID) {
273             if (prev)
274                 prev->next = tmp->next;
275             else
276                 client->filters = tmp->next;
277 
278             VIR_FREE(tmp);
279             break;
280         }
281         prev = tmp;
282         tmp = tmp->next;
283     }
284 
285     virObjectUnlock(client);
286 }
287 
288 
289 /* Check the client's access. */
290 static int
virNetServerClientCheckAccess(virNetServerClient * client)291 virNetServerClientCheckAccess(virNetServerClient *client)
292 {
293     virNetMessage *confirm;
294 
295     /* Verify client certificate. */
296     if (virNetTLSContextCheckCertificate(client->tlsCtxt, client->tls) < 0)
297         return -1;
298 
299     if (client->tx) {
300         VIR_DEBUG("client had unexpected data pending tx after access check");
301         return -1;
302     }
303 
304     if (!(confirm = virNetMessageNew(false)))
305         return -1;
306 
307     /* Checks have succeeded.  Write a '\1' byte back to the client to
308      * indicate this (otherwise the socket is abruptly closed).
309      * (NB. The '\1' byte is sent in an encrypted record).
310      */
311     confirm->bufferLength = 1;
312     confirm->buffer = g_new0(char, confirm->bufferLength);
313     confirm->bufferOffset = 0;
314     confirm->buffer[0] = '\1';
315 
316     client->tx = confirm;
317 
318     return 0;
319 }
320 
321 
virNetServerClientDispatchMessage(virNetServerClient * client,virNetMessage * msg)322 static void virNetServerClientDispatchMessage(virNetServerClient *client,
323                                               virNetMessage *msg)
324 {
325     virObjectLock(client);
326     if (!client->dispatchFunc) {
327         virNetMessageFree(msg);
328         client->wantClose = true;
329         virObjectUnlock(client);
330     } else {
331         virObjectUnlock(client);
332         /* Accessing 'client' is safe, because virNetServerClientSetDispatcher
333          * only permits setting 'dispatchFunc' once, so if non-NULL, it will
334          * never change again
335          */
336         client->dispatchFunc(client, msg, client->dispatchOpaque);
337     }
338 }
339 
340 
virNetServerClientSockTimerFunc(int timer,void * opaque)341 static void virNetServerClientSockTimerFunc(int timer,
342                                             void *opaque)
343 {
344     virNetServerClient *client = opaque;
345     virNetMessage *msg = NULL;
346     virObjectLock(client);
347     virEventUpdateTimeout(timer, -1);
348     /* Although client->rx != NULL when this timer is enabled, it might have
349      * changed since the client was unlocked in the meantime. */
350     if (client->rx)
351         msg = virNetServerClientDispatchRead(client);
352     virObjectUnlock(client);
353 
354     if (msg)
355         virNetServerClientDispatchMessage(client, msg);
356 }
357 
358 
359 /**
360  * virNetServerClientAuthMethodImpliesAuthenticated:
361  * @auth: authentication method to check
362  *
363  * Check if the passed authentication method implies that a client is
364  * automatically authenticated.
365  *
366  * Returns true if @auth implies that a client is automatically
367  * authenticated, otherwise false.
368  */
369 static bool
virNetServerClientAuthMethodImpliesAuthenticated(int auth)370 virNetServerClientAuthMethodImpliesAuthenticated(int auth)
371 {
372     return auth == VIR_NET_SERVER_SERVICE_AUTH_NONE;
373 }
374 
375 
376 static virNetServerClient *
virNetServerClientNewInternal(unsigned long long id,virNetSocket * sock,int auth,bool auth_pending,virNetTLSContext * tls,bool readonly,size_t nrequests_max,long long timestamp)377 virNetServerClientNewInternal(unsigned long long id,
378                               virNetSocket *sock,
379                               int auth,
380                               bool auth_pending,
381                               virNetTLSContext *tls,
382                               bool readonly,
383                               size_t nrequests_max,
384                               long long timestamp)
385 {
386     virNetServerClient *client;
387 
388     if (virNetServerClientInitialize() < 0)
389         return NULL;
390 
391     if (!(client = virObjectLockableNew(virNetServerClientClass)))
392         return NULL;
393 
394     client->id = id;
395     client->sock = virObjectRef(sock);
396     client->auth = auth;
397     client->auth_pending = auth_pending;
398     client->readonly = readonly;
399     client->tlsCtxt = virObjectRef(tls);
400     client->nrequests_max = nrequests_max;
401     client->conn_time = timestamp;
402 
403     client->sockTimer = virEventAddTimeout(-1, virNetServerClientSockTimerFunc,
404                                            client, NULL);
405     if (client->sockTimer < 0)
406         goto error;
407 
408     /* Prepare one for packet receive */
409     if (!(client->rx = virNetMessageNew(true)))
410         goto error;
411     client->rx->bufferLength = VIR_NET_MESSAGE_LEN_MAX;
412     client->rx->buffer = g_new0(char, client->rx->bufferLength);
413     client->nrequests = 1;
414 
415     PROBE(RPC_SERVER_CLIENT_NEW,
416           "client=%p sock=%p",
417           client, client->sock);
418 
419     return client;
420 
421  error:
422     virObjectUnref(client);
423     return NULL;
424 }
425 
426 
virNetServerClientNew(unsigned long long id,virNetSocket * sock,int auth,bool readonly,size_t nrequests_max,virNetTLSContext * tls,virNetServerClientPrivNew privNew,virNetServerClientPrivPreExecRestart privPreExecRestart,virFreeCallback privFree,void * privOpaque)427 virNetServerClient *virNetServerClientNew(unsigned long long id,
428                                             virNetSocket *sock,
429                                             int auth,
430                                             bool readonly,
431                                             size_t nrequests_max,
432                                             virNetTLSContext *tls,
433                                             virNetServerClientPrivNew privNew,
434                                             virNetServerClientPrivPreExecRestart privPreExecRestart,
435                                             virFreeCallback privFree,
436                                             void *privOpaque)
437 {
438     virNetServerClient *client;
439     time_t now;
440     bool auth_pending = !virNetServerClientAuthMethodImpliesAuthenticated(auth);
441 
442     VIR_DEBUG("sock=%p auth=%d tls=%p", sock, auth, tls);
443 
444     if ((now = time(NULL)) == (time_t)-1) {
445         virReportSystemError(errno, "%s", _("failed to get current time"));
446         return NULL;
447     }
448 
449     if (!(client = virNetServerClientNewInternal(id, sock, auth, auth_pending,
450                                                  tls, readonly, nrequests_max,
451                                                  now)))
452         return NULL;
453 
454     if (!(client->privateData = privNew(client, privOpaque))) {
455         virObjectUnref(client);
456         return NULL;
457     }
458     client->privateDataFreeFunc = privFree;
459     client->privateDataPreExecRestart = privPreExecRestart;
460 
461     return client;
462 }
463 
464 
virNetServerClientNewPostExecRestart(virNetServer * srv,virJSONValue * object,virNetServerClientPrivNewPostExecRestart privNew,virNetServerClientPrivPreExecRestart privPreExecRestart,virFreeCallback privFree,void * privOpaque)465 virNetServerClient *virNetServerClientNewPostExecRestart(virNetServer *srv,
466                                                            virJSONValue *object,
467                                                            virNetServerClientPrivNewPostExecRestart privNew,
468                                                            virNetServerClientPrivPreExecRestart privPreExecRestart,
469                                                            virFreeCallback privFree,
470                                                            void *privOpaque)
471 {
472     virJSONValue *child;
473     virNetServerClient *client = NULL;
474     virNetSocket *sock;
475     int auth;
476     bool readonly, auth_pending;
477     unsigned int nrequests_max;
478     unsigned long long id;
479     long long timestamp;
480 
481     if (virJSONValueObjectGetNumberInt(object, "auth", &auth) < 0) {
482         virReportError(VIR_ERR_INTERNAL_ERROR, "%s",
483                        _("Missing auth field in JSON state document"));
484         return NULL;
485     }
486 
487     if (!virJSONValueObjectHasKey(object, "auth_pending")) {
488         auth_pending = !virNetServerClientAuthMethodImpliesAuthenticated(auth);
489     } else {
490         if (virJSONValueObjectGetBoolean(object, "auth_pending", &auth_pending) < 0) {
491             virReportError(VIR_ERR_INTERNAL_ERROR, "%s",
492                            _("Malformed auth_pending field in JSON state document"));
493             return NULL;
494         }
495 
496         /* If the used authentication method implies that the new
497          * client is automatically authenticated, the authentication
498          * cannot be pending */
499         if (auth_pending && virNetServerClientAuthMethodImpliesAuthenticated(auth)) {
500             virReportError(VIR_ERR_INTERNAL_ERROR, "%s",
501                            _("Invalid auth_pending and auth combination in JSON state document"));
502             return NULL;
503         }
504     }
505 
506     if (virJSONValueObjectGetBoolean(object, "readonly", &readonly) < 0) {
507         virReportError(VIR_ERR_INTERNAL_ERROR, "%s",
508                        _("Missing readonly field in JSON state document"));
509         return NULL;
510     }
511     if (virJSONValueObjectGetNumberUint(object, "nrequests_max",
512                                         &nrequests_max) < 0) {
513         virReportError(VIR_ERR_INTERNAL_ERROR, "%s",
514                        _("Missing nrequests_client_max field in JSON state document"));
515         return NULL;
516     }
517 
518     if (!(child = virJSONValueObjectGet(object, "sock"))) {
519         virReportError(VIR_ERR_INTERNAL_ERROR, "%s",
520                        _("Missing sock field in JSON state document"));
521         return NULL;
522     }
523 
524     if (!virJSONValueObjectHasKey(object, "id")) {
525         /* no ID found in, a new one must be generated */
526         id = virNetServerNextClientID(srv);
527     } else {
528         if (virJSONValueObjectGetNumberUlong(object, "id", &id) < 0) {
529             virReportError(VIR_ERR_INTERNAL_ERROR, "%s",
530                            _("Malformed id field in JSON state document"));
531             return NULL;
532         }
533     }
534 
535     if (!virJSONValueObjectHasKey(object, "conn_time")) {
536         timestamp = 0;
537     } else {
538         if (virJSONValueObjectGetNumberLong(object, "conn_time", &timestamp) < 0) {
539             virReportError(VIR_ERR_INTERNAL_ERROR, "%s",
540                            _("Malformed conn_time field in JSON "
541                              "state document"));
542             return NULL;
543         }
544     }
545 
546     if (!(sock = virNetSocketNewPostExecRestart(child))) {
547         virObjectUnref(sock);
548         return NULL;
549     }
550 
551     if (!(client = virNetServerClientNewInternal(id,
552                                                  sock,
553                                                  auth,
554                                                  auth_pending,
555                                                  NULL,
556                                                  readonly,
557                                                  nrequests_max,
558                                                  timestamp))) {
559         virObjectUnref(sock);
560         return NULL;
561     }
562     virObjectUnref(sock);
563 
564     if (!(child = virJSONValueObjectGet(object, "privateData"))) {
565         virReportError(VIR_ERR_INTERNAL_ERROR, "%s",
566                        _("Missing privateData field in JSON state document"));
567         goto error;
568     }
569 
570     if (!(client->privateData = privNew(client, child, privOpaque)))
571         goto error;
572 
573     client->privateDataFreeFunc = privFree;
574     client->privateDataPreExecRestart = privPreExecRestart;
575 
576 
577     return client;
578 
579  error:
580     virObjectUnref(client);
581     return NULL;
582 }
583 
584 
virNetServerClientPreExecRestart(virNetServerClient * client)585 virJSONValue *virNetServerClientPreExecRestart(virNetServerClient *client)
586 {
587     g_autoptr(virJSONValue) object = virJSONValueNewObject();
588     g_autoptr(virJSONValue) sock = NULL;
589     g_autoptr(virJSONValue) priv = NULL;
590 
591     virObjectLock(client);
592 
593     if (virJSONValueObjectAppendNumberUlong(object, "id", client->id) < 0)
594         goto error;
595     if (virJSONValueObjectAppendNumberInt(object, "auth", client->auth) < 0)
596         goto error;
597     if (virJSONValueObjectAppendBoolean(object, "auth_pending", client->auth_pending) < 0)
598         goto error;
599     if (virJSONValueObjectAppendBoolean(object, "readonly", client->readonly) < 0)
600         goto error;
601     if (virJSONValueObjectAppendNumberUint(object, "nrequests_max", client->nrequests_max) < 0)
602         goto error;
603 
604     if (client->conn_time &&
605         virJSONValueObjectAppendNumberLong(object, "conn_time",
606                                            client->conn_time) < 0)
607         goto error;
608 
609     if (!(sock = virNetSocketPreExecRestart(client->sock)))
610         goto error;
611 
612     if (virJSONValueObjectAppend(object, "sock", &sock) < 0)
613         goto error;
614 
615     if (!(priv = client->privateDataPreExecRestart(client, client->privateData)))
616         goto error;
617 
618     if (virJSONValueObjectAppend(object, "privateData", &priv) < 0)
619         goto error;
620 
621     virObjectUnlock(client);
622     return g_steal_pointer(&object);
623 
624  error:
625     virObjectUnlock(client);
626     return NULL;
627 }
628 
629 
virNetServerClientGetAuth(virNetServerClient * client)630 int virNetServerClientGetAuth(virNetServerClient *client)
631 {
632     int auth;
633     virObjectLock(client);
634     auth = client->auth;
635     virObjectUnlock(client);
636     return auth;
637 }
638 
639 
640 void
virNetServerClientSetAuthLocked(virNetServerClient * client,int auth)641 virNetServerClientSetAuthLocked(virNetServerClient *client,
642                                 int auth)
643 {
644     client->auth = auth;
645 }
646 
647 
virNetServerClientGetReadonly(virNetServerClient * client)648 bool virNetServerClientGetReadonly(virNetServerClient *client)
649 {
650     bool readonly;
651     virObjectLock(client);
652     readonly = client->readonly;
653     virObjectUnlock(client);
654     return readonly;
655 }
656 
657 
658 void
virNetServerClientSetReadonly(virNetServerClient * client,bool readonly)659 virNetServerClientSetReadonly(virNetServerClient *client,
660                               bool readonly)
661 {
662     virObjectLock(client);
663     client->readonly = readonly;
664     virObjectUnlock(client);
665 }
666 
667 
virNetServerClientGetID(virNetServerClient * client)668 unsigned long long virNetServerClientGetID(virNetServerClient *client)
669 {
670     return client->id;
671 }
672 
virNetServerClientGetTimestamp(virNetServerClient * client)673 long long virNetServerClientGetTimestamp(virNetServerClient *client)
674 {
675     return client->conn_time;
676 }
677 
virNetServerClientHasTLSSession(virNetServerClient * client)678 bool virNetServerClientHasTLSSession(virNetServerClient *client)
679 {
680     bool has;
681     virObjectLock(client);
682     has = client->tls ? true : false;
683     virObjectUnlock(client);
684     return has;
685 }
686 
687 
virNetServerClientGetTLSSession(virNetServerClient * client)688 virNetTLSSession *virNetServerClientGetTLSSession(virNetServerClient *client)
689 {
690     virNetTLSSession *tls;
691     virObjectLock(client);
692     tls = client->tls;
693     virObjectUnlock(client);
694     return tls;
695 }
696 
virNetServerClientGetTLSKeySize(virNetServerClient * client)697 int virNetServerClientGetTLSKeySize(virNetServerClient *client)
698 {
699     int size = 0;
700     virObjectLock(client);
701     if (client->tls)
702         size = virNetTLSSessionGetKeySize(client->tls);
703     virObjectUnlock(client);
704     return size;
705 }
706 
virNetServerClientGetFD(virNetServerClient * client)707 int virNetServerClientGetFD(virNetServerClient *client)
708 {
709     int fd = -1;
710     virObjectLock(client);
711     if (client->sock)
712         fd = virNetSocketGetFD(client->sock);
713     virObjectUnlock(client);
714     return fd;
715 }
716 
717 
virNetServerClientIsLocal(virNetServerClient * client)718 bool virNetServerClientIsLocal(virNetServerClient *client)
719 {
720     bool local = false;
721     virObjectLock(client);
722     if (client->sock)
723         local = virNetSocketIsLocal(client->sock);
724     virObjectUnlock(client);
725     return local;
726 }
727 
728 
virNetServerClientGetUNIXIdentity(virNetServerClient * client,uid_t * uid,gid_t * gid,pid_t * pid,unsigned long long * timestamp)729 int virNetServerClientGetUNIXIdentity(virNetServerClient *client,
730                                       uid_t *uid, gid_t *gid, pid_t *pid,
731                                       unsigned long long *timestamp)
732 {
733     int ret = -1;
734     virObjectLock(client);
735     if (client->sock)
736         ret = virNetSocketGetUNIXIdentity(client->sock,
737                                           uid, gid, pid,
738                                           timestamp);
739     virObjectUnlock(client);
740     return ret;
741 }
742 
743 
744 static virIdentity *
virNetServerClientCreateIdentity(virNetServerClient * client)745 virNetServerClientCreateIdentity(virNetServerClient *client)
746 {
747     g_autofree char *username = NULL;
748     g_autofree char *groupname = NULL;
749     g_autofree char *seccontext = NULL;
750     g_autoptr(virIdentity) ret = virIdentityNew();
751 
752     if (client->sock && virNetSocketIsLocal(client->sock)) {
753         gid_t gid;
754         uid_t uid;
755         pid_t pid;
756         unsigned long long timestamp;
757         if (virNetSocketGetUNIXIdentity(client->sock,
758                                         &uid, &gid, &pid,
759                                         &timestamp) < 0)
760             return NULL;
761 
762         if (!(username = virGetUserName(uid)))
763             return NULL;
764         if (virIdentitySetUserName(ret, username) < 0)
765             return NULL;
766         if (virIdentitySetUNIXUserID(ret, uid) < 0)
767             return NULL;
768 
769         if (!(groupname = virGetGroupName(gid)))
770             return NULL;
771         if (virIdentitySetGroupName(ret, groupname) < 0)
772             return NULL;
773         if (virIdentitySetUNIXGroupID(ret, gid) < 0)
774             return NULL;
775 
776         if (virIdentitySetProcessID(ret, pid) < 0)
777             return NULL;
778         if (virIdentitySetProcessTime(ret, timestamp) < 0)
779             return NULL;
780     }
781 
782 #if WITH_SASL
783     if (client->sasl) {
784         const char *identity = virNetSASLSessionGetIdentity(client->sasl);
785         if (virIdentitySetSASLUserName(ret, identity) < 0)
786             return NULL;
787     }
788 #endif
789 
790     if (client->tls) {
791         const char *identity = virNetTLSSessionGetX509DName(client->tls);
792         if (virIdentitySetX509DName(ret, identity) < 0)
793             return NULL;
794     }
795 
796     if (client->sock &&
797         virNetSocketGetSELinuxContext(client->sock, &seccontext) < 0)
798         return NULL;
799     if (seccontext &&
800         virIdentitySetSELinuxContext(ret, seccontext) < 0)
801         return NULL;
802 
803     return g_steal_pointer(&ret);
804 }
805 
806 
virNetServerClientGetIdentity(virNetServerClient * client)807 virIdentity *virNetServerClientGetIdentity(virNetServerClient *client)
808 {
809     virIdentity *ret = NULL;
810     virObjectLock(client);
811     if (!client->identity)
812         client->identity = virNetServerClientCreateIdentity(client);
813     if (client->identity)
814         ret = g_object_ref(client->identity);
815     virObjectUnlock(client);
816     return ret;
817 }
818 
819 
virNetServerClientSetIdentity(virNetServerClient * client,virIdentity * identity)820 void virNetServerClientSetIdentity(virNetServerClient *client,
821                                    virIdentity *identity)
822 {
823     virObjectLock(client);
824     g_clear_object(&client->identity);
825     client->identity = identity;
826     if (client->identity)
827         g_object_ref(client->identity);
828     virObjectUnlock(client);
829 }
830 
831 
virNetServerClientGetSELinuxContext(virNetServerClient * client,char ** context)832 int virNetServerClientGetSELinuxContext(virNetServerClient *client,
833                                         char **context)
834 {
835     int ret = 0;
836     *context = NULL;
837     virObjectLock(client);
838     if (client->sock)
839         ret = virNetSocketGetSELinuxContext(client->sock, context);
840     virObjectUnlock(client);
841     return ret;
842 }
843 
844 
virNetServerClientIsSecure(virNetServerClient * client)845 bool virNetServerClientIsSecure(virNetServerClient *client)
846 {
847     bool secure = false;
848     virObjectLock(client);
849     if (client->tls)
850         secure = true;
851 #if WITH_SASL
852     if (client->sasl)
853         secure = true;
854 #endif
855     if (client->sock && virNetSocketIsLocal(client->sock))
856         secure = true;
857     virObjectUnlock(client);
858     return secure;
859 }
860 
861 
862 #if WITH_SASL
virNetServerClientSetSASLSession(virNetServerClient * client,virNetSASLSession * sasl)863 void virNetServerClientSetSASLSession(virNetServerClient *client,
864                                       virNetSASLSession *sasl)
865 {
866     /* We don't set the sasl session on the socket here
867      * because we need to send out the auth confirmation
868      * in the clear. Only once we complete the next 'tx'
869      * operation do we switch to SASL mode
870      */
871     virObjectLock(client);
872     client->sasl = virObjectRef(sasl);
873     virObjectUnlock(client);
874 }
875 
876 
virNetServerClientGetSASLSession(virNetServerClient * client)877 virNetSASLSession *virNetServerClientGetSASLSession(virNetServerClient *client)
878 {
879     virNetSASLSession *sasl;
880     virObjectLock(client);
881     sasl = client->sasl;
882     virObjectUnlock(client);
883     return sasl;
884 }
885 
virNetServerClientHasSASLSession(virNetServerClient * client)886 bool virNetServerClientHasSASLSession(virNetServerClient *client)
887 {
888     bool has = false;
889     virObjectLock(client);
890     has = !!client->sasl;
891     virObjectUnlock(client);
892     return has;
893 }
894 #endif
895 
896 
virNetServerClientGetPrivateData(virNetServerClient * client)897 void *virNetServerClientGetPrivateData(virNetServerClient *client)
898 {
899     void *data;
900     virObjectLock(client);
901     data = client->privateData;
902     virObjectUnlock(client);
903     return data;
904 }
905 
906 
virNetServerClientSetCloseHook(virNetServerClient * client,virNetServerClientCloseFunc cf)907 void virNetServerClientSetCloseHook(virNetServerClient *client,
908                                     virNetServerClientCloseFunc cf)
909 {
910     virObjectLock(client);
911     client->privateDataCloseFunc = cf;
912     virObjectUnlock(client);
913 }
914 
915 
virNetServerClientSetDispatcher(virNetServerClient * client,virNetServerClientDispatchFunc func,void * opaque)916 void virNetServerClientSetDispatcher(virNetServerClient *client,
917                                      virNetServerClientDispatchFunc func,
918                                      void *opaque)
919 {
920     virObjectLock(client);
921     /* Only set dispatcher if not already set, to avoid race
922      * with dispatch code that runs without locks held
923      */
924     if (!client->dispatchFunc) {
925         client->dispatchFunc = func;
926         client->dispatchOpaque = opaque;
927     }
928     virObjectUnlock(client);
929 }
930 
931 
virNetServerClientLocalAddrStringSASL(virNetServerClient * client)932 const char *virNetServerClientLocalAddrStringSASL(virNetServerClient *client)
933 {
934     if (!client->sock)
935         return NULL;
936     return virNetSocketLocalAddrStringSASL(client->sock);
937 }
938 
939 
virNetServerClientRemoteAddrStringSASL(virNetServerClient * client)940 const char *virNetServerClientRemoteAddrStringSASL(virNetServerClient *client)
941 {
942     if (!client->sock)
943         return NULL;
944     return virNetSocketRemoteAddrStringSASL(client->sock);
945 }
946 
virNetServerClientRemoteAddrStringURI(virNetServerClient * client)947 const char *virNetServerClientRemoteAddrStringURI(virNetServerClient *client)
948 {
949     if (!client->sock)
950         return NULL;
951     return virNetSocketRemoteAddrStringURI(client->sock);
952 }
953 
virNetServerClientDispose(void * obj)954 void virNetServerClientDispose(void *obj)
955 {
956     virNetServerClient *client = obj;
957 
958     PROBE(RPC_SERVER_CLIENT_DISPOSE,
959           "client=%p", client);
960 
961     if (client->privateData)
962         client->privateDataFreeFunc(client->privateData);
963 
964     g_clear_object(&client->identity);
965 
966 #if WITH_SASL
967     virObjectUnref(client->sasl);
968 #endif
969     if (client->sockTimer > 0)
970         virEventRemoveTimeout(client->sockTimer);
971     virObjectUnref(client->tls);
972     virObjectUnref(client->tlsCtxt);
973     virObjectUnref(client->sock);
974 }
975 
976 
977 /*
978  *
979  * We don't free stuff here, merely disconnect the client's
980  * network socket & resources.
981  *
982  * Full free of the client is done later in a safe point
983  * where it can be guaranteed it is no longer in use
984  */
985 void
virNetServerClientCloseLocked(virNetServerClient * client)986 virNetServerClientCloseLocked(virNetServerClient *client)
987 {
988     virNetServerClientCloseFunc cf;
989     virKeepAlive *ka;
990 
991     VIR_DEBUG("client=%p", client);
992     if (!client->sock)
993         return;
994 
995     if (client->keepalive) {
996         virKeepAliveStop(client->keepalive);
997         ka = g_steal_pointer(&client->keepalive);
998         virObjectRef(client);
999         virObjectUnlock(client);
1000         virObjectUnref(ka);
1001         virObjectLock(client);
1002         virObjectUnref(client);
1003     }
1004 
1005     if (client->privateDataCloseFunc) {
1006         cf = client->privateDataCloseFunc;
1007         virObjectRef(client);
1008         virObjectUnlock(client);
1009         (cf)(client);
1010         virObjectLock(client);
1011         virObjectUnref(client);
1012     }
1013 
1014     /* Do now, even though we don't close the socket
1015      * until end, to ensure we don't get invoked
1016      * again due to tls shutdown */
1017     if (client->sock)
1018         virNetSocketRemoveIOCallback(client->sock);
1019 
1020     if (client->tls) {
1021         virObjectUnref(client->tls);
1022         client->tls = NULL;
1023     }
1024     client->wantClose = true;
1025 
1026     while (client->rx) {
1027         virNetMessage *msg
1028             = virNetMessageQueueServe(&client->rx);
1029         virNetMessageFree(msg);
1030     }
1031     while (client->tx) {
1032         virNetMessage *msg
1033             = virNetMessageQueueServe(&client->tx);
1034         virNetMessageFree(msg);
1035     }
1036 
1037     if (client->sock) {
1038         virObjectUnref(client->sock);
1039         client->sock = NULL;
1040     }
1041 }
1042 
1043 
1044 void
virNetServerClientClose(virNetServerClient * client)1045 virNetServerClientClose(virNetServerClient *client)
1046 {
1047     virObjectLock(client);
1048     virNetServerClientCloseLocked(client);
1049     virObjectUnlock(client);
1050 }
1051 
1052 
1053 bool
virNetServerClientIsClosedLocked(virNetServerClient * client)1054 virNetServerClientIsClosedLocked(virNetServerClient *client)
1055 {
1056     return client->sock == NULL;
1057 }
1058 
1059 
virNetServerClientDelayedClose(virNetServerClient * client)1060 void virNetServerClientDelayedClose(virNetServerClient *client)
1061 {
1062     virObjectLock(client);
1063     client->delayedClose = true;
1064     virObjectUnlock(client);
1065 }
1066 
virNetServerClientImmediateClose(virNetServerClient * client)1067 void virNetServerClientImmediateClose(virNetServerClient *client)
1068 {
1069     virObjectLock(client);
1070     client->wantClose = true;
1071     virObjectUnlock(client);
1072 }
1073 
1074 
1075 bool
virNetServerClientWantCloseLocked(virNetServerClient * client)1076 virNetServerClientWantCloseLocked(virNetServerClient *client)
1077 {
1078     return client->wantClose;
1079 }
1080 
1081 
virNetServerClientInit(virNetServerClient * client)1082 int virNetServerClientInit(virNetServerClient *client)
1083 {
1084     virObjectLock(client);
1085 
1086     if (!client->tlsCtxt) {
1087         /* Plain socket, so prepare to read first message */
1088         if (virNetServerClientRegisterEvent(client) < 0)
1089             goto error;
1090     } else {
1091         int ret;
1092 
1093         if (!(client->tls = virNetTLSSessionNew(client->tlsCtxt,
1094                                                 NULL)))
1095             goto error;
1096 
1097         virNetSocketSetTLSSession(client->sock,
1098                                   client->tls);
1099 
1100         /* Begin the TLS handshake. */
1101         virObjectLock(client->tlsCtxt);
1102         ret = virNetTLSSessionHandshake(client->tls);
1103         virObjectUnlock(client->tlsCtxt);
1104         if (ret == 0) {
1105             /* Unlikely, but ...  Next step is to check the certificate. */
1106             if (virNetServerClientCheckAccess(client) < 0)
1107                 goto error;
1108 
1109             /* Handshake & cert check OK,  so prepare to read first message */
1110             if (virNetServerClientRegisterEvent(client) < 0)
1111                 goto error;
1112         } else if (ret > 0) {
1113             /* Most likely, need to do more handshake data */
1114             if (virNetServerClientRegisterEvent(client) < 0)
1115                 goto error;
1116         } else {
1117             goto error;
1118         }
1119     }
1120 
1121     virObjectUnlock(client);
1122     return 0;
1123 
1124  error:
1125     client->wantClose = true;
1126     virObjectUnlock(client);
1127     return -1;
1128 }
1129 
1130 
1131 
1132 /*
1133  * Read data into buffer using wire decoding (plain or TLS)
1134  *
1135  * Returns:
1136  *   -1 on error or EOF
1137  *    0 on EAGAIN
1138  *    n number of bytes
1139  */
virNetServerClientRead(virNetServerClient * client)1140 static ssize_t virNetServerClientRead(virNetServerClient *client)
1141 {
1142     ssize_t ret;
1143 
1144     if (client->rx->bufferLength <= client->rx->bufferOffset) {
1145         virReportError(VIR_ERR_RPC,
1146                        _("unexpected zero/negative length request %lld"),
1147                        (long long int)(client->rx->bufferLength - client->rx->bufferOffset));
1148         client->wantClose = true;
1149         return -1;
1150     }
1151 
1152     ret = virNetSocketRead(client->sock,
1153                            client->rx->buffer + client->rx->bufferOffset,
1154                            client->rx->bufferLength - client->rx->bufferOffset);
1155 
1156     if (ret <= 0)
1157         return ret;
1158 
1159     client->rx->bufferOffset += ret;
1160     return ret;
1161 }
1162 
1163 
1164 /*
1165  * Read data until we get a complete message to process.
1166  * If a complete message is available, it will be returned
1167  * from this method, for dispatch by the caller.
1168  *
1169  * Returns a complete message for dispatch, or NULL if none is
1170  * yet available, or an error occurred. On error, the wantClose
1171  * flag will be set.
1172  */
virNetServerClientDispatchRead(virNetServerClient * client)1173 static virNetMessage *virNetServerClientDispatchRead(virNetServerClient *client)
1174 {
1175  readmore:
1176     if (client->rx->nfds == 0) {
1177         if (virNetServerClientRead(client) < 0) {
1178             client->wantClose = true;
1179             return NULL; /* Error */
1180         }
1181     }
1182 
1183     if (client->rx->bufferOffset < client->rx->bufferLength)
1184         return NULL; /* Still not read enough */
1185 
1186     /* Either done with length word header */
1187     if (client->rx->bufferLength == VIR_NET_MESSAGE_LEN_MAX) {
1188         if (virNetMessageDecodeLength(client->rx) < 0) {
1189             client->wantClose = true;
1190             return NULL;
1191         }
1192 
1193         virNetServerClientUpdateEvent(client);
1194 
1195         /* Try and read payload immediately instead of going back
1196            into poll() because chances are the data is already
1197            waiting for us */
1198         goto readmore;
1199     } else {
1200         /* Grab the completed message */
1201         virNetMessage *msg = client->rx;
1202         virNetMessage *response = NULL;
1203         virNetServerClientFilter *filter;
1204         size_t i;
1205 
1206         /* Decode the header so we can use it for routing decisions */
1207         if (virNetMessageDecodeHeader(msg) < 0) {
1208             virNetMessageQueueServe(&client->rx);
1209             virNetMessageFree(msg);
1210             client->wantClose = true;
1211             return NULL;
1212         }
1213 
1214         /* Now figure out if we need to read more data to get some
1215          * file descriptors */
1216         if (msg->header.type == VIR_NET_CALL_WITH_FDS) {
1217             if (virNetMessageDecodeNumFDs(msg) < 0) {
1218                 virNetMessageQueueServe(&client->rx);
1219                 virNetMessageFree(msg);
1220                 client->wantClose = true;
1221                 return NULL; /* Error */
1222             }
1223 
1224             /* Try getting the file descriptors (may fail if blocking) */
1225             for (i = msg->donefds; i < msg->nfds; i++) {
1226                 int rv;
1227                 if ((rv = virNetSocketRecvFD(client->sock, &(msg->fds[i]))) < 0) {
1228                     virNetMessageQueueServe(&client->rx);
1229                     virNetMessageFree(msg);
1230                     client->wantClose = true;
1231                     return NULL;
1232                 }
1233                 if (rv == 0) /* Blocking */
1234                     break;
1235                 msg->donefds++;
1236             }
1237 
1238             /* Need to poll() until FDs arrive */
1239             if (msg->donefds < msg->nfds) {
1240                 /* Because DecodeHeader/NumFDs reset bufferOffset, we
1241                  * put it back to what it was, so everything works
1242                  * again next time we run this method
1243                  */
1244                 client->rx->bufferOffset = client->rx->bufferLength;
1245                 return NULL;
1246             }
1247         }
1248 
1249         /* Definitely finished reading, so remove from queue */
1250         virNetMessageQueueServe(&client->rx);
1251         PROBE(RPC_SERVER_CLIENT_MSG_RX,
1252               "client=%p len=%zu prog=%u vers=%u proc=%u type=%u status=%u serial=%u",
1253               client, msg->bufferLength,
1254               msg->header.prog, msg->header.vers, msg->header.proc,
1255               msg->header.type, msg->header.status, msg->header.serial);
1256 
1257         if (virKeepAliveCheckMessage(client->keepalive, msg, &response)) {
1258             virNetMessageFree(msg);
1259             client->nrequests--;
1260             msg = NULL;
1261 
1262             if (response &&
1263                 virNetServerClientSendMessageLocked(client, response) < 0)
1264                 virNetMessageFree(response);
1265         }
1266 
1267         /* Maybe send off for queue against a filter */
1268         if (msg) {
1269             filter = client->filters;
1270             while (filter) {
1271                 int ret = filter->func(client, msg, filter->opaque);
1272                 if (ret < 0) {
1273                     virNetMessageFree(msg);
1274                     msg = NULL;
1275                     client->wantClose = true;
1276                     break;
1277                 }
1278                 if (ret > 0) {
1279                     msg = NULL;
1280                     break;
1281                 }
1282 
1283                 filter = filter->next;
1284             }
1285         }
1286 
1287         /* Possibly need to create another receive buffer */
1288         if (client->nrequests < client->nrequests_max) {
1289             if (!(client->rx = virNetMessageNew(true))) {
1290                 client->wantClose = true;
1291             } else {
1292                 client->rx->bufferLength = VIR_NET_MESSAGE_LEN_MAX;
1293                 client->rx->buffer = g_new0(char, client->rx->bufferLength);
1294                 client->nrequests++;
1295             }
1296         }
1297         virNetServerClientUpdateEvent(client);
1298 
1299         return msg;
1300     }
1301 }
1302 
1303 
1304 /*
1305  * Send client->tx using no encoding
1306  *
1307  * Returns:
1308  *   -1 on error or EOF
1309  *    0 on EAGAIN
1310  *    n number of bytes
1311  */
virNetServerClientWrite(virNetServerClient * client)1312 static ssize_t virNetServerClientWrite(virNetServerClient *client)
1313 {
1314     ssize_t ret;
1315 
1316     if (client->tx->bufferLength < client->tx->bufferOffset) {
1317         virReportError(VIR_ERR_RPC,
1318                        _("unexpected zero/negative length request %lld"),
1319                        (long long int)(client->tx->bufferLength - client->tx->bufferOffset));
1320         client->wantClose = true;
1321         return -1;
1322     }
1323 
1324     if (client->tx->bufferLength == client->tx->bufferOffset)
1325         return 1;
1326 
1327     ret = virNetSocketWrite(client->sock,
1328                             client->tx->buffer + client->tx->bufferOffset,
1329                             client->tx->bufferLength - client->tx->bufferOffset);
1330     if (ret <= 0)
1331         return ret; /* -1 error, 0 = egain */
1332 
1333     client->tx->bufferOffset += ret;
1334     return ret;
1335 }
1336 
1337 
1338 /*
1339  * Process all queued client->tx messages until
1340  * we would block on I/O
1341  */
1342 static void
virNetServerClientDispatchWrite(virNetServerClient * client)1343 virNetServerClientDispatchWrite(virNetServerClient *client)
1344 {
1345     while (client->tx) {
1346         if (client->tx->bufferOffset < client->tx->bufferLength) {
1347             ssize_t ret;
1348             ret = virNetServerClientWrite(client);
1349             if (ret < 0) {
1350                 client->wantClose = true;
1351                 return;
1352             }
1353             if (ret == 0)
1354                 return; /* Would block on write EAGAIN */
1355         }
1356 
1357         if (client->tx->bufferOffset == client->tx->bufferLength) {
1358             virNetMessage *msg;
1359             size_t i;
1360 
1361             for (i = client->tx->donefds; i < client->tx->nfds; i++) {
1362                 int rv;
1363                 if ((rv = virNetSocketSendFD(client->sock, client->tx->fds[i])) < 0) {
1364                     client->wantClose = true;
1365                     return;
1366                 }
1367                 if (rv == 0) /* Blocking */
1368                     return;
1369                 client->tx->donefds++;
1370             }
1371 
1372 #if WITH_SASL
1373             /* Completed this 'tx' operation, so now read for all
1374              * future rx/tx to be under a SASL SSF layer
1375              */
1376             if (client->sasl) {
1377                 virNetSocketSetSASLSession(client->sock, client->sasl);
1378                 virObjectUnref(client->sasl);
1379                 client->sasl = NULL;
1380             }
1381 #endif
1382 
1383             /* Get finished msg from head of tx queue */
1384             msg = virNetMessageQueueServe(&client->tx);
1385 
1386             if (msg->tracked) {
1387                 client->nrequests--;
1388                 /* See if the recv queue is currently throttled */
1389                 if (!client->rx &&
1390                     client->nrequests < client->nrequests_max) {
1391                     /* Ready to recv more messages */
1392                     virNetMessageClear(msg);
1393                     msg->bufferLength = VIR_NET_MESSAGE_LEN_MAX;
1394                     msg->buffer = g_new0(char, msg->bufferLength);
1395                     client->rx = g_steal_pointer(&msg);
1396                     client->nrequests++;
1397                 }
1398             }
1399 
1400             virNetMessageFree(msg);
1401 
1402             virNetServerClientUpdateEvent(client);
1403 
1404             if (client->delayedClose)
1405                 client->wantClose = true;
1406          }
1407     }
1408 }
1409 
1410 
1411 static void
virNetServerClientDispatchHandshake(virNetServerClient * client)1412 virNetServerClientDispatchHandshake(virNetServerClient *client)
1413 {
1414     int ret;
1415     /* Continue the handshake. */
1416     virObjectLock(client->tlsCtxt);
1417     ret = virNetTLSSessionHandshake(client->tls);
1418     virObjectUnlock(client->tlsCtxt);
1419     if (ret == 0) {
1420         /* Finished.  Next step is to check the certificate. */
1421         if (virNetServerClientCheckAccess(client) < 0)
1422             client->wantClose = true;
1423         else
1424             virNetServerClientUpdateEvent(client);
1425     } else if (ret > 0) {
1426         /* Carry on waiting for more handshake. Update
1427            the events just in case handshake data flow
1428            direction has changed */
1429         virNetServerClientUpdateEvent(client);
1430     } else {
1431         /* Fatal error in handshake */
1432         client->wantClose = true;
1433     }
1434 }
1435 
1436 
1437 static void
virNetServerClientDispatchEvent(virNetSocket * sock,int events,void * opaque)1438 virNetServerClientDispatchEvent(virNetSocket *sock, int events, void *opaque)
1439 {
1440     virNetServerClient *client = opaque;
1441     virNetMessage *msg = NULL;
1442 
1443     virObjectLock(client);
1444 
1445     if (client->sock != sock) {
1446         virNetSocketRemoveIOCallback(sock);
1447         virObjectUnlock(client);
1448         return;
1449     }
1450 
1451     if (events & (VIR_EVENT_HANDLE_WRITABLE |
1452                   VIR_EVENT_HANDLE_READABLE)) {
1453         if (client->tls &&
1454             virNetTLSSessionGetHandshakeStatus(client->tls) !=
1455             VIR_NET_TLS_HANDSHAKE_COMPLETE) {
1456             virNetServerClientDispatchHandshake(client);
1457         } else {
1458             if (events & VIR_EVENT_HANDLE_WRITABLE)
1459                 virNetServerClientDispatchWrite(client);
1460             if (events & VIR_EVENT_HANDLE_READABLE &&
1461                 client->rx)
1462                 msg = virNetServerClientDispatchRead(client);
1463         }
1464     }
1465 
1466     /* NB, will get HANGUP + READABLE at same time upon
1467      * disconnect */
1468     if (events & (VIR_EVENT_HANDLE_ERROR |
1469                   VIR_EVENT_HANDLE_HANGUP))
1470         client->wantClose = true;
1471 
1472     virObjectUnlock(client);
1473 
1474     if (msg)
1475         virNetServerClientDispatchMessage(client, msg);
1476 }
1477 
1478 
1479 static int
virNetServerClientSendMessageLocked(virNetServerClient * client,virNetMessage * msg)1480 virNetServerClientSendMessageLocked(virNetServerClient *client,
1481                                     virNetMessage *msg)
1482 {
1483     int ret = -1;
1484     VIR_DEBUG("msg=%p proc=%d len=%zu offset=%zu",
1485               msg, msg->header.proc,
1486               msg->bufferLength, msg->bufferOffset);
1487 
1488     msg->donefds = 0;
1489     if (client->sock && !client->wantClose) {
1490         PROBE(RPC_SERVER_CLIENT_MSG_TX_QUEUE,
1491               "client=%p len=%zu prog=%u vers=%u proc=%u type=%u status=%u serial=%u",
1492               client, msg->bufferLength,
1493               msg->header.prog, msg->header.vers, msg->header.proc,
1494               msg->header.type, msg->header.status, msg->header.serial);
1495         virNetMessageQueuePush(&client->tx, msg);
1496 
1497         virNetServerClientUpdateEvent(client);
1498         ret = 0;
1499     }
1500 
1501     return ret;
1502 }
1503 
virNetServerClientSendMessage(virNetServerClient * client,virNetMessage * msg)1504 int virNetServerClientSendMessage(virNetServerClient *client,
1505                                   virNetMessage *msg)
1506 {
1507     int ret;
1508 
1509     virObjectLock(client);
1510     ret = virNetServerClientSendMessageLocked(client, msg);
1511     virObjectUnlock(client);
1512 
1513     return ret;
1514 }
1515 
1516 
1517 bool
virNetServerClientIsAuthenticated(virNetServerClient * client)1518 virNetServerClientIsAuthenticated(virNetServerClient *client)
1519 {
1520     bool authenticated;
1521     virObjectLock(client);
1522     authenticated = virNetServerClientAuthMethodImpliesAuthenticated(client->auth);
1523     virObjectUnlock(client);
1524     return authenticated;
1525 }
1526 
1527 
1528 /* The caller must hold the lock for @client */
1529 void
virNetServerClientSetAuthPendingLocked(virNetServerClient * client,bool auth_pending)1530 virNetServerClientSetAuthPendingLocked(virNetServerClient *client,
1531                                        bool auth_pending)
1532 {
1533     client->auth_pending = auth_pending;
1534 }
1535 
1536 
1537 /* The caller must hold the lock for @client */
1538 bool
virNetServerClientIsAuthPendingLocked(virNetServerClient * client)1539 virNetServerClientIsAuthPendingLocked(virNetServerClient *client)
1540 {
1541     return client->auth_pending;
1542 }
1543 
1544 
1545 static void
virNetServerClientKeepAliveDeadCB(void * opaque)1546 virNetServerClientKeepAliveDeadCB(void *opaque)
1547 {
1548     virNetServerClientImmediateClose(opaque);
1549 }
1550 
1551 static int
virNetServerClientKeepAliveSendCB(void * opaque,virNetMessage * msg)1552 virNetServerClientKeepAliveSendCB(void *opaque,
1553                                   virNetMessage *msg)
1554 {
1555     return virNetServerClientSendMessage(opaque, msg);
1556 }
1557 
1558 
1559 int
virNetServerClientInitKeepAlive(virNetServerClient * client,int interval,unsigned int count)1560 virNetServerClientInitKeepAlive(virNetServerClient *client,
1561                                 int interval,
1562                                 unsigned int count)
1563 {
1564     virKeepAlive *ka;
1565     int ret = -1;
1566 
1567     virObjectLock(client);
1568 
1569     if (!(ka = virKeepAliveNew(interval, count, client,
1570                                virNetServerClientKeepAliveSendCB,
1571                                virNetServerClientKeepAliveDeadCB,
1572                                virObjectFreeCallback)))
1573         goto cleanup;
1574     /* keepalive object has a reference to client */
1575     virObjectRef(client);
1576 
1577     client->keepalive = ka;
1578     ret = 0;
1579  cleanup:
1580     virObjectUnlock(client);
1581 
1582     return ret;
1583 }
1584 
1585 int
virNetServerClientStartKeepAlive(virNetServerClient * client)1586 virNetServerClientStartKeepAlive(virNetServerClient *client)
1587 {
1588     int ret = -1;
1589 
1590     virObjectLock(client);
1591 
1592     /* The connection might have been closed before we got here and thus the
1593      * keepalive object could have been removed too.
1594      */
1595     if (!client->keepalive) {
1596         virReportError(VIR_ERR_INTERNAL_ERROR, "%s",
1597                        _("connection not open"));
1598         goto cleanup;
1599     }
1600 
1601     ret = virKeepAliveStart(client->keepalive, 0, 0);
1602 
1603  cleanup:
1604     virObjectUnlock(client);
1605     return ret;
1606 }
1607 
1608 int
virNetServerClientGetTransport(virNetServerClient * client)1609 virNetServerClientGetTransport(virNetServerClient *client)
1610 {
1611     int ret = -1;
1612 
1613     virObjectLock(client);
1614 
1615     if (client->sock && virNetSocketIsLocal(client->sock))
1616         ret = VIR_CLIENT_TRANS_UNIX;
1617     else
1618         ret = VIR_CLIENT_TRANS_TCP;
1619 
1620     if (client->tls)
1621         ret = VIR_CLIENT_TRANS_TLS;
1622 
1623     virObjectUnlock(client);
1624 
1625     return ret;
1626 }
1627 
1628 int
virNetServerClientGetInfo(virNetServerClient * client,bool * readonly,char ** sock_addr,virIdentity ** identity)1629 virNetServerClientGetInfo(virNetServerClient *client,
1630                           bool *readonly, char **sock_addr,
1631                           virIdentity **identity)
1632 {
1633     int ret = -1;
1634     const char *addr;
1635 
1636     virObjectLock(client);
1637     *readonly = client->readonly;
1638 
1639     if (!(addr = virNetServerClientRemoteAddrStringURI(client))) {
1640         virReportError(VIR_ERR_INTERNAL_ERROR, "%s",
1641                        _("No network socket associated with client"));
1642         goto cleanup;
1643     }
1644 
1645     *sock_addr = g_strdup(addr);
1646 
1647     if (!client->identity) {
1648         virReportError(VIR_ERR_INTERNAL_ERROR, "%s",
1649                        _("No identity information available for client"));
1650         goto cleanup;
1651     }
1652 
1653     *identity = g_object_ref(client->identity);
1654 
1655     ret = 0;
1656  cleanup:
1657     virObjectUnlock(client);
1658     return ret;
1659 }
1660 
1661 
1662 /**
1663  * virNetServerClientSetQuietEOF:
1664  *
1665  * Don't report errors for protocols that close connection by hangup of the
1666  * socket rather than calling an API to close it.
1667  */
1668 void
virNetServerClientSetQuietEOF(virNetServerClient * client)1669 virNetServerClientSetQuietEOF(virNetServerClient *client)
1670 {
1671     virNetSocketSetQuietEOF(client->sock);
1672 }
1673