qemu-patch-raspberry4/contrib/libvhost-user/libvhost-user.c
Dr. David Alan Gilbert 9bb3801994 vhost+postcopy: Send address back to qemu
We need a better way, but at the moment we need the address of the
mappings sent back to qemu so it can interpret the messages on the
userfaultfd it reads.

This is done as a 3 stage set:
   QEMU -> client
      set_mem_table

   mmap stuff, get addresses

   client -> qemu
       here are the addresses

   qemu -> client
       OK - now you can use them

That ensures that qemu has registered the new addresses in it's
userfault code before the client starts accessing them.

Note: We don't ask for the default 'ack' reply since we've got our own.

Signed-off-by: Dr. David Alan Gilbert <dgilbert@redhat.com>
Reviewed-by: Marc-André Lureau <marcandre.lureau@redhat.com>
Reviewed-by: Michael S. Tsirkin <mst@redhat.com>
Signed-off-by: Michael S. Tsirkin <mst@redhat.com>
2018-03-20 05:03:28 +02:00

1951 lines
51 KiB
C

/*
* Vhost User library
*
* Copyright IBM, Corp. 2007
* Copyright (c) 2016 Red Hat, Inc.
*
* Authors:
* Anthony Liguori <aliguori@us.ibm.com>
* Marc-André Lureau <mlureau@redhat.com>
* Victor Kaplansky <victork@redhat.com>
*
* This work is licensed under the terms of the GNU GPL, version 2 or
* later. See the COPYING file in the top-level directory.
*/
/* this code avoids GLib dependency */
#include <stdlib.h>
#include <stdio.h>
#include <unistd.h>
#include <stdarg.h>
#include <errno.h>
#include <string.h>
#include <assert.h>
#include <inttypes.h>
#include <sys/types.h>
#include <sys/socket.h>
#include <sys/eventfd.h>
#include <sys/mman.h>
#include "qemu/compiler.h"
#if defined(__linux__)
#include <sys/syscall.h>
#include <fcntl.h>
#include <sys/ioctl.h>
#include <linux/vhost.h>
#ifdef __NR_userfaultfd
#include <linux/userfaultfd.h>
#endif
#endif
#include "qemu/atomic.h"
#include "libvhost-user.h"
/* usually provided by GLib */
#ifndef MIN
#define MIN(x, y) ({ \
typeof(x) _min1 = (x); \
typeof(y) _min2 = (y); \
(void) (&_min1 == &_min2); \
_min1 < _min2 ? _min1 : _min2; })
#endif
#define VHOST_USER_HDR_SIZE offsetof(VhostUserMsg, payload.u64)
/* The version of the protocol we support */
#define VHOST_USER_VERSION 1
#define LIBVHOST_USER_DEBUG 0
#define DPRINT(...) \
do { \
if (LIBVHOST_USER_DEBUG) { \
fprintf(stderr, __VA_ARGS__); \
} \
} while (0)
static const char *
vu_request_to_string(unsigned int req)
{
#define REQ(req) [req] = #req
static const char *vu_request_str[] = {
REQ(VHOST_USER_NONE),
REQ(VHOST_USER_GET_FEATURES),
REQ(VHOST_USER_SET_FEATURES),
REQ(VHOST_USER_SET_OWNER),
REQ(VHOST_USER_RESET_OWNER),
REQ(VHOST_USER_SET_MEM_TABLE),
REQ(VHOST_USER_SET_LOG_BASE),
REQ(VHOST_USER_SET_LOG_FD),
REQ(VHOST_USER_SET_VRING_NUM),
REQ(VHOST_USER_SET_VRING_ADDR),
REQ(VHOST_USER_SET_VRING_BASE),
REQ(VHOST_USER_GET_VRING_BASE),
REQ(VHOST_USER_SET_VRING_KICK),
REQ(VHOST_USER_SET_VRING_CALL),
REQ(VHOST_USER_SET_VRING_ERR),
REQ(VHOST_USER_GET_PROTOCOL_FEATURES),
REQ(VHOST_USER_SET_PROTOCOL_FEATURES),
REQ(VHOST_USER_GET_QUEUE_NUM),
REQ(VHOST_USER_SET_VRING_ENABLE),
REQ(VHOST_USER_SEND_RARP),
REQ(VHOST_USER_NET_SET_MTU),
REQ(VHOST_USER_SET_SLAVE_REQ_FD),
REQ(VHOST_USER_IOTLB_MSG),
REQ(VHOST_USER_SET_VRING_ENDIAN),
REQ(VHOST_USER_GET_CONFIG),
REQ(VHOST_USER_SET_CONFIG),
REQ(VHOST_USER_POSTCOPY_ADVISE),
REQ(VHOST_USER_POSTCOPY_LISTEN),
REQ(VHOST_USER_MAX),
};
#undef REQ
if (req < VHOST_USER_MAX) {
return vu_request_str[req];
} else {
return "unknown";
}
}
static void
vu_panic(VuDev *dev, const char *msg, ...)
{
char *buf = NULL;
va_list ap;
va_start(ap, msg);
if (vasprintf(&buf, msg, ap) < 0) {
buf = NULL;
}
va_end(ap);
dev->broken = true;
dev->panic(dev, buf);
free(buf);
/* FIXME: find a way to call virtio_error? */
}
/* Translate guest physical address to our virtual address. */
void *
vu_gpa_to_va(VuDev *dev, uint64_t *plen, uint64_t guest_addr)
{
int i;
if (*plen == 0) {
return NULL;
}
/* Find matching memory region. */
for (i = 0; i < dev->nregions; i++) {
VuDevRegion *r = &dev->regions[i];
if ((guest_addr >= r->gpa) && (guest_addr < (r->gpa + r->size))) {
if ((guest_addr + *plen) > (r->gpa + r->size)) {
*plen = r->gpa + r->size - guest_addr;
}
return (void *)(uintptr_t)
guest_addr - r->gpa + r->mmap_addr + r->mmap_offset;
}
}
return NULL;
}
/* Translate qemu virtual address to our virtual address. */
static void *
qva_to_va(VuDev *dev, uint64_t qemu_addr)
{
int i;
/* Find matching memory region. */
for (i = 0; i < dev->nregions; i++) {
VuDevRegion *r = &dev->regions[i];
if ((qemu_addr >= r->qva) && (qemu_addr < (r->qva + r->size))) {
return (void *)(uintptr_t)
qemu_addr - r->qva + r->mmap_addr + r->mmap_offset;
}
}
return NULL;
}
static void
vmsg_close_fds(VhostUserMsg *vmsg)
{
int i;
for (i = 0; i < vmsg->fd_num; i++) {
close(vmsg->fds[i]);
}
}
static bool
vu_message_read(VuDev *dev, int conn_fd, VhostUserMsg *vmsg)
{
char control[CMSG_SPACE(VHOST_MEMORY_MAX_NREGIONS * sizeof(int))] = { };
struct iovec iov = {
.iov_base = (char *)vmsg,
.iov_len = VHOST_USER_HDR_SIZE,
};
struct msghdr msg = {
.msg_iov = &iov,
.msg_iovlen = 1,
.msg_control = control,
.msg_controllen = sizeof(control),
};
size_t fd_size;
struct cmsghdr *cmsg;
int rc;
do {
rc = recvmsg(conn_fd, &msg, 0);
} while (rc < 0 && (errno == EINTR || errno == EAGAIN));
if (rc < 0) {
vu_panic(dev, "Error while recvmsg: %s", strerror(errno));
return false;
}
vmsg->fd_num = 0;
for (cmsg = CMSG_FIRSTHDR(&msg);
cmsg != NULL;
cmsg = CMSG_NXTHDR(&msg, cmsg))
{
if (cmsg->cmsg_level == SOL_SOCKET && cmsg->cmsg_type == SCM_RIGHTS) {
fd_size = cmsg->cmsg_len - CMSG_LEN(0);
vmsg->fd_num = fd_size / sizeof(int);
memcpy(vmsg->fds, CMSG_DATA(cmsg), fd_size);
break;
}
}
if (vmsg->size > sizeof(vmsg->payload)) {
vu_panic(dev,
"Error: too big message request: %d, size: vmsg->size: %u, "
"while sizeof(vmsg->payload) = %zu\n",
vmsg->request, vmsg->size, sizeof(vmsg->payload));
goto fail;
}
if (vmsg->size) {
do {
rc = read(conn_fd, &vmsg->payload, vmsg->size);
} while (rc < 0 && (errno == EINTR || errno == EAGAIN));
if (rc <= 0) {
vu_panic(dev, "Error while reading: %s", strerror(errno));
goto fail;
}
assert(rc == vmsg->size);
}
return true;
fail:
vmsg_close_fds(vmsg);
return false;
}
static bool
vu_message_write(VuDev *dev, int conn_fd, VhostUserMsg *vmsg)
{
int rc;
uint8_t *p = (uint8_t *)vmsg;
char control[CMSG_SPACE(VHOST_MEMORY_MAX_NREGIONS * sizeof(int))] = { };
struct iovec iov = {
.iov_base = (char *)vmsg,
.iov_len = VHOST_USER_HDR_SIZE,
};
struct msghdr msg = {
.msg_iov = &iov,
.msg_iovlen = 1,
.msg_control = control,
};
struct cmsghdr *cmsg;
memset(control, 0, sizeof(control));
assert(vmsg->fd_num <= VHOST_MEMORY_MAX_NREGIONS);
if (vmsg->fd_num > 0) {
size_t fdsize = vmsg->fd_num * sizeof(int);
msg.msg_controllen = CMSG_SPACE(fdsize);
cmsg = CMSG_FIRSTHDR(&msg);
cmsg->cmsg_len = CMSG_LEN(fdsize);
cmsg->cmsg_level = SOL_SOCKET;
cmsg->cmsg_type = SCM_RIGHTS;
memcpy(CMSG_DATA(cmsg), vmsg->fds, fdsize);
} else {
msg.msg_controllen = 0;
}
/* Set the version in the flags when sending the reply */
vmsg->flags &= ~VHOST_USER_VERSION_MASK;
vmsg->flags |= VHOST_USER_VERSION;
vmsg->flags |= VHOST_USER_REPLY_MASK;
do {
rc = sendmsg(conn_fd, &msg, 0);
} while (rc < 0 && (errno == EINTR || errno == EAGAIN));
do {
if (vmsg->data) {
rc = write(conn_fd, vmsg->data, vmsg->size);
} else {
rc = write(conn_fd, p + VHOST_USER_HDR_SIZE, vmsg->size);
}
} while (rc < 0 && (errno == EINTR || errno == EAGAIN));
if (rc <= 0) {
vu_panic(dev, "Error while writing: %s", strerror(errno));
return false;
}
return true;
}
/* Kick the log_call_fd if required. */
static void
vu_log_kick(VuDev *dev)
{
if (dev->log_call_fd != -1) {
DPRINT("Kicking the QEMU's log...\n");
if (eventfd_write(dev->log_call_fd, 1) < 0) {
vu_panic(dev, "Error writing eventfd: %s", strerror(errno));
}
}
}
static void
vu_log_page(uint8_t *log_table, uint64_t page)
{
DPRINT("Logged dirty guest page: %"PRId64"\n", page);
atomic_or(&log_table[page / 8], 1 << (page % 8));
}
static void
vu_log_write(VuDev *dev, uint64_t address, uint64_t length)
{
uint64_t page;
if (!(dev->features & (1ULL << VHOST_F_LOG_ALL)) ||
!dev->log_table || !length) {
return;
}
assert(dev->log_size > ((address + length - 1) / VHOST_LOG_PAGE / 8));
page = address / VHOST_LOG_PAGE;
while (page * VHOST_LOG_PAGE < address + length) {
vu_log_page(dev->log_table, page);
page += VHOST_LOG_PAGE;
}
vu_log_kick(dev);
}
static void
vu_kick_cb(VuDev *dev, int condition, void *data)
{
int index = (intptr_t)data;
VuVirtq *vq = &dev->vq[index];
int sock = vq->kick_fd;
eventfd_t kick_data;
ssize_t rc;
rc = eventfd_read(sock, &kick_data);
if (rc == -1) {
vu_panic(dev, "kick eventfd_read(): %s", strerror(errno));
dev->remove_watch(dev, dev->vq[index].kick_fd);
} else {
DPRINT("Got kick_data: %016"PRIx64" handler:%p idx:%d\n",
kick_data, vq->handler, index);
if (vq->handler) {
vq->handler(dev, index);
}
}
}
static bool
vu_get_features_exec(VuDev *dev, VhostUserMsg *vmsg)
{
vmsg->payload.u64 =
1ULL << VHOST_F_LOG_ALL |
1ULL << VHOST_USER_F_PROTOCOL_FEATURES;
if (dev->iface->get_features) {
vmsg->payload.u64 |= dev->iface->get_features(dev);
}
vmsg->size = sizeof(vmsg->payload.u64);
vmsg->fd_num = 0;
DPRINT("Sending back to guest u64: 0x%016"PRIx64"\n", vmsg->payload.u64);
return true;
}
static void
vu_set_enable_all_rings(VuDev *dev, bool enabled)
{
int i;
for (i = 0; i < VHOST_MAX_NR_VIRTQUEUE; i++) {
dev->vq[i].enable = enabled;
}
}
static bool
vu_set_features_exec(VuDev *dev, VhostUserMsg *vmsg)
{
DPRINT("u64: 0x%016"PRIx64"\n", vmsg->payload.u64);
dev->features = vmsg->payload.u64;
if (!(dev->features & VHOST_USER_F_PROTOCOL_FEATURES)) {
vu_set_enable_all_rings(dev, true);
}
if (dev->iface->set_features) {
dev->iface->set_features(dev, dev->features);
}
return false;
}
static bool
vu_set_owner_exec(VuDev *dev, VhostUserMsg *vmsg)
{
return false;
}
static void
vu_close_log(VuDev *dev)
{
if (dev->log_table) {
if (munmap(dev->log_table, dev->log_size) != 0) {
perror("close log munmap() error");
}
dev->log_table = NULL;
}
if (dev->log_call_fd != -1) {
close(dev->log_call_fd);
dev->log_call_fd = -1;
}
}
static bool
vu_reset_device_exec(VuDev *dev, VhostUserMsg *vmsg)
{
vu_set_enable_all_rings(dev, false);
return false;
}
static bool
vu_set_mem_table_exec_postcopy(VuDev *dev, VhostUserMsg *vmsg)
{
int i;
VhostUserMemory *memory = &vmsg->payload.memory;
dev->nregions = memory->nregions;
/* TODO: Postcopy specific code */
DPRINT("Nregions: %d\n", memory->nregions);
for (i = 0; i < dev->nregions; i++) {
void *mmap_addr;
VhostUserMemoryRegion *msg_region = &memory->regions[i];
VuDevRegion *dev_region = &dev->regions[i];
DPRINT("Region %d\n", i);
DPRINT(" guest_phys_addr: 0x%016"PRIx64"\n",
msg_region->guest_phys_addr);
DPRINT(" memory_size: 0x%016"PRIx64"\n",
msg_region->memory_size);
DPRINT(" userspace_addr 0x%016"PRIx64"\n",
msg_region->userspace_addr);
DPRINT(" mmap_offset 0x%016"PRIx64"\n",
msg_region->mmap_offset);
dev_region->gpa = msg_region->guest_phys_addr;
dev_region->size = msg_region->memory_size;
dev_region->qva = msg_region->userspace_addr;
dev_region->mmap_offset = msg_region->mmap_offset;
/* We don't use offset argument of mmap() since the
* mapped address has to be page aligned, and we use huge
* pages. */
mmap_addr = mmap(0, dev_region->size + dev_region->mmap_offset,
PROT_READ | PROT_WRITE, MAP_SHARED,
vmsg->fds[i], 0);
if (mmap_addr == MAP_FAILED) {
vu_panic(dev, "region mmap error: %s", strerror(errno));
} else {
dev_region->mmap_addr = (uint64_t)(uintptr_t)mmap_addr;
DPRINT(" mmap_addr: 0x%016"PRIx64"\n",
dev_region->mmap_addr);
}
/* Return the address to QEMU so that it can translate the ufd
* fault addresses back.
*/
msg_region->userspace_addr = (uintptr_t)(mmap_addr +
dev_region->mmap_offset);
close(vmsg->fds[i]);
}
/* Send the message back to qemu with the addresses filled in */
vmsg->fd_num = 0;
if (!vu_message_write(dev, dev->sock, vmsg)) {
vu_panic(dev, "failed to respond to set-mem-table for postcopy");
return false;
}
/* Wait for QEMU to confirm that it's registered the handler for the
* faults.
*/
if (!vu_message_read(dev, dev->sock, vmsg) ||
vmsg->size != sizeof(vmsg->payload.u64) ||
vmsg->payload.u64 != 0) {
vu_panic(dev, "failed to receive valid ack for postcopy set-mem-table");
return false;
}
/* OK, now we can go and register the memory and generate faults */
for (i = 0; i < dev->nregions; i++) {
VuDevRegion *dev_region = &dev->regions[i];
#ifdef UFFDIO_REGISTER
/* We should already have an open ufd. Mark each memory
* range as ufd.
* Note: Do we need any madvises? Well it's not been accessed
* yet, still probably need no THP to be safe, discard to be safe?
*/
struct uffdio_register reg_struct;
reg_struct.range.start = (uintptr_t)dev_region->mmap_addr;
reg_struct.range.len = dev_region->size + dev_region->mmap_offset;
reg_struct.mode = UFFDIO_REGISTER_MODE_MISSING;
if (ioctl(dev->postcopy_ufd, UFFDIO_REGISTER, &reg_struct)) {
vu_panic(dev, "%s: Failed to userfault region %d "
"@%p + size:%zx offset: %zx: (ufd=%d)%s\n",
__func__, i,
dev_region->mmap_addr,
dev_region->size, dev_region->mmap_offset,
dev->postcopy_ufd, strerror(errno));
return false;
}
if (!(reg_struct.ioctls & ((__u64)1 << _UFFDIO_COPY))) {
vu_panic(dev, "%s Region (%d) doesn't support COPY",
__func__, i);
return false;
}
DPRINT("%s: region %d: Registered userfault for %llx + %llx\n",
__func__, i, reg_struct.range.start, reg_struct.range.len);
/* TODO: Stash 'zero' support flags somewhere */
#endif
}
return false;
}
static bool
vu_set_mem_table_exec(VuDev *dev, VhostUserMsg *vmsg)
{
int i;
VhostUserMemory *memory = &vmsg->payload.memory;
for (i = 0; i < dev->nregions; i++) {
VuDevRegion *r = &dev->regions[i];
void *m = (void *) (uintptr_t) r->mmap_addr;
if (m) {
munmap(m, r->size + r->mmap_offset);
}
}
dev->nregions = memory->nregions;
if (dev->postcopy_listening) {
return vu_set_mem_table_exec_postcopy(dev, vmsg);
}
DPRINT("Nregions: %d\n", memory->nregions);
for (i = 0; i < dev->nregions; i++) {
void *mmap_addr;
VhostUserMemoryRegion *msg_region = &memory->regions[i];
VuDevRegion *dev_region = &dev->regions[i];
DPRINT("Region %d\n", i);
DPRINT(" guest_phys_addr: 0x%016"PRIx64"\n",
msg_region->guest_phys_addr);
DPRINT(" memory_size: 0x%016"PRIx64"\n",
msg_region->memory_size);
DPRINT(" userspace_addr 0x%016"PRIx64"\n",
msg_region->userspace_addr);
DPRINT(" mmap_offset 0x%016"PRIx64"\n",
msg_region->mmap_offset);
dev_region->gpa = msg_region->guest_phys_addr;
dev_region->size = msg_region->memory_size;
dev_region->qva = msg_region->userspace_addr;
dev_region->mmap_offset = msg_region->mmap_offset;
/* We don't use offset argument of mmap() since the
* mapped address has to be page aligned, and we use huge
* pages. */
mmap_addr = mmap(0, dev_region->size + dev_region->mmap_offset,
PROT_READ | PROT_WRITE, MAP_SHARED,
vmsg->fds[i], 0);
if (mmap_addr == MAP_FAILED) {
vu_panic(dev, "region mmap error: %s", strerror(errno));
} else {
dev_region->mmap_addr = (uint64_t)(uintptr_t)mmap_addr;
DPRINT(" mmap_addr: 0x%016"PRIx64"\n",
dev_region->mmap_addr);
}
close(vmsg->fds[i]);
}
return false;
}
static bool
vu_set_log_base_exec(VuDev *dev, VhostUserMsg *vmsg)
{
int fd;
uint64_t log_mmap_size, log_mmap_offset;
void *rc;
if (vmsg->fd_num != 1 ||
vmsg->size != sizeof(vmsg->payload.log)) {
vu_panic(dev, "Invalid log_base message");
return true;
}
fd = vmsg->fds[0];
log_mmap_offset = vmsg->payload.log.mmap_offset;
log_mmap_size = vmsg->payload.log.mmap_size;
DPRINT("Log mmap_offset: %"PRId64"\n", log_mmap_offset);
DPRINT("Log mmap_size: %"PRId64"\n", log_mmap_size);
rc = mmap(0, log_mmap_size, PROT_READ | PROT_WRITE, MAP_SHARED, fd,
log_mmap_offset);
close(fd);
if (rc == MAP_FAILED) {
perror("log mmap error");
}
if (dev->log_table) {
munmap(dev->log_table, dev->log_size);
}
dev->log_table = rc;
dev->log_size = log_mmap_size;
vmsg->size = sizeof(vmsg->payload.u64);
vmsg->fd_num = 0;
return true;
}
static bool
vu_set_log_fd_exec(VuDev *dev, VhostUserMsg *vmsg)
{
if (vmsg->fd_num != 1) {
vu_panic(dev, "Invalid log_fd message");
return false;
}
if (dev->log_call_fd != -1) {
close(dev->log_call_fd);
}
dev->log_call_fd = vmsg->fds[0];
DPRINT("Got log_call_fd: %d\n", vmsg->fds[0]);
return false;
}
static bool
vu_set_vring_num_exec(VuDev *dev, VhostUserMsg *vmsg)
{
unsigned int index = vmsg->payload.state.index;
unsigned int num = vmsg->payload.state.num;
DPRINT("State.index: %d\n", index);
DPRINT("State.num: %d\n", num);
dev->vq[index].vring.num = num;
return false;
}
static bool
vu_set_vring_addr_exec(VuDev *dev, VhostUserMsg *vmsg)
{
struct vhost_vring_addr *vra = &vmsg->payload.addr;
unsigned int index = vra->index;
VuVirtq *vq = &dev->vq[index];
DPRINT("vhost_vring_addr:\n");
DPRINT(" index: %d\n", vra->index);
DPRINT(" flags: %d\n", vra->flags);
DPRINT(" desc_user_addr: 0x%016llx\n", vra->desc_user_addr);
DPRINT(" used_user_addr: 0x%016llx\n", vra->used_user_addr);
DPRINT(" avail_user_addr: 0x%016llx\n", vra->avail_user_addr);
DPRINT(" log_guest_addr: 0x%016llx\n", vra->log_guest_addr);
vq->vring.flags = vra->flags;
vq->vring.desc = qva_to_va(dev, vra->desc_user_addr);
vq->vring.used = qva_to_va(dev, vra->used_user_addr);
vq->vring.avail = qva_to_va(dev, vra->avail_user_addr);
vq->vring.log_guest_addr = vra->log_guest_addr;
DPRINT("Setting virtq addresses:\n");
DPRINT(" vring_desc at %p\n", vq->vring.desc);
DPRINT(" vring_used at %p\n", vq->vring.used);
DPRINT(" vring_avail at %p\n", vq->vring.avail);
if (!(vq->vring.desc && vq->vring.used && vq->vring.avail)) {
vu_panic(dev, "Invalid vring_addr message");
return false;
}
vq->used_idx = vq->vring.used->idx;
if (vq->last_avail_idx != vq->used_idx) {
bool resume = dev->iface->queue_is_processed_in_order &&
dev->iface->queue_is_processed_in_order(dev, index);
DPRINT("Last avail index != used index: %u != %u%s\n",
vq->last_avail_idx, vq->used_idx,
resume ? ", resuming" : "");
if (resume) {
vq->shadow_avail_idx = vq->last_avail_idx = vq->used_idx;
}
}
return false;
}
static bool
vu_set_vring_base_exec(VuDev *dev, VhostUserMsg *vmsg)
{
unsigned int index = vmsg->payload.state.index;
unsigned int num = vmsg->payload.state.num;
DPRINT("State.index: %d\n", index);
DPRINT("State.num: %d\n", num);
dev->vq[index].shadow_avail_idx = dev->vq[index].last_avail_idx = num;
return false;
}
static bool
vu_get_vring_base_exec(VuDev *dev, VhostUserMsg *vmsg)
{
unsigned int index = vmsg->payload.state.index;
DPRINT("State.index: %d\n", index);
vmsg->payload.state.num = dev->vq[index].last_avail_idx;
vmsg->size = sizeof(vmsg->payload.state);
dev->vq[index].started = false;
if (dev->iface->queue_set_started) {
dev->iface->queue_set_started(dev, index, false);
}
if (dev->vq[index].call_fd != -1) {
close(dev->vq[index].call_fd);
dev->vq[index].call_fd = -1;
}
if (dev->vq[index].kick_fd != -1) {
dev->remove_watch(dev, dev->vq[index].kick_fd);
close(dev->vq[index].kick_fd);
dev->vq[index].kick_fd = -1;
}
return true;
}
static bool
vu_check_queue_msg_file(VuDev *dev, VhostUserMsg *vmsg)
{
int index = vmsg->payload.u64 & VHOST_USER_VRING_IDX_MASK;
if (index >= VHOST_MAX_NR_VIRTQUEUE) {
vmsg_close_fds(vmsg);
vu_panic(dev, "Invalid queue index: %u", index);
return false;
}
if (vmsg->payload.u64 & VHOST_USER_VRING_NOFD_MASK ||
vmsg->fd_num != 1) {
vmsg_close_fds(vmsg);
vu_panic(dev, "Invalid fds in request: %d", vmsg->request);
return false;
}
return true;
}
static bool
vu_set_vring_kick_exec(VuDev *dev, VhostUserMsg *vmsg)
{
int index = vmsg->payload.u64 & VHOST_USER_VRING_IDX_MASK;
DPRINT("u64: 0x%016"PRIx64"\n", vmsg->payload.u64);
if (!vu_check_queue_msg_file(dev, vmsg)) {
return false;
}
if (dev->vq[index].kick_fd != -1) {
dev->remove_watch(dev, dev->vq[index].kick_fd);
close(dev->vq[index].kick_fd);
dev->vq[index].kick_fd = -1;
}
if (!(vmsg->payload.u64 & VHOST_USER_VRING_NOFD_MASK)) {
dev->vq[index].kick_fd = vmsg->fds[0];
DPRINT("Got kick_fd: %d for vq: %d\n", vmsg->fds[0], index);
}
dev->vq[index].started = true;
if (dev->iface->queue_set_started) {
dev->iface->queue_set_started(dev, index, true);
}
if (dev->vq[index].kick_fd != -1 && dev->vq[index].handler) {
dev->set_watch(dev, dev->vq[index].kick_fd, VU_WATCH_IN,
vu_kick_cb, (void *)(long)index);
DPRINT("Waiting for kicks on fd: %d for vq: %d\n",
dev->vq[index].kick_fd, index);
}
return false;
}
void vu_set_queue_handler(VuDev *dev, VuVirtq *vq,
vu_queue_handler_cb handler)
{
int qidx = vq - dev->vq;
vq->handler = handler;
if (vq->kick_fd >= 0) {
if (handler) {
dev->set_watch(dev, vq->kick_fd, VU_WATCH_IN,
vu_kick_cb, (void *)(long)qidx);
} else {
dev->remove_watch(dev, vq->kick_fd);
}
}
}
static bool
vu_set_vring_call_exec(VuDev *dev, VhostUserMsg *vmsg)
{
int index = vmsg->payload.u64 & VHOST_USER_VRING_IDX_MASK;
DPRINT("u64: 0x%016"PRIx64"\n", vmsg->payload.u64);
if (!vu_check_queue_msg_file(dev, vmsg)) {
return false;
}
if (dev->vq[index].call_fd != -1) {
close(dev->vq[index].call_fd);
dev->vq[index].call_fd = -1;
}
if (!(vmsg->payload.u64 & VHOST_USER_VRING_NOFD_MASK)) {
dev->vq[index].call_fd = vmsg->fds[0];
}
DPRINT("Got call_fd: %d for vq: %d\n", vmsg->fds[0], index);
return false;
}
static bool
vu_set_vring_err_exec(VuDev *dev, VhostUserMsg *vmsg)
{
int index = vmsg->payload.u64 & VHOST_USER_VRING_IDX_MASK;
DPRINT("u64: 0x%016"PRIx64"\n", vmsg->payload.u64);
if (!vu_check_queue_msg_file(dev, vmsg)) {
return false;
}
if (dev->vq[index].err_fd != -1) {
close(dev->vq[index].err_fd);
dev->vq[index].err_fd = -1;
}
if (!(vmsg->payload.u64 & VHOST_USER_VRING_NOFD_MASK)) {
dev->vq[index].err_fd = vmsg->fds[0];
}
return false;
}
static bool
vu_get_protocol_features_exec(VuDev *dev, VhostUserMsg *vmsg)
{
uint64_t features = 1ULL << VHOST_USER_PROTOCOL_F_LOG_SHMFD |
1ULL << VHOST_USER_PROTOCOL_F_SLAVE_REQ;
if (dev->iface->get_protocol_features) {
features |= dev->iface->get_protocol_features(dev);
}
vmsg->payload.u64 = features;
vmsg->size = sizeof(vmsg->payload.u64);
vmsg->fd_num = 0;
return true;
}
static bool
vu_set_protocol_features_exec(VuDev *dev, VhostUserMsg *vmsg)
{
uint64_t features = vmsg->payload.u64;
DPRINT("u64: 0x%016"PRIx64"\n", features);
dev->protocol_features = vmsg->payload.u64;
if (dev->iface->set_protocol_features) {
dev->iface->set_protocol_features(dev, features);
}
return false;
}
static bool
vu_get_queue_num_exec(VuDev *dev, VhostUserMsg *vmsg)
{
DPRINT("Function %s() not implemented yet.\n", __func__);
return false;
}
static bool
vu_set_vring_enable_exec(VuDev *dev, VhostUserMsg *vmsg)
{
unsigned int index = vmsg->payload.state.index;
unsigned int enable = vmsg->payload.state.num;
DPRINT("State.index: %d\n", index);
DPRINT("State.enable: %d\n", enable);
if (index >= VHOST_MAX_NR_VIRTQUEUE) {
vu_panic(dev, "Invalid vring_enable index: %u", index);
return false;
}
dev->vq[index].enable = enable;
return false;
}
static bool
vu_set_slave_req_fd(VuDev *dev, VhostUserMsg *vmsg)
{
if (vmsg->fd_num != 1) {
vu_panic(dev, "Invalid slave_req_fd message (%d fd's)", vmsg->fd_num);
return false;
}
if (dev->slave_fd != -1) {
close(dev->slave_fd);
}
dev->slave_fd = vmsg->fds[0];
DPRINT("Got slave_fd: %d\n", vmsg->fds[0]);
return false;
}
static bool
vu_get_config(VuDev *dev, VhostUserMsg *vmsg)
{
int ret = -1;
if (dev->iface->get_config) {
ret = dev->iface->get_config(dev, vmsg->payload.config.region,
vmsg->payload.config.size);
}
if (ret) {
/* resize to zero to indicate an error to master */
vmsg->size = 0;
}
return true;
}
static bool
vu_set_config(VuDev *dev, VhostUserMsg *vmsg)
{
int ret = -1;
if (dev->iface->set_config) {
ret = dev->iface->set_config(dev, vmsg->payload.config.region,
vmsg->payload.config.offset,
vmsg->payload.config.size,
vmsg->payload.config.flags);
if (ret) {
vu_panic(dev, "Set virtio configuration space failed");
}
}
return false;
}
static bool
vu_set_postcopy_advise(VuDev *dev, VhostUserMsg *vmsg)
{
dev->postcopy_ufd = -1;
#ifdef UFFDIO_API
struct uffdio_api api_struct;
dev->postcopy_ufd = syscall(__NR_userfaultfd, O_CLOEXEC | O_NONBLOCK);
vmsg->size = 0;
#endif
if (dev->postcopy_ufd == -1) {
vu_panic(dev, "Userfaultfd not available: %s", strerror(errno));
goto out;
}
#ifdef UFFDIO_API
api_struct.api = UFFD_API;
api_struct.features = 0;
if (ioctl(dev->postcopy_ufd, UFFDIO_API, &api_struct)) {
vu_panic(dev, "Failed UFFDIO_API: %s", strerror(errno));
close(dev->postcopy_ufd);
dev->postcopy_ufd = -1;
goto out;
}
/* TODO: Stash feature flags somewhere */
#endif
out:
/* Return a ufd to the QEMU */
vmsg->fd_num = 1;
vmsg->fds[0] = dev->postcopy_ufd;
return true; /* = send a reply */
}
static bool
vu_set_postcopy_listen(VuDev *dev, VhostUserMsg *vmsg)
{
vmsg->payload.u64 = -1;
vmsg->size = sizeof(vmsg->payload.u64);
if (dev->nregions) {
vu_panic(dev, "Regions already registered at postcopy-listen");
return true;
}
dev->postcopy_listening = true;
vmsg->flags = VHOST_USER_VERSION | VHOST_USER_REPLY_MASK;
vmsg->payload.u64 = 0; /* Success */
return true;
}
static bool
vu_process_message(VuDev *dev, VhostUserMsg *vmsg)
{
int do_reply = 0;
/* Print out generic part of the request. */
DPRINT("================ Vhost user message ================\n");
DPRINT("Request: %s (%d)\n", vu_request_to_string(vmsg->request),
vmsg->request);
DPRINT("Flags: 0x%x\n", vmsg->flags);
DPRINT("Size: %d\n", vmsg->size);
if (vmsg->fd_num) {
int i;
DPRINT("Fds:");
for (i = 0; i < vmsg->fd_num; i++) {
DPRINT(" %d", vmsg->fds[i]);
}
DPRINT("\n");
}
if (dev->iface->process_msg &&
dev->iface->process_msg(dev, vmsg, &do_reply)) {
return do_reply;
}
switch (vmsg->request) {
case VHOST_USER_GET_FEATURES:
return vu_get_features_exec(dev, vmsg);
case VHOST_USER_SET_FEATURES:
return vu_set_features_exec(dev, vmsg);
case VHOST_USER_GET_PROTOCOL_FEATURES:
return vu_get_protocol_features_exec(dev, vmsg);
case VHOST_USER_SET_PROTOCOL_FEATURES:
return vu_set_protocol_features_exec(dev, vmsg);
case VHOST_USER_SET_OWNER:
return vu_set_owner_exec(dev, vmsg);
case VHOST_USER_RESET_OWNER:
return vu_reset_device_exec(dev, vmsg);
case VHOST_USER_SET_MEM_TABLE:
return vu_set_mem_table_exec(dev, vmsg);
case VHOST_USER_SET_LOG_BASE:
return vu_set_log_base_exec(dev, vmsg);
case VHOST_USER_SET_LOG_FD:
return vu_set_log_fd_exec(dev, vmsg);
case VHOST_USER_SET_VRING_NUM:
return vu_set_vring_num_exec(dev, vmsg);
case VHOST_USER_SET_VRING_ADDR:
return vu_set_vring_addr_exec(dev, vmsg);
case VHOST_USER_SET_VRING_BASE:
return vu_set_vring_base_exec(dev, vmsg);
case VHOST_USER_GET_VRING_BASE:
return vu_get_vring_base_exec(dev, vmsg);
case VHOST_USER_SET_VRING_KICK:
return vu_set_vring_kick_exec(dev, vmsg);
case VHOST_USER_SET_VRING_CALL:
return vu_set_vring_call_exec(dev, vmsg);
case VHOST_USER_SET_VRING_ERR:
return vu_set_vring_err_exec(dev, vmsg);
case VHOST_USER_GET_QUEUE_NUM:
return vu_get_queue_num_exec(dev, vmsg);
case VHOST_USER_SET_VRING_ENABLE:
return vu_set_vring_enable_exec(dev, vmsg);
case VHOST_USER_SET_SLAVE_REQ_FD:
return vu_set_slave_req_fd(dev, vmsg);
case VHOST_USER_GET_CONFIG:
return vu_get_config(dev, vmsg);
case VHOST_USER_SET_CONFIG:
return vu_set_config(dev, vmsg);
case VHOST_USER_NONE:
break;
case VHOST_USER_POSTCOPY_ADVISE:
return vu_set_postcopy_advise(dev, vmsg);
case VHOST_USER_POSTCOPY_LISTEN:
return vu_set_postcopy_listen(dev, vmsg);
default:
vmsg_close_fds(vmsg);
vu_panic(dev, "Unhandled request: %d", vmsg->request);
}
return false;
}
bool
vu_dispatch(VuDev *dev)
{
VhostUserMsg vmsg = { 0, };
int reply_requested;
bool success = false;
if (!vu_message_read(dev, dev->sock, &vmsg)) {
goto end;
}
reply_requested = vu_process_message(dev, &vmsg);
if (!reply_requested) {
success = true;
goto end;
}
if (!vu_message_write(dev, dev->sock, &vmsg)) {
goto end;
}
success = true;
end:
free(vmsg.data);
return success;
}
void
vu_deinit(VuDev *dev)
{
int i;
for (i = 0; i < dev->nregions; i++) {
VuDevRegion *r = &dev->regions[i];
void *m = (void *) (uintptr_t) r->mmap_addr;
if (m != MAP_FAILED) {
munmap(m, r->size + r->mmap_offset);
}
}
dev->nregions = 0;
for (i = 0; i < VHOST_MAX_NR_VIRTQUEUE; i++) {
VuVirtq *vq = &dev->vq[i];
if (vq->call_fd != -1) {
close(vq->call_fd);
vq->call_fd = -1;
}
if (vq->kick_fd != -1) {
close(vq->kick_fd);
vq->kick_fd = -1;
}
if (vq->err_fd != -1) {
close(vq->err_fd);
vq->err_fd = -1;
}
}
vu_close_log(dev);
if (dev->slave_fd != -1) {
close(dev->slave_fd);
dev->slave_fd = -1;
}
if (dev->sock != -1) {
close(dev->sock);
}
}
void
vu_init(VuDev *dev,
int socket,
vu_panic_cb panic,
vu_set_watch_cb set_watch,
vu_remove_watch_cb remove_watch,
const VuDevIface *iface)
{
int i;
assert(socket >= 0);
assert(set_watch);
assert(remove_watch);
assert(iface);
assert(panic);
memset(dev, 0, sizeof(*dev));
dev->sock = socket;
dev->panic = panic;
dev->set_watch = set_watch;
dev->remove_watch = remove_watch;
dev->iface = iface;
dev->log_call_fd = -1;
dev->slave_fd = -1;
for (i = 0; i < VHOST_MAX_NR_VIRTQUEUE; i++) {
dev->vq[i] = (VuVirtq) {
.call_fd = -1, .kick_fd = -1, .err_fd = -1,
.notification = true,
};
}
}
VuVirtq *
vu_get_queue(VuDev *dev, int qidx)
{
assert(qidx < VHOST_MAX_NR_VIRTQUEUE);
return &dev->vq[qidx];
}
bool
vu_queue_enabled(VuDev *dev, VuVirtq *vq)
{
return vq->enable;
}
bool
vu_queue_started(const VuDev *dev, const VuVirtq *vq)
{
return vq->started;
}
static inline uint16_t
vring_avail_flags(VuVirtq *vq)
{
return vq->vring.avail->flags;
}
static inline uint16_t
vring_avail_idx(VuVirtq *vq)
{
vq->shadow_avail_idx = vq->vring.avail->idx;
return vq->shadow_avail_idx;
}
static inline uint16_t
vring_avail_ring(VuVirtq *vq, int i)
{
return vq->vring.avail->ring[i];
}
static inline uint16_t
vring_get_used_event(VuVirtq *vq)
{
return vring_avail_ring(vq, vq->vring.num);
}
static int
virtqueue_num_heads(VuDev *dev, VuVirtq *vq, unsigned int idx)
{
uint16_t num_heads = vring_avail_idx(vq) - idx;
/* Check it isn't doing very strange things with descriptor numbers. */
if (num_heads > vq->vring.num) {
vu_panic(dev, "Guest moved used index from %u to %u",
idx, vq->shadow_avail_idx);
return -1;
}
if (num_heads) {
/* On success, callers read a descriptor at vq->last_avail_idx.
* Make sure descriptor read does not bypass avail index read. */
smp_rmb();
}
return num_heads;
}
static bool
virtqueue_get_head(VuDev *dev, VuVirtq *vq,
unsigned int idx, unsigned int *head)
{
/* Grab the next descriptor number they're advertising, and increment
* the index we've seen. */
*head = vring_avail_ring(vq, idx % vq->vring.num);
/* If their number is silly, that's a fatal mistake. */
if (*head >= vq->vring.num) {
vu_panic(dev, "Guest says index %u is available", head);
return false;
}
return true;
}
static int
virtqueue_read_indirect_desc(VuDev *dev, struct vring_desc *desc,
uint64_t addr, size_t len)
{
struct vring_desc *ori_desc;
uint64_t read_len;
if (len > (VIRTQUEUE_MAX_SIZE * sizeof(struct vring_desc))) {
return -1;
}
if (len == 0) {
return -1;
}
while (len) {
read_len = len;
ori_desc = vu_gpa_to_va(dev, &read_len, addr);
if (!ori_desc) {
return -1;
}
memcpy(desc, ori_desc, read_len);
len -= read_len;
addr += read_len;
desc += read_len;
}
return 0;
}
enum {
VIRTQUEUE_READ_DESC_ERROR = -1,
VIRTQUEUE_READ_DESC_DONE = 0, /* end of chain */
VIRTQUEUE_READ_DESC_MORE = 1, /* more buffers in chain */
};
static int
virtqueue_read_next_desc(VuDev *dev, struct vring_desc *desc,
int i, unsigned int max, unsigned int *next)
{
/* If this descriptor says it doesn't chain, we're done. */
if (!(desc[i].flags & VRING_DESC_F_NEXT)) {
return VIRTQUEUE_READ_DESC_DONE;
}
/* Check they're not leading us off end of descriptors. */
*next = desc[i].next;
/* Make sure compiler knows to grab that: we don't want it changing! */
smp_wmb();
if (*next >= max) {
vu_panic(dev, "Desc next is %u", next);
return VIRTQUEUE_READ_DESC_ERROR;
}
return VIRTQUEUE_READ_DESC_MORE;
}
void
vu_queue_get_avail_bytes(VuDev *dev, VuVirtq *vq, unsigned int *in_bytes,
unsigned int *out_bytes,
unsigned max_in_bytes, unsigned max_out_bytes)
{
unsigned int idx;
unsigned int total_bufs, in_total, out_total;
int rc;
idx = vq->last_avail_idx;
total_bufs = in_total = out_total = 0;
if (unlikely(dev->broken) ||
unlikely(!vq->vring.avail)) {
goto done;
}
while ((rc = virtqueue_num_heads(dev, vq, idx)) > 0) {
unsigned int max, desc_len, num_bufs, indirect = 0;
uint64_t desc_addr, read_len;
struct vring_desc *desc;
struct vring_desc desc_buf[VIRTQUEUE_MAX_SIZE];
unsigned int i;
max = vq->vring.num;
num_bufs = total_bufs;
if (!virtqueue_get_head(dev, vq, idx++, &i)) {
goto err;
}
desc = vq->vring.desc;
if (desc[i].flags & VRING_DESC_F_INDIRECT) {
if (desc[i].len % sizeof(struct vring_desc)) {
vu_panic(dev, "Invalid size for indirect buffer table");
goto err;
}
/* If we've got too many, that implies a descriptor loop. */
if (num_bufs >= max) {
vu_panic(dev, "Looped descriptor");
goto err;
}
/* loop over the indirect descriptor table */
indirect = 1;
desc_addr = desc[i].addr;
desc_len = desc[i].len;
max = desc_len / sizeof(struct vring_desc);
read_len = desc_len;
desc = vu_gpa_to_va(dev, &read_len, desc_addr);
if (unlikely(desc && read_len != desc_len)) {
/* Failed to use zero copy */
desc = NULL;
if (!virtqueue_read_indirect_desc(dev, desc_buf,
desc_addr,
desc_len)) {
desc = desc_buf;
}
}
if (!desc) {
vu_panic(dev, "Invalid indirect buffer table");
goto err;
}
num_bufs = i = 0;
}
do {
/* If we've got too many, that implies a descriptor loop. */
if (++num_bufs > max) {
vu_panic(dev, "Looped descriptor");
goto err;
}
if (desc[i].flags & VRING_DESC_F_WRITE) {
in_total += desc[i].len;
} else {
out_total += desc[i].len;
}
if (in_total >= max_in_bytes && out_total >= max_out_bytes) {
goto done;
}
rc = virtqueue_read_next_desc(dev, desc, i, max, &i);
} while (rc == VIRTQUEUE_READ_DESC_MORE);
if (rc == VIRTQUEUE_READ_DESC_ERROR) {
goto err;
}
if (!indirect) {
total_bufs = num_bufs;
} else {
total_bufs++;
}
}
if (rc < 0) {
goto err;
}
done:
if (in_bytes) {
*in_bytes = in_total;
}
if (out_bytes) {
*out_bytes = out_total;
}
return;
err:
in_total = out_total = 0;
goto done;
}
bool
vu_queue_avail_bytes(VuDev *dev, VuVirtq *vq, unsigned int in_bytes,
unsigned int out_bytes)
{
unsigned int in_total, out_total;
vu_queue_get_avail_bytes(dev, vq, &in_total, &out_total,
in_bytes, out_bytes);
return in_bytes <= in_total && out_bytes <= out_total;
}
/* Fetch avail_idx from VQ memory only when we really need to know if
* guest has added some buffers. */
bool
vu_queue_empty(VuDev *dev, VuVirtq *vq)
{
if (unlikely(dev->broken) ||
unlikely(!vq->vring.avail)) {
return true;
}
if (vq->shadow_avail_idx != vq->last_avail_idx) {
return false;
}
return vring_avail_idx(vq) == vq->last_avail_idx;
}
static inline
bool has_feature(uint64_t features, unsigned int fbit)
{
assert(fbit < 64);
return !!(features & (1ULL << fbit));
}
static inline
bool vu_has_feature(VuDev *dev,
unsigned int fbit)
{
return has_feature(dev->features, fbit);
}
static bool
vring_notify(VuDev *dev, VuVirtq *vq)
{
uint16_t old, new;
bool v;
/* We need to expose used array entries before checking used event. */
smp_mb();
/* Always notify when queue is empty (when feature acknowledge) */
if (vu_has_feature(dev, VIRTIO_F_NOTIFY_ON_EMPTY) &&
!vq->inuse && vu_queue_empty(dev, vq)) {
return true;
}
if (!vu_has_feature(dev, VIRTIO_RING_F_EVENT_IDX)) {
return !(vring_avail_flags(vq) & VRING_AVAIL_F_NO_INTERRUPT);
}
v = vq->signalled_used_valid;
vq->signalled_used_valid = true;
old = vq->signalled_used;
new = vq->signalled_used = vq->used_idx;
return !v || vring_need_event(vring_get_used_event(vq), new, old);
}
void
vu_queue_notify(VuDev *dev, VuVirtq *vq)
{
if (unlikely(dev->broken) ||
unlikely(!vq->vring.avail)) {
return;
}
if (!vring_notify(dev, vq)) {
DPRINT("skipped notify...\n");
return;
}
if (eventfd_write(vq->call_fd, 1) < 0) {
vu_panic(dev, "Error writing eventfd: %s", strerror(errno));
}
}
static inline void
vring_used_flags_set_bit(VuVirtq *vq, int mask)
{
uint16_t *flags;
flags = (uint16_t *)((char*)vq->vring.used +
offsetof(struct vring_used, flags));
*flags |= mask;
}
static inline void
vring_used_flags_unset_bit(VuVirtq *vq, int mask)
{
uint16_t *flags;
flags = (uint16_t *)((char*)vq->vring.used +
offsetof(struct vring_used, flags));
*flags &= ~mask;
}
static inline void
vring_set_avail_event(VuVirtq *vq, uint16_t val)
{
if (!vq->notification) {
return;
}
*((uint16_t *) &vq->vring.used->ring[vq->vring.num]) = val;
}
void
vu_queue_set_notification(VuDev *dev, VuVirtq *vq, int enable)
{
vq->notification = enable;
if (vu_has_feature(dev, VIRTIO_RING_F_EVENT_IDX)) {
vring_set_avail_event(vq, vring_avail_idx(vq));
} else if (enable) {
vring_used_flags_unset_bit(vq, VRING_USED_F_NO_NOTIFY);
} else {
vring_used_flags_set_bit(vq, VRING_USED_F_NO_NOTIFY);
}
if (enable) {
/* Expose avail event/used flags before caller checks the avail idx. */
smp_mb();
}
}
static void
virtqueue_map_desc(VuDev *dev,
unsigned int *p_num_sg, struct iovec *iov,
unsigned int max_num_sg, bool is_write,
uint64_t pa, size_t sz)
{
unsigned num_sg = *p_num_sg;
assert(num_sg <= max_num_sg);
if (!sz) {
vu_panic(dev, "virtio: zero sized buffers are not allowed");
return;
}
while (sz) {
uint64_t len = sz;
if (num_sg == max_num_sg) {
vu_panic(dev, "virtio: too many descriptors in indirect table");
return;
}
iov[num_sg].iov_base = vu_gpa_to_va(dev, &len, pa);
if (iov[num_sg].iov_base == NULL) {
vu_panic(dev, "virtio: invalid address for buffers");
return;
}
iov[num_sg].iov_len = len;
num_sg++;
sz -= len;
pa += len;
}
*p_num_sg = num_sg;
}
/* Round number down to multiple */
#define ALIGN_DOWN(n, m) ((n) / (m) * (m))
/* Round number up to multiple */
#define ALIGN_UP(n, m) ALIGN_DOWN((n) + (m) - 1, (m))
static void *
virtqueue_alloc_element(size_t sz,
unsigned out_num, unsigned in_num)
{
VuVirtqElement *elem;
size_t in_sg_ofs = ALIGN_UP(sz, __alignof__(elem->in_sg[0]));
size_t out_sg_ofs = in_sg_ofs + in_num * sizeof(elem->in_sg[0]);
size_t out_sg_end = out_sg_ofs + out_num * sizeof(elem->out_sg[0]);
assert(sz >= sizeof(VuVirtqElement));
elem = malloc(out_sg_end);
elem->out_num = out_num;
elem->in_num = in_num;
elem->in_sg = (void *)elem + in_sg_ofs;
elem->out_sg = (void *)elem + out_sg_ofs;
return elem;
}
void *
vu_queue_pop(VuDev *dev, VuVirtq *vq, size_t sz)
{
unsigned int i, head, max, desc_len;
uint64_t desc_addr, read_len;
VuVirtqElement *elem;
unsigned out_num, in_num;
struct iovec iov[VIRTQUEUE_MAX_SIZE];
struct vring_desc desc_buf[VIRTQUEUE_MAX_SIZE];
struct vring_desc *desc;
int rc;
if (unlikely(dev->broken) ||
unlikely(!vq->vring.avail)) {
return NULL;
}
if (vu_queue_empty(dev, vq)) {
return NULL;
}
/* Needed after virtio_queue_empty(), see comment in
* virtqueue_num_heads(). */
smp_rmb();
/* When we start there are none of either input nor output. */
out_num = in_num = 0;
max = vq->vring.num;
if (vq->inuse >= vq->vring.num) {
vu_panic(dev, "Virtqueue size exceeded");
return NULL;
}
if (!virtqueue_get_head(dev, vq, vq->last_avail_idx++, &head)) {
return NULL;
}
if (vu_has_feature(dev, VIRTIO_RING_F_EVENT_IDX)) {
vring_set_avail_event(vq, vq->last_avail_idx);
}
i = head;
desc = vq->vring.desc;
if (desc[i].flags & VRING_DESC_F_INDIRECT) {
if (desc[i].len % sizeof(struct vring_desc)) {
vu_panic(dev, "Invalid size for indirect buffer table");
}
/* loop over the indirect descriptor table */
desc_addr = desc[i].addr;
desc_len = desc[i].len;
max = desc_len / sizeof(struct vring_desc);
read_len = desc_len;
desc = vu_gpa_to_va(dev, &read_len, desc_addr);
if (unlikely(desc && read_len != desc_len)) {
/* Failed to use zero copy */
desc = NULL;
if (!virtqueue_read_indirect_desc(dev, desc_buf,
desc_addr,
desc_len)) {
desc = desc_buf;
}
}
if (!desc) {
vu_panic(dev, "Invalid indirect buffer table");
return NULL;
}
i = 0;
}
/* Collect all the descriptors */
do {
if (desc[i].flags & VRING_DESC_F_WRITE) {
virtqueue_map_desc(dev, &in_num, iov + out_num,
VIRTQUEUE_MAX_SIZE - out_num, true,
desc[i].addr, desc[i].len);
} else {
if (in_num) {
vu_panic(dev, "Incorrect order for descriptors");
return NULL;
}
virtqueue_map_desc(dev, &out_num, iov,
VIRTQUEUE_MAX_SIZE, false,
desc[i].addr, desc[i].len);
}
/* If we've got too many, that implies a descriptor loop. */
if ((in_num + out_num) > max) {
vu_panic(dev, "Looped descriptor");
}
rc = virtqueue_read_next_desc(dev, desc, i, max, &i);
} while (rc == VIRTQUEUE_READ_DESC_MORE);
if (rc == VIRTQUEUE_READ_DESC_ERROR) {
return NULL;
}
/* Now copy what we have collected and mapped */
elem = virtqueue_alloc_element(sz, out_num, in_num);
elem->index = head;
for (i = 0; i < out_num; i++) {
elem->out_sg[i] = iov[i];
}
for (i = 0; i < in_num; i++) {
elem->in_sg[i] = iov[out_num + i];
}
vq->inuse++;
return elem;
}
bool
vu_queue_rewind(VuDev *dev, VuVirtq *vq, unsigned int num)
{
if (num > vq->inuse) {
return false;
}
vq->last_avail_idx -= num;
vq->inuse -= num;
return true;
}
static inline
void vring_used_write(VuDev *dev, VuVirtq *vq,
struct vring_used_elem *uelem, int i)
{
struct vring_used *used = vq->vring.used;
used->ring[i] = *uelem;
vu_log_write(dev, vq->vring.log_guest_addr +
offsetof(struct vring_used, ring[i]),
sizeof(used->ring[i]));
}
static void
vu_log_queue_fill(VuDev *dev, VuVirtq *vq,
const VuVirtqElement *elem,
unsigned int len)
{
struct vring_desc *desc = vq->vring.desc;
unsigned int i, max, min, desc_len;
uint64_t desc_addr, read_len;
struct vring_desc desc_buf[VIRTQUEUE_MAX_SIZE];
unsigned num_bufs = 0;
max = vq->vring.num;
i = elem->index;
if (desc[i].flags & VRING_DESC_F_INDIRECT) {
if (desc[i].len % sizeof(struct vring_desc)) {
vu_panic(dev, "Invalid size for indirect buffer table");
}
/* loop over the indirect descriptor table */
desc_addr = desc[i].addr;
desc_len = desc[i].len;
max = desc_len / sizeof(struct vring_desc);
read_len = desc_len;
desc = vu_gpa_to_va(dev, &read_len, desc_addr);
if (unlikely(desc && read_len != desc_len)) {
/* Failed to use zero copy */
desc = NULL;
if (!virtqueue_read_indirect_desc(dev, desc_buf,
desc_addr,
desc_len)) {
desc = desc_buf;
}
}
if (!desc) {
vu_panic(dev, "Invalid indirect buffer table");
return;
}
i = 0;
}
do {
if (++num_bufs > max) {
vu_panic(dev, "Looped descriptor");
return;
}
if (desc[i].flags & VRING_DESC_F_WRITE) {
min = MIN(desc[i].len, len);
vu_log_write(dev, desc[i].addr, min);
len -= min;
}
} while (len > 0 &&
(virtqueue_read_next_desc(dev, desc, i, max, &i)
== VIRTQUEUE_READ_DESC_MORE));
}
void
vu_queue_fill(VuDev *dev, VuVirtq *vq,
const VuVirtqElement *elem,
unsigned int len, unsigned int idx)
{
struct vring_used_elem uelem;
if (unlikely(dev->broken) ||
unlikely(!vq->vring.avail)) {
return;
}
vu_log_queue_fill(dev, vq, elem, len);
idx = (idx + vq->used_idx) % vq->vring.num;
uelem.id = elem->index;
uelem.len = len;
vring_used_write(dev, vq, &uelem, idx);
}
static inline
void vring_used_idx_set(VuDev *dev, VuVirtq *vq, uint16_t val)
{
vq->vring.used->idx = val;
vu_log_write(dev,
vq->vring.log_guest_addr + offsetof(struct vring_used, idx),
sizeof(vq->vring.used->idx));
vq->used_idx = val;
}
void
vu_queue_flush(VuDev *dev, VuVirtq *vq, unsigned int count)
{
uint16_t old, new;
if (unlikely(dev->broken) ||
unlikely(!vq->vring.avail)) {
return;
}
/* Make sure buffer is written before we update index. */
smp_wmb();
old = vq->used_idx;
new = old + count;
vring_used_idx_set(dev, vq, new);
vq->inuse -= count;
if (unlikely((int16_t)(new - vq->signalled_used) < (uint16_t)(new - old))) {
vq->signalled_used_valid = false;
}
}
void
vu_queue_push(VuDev *dev, VuVirtq *vq,
const VuVirtqElement *elem, unsigned int len)
{
vu_queue_fill(dev, vq, elem, len, 0);
vu_queue_flush(dev, vq, 1);
}