1 /*
2  * Vhost User library
3  *
4  * Copyright IBM, Corp. 2007
5  * Copyright (c) 2016 Red Hat, Inc.
6  *
7  * Authors:
8  *  Anthony Liguori <aliguori@us.ibm.com>
9  *  Marc-André Lureau <mlureau@redhat.com>
10  *  Victor Kaplansky <victork@redhat.com>
11  *
12  * This work is licensed under the terms of the GNU GPL, version 2 or
13  * later.  See the COPYING file in the top-level directory.
14  */
15 
16 /* this code avoids GLib dependency */
17 #include <stdlib.h>
18 #include <stdio.h>
19 #include <unistd.h>
20 #include <stdarg.h>
21 #include <errno.h>
22 #include <string.h>
23 #include <assert.h>
24 #include <inttypes.h>
25 #include <sys/types.h>
26 #include <sys/socket.h>
27 #include <sys/eventfd.h>
28 #include <sys/mman.h>
29 #include <endian.h>
30 
31 #if defined(__linux__)
32 #include <sys/syscall.h>
33 #include <fcntl.h>
34 #include <sys/ioctl.h>
35 #include <linux/vhost.h>
36 
37 #ifdef __NR_userfaultfd
38 #include <linux/userfaultfd.h>
39 #endif
40 
41 #endif
42 
43 #include "include/atomic.h"
44 
45 #include "libvhost-user.h"
46 
47 /* usually provided by GLib */
48 #ifndef MIN
49 #define MIN(x, y) ({                            \
50             typeof(x) _min1 = (x);              \
51             typeof(y) _min2 = (y);              \
52             (void) (&_min1 == &_min2);          \
53             _min1 < _min2 ? _min1 : _min2; })
54 #endif
55 
56 /* Round number down to multiple */
57 #define ALIGN_DOWN(n, m) ((n) / (m) * (m))
58 
59 /* Round number up to multiple */
60 #define ALIGN_UP(n, m) ALIGN_DOWN((n) + (m) - 1, (m))
61 
62 #ifndef unlikely
63 #define unlikely(x)   __builtin_expect(!!(x), 0)
64 #endif
65 
66 /* Align each region to cache line size in inflight buffer */
67 #define INFLIGHT_ALIGNMENT 64
68 
69 /* The version of inflight buffer */
70 #define INFLIGHT_VERSION 1
71 
72 /* The version of the protocol we support */
73 #define VHOST_USER_VERSION 1
74 #define LIBVHOST_USER_DEBUG 0
75 
76 #define DPRINT(...)                             \
77     do {                                        \
78         if (LIBVHOST_USER_DEBUG) {              \
79             fprintf(stderr, __VA_ARGS__);        \
80         }                                       \
81     } while (0)
82 
83 static inline
84 bool has_feature(uint64_t features, unsigned int fbit)
85 {
86     assert(fbit < 64);
87     return !!(features & (1ULL << fbit));
88 }
89 
90 static inline
91 bool vu_has_feature(VuDev *dev,
92                     unsigned int fbit)
93 {
94     return has_feature(dev->features, fbit);
95 }
96 
97 static inline bool vu_has_protocol_feature(VuDev *dev, unsigned int fbit)
98 {
99     return has_feature(dev->protocol_features, fbit);
100 }
101 
102 const char *
103 vu_request_to_string(unsigned int req)
104 {
105 #define REQ(req) [req] = #req
106     static const char *vu_request_str[] = {
107         REQ(VHOST_USER_NONE),
108         REQ(VHOST_USER_GET_FEATURES),
109         REQ(VHOST_USER_SET_FEATURES),
110         REQ(VHOST_USER_SET_OWNER),
111         REQ(VHOST_USER_RESET_OWNER),
112         REQ(VHOST_USER_SET_MEM_TABLE),
113         REQ(VHOST_USER_SET_LOG_BASE),
114         REQ(VHOST_USER_SET_LOG_FD),
115         REQ(VHOST_USER_SET_VRING_NUM),
116         REQ(VHOST_USER_SET_VRING_ADDR),
117         REQ(VHOST_USER_SET_VRING_BASE),
118         REQ(VHOST_USER_GET_VRING_BASE),
119         REQ(VHOST_USER_SET_VRING_KICK),
120         REQ(VHOST_USER_SET_VRING_CALL),
121         REQ(VHOST_USER_SET_VRING_ERR),
122         REQ(VHOST_USER_GET_PROTOCOL_FEATURES),
123         REQ(VHOST_USER_SET_PROTOCOL_FEATURES),
124         REQ(VHOST_USER_GET_QUEUE_NUM),
125         REQ(VHOST_USER_SET_VRING_ENABLE),
126         REQ(VHOST_USER_SEND_RARP),
127         REQ(VHOST_USER_NET_SET_MTU),
128         REQ(VHOST_USER_SET_SLAVE_REQ_FD),
129         REQ(VHOST_USER_IOTLB_MSG),
130         REQ(VHOST_USER_SET_VRING_ENDIAN),
131         REQ(VHOST_USER_GET_CONFIG),
132         REQ(VHOST_USER_SET_CONFIG),
133         REQ(VHOST_USER_POSTCOPY_ADVISE),
134         REQ(VHOST_USER_POSTCOPY_LISTEN),
135         REQ(VHOST_USER_POSTCOPY_END),
136         REQ(VHOST_USER_GET_INFLIGHT_FD),
137         REQ(VHOST_USER_SET_INFLIGHT_FD),
138         REQ(VHOST_USER_GPU_SET_SOCKET),
139         REQ(VHOST_USER_VRING_KICK),
140         REQ(VHOST_USER_GET_MAX_MEM_SLOTS),
141         REQ(VHOST_USER_ADD_MEM_REG),
142         REQ(VHOST_USER_REM_MEM_REG),
143         REQ(VHOST_USER_MAX),
144     };
145 #undef REQ
146 
147     if (req < VHOST_USER_MAX) {
148         return vu_request_str[req];
149     } else {
150         return "unknown";
151     }
152 }
153 
154 static void
155 vu_panic(VuDev *dev, const char *msg, ...)
156 {
157     char *buf = NULL;
158     va_list ap;
159 
160     va_start(ap, msg);
161     if (vasprintf(&buf, msg, ap) < 0) {
162         buf = NULL;
163     }
164     va_end(ap);
165 
166     dev->broken = true;
167     dev->panic(dev, buf);
168     free(buf);
169 
170     /*
171      * FIXME:
172      * find a way to call virtio_error, or perhaps close the connection?
173      */
174 }
175 
176 /* Translate guest physical address to our virtual address.  */
177 void *
178 vu_gpa_to_va(VuDev *dev, uint64_t *plen, uint64_t guest_addr)
179 {
180     int i;
181 
182     if (*plen == 0) {
183         return NULL;
184     }
185 
186     /* Find matching memory region.  */
187     for (i = 0; i < dev->nregions; i++) {
188         VuDevRegion *r = &dev->regions[i];
189 
190         if ((guest_addr >= r->gpa) && (guest_addr < (r->gpa + r->size))) {
191             if ((guest_addr + *plen) > (r->gpa + r->size)) {
192                 *plen = r->gpa + r->size - guest_addr;
193             }
194             return (void *)(uintptr_t)
195                 guest_addr - r->gpa + r->mmap_addr + r->mmap_offset;
196         }
197     }
198 
199     return NULL;
200 }
201 
202 /* Translate qemu virtual address to our virtual address.  */
203 static void *
204 qva_to_va(VuDev *dev, uint64_t qemu_addr)
205 {
206     int i;
207 
208     /* Find matching memory region.  */
209     for (i = 0; i < dev->nregions; i++) {
210         VuDevRegion *r = &dev->regions[i];
211 
212         if ((qemu_addr >= r->qva) && (qemu_addr < (r->qva + r->size))) {
213             return (void *)(uintptr_t)
214                 qemu_addr - r->qva + r->mmap_addr + r->mmap_offset;
215         }
216     }
217 
218     return NULL;
219 }
220 
221 static void
222 vmsg_close_fds(VhostUserMsg *vmsg)
223 {
224     int i;
225 
226     for (i = 0; i < vmsg->fd_num; i++) {
227         close(vmsg->fds[i]);
228     }
229 }
230 
231 /* Set reply payload.u64 and clear request flags and fd_num */
232 static void vmsg_set_reply_u64(VhostUserMsg *vmsg, uint64_t val)
233 {
234     vmsg->flags = 0; /* defaults will be set by vu_send_reply() */
235     vmsg->size = sizeof(vmsg->payload.u64);
236     vmsg->payload.u64 = val;
237     vmsg->fd_num = 0;
238 }
239 
240 /* A test to see if we have userfault available */
241 static bool
242 have_userfault(void)
243 {
244 #if defined(__linux__) && defined(__NR_userfaultfd) &&\
245         defined(UFFD_FEATURE_MISSING_SHMEM) &&\
246         defined(UFFD_FEATURE_MISSING_HUGETLBFS)
247     /* Now test the kernel we're running on really has the features */
248     int ufd = syscall(__NR_userfaultfd, O_CLOEXEC | O_NONBLOCK);
249     struct uffdio_api api_struct;
250     if (ufd < 0) {
251         return false;
252     }
253 
254     api_struct.api = UFFD_API;
255     api_struct.features = UFFD_FEATURE_MISSING_SHMEM |
256                           UFFD_FEATURE_MISSING_HUGETLBFS;
257     if (ioctl(ufd, UFFDIO_API, &api_struct)) {
258         close(ufd);
259         return false;
260     }
261     close(ufd);
262     return true;
263 
264 #else
265     return false;
266 #endif
267 }
268 
269 static bool
270 vu_message_read_default(VuDev *dev, int conn_fd, VhostUserMsg *vmsg)
271 {
272     char control[CMSG_SPACE(VHOST_MEMORY_BASELINE_NREGIONS * sizeof(int))] = {};
273     struct iovec iov = {
274         .iov_base = (char *)vmsg,
275         .iov_len = VHOST_USER_HDR_SIZE,
276     };
277     struct msghdr msg = {
278         .msg_iov = &iov,
279         .msg_iovlen = 1,
280         .msg_control = control,
281         .msg_controllen = sizeof(control),
282     };
283     size_t fd_size;
284     struct cmsghdr *cmsg;
285     int rc;
286 
287     do {
288         rc = recvmsg(conn_fd, &msg, 0);
289     } while (rc < 0 && (errno == EINTR || errno == EAGAIN));
290 
291     if (rc < 0) {
292         vu_panic(dev, "Error while recvmsg: %s", strerror(errno));
293         return false;
294     }
295 
296     vmsg->fd_num = 0;
297     for (cmsg = CMSG_FIRSTHDR(&msg);
298          cmsg != NULL;
299          cmsg = CMSG_NXTHDR(&msg, cmsg))
300     {
301         if (cmsg->cmsg_level == SOL_SOCKET && cmsg->cmsg_type == SCM_RIGHTS) {
302             fd_size = cmsg->cmsg_len - CMSG_LEN(0);
303             vmsg->fd_num = fd_size / sizeof(int);
304             memcpy(vmsg->fds, CMSG_DATA(cmsg), fd_size);
305             break;
306         }
307     }
308 
309     if (vmsg->size > sizeof(vmsg->payload)) {
310         vu_panic(dev,
311                  "Error: too big message request: %d, size: vmsg->size: %u, "
312                  "while sizeof(vmsg->payload) = %zu\n",
313                  vmsg->request, vmsg->size, sizeof(vmsg->payload));
314         goto fail;
315     }
316 
317     if (vmsg->size) {
318         do {
319             rc = read(conn_fd, &vmsg->payload, vmsg->size);
320         } while (rc < 0 && (errno == EINTR || errno == EAGAIN));
321 
322         if (rc <= 0) {
323             vu_panic(dev, "Error while reading: %s", strerror(errno));
324             goto fail;
325         }
326 
327         assert(rc == vmsg->size);
328     }
329 
330     return true;
331 
332 fail:
333     vmsg_close_fds(vmsg);
334 
335     return false;
336 }
337 
338 static bool
339 vu_message_write(VuDev *dev, int conn_fd, VhostUserMsg *vmsg)
340 {
341     int rc;
342     uint8_t *p = (uint8_t *)vmsg;
343     char control[CMSG_SPACE(VHOST_MEMORY_BASELINE_NREGIONS * sizeof(int))] = {};
344     struct iovec iov = {
345         .iov_base = (char *)vmsg,
346         .iov_len = VHOST_USER_HDR_SIZE,
347     };
348     struct msghdr msg = {
349         .msg_iov = &iov,
350         .msg_iovlen = 1,
351         .msg_control = control,
352     };
353     struct cmsghdr *cmsg;
354 
355     memset(control, 0, sizeof(control));
356     assert(vmsg->fd_num <= VHOST_MEMORY_BASELINE_NREGIONS);
357     if (vmsg->fd_num > 0) {
358         size_t fdsize = vmsg->fd_num * sizeof(int);
359         msg.msg_controllen = CMSG_SPACE(fdsize);
360         cmsg = CMSG_FIRSTHDR(&msg);
361         cmsg->cmsg_len = CMSG_LEN(fdsize);
362         cmsg->cmsg_level = SOL_SOCKET;
363         cmsg->cmsg_type = SCM_RIGHTS;
364         memcpy(CMSG_DATA(cmsg), vmsg->fds, fdsize);
365     } else {
366         msg.msg_controllen = 0;
367     }
368 
369     do {
370         rc = sendmsg(conn_fd, &msg, 0);
371     } while (rc < 0 && (errno == EINTR || errno == EAGAIN));
372 
373     if (vmsg->size) {
374         do {
375             if (vmsg->data) {
376                 rc = write(conn_fd, vmsg->data, vmsg->size);
377             } else {
378                 rc = write(conn_fd, p + VHOST_USER_HDR_SIZE, vmsg->size);
379             }
380         } while (rc < 0 && (errno == EINTR || errno == EAGAIN));
381     }
382 
383     if (rc <= 0) {
384         vu_panic(dev, "Error while writing: %s", strerror(errno));
385         return false;
386     }
387 
388     return true;
389 }
390 
391 static bool
392 vu_send_reply(VuDev *dev, int conn_fd, VhostUserMsg *vmsg)
393 {
394     /* Set the version in the flags when sending the reply */
395     vmsg->flags &= ~VHOST_USER_VERSION_MASK;
396     vmsg->flags |= VHOST_USER_VERSION;
397     vmsg->flags |= VHOST_USER_REPLY_MASK;
398 
399     return vu_message_write(dev, conn_fd, vmsg);
400 }
401 
402 /*
403  * Processes a reply on the slave channel.
404  * Entered with slave_mutex held and releases it before exit.
405  * Returns true on success.
406  */
407 static bool
408 vu_process_message_reply(VuDev *dev, const VhostUserMsg *vmsg)
409 {
410     VhostUserMsg msg_reply;
411     bool result = false;
412 
413     if ((vmsg->flags & VHOST_USER_NEED_REPLY_MASK) == 0) {
414         result = true;
415         goto out;
416     }
417 
418     if (!vu_message_read_default(dev, dev->slave_fd, &msg_reply)) {
419         goto out;
420     }
421 
422     if (msg_reply.request != vmsg->request) {
423         DPRINT("Received unexpected msg type. Expected %d received %d",
424                vmsg->request, msg_reply.request);
425         goto out;
426     }
427 
428     result = msg_reply.payload.u64 == 0;
429 
430 out:
431     pthread_mutex_unlock(&dev->slave_mutex);
432     return result;
433 }
434 
435 /* Kick the log_call_fd if required. */
436 static void
437 vu_log_kick(VuDev *dev)
438 {
439     if (dev->log_call_fd != -1) {
440         DPRINT("Kicking the QEMU's log...\n");
441         if (eventfd_write(dev->log_call_fd, 1) < 0) {
442             vu_panic(dev, "Error writing eventfd: %s", strerror(errno));
443         }
444     }
445 }
446 
447 static void
448 vu_log_page(uint8_t *log_table, uint64_t page)
449 {
450     DPRINT("Logged dirty guest page: %"PRId64"\n", page);
451     qatomic_or(&log_table[page / 8], 1 << (page % 8));
452 }
453 
454 static void
455 vu_log_write(VuDev *dev, uint64_t address, uint64_t length)
456 {
457     uint64_t page;
458 
459     if (!(dev->features & (1ULL << VHOST_F_LOG_ALL)) ||
460         !dev->log_table || !length) {
461         return;
462     }
463 
464     assert(dev->log_size > ((address + length - 1) / VHOST_LOG_PAGE / 8));
465 
466     page = address / VHOST_LOG_PAGE;
467     while (page * VHOST_LOG_PAGE < address + length) {
468         vu_log_page(dev->log_table, page);
469         page += 1;
470     }
471 
472     vu_log_kick(dev);
473 }
474 
475 static void
476 vu_kick_cb(VuDev *dev, int condition, void *data)
477 {
478     int index = (intptr_t)data;
479     VuVirtq *vq = &dev->vq[index];
480     int sock = vq->kick_fd;
481     eventfd_t kick_data;
482     ssize_t rc;
483 
484     rc = eventfd_read(sock, &kick_data);
485     if (rc == -1) {
486         vu_panic(dev, "kick eventfd_read(): %s", strerror(errno));
487         dev->remove_watch(dev, dev->vq[index].kick_fd);
488     } else {
489         DPRINT("Got kick_data: %016"PRIx64" handler:%p idx:%d\n",
490                kick_data, vq->handler, index);
491         if (vq->handler) {
492             vq->handler(dev, index);
493         }
494     }
495 }
496 
497 static bool
498 vu_get_features_exec(VuDev *dev, VhostUserMsg *vmsg)
499 {
500     vmsg->payload.u64 =
501         /*
502          * The following VIRTIO feature bits are supported by our virtqueue
503          * implementation:
504          */
505         1ULL << VIRTIO_F_NOTIFY_ON_EMPTY |
506         1ULL << VIRTIO_RING_F_INDIRECT_DESC |
507         1ULL << VIRTIO_RING_F_EVENT_IDX |
508         1ULL << VIRTIO_F_VERSION_1 |
509 
510         /* vhost-user feature bits */
511         1ULL << VHOST_F_LOG_ALL |
512         1ULL << VHOST_USER_F_PROTOCOL_FEATURES;
513 
514     if (dev->iface->get_features) {
515         vmsg->payload.u64 |= dev->iface->get_features(dev);
516     }
517 
518     vmsg->size = sizeof(vmsg->payload.u64);
519     vmsg->fd_num = 0;
520 
521     DPRINT("Sending back to guest u64: 0x%016"PRIx64"\n", vmsg->payload.u64);
522 
523     return true;
524 }
525 
526 static void
527 vu_set_enable_all_rings(VuDev *dev, bool enabled)
528 {
529     uint16_t i;
530 
531     for (i = 0; i < dev->max_queues; i++) {
532         dev->vq[i].enable = enabled;
533     }
534 }
535 
536 static bool
537 vu_set_features_exec(VuDev *dev, VhostUserMsg *vmsg)
538 {
539     DPRINT("u64: 0x%016"PRIx64"\n", vmsg->payload.u64);
540 
541     dev->features = vmsg->payload.u64;
542     if (!vu_has_feature(dev, VIRTIO_F_VERSION_1)) {
543         /*
544          * We only support devices conforming to VIRTIO 1.0 or
545          * later
546          */
547         vu_panic(dev, "virtio legacy devices aren't supported by libvhost-user");
548         return false;
549     }
550 
551     if (!(dev->features & VHOST_USER_F_PROTOCOL_FEATURES)) {
552         vu_set_enable_all_rings(dev, true);
553     }
554 
555     if (dev->iface->set_features) {
556         dev->iface->set_features(dev, dev->features);
557     }
558 
559     return false;
560 }
561 
562 static bool
563 vu_set_owner_exec(VuDev *dev, VhostUserMsg *vmsg)
564 {
565     return false;
566 }
567 
568 static void
569 vu_close_log(VuDev *dev)
570 {
571     if (dev->log_table) {
572         if (munmap(dev->log_table, dev->log_size) != 0) {
573             perror("close log munmap() error");
574         }
575 
576         dev->log_table = NULL;
577     }
578     if (dev->log_call_fd != -1) {
579         close(dev->log_call_fd);
580         dev->log_call_fd = -1;
581     }
582 }
583 
584 static bool
585 vu_reset_device_exec(VuDev *dev, VhostUserMsg *vmsg)
586 {
587     vu_set_enable_all_rings(dev, false);
588 
589     return false;
590 }
591 
592 static bool
593 map_ring(VuDev *dev, VuVirtq *vq)
594 {
595     vq->vring.desc = qva_to_va(dev, vq->vra.desc_user_addr);
596     vq->vring.used = qva_to_va(dev, vq->vra.used_user_addr);
597     vq->vring.avail = qva_to_va(dev, vq->vra.avail_user_addr);
598 
599     DPRINT("Setting virtq addresses:\n");
600     DPRINT("    vring_desc  at %p\n", vq->vring.desc);
601     DPRINT("    vring_used  at %p\n", vq->vring.used);
602     DPRINT("    vring_avail at %p\n", vq->vring.avail);
603 
604     return !(vq->vring.desc && vq->vring.used && vq->vring.avail);
605 }
606 
607 static bool
608 generate_faults(VuDev *dev) {
609     int i;
610     for (i = 0; i < dev->nregions; i++) {
611         VuDevRegion *dev_region = &dev->regions[i];
612         int ret;
613 #ifdef UFFDIO_REGISTER
614         /*
615          * We should already have an open ufd. Mark each memory
616          * range as ufd.
617          * Discard any mapping we have here; note I can't use MADV_REMOVE
618          * or fallocate to make the hole since I don't want to lose
619          * data that's already arrived in the shared process.
620          * TODO: How to do hugepage
621          */
622         ret = madvise((void *)(uintptr_t)dev_region->mmap_addr,
623                       dev_region->size + dev_region->mmap_offset,
624                       MADV_DONTNEED);
625         if (ret) {
626             fprintf(stderr,
627                     "%s: Failed to madvise(DONTNEED) region %d: %s\n",
628                     __func__, i, strerror(errno));
629         }
630         /*
631          * Turn off transparent hugepages so we dont get lose wakeups
632          * in neighbouring pages.
633          * TODO: Turn this backon later.
634          */
635         ret = madvise((void *)(uintptr_t)dev_region->mmap_addr,
636                       dev_region->size + dev_region->mmap_offset,
637                       MADV_NOHUGEPAGE);
638         if (ret) {
639             /*
640              * Note: This can happen legally on kernels that are configured
641              * without madvise'able hugepages
642              */
643             fprintf(stderr,
644                     "%s: Failed to madvise(NOHUGEPAGE) region %d: %s\n",
645                     __func__, i, strerror(errno));
646         }
647         struct uffdio_register reg_struct;
648         reg_struct.range.start = (uintptr_t)dev_region->mmap_addr;
649         reg_struct.range.len = dev_region->size + dev_region->mmap_offset;
650         reg_struct.mode = UFFDIO_REGISTER_MODE_MISSING;
651 
652         if (ioctl(dev->postcopy_ufd, UFFDIO_REGISTER, &reg_struct)) {
653             vu_panic(dev, "%s: Failed to userfault region %d "
654                           "@%p + size:%zx offset: %zx: (ufd=%d)%s\n",
655                      __func__, i,
656                      dev_region->mmap_addr,
657                      dev_region->size, dev_region->mmap_offset,
658                      dev->postcopy_ufd, strerror(errno));
659             return false;
660         }
661         if (!(reg_struct.ioctls & ((__u64)1 << _UFFDIO_COPY))) {
662             vu_panic(dev, "%s Region (%d) doesn't support COPY",
663                      __func__, i);
664             return false;
665         }
666         DPRINT("%s: region %d: Registered userfault for %"
667                PRIx64 " + %" PRIx64 "\n", __func__, i,
668                (uint64_t)reg_struct.range.start,
669                (uint64_t)reg_struct.range.len);
670         /* Now it's registered we can let the client at it */
671         if (mprotect((void *)(uintptr_t)dev_region->mmap_addr,
672                      dev_region->size + dev_region->mmap_offset,
673                      PROT_READ | PROT_WRITE)) {
674             vu_panic(dev, "failed to mprotect region %d for postcopy (%s)",
675                      i, strerror(errno));
676             return false;
677         }
678         /* TODO: Stash 'zero' support flags somewhere */
679 #endif
680     }
681 
682     return true;
683 }
684 
685 static bool
686 vu_add_mem_reg(VuDev *dev, VhostUserMsg *vmsg) {
687     int i;
688     bool track_ramblocks = dev->postcopy_listening;
689     VhostUserMemoryRegion m = vmsg->payload.memreg.region, *msg_region = &m;
690     VuDevRegion *dev_region = &dev->regions[dev->nregions];
691     void *mmap_addr;
692 
693     if (vmsg->fd_num != 1) {
694         vmsg_close_fds(vmsg);
695         vu_panic(dev, "VHOST_USER_ADD_MEM_REG received %d fds - only 1 fd "
696                       "should be sent for this message type", vmsg->fd_num);
697         return false;
698     }
699 
700     if (vmsg->size < VHOST_USER_MEM_REG_SIZE) {
701         close(vmsg->fds[0]);
702         vu_panic(dev, "VHOST_USER_ADD_MEM_REG requires a message size of at "
703                       "least %d bytes and only %d bytes were received",
704                       VHOST_USER_MEM_REG_SIZE, vmsg->size);
705         return false;
706     }
707 
708     if (dev->nregions == VHOST_USER_MAX_RAM_SLOTS) {
709         close(vmsg->fds[0]);
710         vu_panic(dev, "failing attempt to hot add memory via "
711                       "VHOST_USER_ADD_MEM_REG message because the backend has "
712                       "no free ram slots available");
713         return false;
714     }
715 
716     /*
717      * If we are in postcopy mode and we receive a u64 payload with a 0 value
718      * we know all the postcopy client bases have been received, and we
719      * should start generating faults.
720      */
721     if (track_ramblocks &&
722         vmsg->size == sizeof(vmsg->payload.u64) &&
723         vmsg->payload.u64 == 0) {
724         (void)generate_faults(dev);
725         return false;
726     }
727 
728     DPRINT("Adding region: %u\n", dev->nregions);
729     DPRINT("    guest_phys_addr: 0x%016"PRIx64"\n",
730            msg_region->guest_phys_addr);
731     DPRINT("    memory_size:     0x%016"PRIx64"\n",
732            msg_region->memory_size);
733     DPRINT("    userspace_addr   0x%016"PRIx64"\n",
734            msg_region->userspace_addr);
735     DPRINT("    mmap_offset      0x%016"PRIx64"\n",
736            msg_region->mmap_offset);
737 
738     dev_region->gpa = msg_region->guest_phys_addr;
739     dev_region->size = msg_region->memory_size;
740     dev_region->qva = msg_region->userspace_addr;
741     dev_region->mmap_offset = msg_region->mmap_offset;
742 
743     /*
744      * We don't use offset argument of mmap() since the
745      * mapped address has to be page aligned, and we use huge
746      * pages.
747      */
748     if (track_ramblocks) {
749         /*
750          * In postcopy we're using PROT_NONE here to catch anyone
751          * accessing it before we userfault.
752          */
753         mmap_addr = mmap(0, dev_region->size + dev_region->mmap_offset,
754                          PROT_NONE, MAP_SHARED | MAP_NORESERVE,
755                          vmsg->fds[0], 0);
756     } else {
757         mmap_addr = mmap(0, dev_region->size + dev_region->mmap_offset,
758                          PROT_READ | PROT_WRITE, MAP_SHARED | MAP_NORESERVE,
759                          vmsg->fds[0], 0);
760     }
761 
762     if (mmap_addr == MAP_FAILED) {
763         vu_panic(dev, "region mmap error: %s", strerror(errno));
764     } else {
765         dev_region->mmap_addr = (uint64_t)(uintptr_t)mmap_addr;
766         DPRINT("    mmap_addr:       0x%016"PRIx64"\n",
767                dev_region->mmap_addr);
768     }
769 
770     close(vmsg->fds[0]);
771 
772     if (track_ramblocks) {
773         /*
774          * Return the address to QEMU so that it can translate the ufd
775          * fault addresses back.
776          */
777         msg_region->userspace_addr = (uintptr_t)(mmap_addr +
778                                                  dev_region->mmap_offset);
779 
780         /* Send the message back to qemu with the addresses filled in. */
781         vmsg->fd_num = 0;
782         DPRINT("Successfully added new region in postcopy\n");
783         dev->nregions++;
784         return true;
785     } else {
786         for (i = 0; i < dev->max_queues; i++) {
787             if (dev->vq[i].vring.desc) {
788                 if (map_ring(dev, &dev->vq[i])) {
789                     vu_panic(dev, "remapping queue %d for new memory region",
790                              i);
791                 }
792             }
793         }
794 
795         DPRINT("Successfully added new region\n");
796         dev->nregions++;
797         return false;
798     }
799 }
800 
801 static inline bool reg_equal(VuDevRegion *vudev_reg,
802                              VhostUserMemoryRegion *msg_reg)
803 {
804     if (vudev_reg->gpa == msg_reg->guest_phys_addr &&
805         vudev_reg->qva == msg_reg->userspace_addr &&
806         vudev_reg->size == msg_reg->memory_size) {
807         return true;
808     }
809 
810     return false;
811 }
812 
813 static bool
814 vu_rem_mem_reg(VuDev *dev, VhostUserMsg *vmsg) {
815     VhostUserMemoryRegion m = vmsg->payload.memreg.region, *msg_region = &m;
816     int i;
817     bool found = false;
818 
819     if (vmsg->fd_num > 1) {
820         vmsg_close_fds(vmsg);
821         vu_panic(dev, "VHOST_USER_REM_MEM_REG received %d fds - at most 1 fd "
822                       "should be sent for this message type", vmsg->fd_num);
823         return false;
824     }
825 
826     if (vmsg->size < VHOST_USER_MEM_REG_SIZE) {
827         vmsg_close_fds(vmsg);
828         vu_panic(dev, "VHOST_USER_REM_MEM_REG requires a message size of at "
829                       "least %d bytes and only %d bytes were received",
830                       VHOST_USER_MEM_REG_SIZE, vmsg->size);
831         return false;
832     }
833 
834     DPRINT("Removing region:\n");
835     DPRINT("    guest_phys_addr: 0x%016"PRIx64"\n",
836            msg_region->guest_phys_addr);
837     DPRINT("    memory_size:     0x%016"PRIx64"\n",
838            msg_region->memory_size);
839     DPRINT("    userspace_addr   0x%016"PRIx64"\n",
840            msg_region->userspace_addr);
841     DPRINT("    mmap_offset      0x%016"PRIx64"\n",
842            msg_region->mmap_offset);
843 
844     for (i = 0; i < dev->nregions; i++) {
845         if (reg_equal(&dev->regions[i], msg_region)) {
846             VuDevRegion *r = &dev->regions[i];
847             void *m = (void *) (uintptr_t) r->mmap_addr;
848 
849             if (m) {
850                 munmap(m, r->size + r->mmap_offset);
851             }
852 
853             /*
854              * Shift all affected entries by 1 to close the hole at index i and
855              * zero out the last entry.
856              */
857             memmove(dev->regions + i, dev->regions + i + 1,
858                     sizeof(VuDevRegion) * (dev->nregions - i - 1));
859             memset(dev->regions + dev->nregions - 1, 0, sizeof(VuDevRegion));
860             DPRINT("Successfully removed a region\n");
861             dev->nregions--;
862             i--;
863 
864             found = true;
865 
866             /* Continue the search for eventual duplicates. */
867         }
868     }
869 
870     if (!found) {
871         vu_panic(dev, "Specified region not found\n");
872     }
873 
874     vmsg_close_fds(vmsg);
875 
876     return false;
877 }
878 
879 static bool
880 vu_set_mem_table_exec_postcopy(VuDev *dev, VhostUserMsg *vmsg)
881 {
882     int i;
883     VhostUserMemory m = vmsg->payload.memory, *memory = &m;
884     dev->nregions = memory->nregions;
885 
886     DPRINT("Nregions: %u\n", memory->nregions);
887     for (i = 0; i < dev->nregions; i++) {
888         void *mmap_addr;
889         VhostUserMemoryRegion *msg_region = &memory->regions[i];
890         VuDevRegion *dev_region = &dev->regions[i];
891 
892         DPRINT("Region %d\n", i);
893         DPRINT("    guest_phys_addr: 0x%016"PRIx64"\n",
894                msg_region->guest_phys_addr);
895         DPRINT("    memory_size:     0x%016"PRIx64"\n",
896                msg_region->memory_size);
897         DPRINT("    userspace_addr   0x%016"PRIx64"\n",
898                msg_region->userspace_addr);
899         DPRINT("    mmap_offset      0x%016"PRIx64"\n",
900                msg_region->mmap_offset);
901 
902         dev_region->gpa = msg_region->guest_phys_addr;
903         dev_region->size = msg_region->memory_size;
904         dev_region->qva = msg_region->userspace_addr;
905         dev_region->mmap_offset = msg_region->mmap_offset;
906 
907         /* We don't use offset argument of mmap() since the
908          * mapped address has to be page aligned, and we use huge
909          * pages.
910          * In postcopy we're using PROT_NONE here to catch anyone
911          * accessing it before we userfault
912          */
913         mmap_addr = mmap(0, dev_region->size + dev_region->mmap_offset,
914                          PROT_NONE, MAP_SHARED | MAP_NORESERVE,
915                          vmsg->fds[i], 0);
916 
917         if (mmap_addr == MAP_FAILED) {
918             vu_panic(dev, "region mmap error: %s", strerror(errno));
919         } else {
920             dev_region->mmap_addr = (uint64_t)(uintptr_t)mmap_addr;
921             DPRINT("    mmap_addr:       0x%016"PRIx64"\n",
922                    dev_region->mmap_addr);
923         }
924 
925         /* Return the address to QEMU so that it can translate the ufd
926          * fault addresses back.
927          */
928         msg_region->userspace_addr = (uintptr_t)(mmap_addr +
929                                                  dev_region->mmap_offset);
930         close(vmsg->fds[i]);
931     }
932 
933     /* Send the message back to qemu with the addresses filled in */
934     vmsg->fd_num = 0;
935     if (!vu_send_reply(dev, dev->sock, vmsg)) {
936         vu_panic(dev, "failed to respond to set-mem-table for postcopy");
937         return false;
938     }
939 
940     /* Wait for QEMU to confirm that it's registered the handler for the
941      * faults.
942      */
943     if (!dev->read_msg(dev, dev->sock, vmsg) ||
944         vmsg->size != sizeof(vmsg->payload.u64) ||
945         vmsg->payload.u64 != 0) {
946         vu_panic(dev, "failed to receive valid ack for postcopy set-mem-table");
947         return false;
948     }
949 
950     /* OK, now we can go and register the memory and generate faults */
951     (void)generate_faults(dev);
952 
953     return false;
954 }
955 
956 static bool
957 vu_set_mem_table_exec(VuDev *dev, VhostUserMsg *vmsg)
958 {
959     int i;
960     VhostUserMemory m = vmsg->payload.memory, *memory = &m;
961 
962     for (i = 0; i < dev->nregions; i++) {
963         VuDevRegion *r = &dev->regions[i];
964         void *m = (void *) (uintptr_t) r->mmap_addr;
965 
966         if (m) {
967             munmap(m, r->size + r->mmap_offset);
968         }
969     }
970     dev->nregions = memory->nregions;
971 
972     if (dev->postcopy_listening) {
973         return vu_set_mem_table_exec_postcopy(dev, vmsg);
974     }
975 
976     DPRINT("Nregions: %u\n", memory->nregions);
977     for (i = 0; i < dev->nregions; i++) {
978         void *mmap_addr;
979         VhostUserMemoryRegion *msg_region = &memory->regions[i];
980         VuDevRegion *dev_region = &dev->regions[i];
981 
982         DPRINT("Region %d\n", i);
983         DPRINT("    guest_phys_addr: 0x%016"PRIx64"\n",
984                msg_region->guest_phys_addr);
985         DPRINT("    memory_size:     0x%016"PRIx64"\n",
986                msg_region->memory_size);
987         DPRINT("    userspace_addr   0x%016"PRIx64"\n",
988                msg_region->userspace_addr);
989         DPRINT("    mmap_offset      0x%016"PRIx64"\n",
990                msg_region->mmap_offset);
991 
992         dev_region->gpa = msg_region->guest_phys_addr;
993         dev_region->size = msg_region->memory_size;
994         dev_region->qva = msg_region->userspace_addr;
995         dev_region->mmap_offset = msg_region->mmap_offset;
996 
997         /* We don't use offset argument of mmap() since the
998          * mapped address has to be page aligned, and we use huge
999          * pages.  */
1000         mmap_addr = mmap(0, dev_region->size + dev_region->mmap_offset,
1001                          PROT_READ | PROT_WRITE, MAP_SHARED | MAP_NORESERVE,
1002                          vmsg->fds[i], 0);
1003 
1004         if (mmap_addr == MAP_FAILED) {
1005             vu_panic(dev, "region mmap error: %s", strerror(errno));
1006         } else {
1007             dev_region->mmap_addr = (uint64_t)(uintptr_t)mmap_addr;
1008             DPRINT("    mmap_addr:       0x%016"PRIx64"\n",
1009                    dev_region->mmap_addr);
1010         }
1011 
1012         close(vmsg->fds[i]);
1013     }
1014 
1015     for (i = 0; i < dev->max_queues; i++) {
1016         if (dev->vq[i].vring.desc) {
1017             if (map_ring(dev, &dev->vq[i])) {
1018                 vu_panic(dev, "remapping queue %d during setmemtable", i);
1019             }
1020         }
1021     }
1022 
1023     return false;
1024 }
1025 
1026 static bool
1027 vu_set_log_base_exec(VuDev *dev, VhostUserMsg *vmsg)
1028 {
1029     int fd;
1030     uint64_t log_mmap_size, log_mmap_offset;
1031     void *rc;
1032 
1033     if (vmsg->fd_num != 1 ||
1034         vmsg->size != sizeof(vmsg->payload.log)) {
1035         vu_panic(dev, "Invalid log_base message");
1036         return true;
1037     }
1038 
1039     fd = vmsg->fds[0];
1040     log_mmap_offset = vmsg->payload.log.mmap_offset;
1041     log_mmap_size = vmsg->payload.log.mmap_size;
1042     DPRINT("Log mmap_offset: %"PRId64"\n", log_mmap_offset);
1043     DPRINT("Log mmap_size:   %"PRId64"\n", log_mmap_size);
1044 
1045     rc = mmap(0, log_mmap_size, PROT_READ | PROT_WRITE, MAP_SHARED, fd,
1046               log_mmap_offset);
1047     close(fd);
1048     if (rc == MAP_FAILED) {
1049         perror("log mmap error");
1050     }
1051 
1052     if (dev->log_table) {
1053         munmap(dev->log_table, dev->log_size);
1054     }
1055     dev->log_table = rc;
1056     dev->log_size = log_mmap_size;
1057 
1058     vmsg->size = sizeof(vmsg->payload.u64);
1059     vmsg->fd_num = 0;
1060 
1061     return true;
1062 }
1063 
1064 static bool
1065 vu_set_log_fd_exec(VuDev *dev, VhostUserMsg *vmsg)
1066 {
1067     if (vmsg->fd_num != 1) {
1068         vu_panic(dev, "Invalid log_fd message");
1069         return false;
1070     }
1071 
1072     if (dev->log_call_fd != -1) {
1073         close(dev->log_call_fd);
1074     }
1075     dev->log_call_fd = vmsg->fds[0];
1076     DPRINT("Got log_call_fd: %d\n", vmsg->fds[0]);
1077 
1078     return false;
1079 }
1080 
1081 static bool
1082 vu_set_vring_num_exec(VuDev *dev, VhostUserMsg *vmsg)
1083 {
1084     unsigned int index = vmsg->payload.state.index;
1085     unsigned int num = vmsg->payload.state.num;
1086 
1087     DPRINT("State.index: %u\n", index);
1088     DPRINT("State.num:   %u\n", num);
1089     dev->vq[index].vring.num = num;
1090 
1091     return false;
1092 }
1093 
1094 static bool
1095 vu_set_vring_addr_exec(VuDev *dev, VhostUserMsg *vmsg)
1096 {
1097     struct vhost_vring_addr addr = vmsg->payload.addr, *vra = &addr;
1098     unsigned int index = vra->index;
1099     VuVirtq *vq = &dev->vq[index];
1100 
1101     DPRINT("vhost_vring_addr:\n");
1102     DPRINT("    index:  %d\n", vra->index);
1103     DPRINT("    flags:  %d\n", vra->flags);
1104     DPRINT("    desc_user_addr:   0x%016" PRIx64 "\n", (uint64_t)vra->desc_user_addr);
1105     DPRINT("    used_user_addr:   0x%016" PRIx64 "\n", (uint64_t)vra->used_user_addr);
1106     DPRINT("    avail_user_addr:  0x%016" PRIx64 "\n", (uint64_t)vra->avail_user_addr);
1107     DPRINT("    log_guest_addr:   0x%016" PRIx64 "\n", (uint64_t)vra->log_guest_addr);
1108 
1109     vq->vra = *vra;
1110     vq->vring.flags = vra->flags;
1111     vq->vring.log_guest_addr = vra->log_guest_addr;
1112 
1113 
1114     if (map_ring(dev, vq)) {
1115         vu_panic(dev, "Invalid vring_addr message");
1116         return false;
1117     }
1118 
1119     vq->used_idx = le16toh(vq->vring.used->idx);
1120 
1121     if (vq->last_avail_idx != vq->used_idx) {
1122         bool resume = dev->iface->queue_is_processed_in_order &&
1123             dev->iface->queue_is_processed_in_order(dev, index);
1124 
1125         DPRINT("Last avail index != used index: %u != %u%s\n",
1126                vq->last_avail_idx, vq->used_idx,
1127                resume ? ", resuming" : "");
1128 
1129         if (resume) {
1130             vq->shadow_avail_idx = vq->last_avail_idx = vq->used_idx;
1131         }
1132     }
1133 
1134     return false;
1135 }
1136 
1137 static bool
1138 vu_set_vring_base_exec(VuDev *dev, VhostUserMsg *vmsg)
1139 {
1140     unsigned int index = vmsg->payload.state.index;
1141     unsigned int num = vmsg->payload.state.num;
1142 
1143     DPRINT("State.index: %u\n", index);
1144     DPRINT("State.num:   %u\n", num);
1145     dev->vq[index].shadow_avail_idx = dev->vq[index].last_avail_idx = num;
1146 
1147     return false;
1148 }
1149 
1150 static bool
1151 vu_get_vring_base_exec(VuDev *dev, VhostUserMsg *vmsg)
1152 {
1153     unsigned int index = vmsg->payload.state.index;
1154 
1155     DPRINT("State.index: %u\n", index);
1156     vmsg->payload.state.num = dev->vq[index].last_avail_idx;
1157     vmsg->size = sizeof(vmsg->payload.state);
1158 
1159     dev->vq[index].started = false;
1160     if (dev->iface->queue_set_started) {
1161         dev->iface->queue_set_started(dev, index, false);
1162     }
1163 
1164     if (dev->vq[index].call_fd != -1) {
1165         close(dev->vq[index].call_fd);
1166         dev->vq[index].call_fd = -1;
1167     }
1168     if (dev->vq[index].kick_fd != -1) {
1169         dev->remove_watch(dev, dev->vq[index].kick_fd);
1170         close(dev->vq[index].kick_fd);
1171         dev->vq[index].kick_fd = -1;
1172     }
1173 
1174     return true;
1175 }
1176 
1177 static bool
1178 vu_check_queue_msg_file(VuDev *dev, VhostUserMsg *vmsg)
1179 {
1180     int index = vmsg->payload.u64 & VHOST_USER_VRING_IDX_MASK;
1181     bool nofd = vmsg->payload.u64 & VHOST_USER_VRING_NOFD_MASK;
1182 
1183     if (index >= dev->max_queues) {
1184         vmsg_close_fds(vmsg);
1185         vu_panic(dev, "Invalid queue index: %u", index);
1186         return false;
1187     }
1188 
1189     if (nofd) {
1190         vmsg_close_fds(vmsg);
1191         return true;
1192     }
1193 
1194     if (vmsg->fd_num != 1) {
1195         vmsg_close_fds(vmsg);
1196         vu_panic(dev, "Invalid fds in request: %d", vmsg->request);
1197         return false;
1198     }
1199 
1200     return true;
1201 }
1202 
1203 static int
1204 inflight_desc_compare(const void *a, const void *b)
1205 {
1206     VuVirtqInflightDesc *desc0 = (VuVirtqInflightDesc *)a,
1207                         *desc1 = (VuVirtqInflightDesc *)b;
1208 
1209     if (desc1->counter > desc0->counter &&
1210         (desc1->counter - desc0->counter) < VIRTQUEUE_MAX_SIZE * 2) {
1211         return 1;
1212     }
1213 
1214     return -1;
1215 }
1216 
1217 static int
1218 vu_check_queue_inflights(VuDev *dev, VuVirtq *vq)
1219 {
1220     int i = 0;
1221 
1222     if (!vu_has_protocol_feature(dev, VHOST_USER_PROTOCOL_F_INFLIGHT_SHMFD)) {
1223         return 0;
1224     }
1225 
1226     if (unlikely(!vq->inflight)) {
1227         return -1;
1228     }
1229 
1230     if (unlikely(!vq->inflight->version)) {
1231         /* initialize the buffer */
1232         vq->inflight->version = INFLIGHT_VERSION;
1233         return 0;
1234     }
1235 
1236     vq->used_idx = le16toh(vq->vring.used->idx);
1237     vq->resubmit_num = 0;
1238     vq->resubmit_list = NULL;
1239     vq->counter = 0;
1240 
1241     if (unlikely(vq->inflight->used_idx != vq->used_idx)) {
1242         vq->inflight->desc[vq->inflight->last_batch_head].inflight = 0;
1243 
1244         barrier();
1245 
1246         vq->inflight->used_idx = vq->used_idx;
1247     }
1248 
1249     for (i = 0; i < vq->inflight->desc_num; i++) {
1250         if (vq->inflight->desc[i].inflight == 1) {
1251             vq->inuse++;
1252         }
1253     }
1254 
1255     vq->shadow_avail_idx = vq->last_avail_idx = vq->inuse + vq->used_idx;
1256 
1257     if (vq->inuse) {
1258         vq->resubmit_list = calloc(vq->inuse, sizeof(VuVirtqInflightDesc));
1259         if (!vq->resubmit_list) {
1260             return -1;
1261         }
1262 
1263         for (i = 0; i < vq->inflight->desc_num; i++) {
1264             if (vq->inflight->desc[i].inflight) {
1265                 vq->resubmit_list[vq->resubmit_num].index = i;
1266                 vq->resubmit_list[vq->resubmit_num].counter =
1267                                         vq->inflight->desc[i].counter;
1268                 vq->resubmit_num++;
1269             }
1270         }
1271 
1272         if (vq->resubmit_num > 1) {
1273             qsort(vq->resubmit_list, vq->resubmit_num,
1274                   sizeof(VuVirtqInflightDesc), inflight_desc_compare);
1275         }
1276         vq->counter = vq->resubmit_list[0].counter + 1;
1277     }
1278 
1279     /* in case of I/O hang after reconnecting */
1280     if (eventfd_write(vq->kick_fd, 1)) {
1281         return -1;
1282     }
1283 
1284     return 0;
1285 }
1286 
1287 static bool
1288 vu_set_vring_kick_exec(VuDev *dev, VhostUserMsg *vmsg)
1289 {
1290     int index = vmsg->payload.u64 & VHOST_USER_VRING_IDX_MASK;
1291     bool nofd = vmsg->payload.u64 & VHOST_USER_VRING_NOFD_MASK;
1292 
1293     DPRINT("u64: 0x%016"PRIx64"\n", vmsg->payload.u64);
1294 
1295     if (!vu_check_queue_msg_file(dev, vmsg)) {
1296         return false;
1297     }
1298 
1299     if (dev->vq[index].kick_fd != -1) {
1300         dev->remove_watch(dev, dev->vq[index].kick_fd);
1301         close(dev->vq[index].kick_fd);
1302         dev->vq[index].kick_fd = -1;
1303     }
1304 
1305     dev->vq[index].kick_fd = nofd ? -1 : vmsg->fds[0];
1306     DPRINT("Got kick_fd: %d for vq: %d\n", dev->vq[index].kick_fd, index);
1307 
1308     dev->vq[index].started = true;
1309     if (dev->iface->queue_set_started) {
1310         dev->iface->queue_set_started(dev, index, true);
1311     }
1312 
1313     if (dev->vq[index].kick_fd != -1 && dev->vq[index].handler) {
1314         dev->set_watch(dev, dev->vq[index].kick_fd, VU_WATCH_IN,
1315                        vu_kick_cb, (void *)(long)index);
1316 
1317         DPRINT("Waiting for kicks on fd: %d for vq: %d\n",
1318                dev->vq[index].kick_fd, index);
1319     }
1320 
1321     if (vu_check_queue_inflights(dev, &dev->vq[index])) {
1322         vu_panic(dev, "Failed to check inflights for vq: %d\n", index);
1323     }
1324 
1325     return false;
1326 }
1327 
1328 void vu_set_queue_handler(VuDev *dev, VuVirtq *vq,
1329                           vu_queue_handler_cb handler)
1330 {
1331     int qidx = vq - dev->vq;
1332 
1333     vq->handler = handler;
1334     if (vq->kick_fd >= 0) {
1335         if (handler) {
1336             dev->set_watch(dev, vq->kick_fd, VU_WATCH_IN,
1337                            vu_kick_cb, (void *)(long)qidx);
1338         } else {
1339             dev->remove_watch(dev, vq->kick_fd);
1340         }
1341     }
1342 }
1343 
1344 bool vu_set_queue_host_notifier(VuDev *dev, VuVirtq *vq, int fd,
1345                                 int size, int offset)
1346 {
1347     int qidx = vq - dev->vq;
1348     int fd_num = 0;
1349     VhostUserMsg vmsg = {
1350         .request = VHOST_USER_SLAVE_VRING_HOST_NOTIFIER_MSG,
1351         .flags = VHOST_USER_VERSION | VHOST_USER_NEED_REPLY_MASK,
1352         .size = sizeof(vmsg.payload.area),
1353         .payload.area = {
1354             .u64 = qidx & VHOST_USER_VRING_IDX_MASK,
1355             .size = size,
1356             .offset = offset,
1357         },
1358     };
1359 
1360     if (fd == -1) {
1361         vmsg.payload.area.u64 |= VHOST_USER_VRING_NOFD_MASK;
1362     } else {
1363         vmsg.fds[fd_num++] = fd;
1364     }
1365 
1366     vmsg.fd_num = fd_num;
1367 
1368     if (!vu_has_protocol_feature(dev, VHOST_USER_PROTOCOL_F_SLAVE_SEND_FD)) {
1369         return false;
1370     }
1371 
1372     pthread_mutex_lock(&dev->slave_mutex);
1373     if (!vu_message_write(dev, dev->slave_fd, &vmsg)) {
1374         pthread_mutex_unlock(&dev->slave_mutex);
1375         return false;
1376     }
1377 
1378     /* Also unlocks the slave_mutex */
1379     return vu_process_message_reply(dev, &vmsg);
1380 }
1381 
1382 static bool
1383 vu_set_vring_call_exec(VuDev *dev, VhostUserMsg *vmsg)
1384 {
1385     int index = vmsg->payload.u64 & VHOST_USER_VRING_IDX_MASK;
1386     bool nofd = vmsg->payload.u64 & VHOST_USER_VRING_NOFD_MASK;
1387 
1388     DPRINT("u64: 0x%016"PRIx64"\n", vmsg->payload.u64);
1389 
1390     if (!vu_check_queue_msg_file(dev, vmsg)) {
1391         return false;
1392     }
1393 
1394     if (dev->vq[index].call_fd != -1) {
1395         close(dev->vq[index].call_fd);
1396         dev->vq[index].call_fd = -1;
1397     }
1398 
1399     dev->vq[index].call_fd = nofd ? -1 : vmsg->fds[0];
1400 
1401     /* in case of I/O hang after reconnecting */
1402     if (dev->vq[index].call_fd != -1 && eventfd_write(vmsg->fds[0], 1)) {
1403         return -1;
1404     }
1405 
1406     DPRINT("Got call_fd: %d for vq: %d\n", dev->vq[index].call_fd, index);
1407 
1408     return false;
1409 }
1410 
1411 static bool
1412 vu_set_vring_err_exec(VuDev *dev, VhostUserMsg *vmsg)
1413 {
1414     int index = vmsg->payload.u64 & VHOST_USER_VRING_IDX_MASK;
1415     bool nofd = vmsg->payload.u64 & VHOST_USER_VRING_NOFD_MASK;
1416 
1417     DPRINT("u64: 0x%016"PRIx64"\n", vmsg->payload.u64);
1418 
1419     if (!vu_check_queue_msg_file(dev, vmsg)) {
1420         return false;
1421     }
1422 
1423     if (dev->vq[index].err_fd != -1) {
1424         close(dev->vq[index].err_fd);
1425         dev->vq[index].err_fd = -1;
1426     }
1427 
1428     dev->vq[index].err_fd = nofd ? -1 : vmsg->fds[0];
1429 
1430     return false;
1431 }
1432 
1433 static bool
1434 vu_get_protocol_features_exec(VuDev *dev, VhostUserMsg *vmsg)
1435 {
1436     /*
1437      * Note that we support, but intentionally do not set,
1438      * VHOST_USER_PROTOCOL_F_INBAND_NOTIFICATIONS. This means that
1439      * a device implementation can return it in its callback
1440      * (get_protocol_features) if it wants to use this for
1441      * simulation, but it is otherwise not desirable (if even
1442      * implemented by the master.)
1443      */
1444     uint64_t features = 1ULL << VHOST_USER_PROTOCOL_F_MQ |
1445                         1ULL << VHOST_USER_PROTOCOL_F_LOG_SHMFD |
1446                         1ULL << VHOST_USER_PROTOCOL_F_SLAVE_REQ |
1447                         1ULL << VHOST_USER_PROTOCOL_F_HOST_NOTIFIER |
1448                         1ULL << VHOST_USER_PROTOCOL_F_SLAVE_SEND_FD |
1449                         1ULL << VHOST_USER_PROTOCOL_F_REPLY_ACK |
1450                         1ULL << VHOST_USER_PROTOCOL_F_CONFIGURE_MEM_SLOTS;
1451 
1452     if (have_userfault()) {
1453         features |= 1ULL << VHOST_USER_PROTOCOL_F_PAGEFAULT;
1454     }
1455 
1456     if (dev->iface->get_config && dev->iface->set_config) {
1457         features |= 1ULL << VHOST_USER_PROTOCOL_F_CONFIG;
1458     }
1459 
1460     if (dev->iface->get_protocol_features) {
1461         features |= dev->iface->get_protocol_features(dev);
1462     }
1463 
1464     vmsg_set_reply_u64(vmsg, features);
1465     return true;
1466 }
1467 
1468 static bool
1469 vu_set_protocol_features_exec(VuDev *dev, VhostUserMsg *vmsg)
1470 {
1471     uint64_t features = vmsg->payload.u64;
1472 
1473     DPRINT("u64: 0x%016"PRIx64"\n", features);
1474 
1475     dev->protocol_features = vmsg->payload.u64;
1476 
1477     if (vu_has_protocol_feature(dev,
1478                                 VHOST_USER_PROTOCOL_F_INBAND_NOTIFICATIONS) &&
1479         (!vu_has_protocol_feature(dev, VHOST_USER_PROTOCOL_F_SLAVE_REQ) ||
1480          !vu_has_protocol_feature(dev, VHOST_USER_PROTOCOL_F_REPLY_ACK))) {
1481         /*
1482          * The use case for using messages for kick/call is simulation, to make
1483          * the kick and call synchronous. To actually get that behaviour, both
1484          * of the other features are required.
1485          * Theoretically, one could use only kick messages, or do them without
1486          * having F_REPLY_ACK, but too many (possibly pending) messages on the
1487          * socket will eventually cause the master to hang, to avoid this in
1488          * scenarios where not desired enforce that the settings are in a way
1489          * that actually enables the simulation case.
1490          */
1491         vu_panic(dev,
1492                  "F_IN_BAND_NOTIFICATIONS requires F_SLAVE_REQ && F_REPLY_ACK");
1493         return false;
1494     }
1495 
1496     if (dev->iface->set_protocol_features) {
1497         dev->iface->set_protocol_features(dev, features);
1498     }
1499 
1500     return false;
1501 }
1502 
1503 static bool
1504 vu_get_queue_num_exec(VuDev *dev, VhostUserMsg *vmsg)
1505 {
1506     vmsg_set_reply_u64(vmsg, dev->max_queues);
1507     return true;
1508 }
1509 
1510 static bool
1511 vu_set_vring_enable_exec(VuDev *dev, VhostUserMsg *vmsg)
1512 {
1513     unsigned int index = vmsg->payload.state.index;
1514     unsigned int enable = vmsg->payload.state.num;
1515 
1516     DPRINT("State.index: %u\n", index);
1517     DPRINT("State.enable:   %u\n", enable);
1518 
1519     if (index >= dev->max_queues) {
1520         vu_panic(dev, "Invalid vring_enable index: %u", index);
1521         return false;
1522     }
1523 
1524     dev->vq[index].enable = enable;
1525     return false;
1526 }
1527 
1528 static bool
1529 vu_set_slave_req_fd(VuDev *dev, VhostUserMsg *vmsg)
1530 {
1531     if (vmsg->fd_num != 1) {
1532         vu_panic(dev, "Invalid slave_req_fd message (%d fd's)", vmsg->fd_num);
1533         return false;
1534     }
1535 
1536     if (dev->slave_fd != -1) {
1537         close(dev->slave_fd);
1538     }
1539     dev->slave_fd = vmsg->fds[0];
1540     DPRINT("Got slave_fd: %d\n", vmsg->fds[0]);
1541 
1542     return false;
1543 }
1544 
1545 static bool
1546 vu_get_config(VuDev *dev, VhostUserMsg *vmsg)
1547 {
1548     int ret = -1;
1549 
1550     if (dev->iface->get_config) {
1551         ret = dev->iface->get_config(dev, vmsg->payload.config.region,
1552                                      vmsg->payload.config.size);
1553     }
1554 
1555     if (ret) {
1556         /* resize to zero to indicate an error to master */
1557         vmsg->size = 0;
1558     }
1559 
1560     return true;
1561 }
1562 
1563 static bool
1564 vu_set_config(VuDev *dev, VhostUserMsg *vmsg)
1565 {
1566     int ret = -1;
1567 
1568     if (dev->iface->set_config) {
1569         ret = dev->iface->set_config(dev, vmsg->payload.config.region,
1570                                      vmsg->payload.config.offset,
1571                                      vmsg->payload.config.size,
1572                                      vmsg->payload.config.flags);
1573         if (ret) {
1574             vu_panic(dev, "Set virtio configuration space failed");
1575         }
1576     }
1577 
1578     return false;
1579 }
1580 
1581 static bool
1582 vu_set_postcopy_advise(VuDev *dev, VhostUserMsg *vmsg)
1583 {
1584     dev->postcopy_ufd = -1;
1585 #ifdef UFFDIO_API
1586     struct uffdio_api api_struct;
1587 
1588     dev->postcopy_ufd = syscall(__NR_userfaultfd, O_CLOEXEC | O_NONBLOCK);
1589     vmsg->size = 0;
1590 #endif
1591 
1592     if (dev->postcopy_ufd == -1) {
1593         vu_panic(dev, "Userfaultfd not available: %s", strerror(errno));
1594         goto out;
1595     }
1596 
1597 #ifdef UFFDIO_API
1598     api_struct.api = UFFD_API;
1599     api_struct.features = 0;
1600     if (ioctl(dev->postcopy_ufd, UFFDIO_API, &api_struct)) {
1601         vu_panic(dev, "Failed UFFDIO_API: %s", strerror(errno));
1602         close(dev->postcopy_ufd);
1603         dev->postcopy_ufd = -1;
1604         goto out;
1605     }
1606     /* TODO: Stash feature flags somewhere */
1607 #endif
1608 
1609 out:
1610     /* Return a ufd to the QEMU */
1611     vmsg->fd_num = 1;
1612     vmsg->fds[0] = dev->postcopy_ufd;
1613     return true; /* = send a reply */
1614 }
1615 
1616 static bool
1617 vu_set_postcopy_listen(VuDev *dev, VhostUserMsg *vmsg)
1618 {
1619     if (dev->nregions) {
1620         vu_panic(dev, "Regions already registered at postcopy-listen");
1621         vmsg_set_reply_u64(vmsg, -1);
1622         return true;
1623     }
1624     dev->postcopy_listening = true;
1625 
1626     vmsg_set_reply_u64(vmsg, 0);
1627     return true;
1628 }
1629 
1630 static bool
1631 vu_set_postcopy_end(VuDev *dev, VhostUserMsg *vmsg)
1632 {
1633     DPRINT("%s: Entry\n", __func__);
1634     dev->postcopy_listening = false;
1635     if (dev->postcopy_ufd > 0) {
1636         close(dev->postcopy_ufd);
1637         dev->postcopy_ufd = -1;
1638         DPRINT("%s: Done close\n", __func__);
1639     }
1640 
1641     vmsg_set_reply_u64(vmsg, 0);
1642     DPRINT("%s: exit\n", __func__);
1643     return true;
1644 }
1645 
1646 static inline uint64_t
1647 vu_inflight_queue_size(uint16_t queue_size)
1648 {
1649     return ALIGN_UP(sizeof(VuDescStateSplit) * queue_size +
1650            sizeof(uint16_t), INFLIGHT_ALIGNMENT);
1651 }
1652 
1653 #ifdef MFD_ALLOW_SEALING
1654 static void *
1655 memfd_alloc(const char *name, size_t size, unsigned int flags, int *fd)
1656 {
1657     void *ptr;
1658     int ret;
1659 
1660     *fd = memfd_create(name, MFD_ALLOW_SEALING);
1661     if (*fd < 0) {
1662         return NULL;
1663     }
1664 
1665     ret = ftruncate(*fd, size);
1666     if (ret < 0) {
1667         close(*fd);
1668         return NULL;
1669     }
1670 
1671     ret = fcntl(*fd, F_ADD_SEALS, flags);
1672     if (ret < 0) {
1673         close(*fd);
1674         return NULL;
1675     }
1676 
1677     ptr = mmap(0, size, PROT_READ | PROT_WRITE, MAP_SHARED, *fd, 0);
1678     if (ptr == MAP_FAILED) {
1679         close(*fd);
1680         return NULL;
1681     }
1682 
1683     return ptr;
1684 }
1685 #endif
1686 
1687 static bool
1688 vu_get_inflight_fd(VuDev *dev, VhostUserMsg *vmsg)
1689 {
1690     int fd = -1;
1691     void *addr = NULL;
1692     uint64_t mmap_size;
1693     uint16_t num_queues, queue_size;
1694 
1695     if (vmsg->size != sizeof(vmsg->payload.inflight)) {
1696         vu_panic(dev, "Invalid get_inflight_fd message:%d", vmsg->size);
1697         vmsg->payload.inflight.mmap_size = 0;
1698         return true;
1699     }
1700 
1701     num_queues = vmsg->payload.inflight.num_queues;
1702     queue_size = vmsg->payload.inflight.queue_size;
1703 
1704     DPRINT("set_inflight_fd num_queues: %"PRId16"\n", num_queues);
1705     DPRINT("set_inflight_fd queue_size: %"PRId16"\n", queue_size);
1706 
1707     mmap_size = vu_inflight_queue_size(queue_size) * num_queues;
1708 
1709 #ifdef MFD_ALLOW_SEALING
1710     addr = memfd_alloc("vhost-inflight", mmap_size,
1711                        F_SEAL_GROW | F_SEAL_SHRINK | F_SEAL_SEAL,
1712                        &fd);
1713 #else
1714     vu_panic(dev, "Not implemented: memfd support is missing");
1715 #endif
1716 
1717     if (!addr) {
1718         vu_panic(dev, "Failed to alloc vhost inflight area");
1719         vmsg->payload.inflight.mmap_size = 0;
1720         return true;
1721     }
1722 
1723     memset(addr, 0, mmap_size);
1724 
1725     dev->inflight_info.addr = addr;
1726     dev->inflight_info.size = vmsg->payload.inflight.mmap_size = mmap_size;
1727     dev->inflight_info.fd = vmsg->fds[0] = fd;
1728     vmsg->fd_num = 1;
1729     vmsg->payload.inflight.mmap_offset = 0;
1730 
1731     DPRINT("send inflight mmap_size: %"PRId64"\n",
1732            vmsg->payload.inflight.mmap_size);
1733     DPRINT("send inflight mmap offset: %"PRId64"\n",
1734            vmsg->payload.inflight.mmap_offset);
1735 
1736     return true;
1737 }
1738 
1739 static bool
1740 vu_set_inflight_fd(VuDev *dev, VhostUserMsg *vmsg)
1741 {
1742     int fd, i;
1743     uint64_t mmap_size, mmap_offset;
1744     uint16_t num_queues, queue_size;
1745     void *rc;
1746 
1747     if (vmsg->fd_num != 1 ||
1748         vmsg->size != sizeof(vmsg->payload.inflight)) {
1749         vu_panic(dev, "Invalid set_inflight_fd message size:%d fds:%d",
1750                  vmsg->size, vmsg->fd_num);
1751         return false;
1752     }
1753 
1754     fd = vmsg->fds[0];
1755     mmap_size = vmsg->payload.inflight.mmap_size;
1756     mmap_offset = vmsg->payload.inflight.mmap_offset;
1757     num_queues = vmsg->payload.inflight.num_queues;
1758     queue_size = vmsg->payload.inflight.queue_size;
1759 
1760     DPRINT("set_inflight_fd mmap_size: %"PRId64"\n", mmap_size);
1761     DPRINT("set_inflight_fd mmap_offset: %"PRId64"\n", mmap_offset);
1762     DPRINT("set_inflight_fd num_queues: %"PRId16"\n", num_queues);
1763     DPRINT("set_inflight_fd queue_size: %"PRId16"\n", queue_size);
1764 
1765     rc = mmap(0, mmap_size, PROT_READ | PROT_WRITE, MAP_SHARED,
1766               fd, mmap_offset);
1767 
1768     if (rc == MAP_FAILED) {
1769         vu_panic(dev, "set_inflight_fd mmap error: %s", strerror(errno));
1770         return false;
1771     }
1772 
1773     if (dev->inflight_info.fd) {
1774         close(dev->inflight_info.fd);
1775     }
1776 
1777     if (dev->inflight_info.addr) {
1778         munmap(dev->inflight_info.addr, dev->inflight_info.size);
1779     }
1780 
1781     dev->inflight_info.fd = fd;
1782     dev->inflight_info.addr = rc;
1783     dev->inflight_info.size = mmap_size;
1784 
1785     for (i = 0; i < num_queues; i++) {
1786         dev->vq[i].inflight = (VuVirtqInflight *)rc;
1787         dev->vq[i].inflight->desc_num = queue_size;
1788         rc = (void *)((char *)rc + vu_inflight_queue_size(queue_size));
1789     }
1790 
1791     return false;
1792 }
1793 
1794 static bool
1795 vu_handle_vring_kick(VuDev *dev, VhostUserMsg *vmsg)
1796 {
1797     unsigned int index = vmsg->payload.state.index;
1798 
1799     if (index >= dev->max_queues) {
1800         vu_panic(dev, "Invalid queue index: %u", index);
1801         return false;
1802     }
1803 
1804     DPRINT("Got kick message: handler:%p idx:%u\n",
1805            dev->vq[index].handler, index);
1806 
1807     if (!dev->vq[index].started) {
1808         dev->vq[index].started = true;
1809 
1810         if (dev->iface->queue_set_started) {
1811             dev->iface->queue_set_started(dev, index, true);
1812         }
1813     }
1814 
1815     if (dev->vq[index].handler) {
1816         dev->vq[index].handler(dev, index);
1817     }
1818 
1819     return false;
1820 }
1821 
1822 static bool vu_handle_get_max_memslots(VuDev *dev, VhostUserMsg *vmsg)
1823 {
1824     vmsg_set_reply_u64(vmsg, VHOST_USER_MAX_RAM_SLOTS);
1825 
1826     DPRINT("u64: 0x%016"PRIx64"\n", (uint64_t) VHOST_USER_MAX_RAM_SLOTS);
1827 
1828     return true;
1829 }
1830 
1831 static bool
1832 vu_process_message(VuDev *dev, VhostUserMsg *vmsg)
1833 {
1834     int do_reply = 0;
1835 
1836     /* Print out generic part of the request. */
1837     DPRINT("================ Vhost user message ================\n");
1838     DPRINT("Request: %s (%d)\n", vu_request_to_string(vmsg->request),
1839            vmsg->request);
1840     DPRINT("Flags:   0x%x\n", vmsg->flags);
1841     DPRINT("Size:    %u\n", vmsg->size);
1842 
1843     if (vmsg->fd_num) {
1844         int i;
1845         DPRINT("Fds:");
1846         for (i = 0; i < vmsg->fd_num; i++) {
1847             DPRINT(" %d", vmsg->fds[i]);
1848         }
1849         DPRINT("\n");
1850     }
1851 
1852     if (dev->iface->process_msg &&
1853         dev->iface->process_msg(dev, vmsg, &do_reply)) {
1854         return do_reply;
1855     }
1856 
1857     switch (vmsg->request) {
1858     case VHOST_USER_GET_FEATURES:
1859         return vu_get_features_exec(dev, vmsg);
1860     case VHOST_USER_SET_FEATURES:
1861         return vu_set_features_exec(dev, vmsg);
1862     case VHOST_USER_GET_PROTOCOL_FEATURES:
1863         return vu_get_protocol_features_exec(dev, vmsg);
1864     case VHOST_USER_SET_PROTOCOL_FEATURES:
1865         return vu_set_protocol_features_exec(dev, vmsg);
1866     case VHOST_USER_SET_OWNER:
1867         return vu_set_owner_exec(dev, vmsg);
1868     case VHOST_USER_RESET_OWNER:
1869         return vu_reset_device_exec(dev, vmsg);
1870     case VHOST_USER_SET_MEM_TABLE:
1871         return vu_set_mem_table_exec(dev, vmsg);
1872     case VHOST_USER_SET_LOG_BASE:
1873         return vu_set_log_base_exec(dev, vmsg);
1874     case VHOST_USER_SET_LOG_FD:
1875         return vu_set_log_fd_exec(dev, vmsg);
1876     case VHOST_USER_SET_VRING_NUM:
1877         return vu_set_vring_num_exec(dev, vmsg);
1878     case VHOST_USER_SET_VRING_ADDR:
1879         return vu_set_vring_addr_exec(dev, vmsg);
1880     case VHOST_USER_SET_VRING_BASE:
1881         return vu_set_vring_base_exec(dev, vmsg);
1882     case VHOST_USER_GET_VRING_BASE:
1883         return vu_get_vring_base_exec(dev, vmsg);
1884     case VHOST_USER_SET_VRING_KICK:
1885         return vu_set_vring_kick_exec(dev, vmsg);
1886     case VHOST_USER_SET_VRING_CALL:
1887         return vu_set_vring_call_exec(dev, vmsg);
1888     case VHOST_USER_SET_VRING_ERR:
1889         return vu_set_vring_err_exec(dev, vmsg);
1890     case VHOST_USER_GET_QUEUE_NUM:
1891         return vu_get_queue_num_exec(dev, vmsg);
1892     case VHOST_USER_SET_VRING_ENABLE:
1893         return vu_set_vring_enable_exec(dev, vmsg);
1894     case VHOST_USER_SET_SLAVE_REQ_FD:
1895         return vu_set_slave_req_fd(dev, vmsg);
1896     case VHOST_USER_GET_CONFIG:
1897         return vu_get_config(dev, vmsg);
1898     case VHOST_USER_SET_CONFIG:
1899         return vu_set_config(dev, vmsg);
1900     case VHOST_USER_NONE:
1901         /* if you need processing before exit, override iface->process_msg */
1902         exit(0);
1903     case VHOST_USER_POSTCOPY_ADVISE:
1904         return vu_set_postcopy_advise(dev, vmsg);
1905     case VHOST_USER_POSTCOPY_LISTEN:
1906         return vu_set_postcopy_listen(dev, vmsg);
1907     case VHOST_USER_POSTCOPY_END:
1908         return vu_set_postcopy_end(dev, vmsg);
1909     case VHOST_USER_GET_INFLIGHT_FD:
1910         return vu_get_inflight_fd(dev, vmsg);
1911     case VHOST_USER_SET_INFLIGHT_FD:
1912         return vu_set_inflight_fd(dev, vmsg);
1913     case VHOST_USER_VRING_KICK:
1914         return vu_handle_vring_kick(dev, vmsg);
1915     case VHOST_USER_GET_MAX_MEM_SLOTS:
1916         return vu_handle_get_max_memslots(dev, vmsg);
1917     case VHOST_USER_ADD_MEM_REG:
1918         return vu_add_mem_reg(dev, vmsg);
1919     case VHOST_USER_REM_MEM_REG:
1920         return vu_rem_mem_reg(dev, vmsg);
1921     default:
1922         vmsg_close_fds(vmsg);
1923         vu_panic(dev, "Unhandled request: %d", vmsg->request);
1924     }
1925 
1926     return false;
1927 }
1928 
1929 bool
1930 vu_dispatch(VuDev *dev)
1931 {
1932     VhostUserMsg vmsg = { 0, };
1933     int reply_requested;
1934     bool need_reply, success = false;
1935 
1936     if (!dev->read_msg(dev, dev->sock, &vmsg)) {
1937         goto end;
1938     }
1939 
1940     need_reply = vmsg.flags & VHOST_USER_NEED_REPLY_MASK;
1941 
1942     reply_requested = vu_process_message(dev, &vmsg);
1943     if (!reply_requested && need_reply) {
1944         vmsg_set_reply_u64(&vmsg, 0);
1945         reply_requested = 1;
1946     }
1947 
1948     if (!reply_requested) {
1949         success = true;
1950         goto end;
1951     }
1952 
1953     if (!vu_send_reply(dev, dev->sock, &vmsg)) {
1954         goto end;
1955     }
1956 
1957     success = true;
1958 
1959 end:
1960     free(vmsg.data);
1961     return success;
1962 }
1963 
1964 void
1965 vu_deinit(VuDev *dev)
1966 {
1967     int i;
1968 
1969     for (i = 0; i < dev->nregions; i++) {
1970         VuDevRegion *r = &dev->regions[i];
1971         void *m = (void *) (uintptr_t) r->mmap_addr;
1972         if (m != MAP_FAILED) {
1973             munmap(m, r->size + r->mmap_offset);
1974         }
1975     }
1976     dev->nregions = 0;
1977 
1978     for (i = 0; i < dev->max_queues; i++) {
1979         VuVirtq *vq = &dev->vq[i];
1980 
1981         if (vq->call_fd != -1) {
1982             close(vq->call_fd);
1983             vq->call_fd = -1;
1984         }
1985 
1986         if (vq->kick_fd != -1) {
1987             dev->remove_watch(dev, vq->kick_fd);
1988             close(vq->kick_fd);
1989             vq->kick_fd = -1;
1990         }
1991 
1992         if (vq->err_fd != -1) {
1993             close(vq->err_fd);
1994             vq->err_fd = -1;
1995         }
1996 
1997         if (vq->resubmit_list) {
1998             free(vq->resubmit_list);
1999             vq->resubmit_list = NULL;
2000         }
2001 
2002         vq->inflight = NULL;
2003     }
2004 
2005     if (dev->inflight_info.addr) {
2006         munmap(dev->inflight_info.addr, dev->inflight_info.size);
2007         dev->inflight_info.addr = NULL;
2008     }
2009 
2010     if (dev->inflight_info.fd > 0) {
2011         close(dev->inflight_info.fd);
2012         dev->inflight_info.fd = -1;
2013     }
2014 
2015     vu_close_log(dev);
2016     if (dev->slave_fd != -1) {
2017         close(dev->slave_fd);
2018         dev->slave_fd = -1;
2019     }
2020     pthread_mutex_destroy(&dev->slave_mutex);
2021 
2022     if (dev->sock != -1) {
2023         close(dev->sock);
2024     }
2025 
2026     free(dev->vq);
2027     dev->vq = NULL;
2028 }
2029 
2030 bool
2031 vu_init(VuDev *dev,
2032         uint16_t max_queues,
2033         int socket,
2034         vu_panic_cb panic,
2035         vu_read_msg_cb read_msg,
2036         vu_set_watch_cb set_watch,
2037         vu_remove_watch_cb remove_watch,
2038         const VuDevIface *iface)
2039 {
2040     uint16_t i;
2041 
2042     assert(max_queues > 0);
2043     assert(socket >= 0);
2044     assert(set_watch);
2045     assert(remove_watch);
2046     assert(iface);
2047     assert(panic);
2048 
2049     memset(dev, 0, sizeof(*dev));
2050 
2051     dev->sock = socket;
2052     dev->panic = panic;
2053     dev->read_msg = read_msg ? read_msg : vu_message_read_default;
2054     dev->set_watch = set_watch;
2055     dev->remove_watch = remove_watch;
2056     dev->iface = iface;
2057     dev->log_call_fd = -1;
2058     pthread_mutex_init(&dev->slave_mutex, NULL);
2059     dev->slave_fd = -1;
2060     dev->max_queues = max_queues;
2061 
2062     dev->vq = malloc(max_queues * sizeof(dev->vq[0]));
2063     if (!dev->vq) {
2064         DPRINT("%s: failed to malloc virtqueues\n", __func__);
2065         return false;
2066     }
2067 
2068     for (i = 0; i < max_queues; i++) {
2069         dev->vq[i] = (VuVirtq) {
2070             .call_fd = -1, .kick_fd = -1, .err_fd = -1,
2071             .notification = true,
2072         };
2073     }
2074 
2075     return true;
2076 }
2077 
2078 VuVirtq *
2079 vu_get_queue(VuDev *dev, int qidx)
2080 {
2081     assert(qidx < dev->max_queues);
2082     return &dev->vq[qidx];
2083 }
2084 
2085 bool
2086 vu_queue_enabled(VuDev *dev, VuVirtq *vq)
2087 {
2088     return vq->enable;
2089 }
2090 
2091 bool
2092 vu_queue_started(const VuDev *dev, const VuVirtq *vq)
2093 {
2094     return vq->started;
2095 }
2096 
2097 static inline uint16_t
2098 vring_avail_flags(VuVirtq *vq)
2099 {
2100     return le16toh(vq->vring.avail->flags);
2101 }
2102 
2103 static inline uint16_t
2104 vring_avail_idx(VuVirtq *vq)
2105 {
2106     vq->shadow_avail_idx = le16toh(vq->vring.avail->idx);
2107 
2108     return vq->shadow_avail_idx;
2109 }
2110 
2111 static inline uint16_t
2112 vring_avail_ring(VuVirtq *vq, int i)
2113 {
2114     return le16toh(vq->vring.avail->ring[i]);
2115 }
2116 
2117 static inline uint16_t
2118 vring_get_used_event(VuVirtq *vq)
2119 {
2120     return vring_avail_ring(vq, vq->vring.num);
2121 }
2122 
2123 static int
2124 virtqueue_num_heads(VuDev *dev, VuVirtq *vq, unsigned int idx)
2125 {
2126     uint16_t num_heads = vring_avail_idx(vq) - idx;
2127 
2128     /* Check it isn't doing very strange things with descriptor numbers. */
2129     if (num_heads > vq->vring.num) {
2130         vu_panic(dev, "Guest moved used index from %u to %u",
2131                  idx, vq->shadow_avail_idx);
2132         return -1;
2133     }
2134     if (num_heads) {
2135         /* On success, callers read a descriptor at vq->last_avail_idx.
2136          * Make sure descriptor read does not bypass avail index read. */
2137         smp_rmb();
2138     }
2139 
2140     return num_heads;
2141 }
2142 
2143 static bool
2144 virtqueue_get_head(VuDev *dev, VuVirtq *vq,
2145                    unsigned int idx, unsigned int *head)
2146 {
2147     /* Grab the next descriptor number they're advertising, and increment
2148      * the index we've seen. */
2149     *head = vring_avail_ring(vq, idx % vq->vring.num);
2150 
2151     /* If their number is silly, that's a fatal mistake. */
2152     if (*head >= vq->vring.num) {
2153         vu_panic(dev, "Guest says index %u is available", *head);
2154         return false;
2155     }
2156 
2157     return true;
2158 }
2159 
2160 static int
2161 virtqueue_read_indirect_desc(VuDev *dev, struct vring_desc *desc,
2162                              uint64_t addr, size_t len)
2163 {
2164     struct vring_desc *ori_desc;
2165     uint64_t read_len;
2166 
2167     if (len > (VIRTQUEUE_MAX_SIZE * sizeof(struct vring_desc))) {
2168         return -1;
2169     }
2170 
2171     if (len == 0) {
2172         return -1;
2173     }
2174 
2175     while (len) {
2176         read_len = len;
2177         ori_desc = vu_gpa_to_va(dev, &read_len, addr);
2178         if (!ori_desc) {
2179             return -1;
2180         }
2181 
2182         memcpy(desc, ori_desc, read_len);
2183         len -= read_len;
2184         addr += read_len;
2185         desc += read_len;
2186     }
2187 
2188     return 0;
2189 }
2190 
2191 enum {
2192     VIRTQUEUE_READ_DESC_ERROR = -1,
2193     VIRTQUEUE_READ_DESC_DONE = 0,   /* end of chain */
2194     VIRTQUEUE_READ_DESC_MORE = 1,   /* more buffers in chain */
2195 };
2196 
2197 static int
2198 virtqueue_read_next_desc(VuDev *dev, struct vring_desc *desc,
2199                          int i, unsigned int max, unsigned int *next)
2200 {
2201     /* If this descriptor says it doesn't chain, we're done. */
2202     if (!(le16toh(desc[i].flags) & VRING_DESC_F_NEXT)) {
2203         return VIRTQUEUE_READ_DESC_DONE;
2204     }
2205 
2206     /* Check they're not leading us off end of descriptors. */
2207     *next = le16toh(desc[i].next);
2208     /* Make sure compiler knows to grab that: we don't want it changing! */
2209     smp_wmb();
2210 
2211     if (*next >= max) {
2212         vu_panic(dev, "Desc next is %u", *next);
2213         return VIRTQUEUE_READ_DESC_ERROR;
2214     }
2215 
2216     return VIRTQUEUE_READ_DESC_MORE;
2217 }
2218 
2219 void
2220 vu_queue_get_avail_bytes(VuDev *dev, VuVirtq *vq, unsigned int *in_bytes,
2221                          unsigned int *out_bytes,
2222                          unsigned max_in_bytes, unsigned max_out_bytes)
2223 {
2224     unsigned int idx;
2225     unsigned int total_bufs, in_total, out_total;
2226     int rc;
2227 
2228     idx = vq->last_avail_idx;
2229 
2230     total_bufs = in_total = out_total = 0;
2231     if (unlikely(dev->broken) ||
2232         unlikely(!vq->vring.avail)) {
2233         goto done;
2234     }
2235 
2236     while ((rc = virtqueue_num_heads(dev, vq, idx)) > 0) {
2237         unsigned int max, desc_len, num_bufs, indirect = 0;
2238         uint64_t desc_addr, read_len;
2239         struct vring_desc *desc;
2240         struct vring_desc desc_buf[VIRTQUEUE_MAX_SIZE];
2241         unsigned int i;
2242 
2243         max = vq->vring.num;
2244         num_bufs = total_bufs;
2245         if (!virtqueue_get_head(dev, vq, idx++, &i)) {
2246             goto err;
2247         }
2248         desc = vq->vring.desc;
2249 
2250         if (le16toh(desc[i].flags) & VRING_DESC_F_INDIRECT) {
2251             if (le32toh(desc[i].len) % sizeof(struct vring_desc)) {
2252                 vu_panic(dev, "Invalid size for indirect buffer table");
2253                 goto err;
2254             }
2255 
2256             /* If we've got too many, that implies a descriptor loop. */
2257             if (num_bufs >= max) {
2258                 vu_panic(dev, "Looped descriptor");
2259                 goto err;
2260             }
2261 
2262             /* loop over the indirect descriptor table */
2263             indirect = 1;
2264             desc_addr = le64toh(desc[i].addr);
2265             desc_len = le32toh(desc[i].len);
2266             max = desc_len / sizeof(struct vring_desc);
2267             read_len = desc_len;
2268             desc = vu_gpa_to_va(dev, &read_len, desc_addr);
2269             if (unlikely(desc && read_len != desc_len)) {
2270                 /* Failed to use zero copy */
2271                 desc = NULL;
2272                 if (!virtqueue_read_indirect_desc(dev, desc_buf,
2273                                                   desc_addr,
2274                                                   desc_len)) {
2275                     desc = desc_buf;
2276                 }
2277             }
2278             if (!desc) {
2279                 vu_panic(dev, "Invalid indirect buffer table");
2280                 goto err;
2281             }
2282             num_bufs = i = 0;
2283         }
2284 
2285         do {
2286             /* If we've got too many, that implies a descriptor loop. */
2287             if (++num_bufs > max) {
2288                 vu_panic(dev, "Looped descriptor");
2289                 goto err;
2290             }
2291 
2292             if (le16toh(desc[i].flags) & VRING_DESC_F_WRITE) {
2293                 in_total += le32toh(desc[i].len);
2294             } else {
2295                 out_total += le32toh(desc[i].len);
2296             }
2297             if (in_total >= max_in_bytes && out_total >= max_out_bytes) {
2298                 goto done;
2299             }
2300             rc = virtqueue_read_next_desc(dev, desc, i, max, &i);
2301         } while (rc == VIRTQUEUE_READ_DESC_MORE);
2302 
2303         if (rc == VIRTQUEUE_READ_DESC_ERROR) {
2304             goto err;
2305         }
2306 
2307         if (!indirect) {
2308             total_bufs = num_bufs;
2309         } else {
2310             total_bufs++;
2311         }
2312     }
2313     if (rc < 0) {
2314         goto err;
2315     }
2316 done:
2317     if (in_bytes) {
2318         *in_bytes = in_total;
2319     }
2320     if (out_bytes) {
2321         *out_bytes = out_total;
2322     }
2323     return;
2324 
2325 err:
2326     in_total = out_total = 0;
2327     goto done;
2328 }
2329 
2330 bool
2331 vu_queue_avail_bytes(VuDev *dev, VuVirtq *vq, unsigned int in_bytes,
2332                      unsigned int out_bytes)
2333 {
2334     unsigned int in_total, out_total;
2335 
2336     vu_queue_get_avail_bytes(dev, vq, &in_total, &out_total,
2337                              in_bytes, out_bytes);
2338 
2339     return in_bytes <= in_total && out_bytes <= out_total;
2340 }
2341 
2342 /* Fetch avail_idx from VQ memory only when we really need to know if
2343  * guest has added some buffers. */
2344 bool
2345 vu_queue_empty(VuDev *dev, VuVirtq *vq)
2346 {
2347     if (unlikely(dev->broken) ||
2348         unlikely(!vq->vring.avail)) {
2349         return true;
2350     }
2351 
2352     if (vq->shadow_avail_idx != vq->last_avail_idx) {
2353         return false;
2354     }
2355 
2356     return vring_avail_idx(vq) == vq->last_avail_idx;
2357 }
2358 
2359 static bool
2360 vring_notify(VuDev *dev, VuVirtq *vq)
2361 {
2362     uint16_t old, new;
2363     bool v;
2364 
2365     /* We need to expose used array entries before checking used event. */
2366     smp_mb();
2367 
2368     /* Always notify when queue is empty (when feature acknowledge) */
2369     if (vu_has_feature(dev, VIRTIO_F_NOTIFY_ON_EMPTY) &&
2370         !vq->inuse && vu_queue_empty(dev, vq)) {
2371         return true;
2372     }
2373 
2374     if (!vu_has_feature(dev, VIRTIO_RING_F_EVENT_IDX)) {
2375         return !(vring_avail_flags(vq) & VRING_AVAIL_F_NO_INTERRUPT);
2376     }
2377 
2378     v = vq->signalled_used_valid;
2379     vq->signalled_used_valid = true;
2380     old = vq->signalled_used;
2381     new = vq->signalled_used = vq->used_idx;
2382     return !v || vring_need_event(vring_get_used_event(vq), new, old);
2383 }
2384 
2385 static void _vu_queue_notify(VuDev *dev, VuVirtq *vq, bool sync)
2386 {
2387     if (unlikely(dev->broken) ||
2388         unlikely(!vq->vring.avail)) {
2389         return;
2390     }
2391 
2392     if (!vring_notify(dev, vq)) {
2393         DPRINT("skipped notify...\n");
2394         return;
2395     }
2396 
2397     if (vq->call_fd < 0 &&
2398         vu_has_protocol_feature(dev,
2399                                 VHOST_USER_PROTOCOL_F_INBAND_NOTIFICATIONS) &&
2400         vu_has_protocol_feature(dev, VHOST_USER_PROTOCOL_F_SLAVE_REQ)) {
2401         VhostUserMsg vmsg = {
2402             .request = VHOST_USER_SLAVE_VRING_CALL,
2403             .flags = VHOST_USER_VERSION,
2404             .size = sizeof(vmsg.payload.state),
2405             .payload.state = {
2406                 .index = vq - dev->vq,
2407             },
2408         };
2409         bool ack = sync &&
2410                    vu_has_protocol_feature(dev,
2411                                            VHOST_USER_PROTOCOL_F_REPLY_ACK);
2412 
2413         if (ack) {
2414             vmsg.flags |= VHOST_USER_NEED_REPLY_MASK;
2415         }
2416 
2417         vu_message_write(dev, dev->slave_fd, &vmsg);
2418         if (ack) {
2419             vu_message_read_default(dev, dev->slave_fd, &vmsg);
2420         }
2421         return;
2422     }
2423 
2424     if (eventfd_write(vq->call_fd, 1) < 0) {
2425         vu_panic(dev, "Error writing eventfd: %s", strerror(errno));
2426     }
2427 }
2428 
2429 void vu_queue_notify(VuDev *dev, VuVirtq *vq)
2430 {
2431     _vu_queue_notify(dev, vq, false);
2432 }
2433 
2434 void vu_queue_notify_sync(VuDev *dev, VuVirtq *vq)
2435 {
2436     _vu_queue_notify(dev, vq, true);
2437 }
2438 
2439 static inline void
2440 vring_used_flags_set_bit(VuVirtq *vq, int mask)
2441 {
2442     uint16_t *flags;
2443 
2444     flags = (uint16_t *)((char*)vq->vring.used +
2445                          offsetof(struct vring_used, flags));
2446     *flags = htole16(le16toh(*flags) | mask);
2447 }
2448 
2449 static inline void
2450 vring_used_flags_unset_bit(VuVirtq *vq, int mask)
2451 {
2452     uint16_t *flags;
2453 
2454     flags = (uint16_t *)((char*)vq->vring.used +
2455                          offsetof(struct vring_used, flags));
2456     *flags = htole16(le16toh(*flags) & ~mask);
2457 }
2458 
2459 static inline void
2460 vring_set_avail_event(VuVirtq *vq, uint16_t val)
2461 {
2462     uint16_t *avail;
2463 
2464     if (!vq->notification) {
2465         return;
2466     }
2467 
2468     avail = (uint16_t *)&vq->vring.used->ring[vq->vring.num];
2469     *avail = htole16(val);
2470 }
2471 
2472 void
2473 vu_queue_set_notification(VuDev *dev, VuVirtq *vq, int enable)
2474 {
2475     vq->notification = enable;
2476     if (vu_has_feature(dev, VIRTIO_RING_F_EVENT_IDX)) {
2477         vring_set_avail_event(vq, vring_avail_idx(vq));
2478     } else if (enable) {
2479         vring_used_flags_unset_bit(vq, VRING_USED_F_NO_NOTIFY);
2480     } else {
2481         vring_used_flags_set_bit(vq, VRING_USED_F_NO_NOTIFY);
2482     }
2483     if (enable) {
2484         /* Expose avail event/used flags before caller checks the avail idx. */
2485         smp_mb();
2486     }
2487 }
2488 
2489 static bool
2490 virtqueue_map_desc(VuDev *dev,
2491                    unsigned int *p_num_sg, struct iovec *iov,
2492                    unsigned int max_num_sg, bool is_write,
2493                    uint64_t pa, size_t sz)
2494 {
2495     unsigned num_sg = *p_num_sg;
2496 
2497     assert(num_sg <= max_num_sg);
2498 
2499     if (!sz) {
2500         vu_panic(dev, "virtio: zero sized buffers are not allowed");
2501         return false;
2502     }
2503 
2504     while (sz) {
2505         uint64_t len = sz;
2506 
2507         if (num_sg == max_num_sg) {
2508             vu_panic(dev, "virtio: too many descriptors in indirect table");
2509             return false;
2510         }
2511 
2512         iov[num_sg].iov_base = vu_gpa_to_va(dev, &len, pa);
2513         if (iov[num_sg].iov_base == NULL) {
2514             vu_panic(dev, "virtio: invalid address for buffers");
2515             return false;
2516         }
2517         iov[num_sg].iov_len = len;
2518         num_sg++;
2519         sz -= len;
2520         pa += len;
2521     }
2522 
2523     *p_num_sg = num_sg;
2524     return true;
2525 }
2526 
2527 static void *
2528 virtqueue_alloc_element(size_t sz,
2529                                      unsigned out_num, unsigned in_num)
2530 {
2531     VuVirtqElement *elem;
2532     size_t in_sg_ofs = ALIGN_UP(sz, __alignof__(elem->in_sg[0]));
2533     size_t out_sg_ofs = in_sg_ofs + in_num * sizeof(elem->in_sg[0]);
2534     size_t out_sg_end = out_sg_ofs + out_num * sizeof(elem->out_sg[0]);
2535 
2536     assert(sz >= sizeof(VuVirtqElement));
2537     elem = malloc(out_sg_end);
2538     elem->out_num = out_num;
2539     elem->in_num = in_num;
2540     elem->in_sg = (void *)elem + in_sg_ofs;
2541     elem->out_sg = (void *)elem + out_sg_ofs;
2542     return elem;
2543 }
2544 
2545 static void *
2546 vu_queue_map_desc(VuDev *dev, VuVirtq *vq, unsigned int idx, size_t sz)
2547 {
2548     struct vring_desc *desc = vq->vring.desc;
2549     uint64_t desc_addr, read_len;
2550     unsigned int desc_len;
2551     unsigned int max = vq->vring.num;
2552     unsigned int i = idx;
2553     VuVirtqElement *elem;
2554     unsigned int out_num = 0, in_num = 0;
2555     struct iovec iov[VIRTQUEUE_MAX_SIZE];
2556     struct vring_desc desc_buf[VIRTQUEUE_MAX_SIZE];
2557     int rc;
2558 
2559     if (le16toh(desc[i].flags) & VRING_DESC_F_INDIRECT) {
2560         if (le32toh(desc[i].len) % sizeof(struct vring_desc)) {
2561             vu_panic(dev, "Invalid size for indirect buffer table");
2562             return NULL;
2563         }
2564 
2565         /* loop over the indirect descriptor table */
2566         desc_addr = le64toh(desc[i].addr);
2567         desc_len = le32toh(desc[i].len);
2568         max = desc_len / sizeof(struct vring_desc);
2569         read_len = desc_len;
2570         desc = vu_gpa_to_va(dev, &read_len, desc_addr);
2571         if (unlikely(desc && read_len != desc_len)) {
2572             /* Failed to use zero copy */
2573             desc = NULL;
2574             if (!virtqueue_read_indirect_desc(dev, desc_buf,
2575                                               desc_addr,
2576                                               desc_len)) {
2577                 desc = desc_buf;
2578             }
2579         }
2580         if (!desc) {
2581             vu_panic(dev, "Invalid indirect buffer table");
2582             return NULL;
2583         }
2584         i = 0;
2585     }
2586 
2587     /* Collect all the descriptors */
2588     do {
2589         if (le16toh(desc[i].flags) & VRING_DESC_F_WRITE) {
2590             if (!virtqueue_map_desc(dev, &in_num, iov + out_num,
2591                                VIRTQUEUE_MAX_SIZE - out_num, true,
2592                                le64toh(desc[i].addr),
2593                                le32toh(desc[i].len))) {
2594                 return NULL;
2595             }
2596         } else {
2597             if (in_num) {
2598                 vu_panic(dev, "Incorrect order for descriptors");
2599                 return NULL;
2600             }
2601             if (!virtqueue_map_desc(dev, &out_num, iov,
2602                                VIRTQUEUE_MAX_SIZE, false,
2603                                le64toh(desc[i].addr),
2604                                le32toh(desc[i].len))) {
2605                 return NULL;
2606             }
2607         }
2608 
2609         /* If we've got too many, that implies a descriptor loop. */
2610         if ((in_num + out_num) > max) {
2611             vu_panic(dev, "Looped descriptor");
2612             return NULL;
2613         }
2614         rc = virtqueue_read_next_desc(dev, desc, i, max, &i);
2615     } while (rc == VIRTQUEUE_READ_DESC_MORE);
2616 
2617     if (rc == VIRTQUEUE_READ_DESC_ERROR) {
2618         vu_panic(dev, "read descriptor error");
2619         return NULL;
2620     }
2621 
2622     /* Now copy what we have collected and mapped */
2623     elem = virtqueue_alloc_element(sz, out_num, in_num);
2624     elem->index = idx;
2625     for (i = 0; i < out_num; i++) {
2626         elem->out_sg[i] = iov[i];
2627     }
2628     for (i = 0; i < in_num; i++) {
2629         elem->in_sg[i] = iov[out_num + i];
2630     }
2631 
2632     return elem;
2633 }
2634 
2635 static int
2636 vu_queue_inflight_get(VuDev *dev, VuVirtq *vq, int desc_idx)
2637 {
2638     if (!vu_has_protocol_feature(dev, VHOST_USER_PROTOCOL_F_INFLIGHT_SHMFD)) {
2639         return 0;
2640     }
2641 
2642     if (unlikely(!vq->inflight)) {
2643         return -1;
2644     }
2645 
2646     vq->inflight->desc[desc_idx].counter = vq->counter++;
2647     vq->inflight->desc[desc_idx].inflight = 1;
2648 
2649     return 0;
2650 }
2651 
2652 static int
2653 vu_queue_inflight_pre_put(VuDev *dev, VuVirtq *vq, int desc_idx)
2654 {
2655     if (!vu_has_protocol_feature(dev, VHOST_USER_PROTOCOL_F_INFLIGHT_SHMFD)) {
2656         return 0;
2657     }
2658 
2659     if (unlikely(!vq->inflight)) {
2660         return -1;
2661     }
2662 
2663     vq->inflight->last_batch_head = desc_idx;
2664 
2665     return 0;
2666 }
2667 
2668 static int
2669 vu_queue_inflight_post_put(VuDev *dev, VuVirtq *vq, int desc_idx)
2670 {
2671     if (!vu_has_protocol_feature(dev, VHOST_USER_PROTOCOL_F_INFLIGHT_SHMFD)) {
2672         return 0;
2673     }
2674 
2675     if (unlikely(!vq->inflight)) {
2676         return -1;
2677     }
2678 
2679     barrier();
2680 
2681     vq->inflight->desc[desc_idx].inflight = 0;
2682 
2683     barrier();
2684 
2685     vq->inflight->used_idx = vq->used_idx;
2686 
2687     return 0;
2688 }
2689 
2690 void *
2691 vu_queue_pop(VuDev *dev, VuVirtq *vq, size_t sz)
2692 {
2693     int i;
2694     unsigned int head;
2695     VuVirtqElement *elem;
2696 
2697     if (unlikely(dev->broken) ||
2698         unlikely(!vq->vring.avail)) {
2699         return NULL;
2700     }
2701 
2702     if (unlikely(vq->resubmit_list && vq->resubmit_num > 0)) {
2703         i = (--vq->resubmit_num);
2704         elem = vu_queue_map_desc(dev, vq, vq->resubmit_list[i].index, sz);
2705 
2706         if (!vq->resubmit_num) {
2707             free(vq->resubmit_list);
2708             vq->resubmit_list = NULL;
2709         }
2710 
2711         return elem;
2712     }
2713 
2714     if (vu_queue_empty(dev, vq)) {
2715         return NULL;
2716     }
2717     /*
2718      * Needed after virtio_queue_empty(), see comment in
2719      * virtqueue_num_heads().
2720      */
2721     smp_rmb();
2722 
2723     if (vq->inuse >= vq->vring.num) {
2724         vu_panic(dev, "Virtqueue size exceeded");
2725         return NULL;
2726     }
2727 
2728     if (!virtqueue_get_head(dev, vq, vq->last_avail_idx++, &head)) {
2729         return NULL;
2730     }
2731 
2732     if (vu_has_feature(dev, VIRTIO_RING_F_EVENT_IDX)) {
2733         vring_set_avail_event(vq, vq->last_avail_idx);
2734     }
2735 
2736     elem = vu_queue_map_desc(dev, vq, head, sz);
2737 
2738     if (!elem) {
2739         return NULL;
2740     }
2741 
2742     vq->inuse++;
2743 
2744     vu_queue_inflight_get(dev, vq, head);
2745 
2746     return elem;
2747 }
2748 
2749 static void
2750 vu_queue_detach_element(VuDev *dev, VuVirtq *vq, VuVirtqElement *elem,
2751                         size_t len)
2752 {
2753     vq->inuse--;
2754     /* unmap, when DMA support is added */
2755 }
2756 
2757 void
2758 vu_queue_unpop(VuDev *dev, VuVirtq *vq, VuVirtqElement *elem,
2759                size_t len)
2760 {
2761     vq->last_avail_idx--;
2762     vu_queue_detach_element(dev, vq, elem, len);
2763 }
2764 
2765 bool
2766 vu_queue_rewind(VuDev *dev, VuVirtq *vq, unsigned int num)
2767 {
2768     if (num > vq->inuse) {
2769         return false;
2770     }
2771     vq->last_avail_idx -= num;
2772     vq->inuse -= num;
2773     return true;
2774 }
2775 
2776 static inline
2777 void vring_used_write(VuDev *dev, VuVirtq *vq,
2778                       struct vring_used_elem *uelem, int i)
2779 {
2780     struct vring_used *used = vq->vring.used;
2781 
2782     used->ring[i] = *uelem;
2783     vu_log_write(dev, vq->vring.log_guest_addr +
2784                  offsetof(struct vring_used, ring[i]),
2785                  sizeof(used->ring[i]));
2786 }
2787 
2788 
2789 static void
2790 vu_log_queue_fill(VuDev *dev, VuVirtq *vq,
2791                   const VuVirtqElement *elem,
2792                   unsigned int len)
2793 {
2794     struct vring_desc *desc = vq->vring.desc;
2795     unsigned int i, max, min, desc_len;
2796     uint64_t desc_addr, read_len;
2797     struct vring_desc desc_buf[VIRTQUEUE_MAX_SIZE];
2798     unsigned num_bufs = 0;
2799 
2800     max = vq->vring.num;
2801     i = elem->index;
2802 
2803     if (le16toh(desc[i].flags) & VRING_DESC_F_INDIRECT) {
2804         if (le32toh(desc[i].len) % sizeof(struct vring_desc)) {
2805             vu_panic(dev, "Invalid size for indirect buffer table");
2806             return;
2807         }
2808 
2809         /* loop over the indirect descriptor table */
2810         desc_addr = le64toh(desc[i].addr);
2811         desc_len = le32toh(desc[i].len);
2812         max = desc_len / sizeof(struct vring_desc);
2813         read_len = desc_len;
2814         desc = vu_gpa_to_va(dev, &read_len, desc_addr);
2815         if (unlikely(desc && read_len != desc_len)) {
2816             /* Failed to use zero copy */
2817             desc = NULL;
2818             if (!virtqueue_read_indirect_desc(dev, desc_buf,
2819                                               desc_addr,
2820                                               desc_len)) {
2821                 desc = desc_buf;
2822             }
2823         }
2824         if (!desc) {
2825             vu_panic(dev, "Invalid indirect buffer table");
2826             return;
2827         }
2828         i = 0;
2829     }
2830 
2831     do {
2832         if (++num_bufs > max) {
2833             vu_panic(dev, "Looped descriptor");
2834             return;
2835         }
2836 
2837         if (le16toh(desc[i].flags) & VRING_DESC_F_WRITE) {
2838             min = MIN(le32toh(desc[i].len), len);
2839             vu_log_write(dev, le64toh(desc[i].addr), min);
2840             len -= min;
2841         }
2842 
2843     } while (len > 0 &&
2844              (virtqueue_read_next_desc(dev, desc, i, max, &i)
2845               == VIRTQUEUE_READ_DESC_MORE));
2846 }
2847 
2848 void
2849 vu_queue_fill(VuDev *dev, VuVirtq *vq,
2850               const VuVirtqElement *elem,
2851               unsigned int len, unsigned int idx)
2852 {
2853     struct vring_used_elem uelem;
2854 
2855     if (unlikely(dev->broken) ||
2856         unlikely(!vq->vring.avail)) {
2857         return;
2858     }
2859 
2860     vu_log_queue_fill(dev, vq, elem, len);
2861 
2862     idx = (idx + vq->used_idx) % vq->vring.num;
2863 
2864     uelem.id = htole32(elem->index);
2865     uelem.len = htole32(len);
2866     vring_used_write(dev, vq, &uelem, idx);
2867 }
2868 
2869 static inline
2870 void vring_used_idx_set(VuDev *dev, VuVirtq *vq, uint16_t val)
2871 {
2872     vq->vring.used->idx = htole16(val);
2873     vu_log_write(dev,
2874                  vq->vring.log_guest_addr + offsetof(struct vring_used, idx),
2875                  sizeof(vq->vring.used->idx));
2876 
2877     vq->used_idx = val;
2878 }
2879 
2880 void
2881 vu_queue_flush(VuDev *dev, VuVirtq *vq, unsigned int count)
2882 {
2883     uint16_t old, new;
2884 
2885     if (unlikely(dev->broken) ||
2886         unlikely(!vq->vring.avail)) {
2887         return;
2888     }
2889 
2890     /* Make sure buffer is written before we update index. */
2891     smp_wmb();
2892 
2893     old = vq->used_idx;
2894     new = old + count;
2895     vring_used_idx_set(dev, vq, new);
2896     vq->inuse -= count;
2897     if (unlikely((int16_t)(new - vq->signalled_used) < (uint16_t)(new - old))) {
2898         vq->signalled_used_valid = false;
2899     }
2900 }
2901 
2902 void
2903 vu_queue_push(VuDev *dev, VuVirtq *vq,
2904               const VuVirtqElement *elem, unsigned int len)
2905 {
2906     vu_queue_fill(dev, vq, elem, len, 0);
2907     vu_queue_inflight_pre_put(dev, vq, elem->index);
2908     vu_queue_flush(dev, vq, 1);
2909     vu_queue_inflight_post_put(dev, vq, elem->index);
2910 }
2911