1 /*
2  * virnetserver.c: generic network RPC server
3  *
4  * Copyright (C) 2006-2015 Red Hat, Inc.
5  * Copyright (C) 2006 Daniel P. Berrange
6  *
7  * This library is free software; you can redistribute it and/or
8  * modify it under the terms of the GNU Lesser General Public
9  * License as published by the Free Software Foundation; either
10  * version 2.1 of the License, or (at your option) any later version.
11  *
12  * This library is distributed in the hope that it will be useful,
13  * but WITHOUT ANY WARRANTY; without even the implied warranty of
14  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
15  * Lesser General Public License for more details.
16  *
17  * You should have received a copy of the GNU Lesser General Public
18  * License along with this library.  If not, see
19  * <http://www.gnu.org/licenses/>.
20  */
21 
22 #include <config.h>
23 
24 #include "virnetserver.h"
25 #include "virlog.h"
26 #include "viralloc.h"
27 #include "virerror.h"
28 #include "virthread.h"
29 #include "virthreadpool.h"
30 #include "virstring.h"
31 #include "virutil.h"
32 
33 #define VIR_FROM_THIS VIR_FROM_RPC
34 
35 VIR_LOG_INIT("rpc.netserver");
36 
37 
38 typedef struct _virNetServerJob virNetServerJob;
39 struct _virNetServerJob {
40     virNetServerClient *client;
41     virNetMessage *msg;
42     virNetServerProgram *prog;
43 };
44 
45 struct _virNetServer {
46     virObjectLockable parent;
47 
48     char *name;
49 
50     /* Immutable pointer, self-locking APIs */
51     virThreadPool *workers;
52 
53     size_t nservices;
54     virNetServerService **services;
55 
56     size_t nprograms;
57     virNetServerProgram **programs;
58 
59     size_t nclients;                    /* Current clients count */
60     virNetServerClient **clients;     /* Clients */
61     unsigned long long next_client_id;  /* next client ID */
62     size_t nclients_max;                /* Max allowed clients count */
63     size_t nclients_unauth;             /* Unauthenticated clients count */
64     size_t nclients_unauth_max;         /* Max allowed unauth clients count */
65 
66     int keepaliveInterval;
67     unsigned int keepaliveCount;
68 
69     virNetTLSContext *tls;
70 
71     virNetServerClientPrivNew clientPrivNew;
72     virNetServerClientPrivPreExecRestart clientPrivPreExecRestart;
73     virFreeCallback clientPrivFree;
74     void *clientPrivOpaque;
75 };
76 
77 
78 static virClass *virNetServerClass;
79 static void virNetServerDispose(void *obj);
80 static void virNetServerUpdateServicesLocked(virNetServer *srv,
81                                              bool enabled);
82 static inline size_t virNetServerTrackPendingAuthLocked(virNetServer *srv);
83 static inline size_t virNetServerTrackCompletedAuthLocked(virNetServer *srv);
84 
virNetServerOnceInit(void)85 static int virNetServerOnceInit(void)
86 {
87     if (!VIR_CLASS_NEW(virNetServer, virClassForObjectLockable()))
88         return -1;
89 
90     return 0;
91 }
92 
93 VIR_ONCE_GLOBAL_INIT(virNetServer);
94 
virNetServerNextClientID(virNetServer * srv)95 unsigned long long virNetServerNextClientID(virNetServer *srv)
96 {
97     unsigned long long val;
98 
99     virObjectLock(srv);
100     val = srv->next_client_id++;
101     virObjectUnlock(srv);
102 
103     return val;
104 }
105 
virNetServerProcessMsg(virNetServer * srv,virNetServerClient * client,virNetServerProgram * prog,virNetMessage * msg)106 static int virNetServerProcessMsg(virNetServer *srv,
107                                   virNetServerClient *client,
108                                   virNetServerProgram *prog,
109                                   virNetMessage *msg)
110 {
111     if (!prog) {
112         /* Only send back an error for type == CALL. Other
113          * message types are not expecting replies, so we
114          * must just log it & drop them
115          */
116         if (msg->header.type == VIR_NET_CALL ||
117             msg->header.type == VIR_NET_CALL_WITH_FDS) {
118             if (virNetServerProgramUnknownError(client,
119                                                 msg,
120                                                 &msg->header) < 0)
121                 return -1;
122         } else {
123             VIR_INFO("Dropping client message, unknown program %d version %d type %d proc %d",
124                      msg->header.prog, msg->header.vers,
125                      msg->header.type, msg->header.proc);
126             /* Send a dummy reply to free up 'msg' & unblock client rx */
127             virNetMessageClear(msg);
128             msg->header.type = VIR_NET_REPLY;
129             if (virNetServerClientSendMessage(client, msg) < 0)
130                 return -1;
131         }
132         return 0;
133     }
134 
135     if (virNetServerProgramDispatch(prog,
136                                     srv,
137                                     client,
138                                     msg) < 0)
139         return -1;
140 
141     return 0;
142 }
143 
virNetServerHandleJob(void * jobOpaque,void * opaque)144 static void virNetServerHandleJob(void *jobOpaque, void *opaque)
145 {
146     virNetServer *srv = opaque;
147     virNetServerJob *job = jobOpaque;
148 
149     VIR_DEBUG("server=%p client=%p message=%p prog=%p",
150               srv, job->client, job->msg, job->prog);
151 
152     if (virNetServerProcessMsg(srv, job->client, job->prog, job->msg) < 0)
153         goto error;
154 
155     virObjectUnref(job->prog);
156     virObjectUnref(job->client);
157     VIR_FREE(job);
158     return;
159 
160  error:
161     virObjectUnref(job->prog);
162     virNetMessageFree(job->msg);
163     virNetServerClientClose(job->client);
164     virObjectUnref(job->client);
165     VIR_FREE(job);
166 }
167 
168 /**
169  * virNetServerGetProgramLocked:
170  * @srv: server (must be locked by the caller)
171  * @msg: message
172  *
173  * Searches @srv for the right program for a given message @msg.
174  *
175  * Returns a pointer to the server program or NULL if not found.
176  */
177 static virNetServerProgram *
virNetServerGetProgramLocked(virNetServer * srv,virNetMessage * msg)178 virNetServerGetProgramLocked(virNetServer *srv,
179                              virNetMessage *msg)
180 {
181     size_t i;
182     for (i = 0; i < srv->nprograms; i++) {
183         if (virNetServerProgramMatches(srv->programs[i], msg))
184             return srv->programs[i];
185     }
186     return NULL;
187 }
188 
189 static void
virNetServerDispatchNewMessage(virNetServerClient * client,virNetMessage * msg,void * opaque)190 virNetServerDispatchNewMessage(virNetServerClient *client,
191                                virNetMessage *msg,
192                                void *opaque)
193 {
194     virNetServer *srv = opaque;
195     virNetServerProgram *prog = NULL;
196     unsigned int priority = 0;
197 
198     VIR_DEBUG("server=%p client=%p message=%p",
199               srv, client, msg);
200 
201     virObjectLock(srv);
202     prog = virNetServerGetProgramLocked(srv, msg);
203     /* we can unlock @srv since @prog can only become invalid in case
204      * of disposing @srv, but let's grab a ref first to ensure nothing
205      * disposes of it before we use it. */
206     virObjectRef(srv);
207     virObjectUnlock(srv);
208 
209     if (virThreadPoolGetMaxWorkers(srv->workers) > 0)  {
210         virNetServerJob *job;
211 
212         job = g_new0(virNetServerJob, 1);
213 
214         job->client = virObjectRef(client);
215         job->msg = msg;
216 
217         if (prog) {
218             job->prog = virObjectRef(prog);
219             priority = virNetServerProgramGetPriority(prog, msg->header.proc);
220         }
221 
222         if (virThreadPoolSendJob(srv->workers, priority, job) < 0) {
223             virObjectUnref(client);
224             VIR_FREE(job);
225             virObjectUnref(prog);
226             goto error;
227         }
228     } else {
229         if (virNetServerProcessMsg(srv, client, prog, msg) < 0)
230             goto error;
231     }
232 
233     virObjectUnref(srv);
234     return;
235 
236  error:
237     virNetMessageFree(msg);
238     virNetServerClientClose(client);
239     virObjectUnref(srv);
240 }
241 
242 
243 /**
244  * virNetServerCheckLimits:
245  * @srv: server to check limits on
246  *
247  * Check if limits like max_clients or max_anonymous_clients
248  * are satisfied. If so, re-enable accepting new clients. If these are violated
249  * however, temporarily disable accepting new clients.
250  * The @srv must be locked when this function is called.
251  */
252 static void
virNetServerCheckLimits(virNetServer * srv)253 virNetServerCheckLimits(virNetServer *srv)
254 {
255     size_t i;
256 
257     for (i = 0; i < srv->nservices; i++) {
258         if (virNetServerServiceTimerActive(srv->services[i])) {
259             VIR_DEBUG("Skipping client-related limits evaluation");
260             return;
261         }
262     }
263 
264     VIR_DEBUG("Checking client-related limits to re-enable or temporarily "
265               "suspend services: nclients=%zu nclients_max=%zu "
266               "nclients_unauth=%zu nclients_unauth_max=%zu",
267               srv->nclients, srv->nclients_max,
268               srv->nclients_unauth, srv->nclients_unauth_max);
269 
270     /* Check the max_anonymous_clients and max_clients limits so that we can
271      * decide whether the services should be temporarily suspended, thus not
272      * accepting any more clients for a while or re-enabling the previously
273      * suspended services in order to accept new clients again.
274      * A new client can only be accepted if both max_clients and
275      * max_anonymous_clients wouldn't get overcommitted by accepting it.
276      */
277     if (srv->nclients >= srv->nclients_max ||
278         (srv->nclients_unauth_max &&
279          srv->nclients_unauth >= srv->nclients_unauth_max)) {
280         /* Temporarily stop accepting new clients */
281         VIR_INFO("Temporarily suspending services");
282         virNetServerUpdateServicesLocked(srv, false);
283     } else if (srv->nclients < srv->nclients_max &&
284                (!srv->nclients_unauth_max ||
285                 srv->nclients_unauth < srv->nclients_unauth_max)) {
286         /* Now it makes sense to accept() a new client. */
287         VIR_INFO("Re-enabling services");
288         virNetServerUpdateServicesLocked(srv, true);
289     }
290 }
291 
virNetServerAddClient(virNetServer * srv,virNetServerClient * client)292 int virNetServerAddClient(virNetServer *srv,
293                           virNetServerClient *client)
294 {
295     virObjectLock(srv);
296 
297     if (virNetServerClientInit(client) < 0)
298         goto error;
299 
300     VIR_EXPAND_N(srv->clients, srv->nclients, 1);
301     srv->clients[srv->nclients-1] = virObjectRef(client);
302 
303     virObjectLock(client);
304     if (virNetServerClientIsAuthPendingLocked(client))
305         virNetServerTrackPendingAuthLocked(srv);
306     virObjectUnlock(client);
307 
308     virNetServerCheckLimits(srv);
309 
310     virNetServerClientSetDispatcher(client,
311                                     virNetServerDispatchNewMessage,
312                                     srv);
313 
314     if (virNetServerClientInitKeepAlive(client, srv->keepaliveInterval,
315                                         srv->keepaliveCount) < 0)
316         goto error;
317 
318     virObjectUnlock(srv);
319     return 0;
320 
321  error:
322     virObjectUnlock(srv);
323     return -1;
324 }
325 
virNetServerDispatchNewClient(virNetServerService * svc,virNetSocket * clientsock,void * opaque)326 static int virNetServerDispatchNewClient(virNetServerService *svc,
327                                          virNetSocket *clientsock,
328                                          void *opaque)
329 {
330     virNetServer *srv = opaque;
331     virNetServerClient *client;
332 
333     if (!(client = virNetServerClientNew(virNetServerNextClientID(srv),
334                                          clientsock,
335                                          virNetServerServiceGetAuth(svc),
336                                          virNetServerServiceIsReadonly(svc),
337                                          virNetServerServiceGetMaxRequests(svc),
338                                          virNetServerServiceGetTLSContext(svc),
339                                          srv->clientPrivNew,
340                                          srv->clientPrivPreExecRestart,
341                                          srv->clientPrivFree,
342                                          srv->clientPrivOpaque)))
343         return -1;
344 
345     if (virNetServerAddClient(srv, client) < 0) {
346         virNetServerClientClose(client);
347         virObjectUnref(client);
348         return -1;
349     }
350     virObjectUnref(client);
351     return 0;
352 }
353 
354 
virNetServerNew(const char * name,unsigned long long next_client_id,size_t min_workers,size_t max_workers,size_t priority_workers,size_t max_clients,size_t max_anonymous_clients,int keepaliveInterval,unsigned int keepaliveCount,virNetServerClientPrivNew clientPrivNew,virNetServerClientPrivPreExecRestart clientPrivPreExecRestart,virFreeCallback clientPrivFree,void * clientPrivOpaque)355 virNetServer *virNetServerNew(const char *name,
356                                 unsigned long long next_client_id,
357                                 size_t min_workers,
358                                 size_t max_workers,
359                                 size_t priority_workers,
360                                 size_t max_clients,
361                                 size_t max_anonymous_clients,
362                                 int keepaliveInterval,
363                                 unsigned int keepaliveCount,
364                                 virNetServerClientPrivNew clientPrivNew,
365                                 virNetServerClientPrivPreExecRestart clientPrivPreExecRestart,
366                                 virFreeCallback clientPrivFree,
367                                 void *clientPrivOpaque)
368 {
369     virNetServer *srv;
370 
371     if (virNetServerInitialize() < 0)
372         return NULL;
373 
374     if (!(srv = virObjectLockableNew(virNetServerClass)))
375         return NULL;
376 
377     if (!(srv->workers = virThreadPoolNewFull(min_workers, max_workers,
378                                               priority_workers,
379                                               virNetServerHandleJob,
380                                               "rpc-worker",
381                                               NULL,
382                                               srv)))
383         goto error;
384 
385     srv->name = g_strdup(name);
386 
387     srv->next_client_id = next_client_id;
388     srv->nclients_max = max_clients;
389     srv->nclients_unauth_max = max_anonymous_clients;
390     srv->keepaliveInterval = keepaliveInterval;
391     srv->keepaliveCount = keepaliveCount;
392     srv->clientPrivNew = clientPrivNew;
393     srv->clientPrivPreExecRestart = clientPrivPreExecRestart;
394     srv->clientPrivFree = clientPrivFree;
395     srv->clientPrivOpaque = clientPrivOpaque;
396 
397     return srv;
398  error:
399     virObjectUnref(srv);
400     return NULL;
401 }
402 
403 
virNetServerNewPostExecRestart(virJSONValue * object,const char * name,virNetServerClientPrivNew clientPrivNew,virNetServerClientPrivNewPostExecRestart clientPrivNewPostExecRestart,virNetServerClientPrivPreExecRestart clientPrivPreExecRestart,virFreeCallback clientPrivFree,void * clientPrivOpaque)404 virNetServer *virNetServerNewPostExecRestart(virJSONValue *object,
405                                                const char *name,
406                                                virNetServerClientPrivNew clientPrivNew,
407                                                virNetServerClientPrivNewPostExecRestart clientPrivNewPostExecRestart,
408                                                virNetServerClientPrivPreExecRestart clientPrivPreExecRestart,
409                                                virFreeCallback clientPrivFree,
410                                                void *clientPrivOpaque)
411 {
412     virNetServer *srv = NULL;
413     virJSONValue *clients;
414     virJSONValue *services;
415     size_t i;
416     unsigned int min_workers;
417     unsigned int max_workers;
418     unsigned int priority_workers;
419     unsigned int max_clients;
420     unsigned int max_anonymous_clients;
421     unsigned int keepaliveInterval;
422     unsigned int keepaliveCount;
423     unsigned long long next_client_id;
424 
425     if (virJSONValueObjectGetNumberUint(object, "min_workers", &min_workers) < 0) {
426         virReportError(VIR_ERR_INTERNAL_ERROR, "%s",
427                        _("Missing min_workers data in JSON document"));
428         goto error;
429     }
430     if (virJSONValueObjectGetNumberUint(object, "max_workers", &max_workers) < 0) {
431         virReportError(VIR_ERR_INTERNAL_ERROR, "%s",
432                        _("Missing max_workers data in JSON document"));
433         goto error;
434     }
435     if (virJSONValueObjectGetNumberUint(object, "priority_workers", &priority_workers) < 0) {
436         virReportError(VIR_ERR_INTERNAL_ERROR, "%s",
437                        _("Missing priority_workers data in JSON document"));
438         goto error;
439     }
440     if (virJSONValueObjectGetNumberUint(object, "max_clients", &max_clients) < 0) {
441         virReportError(VIR_ERR_INTERNAL_ERROR, "%s",
442                        _("Missing max_clients data in JSON document"));
443         goto error;
444     }
445     if (virJSONValueObjectHasKey(object, "max_anonymous_clients")) {
446         if (virJSONValueObjectGetNumberUint(object, "max_anonymous_clients",
447                                             &max_anonymous_clients) < 0) {
448             virReportError(VIR_ERR_INTERNAL_ERROR, "%s",
449                            _("Malformed max_anonymous_clients data in JSON document"));
450             goto error;
451         }
452     } else {
453         max_anonymous_clients = max_clients;
454     }
455     if (virJSONValueObjectGetNumberUint(object, "keepaliveInterval", &keepaliveInterval) < 0) {
456         virReportError(VIR_ERR_INTERNAL_ERROR, "%s",
457                        _("Missing keepaliveInterval data in JSON document"));
458         goto error;
459     }
460     if (virJSONValueObjectGetNumberUint(object, "keepaliveCount", &keepaliveCount) < 0) {
461         virReportError(VIR_ERR_INTERNAL_ERROR, "%s",
462                        _("Missing keepaliveCount data in JSON document"));
463         goto error;
464     }
465 
466     if (virJSONValueObjectGetNumberUlong(object, "next_client_id",
467                                          &next_client_id) < 0) {
468         VIR_WARN("Missing next_client_id data in JSON document");
469         next_client_id = 1;
470     }
471 
472     if (!(srv = virNetServerNew(name, next_client_id,
473                                 min_workers, max_workers,
474                                 priority_workers, max_clients,
475                                 max_anonymous_clients,
476                                 keepaliveInterval, keepaliveCount,
477                                 clientPrivNew, clientPrivPreExecRestart,
478                                 clientPrivFree, clientPrivOpaque)))
479         goto error;
480 
481     if (!(services = virJSONValueObjectGet(object, "services"))) {
482         virReportError(VIR_ERR_INTERNAL_ERROR, "%s",
483                        _("Missing services data in JSON document"));
484         goto error;
485     }
486 
487     if (!virJSONValueIsArray(services)) {
488         virReportError(VIR_ERR_INTERNAL_ERROR, "%s",
489                        _("Malformed services array"));
490         goto error;
491     }
492 
493     for (i = 0; i < virJSONValueArraySize(services); i++) {
494         virNetServerService *service;
495         virJSONValue *child = virJSONValueArrayGet(services, i);
496         if (!child) {
497             virReportError(VIR_ERR_INTERNAL_ERROR, "%s",
498                            _("Missing service data in JSON document"));
499             goto error;
500         }
501 
502         if (!(service = virNetServerServiceNewPostExecRestart(child)))
503             goto error;
504 
505         if (virNetServerAddService(srv, service) < 0) {
506             virObjectUnref(service);
507             goto error;
508         }
509     }
510 
511 
512     if (!(clients = virJSONValueObjectGet(object, "clients"))) {
513         virReportError(VIR_ERR_INTERNAL_ERROR, "%s",
514                        _("Missing clients data in JSON document"));
515         goto error;
516     }
517 
518     if (!virJSONValueIsArray(clients)) {
519         virReportError(VIR_ERR_INTERNAL_ERROR, "%s",
520                        _("Malformed clients array"));
521         goto error;
522     }
523 
524     for (i = 0; i < virJSONValueArraySize(clients); i++) {
525         virNetServerClient *client;
526         virJSONValue *child = virJSONValueArrayGet(clients, i);
527         if (!child) {
528             virReportError(VIR_ERR_INTERNAL_ERROR, "%s",
529                            _("Missing client data in JSON document"));
530             goto error;
531         }
532 
533         if (!(client = virNetServerClientNewPostExecRestart(srv,
534                                                             child,
535                                                             clientPrivNewPostExecRestart,
536                                                             clientPrivPreExecRestart,
537                                                             clientPrivFree,
538                                                             clientPrivOpaque)))
539             goto error;
540 
541         if (virNetServerAddClient(srv, client) < 0) {
542             virObjectUnref(client);
543             goto error;
544         }
545         virObjectUnref(client);
546     }
547 
548     return srv;
549 
550  error:
551     virObjectUnref(srv);
552     return NULL;
553 }
554 
555 
virNetServerPreExecRestart(virNetServer * srv)556 virJSONValue *virNetServerPreExecRestart(virNetServer *srv)
557 {
558     g_autoptr(virJSONValue) object = virJSONValueNewObject();
559     g_autoptr(virJSONValue) clients = virJSONValueNewArray();
560     g_autoptr(virJSONValue) services = virJSONValueNewArray();
561     size_t i;
562 
563     virObjectLock(srv);
564 
565     if (virJSONValueObjectAppendNumberUint(object, "min_workers",
566                                            virThreadPoolGetMinWorkers(srv->workers)) < 0)
567         goto error;
568     if (virJSONValueObjectAppendNumberUint(object, "max_workers",
569                                            virThreadPoolGetMaxWorkers(srv->workers)) < 0)
570         goto error;
571     if (virJSONValueObjectAppendNumberUint(object, "priority_workers",
572                                            virThreadPoolGetPriorityWorkers(srv->workers)) < 0)
573         goto error;
574 
575     if (virJSONValueObjectAppendNumberUint(object, "max_clients", srv->nclients_max) < 0)
576         goto error;
577     if (virJSONValueObjectAppendNumberUint(object, "max_anonymous_clients",
578                                            srv->nclients_unauth_max) < 0)
579         goto error;
580 
581     if (virJSONValueObjectAppendNumberUint(object, "keepaliveInterval", srv->keepaliveInterval) < 0)
582         goto error;
583     if (virJSONValueObjectAppendNumberUint(object, "keepaliveCount", srv->keepaliveCount) < 0)
584         goto error;
585 
586     if (virJSONValueObjectAppendNumberUlong(object, "next_client_id",
587                                             srv->next_client_id) < 0)
588         goto error;
589 
590     for (i = 0; i < srv->nservices; i++) {
591         g_autoptr(virJSONValue) child = NULL;
592         if (!(child = virNetServerServicePreExecRestart(srv->services[i])))
593             goto error;
594 
595         if (virJSONValueArrayAppend(services, &child) < 0)
596             goto error;
597     }
598 
599     if (virJSONValueObjectAppend(object, "services", &services) < 0)
600         goto error;
601 
602     for (i = 0; i < srv->nclients; i++) {
603         g_autoptr(virJSONValue) child = NULL;
604         if (!(child = virNetServerClientPreExecRestart(srv->clients[i])))
605             goto error;
606 
607         if (virJSONValueArrayAppend(clients, &child) < 0)
608             goto error;
609     }
610 
611     if (virJSONValueObjectAppend(object, "clients", &clients) < 0)
612         goto error;
613 
614     virObjectUnlock(srv);
615 
616     return g_steal_pointer(&object);
617 
618  error:
619     virObjectUnlock(srv);
620     return NULL;
621 }
622 
623 
624 
virNetServerAddService(virNetServer * srv,virNetServerService * svc)625 int virNetServerAddService(virNetServer *srv,
626                            virNetServerService *svc)
627 {
628     virObjectLock(srv);
629 
630     VIR_EXPAND_N(srv->services, srv->nservices, 1);
631     srv->services[srv->nservices-1] = virObjectRef(svc);
632 
633     virNetServerServiceSetDispatcher(svc,
634                                      virNetServerDispatchNewClient,
635                                      srv);
636 
637     virObjectUnlock(srv);
638     return 0;
639 }
640 
641 
642 static int
virNetServerAddServiceActivation(virNetServer * srv,virSystemdActivation * act,const char * actname,int auth,virNetTLSContext * tls,bool readonly,size_t max_queued_clients,size_t nrequests_client_max)643 virNetServerAddServiceActivation(virNetServer *srv,
644                                  virSystemdActivation *act,
645                                  const char *actname,
646                                  int auth,
647                                  virNetTLSContext *tls,
648                                  bool readonly,
649                                  size_t max_queued_clients,
650                                  size_t nrequests_client_max)
651 {
652     int *fds;
653     size_t nfds;
654 
655     if (act == NULL)
656         return 0;
657 
658     virSystemdActivationClaimFDs(act, actname, &fds, &nfds);
659 
660     if (nfds) {
661         virNetServerService *svc;
662 
663         svc = virNetServerServiceNewFDs(fds,
664                                         nfds,
665                                         false,
666                                         auth,
667                                         tls,
668                                         readonly,
669                                         max_queued_clients,
670                                         nrequests_client_max);
671         if (!svc)
672             return -1;
673 
674         if (virNetServerAddService(srv, svc) < 0) {
675             virObjectUnref(svc);
676             return -1;
677         }
678     }
679 
680     /* Intentionally return 1 any time activation is present,
681      * even if we didn't find any sockets with the matching
682      * name. The user needs to be free to disable some of the
683      * services via unit files without causing us to fallback
684      * to creating the service manually.
685      */
686     return 1;
687 }
688 
689 
virNetServerAddServiceTCP(virNetServer * srv,virSystemdActivation * act,const char * actname,const char * nodename,const char * service,int family,int auth,virNetTLSContext * tls,bool readonly,size_t max_queued_clients,size_t nrequests_client_max)690 int virNetServerAddServiceTCP(virNetServer *srv,
691                               virSystemdActivation *act,
692                               const char *actname,
693                               const char *nodename,
694                               const char *service,
695                               int family,
696                               int auth,
697                               virNetTLSContext *tls,
698                               bool readonly,
699                               size_t max_queued_clients,
700                               size_t nrequests_client_max)
701 {
702     virNetServerService *svc = NULL;
703     int ret;
704 
705     ret = virNetServerAddServiceActivation(srv, act, actname,
706                                            auth,
707                                            tls,
708                                            readonly,
709                                            max_queued_clients,
710                                            nrequests_client_max);
711     if (ret < 0)
712         return -1;
713 
714     if (ret == 1)
715         return 0;
716 
717     if (!(svc = virNetServerServiceNewTCP(nodename,
718                                           service,
719                                           family,
720                                           auth,
721                                           tls,
722                                           readonly,
723                                           max_queued_clients,
724                                           nrequests_client_max)))
725         return -1;
726 
727     if (virNetServerAddService(srv, svc) < 0) {
728         virObjectUnref(svc);
729         return -1;
730     }
731 
732     virObjectUnref(svc);
733 
734     return 0;
735 }
736 
737 
virNetServerAddServiceUNIX(virNetServer * srv,virSystemdActivation * act,const char * actname,const char * path,mode_t mask,gid_t grp,int auth,virNetTLSContext * tls,bool readonly,size_t max_queued_clients,size_t nrequests_client_max)738 int virNetServerAddServiceUNIX(virNetServer *srv,
739                                virSystemdActivation *act,
740                                const char *actname,
741                                const char *path,
742                                mode_t mask,
743                                gid_t grp,
744                                int auth,
745                                virNetTLSContext *tls,
746                                bool readonly,
747                                size_t max_queued_clients,
748                                size_t nrequests_client_max)
749 {
750     virNetServerService *svc = NULL;
751     int ret;
752 
753     ret = virNetServerAddServiceActivation(srv, act, actname,
754                                            auth,
755                                            tls,
756                                            readonly,
757                                            max_queued_clients,
758                                            nrequests_client_max);
759     if (ret < 0)
760         return -1;
761 
762     if (ret == 1)
763         return 0;
764 
765     if (!(svc = virNetServerServiceNewUNIX(path,
766                                            mask,
767                                            grp,
768                                            auth,
769                                            tls,
770                                            readonly,
771                                            max_queued_clients,
772                                            nrequests_client_max)))
773         return -1;
774 
775     if (virNetServerAddService(srv, svc) < 0) {
776         virObjectUnref(svc);
777         return -1;
778     }
779 
780     virObjectUnref(svc);
781 
782     return 0;
783 }
784 
785 
virNetServerAddProgram(virNetServer * srv,virNetServerProgram * prog)786 int virNetServerAddProgram(virNetServer *srv,
787                            virNetServerProgram *prog)
788 {
789     virObjectLock(srv);
790 
791     VIR_EXPAND_N(srv->programs, srv->nprograms, 1);
792     srv->programs[srv->nprograms-1] = virObjectRef(prog);
793 
794     virObjectUnlock(srv);
795     return 0;
796 }
797 
virNetServerSetTLSContext(virNetServer * srv,virNetTLSContext * tls)798 int virNetServerSetTLSContext(virNetServer *srv,
799                               virNetTLSContext *tls)
800 {
801     srv->tls = virObjectRef(tls);
802     return 0;
803 }
804 
805 
806 /**
807  * virNetServerSetClientAuthCompletedLocked:
808  * @srv: server must be locked by the caller
809  * @client: client must be locked by the caller
810  *
811  * If the client authentication was pending, clear that pending and
812  * update the server tracking.
813  */
814 static void
virNetServerSetClientAuthCompletedLocked(virNetServer * srv,virNetServerClient * client)815 virNetServerSetClientAuthCompletedLocked(virNetServer *srv,
816                                          virNetServerClient *client)
817 {
818     if (virNetServerClientIsAuthPendingLocked(client)) {
819         virNetServerClientSetAuthPendingLocked(client, false);
820         virNetServerTrackCompletedAuthLocked(srv);
821     }
822 }
823 
824 
825 /**
826  * virNetServerSetClientAuthenticated:
827  * @srv: server must be unlocked
828  * @client: client must be unlocked
829  *
830  * Mark @client as authenticated and tracks on @srv that the
831  * authentication of this @client has been completed. Also it checks
832  * the limits of @srv.
833  */
834 void
virNetServerSetClientAuthenticated(virNetServer * srv,virNetServerClient * client)835 virNetServerSetClientAuthenticated(virNetServer *srv,
836                                    virNetServerClient *client)
837 {
838     virObjectLock(srv);
839     virObjectLock(client);
840     virNetServerClientSetAuthLocked(client, VIR_NET_SERVER_SERVICE_AUTH_NONE);
841     virNetServerSetClientAuthCompletedLocked(srv, client);
842     virNetServerCheckLimits(srv);
843     virObjectUnlock(client);
844     virObjectUnlock(srv);
845 }
846 
847 
848 static void
virNetServerUpdateServicesLocked(virNetServer * srv,bool enabled)849 virNetServerUpdateServicesLocked(virNetServer *srv,
850                                  bool enabled)
851 {
852     size_t i;
853 
854     for (i = 0; i < srv->nservices; i++)
855         virNetServerServiceToggle(srv->services[i], enabled);
856 }
857 
858 
virNetServerUpdateServices(virNetServer * srv,bool enabled)859 void virNetServerUpdateServices(virNetServer *srv,
860                                 bool enabled)
861 {
862     virObjectLock(srv);
863     virNetServerUpdateServicesLocked(srv, enabled);
864     virObjectUnlock(srv);
865 }
866 
virNetServerDispose(void * obj)867 void virNetServerDispose(void *obj)
868 {
869     virNetServer *srv = obj;
870     size_t i;
871 
872     g_free(srv->name);
873 
874     virThreadPoolFree(srv->workers);
875 
876     for (i = 0; i < srv->nservices; i++)
877         virObjectUnref(srv->services[i]);
878     g_free(srv->services);
879 
880     for (i = 0; i < srv->nprograms; i++)
881         virObjectUnref(srv->programs[i]);
882     g_free(srv->programs);
883 
884     for (i = 0; i < srv->nclients; i++)
885         virObjectUnref(srv->clients[i]);
886     g_free(srv->clients);
887 }
888 
virNetServerClose(virNetServer * srv)889 void virNetServerClose(virNetServer *srv)
890 {
891     size_t i;
892 
893     if (!srv)
894         return;
895 
896     virObjectLock(srv);
897 
898     for (i = 0; i < srv->nservices; i++)
899         virNetServerServiceClose(srv->services[i]);
900 
901     for (i = 0; i < srv->nclients; i++)
902         virNetServerClientClose(srv->clients[i]);
903 
904     virThreadPoolStop(srv->workers);
905 
906     virObjectUnlock(srv);
907 }
908 
909 void
virNetServerShutdownWait(virNetServer * srv)910 virNetServerShutdownWait(virNetServer *srv)
911 {
912     virThreadPoolDrain(srv->workers);
913 }
914 
915 static inline size_t
virNetServerTrackPendingAuthLocked(virNetServer * srv)916 virNetServerTrackPendingAuthLocked(virNetServer *srv)
917 {
918     return ++srv->nclients_unauth;
919 }
920 
921 static inline size_t
virNetServerTrackCompletedAuthLocked(virNetServer * srv)922 virNetServerTrackCompletedAuthLocked(virNetServer *srv)
923 {
924     return --srv->nclients_unauth;
925 }
926 
927 
928 bool
virNetServerHasClients(virNetServer * srv)929 virNetServerHasClients(virNetServer *srv)
930 {
931     bool ret;
932 
933     virObjectLock(srv);
934     ret = !!srv->nclients;
935     virObjectUnlock(srv);
936 
937     return ret;
938 }
939 
940 void
virNetServerProcessClients(virNetServer * srv)941 virNetServerProcessClients(virNetServer *srv)
942 {
943     size_t i;
944     virNetServerClient *client;
945 
946     virObjectLock(srv);
947 
948  reprocess:
949     for (i = 0; i < srv->nclients; i++) {
950         client = srv->clients[i];
951         virObjectLock(client);
952         if (virNetServerClientWantCloseLocked(client))
953             virNetServerClientCloseLocked(client);
954 
955         if (virNetServerClientIsClosedLocked(client)) {
956             VIR_DELETE_ELEMENT(srv->clients, i, srv->nclients);
957 
958             /* Update server authentication tracking */
959             virNetServerSetClientAuthCompletedLocked(srv, client);
960             virObjectUnlock(client);
961 
962             virNetServerCheckLimits(srv);
963 
964             virObjectUnlock(srv);
965             virObjectUnref(client);
966             virObjectLock(srv);
967 
968             goto reprocess;
969         } else {
970             virObjectUnlock(client);
971         }
972     }
973 
974     virObjectUnlock(srv);
975 }
976 
977 const char *
virNetServerGetName(virNetServer * srv)978 virNetServerGetName(virNetServer *srv)
979 {
980     return srv->name;
981 }
982 
983 int
virNetServerGetThreadPoolParameters(virNetServer * srv,size_t * minWorkers,size_t * maxWorkers,size_t * nWorkers,size_t * freeWorkers,size_t * nPrioWorkers,size_t * jobQueueDepth)984 virNetServerGetThreadPoolParameters(virNetServer *srv,
985                                     size_t *minWorkers,
986                                     size_t *maxWorkers,
987                                     size_t *nWorkers,
988                                     size_t *freeWorkers,
989                                     size_t *nPrioWorkers,
990                                     size_t *jobQueueDepth)
991 {
992     virObjectLock(srv);
993 
994     *minWorkers = virThreadPoolGetMinWorkers(srv->workers);
995     *maxWorkers = virThreadPoolGetMaxWorkers(srv->workers);
996     *freeWorkers = virThreadPoolGetFreeWorkers(srv->workers);
997     *nWorkers = virThreadPoolGetCurrentWorkers(srv->workers);
998     *nPrioWorkers = virThreadPoolGetPriorityWorkers(srv->workers);
999     *jobQueueDepth = virThreadPoolGetJobQueueDepth(srv->workers);
1000 
1001     virObjectUnlock(srv);
1002     return 0;
1003 }
1004 
1005 int
virNetServerSetThreadPoolParameters(virNetServer * srv,long long int minWorkers,long long int maxWorkers,long long int prioWorkers)1006 virNetServerSetThreadPoolParameters(virNetServer *srv,
1007                                     long long int minWorkers,
1008                                     long long int maxWorkers,
1009                                     long long int prioWorkers)
1010 {
1011     int ret;
1012 
1013     virObjectLock(srv);
1014     ret = virThreadPoolSetParameters(srv->workers, minWorkers,
1015                                      maxWorkers, prioWorkers);
1016     virObjectUnlock(srv);
1017 
1018     return ret;
1019 }
1020 
1021 size_t
virNetServerGetMaxClients(virNetServer * srv)1022 virNetServerGetMaxClients(virNetServer *srv)
1023 {
1024     size_t ret;
1025 
1026     virObjectLock(srv);
1027     ret = srv->nclients_max;
1028     virObjectUnlock(srv);
1029 
1030     return ret;
1031 }
1032 
1033 size_t
virNetServerGetCurrentClients(virNetServer * srv)1034 virNetServerGetCurrentClients(virNetServer *srv)
1035 {
1036     size_t ret;
1037 
1038     virObjectLock(srv);
1039     ret = srv->nclients;
1040     virObjectUnlock(srv);
1041 
1042     return ret;
1043 }
1044 
1045 size_t
virNetServerGetMaxUnauthClients(virNetServer * srv)1046 virNetServerGetMaxUnauthClients(virNetServer *srv)
1047 {
1048     size_t ret;
1049 
1050     virObjectLock(srv);
1051     ret = srv->nclients_unauth_max;
1052     virObjectUnlock(srv);
1053 
1054     return ret;
1055 }
1056 
1057 size_t
virNetServerGetCurrentUnauthClients(virNetServer * srv)1058 virNetServerGetCurrentUnauthClients(virNetServer *srv)
1059 {
1060     size_t ret;
1061 
1062     virObjectLock(srv);
1063     ret = srv->nclients_unauth;
1064     virObjectUnlock(srv);
1065 
1066     return ret;
1067 }
1068 
1069 
virNetServerNeedsAuth(virNetServer * srv,int auth)1070 bool virNetServerNeedsAuth(virNetServer *srv,
1071                            int auth)
1072 {
1073     bool ret = false;
1074     size_t i;
1075 
1076     virObjectLock(srv);
1077     for (i = 0; i < srv->nservices; i++) {
1078         if (virNetServerServiceGetAuth(srv->services[i]) == auth)
1079             ret = true;
1080     }
1081     virObjectUnlock(srv);
1082 
1083     return ret;
1084 }
1085 
1086 int
virNetServerGetClients(virNetServer * srv,virNetServerClient *** clts)1087 virNetServerGetClients(virNetServer *srv,
1088                        virNetServerClient ***clts)
1089 {
1090     size_t i;
1091     size_t nclients = 0;
1092     virNetServerClient **list = NULL;
1093 
1094     virObjectLock(srv);
1095 
1096     for (i = 0; i < srv->nclients; i++) {
1097         virNetServerClient *client = virObjectRef(srv->clients[i]);
1098         VIR_APPEND_ELEMENT(list, nclients, client);
1099     }
1100 
1101     *clts = g_steal_pointer(&list);
1102 
1103     virObjectUnlock(srv);
1104 
1105     return nclients;
1106 }
1107 
1108 virNetServerClient *
virNetServerGetClient(virNetServer * srv,unsigned long long id)1109 virNetServerGetClient(virNetServer *srv,
1110                       unsigned long long id)
1111 {
1112     size_t i;
1113     virNetServerClient *ret = NULL;
1114 
1115     virObjectLock(srv);
1116 
1117     for (i = 0; i < srv->nclients; i++) {
1118         virNetServerClient *client = srv->clients[i];
1119         if (virNetServerClientGetID(client) == id)
1120             ret = virObjectRef(client);
1121     }
1122 
1123     virObjectUnlock(srv);
1124 
1125     if (!ret)
1126         virReportError(VIR_ERR_NO_CLIENT,
1127                        _("No client with matching ID '%llu'"), id);
1128     return ret;
1129 }
1130 
1131 int
virNetServerSetClientLimits(virNetServer * srv,long long int maxClients,long long int maxClientsUnauth)1132 virNetServerSetClientLimits(virNetServer *srv,
1133                             long long int maxClients,
1134                             long long int maxClientsUnauth)
1135 {
1136     int ret = -1;
1137     size_t max, max_unauth;
1138 
1139     virObjectLock(srv);
1140 
1141     max = maxClients >= 0 ? maxClients : srv->nclients_max;
1142     max_unauth = maxClientsUnauth >= 0 ?
1143         maxClientsUnauth : srv->nclients_unauth_max;
1144 
1145     if (max < max_unauth) {
1146         virReportError(VIR_ERR_INVALID_ARG, "%s",
1147                        _("The overall maximum number of clients must be "
1148                          "greater than the maximum number of clients waiting "
1149                          "for authentication"));
1150         goto cleanup;
1151     }
1152 
1153     if (maxClients >= 0)
1154         srv->nclients_max = maxClients;
1155 
1156     if (maxClientsUnauth >= 0)
1157         srv->nclients_unauth_max = maxClientsUnauth;
1158 
1159     virNetServerCheckLimits(srv);
1160 
1161     ret = 0;
1162  cleanup:
1163     virObjectUnlock(srv);
1164     return ret;
1165 }
1166 
1167 static virNetTLSContext *
virNetServerGetTLSContext(virNetServer * srv)1168 virNetServerGetTLSContext(virNetServer *srv)
1169 {
1170     size_t i;
1171     virNetTLSContext *ctxt = NULL;
1172     virNetServerService *svc = NULL;
1173 
1174     /* find svcTLS from srv, get svcTLS->tls */
1175     for (i = 0; i < srv->nservices; i++) {
1176         svc = srv->services[i];
1177         ctxt = virNetServerServiceGetTLSContext(svc);
1178         if (ctxt != NULL)
1179             break;
1180     }
1181 
1182     return ctxt;
1183 }
1184 
1185 int
virNetServerUpdateTlsFiles(virNetServer * srv)1186 virNetServerUpdateTlsFiles(virNetServer *srv)
1187 {
1188     int ret = -1;
1189     virNetTLSContext *ctxt = NULL;
1190     bool privileged = geteuid() == 0;
1191 
1192     ctxt = virNetServerGetTLSContext(srv);
1193     if (!ctxt) {
1194         virReportError(VIR_ERR_INTERNAL_ERROR, "%s",
1195                        _("no tls service found, unable to update tls files"));
1196         return -1;
1197     }
1198 
1199     virObjectLock(srv);
1200     virObjectLock(ctxt);
1201 
1202     if (virNetTLSContextReloadForServer(ctxt, !privileged)) {
1203         VIR_DEBUG("failed to reload server's tls context");
1204         goto cleanup;
1205     }
1206 
1207     VIR_DEBUG("update tls files success");
1208     ret = 0;
1209 
1210  cleanup:
1211     virObjectUnlock(ctxt);
1212     virObjectUnlock(srv);
1213     return ret;
1214 }
1215