1 /*
2 * virnetserverprogram.c: generic network RPC server program
3 *
4 * Copyright (C) 2006-2012 Red Hat, Inc.
5 * Copyright (C) 2006 Daniel P. Berrange
6 *
7 * This library is free software; you can redistribute it and/or
8 * modify it under the terms of the GNU Lesser General Public
9 * License as published by the Free Software Foundation; either
10 * version 2.1 of the License, or (at your option) any later version.
11 *
12 * This library is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
15 * Lesser General Public License for more details.
16 *
17 * You should have received a copy of the GNU Lesser General Public
18 * License along with this library. If not, see
19 * <http://www.gnu.org/licenses/>.
20 */
21
22 #include <config.h>
23
24 #include "virnetserverprogram.h"
25 #include "virnetserverclient.h"
26
27 #include "viralloc.h"
28 #include "virerror.h"
29 #include "virlog.h"
30 #include "virfile.h"
31 #include "virthread.h"
32
33 #define VIR_FROM_THIS VIR_FROM_RPC
34
35 VIR_LOG_INIT("rpc.netserverprogram");
36
37 struct _virNetServerProgram {
38 virObject parent;
39
40 unsigned program;
41 unsigned version;
42 virNetServerProgramProc *procs;
43 size_t nprocs;
44 };
45
46
47 static virClass *virNetServerProgramClass;
48 static void virNetServerProgramDispose(void *obj);
49
virNetServerProgramOnceInit(void)50 static int virNetServerProgramOnceInit(void)
51 {
52 if (!VIR_CLASS_NEW(virNetServerProgram, virClassForObject()))
53 return -1;
54
55 return 0;
56 }
57
58 VIR_ONCE_GLOBAL_INIT(virNetServerProgram);
59
60
virNetServerProgramNew(unsigned program,unsigned version,virNetServerProgramProc * procs,size_t nprocs)61 virNetServerProgram *virNetServerProgramNew(unsigned program,
62 unsigned version,
63 virNetServerProgramProc *procs,
64 size_t nprocs)
65 {
66 virNetServerProgram *prog;
67
68 if (virNetServerProgramInitialize() < 0)
69 return NULL;
70
71 if (!(prog = virObjectNew(virNetServerProgramClass)))
72 return NULL;
73
74 prog->program = program;
75 prog->version = version;
76 prog->procs = procs;
77 prog->nprocs = nprocs;
78
79 VIR_DEBUG("prog=%p", prog);
80
81 return prog;
82 }
83
84
virNetServerProgramGetID(virNetServerProgram * prog)85 int virNetServerProgramGetID(virNetServerProgram *prog)
86 {
87 return prog->program;
88 }
89
90
virNetServerProgramGetVersion(virNetServerProgram * prog)91 int virNetServerProgramGetVersion(virNetServerProgram *prog)
92 {
93 return prog->version;
94 }
95
96
virNetServerProgramMatches(virNetServerProgram * prog,virNetMessage * msg)97 int virNetServerProgramMatches(virNetServerProgram *prog,
98 virNetMessage *msg)
99 {
100 if (prog->program == msg->header.prog &&
101 prog->version == msg->header.vers)
102 return 1;
103 return 0;
104 }
105
106
virNetServerProgramGetProc(virNetServerProgram * prog,int procedure)107 static virNetServerProgramProc *virNetServerProgramGetProc(virNetServerProgram *prog,
108 int procedure)
109 {
110 virNetServerProgramProc *proc;
111
112 if (procedure < 0)
113 return NULL;
114 if (procedure >= prog->nprocs)
115 return NULL;
116
117 proc = &prog->procs[procedure];
118
119 if (!proc->func)
120 return NULL;
121
122 return proc;
123 }
124
125 unsigned int
virNetServerProgramGetPriority(virNetServerProgram * prog,int procedure)126 virNetServerProgramGetPriority(virNetServerProgram *prog,
127 int procedure)
128 {
129 virNetServerProgramProc *proc = virNetServerProgramGetProc(prog, procedure);
130
131 if (!proc)
132 return 0;
133
134 return proc->priority;
135 }
136
137 static int
virNetServerProgramSendError(unsigned program,unsigned version,virNetServerClient * client,virNetMessage * msg,struct virNetMessageError * rerr,int procedure,int type,unsigned int serial)138 virNetServerProgramSendError(unsigned program,
139 unsigned version,
140 virNetServerClient *client,
141 virNetMessage *msg,
142 struct virNetMessageError *rerr,
143 int procedure,
144 int type,
145 unsigned int serial)
146 {
147 VIR_DEBUG("prog=%d ver=%d proc=%d type=%d serial=%u msg=%p rerr=%p",
148 program, version, procedure, type, serial, msg, rerr);
149
150 virNetMessageSaveError(rerr);
151
152 /* Return header. */
153 msg->header.prog = program;
154 msg->header.vers = version;
155 msg->header.proc = procedure;
156 msg->header.type = type;
157 msg->header.serial = serial;
158 msg->header.status = VIR_NET_ERROR;
159
160 if (virNetMessageEncodeHeader(msg) < 0)
161 goto error;
162
163 if (virNetMessageEncodePayload(msg, (xdrproc_t)xdr_virNetMessageError, rerr) < 0)
164 goto error;
165 xdr_free((xdrproc_t)xdr_virNetMessageError, (void*)rerr);
166
167 /* Put reply on end of tx queue to send out */
168 if (virNetServerClientSendMessage(client, msg) < 0)
169 return -1;
170
171 return 0;
172
173 error:
174 VIR_WARN("Failed to serialize remote error '%p'", rerr);
175 xdr_free((xdrproc_t)xdr_virNetMessageError, (void*)rerr);
176 return -1;
177 }
178
179
180 /*
181 * @client: the client to send the error to
182 * @req: the message this error is in reply to
183 *
184 * Send an error message to the client
185 *
186 * Returns 0 if the error was sent, -1 upon fatal error
187 */
188 int
virNetServerProgramSendReplyError(virNetServerProgram * prog,virNetServerClient * client,virNetMessage * msg,struct virNetMessageError * rerr,struct virNetMessageHeader * req)189 virNetServerProgramSendReplyError(virNetServerProgram *prog,
190 virNetServerClient *client,
191 virNetMessage *msg,
192 struct virNetMessageError *rerr,
193 struct virNetMessageHeader *req)
194 {
195 /*
196 * For data streams, errors are sent back as data streams
197 * For method calls, errors are sent back as method replies
198 */
199 return virNetServerProgramSendError(prog->program,
200 prog->version,
201 client,
202 msg,
203 rerr,
204 req->proc,
205 req->type == VIR_NET_STREAM ? VIR_NET_STREAM : VIR_NET_REPLY,
206 req->serial);
207 }
208
209
virNetServerProgramSendStreamError(virNetServerProgram * prog,virNetServerClient * client,virNetMessage * msg,struct virNetMessageError * rerr,int procedure,unsigned int serial)210 int virNetServerProgramSendStreamError(virNetServerProgram *prog,
211 virNetServerClient *client,
212 virNetMessage *msg,
213 struct virNetMessageError *rerr,
214 int procedure,
215 unsigned int serial)
216 {
217 return virNetServerProgramSendError(prog->program,
218 prog->version,
219 client,
220 msg,
221 rerr,
222 procedure,
223 VIR_NET_STREAM,
224 serial);
225 }
226
227
virNetServerProgramUnknownError(virNetServerClient * client,virNetMessage * msg,struct virNetMessageHeader * req)228 int virNetServerProgramUnknownError(virNetServerClient *client,
229 virNetMessage *msg,
230 struct virNetMessageHeader *req)
231 {
232 virNetMessageError rerr;
233
234 virReportError(VIR_ERR_RPC,
235 _("Cannot find program %d version %d"), req->prog, req->vers);
236
237 memset(&rerr, 0, sizeof(rerr));
238 return virNetServerProgramSendError(req->prog,
239 req->vers,
240 client,
241 msg,
242 &rerr,
243 req->proc,
244 VIR_NET_REPLY,
245 req->serial);
246 }
247
248
249 static int
250 virNetServerProgramDispatchCall(virNetServerProgram *prog,
251 virNetServer *server,
252 virNetServerClient *client,
253 virNetMessage *msg);
254
255 /*
256 * @server: the unlocked server object
257 * @client: the unlocked client object
258 * @msg: the complete incoming message packet, with header already decoded
259 *
260 * This function is intended to be called from worker threads
261 * when an incoming message is ready to be dispatched for
262 * execution.
263 *
264 * Upon successful return the '@msg' instance will be released
265 * by this function (or more often, reused to send a reply).
266 * Upon failure, the '@msg' must be freed by the caller.
267 *
268 * Returns 0 if the message was dispatched, -1 upon fatal error
269 */
virNetServerProgramDispatch(virNetServerProgram * prog,virNetServer * server,virNetServerClient * client,virNetMessage * msg)270 int virNetServerProgramDispatch(virNetServerProgram *prog,
271 virNetServer *server,
272 virNetServerClient *client,
273 virNetMessage *msg)
274 {
275 int ret = -1;
276 virNetMessageError rerr;
277
278 memset(&rerr, 0, sizeof(rerr));
279
280 VIR_DEBUG("prog=%d ver=%d type=%d status=%d serial=%u proc=%d",
281 msg->header.prog, msg->header.vers, msg->header.type,
282 msg->header.status, msg->header.serial, msg->header.proc);
283
284 /* Check version, etc. */
285 if (msg->header.prog != prog->program) {
286 virReportError(VIR_ERR_RPC,
287 _("program mismatch (actual %x, expected %x)"),
288 msg->header.prog, prog->program);
289 goto error;
290 }
291
292 if (msg->header.vers != prog->version) {
293 virReportError(VIR_ERR_RPC,
294 _("version mismatch (actual %x, expected %x)"),
295 msg->header.vers, prog->version);
296 goto error;
297 }
298
299 switch (msg->header.type) {
300 case VIR_NET_CALL:
301 case VIR_NET_CALL_WITH_FDS:
302 ret = virNetServerProgramDispatchCall(prog, server, client, msg);
303 break;
304
305 case VIR_NET_STREAM:
306 /* Since stream data is non-acked, async, we may continue to receive
307 * stream packets after we closed down a stream. Just drop & ignore
308 * these.
309 */
310 VIR_INFO("Ignoring unexpected stream data serial=%u proc=%d status=%d",
311 msg->header.serial, msg->header.proc, msg->header.status);
312 /* Send a dummy reply to free up 'msg' & unblock client rx */
313 virNetMessageClear(msg);
314 msg->header.type = VIR_NET_REPLY;
315 if (virNetServerClientSendMessage(client, msg) < 0)
316 return -1;
317 ret = 0;
318 break;
319
320 case VIR_NET_REPLY:
321 case VIR_NET_REPLY_WITH_FDS:
322 case VIR_NET_MESSAGE:
323 case VIR_NET_STREAM_HOLE:
324 default:
325 virReportError(VIR_ERR_RPC,
326 _("Unexpected message type %u"),
327 msg->header.type);
328 goto error;
329 }
330
331 return ret;
332
333 error:
334 if (msg->header.type == VIR_NET_CALL ||
335 msg->header.type == VIR_NET_CALL_WITH_FDS) {
336 ret = virNetServerProgramSendReplyError(prog, client, msg, &rerr, &msg->header);
337 } else {
338 /* Send a dummy reply to free up 'msg' & unblock client rx */
339 virNetMessageClear(msg);
340 msg->header.type = VIR_NET_REPLY;
341 if (virNetServerClientSendMessage(client, msg) < 0)
342 return -1;
343 ret = 0;
344 }
345
346 return ret;
347 }
348
349
350 /*
351 * @server: the unlocked server object
352 * @client: the unlocked client object
353 * @msg: the complete incoming method call, with header already decoded
354 *
355 * This method is used to dispatch a message representing an
356 * incoming method call from a client. It decodes the payload
357 * to obtain method call arguments, invokes the method and
358 * then sends a reply packet with the return values
359 *
360 * Returns 0 if the reply was sent, or -1 upon fatal error
361 */
362 static int
virNetServerProgramDispatchCall(virNetServerProgram * prog,virNetServer * server,virNetServerClient * client,virNetMessage * msg)363 virNetServerProgramDispatchCall(virNetServerProgram *prog,
364 virNetServer *server,
365 virNetServerClient *client,
366 virNetMessage *msg)
367 {
368 g_autofree char *arg = NULL;
369 g_autofree char *ret = NULL;
370 int rv = -1;
371 virNetServerProgramProc *dispatcher;
372 virNetMessageError rerr;
373 size_t i;
374 g_autoptr(virIdentity) identity = NULL;
375
376 memset(&rerr, 0, sizeof(rerr));
377
378 if (msg->header.status != VIR_NET_OK) {
379 virReportError(VIR_ERR_RPC,
380 _("Unexpected message status %u"),
381 msg->header.status);
382 goto error;
383 }
384
385 dispatcher = virNetServerProgramGetProc(prog, msg->header.proc);
386
387 if (!dispatcher) {
388 virReportError(VIR_ERR_RPC,
389 _("unknown procedure: %d"),
390 msg->header.proc);
391 goto error;
392 }
393
394 /* If the client is not authenticated, don't allow any RPC ops
395 * which are except for authentication ones */
396 if (dispatcher->needAuth &&
397 !virNetServerClientIsAuthenticated(client)) {
398 /* Explicitly *NOT* calling remoteDispatchAuthError() because
399 we want back-compatibility with libvirt clients which don't
400 support the VIR_ERR_AUTH_FAILED error code */
401 virReportError(VIR_ERR_RPC,
402 "%s", _("authentication required"));
403 goto error;
404 }
405
406 arg = g_new0(char, dispatcher->arg_len);
407 ret = g_new0(char, dispatcher->ret_len);
408
409 if (virNetMessageDecodePayload(msg, dispatcher->arg_filter, arg) < 0)
410 goto error;
411
412 if (!(identity = virNetServerClientGetIdentity(client)))
413 goto error;
414
415 if (virIdentitySetCurrent(identity) < 0)
416 goto error;
417
418 /*
419 * When the RPC handler is called:
420 *
421 * - Server object is unlocked
422 * - Client object is unlocked
423 *
424 * Without locking, it is safe to use:
425 *
426 * 'args and 'ret'
427 */
428 rv = (dispatcher->func)(server, client, msg, &rerr, arg, ret);
429
430 if (virIdentitySetCurrent(NULL) < 0)
431 goto error;
432
433 /*
434 * If rv == 1, this indicates the dispatch func has
435 * populated 'msg' with a list of FDs to return to
436 * the caller.
437 *
438 * Otherwise we must clear out the FDs we got from
439 * the client originally.
440 *
441 */
442 if (rv != 1) {
443 for (i = 0; i < msg->nfds; i++)
444 VIR_FORCE_CLOSE(msg->fds[i]);
445 VIR_FREE(msg->fds);
446 msg->nfds = 0;
447 }
448
449 xdr_free(dispatcher->arg_filter, arg);
450
451 if (rv < 0)
452 goto error;
453
454 /* Return header. We're re-using same message object, so
455 * only need to tweak type/status fields */
456 /*msg->header.prog = msg->header.prog;*/
457 /*msg->header.vers = msg->header.vers;*/
458 /*msg->header.proc = msg->header.proc;*/
459 msg->header.type = msg->nfds ? VIR_NET_REPLY_WITH_FDS : VIR_NET_REPLY;
460 /*msg->header.serial = msg->header.serial;*/
461 msg->header.status = VIR_NET_OK;
462
463 if (virNetMessageEncodeHeader(msg) < 0) {
464 xdr_free(dispatcher->ret_filter, ret);
465 goto error;
466 }
467
468 if (msg->nfds &&
469 virNetMessageEncodeNumFDs(msg) < 0) {
470 xdr_free(dispatcher->ret_filter, ret);
471 goto error;
472 }
473
474 if (virNetMessageEncodePayload(msg, dispatcher->ret_filter, ret) < 0) {
475 xdr_free(dispatcher->ret_filter, ret);
476 goto error;
477 }
478
479 xdr_free(dispatcher->ret_filter, ret);
480
481 /* Put reply on end of tx queue to send out */
482 return virNetServerClientSendMessage(client, msg);
483
484 error:
485 /* Bad stuff (de-)serializing message, but we have an
486 * RPC error message we can send back to the client */
487 rv = virNetServerProgramSendReplyError(prog, client, msg, &rerr, &msg->header);
488
489 return rv;
490 }
491
492
virNetServerProgramSendStreamData(virNetServerProgram * prog,virNetServerClient * client,virNetMessage * msg,int procedure,unsigned int serial,const char * data,size_t len)493 int virNetServerProgramSendStreamData(virNetServerProgram *prog,
494 virNetServerClient *client,
495 virNetMessage *msg,
496 int procedure,
497 unsigned int serial,
498 const char *data,
499 size_t len)
500 {
501 VIR_DEBUG("client=%p msg=%p data=%p len=%zu", client, msg, data, len);
502
503 /* Return header. We're reusing same message object, so
504 * only need to tweak type/status fields */
505 msg->header.prog = prog->program;
506 msg->header.vers = prog->version;
507 msg->header.proc = procedure;
508 msg->header.type = VIR_NET_STREAM;
509 msg->header.serial = serial;
510 /*
511 * NB
512 * data != NULL + len > 0 => VIR_NET_CONTINUE (Sending back data)
513 * data != NULL + len == 0 => VIR_NET_CONTINUE (Sending read EOF)
514 * data == NULL => VIR_NET_OK (Sending finish handshake confirmation)
515 */
516 msg->header.status = data ? VIR_NET_CONTINUE : VIR_NET_OK;
517
518 if (virNetMessageEncodeHeader(msg) < 0)
519 return -1;
520
521 if (data && len) {
522 if (virNetMessageEncodePayloadRaw(msg, data, len) < 0)
523 return -1;
524
525 } else {
526 if (virNetMessageEncodePayloadEmpty(msg) < 0)
527 return -1;
528 }
529 VIR_DEBUG("Total %zu", msg->bufferLength);
530
531 return virNetServerClientSendMessage(client, msg);
532 }
533
534
virNetServerProgramSendStreamHole(virNetServerProgram * prog,virNetServerClient * client,virNetMessage * msg,int procedure,unsigned int serial,long long length,unsigned int flags)535 int virNetServerProgramSendStreamHole(virNetServerProgram *prog,
536 virNetServerClient *client,
537 virNetMessage *msg,
538 int procedure,
539 unsigned int serial,
540 long long length,
541 unsigned int flags)
542 {
543 virNetStreamHole data;
544
545 VIR_DEBUG("client=%p msg=%p length=%lld", client, msg, length);
546
547 memset(&data, 0, sizeof(data));
548 data.length = length;
549 data.flags = flags;
550
551 msg->header.prog = prog->program;
552 msg->header.vers = prog->version;
553 msg->header.proc = procedure;
554 msg->header.type = VIR_NET_STREAM_HOLE;
555 msg->header.serial = serial;
556 msg->header.status = VIR_NET_CONTINUE;
557
558 if (virNetMessageEncodeHeader(msg) < 0)
559 return -1;
560
561 if (virNetMessageEncodePayload(msg,
562 (xdrproc_t)xdr_virNetStreamHole,
563 &data) < 0)
564 return -1;
565
566 return virNetServerClientSendMessage(client, msg);
567 }
568
569
virNetServerProgramDispose(void * obj G_GNUC_UNUSED)570 void virNetServerProgramDispose(void *obj G_GNUC_UNUSED)
571 {
572 }
573