1 /*
2  * Vhost User library
3  *
4  * Copyright IBM, Corp. 2007
5  * Copyright (c) 2016 Red Hat, Inc.
6  *
7  * Authors:
8  *  Anthony Liguori <aliguori@us.ibm.com>
9  *  Marc-André Lureau <mlureau@redhat.com>
10  *  Victor Kaplansky <victork@redhat.com>
11  *
12  * This work is licensed under the terms of the GNU GPL, version 2 or
13  * later.  See the COPYING file in the top-level directory.
14  */
15 
16 /* this code avoids GLib dependency */
17 #include <stdlib.h>
18 #include <stdio.h>
19 #include <unistd.h>
20 #include <stdarg.h>
21 #include <errno.h>
22 #include <string.h>
23 #include <assert.h>
24 #include <inttypes.h>
25 #include <sys/types.h>
26 #include <sys/socket.h>
27 #include <sys/eventfd.h>
28 #include <sys/mman.h>
29 #include <endian.h>
30 
31 #if defined(__linux__)
32 #include <sys/syscall.h>
33 #include <fcntl.h>
34 #include <sys/ioctl.h>
35 #include <linux/vhost.h>
36 
37 #ifdef __NR_userfaultfd
38 #include <linux/userfaultfd.h>
39 #endif
40 
41 #endif
42 
43 #include "include/atomic.h"
44 
45 #include "libvhost-user.h"
46 
47 /* usually provided by GLib */
48 #if     __GNUC__ > 2 || (__GNUC__ == 2 && __GNUC_MINOR__ > 4)
49 #if !defined(__clang__) && (__GNUC__ == 4 && __GNUC_MINOR__ == 4)
50 #define G_GNUC_PRINTF(format_idx, arg_idx) \
51   __attribute__((__format__(gnu_printf, format_idx, arg_idx)))
52 #else
53 #define G_GNUC_PRINTF(format_idx, arg_idx) \
54   __attribute__((__format__(__printf__, format_idx, arg_idx)))
55 #endif
56 #else   /* !__GNUC__ */
57 #define G_GNUC_PRINTF(format_idx, arg_idx)
58 #endif  /* !__GNUC__ */
59 #ifndef MIN
60 #define MIN(x, y) ({                            \
61             typeof(x) _min1 = (x);              \
62             typeof(y) _min2 = (y);              \
63             (void) (&_min1 == &_min2);          \
64             _min1 < _min2 ? _min1 : _min2; })
65 #endif
66 
67 /* Round number down to multiple */
68 #define ALIGN_DOWN(n, m) ((n) / (m) * (m))
69 
70 /* Round number up to multiple */
71 #define ALIGN_UP(n, m) ALIGN_DOWN((n) + (m) - 1, (m))
72 
73 #ifndef unlikely
74 #define unlikely(x)   __builtin_expect(!!(x), 0)
75 #endif
76 
77 /* Align each region to cache line size in inflight buffer */
78 #define INFLIGHT_ALIGNMENT 64
79 
80 /* The version of inflight buffer */
81 #define INFLIGHT_VERSION 1
82 
83 /* The version of the protocol we support */
84 #define VHOST_USER_VERSION 1
85 #define LIBVHOST_USER_DEBUG 0
86 
87 #define DPRINT(...)                             \
88     do {                                        \
89         if (LIBVHOST_USER_DEBUG) {              \
90             fprintf(stderr, __VA_ARGS__);        \
91         }                                       \
92     } while (0)
93 
94 static inline
95 bool has_feature(uint64_t features, unsigned int fbit)
96 {
97     assert(fbit < 64);
98     return !!(features & (1ULL << fbit));
99 }
100 
101 static inline
102 bool vu_has_feature(VuDev *dev,
103                     unsigned int fbit)
104 {
105     return has_feature(dev->features, fbit);
106 }
107 
108 static inline bool vu_has_protocol_feature(VuDev *dev, unsigned int fbit)
109 {
110     return has_feature(dev->protocol_features, fbit);
111 }
112 
113 const char *
114 vu_request_to_string(unsigned int req)
115 {
116 #define REQ(req) [req] = #req
117     static const char *vu_request_str[] = {
118         REQ(VHOST_USER_NONE),
119         REQ(VHOST_USER_GET_FEATURES),
120         REQ(VHOST_USER_SET_FEATURES),
121         REQ(VHOST_USER_SET_OWNER),
122         REQ(VHOST_USER_RESET_OWNER),
123         REQ(VHOST_USER_SET_MEM_TABLE),
124         REQ(VHOST_USER_SET_LOG_BASE),
125         REQ(VHOST_USER_SET_LOG_FD),
126         REQ(VHOST_USER_SET_VRING_NUM),
127         REQ(VHOST_USER_SET_VRING_ADDR),
128         REQ(VHOST_USER_SET_VRING_BASE),
129         REQ(VHOST_USER_GET_VRING_BASE),
130         REQ(VHOST_USER_SET_VRING_KICK),
131         REQ(VHOST_USER_SET_VRING_CALL),
132         REQ(VHOST_USER_SET_VRING_ERR),
133         REQ(VHOST_USER_GET_PROTOCOL_FEATURES),
134         REQ(VHOST_USER_SET_PROTOCOL_FEATURES),
135         REQ(VHOST_USER_GET_QUEUE_NUM),
136         REQ(VHOST_USER_SET_VRING_ENABLE),
137         REQ(VHOST_USER_SEND_RARP),
138         REQ(VHOST_USER_NET_SET_MTU),
139         REQ(VHOST_USER_SET_SLAVE_REQ_FD),
140         REQ(VHOST_USER_IOTLB_MSG),
141         REQ(VHOST_USER_SET_VRING_ENDIAN),
142         REQ(VHOST_USER_GET_CONFIG),
143         REQ(VHOST_USER_SET_CONFIG),
144         REQ(VHOST_USER_POSTCOPY_ADVISE),
145         REQ(VHOST_USER_POSTCOPY_LISTEN),
146         REQ(VHOST_USER_POSTCOPY_END),
147         REQ(VHOST_USER_GET_INFLIGHT_FD),
148         REQ(VHOST_USER_SET_INFLIGHT_FD),
149         REQ(VHOST_USER_GPU_SET_SOCKET),
150         REQ(VHOST_USER_VRING_KICK),
151         REQ(VHOST_USER_GET_MAX_MEM_SLOTS),
152         REQ(VHOST_USER_ADD_MEM_REG),
153         REQ(VHOST_USER_REM_MEM_REG),
154         REQ(VHOST_USER_MAX),
155     };
156 #undef REQ
157 
158     if (req < VHOST_USER_MAX) {
159         return vu_request_str[req];
160     } else {
161         return "unknown";
162     }
163 }
164 
165 static void G_GNUC_PRINTF(2, 3)
166 vu_panic(VuDev *dev, const char *msg, ...)
167 {
168     char *buf = NULL;
169     va_list ap;
170 
171     va_start(ap, msg);
172     if (vasprintf(&buf, msg, ap) < 0) {
173         buf = NULL;
174     }
175     va_end(ap);
176 
177     dev->broken = true;
178     dev->panic(dev, buf);
179     free(buf);
180 
181     /*
182      * FIXME:
183      * find a way to call virtio_error, or perhaps close the connection?
184      */
185 }
186 
187 /* Translate guest physical address to our virtual address.  */
188 void *
189 vu_gpa_to_va(VuDev *dev, uint64_t *plen, uint64_t guest_addr)
190 {
191     int i;
192 
193     if (*plen == 0) {
194         return NULL;
195     }
196 
197     /* Find matching memory region.  */
198     for (i = 0; i < dev->nregions; i++) {
199         VuDevRegion *r = &dev->regions[i];
200 
201         if ((guest_addr >= r->gpa) && (guest_addr < (r->gpa + r->size))) {
202             if ((guest_addr + *plen) > (r->gpa + r->size)) {
203                 *plen = r->gpa + r->size - guest_addr;
204             }
205             return (void *)(uintptr_t)
206                 guest_addr - r->gpa + r->mmap_addr + r->mmap_offset;
207         }
208     }
209 
210     return NULL;
211 }
212 
213 /* Translate qemu virtual address to our virtual address.  */
214 static void *
215 qva_to_va(VuDev *dev, uint64_t qemu_addr)
216 {
217     int i;
218 
219     /* Find matching memory region.  */
220     for (i = 0; i < dev->nregions; i++) {
221         VuDevRegion *r = &dev->regions[i];
222 
223         if ((qemu_addr >= r->qva) && (qemu_addr < (r->qva + r->size))) {
224             return (void *)(uintptr_t)
225                 qemu_addr - r->qva + r->mmap_addr + r->mmap_offset;
226         }
227     }
228 
229     return NULL;
230 }
231 
232 static void
233 vmsg_close_fds(VhostUserMsg *vmsg)
234 {
235     int i;
236 
237     for (i = 0; i < vmsg->fd_num; i++) {
238         close(vmsg->fds[i]);
239     }
240 }
241 
242 /* Set reply payload.u64 and clear request flags and fd_num */
243 static void vmsg_set_reply_u64(VhostUserMsg *vmsg, uint64_t val)
244 {
245     vmsg->flags = 0; /* defaults will be set by vu_send_reply() */
246     vmsg->size = sizeof(vmsg->payload.u64);
247     vmsg->payload.u64 = val;
248     vmsg->fd_num = 0;
249 }
250 
251 /* A test to see if we have userfault available */
252 static bool
253 have_userfault(void)
254 {
255 #if defined(__linux__) && defined(__NR_userfaultfd) &&\
256         defined(UFFD_FEATURE_MISSING_SHMEM) &&\
257         defined(UFFD_FEATURE_MISSING_HUGETLBFS)
258     /* Now test the kernel we're running on really has the features */
259     int ufd = syscall(__NR_userfaultfd, O_CLOEXEC | O_NONBLOCK);
260     struct uffdio_api api_struct;
261     if (ufd < 0) {
262         return false;
263     }
264 
265     api_struct.api = UFFD_API;
266     api_struct.features = UFFD_FEATURE_MISSING_SHMEM |
267                           UFFD_FEATURE_MISSING_HUGETLBFS;
268     if (ioctl(ufd, UFFDIO_API, &api_struct)) {
269         close(ufd);
270         return false;
271     }
272     close(ufd);
273     return true;
274 
275 #else
276     return false;
277 #endif
278 }
279 
280 static bool
281 vu_message_read_default(VuDev *dev, int conn_fd, VhostUserMsg *vmsg)
282 {
283     char control[CMSG_SPACE(VHOST_MEMORY_BASELINE_NREGIONS * sizeof(int))] = {};
284     struct iovec iov = {
285         .iov_base = (char *)vmsg,
286         .iov_len = VHOST_USER_HDR_SIZE,
287     };
288     struct msghdr msg = {
289         .msg_iov = &iov,
290         .msg_iovlen = 1,
291         .msg_control = control,
292         .msg_controllen = sizeof(control),
293     };
294     size_t fd_size;
295     struct cmsghdr *cmsg;
296     int rc;
297 
298     do {
299         rc = recvmsg(conn_fd, &msg, 0);
300     } while (rc < 0 && (errno == EINTR || errno == EAGAIN));
301 
302     if (rc < 0) {
303         vu_panic(dev, "Error while recvmsg: %s", strerror(errno));
304         return false;
305     }
306 
307     vmsg->fd_num = 0;
308     for (cmsg = CMSG_FIRSTHDR(&msg);
309          cmsg != NULL;
310          cmsg = CMSG_NXTHDR(&msg, cmsg))
311     {
312         if (cmsg->cmsg_level == SOL_SOCKET && cmsg->cmsg_type == SCM_RIGHTS) {
313             fd_size = cmsg->cmsg_len - CMSG_LEN(0);
314             vmsg->fd_num = fd_size / sizeof(int);
315             memcpy(vmsg->fds, CMSG_DATA(cmsg), fd_size);
316             break;
317         }
318     }
319 
320     if (vmsg->size > sizeof(vmsg->payload)) {
321         vu_panic(dev,
322                  "Error: too big message request: %d, size: vmsg->size: %u, "
323                  "while sizeof(vmsg->payload) = %zu\n",
324                  vmsg->request, vmsg->size, sizeof(vmsg->payload));
325         goto fail;
326     }
327 
328     if (vmsg->size) {
329         do {
330             rc = read(conn_fd, &vmsg->payload, vmsg->size);
331         } while (rc < 0 && (errno == EINTR || errno == EAGAIN));
332 
333         if (rc <= 0) {
334             vu_panic(dev, "Error while reading: %s", strerror(errno));
335             goto fail;
336         }
337 
338         assert(rc == vmsg->size);
339     }
340 
341     return true;
342 
343 fail:
344     vmsg_close_fds(vmsg);
345 
346     return false;
347 }
348 
349 static bool
350 vu_message_write(VuDev *dev, int conn_fd, VhostUserMsg *vmsg)
351 {
352     int rc;
353     uint8_t *p = (uint8_t *)vmsg;
354     char control[CMSG_SPACE(VHOST_MEMORY_BASELINE_NREGIONS * sizeof(int))] = {};
355     struct iovec iov = {
356         .iov_base = (char *)vmsg,
357         .iov_len = VHOST_USER_HDR_SIZE,
358     };
359     struct msghdr msg = {
360         .msg_iov = &iov,
361         .msg_iovlen = 1,
362         .msg_control = control,
363     };
364     struct cmsghdr *cmsg;
365 
366     memset(control, 0, sizeof(control));
367     assert(vmsg->fd_num <= VHOST_MEMORY_BASELINE_NREGIONS);
368     if (vmsg->fd_num > 0) {
369         size_t fdsize = vmsg->fd_num * sizeof(int);
370         msg.msg_controllen = CMSG_SPACE(fdsize);
371         cmsg = CMSG_FIRSTHDR(&msg);
372         cmsg->cmsg_len = CMSG_LEN(fdsize);
373         cmsg->cmsg_level = SOL_SOCKET;
374         cmsg->cmsg_type = SCM_RIGHTS;
375         memcpy(CMSG_DATA(cmsg), vmsg->fds, fdsize);
376     } else {
377         msg.msg_controllen = 0;
378     }
379 
380     do {
381         rc = sendmsg(conn_fd, &msg, 0);
382     } while (rc < 0 && (errno == EINTR || errno == EAGAIN));
383 
384     if (vmsg->size) {
385         do {
386             if (vmsg->data) {
387                 rc = write(conn_fd, vmsg->data, vmsg->size);
388             } else {
389                 rc = write(conn_fd, p + VHOST_USER_HDR_SIZE, vmsg->size);
390             }
391         } while (rc < 0 && (errno == EINTR || errno == EAGAIN));
392     }
393 
394     if (rc <= 0) {
395         vu_panic(dev, "Error while writing: %s", strerror(errno));
396         return false;
397     }
398 
399     return true;
400 }
401 
402 static bool
403 vu_send_reply(VuDev *dev, int conn_fd, VhostUserMsg *vmsg)
404 {
405     /* Set the version in the flags when sending the reply */
406     vmsg->flags &= ~VHOST_USER_VERSION_MASK;
407     vmsg->flags |= VHOST_USER_VERSION;
408     vmsg->flags |= VHOST_USER_REPLY_MASK;
409 
410     return vu_message_write(dev, conn_fd, vmsg);
411 }
412 
413 /*
414  * Processes a reply on the slave channel.
415  * Entered with slave_mutex held and releases it before exit.
416  * Returns true on success.
417  */
418 static bool
419 vu_process_message_reply(VuDev *dev, const VhostUserMsg *vmsg)
420 {
421     VhostUserMsg msg_reply;
422     bool result = false;
423 
424     if ((vmsg->flags & VHOST_USER_NEED_REPLY_MASK) == 0) {
425         result = true;
426         goto out;
427     }
428 
429     if (!vu_message_read_default(dev, dev->slave_fd, &msg_reply)) {
430         goto out;
431     }
432 
433     if (msg_reply.request != vmsg->request) {
434         DPRINT("Received unexpected msg type. Expected %d received %d",
435                vmsg->request, msg_reply.request);
436         goto out;
437     }
438 
439     result = msg_reply.payload.u64 == 0;
440 
441 out:
442     pthread_mutex_unlock(&dev->slave_mutex);
443     return result;
444 }
445 
446 /* Kick the log_call_fd if required. */
447 static void
448 vu_log_kick(VuDev *dev)
449 {
450     if (dev->log_call_fd != -1) {
451         DPRINT("Kicking the QEMU's log...\n");
452         if (eventfd_write(dev->log_call_fd, 1) < 0) {
453             vu_panic(dev, "Error writing eventfd: %s", strerror(errno));
454         }
455     }
456 }
457 
458 static void
459 vu_log_page(uint8_t *log_table, uint64_t page)
460 {
461     DPRINT("Logged dirty guest page: %"PRId64"\n", page);
462     qatomic_or(&log_table[page / 8], 1 << (page % 8));
463 }
464 
465 static void
466 vu_log_write(VuDev *dev, uint64_t address, uint64_t length)
467 {
468     uint64_t page;
469 
470     if (!(dev->features & (1ULL << VHOST_F_LOG_ALL)) ||
471         !dev->log_table || !length) {
472         return;
473     }
474 
475     assert(dev->log_size > ((address + length - 1) / VHOST_LOG_PAGE / 8));
476 
477     page = address / VHOST_LOG_PAGE;
478     while (page * VHOST_LOG_PAGE < address + length) {
479         vu_log_page(dev->log_table, page);
480         page += 1;
481     }
482 
483     vu_log_kick(dev);
484 }
485 
486 static void
487 vu_kick_cb(VuDev *dev, int condition, void *data)
488 {
489     int index = (intptr_t)data;
490     VuVirtq *vq = &dev->vq[index];
491     int sock = vq->kick_fd;
492     eventfd_t kick_data;
493     ssize_t rc;
494 
495     rc = eventfd_read(sock, &kick_data);
496     if (rc == -1) {
497         vu_panic(dev, "kick eventfd_read(): %s", strerror(errno));
498         dev->remove_watch(dev, dev->vq[index].kick_fd);
499     } else {
500         DPRINT("Got kick_data: %016"PRIx64" handler:%p idx:%d\n",
501                kick_data, vq->handler, index);
502         if (vq->handler) {
503             vq->handler(dev, index);
504         }
505     }
506 }
507 
508 static bool
509 vu_get_features_exec(VuDev *dev, VhostUserMsg *vmsg)
510 {
511     vmsg->payload.u64 =
512         /*
513          * The following VIRTIO feature bits are supported by our virtqueue
514          * implementation:
515          */
516         1ULL << VIRTIO_F_NOTIFY_ON_EMPTY |
517         1ULL << VIRTIO_RING_F_INDIRECT_DESC |
518         1ULL << VIRTIO_RING_F_EVENT_IDX |
519         1ULL << VIRTIO_F_VERSION_1 |
520 
521         /* vhost-user feature bits */
522         1ULL << VHOST_F_LOG_ALL |
523         1ULL << VHOST_USER_F_PROTOCOL_FEATURES;
524 
525     if (dev->iface->get_features) {
526         vmsg->payload.u64 |= dev->iface->get_features(dev);
527     }
528 
529     vmsg->size = sizeof(vmsg->payload.u64);
530     vmsg->fd_num = 0;
531 
532     DPRINT("Sending back to guest u64: 0x%016"PRIx64"\n", vmsg->payload.u64);
533 
534     return true;
535 }
536 
537 static void
538 vu_set_enable_all_rings(VuDev *dev, bool enabled)
539 {
540     uint16_t i;
541 
542     for (i = 0; i < dev->max_queues; i++) {
543         dev->vq[i].enable = enabled;
544     }
545 }
546 
547 static bool
548 vu_set_features_exec(VuDev *dev, VhostUserMsg *vmsg)
549 {
550     DPRINT("u64: 0x%016"PRIx64"\n", vmsg->payload.u64);
551 
552     dev->features = vmsg->payload.u64;
553     if (!vu_has_feature(dev, VIRTIO_F_VERSION_1)) {
554         /*
555          * We only support devices conforming to VIRTIO 1.0 or
556          * later
557          */
558         vu_panic(dev, "virtio legacy devices aren't supported by libvhost-user");
559         return false;
560     }
561 
562     if (!(dev->features & VHOST_USER_F_PROTOCOL_FEATURES)) {
563         vu_set_enable_all_rings(dev, true);
564     }
565 
566     if (dev->iface->set_features) {
567         dev->iface->set_features(dev, dev->features);
568     }
569 
570     return false;
571 }
572 
573 static bool
574 vu_set_owner_exec(VuDev *dev, VhostUserMsg *vmsg)
575 {
576     return false;
577 }
578 
579 static void
580 vu_close_log(VuDev *dev)
581 {
582     if (dev->log_table) {
583         if (munmap(dev->log_table, dev->log_size) != 0) {
584             perror("close log munmap() error");
585         }
586 
587         dev->log_table = NULL;
588     }
589     if (dev->log_call_fd != -1) {
590         close(dev->log_call_fd);
591         dev->log_call_fd = -1;
592     }
593 }
594 
595 static bool
596 vu_reset_device_exec(VuDev *dev, VhostUserMsg *vmsg)
597 {
598     vu_set_enable_all_rings(dev, false);
599 
600     return false;
601 }
602 
603 static bool
604 map_ring(VuDev *dev, VuVirtq *vq)
605 {
606     vq->vring.desc = qva_to_va(dev, vq->vra.desc_user_addr);
607     vq->vring.used = qva_to_va(dev, vq->vra.used_user_addr);
608     vq->vring.avail = qva_to_va(dev, vq->vra.avail_user_addr);
609 
610     DPRINT("Setting virtq addresses:\n");
611     DPRINT("    vring_desc  at %p\n", vq->vring.desc);
612     DPRINT("    vring_used  at %p\n", vq->vring.used);
613     DPRINT("    vring_avail at %p\n", vq->vring.avail);
614 
615     return !(vq->vring.desc && vq->vring.used && vq->vring.avail);
616 }
617 
618 static bool
619 generate_faults(VuDev *dev) {
620     int i;
621     for (i = 0; i < dev->nregions; i++) {
622         VuDevRegion *dev_region = &dev->regions[i];
623         int ret;
624 #ifdef UFFDIO_REGISTER
625         /*
626          * We should already have an open ufd. Mark each memory
627          * range as ufd.
628          * Discard any mapping we have here; note I can't use MADV_REMOVE
629          * or fallocate to make the hole since I don't want to lose
630          * data that's already arrived in the shared process.
631          * TODO: How to do hugepage
632          */
633         ret = madvise((void *)(uintptr_t)dev_region->mmap_addr,
634                       dev_region->size + dev_region->mmap_offset,
635                       MADV_DONTNEED);
636         if (ret) {
637             fprintf(stderr,
638                     "%s: Failed to madvise(DONTNEED) region %d: %s\n",
639                     __func__, i, strerror(errno));
640         }
641         /*
642          * Turn off transparent hugepages so we dont get lose wakeups
643          * in neighbouring pages.
644          * TODO: Turn this backon later.
645          */
646         ret = madvise((void *)(uintptr_t)dev_region->mmap_addr,
647                       dev_region->size + dev_region->mmap_offset,
648                       MADV_NOHUGEPAGE);
649         if (ret) {
650             /*
651              * Note: This can happen legally on kernels that are configured
652              * without madvise'able hugepages
653              */
654             fprintf(stderr,
655                     "%s: Failed to madvise(NOHUGEPAGE) region %d: %s\n",
656                     __func__, i, strerror(errno));
657         }
658         struct uffdio_register reg_struct;
659         reg_struct.range.start = (uintptr_t)dev_region->mmap_addr;
660         reg_struct.range.len = dev_region->size + dev_region->mmap_offset;
661         reg_struct.mode = UFFDIO_REGISTER_MODE_MISSING;
662 
663         if (ioctl(dev->postcopy_ufd, UFFDIO_REGISTER, &reg_struct)) {
664             vu_panic(dev, "%s: Failed to userfault region %d "
665                           "@%" PRIx64 " + size:%" PRIx64 " offset: %" PRIx64
666                           ": (ufd=%d)%s\n",
667                      __func__, i,
668                      dev_region->mmap_addr,
669                      dev_region->size, dev_region->mmap_offset,
670                      dev->postcopy_ufd, strerror(errno));
671             return false;
672         }
673         if (!(reg_struct.ioctls & ((__u64)1 << _UFFDIO_COPY))) {
674             vu_panic(dev, "%s Region (%d) doesn't support COPY",
675                      __func__, i);
676             return false;
677         }
678         DPRINT("%s: region %d: Registered userfault for %"
679                PRIx64 " + %" PRIx64 "\n", __func__, i,
680                (uint64_t)reg_struct.range.start,
681                (uint64_t)reg_struct.range.len);
682         /* Now it's registered we can let the client at it */
683         if (mprotect((void *)(uintptr_t)dev_region->mmap_addr,
684                      dev_region->size + dev_region->mmap_offset,
685                      PROT_READ | PROT_WRITE)) {
686             vu_panic(dev, "failed to mprotect region %d for postcopy (%s)",
687                      i, strerror(errno));
688             return false;
689         }
690         /* TODO: Stash 'zero' support flags somewhere */
691 #endif
692     }
693 
694     return true;
695 }
696 
697 static bool
698 vu_add_mem_reg(VuDev *dev, VhostUserMsg *vmsg) {
699     int i;
700     bool track_ramblocks = dev->postcopy_listening;
701     VhostUserMemoryRegion m = vmsg->payload.memreg.region, *msg_region = &m;
702     VuDevRegion *dev_region = &dev->regions[dev->nregions];
703     void *mmap_addr;
704 
705     if (vmsg->fd_num != 1) {
706         vmsg_close_fds(vmsg);
707         vu_panic(dev, "VHOST_USER_ADD_MEM_REG received %d fds - only 1 fd "
708                       "should be sent for this message type", vmsg->fd_num);
709         return false;
710     }
711 
712     if (vmsg->size < VHOST_USER_MEM_REG_SIZE) {
713         close(vmsg->fds[0]);
714         vu_panic(dev, "VHOST_USER_ADD_MEM_REG requires a message size of at "
715                       "least %zu bytes and only %d bytes were received",
716                       VHOST_USER_MEM_REG_SIZE, vmsg->size);
717         return false;
718     }
719 
720     if (dev->nregions == VHOST_USER_MAX_RAM_SLOTS) {
721         close(vmsg->fds[0]);
722         vu_panic(dev, "failing attempt to hot add memory via "
723                       "VHOST_USER_ADD_MEM_REG message because the backend has "
724                       "no free ram slots available");
725         return false;
726     }
727 
728     /*
729      * If we are in postcopy mode and we receive a u64 payload with a 0 value
730      * we know all the postcopy client bases have been received, and we
731      * should start generating faults.
732      */
733     if (track_ramblocks &&
734         vmsg->size == sizeof(vmsg->payload.u64) &&
735         vmsg->payload.u64 == 0) {
736         (void)generate_faults(dev);
737         return false;
738     }
739 
740     DPRINT("Adding region: %u\n", dev->nregions);
741     DPRINT("    guest_phys_addr: 0x%016"PRIx64"\n",
742            msg_region->guest_phys_addr);
743     DPRINT("    memory_size:     0x%016"PRIx64"\n",
744            msg_region->memory_size);
745     DPRINT("    userspace_addr   0x%016"PRIx64"\n",
746            msg_region->userspace_addr);
747     DPRINT("    mmap_offset      0x%016"PRIx64"\n",
748            msg_region->mmap_offset);
749 
750     dev_region->gpa = msg_region->guest_phys_addr;
751     dev_region->size = msg_region->memory_size;
752     dev_region->qva = msg_region->userspace_addr;
753     dev_region->mmap_offset = msg_region->mmap_offset;
754 
755     /*
756      * We don't use offset argument of mmap() since the
757      * mapped address has to be page aligned, and we use huge
758      * pages.
759      */
760     if (track_ramblocks) {
761         /*
762          * In postcopy we're using PROT_NONE here to catch anyone
763          * accessing it before we userfault.
764          */
765         mmap_addr = mmap(0, dev_region->size + dev_region->mmap_offset,
766                          PROT_NONE, MAP_SHARED | MAP_NORESERVE,
767                          vmsg->fds[0], 0);
768     } else {
769         mmap_addr = mmap(0, dev_region->size + dev_region->mmap_offset,
770                          PROT_READ | PROT_WRITE, MAP_SHARED | MAP_NORESERVE,
771                          vmsg->fds[0], 0);
772     }
773 
774     if (mmap_addr == MAP_FAILED) {
775         vu_panic(dev, "region mmap error: %s", strerror(errno));
776     } else {
777         dev_region->mmap_addr = (uint64_t)(uintptr_t)mmap_addr;
778         DPRINT("    mmap_addr:       0x%016"PRIx64"\n",
779                dev_region->mmap_addr);
780     }
781 
782     close(vmsg->fds[0]);
783 
784     if (track_ramblocks) {
785         /*
786          * Return the address to QEMU so that it can translate the ufd
787          * fault addresses back.
788          */
789         msg_region->userspace_addr = (uintptr_t)(mmap_addr +
790                                                  dev_region->mmap_offset);
791 
792         /* Send the message back to qemu with the addresses filled in. */
793         vmsg->fd_num = 0;
794         DPRINT("Successfully added new region in postcopy\n");
795         dev->nregions++;
796         return true;
797     } else {
798         for (i = 0; i < dev->max_queues; i++) {
799             if (dev->vq[i].vring.desc) {
800                 if (map_ring(dev, &dev->vq[i])) {
801                     vu_panic(dev, "remapping queue %d for new memory region",
802                              i);
803                 }
804             }
805         }
806 
807         DPRINT("Successfully added new region\n");
808         dev->nregions++;
809         return false;
810     }
811 }
812 
813 static inline bool reg_equal(VuDevRegion *vudev_reg,
814                              VhostUserMemoryRegion *msg_reg)
815 {
816     if (vudev_reg->gpa == msg_reg->guest_phys_addr &&
817         vudev_reg->qva == msg_reg->userspace_addr &&
818         vudev_reg->size == msg_reg->memory_size) {
819         return true;
820     }
821 
822     return false;
823 }
824 
825 static bool
826 vu_rem_mem_reg(VuDev *dev, VhostUserMsg *vmsg) {
827     VhostUserMemoryRegion m = vmsg->payload.memreg.region, *msg_region = &m;
828     int i;
829     bool found = false;
830 
831     if (vmsg->fd_num > 1) {
832         vmsg_close_fds(vmsg);
833         vu_panic(dev, "VHOST_USER_REM_MEM_REG received %d fds - at most 1 fd "
834                       "should be sent for this message type", vmsg->fd_num);
835         return false;
836     }
837 
838     if (vmsg->size < VHOST_USER_MEM_REG_SIZE) {
839         vmsg_close_fds(vmsg);
840         vu_panic(dev, "VHOST_USER_REM_MEM_REG requires a message size of at "
841                       "least %zu bytes and only %d bytes were received",
842                       VHOST_USER_MEM_REG_SIZE, vmsg->size);
843         return false;
844     }
845 
846     DPRINT("Removing region:\n");
847     DPRINT("    guest_phys_addr: 0x%016"PRIx64"\n",
848            msg_region->guest_phys_addr);
849     DPRINT("    memory_size:     0x%016"PRIx64"\n",
850            msg_region->memory_size);
851     DPRINT("    userspace_addr   0x%016"PRIx64"\n",
852            msg_region->userspace_addr);
853     DPRINT("    mmap_offset      0x%016"PRIx64"\n",
854            msg_region->mmap_offset);
855 
856     for (i = 0; i < dev->nregions; i++) {
857         if (reg_equal(&dev->regions[i], msg_region)) {
858             VuDevRegion *r = &dev->regions[i];
859             void *m = (void *) (uintptr_t) r->mmap_addr;
860 
861             if (m) {
862                 munmap(m, r->size + r->mmap_offset);
863             }
864 
865             /*
866              * Shift all affected entries by 1 to close the hole at index i and
867              * zero out the last entry.
868              */
869             memmove(dev->regions + i, dev->regions + i + 1,
870                     sizeof(VuDevRegion) * (dev->nregions - i - 1));
871             memset(dev->regions + dev->nregions - 1, 0, sizeof(VuDevRegion));
872             DPRINT("Successfully removed a region\n");
873             dev->nregions--;
874             i--;
875 
876             found = true;
877 
878             /* Continue the search for eventual duplicates. */
879         }
880     }
881 
882     if (!found) {
883         vu_panic(dev, "Specified region not found\n");
884     }
885 
886     vmsg_close_fds(vmsg);
887 
888     return false;
889 }
890 
891 static bool
892 vu_set_mem_table_exec_postcopy(VuDev *dev, VhostUserMsg *vmsg)
893 {
894     int i;
895     VhostUserMemory m = vmsg->payload.memory, *memory = &m;
896     dev->nregions = memory->nregions;
897 
898     DPRINT("Nregions: %u\n", memory->nregions);
899     for (i = 0; i < dev->nregions; i++) {
900         void *mmap_addr;
901         VhostUserMemoryRegion *msg_region = &memory->regions[i];
902         VuDevRegion *dev_region = &dev->regions[i];
903 
904         DPRINT("Region %d\n", i);
905         DPRINT("    guest_phys_addr: 0x%016"PRIx64"\n",
906                msg_region->guest_phys_addr);
907         DPRINT("    memory_size:     0x%016"PRIx64"\n",
908                msg_region->memory_size);
909         DPRINT("    userspace_addr   0x%016"PRIx64"\n",
910                msg_region->userspace_addr);
911         DPRINT("    mmap_offset      0x%016"PRIx64"\n",
912                msg_region->mmap_offset);
913 
914         dev_region->gpa = msg_region->guest_phys_addr;
915         dev_region->size = msg_region->memory_size;
916         dev_region->qva = msg_region->userspace_addr;
917         dev_region->mmap_offset = msg_region->mmap_offset;
918 
919         /* We don't use offset argument of mmap() since the
920          * mapped address has to be page aligned, and we use huge
921          * pages.
922          * In postcopy we're using PROT_NONE here to catch anyone
923          * accessing it before we userfault
924          */
925         mmap_addr = mmap(0, dev_region->size + dev_region->mmap_offset,
926                          PROT_NONE, MAP_SHARED | MAP_NORESERVE,
927                          vmsg->fds[i], 0);
928 
929         if (mmap_addr == MAP_FAILED) {
930             vu_panic(dev, "region mmap error: %s", strerror(errno));
931         } else {
932             dev_region->mmap_addr = (uint64_t)(uintptr_t)mmap_addr;
933             DPRINT("    mmap_addr:       0x%016"PRIx64"\n",
934                    dev_region->mmap_addr);
935         }
936 
937         /* Return the address to QEMU so that it can translate the ufd
938          * fault addresses back.
939          */
940         msg_region->userspace_addr = (uintptr_t)(mmap_addr +
941                                                  dev_region->mmap_offset);
942         close(vmsg->fds[i]);
943     }
944 
945     /* Send the message back to qemu with the addresses filled in */
946     vmsg->fd_num = 0;
947     if (!vu_send_reply(dev, dev->sock, vmsg)) {
948         vu_panic(dev, "failed to respond to set-mem-table for postcopy");
949         return false;
950     }
951 
952     /* Wait for QEMU to confirm that it's registered the handler for the
953      * faults.
954      */
955     if (!dev->read_msg(dev, dev->sock, vmsg) ||
956         vmsg->size != sizeof(vmsg->payload.u64) ||
957         vmsg->payload.u64 != 0) {
958         vu_panic(dev, "failed to receive valid ack for postcopy set-mem-table");
959         return false;
960     }
961 
962     /* OK, now we can go and register the memory and generate faults */
963     (void)generate_faults(dev);
964 
965     return false;
966 }
967 
968 static bool
969 vu_set_mem_table_exec(VuDev *dev, VhostUserMsg *vmsg)
970 {
971     int i;
972     VhostUserMemory m = vmsg->payload.memory, *memory = &m;
973 
974     for (i = 0; i < dev->nregions; i++) {
975         VuDevRegion *r = &dev->regions[i];
976         void *m = (void *) (uintptr_t) r->mmap_addr;
977 
978         if (m) {
979             munmap(m, r->size + r->mmap_offset);
980         }
981     }
982     dev->nregions = memory->nregions;
983 
984     if (dev->postcopy_listening) {
985         return vu_set_mem_table_exec_postcopy(dev, vmsg);
986     }
987 
988     DPRINT("Nregions: %u\n", memory->nregions);
989     for (i = 0; i < dev->nregions; i++) {
990         void *mmap_addr;
991         VhostUserMemoryRegion *msg_region = &memory->regions[i];
992         VuDevRegion *dev_region = &dev->regions[i];
993 
994         DPRINT("Region %d\n", i);
995         DPRINT("    guest_phys_addr: 0x%016"PRIx64"\n",
996                msg_region->guest_phys_addr);
997         DPRINT("    memory_size:     0x%016"PRIx64"\n",
998                msg_region->memory_size);
999         DPRINT("    userspace_addr   0x%016"PRIx64"\n",
1000                msg_region->userspace_addr);
1001         DPRINT("    mmap_offset      0x%016"PRIx64"\n",
1002                msg_region->mmap_offset);
1003 
1004         dev_region->gpa = msg_region->guest_phys_addr;
1005         dev_region->size = msg_region->memory_size;
1006         dev_region->qva = msg_region->userspace_addr;
1007         dev_region->mmap_offset = msg_region->mmap_offset;
1008 
1009         /* We don't use offset argument of mmap() since the
1010          * mapped address has to be page aligned, and we use huge
1011          * pages.  */
1012         mmap_addr = mmap(0, dev_region->size + dev_region->mmap_offset,
1013                          PROT_READ | PROT_WRITE, MAP_SHARED | MAP_NORESERVE,
1014                          vmsg->fds[i], 0);
1015 
1016         if (mmap_addr == MAP_FAILED) {
1017             vu_panic(dev, "region mmap error: %s", strerror(errno));
1018         } else {
1019             dev_region->mmap_addr = (uint64_t)(uintptr_t)mmap_addr;
1020             DPRINT("    mmap_addr:       0x%016"PRIx64"\n",
1021                    dev_region->mmap_addr);
1022         }
1023 
1024         close(vmsg->fds[i]);
1025     }
1026 
1027     for (i = 0; i < dev->max_queues; i++) {
1028         if (dev->vq[i].vring.desc) {
1029             if (map_ring(dev, &dev->vq[i])) {
1030                 vu_panic(dev, "remapping queue %d during setmemtable", i);
1031             }
1032         }
1033     }
1034 
1035     return false;
1036 }
1037 
1038 static bool
1039 vu_set_log_base_exec(VuDev *dev, VhostUserMsg *vmsg)
1040 {
1041     int fd;
1042     uint64_t log_mmap_size, log_mmap_offset;
1043     void *rc;
1044 
1045     if (vmsg->fd_num != 1 ||
1046         vmsg->size != sizeof(vmsg->payload.log)) {
1047         vu_panic(dev, "Invalid log_base message");
1048         return true;
1049     }
1050 
1051     fd = vmsg->fds[0];
1052     log_mmap_offset = vmsg->payload.log.mmap_offset;
1053     log_mmap_size = vmsg->payload.log.mmap_size;
1054     DPRINT("Log mmap_offset: %"PRId64"\n", log_mmap_offset);
1055     DPRINT("Log mmap_size:   %"PRId64"\n", log_mmap_size);
1056 
1057     rc = mmap(0, log_mmap_size, PROT_READ | PROT_WRITE, MAP_SHARED, fd,
1058               log_mmap_offset);
1059     close(fd);
1060     if (rc == MAP_FAILED) {
1061         perror("log mmap error");
1062     }
1063 
1064     if (dev->log_table) {
1065         munmap(dev->log_table, dev->log_size);
1066     }
1067     dev->log_table = rc;
1068     dev->log_size = log_mmap_size;
1069 
1070     vmsg->size = sizeof(vmsg->payload.u64);
1071     vmsg->fd_num = 0;
1072 
1073     return true;
1074 }
1075 
1076 static bool
1077 vu_set_log_fd_exec(VuDev *dev, VhostUserMsg *vmsg)
1078 {
1079     if (vmsg->fd_num != 1) {
1080         vu_panic(dev, "Invalid log_fd message");
1081         return false;
1082     }
1083 
1084     if (dev->log_call_fd != -1) {
1085         close(dev->log_call_fd);
1086     }
1087     dev->log_call_fd = vmsg->fds[0];
1088     DPRINT("Got log_call_fd: %d\n", vmsg->fds[0]);
1089 
1090     return false;
1091 }
1092 
1093 static bool
1094 vu_set_vring_num_exec(VuDev *dev, VhostUserMsg *vmsg)
1095 {
1096     unsigned int index = vmsg->payload.state.index;
1097     unsigned int num = vmsg->payload.state.num;
1098 
1099     DPRINT("State.index: %u\n", index);
1100     DPRINT("State.num:   %u\n", num);
1101     dev->vq[index].vring.num = num;
1102 
1103     return false;
1104 }
1105 
1106 static bool
1107 vu_set_vring_addr_exec(VuDev *dev, VhostUserMsg *vmsg)
1108 {
1109     struct vhost_vring_addr addr = vmsg->payload.addr, *vra = &addr;
1110     unsigned int index = vra->index;
1111     VuVirtq *vq = &dev->vq[index];
1112 
1113     DPRINT("vhost_vring_addr:\n");
1114     DPRINT("    index:  %d\n", vra->index);
1115     DPRINT("    flags:  %d\n", vra->flags);
1116     DPRINT("    desc_user_addr:   0x%016" PRIx64 "\n", (uint64_t)vra->desc_user_addr);
1117     DPRINT("    used_user_addr:   0x%016" PRIx64 "\n", (uint64_t)vra->used_user_addr);
1118     DPRINT("    avail_user_addr:  0x%016" PRIx64 "\n", (uint64_t)vra->avail_user_addr);
1119     DPRINT("    log_guest_addr:   0x%016" PRIx64 "\n", (uint64_t)vra->log_guest_addr);
1120 
1121     vq->vra = *vra;
1122     vq->vring.flags = vra->flags;
1123     vq->vring.log_guest_addr = vra->log_guest_addr;
1124 
1125 
1126     if (map_ring(dev, vq)) {
1127         vu_panic(dev, "Invalid vring_addr message");
1128         return false;
1129     }
1130 
1131     vq->used_idx = le16toh(vq->vring.used->idx);
1132 
1133     if (vq->last_avail_idx != vq->used_idx) {
1134         bool resume = dev->iface->queue_is_processed_in_order &&
1135             dev->iface->queue_is_processed_in_order(dev, index);
1136 
1137         DPRINT("Last avail index != used index: %u != %u%s\n",
1138                vq->last_avail_idx, vq->used_idx,
1139                resume ? ", resuming" : "");
1140 
1141         if (resume) {
1142             vq->shadow_avail_idx = vq->last_avail_idx = vq->used_idx;
1143         }
1144     }
1145 
1146     return false;
1147 }
1148 
1149 static bool
1150 vu_set_vring_base_exec(VuDev *dev, VhostUserMsg *vmsg)
1151 {
1152     unsigned int index = vmsg->payload.state.index;
1153     unsigned int num = vmsg->payload.state.num;
1154 
1155     DPRINT("State.index: %u\n", index);
1156     DPRINT("State.num:   %u\n", num);
1157     dev->vq[index].shadow_avail_idx = dev->vq[index].last_avail_idx = num;
1158 
1159     return false;
1160 }
1161 
1162 static bool
1163 vu_get_vring_base_exec(VuDev *dev, VhostUserMsg *vmsg)
1164 {
1165     unsigned int index = vmsg->payload.state.index;
1166 
1167     DPRINT("State.index: %u\n", index);
1168     vmsg->payload.state.num = dev->vq[index].last_avail_idx;
1169     vmsg->size = sizeof(vmsg->payload.state);
1170 
1171     dev->vq[index].started = false;
1172     if (dev->iface->queue_set_started) {
1173         dev->iface->queue_set_started(dev, index, false);
1174     }
1175 
1176     if (dev->vq[index].call_fd != -1) {
1177         close(dev->vq[index].call_fd);
1178         dev->vq[index].call_fd = -1;
1179     }
1180     if (dev->vq[index].kick_fd != -1) {
1181         dev->remove_watch(dev, dev->vq[index].kick_fd);
1182         close(dev->vq[index].kick_fd);
1183         dev->vq[index].kick_fd = -1;
1184     }
1185 
1186     return true;
1187 }
1188 
1189 static bool
1190 vu_check_queue_msg_file(VuDev *dev, VhostUserMsg *vmsg)
1191 {
1192     int index = vmsg->payload.u64 & VHOST_USER_VRING_IDX_MASK;
1193     bool nofd = vmsg->payload.u64 & VHOST_USER_VRING_NOFD_MASK;
1194 
1195     if (index >= dev->max_queues) {
1196         vmsg_close_fds(vmsg);
1197         vu_panic(dev, "Invalid queue index: %u", index);
1198         return false;
1199     }
1200 
1201     if (nofd) {
1202         vmsg_close_fds(vmsg);
1203         return true;
1204     }
1205 
1206     if (vmsg->fd_num != 1) {
1207         vmsg_close_fds(vmsg);
1208         vu_panic(dev, "Invalid fds in request: %d", vmsg->request);
1209         return false;
1210     }
1211 
1212     return true;
1213 }
1214 
1215 static int
1216 inflight_desc_compare(const void *a, const void *b)
1217 {
1218     VuVirtqInflightDesc *desc0 = (VuVirtqInflightDesc *)a,
1219                         *desc1 = (VuVirtqInflightDesc *)b;
1220 
1221     if (desc1->counter > desc0->counter &&
1222         (desc1->counter - desc0->counter) < VIRTQUEUE_MAX_SIZE * 2) {
1223         return 1;
1224     }
1225 
1226     return -1;
1227 }
1228 
1229 static int
1230 vu_check_queue_inflights(VuDev *dev, VuVirtq *vq)
1231 {
1232     int i = 0;
1233 
1234     if (!vu_has_protocol_feature(dev, VHOST_USER_PROTOCOL_F_INFLIGHT_SHMFD)) {
1235         return 0;
1236     }
1237 
1238     if (unlikely(!vq->inflight)) {
1239         return -1;
1240     }
1241 
1242     if (unlikely(!vq->inflight->version)) {
1243         /* initialize the buffer */
1244         vq->inflight->version = INFLIGHT_VERSION;
1245         return 0;
1246     }
1247 
1248     vq->used_idx = le16toh(vq->vring.used->idx);
1249     vq->resubmit_num = 0;
1250     vq->resubmit_list = NULL;
1251     vq->counter = 0;
1252 
1253     if (unlikely(vq->inflight->used_idx != vq->used_idx)) {
1254         vq->inflight->desc[vq->inflight->last_batch_head].inflight = 0;
1255 
1256         barrier();
1257 
1258         vq->inflight->used_idx = vq->used_idx;
1259     }
1260 
1261     for (i = 0; i < vq->inflight->desc_num; i++) {
1262         if (vq->inflight->desc[i].inflight == 1) {
1263             vq->inuse++;
1264         }
1265     }
1266 
1267     vq->shadow_avail_idx = vq->last_avail_idx = vq->inuse + vq->used_idx;
1268 
1269     if (vq->inuse) {
1270         vq->resubmit_list = calloc(vq->inuse, sizeof(VuVirtqInflightDesc));
1271         if (!vq->resubmit_list) {
1272             return -1;
1273         }
1274 
1275         for (i = 0; i < vq->inflight->desc_num; i++) {
1276             if (vq->inflight->desc[i].inflight) {
1277                 vq->resubmit_list[vq->resubmit_num].index = i;
1278                 vq->resubmit_list[vq->resubmit_num].counter =
1279                                         vq->inflight->desc[i].counter;
1280                 vq->resubmit_num++;
1281             }
1282         }
1283 
1284         if (vq->resubmit_num > 1) {
1285             qsort(vq->resubmit_list, vq->resubmit_num,
1286                   sizeof(VuVirtqInflightDesc), inflight_desc_compare);
1287         }
1288         vq->counter = vq->resubmit_list[0].counter + 1;
1289     }
1290 
1291     /* in case of I/O hang after reconnecting */
1292     if (eventfd_write(vq->kick_fd, 1)) {
1293         return -1;
1294     }
1295 
1296     return 0;
1297 }
1298 
1299 static bool
1300 vu_set_vring_kick_exec(VuDev *dev, VhostUserMsg *vmsg)
1301 {
1302     int index = vmsg->payload.u64 & VHOST_USER_VRING_IDX_MASK;
1303     bool nofd = vmsg->payload.u64 & VHOST_USER_VRING_NOFD_MASK;
1304 
1305     DPRINT("u64: 0x%016"PRIx64"\n", vmsg->payload.u64);
1306 
1307     if (!vu_check_queue_msg_file(dev, vmsg)) {
1308         return false;
1309     }
1310 
1311     if (dev->vq[index].kick_fd != -1) {
1312         dev->remove_watch(dev, dev->vq[index].kick_fd);
1313         close(dev->vq[index].kick_fd);
1314         dev->vq[index].kick_fd = -1;
1315     }
1316 
1317     dev->vq[index].kick_fd = nofd ? -1 : vmsg->fds[0];
1318     DPRINT("Got kick_fd: %d for vq: %d\n", dev->vq[index].kick_fd, index);
1319 
1320     dev->vq[index].started = true;
1321     if (dev->iface->queue_set_started) {
1322         dev->iface->queue_set_started(dev, index, true);
1323     }
1324 
1325     if (dev->vq[index].kick_fd != -1 && dev->vq[index].handler) {
1326         dev->set_watch(dev, dev->vq[index].kick_fd, VU_WATCH_IN,
1327                        vu_kick_cb, (void *)(long)index);
1328 
1329         DPRINT("Waiting for kicks on fd: %d for vq: %d\n",
1330                dev->vq[index].kick_fd, index);
1331     }
1332 
1333     if (vu_check_queue_inflights(dev, &dev->vq[index])) {
1334         vu_panic(dev, "Failed to check inflights for vq: %d\n", index);
1335     }
1336 
1337     return false;
1338 }
1339 
1340 void vu_set_queue_handler(VuDev *dev, VuVirtq *vq,
1341                           vu_queue_handler_cb handler)
1342 {
1343     int qidx = vq - dev->vq;
1344 
1345     vq->handler = handler;
1346     if (vq->kick_fd >= 0) {
1347         if (handler) {
1348             dev->set_watch(dev, vq->kick_fd, VU_WATCH_IN,
1349                            vu_kick_cb, (void *)(long)qidx);
1350         } else {
1351             dev->remove_watch(dev, vq->kick_fd);
1352         }
1353     }
1354 }
1355 
1356 bool vu_set_queue_host_notifier(VuDev *dev, VuVirtq *vq, int fd,
1357                                 int size, int offset)
1358 {
1359     int qidx = vq - dev->vq;
1360     int fd_num = 0;
1361     VhostUserMsg vmsg = {
1362         .request = VHOST_USER_SLAVE_VRING_HOST_NOTIFIER_MSG,
1363         .flags = VHOST_USER_VERSION | VHOST_USER_NEED_REPLY_MASK,
1364         .size = sizeof(vmsg.payload.area),
1365         .payload.area = {
1366             .u64 = qidx & VHOST_USER_VRING_IDX_MASK,
1367             .size = size,
1368             .offset = offset,
1369         },
1370     };
1371 
1372     if (fd == -1) {
1373         vmsg.payload.area.u64 |= VHOST_USER_VRING_NOFD_MASK;
1374     } else {
1375         vmsg.fds[fd_num++] = fd;
1376     }
1377 
1378     vmsg.fd_num = fd_num;
1379 
1380     if (!vu_has_protocol_feature(dev, VHOST_USER_PROTOCOL_F_SLAVE_SEND_FD)) {
1381         return false;
1382     }
1383 
1384     pthread_mutex_lock(&dev->slave_mutex);
1385     if (!vu_message_write(dev, dev->slave_fd, &vmsg)) {
1386         pthread_mutex_unlock(&dev->slave_mutex);
1387         return false;
1388     }
1389 
1390     /* Also unlocks the slave_mutex */
1391     return vu_process_message_reply(dev, &vmsg);
1392 }
1393 
1394 static bool
1395 vu_set_vring_call_exec(VuDev *dev, VhostUserMsg *vmsg)
1396 {
1397     int index = vmsg->payload.u64 & VHOST_USER_VRING_IDX_MASK;
1398     bool nofd = vmsg->payload.u64 & VHOST_USER_VRING_NOFD_MASK;
1399 
1400     DPRINT("u64: 0x%016"PRIx64"\n", vmsg->payload.u64);
1401 
1402     if (!vu_check_queue_msg_file(dev, vmsg)) {
1403         return false;
1404     }
1405 
1406     if (dev->vq[index].call_fd != -1) {
1407         close(dev->vq[index].call_fd);
1408         dev->vq[index].call_fd = -1;
1409     }
1410 
1411     dev->vq[index].call_fd = nofd ? -1 : vmsg->fds[0];
1412 
1413     /* in case of I/O hang after reconnecting */
1414     if (dev->vq[index].call_fd != -1 && eventfd_write(vmsg->fds[0], 1)) {
1415         return -1;
1416     }
1417 
1418     DPRINT("Got call_fd: %d for vq: %d\n", dev->vq[index].call_fd, index);
1419 
1420     return false;
1421 }
1422 
1423 static bool
1424 vu_set_vring_err_exec(VuDev *dev, VhostUserMsg *vmsg)
1425 {
1426     int index = vmsg->payload.u64 & VHOST_USER_VRING_IDX_MASK;
1427     bool nofd = vmsg->payload.u64 & VHOST_USER_VRING_NOFD_MASK;
1428 
1429     DPRINT("u64: 0x%016"PRIx64"\n", vmsg->payload.u64);
1430 
1431     if (!vu_check_queue_msg_file(dev, vmsg)) {
1432         return false;
1433     }
1434 
1435     if (dev->vq[index].err_fd != -1) {
1436         close(dev->vq[index].err_fd);
1437         dev->vq[index].err_fd = -1;
1438     }
1439 
1440     dev->vq[index].err_fd = nofd ? -1 : vmsg->fds[0];
1441 
1442     return false;
1443 }
1444 
1445 static bool
1446 vu_get_protocol_features_exec(VuDev *dev, VhostUserMsg *vmsg)
1447 {
1448     /*
1449      * Note that we support, but intentionally do not set,
1450      * VHOST_USER_PROTOCOL_F_INBAND_NOTIFICATIONS. This means that
1451      * a device implementation can return it in its callback
1452      * (get_protocol_features) if it wants to use this for
1453      * simulation, but it is otherwise not desirable (if even
1454      * implemented by the master.)
1455      */
1456     uint64_t features = 1ULL << VHOST_USER_PROTOCOL_F_MQ |
1457                         1ULL << VHOST_USER_PROTOCOL_F_LOG_SHMFD |
1458                         1ULL << VHOST_USER_PROTOCOL_F_SLAVE_REQ |
1459                         1ULL << VHOST_USER_PROTOCOL_F_HOST_NOTIFIER |
1460                         1ULL << VHOST_USER_PROTOCOL_F_SLAVE_SEND_FD |
1461                         1ULL << VHOST_USER_PROTOCOL_F_REPLY_ACK |
1462                         1ULL << VHOST_USER_PROTOCOL_F_CONFIGURE_MEM_SLOTS;
1463 
1464     if (have_userfault()) {
1465         features |= 1ULL << VHOST_USER_PROTOCOL_F_PAGEFAULT;
1466     }
1467 
1468     if (dev->iface->get_config && dev->iface->set_config) {
1469         features |= 1ULL << VHOST_USER_PROTOCOL_F_CONFIG;
1470     }
1471 
1472     if (dev->iface->get_protocol_features) {
1473         features |= dev->iface->get_protocol_features(dev);
1474     }
1475 
1476     vmsg_set_reply_u64(vmsg, features);
1477     return true;
1478 }
1479 
1480 static bool
1481 vu_set_protocol_features_exec(VuDev *dev, VhostUserMsg *vmsg)
1482 {
1483     uint64_t features = vmsg->payload.u64;
1484 
1485     DPRINT("u64: 0x%016"PRIx64"\n", features);
1486 
1487     dev->protocol_features = vmsg->payload.u64;
1488 
1489     if (vu_has_protocol_feature(dev,
1490                                 VHOST_USER_PROTOCOL_F_INBAND_NOTIFICATIONS) &&
1491         (!vu_has_protocol_feature(dev, VHOST_USER_PROTOCOL_F_SLAVE_REQ) ||
1492          !vu_has_protocol_feature(dev, VHOST_USER_PROTOCOL_F_REPLY_ACK))) {
1493         /*
1494          * The use case for using messages for kick/call is simulation, to make
1495          * the kick and call synchronous. To actually get that behaviour, both
1496          * of the other features are required.
1497          * Theoretically, one could use only kick messages, or do them without
1498          * having F_REPLY_ACK, but too many (possibly pending) messages on the
1499          * socket will eventually cause the master to hang, to avoid this in
1500          * scenarios where not desired enforce that the settings are in a way
1501          * that actually enables the simulation case.
1502          */
1503         vu_panic(dev,
1504                  "F_IN_BAND_NOTIFICATIONS requires F_SLAVE_REQ && F_REPLY_ACK");
1505         return false;
1506     }
1507 
1508     if (dev->iface->set_protocol_features) {
1509         dev->iface->set_protocol_features(dev, features);
1510     }
1511 
1512     return false;
1513 }
1514 
1515 static bool
1516 vu_get_queue_num_exec(VuDev *dev, VhostUserMsg *vmsg)
1517 {
1518     vmsg_set_reply_u64(vmsg, dev->max_queues);
1519     return true;
1520 }
1521 
1522 static bool
1523 vu_set_vring_enable_exec(VuDev *dev, VhostUserMsg *vmsg)
1524 {
1525     unsigned int index = vmsg->payload.state.index;
1526     unsigned int enable = vmsg->payload.state.num;
1527 
1528     DPRINT("State.index: %u\n", index);
1529     DPRINT("State.enable:   %u\n", enable);
1530 
1531     if (index >= dev->max_queues) {
1532         vu_panic(dev, "Invalid vring_enable index: %u", index);
1533         return false;
1534     }
1535 
1536     dev->vq[index].enable = enable;
1537     return false;
1538 }
1539 
1540 static bool
1541 vu_set_slave_req_fd(VuDev *dev, VhostUserMsg *vmsg)
1542 {
1543     if (vmsg->fd_num != 1) {
1544         vu_panic(dev, "Invalid slave_req_fd message (%d fd's)", vmsg->fd_num);
1545         return false;
1546     }
1547 
1548     if (dev->slave_fd != -1) {
1549         close(dev->slave_fd);
1550     }
1551     dev->slave_fd = vmsg->fds[0];
1552     DPRINT("Got slave_fd: %d\n", vmsg->fds[0]);
1553 
1554     return false;
1555 }
1556 
1557 static bool
1558 vu_get_config(VuDev *dev, VhostUserMsg *vmsg)
1559 {
1560     int ret = -1;
1561 
1562     if (dev->iface->get_config) {
1563         ret = dev->iface->get_config(dev, vmsg->payload.config.region,
1564                                      vmsg->payload.config.size);
1565     }
1566 
1567     if (ret) {
1568         /* resize to zero to indicate an error to master */
1569         vmsg->size = 0;
1570     }
1571 
1572     return true;
1573 }
1574 
1575 static bool
1576 vu_set_config(VuDev *dev, VhostUserMsg *vmsg)
1577 {
1578     int ret = -1;
1579 
1580     if (dev->iface->set_config) {
1581         ret = dev->iface->set_config(dev, vmsg->payload.config.region,
1582                                      vmsg->payload.config.offset,
1583                                      vmsg->payload.config.size,
1584                                      vmsg->payload.config.flags);
1585         if (ret) {
1586             vu_panic(dev, "Set virtio configuration space failed");
1587         }
1588     }
1589 
1590     return false;
1591 }
1592 
1593 static bool
1594 vu_set_postcopy_advise(VuDev *dev, VhostUserMsg *vmsg)
1595 {
1596     dev->postcopy_ufd = -1;
1597 #ifdef UFFDIO_API
1598     struct uffdio_api api_struct;
1599 
1600     dev->postcopy_ufd = syscall(__NR_userfaultfd, O_CLOEXEC | O_NONBLOCK);
1601     vmsg->size = 0;
1602 #endif
1603 
1604     if (dev->postcopy_ufd == -1) {
1605         vu_panic(dev, "Userfaultfd not available: %s", strerror(errno));
1606         goto out;
1607     }
1608 
1609 #ifdef UFFDIO_API
1610     api_struct.api = UFFD_API;
1611     api_struct.features = 0;
1612     if (ioctl(dev->postcopy_ufd, UFFDIO_API, &api_struct)) {
1613         vu_panic(dev, "Failed UFFDIO_API: %s", strerror(errno));
1614         close(dev->postcopy_ufd);
1615         dev->postcopy_ufd = -1;
1616         goto out;
1617     }
1618     /* TODO: Stash feature flags somewhere */
1619 #endif
1620 
1621 out:
1622     /* Return a ufd to the QEMU */
1623     vmsg->fd_num = 1;
1624     vmsg->fds[0] = dev->postcopy_ufd;
1625     return true; /* = send a reply */
1626 }
1627 
1628 static bool
1629 vu_set_postcopy_listen(VuDev *dev, VhostUserMsg *vmsg)
1630 {
1631     if (dev->nregions) {
1632         vu_panic(dev, "Regions already registered at postcopy-listen");
1633         vmsg_set_reply_u64(vmsg, -1);
1634         return true;
1635     }
1636     dev->postcopy_listening = true;
1637 
1638     vmsg_set_reply_u64(vmsg, 0);
1639     return true;
1640 }
1641 
1642 static bool
1643 vu_set_postcopy_end(VuDev *dev, VhostUserMsg *vmsg)
1644 {
1645     DPRINT("%s: Entry\n", __func__);
1646     dev->postcopy_listening = false;
1647     if (dev->postcopy_ufd > 0) {
1648         close(dev->postcopy_ufd);
1649         dev->postcopy_ufd = -1;
1650         DPRINT("%s: Done close\n", __func__);
1651     }
1652 
1653     vmsg_set_reply_u64(vmsg, 0);
1654     DPRINT("%s: exit\n", __func__);
1655     return true;
1656 }
1657 
1658 static inline uint64_t
1659 vu_inflight_queue_size(uint16_t queue_size)
1660 {
1661     return ALIGN_UP(sizeof(VuDescStateSplit) * queue_size +
1662            sizeof(uint16_t), INFLIGHT_ALIGNMENT);
1663 }
1664 
1665 #ifdef MFD_ALLOW_SEALING
1666 static void *
1667 memfd_alloc(const char *name, size_t size, unsigned int flags, int *fd)
1668 {
1669     void *ptr;
1670     int ret;
1671 
1672     *fd = memfd_create(name, MFD_ALLOW_SEALING);
1673     if (*fd < 0) {
1674         return NULL;
1675     }
1676 
1677     ret = ftruncate(*fd, size);
1678     if (ret < 0) {
1679         close(*fd);
1680         return NULL;
1681     }
1682 
1683     ret = fcntl(*fd, F_ADD_SEALS, flags);
1684     if (ret < 0) {
1685         close(*fd);
1686         return NULL;
1687     }
1688 
1689     ptr = mmap(0, size, PROT_READ | PROT_WRITE, MAP_SHARED, *fd, 0);
1690     if (ptr == MAP_FAILED) {
1691         close(*fd);
1692         return NULL;
1693     }
1694 
1695     return ptr;
1696 }
1697 #endif
1698 
1699 static bool
1700 vu_get_inflight_fd(VuDev *dev, VhostUserMsg *vmsg)
1701 {
1702     int fd = -1;
1703     void *addr = NULL;
1704     uint64_t mmap_size;
1705     uint16_t num_queues, queue_size;
1706 
1707     if (vmsg->size != sizeof(vmsg->payload.inflight)) {
1708         vu_panic(dev, "Invalid get_inflight_fd message:%d", vmsg->size);
1709         vmsg->payload.inflight.mmap_size = 0;
1710         return true;
1711     }
1712 
1713     num_queues = vmsg->payload.inflight.num_queues;
1714     queue_size = vmsg->payload.inflight.queue_size;
1715 
1716     DPRINT("set_inflight_fd num_queues: %"PRId16"\n", num_queues);
1717     DPRINT("set_inflight_fd queue_size: %"PRId16"\n", queue_size);
1718 
1719     mmap_size = vu_inflight_queue_size(queue_size) * num_queues;
1720 
1721 #ifdef MFD_ALLOW_SEALING
1722     addr = memfd_alloc("vhost-inflight", mmap_size,
1723                        F_SEAL_GROW | F_SEAL_SHRINK | F_SEAL_SEAL,
1724                        &fd);
1725 #else
1726     vu_panic(dev, "Not implemented: memfd support is missing");
1727 #endif
1728 
1729     if (!addr) {
1730         vu_panic(dev, "Failed to alloc vhost inflight area");
1731         vmsg->payload.inflight.mmap_size = 0;
1732         return true;
1733     }
1734 
1735     memset(addr, 0, mmap_size);
1736 
1737     dev->inflight_info.addr = addr;
1738     dev->inflight_info.size = vmsg->payload.inflight.mmap_size = mmap_size;
1739     dev->inflight_info.fd = vmsg->fds[0] = fd;
1740     vmsg->fd_num = 1;
1741     vmsg->payload.inflight.mmap_offset = 0;
1742 
1743     DPRINT("send inflight mmap_size: %"PRId64"\n",
1744            vmsg->payload.inflight.mmap_size);
1745     DPRINT("send inflight mmap offset: %"PRId64"\n",
1746            vmsg->payload.inflight.mmap_offset);
1747 
1748     return true;
1749 }
1750 
1751 static bool
1752 vu_set_inflight_fd(VuDev *dev, VhostUserMsg *vmsg)
1753 {
1754     int fd, i;
1755     uint64_t mmap_size, mmap_offset;
1756     uint16_t num_queues, queue_size;
1757     void *rc;
1758 
1759     if (vmsg->fd_num != 1 ||
1760         vmsg->size != sizeof(vmsg->payload.inflight)) {
1761         vu_panic(dev, "Invalid set_inflight_fd message size:%d fds:%d",
1762                  vmsg->size, vmsg->fd_num);
1763         return false;
1764     }
1765 
1766     fd = vmsg->fds[0];
1767     mmap_size = vmsg->payload.inflight.mmap_size;
1768     mmap_offset = vmsg->payload.inflight.mmap_offset;
1769     num_queues = vmsg->payload.inflight.num_queues;
1770     queue_size = vmsg->payload.inflight.queue_size;
1771 
1772     DPRINT("set_inflight_fd mmap_size: %"PRId64"\n", mmap_size);
1773     DPRINT("set_inflight_fd mmap_offset: %"PRId64"\n", mmap_offset);
1774     DPRINT("set_inflight_fd num_queues: %"PRId16"\n", num_queues);
1775     DPRINT("set_inflight_fd queue_size: %"PRId16"\n", queue_size);
1776 
1777     rc = mmap(0, mmap_size, PROT_READ | PROT_WRITE, MAP_SHARED,
1778               fd, mmap_offset);
1779 
1780     if (rc == MAP_FAILED) {
1781         vu_panic(dev, "set_inflight_fd mmap error: %s", strerror(errno));
1782         return false;
1783     }
1784 
1785     if (dev->inflight_info.fd) {
1786         close(dev->inflight_info.fd);
1787     }
1788 
1789     if (dev->inflight_info.addr) {
1790         munmap(dev->inflight_info.addr, dev->inflight_info.size);
1791     }
1792 
1793     dev->inflight_info.fd = fd;
1794     dev->inflight_info.addr = rc;
1795     dev->inflight_info.size = mmap_size;
1796 
1797     for (i = 0; i < num_queues; i++) {
1798         dev->vq[i].inflight = (VuVirtqInflight *)rc;
1799         dev->vq[i].inflight->desc_num = queue_size;
1800         rc = (void *)((char *)rc + vu_inflight_queue_size(queue_size));
1801     }
1802 
1803     return false;
1804 }
1805 
1806 static bool
1807 vu_handle_vring_kick(VuDev *dev, VhostUserMsg *vmsg)
1808 {
1809     unsigned int index = vmsg->payload.state.index;
1810 
1811     if (index >= dev->max_queues) {
1812         vu_panic(dev, "Invalid queue index: %u", index);
1813         return false;
1814     }
1815 
1816     DPRINT("Got kick message: handler:%p idx:%u\n",
1817            dev->vq[index].handler, index);
1818 
1819     if (!dev->vq[index].started) {
1820         dev->vq[index].started = true;
1821 
1822         if (dev->iface->queue_set_started) {
1823             dev->iface->queue_set_started(dev, index, true);
1824         }
1825     }
1826 
1827     if (dev->vq[index].handler) {
1828         dev->vq[index].handler(dev, index);
1829     }
1830 
1831     return false;
1832 }
1833 
1834 static bool vu_handle_get_max_memslots(VuDev *dev, VhostUserMsg *vmsg)
1835 {
1836     vmsg_set_reply_u64(vmsg, VHOST_USER_MAX_RAM_SLOTS);
1837 
1838     DPRINT("u64: 0x%016"PRIx64"\n", (uint64_t) VHOST_USER_MAX_RAM_SLOTS);
1839 
1840     return true;
1841 }
1842 
1843 static bool
1844 vu_process_message(VuDev *dev, VhostUserMsg *vmsg)
1845 {
1846     int do_reply = 0;
1847 
1848     /* Print out generic part of the request. */
1849     DPRINT("================ Vhost user message ================\n");
1850     DPRINT("Request: %s (%d)\n", vu_request_to_string(vmsg->request),
1851            vmsg->request);
1852     DPRINT("Flags:   0x%x\n", vmsg->flags);
1853     DPRINT("Size:    %u\n", vmsg->size);
1854 
1855     if (vmsg->fd_num) {
1856         int i;
1857         DPRINT("Fds:");
1858         for (i = 0; i < vmsg->fd_num; i++) {
1859             DPRINT(" %d", vmsg->fds[i]);
1860         }
1861         DPRINT("\n");
1862     }
1863 
1864     if (dev->iface->process_msg &&
1865         dev->iface->process_msg(dev, vmsg, &do_reply)) {
1866         return do_reply;
1867     }
1868 
1869     switch (vmsg->request) {
1870     case VHOST_USER_GET_FEATURES:
1871         return vu_get_features_exec(dev, vmsg);
1872     case VHOST_USER_SET_FEATURES:
1873         return vu_set_features_exec(dev, vmsg);
1874     case VHOST_USER_GET_PROTOCOL_FEATURES:
1875         return vu_get_protocol_features_exec(dev, vmsg);
1876     case VHOST_USER_SET_PROTOCOL_FEATURES:
1877         return vu_set_protocol_features_exec(dev, vmsg);
1878     case VHOST_USER_SET_OWNER:
1879         return vu_set_owner_exec(dev, vmsg);
1880     case VHOST_USER_RESET_OWNER:
1881         return vu_reset_device_exec(dev, vmsg);
1882     case VHOST_USER_SET_MEM_TABLE:
1883         return vu_set_mem_table_exec(dev, vmsg);
1884     case VHOST_USER_SET_LOG_BASE:
1885         return vu_set_log_base_exec(dev, vmsg);
1886     case VHOST_USER_SET_LOG_FD:
1887         return vu_set_log_fd_exec(dev, vmsg);
1888     case VHOST_USER_SET_VRING_NUM:
1889         return vu_set_vring_num_exec(dev, vmsg);
1890     case VHOST_USER_SET_VRING_ADDR:
1891         return vu_set_vring_addr_exec(dev, vmsg);
1892     case VHOST_USER_SET_VRING_BASE:
1893         return vu_set_vring_base_exec(dev, vmsg);
1894     case VHOST_USER_GET_VRING_BASE:
1895         return vu_get_vring_base_exec(dev, vmsg);
1896     case VHOST_USER_SET_VRING_KICK:
1897         return vu_set_vring_kick_exec(dev, vmsg);
1898     case VHOST_USER_SET_VRING_CALL:
1899         return vu_set_vring_call_exec(dev, vmsg);
1900     case VHOST_USER_SET_VRING_ERR:
1901         return vu_set_vring_err_exec(dev, vmsg);
1902     case VHOST_USER_GET_QUEUE_NUM:
1903         return vu_get_queue_num_exec(dev, vmsg);
1904     case VHOST_USER_SET_VRING_ENABLE:
1905         return vu_set_vring_enable_exec(dev, vmsg);
1906     case VHOST_USER_SET_SLAVE_REQ_FD:
1907         return vu_set_slave_req_fd(dev, vmsg);
1908     case VHOST_USER_GET_CONFIG:
1909         return vu_get_config(dev, vmsg);
1910     case VHOST_USER_SET_CONFIG:
1911         return vu_set_config(dev, vmsg);
1912     case VHOST_USER_NONE:
1913         /* if you need processing before exit, override iface->process_msg */
1914         exit(0);
1915     case VHOST_USER_POSTCOPY_ADVISE:
1916         return vu_set_postcopy_advise(dev, vmsg);
1917     case VHOST_USER_POSTCOPY_LISTEN:
1918         return vu_set_postcopy_listen(dev, vmsg);
1919     case VHOST_USER_POSTCOPY_END:
1920         return vu_set_postcopy_end(dev, vmsg);
1921     case VHOST_USER_GET_INFLIGHT_FD:
1922         return vu_get_inflight_fd(dev, vmsg);
1923     case VHOST_USER_SET_INFLIGHT_FD:
1924         return vu_set_inflight_fd(dev, vmsg);
1925     case VHOST_USER_VRING_KICK:
1926         return vu_handle_vring_kick(dev, vmsg);
1927     case VHOST_USER_GET_MAX_MEM_SLOTS:
1928         return vu_handle_get_max_memslots(dev, vmsg);
1929     case VHOST_USER_ADD_MEM_REG:
1930         return vu_add_mem_reg(dev, vmsg);
1931     case VHOST_USER_REM_MEM_REG:
1932         return vu_rem_mem_reg(dev, vmsg);
1933     default:
1934         vmsg_close_fds(vmsg);
1935         vu_panic(dev, "Unhandled request: %d", vmsg->request);
1936     }
1937 
1938     return false;
1939 }
1940 
1941 bool
1942 vu_dispatch(VuDev *dev)
1943 {
1944     VhostUserMsg vmsg = { 0, };
1945     int reply_requested;
1946     bool need_reply, success = false;
1947 
1948     if (!dev->read_msg(dev, dev->sock, &vmsg)) {
1949         goto end;
1950     }
1951 
1952     need_reply = vmsg.flags & VHOST_USER_NEED_REPLY_MASK;
1953 
1954     reply_requested = vu_process_message(dev, &vmsg);
1955     if (!reply_requested && need_reply) {
1956         vmsg_set_reply_u64(&vmsg, 0);
1957         reply_requested = 1;
1958     }
1959 
1960     if (!reply_requested) {
1961         success = true;
1962         goto end;
1963     }
1964 
1965     if (!vu_send_reply(dev, dev->sock, &vmsg)) {
1966         goto end;
1967     }
1968 
1969     success = true;
1970 
1971 end:
1972     free(vmsg.data);
1973     return success;
1974 }
1975 
1976 void
1977 vu_deinit(VuDev *dev)
1978 {
1979     int i;
1980 
1981     for (i = 0; i < dev->nregions; i++) {
1982         VuDevRegion *r = &dev->regions[i];
1983         void *m = (void *) (uintptr_t) r->mmap_addr;
1984         if (m != MAP_FAILED) {
1985             munmap(m, r->size + r->mmap_offset);
1986         }
1987     }
1988     dev->nregions = 0;
1989 
1990     for (i = 0; i < dev->max_queues; i++) {
1991         VuVirtq *vq = &dev->vq[i];
1992 
1993         if (vq->call_fd != -1) {
1994             close(vq->call_fd);
1995             vq->call_fd = -1;
1996         }
1997 
1998         if (vq->kick_fd != -1) {
1999             dev->remove_watch(dev, vq->kick_fd);
2000             close(vq->kick_fd);
2001             vq->kick_fd = -1;
2002         }
2003 
2004         if (vq->err_fd != -1) {
2005             close(vq->err_fd);
2006             vq->err_fd = -1;
2007         }
2008 
2009         if (vq->resubmit_list) {
2010             free(vq->resubmit_list);
2011             vq->resubmit_list = NULL;
2012         }
2013 
2014         vq->inflight = NULL;
2015     }
2016 
2017     if (dev->inflight_info.addr) {
2018         munmap(dev->inflight_info.addr, dev->inflight_info.size);
2019         dev->inflight_info.addr = NULL;
2020     }
2021 
2022     if (dev->inflight_info.fd > 0) {
2023         close(dev->inflight_info.fd);
2024         dev->inflight_info.fd = -1;
2025     }
2026 
2027     vu_close_log(dev);
2028     if (dev->slave_fd != -1) {
2029         close(dev->slave_fd);
2030         dev->slave_fd = -1;
2031     }
2032     pthread_mutex_destroy(&dev->slave_mutex);
2033 
2034     if (dev->sock != -1) {
2035         close(dev->sock);
2036     }
2037 
2038     free(dev->vq);
2039     dev->vq = NULL;
2040 }
2041 
2042 bool
2043 vu_init(VuDev *dev,
2044         uint16_t max_queues,
2045         int socket,
2046         vu_panic_cb panic,
2047         vu_read_msg_cb read_msg,
2048         vu_set_watch_cb set_watch,
2049         vu_remove_watch_cb remove_watch,
2050         const VuDevIface *iface)
2051 {
2052     uint16_t i;
2053 
2054     assert(max_queues > 0);
2055     assert(socket >= 0);
2056     assert(set_watch);
2057     assert(remove_watch);
2058     assert(iface);
2059     assert(panic);
2060 
2061     memset(dev, 0, sizeof(*dev));
2062 
2063     dev->sock = socket;
2064     dev->panic = panic;
2065     dev->read_msg = read_msg ? read_msg : vu_message_read_default;
2066     dev->set_watch = set_watch;
2067     dev->remove_watch = remove_watch;
2068     dev->iface = iface;
2069     dev->log_call_fd = -1;
2070     pthread_mutex_init(&dev->slave_mutex, NULL);
2071     dev->slave_fd = -1;
2072     dev->max_queues = max_queues;
2073 
2074     dev->vq = malloc(max_queues * sizeof(dev->vq[0]));
2075     if (!dev->vq) {
2076         DPRINT("%s: failed to malloc virtqueues\n", __func__);
2077         return false;
2078     }
2079 
2080     for (i = 0; i < max_queues; i++) {
2081         dev->vq[i] = (VuVirtq) {
2082             .call_fd = -1, .kick_fd = -1, .err_fd = -1,
2083             .notification = true,
2084         };
2085     }
2086 
2087     return true;
2088 }
2089 
2090 VuVirtq *
2091 vu_get_queue(VuDev *dev, int qidx)
2092 {
2093     assert(qidx < dev->max_queues);
2094     return &dev->vq[qidx];
2095 }
2096 
2097 bool
2098 vu_queue_enabled(VuDev *dev, VuVirtq *vq)
2099 {
2100     return vq->enable;
2101 }
2102 
2103 bool
2104 vu_queue_started(const VuDev *dev, const VuVirtq *vq)
2105 {
2106     return vq->started;
2107 }
2108 
2109 static inline uint16_t
2110 vring_avail_flags(VuVirtq *vq)
2111 {
2112     return le16toh(vq->vring.avail->flags);
2113 }
2114 
2115 static inline uint16_t
2116 vring_avail_idx(VuVirtq *vq)
2117 {
2118     vq->shadow_avail_idx = le16toh(vq->vring.avail->idx);
2119 
2120     return vq->shadow_avail_idx;
2121 }
2122 
2123 static inline uint16_t
2124 vring_avail_ring(VuVirtq *vq, int i)
2125 {
2126     return le16toh(vq->vring.avail->ring[i]);
2127 }
2128 
2129 static inline uint16_t
2130 vring_get_used_event(VuVirtq *vq)
2131 {
2132     return vring_avail_ring(vq, vq->vring.num);
2133 }
2134 
2135 static int
2136 virtqueue_num_heads(VuDev *dev, VuVirtq *vq, unsigned int idx)
2137 {
2138     uint16_t num_heads = vring_avail_idx(vq) - idx;
2139 
2140     /* Check it isn't doing very strange things with descriptor numbers. */
2141     if (num_heads > vq->vring.num) {
2142         vu_panic(dev, "Guest moved used index from %u to %u",
2143                  idx, vq->shadow_avail_idx);
2144         return -1;
2145     }
2146     if (num_heads) {
2147         /* On success, callers read a descriptor at vq->last_avail_idx.
2148          * Make sure descriptor read does not bypass avail index read. */
2149         smp_rmb();
2150     }
2151 
2152     return num_heads;
2153 }
2154 
2155 static bool
2156 virtqueue_get_head(VuDev *dev, VuVirtq *vq,
2157                    unsigned int idx, unsigned int *head)
2158 {
2159     /* Grab the next descriptor number they're advertising, and increment
2160      * the index we've seen. */
2161     *head = vring_avail_ring(vq, idx % vq->vring.num);
2162 
2163     /* If their number is silly, that's a fatal mistake. */
2164     if (*head >= vq->vring.num) {
2165         vu_panic(dev, "Guest says index %u is available", *head);
2166         return false;
2167     }
2168 
2169     return true;
2170 }
2171 
2172 static int
2173 virtqueue_read_indirect_desc(VuDev *dev, struct vring_desc *desc,
2174                              uint64_t addr, size_t len)
2175 {
2176     struct vring_desc *ori_desc;
2177     uint64_t read_len;
2178 
2179     if (len > (VIRTQUEUE_MAX_SIZE * sizeof(struct vring_desc))) {
2180         return -1;
2181     }
2182 
2183     if (len == 0) {
2184         return -1;
2185     }
2186 
2187     while (len) {
2188         read_len = len;
2189         ori_desc = vu_gpa_to_va(dev, &read_len, addr);
2190         if (!ori_desc) {
2191             return -1;
2192         }
2193 
2194         memcpy(desc, ori_desc, read_len);
2195         len -= read_len;
2196         addr += read_len;
2197         desc += read_len;
2198     }
2199 
2200     return 0;
2201 }
2202 
2203 enum {
2204     VIRTQUEUE_READ_DESC_ERROR = -1,
2205     VIRTQUEUE_READ_DESC_DONE = 0,   /* end of chain */
2206     VIRTQUEUE_READ_DESC_MORE = 1,   /* more buffers in chain */
2207 };
2208 
2209 static int
2210 virtqueue_read_next_desc(VuDev *dev, struct vring_desc *desc,
2211                          int i, unsigned int max, unsigned int *next)
2212 {
2213     /* If this descriptor says it doesn't chain, we're done. */
2214     if (!(le16toh(desc[i].flags) & VRING_DESC_F_NEXT)) {
2215         return VIRTQUEUE_READ_DESC_DONE;
2216     }
2217 
2218     /* Check they're not leading us off end of descriptors. */
2219     *next = le16toh(desc[i].next);
2220     /* Make sure compiler knows to grab that: we don't want it changing! */
2221     smp_wmb();
2222 
2223     if (*next >= max) {
2224         vu_panic(dev, "Desc next is %u", *next);
2225         return VIRTQUEUE_READ_DESC_ERROR;
2226     }
2227 
2228     return VIRTQUEUE_READ_DESC_MORE;
2229 }
2230 
2231 void
2232 vu_queue_get_avail_bytes(VuDev *dev, VuVirtq *vq, unsigned int *in_bytes,
2233                          unsigned int *out_bytes,
2234                          unsigned max_in_bytes, unsigned max_out_bytes)
2235 {
2236     unsigned int idx;
2237     unsigned int total_bufs, in_total, out_total;
2238     int rc;
2239 
2240     idx = vq->last_avail_idx;
2241 
2242     total_bufs = in_total = out_total = 0;
2243     if (unlikely(dev->broken) ||
2244         unlikely(!vq->vring.avail)) {
2245         goto done;
2246     }
2247 
2248     while ((rc = virtqueue_num_heads(dev, vq, idx)) > 0) {
2249         unsigned int max, desc_len, num_bufs, indirect = 0;
2250         uint64_t desc_addr, read_len;
2251         struct vring_desc *desc;
2252         struct vring_desc desc_buf[VIRTQUEUE_MAX_SIZE];
2253         unsigned int i;
2254 
2255         max = vq->vring.num;
2256         num_bufs = total_bufs;
2257         if (!virtqueue_get_head(dev, vq, idx++, &i)) {
2258             goto err;
2259         }
2260         desc = vq->vring.desc;
2261 
2262         if (le16toh(desc[i].flags) & VRING_DESC_F_INDIRECT) {
2263             if (le32toh(desc[i].len) % sizeof(struct vring_desc)) {
2264                 vu_panic(dev, "Invalid size for indirect buffer table");
2265                 goto err;
2266             }
2267 
2268             /* If we've got too many, that implies a descriptor loop. */
2269             if (num_bufs >= max) {
2270                 vu_panic(dev, "Looped descriptor");
2271                 goto err;
2272             }
2273 
2274             /* loop over the indirect descriptor table */
2275             indirect = 1;
2276             desc_addr = le64toh(desc[i].addr);
2277             desc_len = le32toh(desc[i].len);
2278             max = desc_len / sizeof(struct vring_desc);
2279             read_len = desc_len;
2280             desc = vu_gpa_to_va(dev, &read_len, desc_addr);
2281             if (unlikely(desc && read_len != desc_len)) {
2282                 /* Failed to use zero copy */
2283                 desc = NULL;
2284                 if (!virtqueue_read_indirect_desc(dev, desc_buf,
2285                                                   desc_addr,
2286                                                   desc_len)) {
2287                     desc = desc_buf;
2288                 }
2289             }
2290             if (!desc) {
2291                 vu_panic(dev, "Invalid indirect buffer table");
2292                 goto err;
2293             }
2294             num_bufs = i = 0;
2295         }
2296 
2297         do {
2298             /* If we've got too many, that implies a descriptor loop. */
2299             if (++num_bufs > max) {
2300                 vu_panic(dev, "Looped descriptor");
2301                 goto err;
2302             }
2303 
2304             if (le16toh(desc[i].flags) & VRING_DESC_F_WRITE) {
2305                 in_total += le32toh(desc[i].len);
2306             } else {
2307                 out_total += le32toh(desc[i].len);
2308             }
2309             if (in_total >= max_in_bytes && out_total >= max_out_bytes) {
2310                 goto done;
2311             }
2312             rc = virtqueue_read_next_desc(dev, desc, i, max, &i);
2313         } while (rc == VIRTQUEUE_READ_DESC_MORE);
2314 
2315         if (rc == VIRTQUEUE_READ_DESC_ERROR) {
2316             goto err;
2317         }
2318 
2319         if (!indirect) {
2320             total_bufs = num_bufs;
2321         } else {
2322             total_bufs++;
2323         }
2324     }
2325     if (rc < 0) {
2326         goto err;
2327     }
2328 done:
2329     if (in_bytes) {
2330         *in_bytes = in_total;
2331     }
2332     if (out_bytes) {
2333         *out_bytes = out_total;
2334     }
2335     return;
2336 
2337 err:
2338     in_total = out_total = 0;
2339     goto done;
2340 }
2341 
2342 bool
2343 vu_queue_avail_bytes(VuDev *dev, VuVirtq *vq, unsigned int in_bytes,
2344                      unsigned int out_bytes)
2345 {
2346     unsigned int in_total, out_total;
2347 
2348     vu_queue_get_avail_bytes(dev, vq, &in_total, &out_total,
2349                              in_bytes, out_bytes);
2350 
2351     return in_bytes <= in_total && out_bytes <= out_total;
2352 }
2353 
2354 /* Fetch avail_idx from VQ memory only when we really need to know if
2355  * guest has added some buffers. */
2356 bool
2357 vu_queue_empty(VuDev *dev, VuVirtq *vq)
2358 {
2359     if (unlikely(dev->broken) ||
2360         unlikely(!vq->vring.avail)) {
2361         return true;
2362     }
2363 
2364     if (vq->shadow_avail_idx != vq->last_avail_idx) {
2365         return false;
2366     }
2367 
2368     return vring_avail_idx(vq) == vq->last_avail_idx;
2369 }
2370 
2371 static bool
2372 vring_notify(VuDev *dev, VuVirtq *vq)
2373 {
2374     uint16_t old, new;
2375     bool v;
2376 
2377     /* We need to expose used array entries before checking used event. */
2378     smp_mb();
2379 
2380     /* Always notify when queue is empty (when feature acknowledge) */
2381     if (vu_has_feature(dev, VIRTIO_F_NOTIFY_ON_EMPTY) &&
2382         !vq->inuse && vu_queue_empty(dev, vq)) {
2383         return true;
2384     }
2385 
2386     if (!vu_has_feature(dev, VIRTIO_RING_F_EVENT_IDX)) {
2387         return !(vring_avail_flags(vq) & VRING_AVAIL_F_NO_INTERRUPT);
2388     }
2389 
2390     v = vq->signalled_used_valid;
2391     vq->signalled_used_valid = true;
2392     old = vq->signalled_used;
2393     new = vq->signalled_used = vq->used_idx;
2394     return !v || vring_need_event(vring_get_used_event(vq), new, old);
2395 }
2396 
2397 static void _vu_queue_notify(VuDev *dev, VuVirtq *vq, bool sync)
2398 {
2399     if (unlikely(dev->broken) ||
2400         unlikely(!vq->vring.avail)) {
2401         return;
2402     }
2403 
2404     if (!vring_notify(dev, vq)) {
2405         DPRINT("skipped notify...\n");
2406         return;
2407     }
2408 
2409     if (vq->call_fd < 0 &&
2410         vu_has_protocol_feature(dev,
2411                                 VHOST_USER_PROTOCOL_F_INBAND_NOTIFICATIONS) &&
2412         vu_has_protocol_feature(dev, VHOST_USER_PROTOCOL_F_SLAVE_REQ)) {
2413         VhostUserMsg vmsg = {
2414             .request = VHOST_USER_SLAVE_VRING_CALL,
2415             .flags = VHOST_USER_VERSION,
2416             .size = sizeof(vmsg.payload.state),
2417             .payload.state = {
2418                 .index = vq - dev->vq,
2419             },
2420         };
2421         bool ack = sync &&
2422                    vu_has_protocol_feature(dev,
2423                                            VHOST_USER_PROTOCOL_F_REPLY_ACK);
2424 
2425         if (ack) {
2426             vmsg.flags |= VHOST_USER_NEED_REPLY_MASK;
2427         }
2428 
2429         vu_message_write(dev, dev->slave_fd, &vmsg);
2430         if (ack) {
2431             vu_message_read_default(dev, dev->slave_fd, &vmsg);
2432         }
2433         return;
2434     }
2435 
2436     if (eventfd_write(vq->call_fd, 1) < 0) {
2437         vu_panic(dev, "Error writing eventfd: %s", strerror(errno));
2438     }
2439 }
2440 
2441 void vu_queue_notify(VuDev *dev, VuVirtq *vq)
2442 {
2443     _vu_queue_notify(dev, vq, false);
2444 }
2445 
2446 void vu_queue_notify_sync(VuDev *dev, VuVirtq *vq)
2447 {
2448     _vu_queue_notify(dev, vq, true);
2449 }
2450 
2451 static inline void
2452 vring_used_flags_set_bit(VuVirtq *vq, int mask)
2453 {
2454     uint16_t *flags;
2455 
2456     flags = (uint16_t *)((char*)vq->vring.used +
2457                          offsetof(struct vring_used, flags));
2458     *flags = htole16(le16toh(*flags) | mask);
2459 }
2460 
2461 static inline void
2462 vring_used_flags_unset_bit(VuVirtq *vq, int mask)
2463 {
2464     uint16_t *flags;
2465 
2466     flags = (uint16_t *)((char*)vq->vring.used +
2467                          offsetof(struct vring_used, flags));
2468     *flags = htole16(le16toh(*flags) & ~mask);
2469 }
2470 
2471 static inline void
2472 vring_set_avail_event(VuVirtq *vq, uint16_t val)
2473 {
2474     uint16_t *avail;
2475 
2476     if (!vq->notification) {
2477         return;
2478     }
2479 
2480     avail = (uint16_t *)&vq->vring.used->ring[vq->vring.num];
2481     *avail = htole16(val);
2482 }
2483 
2484 void
2485 vu_queue_set_notification(VuDev *dev, VuVirtq *vq, int enable)
2486 {
2487     vq->notification = enable;
2488     if (vu_has_feature(dev, VIRTIO_RING_F_EVENT_IDX)) {
2489         vring_set_avail_event(vq, vring_avail_idx(vq));
2490     } else if (enable) {
2491         vring_used_flags_unset_bit(vq, VRING_USED_F_NO_NOTIFY);
2492     } else {
2493         vring_used_flags_set_bit(vq, VRING_USED_F_NO_NOTIFY);
2494     }
2495     if (enable) {
2496         /* Expose avail event/used flags before caller checks the avail idx. */
2497         smp_mb();
2498     }
2499 }
2500 
2501 static bool
2502 virtqueue_map_desc(VuDev *dev,
2503                    unsigned int *p_num_sg, struct iovec *iov,
2504                    unsigned int max_num_sg, bool is_write,
2505                    uint64_t pa, size_t sz)
2506 {
2507     unsigned num_sg = *p_num_sg;
2508 
2509     assert(num_sg <= max_num_sg);
2510 
2511     if (!sz) {
2512         vu_panic(dev, "virtio: zero sized buffers are not allowed");
2513         return false;
2514     }
2515 
2516     while (sz) {
2517         uint64_t len = sz;
2518 
2519         if (num_sg == max_num_sg) {
2520             vu_panic(dev, "virtio: too many descriptors in indirect table");
2521             return false;
2522         }
2523 
2524         iov[num_sg].iov_base = vu_gpa_to_va(dev, &len, pa);
2525         if (iov[num_sg].iov_base == NULL) {
2526             vu_panic(dev, "virtio: invalid address for buffers");
2527             return false;
2528         }
2529         iov[num_sg].iov_len = len;
2530         num_sg++;
2531         sz -= len;
2532         pa += len;
2533     }
2534 
2535     *p_num_sg = num_sg;
2536     return true;
2537 }
2538 
2539 static void *
2540 virtqueue_alloc_element(size_t sz,
2541                                      unsigned out_num, unsigned in_num)
2542 {
2543     VuVirtqElement *elem;
2544     size_t in_sg_ofs = ALIGN_UP(sz, __alignof__(elem->in_sg[0]));
2545     size_t out_sg_ofs = in_sg_ofs + in_num * sizeof(elem->in_sg[0]);
2546     size_t out_sg_end = out_sg_ofs + out_num * sizeof(elem->out_sg[0]);
2547 
2548     assert(sz >= sizeof(VuVirtqElement));
2549     elem = malloc(out_sg_end);
2550     elem->out_num = out_num;
2551     elem->in_num = in_num;
2552     elem->in_sg = (void *)elem + in_sg_ofs;
2553     elem->out_sg = (void *)elem + out_sg_ofs;
2554     return elem;
2555 }
2556 
2557 static void *
2558 vu_queue_map_desc(VuDev *dev, VuVirtq *vq, unsigned int idx, size_t sz)
2559 {
2560     struct vring_desc *desc = vq->vring.desc;
2561     uint64_t desc_addr, read_len;
2562     unsigned int desc_len;
2563     unsigned int max = vq->vring.num;
2564     unsigned int i = idx;
2565     VuVirtqElement *elem;
2566     unsigned int out_num = 0, in_num = 0;
2567     struct iovec iov[VIRTQUEUE_MAX_SIZE];
2568     struct vring_desc desc_buf[VIRTQUEUE_MAX_SIZE];
2569     int rc;
2570 
2571     if (le16toh(desc[i].flags) & VRING_DESC_F_INDIRECT) {
2572         if (le32toh(desc[i].len) % sizeof(struct vring_desc)) {
2573             vu_panic(dev, "Invalid size for indirect buffer table");
2574             return NULL;
2575         }
2576 
2577         /* loop over the indirect descriptor table */
2578         desc_addr = le64toh(desc[i].addr);
2579         desc_len = le32toh(desc[i].len);
2580         max = desc_len / sizeof(struct vring_desc);
2581         read_len = desc_len;
2582         desc = vu_gpa_to_va(dev, &read_len, desc_addr);
2583         if (unlikely(desc && read_len != desc_len)) {
2584             /* Failed to use zero copy */
2585             desc = NULL;
2586             if (!virtqueue_read_indirect_desc(dev, desc_buf,
2587                                               desc_addr,
2588                                               desc_len)) {
2589                 desc = desc_buf;
2590             }
2591         }
2592         if (!desc) {
2593             vu_panic(dev, "Invalid indirect buffer table");
2594             return NULL;
2595         }
2596         i = 0;
2597     }
2598 
2599     /* Collect all the descriptors */
2600     do {
2601         if (le16toh(desc[i].flags) & VRING_DESC_F_WRITE) {
2602             if (!virtqueue_map_desc(dev, &in_num, iov + out_num,
2603                                VIRTQUEUE_MAX_SIZE - out_num, true,
2604                                le64toh(desc[i].addr),
2605                                le32toh(desc[i].len))) {
2606                 return NULL;
2607             }
2608         } else {
2609             if (in_num) {
2610                 vu_panic(dev, "Incorrect order for descriptors");
2611                 return NULL;
2612             }
2613             if (!virtqueue_map_desc(dev, &out_num, iov,
2614                                VIRTQUEUE_MAX_SIZE, false,
2615                                le64toh(desc[i].addr),
2616                                le32toh(desc[i].len))) {
2617                 return NULL;
2618             }
2619         }
2620 
2621         /* If we've got too many, that implies a descriptor loop. */
2622         if ((in_num + out_num) > max) {
2623             vu_panic(dev, "Looped descriptor");
2624             return NULL;
2625         }
2626         rc = virtqueue_read_next_desc(dev, desc, i, max, &i);
2627     } while (rc == VIRTQUEUE_READ_DESC_MORE);
2628 
2629     if (rc == VIRTQUEUE_READ_DESC_ERROR) {
2630         vu_panic(dev, "read descriptor error");
2631         return NULL;
2632     }
2633 
2634     /* Now copy what we have collected and mapped */
2635     elem = virtqueue_alloc_element(sz, out_num, in_num);
2636     elem->index = idx;
2637     for (i = 0; i < out_num; i++) {
2638         elem->out_sg[i] = iov[i];
2639     }
2640     for (i = 0; i < in_num; i++) {
2641         elem->in_sg[i] = iov[out_num + i];
2642     }
2643 
2644     return elem;
2645 }
2646 
2647 static int
2648 vu_queue_inflight_get(VuDev *dev, VuVirtq *vq, int desc_idx)
2649 {
2650     if (!vu_has_protocol_feature(dev, VHOST_USER_PROTOCOL_F_INFLIGHT_SHMFD)) {
2651         return 0;
2652     }
2653 
2654     if (unlikely(!vq->inflight)) {
2655         return -1;
2656     }
2657 
2658     vq->inflight->desc[desc_idx].counter = vq->counter++;
2659     vq->inflight->desc[desc_idx].inflight = 1;
2660 
2661     return 0;
2662 }
2663 
2664 static int
2665 vu_queue_inflight_pre_put(VuDev *dev, VuVirtq *vq, int desc_idx)
2666 {
2667     if (!vu_has_protocol_feature(dev, VHOST_USER_PROTOCOL_F_INFLIGHT_SHMFD)) {
2668         return 0;
2669     }
2670 
2671     if (unlikely(!vq->inflight)) {
2672         return -1;
2673     }
2674 
2675     vq->inflight->last_batch_head = desc_idx;
2676 
2677     return 0;
2678 }
2679 
2680 static int
2681 vu_queue_inflight_post_put(VuDev *dev, VuVirtq *vq, int desc_idx)
2682 {
2683     if (!vu_has_protocol_feature(dev, VHOST_USER_PROTOCOL_F_INFLIGHT_SHMFD)) {
2684         return 0;
2685     }
2686 
2687     if (unlikely(!vq->inflight)) {
2688         return -1;
2689     }
2690 
2691     barrier();
2692 
2693     vq->inflight->desc[desc_idx].inflight = 0;
2694 
2695     barrier();
2696 
2697     vq->inflight->used_idx = vq->used_idx;
2698 
2699     return 0;
2700 }
2701 
2702 void *
2703 vu_queue_pop(VuDev *dev, VuVirtq *vq, size_t sz)
2704 {
2705     int i;
2706     unsigned int head;
2707     VuVirtqElement *elem;
2708 
2709     if (unlikely(dev->broken) ||
2710         unlikely(!vq->vring.avail)) {
2711         return NULL;
2712     }
2713 
2714     if (unlikely(vq->resubmit_list && vq->resubmit_num > 0)) {
2715         i = (--vq->resubmit_num);
2716         elem = vu_queue_map_desc(dev, vq, vq->resubmit_list[i].index, sz);
2717 
2718         if (!vq->resubmit_num) {
2719             free(vq->resubmit_list);
2720             vq->resubmit_list = NULL;
2721         }
2722 
2723         return elem;
2724     }
2725 
2726     if (vu_queue_empty(dev, vq)) {
2727         return NULL;
2728     }
2729     /*
2730      * Needed after virtio_queue_empty(), see comment in
2731      * virtqueue_num_heads().
2732      */
2733     smp_rmb();
2734 
2735     if (vq->inuse >= vq->vring.num) {
2736         vu_panic(dev, "Virtqueue size exceeded");
2737         return NULL;
2738     }
2739 
2740     if (!virtqueue_get_head(dev, vq, vq->last_avail_idx++, &head)) {
2741         return NULL;
2742     }
2743 
2744     if (vu_has_feature(dev, VIRTIO_RING_F_EVENT_IDX)) {
2745         vring_set_avail_event(vq, vq->last_avail_idx);
2746     }
2747 
2748     elem = vu_queue_map_desc(dev, vq, head, sz);
2749 
2750     if (!elem) {
2751         return NULL;
2752     }
2753 
2754     vq->inuse++;
2755 
2756     vu_queue_inflight_get(dev, vq, head);
2757 
2758     return elem;
2759 }
2760 
2761 static void
2762 vu_queue_detach_element(VuDev *dev, VuVirtq *vq, VuVirtqElement *elem,
2763                         size_t len)
2764 {
2765     vq->inuse--;
2766     /* unmap, when DMA support is added */
2767 }
2768 
2769 void
2770 vu_queue_unpop(VuDev *dev, VuVirtq *vq, VuVirtqElement *elem,
2771                size_t len)
2772 {
2773     vq->last_avail_idx--;
2774     vu_queue_detach_element(dev, vq, elem, len);
2775 }
2776 
2777 bool
2778 vu_queue_rewind(VuDev *dev, VuVirtq *vq, unsigned int num)
2779 {
2780     if (num > vq->inuse) {
2781         return false;
2782     }
2783     vq->last_avail_idx -= num;
2784     vq->inuse -= num;
2785     return true;
2786 }
2787 
2788 static inline
2789 void vring_used_write(VuDev *dev, VuVirtq *vq,
2790                       struct vring_used_elem *uelem, int i)
2791 {
2792     struct vring_used *used = vq->vring.used;
2793 
2794     used->ring[i] = *uelem;
2795     vu_log_write(dev, vq->vring.log_guest_addr +
2796                  offsetof(struct vring_used, ring[i]),
2797                  sizeof(used->ring[i]));
2798 }
2799 
2800 
2801 static void
2802 vu_log_queue_fill(VuDev *dev, VuVirtq *vq,
2803                   const VuVirtqElement *elem,
2804                   unsigned int len)
2805 {
2806     struct vring_desc *desc = vq->vring.desc;
2807     unsigned int i, max, min, desc_len;
2808     uint64_t desc_addr, read_len;
2809     struct vring_desc desc_buf[VIRTQUEUE_MAX_SIZE];
2810     unsigned num_bufs = 0;
2811 
2812     max = vq->vring.num;
2813     i = elem->index;
2814 
2815     if (le16toh(desc[i].flags) & VRING_DESC_F_INDIRECT) {
2816         if (le32toh(desc[i].len) % sizeof(struct vring_desc)) {
2817             vu_panic(dev, "Invalid size for indirect buffer table");
2818             return;
2819         }
2820 
2821         /* loop over the indirect descriptor table */
2822         desc_addr = le64toh(desc[i].addr);
2823         desc_len = le32toh(desc[i].len);
2824         max = desc_len / sizeof(struct vring_desc);
2825         read_len = desc_len;
2826         desc = vu_gpa_to_va(dev, &read_len, desc_addr);
2827         if (unlikely(desc && read_len != desc_len)) {
2828             /* Failed to use zero copy */
2829             desc = NULL;
2830             if (!virtqueue_read_indirect_desc(dev, desc_buf,
2831                                               desc_addr,
2832                                               desc_len)) {
2833                 desc = desc_buf;
2834             }
2835         }
2836         if (!desc) {
2837             vu_panic(dev, "Invalid indirect buffer table");
2838             return;
2839         }
2840         i = 0;
2841     }
2842 
2843     do {
2844         if (++num_bufs > max) {
2845             vu_panic(dev, "Looped descriptor");
2846             return;
2847         }
2848 
2849         if (le16toh(desc[i].flags) & VRING_DESC_F_WRITE) {
2850             min = MIN(le32toh(desc[i].len), len);
2851             vu_log_write(dev, le64toh(desc[i].addr), min);
2852             len -= min;
2853         }
2854 
2855     } while (len > 0 &&
2856              (virtqueue_read_next_desc(dev, desc, i, max, &i)
2857               == VIRTQUEUE_READ_DESC_MORE));
2858 }
2859 
2860 void
2861 vu_queue_fill(VuDev *dev, VuVirtq *vq,
2862               const VuVirtqElement *elem,
2863               unsigned int len, unsigned int idx)
2864 {
2865     struct vring_used_elem uelem;
2866 
2867     if (unlikely(dev->broken) ||
2868         unlikely(!vq->vring.avail)) {
2869         return;
2870     }
2871 
2872     vu_log_queue_fill(dev, vq, elem, len);
2873 
2874     idx = (idx + vq->used_idx) % vq->vring.num;
2875 
2876     uelem.id = htole32(elem->index);
2877     uelem.len = htole32(len);
2878     vring_used_write(dev, vq, &uelem, idx);
2879 }
2880 
2881 static inline
2882 void vring_used_idx_set(VuDev *dev, VuVirtq *vq, uint16_t val)
2883 {
2884     vq->vring.used->idx = htole16(val);
2885     vu_log_write(dev,
2886                  vq->vring.log_guest_addr + offsetof(struct vring_used, idx),
2887                  sizeof(vq->vring.used->idx));
2888 
2889     vq->used_idx = val;
2890 }
2891 
2892 void
2893 vu_queue_flush(VuDev *dev, VuVirtq *vq, unsigned int count)
2894 {
2895     uint16_t old, new;
2896 
2897     if (unlikely(dev->broken) ||
2898         unlikely(!vq->vring.avail)) {
2899         return;
2900     }
2901 
2902     /* Make sure buffer is written before we update index. */
2903     smp_wmb();
2904 
2905     old = vq->used_idx;
2906     new = old + count;
2907     vring_used_idx_set(dev, vq, new);
2908     vq->inuse -= count;
2909     if (unlikely((int16_t)(new - vq->signalled_used) < (uint16_t)(new - old))) {
2910         vq->signalled_used_valid = false;
2911     }
2912 }
2913 
2914 void
2915 vu_queue_push(VuDev *dev, VuVirtq *vq,
2916               const VuVirtqElement *elem, unsigned int len)
2917 {
2918     vu_queue_fill(dev, vq, elem, len, 0);
2919     vu_queue_inflight_pre_put(dev, vq, elem->index);
2920     vu_queue_flush(dev, vq, 1);
2921     vu_queue_inflight_post_put(dev, vq, elem->index);
2922 }
2923