1 /*
2  * virnetserverprogram.c: generic network RPC server program
3  *
4  * Copyright (C) 2006-2012 Red Hat, Inc.
5  * Copyright (C) 2006 Daniel P. Berrange
6  *
7  * This library is free software; you can redistribute it and/or
8  * modify it under the terms of the GNU Lesser General Public
9  * License as published by the Free Software Foundation; either
10  * version 2.1 of the License, or (at your option) any later version.
11  *
12  * This library is distributed in the hope that it will be useful,
13  * but WITHOUT ANY WARRANTY; without even the implied warranty of
14  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
15  * Lesser General Public License for more details.
16  *
17  * You should have received a copy of the GNU Lesser General Public
18  * License along with this library.  If not, see
19  * <http://www.gnu.org/licenses/>.
20  */
21 
22 #include <config.h>
23 
24 #include "virnetserverprogram.h"
25 #include "virnetserverclient.h"
26 
27 #include "viralloc.h"
28 #include "virerror.h"
29 #include "virlog.h"
30 #include "virfile.h"
31 #include "virthread.h"
32 
33 #define VIR_FROM_THIS VIR_FROM_RPC
34 
35 VIR_LOG_INIT("rpc.netserverprogram");
36 
37 struct _virNetServerProgram {
38     virObject parent;
39 
40     unsigned program;
41     unsigned version;
42     virNetServerProgramProc *procs;
43     size_t nprocs;
44 };
45 
46 
47 static virClass *virNetServerProgramClass;
48 static void virNetServerProgramDispose(void *obj);
49 
virNetServerProgramOnceInit(void)50 static int virNetServerProgramOnceInit(void)
51 {
52     if (!VIR_CLASS_NEW(virNetServerProgram, virClassForObject()))
53         return -1;
54 
55     return 0;
56 }
57 
58 VIR_ONCE_GLOBAL_INIT(virNetServerProgram);
59 
60 
virNetServerProgramNew(unsigned program,unsigned version,virNetServerProgramProc * procs,size_t nprocs)61 virNetServerProgram *virNetServerProgramNew(unsigned program,
62                                               unsigned version,
63                                               virNetServerProgramProc *procs,
64                                               size_t nprocs)
65 {
66     virNetServerProgram *prog;
67 
68     if (virNetServerProgramInitialize() < 0)
69         return NULL;
70 
71     if (!(prog = virObjectNew(virNetServerProgramClass)))
72         return NULL;
73 
74     prog->program = program;
75     prog->version = version;
76     prog->procs = procs;
77     prog->nprocs = nprocs;
78 
79     VIR_DEBUG("prog=%p", prog);
80 
81     return prog;
82 }
83 
84 
virNetServerProgramGetID(virNetServerProgram * prog)85 int virNetServerProgramGetID(virNetServerProgram *prog)
86 {
87     return prog->program;
88 }
89 
90 
virNetServerProgramGetVersion(virNetServerProgram * prog)91 int virNetServerProgramGetVersion(virNetServerProgram *prog)
92 {
93     return prog->version;
94 }
95 
96 
virNetServerProgramMatches(virNetServerProgram * prog,virNetMessage * msg)97 int virNetServerProgramMatches(virNetServerProgram *prog,
98                                virNetMessage *msg)
99 {
100     if (prog->program == msg->header.prog &&
101         prog->version == msg->header.vers)
102         return 1;
103     return 0;
104 }
105 
106 
virNetServerProgramGetProc(virNetServerProgram * prog,int procedure)107 static virNetServerProgramProc *virNetServerProgramGetProc(virNetServerProgram *prog,
108                                                              int procedure)
109 {
110     virNetServerProgramProc *proc;
111 
112     if (procedure < 0)
113         return NULL;
114     if (procedure >= prog->nprocs)
115         return NULL;
116 
117     proc = &prog->procs[procedure];
118 
119     if (!proc->func)
120         return NULL;
121 
122     return proc;
123 }
124 
125 unsigned int
virNetServerProgramGetPriority(virNetServerProgram * prog,int procedure)126 virNetServerProgramGetPriority(virNetServerProgram *prog,
127                                int procedure)
128 {
129     virNetServerProgramProc *proc = virNetServerProgramGetProc(prog, procedure);
130 
131     if (!proc)
132         return 0;
133 
134     return proc->priority;
135 }
136 
137 static int
virNetServerProgramSendError(unsigned program,unsigned version,virNetServerClient * client,virNetMessage * msg,struct virNetMessageError * rerr,int procedure,int type,unsigned int serial)138 virNetServerProgramSendError(unsigned program,
139                              unsigned version,
140                              virNetServerClient *client,
141                              virNetMessage *msg,
142                              struct virNetMessageError *rerr,
143                              int procedure,
144                              int type,
145                              unsigned int serial)
146 {
147     VIR_DEBUG("prog=%d ver=%d proc=%d type=%d serial=%u msg=%p rerr=%p",
148               program, version, procedure, type, serial, msg, rerr);
149 
150     virNetMessageSaveError(rerr);
151 
152     /* Return header. */
153     msg->header.prog = program;
154     msg->header.vers = version;
155     msg->header.proc = procedure;
156     msg->header.type = type;
157     msg->header.serial = serial;
158     msg->header.status = VIR_NET_ERROR;
159 
160     if (virNetMessageEncodeHeader(msg) < 0)
161         goto error;
162 
163     if (virNetMessageEncodePayload(msg, (xdrproc_t)xdr_virNetMessageError, rerr) < 0)
164         goto error;
165     xdr_free((xdrproc_t)xdr_virNetMessageError, (void*)rerr);
166 
167     /* Put reply on end of tx queue to send out  */
168     if (virNetServerClientSendMessage(client, msg) < 0)
169         return -1;
170 
171     return 0;
172 
173  error:
174     VIR_WARN("Failed to serialize remote error '%p'", rerr);
175     xdr_free((xdrproc_t)xdr_virNetMessageError, (void*)rerr);
176     return -1;
177 }
178 
179 
180 /*
181  * @client: the client to send the error to
182  * @req: the message this error is in reply to
183  *
184  * Send an error message to the client
185  *
186  * Returns 0 if the error was sent, -1 upon fatal error
187  */
188 int
virNetServerProgramSendReplyError(virNetServerProgram * prog,virNetServerClient * client,virNetMessage * msg,struct virNetMessageError * rerr,struct virNetMessageHeader * req)189 virNetServerProgramSendReplyError(virNetServerProgram *prog,
190                                   virNetServerClient *client,
191                                   virNetMessage *msg,
192                                   struct virNetMessageError *rerr,
193                                   struct virNetMessageHeader *req)
194 {
195     /*
196      * For data streams, errors are sent back as data streams
197      * For method calls, errors are sent back as method replies
198      */
199     return virNetServerProgramSendError(prog->program,
200                                         prog->version,
201                                         client,
202                                         msg,
203                                         rerr,
204                                         req->proc,
205                                         req->type == VIR_NET_STREAM ? VIR_NET_STREAM : VIR_NET_REPLY,
206                                         req->serial);
207 }
208 
209 
virNetServerProgramSendStreamError(virNetServerProgram * prog,virNetServerClient * client,virNetMessage * msg,struct virNetMessageError * rerr,int procedure,unsigned int serial)210 int virNetServerProgramSendStreamError(virNetServerProgram *prog,
211                                        virNetServerClient *client,
212                                        virNetMessage *msg,
213                                        struct virNetMessageError *rerr,
214                                        int procedure,
215                                        unsigned int serial)
216 {
217     return virNetServerProgramSendError(prog->program,
218                                         prog->version,
219                                         client,
220                                         msg,
221                                         rerr,
222                                         procedure,
223                                         VIR_NET_STREAM,
224                                         serial);
225 }
226 
227 
virNetServerProgramUnknownError(virNetServerClient * client,virNetMessage * msg,struct virNetMessageHeader * req)228 int virNetServerProgramUnknownError(virNetServerClient *client,
229                                     virNetMessage *msg,
230                                     struct virNetMessageHeader *req)
231 {
232     virNetMessageError rerr;
233 
234     virReportError(VIR_ERR_RPC,
235                    _("Cannot find program %d version %d"), req->prog, req->vers);
236 
237     memset(&rerr, 0, sizeof(rerr));
238     return virNetServerProgramSendError(req->prog,
239                                         req->vers,
240                                         client,
241                                         msg,
242                                         &rerr,
243                                         req->proc,
244                                         VIR_NET_REPLY,
245                                         req->serial);
246 }
247 
248 
249 static int
250 virNetServerProgramDispatchCall(virNetServerProgram *prog,
251                                 virNetServer *server,
252                                 virNetServerClient *client,
253                                 virNetMessage *msg);
254 
255 /*
256  * @server: the unlocked server object
257  * @client: the unlocked client object
258  * @msg: the complete incoming message packet, with header already decoded
259  *
260  * This function is intended to be called from worker threads
261  * when an incoming message is ready to be dispatched for
262  * execution.
263  *
264  * Upon successful return the '@msg' instance will be released
265  * by this function (or more often, reused to send a reply).
266  * Upon failure, the '@msg' must be freed by the caller.
267  *
268  * Returns 0 if the message was dispatched, -1 upon fatal error
269  */
virNetServerProgramDispatch(virNetServerProgram * prog,virNetServer * server,virNetServerClient * client,virNetMessage * msg)270 int virNetServerProgramDispatch(virNetServerProgram *prog,
271                                 virNetServer *server,
272                                 virNetServerClient *client,
273                                 virNetMessage *msg)
274 {
275     int ret = -1;
276     virNetMessageError rerr;
277 
278     memset(&rerr, 0, sizeof(rerr));
279 
280     VIR_DEBUG("prog=%d ver=%d type=%d status=%d serial=%u proc=%d",
281               msg->header.prog, msg->header.vers, msg->header.type,
282               msg->header.status, msg->header.serial, msg->header.proc);
283 
284     /* Check version, etc. */
285     if (msg->header.prog != prog->program) {
286         virReportError(VIR_ERR_RPC,
287                        _("program mismatch (actual %x, expected %x)"),
288                        msg->header.prog, prog->program);
289         goto error;
290     }
291 
292     if (msg->header.vers != prog->version) {
293         virReportError(VIR_ERR_RPC,
294                        _("version mismatch (actual %x, expected %x)"),
295                        msg->header.vers, prog->version);
296         goto error;
297     }
298 
299     switch (msg->header.type) {
300     case VIR_NET_CALL:
301     case VIR_NET_CALL_WITH_FDS:
302         ret = virNetServerProgramDispatchCall(prog, server, client, msg);
303         break;
304 
305     case VIR_NET_STREAM:
306         /* Since stream data is non-acked, async, we may continue to receive
307          * stream packets after we closed down a stream. Just drop & ignore
308          * these.
309          */
310         VIR_INFO("Ignoring unexpected stream data serial=%u proc=%d status=%d",
311                  msg->header.serial, msg->header.proc, msg->header.status);
312         /* Send a dummy reply to free up 'msg' & unblock client rx */
313         virNetMessageClear(msg);
314         msg->header.type = VIR_NET_REPLY;
315         if (virNetServerClientSendMessage(client, msg) < 0)
316             return -1;
317         ret = 0;
318         break;
319 
320     case VIR_NET_REPLY:
321     case VIR_NET_REPLY_WITH_FDS:
322     case VIR_NET_MESSAGE:
323     case VIR_NET_STREAM_HOLE:
324     default:
325         virReportError(VIR_ERR_RPC,
326                        _("Unexpected message type %u"),
327                        msg->header.type);
328         goto error;
329     }
330 
331     return ret;
332 
333  error:
334     if (msg->header.type == VIR_NET_CALL ||
335         msg->header.type == VIR_NET_CALL_WITH_FDS) {
336         ret = virNetServerProgramSendReplyError(prog, client, msg, &rerr, &msg->header);
337     } else {
338         /* Send a dummy reply to free up 'msg' & unblock client rx */
339         virNetMessageClear(msg);
340         msg->header.type = VIR_NET_REPLY;
341         if (virNetServerClientSendMessage(client, msg) < 0)
342             return -1;
343         ret = 0;
344     }
345 
346     return ret;
347 }
348 
349 
350 /*
351  * @server: the unlocked server object
352  * @client: the unlocked client object
353  * @msg: the complete incoming method call, with header already decoded
354  *
355  * This method is used to dispatch a message representing an
356  * incoming method call from a client. It decodes the payload
357  * to obtain method call arguments, invokes the method and
358  * then sends a reply packet with the return values
359  *
360  * Returns 0 if the reply was sent, or -1 upon fatal error
361  */
362 static int
virNetServerProgramDispatchCall(virNetServerProgram * prog,virNetServer * server,virNetServerClient * client,virNetMessage * msg)363 virNetServerProgramDispatchCall(virNetServerProgram *prog,
364                                 virNetServer *server,
365                                 virNetServerClient *client,
366                                 virNetMessage *msg)
367 {
368     g_autofree char *arg = NULL;
369     g_autofree char *ret = NULL;
370     int rv = -1;
371     virNetServerProgramProc *dispatcher;
372     virNetMessageError rerr;
373     size_t i;
374     g_autoptr(virIdentity) identity = NULL;
375 
376     memset(&rerr, 0, sizeof(rerr));
377 
378     if (msg->header.status != VIR_NET_OK) {
379         virReportError(VIR_ERR_RPC,
380                        _("Unexpected message status %u"),
381                        msg->header.status);
382         goto error;
383     }
384 
385     dispatcher = virNetServerProgramGetProc(prog, msg->header.proc);
386 
387     if (!dispatcher) {
388         virReportError(VIR_ERR_RPC,
389                        _("unknown procedure: %d"),
390                        msg->header.proc);
391         goto error;
392     }
393 
394     /* If the client is not authenticated, don't allow any RPC ops
395      * which are except for authentication ones */
396     if (dispatcher->needAuth &&
397         !virNetServerClientIsAuthenticated(client)) {
398         /* Explicitly *NOT* calling  remoteDispatchAuthError() because
399            we want back-compatibility with libvirt clients which don't
400            support the VIR_ERR_AUTH_FAILED error code */
401         virReportError(VIR_ERR_RPC,
402                        "%s", _("authentication required"));
403         goto error;
404     }
405 
406     arg = g_new0(char, dispatcher->arg_len);
407     ret = g_new0(char, dispatcher->ret_len);
408 
409     if (virNetMessageDecodePayload(msg, dispatcher->arg_filter, arg) < 0)
410         goto error;
411 
412     if (!(identity = virNetServerClientGetIdentity(client)))
413         goto error;
414 
415     if (virIdentitySetCurrent(identity) < 0)
416         goto error;
417 
418     /*
419      * When the RPC handler is called:
420      *
421      *  - Server object is unlocked
422      *  - Client object is unlocked
423      *
424      * Without locking, it is safe to use:
425      *
426      *   'args and 'ret'
427      */
428     rv = (dispatcher->func)(server, client, msg, &rerr, arg, ret);
429 
430     if (virIdentitySetCurrent(NULL) < 0)
431         goto error;
432 
433     /*
434      * If rv == 1, this indicates the dispatch func has
435      * populated 'msg' with a list of FDs to return to
436      * the caller.
437      *
438      * Otherwise we must clear out the FDs we got from
439      * the client originally.
440      *
441      */
442     if (rv != 1) {
443         for (i = 0; i < msg->nfds; i++)
444             VIR_FORCE_CLOSE(msg->fds[i]);
445         VIR_FREE(msg->fds);
446         msg->nfds = 0;
447     }
448 
449     xdr_free(dispatcher->arg_filter, arg);
450 
451     if (rv < 0)
452         goto error;
453 
454     /* Return header. We're re-using same message object, so
455      * only need to tweak type/status fields */
456     /*msg->header.prog = msg->header.prog;*/
457     /*msg->header.vers = msg->header.vers;*/
458     /*msg->header.proc = msg->header.proc;*/
459     msg->header.type = msg->nfds ? VIR_NET_REPLY_WITH_FDS : VIR_NET_REPLY;
460     /*msg->header.serial = msg->header.serial;*/
461     msg->header.status = VIR_NET_OK;
462 
463     if (virNetMessageEncodeHeader(msg) < 0) {
464         xdr_free(dispatcher->ret_filter, ret);
465         goto error;
466     }
467 
468     if (msg->nfds &&
469         virNetMessageEncodeNumFDs(msg) < 0) {
470         xdr_free(dispatcher->ret_filter, ret);
471         goto error;
472     }
473 
474     if (virNetMessageEncodePayload(msg, dispatcher->ret_filter, ret) < 0) {
475         xdr_free(dispatcher->ret_filter, ret);
476         goto error;
477     }
478 
479     xdr_free(dispatcher->ret_filter, ret);
480 
481     /* Put reply on end of tx queue to send out  */
482     return virNetServerClientSendMessage(client, msg);
483 
484  error:
485     /* Bad stuff (de-)serializing message, but we have an
486      * RPC error message we can send back to the client */
487     rv = virNetServerProgramSendReplyError(prog, client, msg, &rerr, &msg->header);
488 
489     return rv;
490 }
491 
492 
virNetServerProgramSendStreamData(virNetServerProgram * prog,virNetServerClient * client,virNetMessage * msg,int procedure,unsigned int serial,const char * data,size_t len)493 int virNetServerProgramSendStreamData(virNetServerProgram *prog,
494                                       virNetServerClient *client,
495                                       virNetMessage *msg,
496                                       int procedure,
497                                       unsigned int serial,
498                                       const char *data,
499                                       size_t len)
500 {
501     VIR_DEBUG("client=%p msg=%p data=%p len=%zu", client, msg, data, len);
502 
503     /* Return header. We're reusing same message object, so
504      * only need to tweak type/status fields */
505     msg->header.prog = prog->program;
506     msg->header.vers = prog->version;
507     msg->header.proc = procedure;
508     msg->header.type = VIR_NET_STREAM;
509     msg->header.serial = serial;
510     /*
511      * NB
512      *   data != NULL + len > 0    => VIR_NET_CONTINUE   (Sending back data)
513      *   data != NULL + len == 0   => VIR_NET_CONTINUE   (Sending read EOF)
514      *   data == NULL              => VIR_NET_OK         (Sending finish handshake confirmation)
515      */
516     msg->header.status = data ? VIR_NET_CONTINUE : VIR_NET_OK;
517 
518     if (virNetMessageEncodeHeader(msg) < 0)
519         return -1;
520 
521     if (data && len) {
522         if (virNetMessageEncodePayloadRaw(msg, data, len) < 0)
523             return -1;
524 
525     } else {
526         if (virNetMessageEncodePayloadEmpty(msg) < 0)
527             return -1;
528     }
529     VIR_DEBUG("Total %zu", msg->bufferLength);
530 
531     return virNetServerClientSendMessage(client, msg);
532 }
533 
534 
virNetServerProgramSendStreamHole(virNetServerProgram * prog,virNetServerClient * client,virNetMessage * msg,int procedure,unsigned int serial,long long length,unsigned int flags)535 int virNetServerProgramSendStreamHole(virNetServerProgram *prog,
536                                       virNetServerClient *client,
537                                       virNetMessage *msg,
538                                       int procedure,
539                                       unsigned int serial,
540                                       long long length,
541                                       unsigned int flags)
542 {
543     virNetStreamHole data;
544 
545     VIR_DEBUG("client=%p msg=%p length=%lld", client, msg, length);
546 
547     memset(&data, 0, sizeof(data));
548     data.length = length;
549     data.flags = flags;
550 
551     msg->header.prog = prog->program;
552     msg->header.vers = prog->version;
553     msg->header.proc = procedure;
554     msg->header.type = VIR_NET_STREAM_HOLE;
555     msg->header.serial = serial;
556     msg->header.status = VIR_NET_CONTINUE;
557 
558     if (virNetMessageEncodeHeader(msg) < 0)
559         return -1;
560 
561     if (virNetMessageEncodePayload(msg,
562                                    (xdrproc_t)xdr_virNetStreamHole,
563                                    &data) < 0)
564         return -1;
565 
566     return virNetServerClientSendMessage(client, msg);
567 }
568 
569 
virNetServerProgramDispose(void * obj G_GNUC_UNUSED)570 void virNetServerProgramDispose(void *obj G_GNUC_UNUSED)
571 {
572 }
573