Skip to content

Instantly share code, notes, and snippets.

@apconole
Last active August 7, 2025 12:58
Show Gist options
  • Select an option

  • Save apconole/70ebb980de70ccd06e6f3059de66e4b9 to your computer and use it in GitHub Desktop.

Select an option

Save apconole/70ebb980de70ccd06e6f3059de66e4b9 to your computer and use it in GitHub Desktop.
Start of a pythonic vhost user library
# SPDX-License-Identifier: BSD-3-Clause
# Copyright (C) 2025, Red Hat, Inc.
"""VHOST User Implementation
"""
import array
import mmap
import os
import select
import socket
import stat
import struct
# Vhost Memory Region
class VhostMemoryRegion:
def __init__(self, guest_phys_addr, memory_size, userspace_addr, mmap_off,
mmap_fd):
self.guest_phys_addr = guest_phys_addr
self.memory_size = memory_size
self.userspace_addr = userspace_addr
self.mmap_offset = mmap_off
self.mmap_fd = mmap_fd
self.mapped = None
def mmap(self):
if self.mmap_fd is None:
raise RuntimeError("No FD for mmap region.")
self.mapped = mmap.mmap(self.mmap_fd, self.memory_size,
flags=mmap.MAP_SHARED,
prot=mmap.PROT_READ | mmap.PROT_WRITE,
offset = self.mmap_offset)
return self.mapped
@staticmethod
def parse_regions(data, fds):
"""Parse the payload of SET_MEM_TABLE and associated FDs."""
regions = []
region_size = 8 * 4
if len(data) % region_size != 0:
raise ValueError("Invalid SET_MEM_TABLE payload length")
count = len(data) // region_size
if count > len(fds):
raise ValueError("Not enough FDs passed for memory regions")
for i in range(count):
offset = i * region_size
fields = struct.unpack("<QQQQ", data[offset:offset + region_size])
region = VhostMemoryRegion(*fields, mmap_fd=fds[i])
regions.append(region)
return regions
# Vhost Message
class VhostMsg:
VHOST_MSG_HDRLEN = 16 # VhostMsg header
# Pulled from lib/vhost/vhost_user.h
VHOST_USER_GET_FEATURES = 1
VHOST_USER_SET_FEATURES = 2
VHOST_USER_SET_OWNER = 3
VHOST_USER_RESET_OWNER = 4
VHOST_USER_SET_MEM_TABLE = 5
VHOST_USER_SET_LOG_BASE = 6
VHOST_USER_SET_LOG_FD = 7
VHOST_USER_SET_VRING_NUM = 8
VHOST_USER_SET_VRING_ADDR = 9
VHOST_USER_SET_VRING_BASE = 10
VHOST_USER_GET_VRING_BASE = 11
VHOST_USER_SET_VRING_KICK = 12
VHOST_USER_SET_VRING_CALL = 13
VHOST_USER_SET_VRING_ERR = 14
VHOST_USER_GET_PROTOCOL_FEATURES = 15
VHOST_USER_SET_PROTOCOL_FEATURES = 16
VHOST_USER_GET_QUEUE_NUM = 17
VHOST_USER_SET_VRING_ENABLE = 18
VHOST_USER_SEND_RARP = 19
VHOST_USER_NET_SET_MTU = 20
VHOST_USER_SET_BACKEND_REQ_FD = 21
VHOST_USER_IOTLB_MSG = 22
VHOST_USER_GET_CONFIG = 24
VHOST_USER_SET_CONFIG = 25
VHOST_USER_CRYPTO_CREATE_SESS = 26
VHOST_USER_CRYPTO_CLOSE_SESS = 27
VHOST_USER_POSTCOPY_ADVISE = 28
VHOST_USER_POSTCOPY_LISTEN = 29
VHOST_USER_POSTCOPY_END = 30
VHOST_USER_GET_INFLIGHT_FD = 31
VHOST_USER_SET_INFLIGHT_FD = 32
VHOST_USER_SET_STATUS = 39
VHOST_USER_GET_STATUS = 40
def __init__(self, type, flags, size, data=b''):
self.type = type
self.flags = flags
self.size = size
self.data = data
def encode(self):
return (struct.pack("<IIII", self.type, self.flags, self.size, 0) +
self.data)
@staticmethod
def decode(data):
type, flags, size, _ = struct.unpack("<IIII", data[:16])
return VhostMsg(type, flags, size, data[16:16+size])
@staticmethod
def type_name(type):
for k, v in VhostMsg.__dict__.items():
if not k.startswith("__") and v == type:
return k
raise ValueError(f"No type for {type} found.")
@staticmethod
def type_val(name):
try:
return getattr(VhostMsg, name)
except AttributeError:
raise ValueError(f"No such type {name} found.")
class VhostRing:
def __init__(self, index):
self.index = index
self.num = 0
self.free_descs = collections.deque(range(self.num))
self.desc_addr = 0
self.avail_addr = 0
self.used_addr = 0
self.flags = 0
self.log_guest_addr = 0
self.kick_fd = None
self.call_fd = None
self.avail_idx = 0
self.last_used_idx = 0
self.next_data_offset = 0x4000
def set_num(self, num):
self.num = num
self.free_descs = collections.deque(range(self.num))
def set_addr(self, desc, avail, used, flags, log_guest_addr=0):
self.desc_addr = desc
self.avail_addr = avail
self.used_addr = used
self.flags = flags
self.log_guest_addr = log_guest_addr
def set_kick_fd(self, fd):
self.kick_fd = fd
def set_call_fd(self, fd):
self.call_fd = fd
def reclaim_used_descriptors(self, memory_region):
"""Scan the used ring and reclaim descriptors."""
while self.last_used_idx != self._read_used_idx(memory_region):
used_entry_offset = self.used_addr + 4 + 4 + (self.last_used_idx %
self.num) * 8
id_bytes = memory_region[used_entry_offset:used_entry_offset + 4]
desc_id = struct.unpack("<I", id_bytes)[0]
self.free_descs.append(desc_id)
self.last_used_idx += 1
def write_packets(self, memory_region, packets, guest_phys_base=0x100000):
"""Submit packets to the vring, safely reusing descriptors."""
desc_size = 16
avail_ring_offset = self.avail_addr + 4 # skip flags + idx
avail_entries = []
self.reclaim_used_descriptors(memory_region)
for pkt in packets:
if not self.free_descs:
print("No free descriptors available — skipping packet")
break
desc_index = self.free_descs.popleft()
pkt_len = len(pkt)
pkt_phys = guest_phys_base + self.next_data_offset
# Write descriptor
desc = struct.pack("<QQHH", pkt_phys, pkt_len, 0, 0)
memory_region[self.desc_addr + desc_index * desc_size:
self.desc_addr + (desc_index + 1) * desc_size] = desc
# Write packet data
memory_region[self.next_data_offset:self.next_data_offset + pkt_len] = pkt
self.next_data_offset += 0x1000
# Write to avail ring
avail_entry_off = avail_ring_offset + (self.avail_idx % self.num) * 2
memory_region[avail_entry_off:avail_entry_off + 2] = struct.pack("<H", desc_index)
self.avail_idx += 1
avail_entries.append(desc_index)
# Write avail.idx
memory_region[self.avail_addr + 2: self.avail_addr + 4] = struct.pack(
"<H", self.avail_idx)
if avail_entries and self.kick_fd:
os.write(self.kick_fd, struct.pack("Q", 1))
print(f"Vring {self.index}: Submitted {len(avail_entries)} packets")
def _read_used_idx(self, memory_region):
return struct.unpack("<H",
memory_region[self.used_addr + 2:
self.used_addr + 4])[0]
def read_packets(self, memory_region, guest_phys_base=0x100000):
desc_size = 16
packets = []
# Read avail.idx from guest
avail_idx = self._read_used_idx(memory_region)
while self.last_used_idx < avail_idx:
ring_idx = self.last_used_idx % self.num
avail_entry_off = self.avail_addr + 4 + ring_idx * 2
desc_index = struct.unpack("<H",
memory_region[avail_entry_off:
avail_entry_off + 2])[0]
# Read descriptor
desc_off = self.desc_addr + desc_index * desc_size
desc_data = memory_region[desc_off:desc_off + desc_size]
addr, length, flags, next_desc = struct.unpack("<QQHH", desc_data)
pkt_offset = addr - guest_phys_base
pkt = bytes(memory_region[pkt_offset:pkt_offset + length])
print(f"Vring {self.index}: Received packet: {pkt.hex()}")
# Add the packet to the list.
packets.append(pkt)
# Write to used ring
used_entry_offset = self.used_addr + 4 + 4 + ring_idx * 8
memory_region[used_entry_offset:
used_entry_offset + 8] = struct.pack("<II",
desc_index,
length)
self.last_used_idx += 1
# Write updated used.idx
memory_region[self.used_addr + 2:
self.used_addr + 4] = struct.pack("<H",
self.last_used_idx)
# Notify client via call fd
if self.call_fd:
self.notify()
return packets
def wait(self):
if not self.kick_fd:
raise RuntimeError("Need a kick fd")
rlist, _, _ = select.select([self.kick_fd], [], [], 0.1)
if not rlist:
return False
return True
def notify(self):
if not self.call_fd:
raise RuntimeError("Need a call fd")
os.write(self.call_fd, struct.pack("Q", 1))
class VhostSocket:
def __init__(self, path="", mode=stat.S_IRUSR | stat.S_IWUSR, server=False):
self.path = path
self.mode = mode
self.is_server = server
self.socket = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
self.accepted = None
def listen(self):
if self.path == "" or not self.is_server:
raise ValueError("Cannot listen on a blank socket.")
if self.accepted is not None:
self.accepted.close()
self.accepted = None
try:
os.unlink(self.path)
except OSError:
if os.path.exists(self.path):
raise
self.socket.bind(self.path)
self.socket.listen(1)
def accept(self):
self.accepted, _ = self.socket.accept()
def connect(self):
self.socket.connect(self.path)
def active_sock(self):
return self.accepted if self.accepted is not None else self.socket
def recv_msg(self):
fds = array.array("i") # to store received FDs
# Ancillary data buffer (enough for 8 fds)
cmsg_buffer_size = socket.CMSG_LEN(4 * 8)
msg, ancdata, _, _ = self.active_sock().recvmsg(
VhostMsg.VHOST_MSG_HDRLEN, cmsg_buffer_size)
if not msg:
return None, []
msg_obj = VhostMsg.decode(msg)
# Parse ancillary FDs
for cmsg_level, cmsg_type, cmsg_data in ancdata:
if cmsg_level == socket.SOL_SOCKET and \
cmsg_type == socket.SCM_RIGHTS:
fds.frombytes(cmsg_data[:len(cmsg_data)])
# Now receive message body (if needed)
body = b""
while len(body) < msg_obj.size:
chunk = self.active_sock().recv(msg_obj.size - len(body))
if not chunk:
break
body += chunk
msg_obj.data = body
return msg_obj, list(fds)
class VhostClient:
def __init__(self, path):
self.socket = VhostSocket(path=path, server=False)
self.socket.connect()
self.memfd = None
self.mem_size = 0x10000
self.mmap_region = None
def create_memory(self, memfd_name="vhost-user-shmem"):
self.memfd = os.memfd_create(memfd_name, os.MFD_CLOEXEC)
os.ftruncate(self.memfd, self.mem_size)
self.mmap_region = mmap.mmap(self.memfd, self.mem_size,
mmap.MAP_SHARED,
mmap.PROT_READ | mmap.PROT_WRITE)
def send_set_mem_table(self):
guest_phys_addr = 0x100000
userspace_addr = 0x0
mmap_offset = 0
region = struct.pack("<QQQQ", guest_phys_addr, self.mem_size,
userspace_addr, mmap_offset)
header = struct.pack("<IIII", int(VhostMsg.VHOST_USER_SET_MEM_TABLE),
0, len(region), 0)
fds = array.array("i", [self.memfd])
self.socket.socket.sendmsg([header, region],
[(socket.SOL_SOCKET, socket.SCM_RIGHTS,
fds.tobytes())])
def send_set_vring_num(self, index, num):
data = struct.pack("<II", index, num)
header = struct.pack("<IIII", int(VhostMsg.VHOST_USER_SET_VRING_NUM), 0,
len(data), 0)
self.socket.socket.sendall(header + data)
def send_set_vring_addr(self, index, desc_addr, avail_addr, used_addr):
flags = 0
log_guest_addr = 0
data = struct.pack("<IIQQQQ", index, flags, desc_addr, avail_addr,
used_addr, log_guest_addr)
header = struct.pack("<IIII", int(VhostMsg.VHOST_USER_SET_VRING_ADDR),
0, len(data), 0)
self.socket.socket.sendall(header + data)
def send_set_vring_kick(self, index):
kick_fd = os.eventfd(0, os.EFD_NONBLOCK)
data = struct.pack("<I", index)
header = struct.pack("<IIII", int(VhostMsg.VHOST_USER_SET_VRING_KICK),
0, len(data), 0)
fds = array.array("i", [kick_fd])
self.socket.socket.sendmsg([header, data], [(socket.SOL_SOCKET,
socket.SCM_RIGHTS,
fds.tobytes())])
return kick_fd
def send_set_vring_call(self, index):
call_fd = os.eventfd(0, os.EFD_NONBLOCK)
data = struct.pack("<I", index)
header = struct.pack("<IIII", int(VhostMsg.VHOST_USER_SET_VRING_CALL),
0, len(data), 0)
fds = array.array("i", [call_fd])
self.socket.socket.sendmsg([header, data],
[(socket.SOL_SOCKET, socket.SCM_RIGHTS,
fds.tobytes())])
return call_fd
class VhostServer:
def __init__(self, path):
self.socket = VhostSocket(path=path, server=True)
self.vrings = {}
self.memory = []
def start(self):
self.socket.listen()
self.socket.accept()
def handle_message(self):
msg, fds = self.socket.recv_msg()
if msg is None:
return False
print(f"Server: Received {VhostMsg.type_name(msg.type)}")
if msg.type == VhostMsg.VHOST_USER_GET_FEATURES:
print(f"Server: features requested.")
# write up the GET FEATURES
elif msg.type == VhostMsg.VHOST_USER_SET_MEM_TABLE:
regions = VhostMemoryRegion.parse_regions(msg.data, fds)
for region in regions:
region.mmap_region()
self.memory.extend(regions)
print(f"Server: Mapped {len(regions)} memory region(s)")
elif msg.type == VhostMsg.VHOST_USER_SET_VRING_NUM:
index, num = struct.unpack("<II", msg.data[:8])
vring = self.vrings.get(index, VhostRing(index))
vring.set_num(num)
self.vrings[index] = vring
print(f"Server: Set VRING {index} num={num}")
elif msg.type == VhostMsg.VHOST_USER_SET_VRING_ADDR:
index, flags = struct.unpack("<II", msg.data[0:8])
desc, avail, used, log_guest = struct.unpack("<QQQQ",
msg.data[8:40])
vring = self.vrings.get(index, VhostRing(index))
vring.set_addr(desc, avail, used, flags, log_guest)
self.vrings[index] = vring
print(f"Server: Set VRING {index} desc=0x{desc:X}, avail=0x{avail:X}, used=0x{used:X}")
elif msg.type == VhostMsg.VHOST_USER_SET_VRING_KICK:
index = struct.unpack("<I", msg.data[:4])[0]
fd = fds[0] if fds else None
vring = self.vrings.get(index, VhostRing(index))
vring.set_kick_fd(fd)
self.vrings[index] = vring
print(f"Server: Set KICK FD for VRING {index}")
elif msg.type == VhostMsg.VHOST_USER_SET_VRING_CALL:
index = struct.unpack("<I", msg.data[:4])[0]
fd = fds[0] if fds else None
vring = self.vrings.get(index, VhostRing(index))
vring.set_call_fd(fd)
self.vrings[index] = vring
print(f"Server: Set CALL FD for VRING {index}")
else:
print(f"Server: Unhandled message type {msg.type}")
return True
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment