Last active
August 7, 2025 12:58
-
-
Save apconole/70ebb980de70ccd06e6f3059de66e4b9 to your computer and use it in GitHub Desktop.
Start of a pythonic vhost user library
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| # SPDX-License-Identifier: BSD-3-Clause | |
| # Copyright (C) 2025, Red Hat, Inc. | |
| """VHOST User Implementation | |
| """ | |
| import array | |
| import mmap | |
| import os | |
| import select | |
| import socket | |
| import stat | |
| import struct | |
| # Vhost Memory Region | |
| class VhostMemoryRegion: | |
| def __init__(self, guest_phys_addr, memory_size, userspace_addr, mmap_off, | |
| mmap_fd): | |
| self.guest_phys_addr = guest_phys_addr | |
| self.memory_size = memory_size | |
| self.userspace_addr = userspace_addr | |
| self.mmap_offset = mmap_off | |
| self.mmap_fd = mmap_fd | |
| self.mapped = None | |
| def mmap(self): | |
| if self.mmap_fd is None: | |
| raise RuntimeError("No FD for mmap region.") | |
| self.mapped = mmap.mmap(self.mmap_fd, self.memory_size, | |
| flags=mmap.MAP_SHARED, | |
| prot=mmap.PROT_READ | mmap.PROT_WRITE, | |
| offset = self.mmap_offset) | |
| return self.mapped | |
| @staticmethod | |
| def parse_regions(data, fds): | |
| """Parse the payload of SET_MEM_TABLE and associated FDs.""" | |
| regions = [] | |
| region_size = 8 * 4 | |
| if len(data) % region_size != 0: | |
| raise ValueError("Invalid SET_MEM_TABLE payload length") | |
| count = len(data) // region_size | |
| if count > len(fds): | |
| raise ValueError("Not enough FDs passed for memory regions") | |
| for i in range(count): | |
| offset = i * region_size | |
| fields = struct.unpack("<QQQQ", data[offset:offset + region_size]) | |
| region = VhostMemoryRegion(*fields, mmap_fd=fds[i]) | |
| regions.append(region) | |
| return regions | |
| # Vhost Message | |
| class VhostMsg: | |
| VHOST_MSG_HDRLEN = 16 # VhostMsg header | |
| # Pulled from lib/vhost/vhost_user.h | |
| VHOST_USER_GET_FEATURES = 1 | |
| VHOST_USER_SET_FEATURES = 2 | |
| VHOST_USER_SET_OWNER = 3 | |
| VHOST_USER_RESET_OWNER = 4 | |
| VHOST_USER_SET_MEM_TABLE = 5 | |
| VHOST_USER_SET_LOG_BASE = 6 | |
| VHOST_USER_SET_LOG_FD = 7 | |
| VHOST_USER_SET_VRING_NUM = 8 | |
| VHOST_USER_SET_VRING_ADDR = 9 | |
| VHOST_USER_SET_VRING_BASE = 10 | |
| VHOST_USER_GET_VRING_BASE = 11 | |
| VHOST_USER_SET_VRING_KICK = 12 | |
| VHOST_USER_SET_VRING_CALL = 13 | |
| VHOST_USER_SET_VRING_ERR = 14 | |
| VHOST_USER_GET_PROTOCOL_FEATURES = 15 | |
| VHOST_USER_SET_PROTOCOL_FEATURES = 16 | |
| VHOST_USER_GET_QUEUE_NUM = 17 | |
| VHOST_USER_SET_VRING_ENABLE = 18 | |
| VHOST_USER_SEND_RARP = 19 | |
| VHOST_USER_NET_SET_MTU = 20 | |
| VHOST_USER_SET_BACKEND_REQ_FD = 21 | |
| VHOST_USER_IOTLB_MSG = 22 | |
| VHOST_USER_GET_CONFIG = 24 | |
| VHOST_USER_SET_CONFIG = 25 | |
| VHOST_USER_CRYPTO_CREATE_SESS = 26 | |
| VHOST_USER_CRYPTO_CLOSE_SESS = 27 | |
| VHOST_USER_POSTCOPY_ADVISE = 28 | |
| VHOST_USER_POSTCOPY_LISTEN = 29 | |
| VHOST_USER_POSTCOPY_END = 30 | |
| VHOST_USER_GET_INFLIGHT_FD = 31 | |
| VHOST_USER_SET_INFLIGHT_FD = 32 | |
| VHOST_USER_SET_STATUS = 39 | |
| VHOST_USER_GET_STATUS = 40 | |
| def __init__(self, type, flags, size, data=b''): | |
| self.type = type | |
| self.flags = flags | |
| self.size = size | |
| self.data = data | |
| def encode(self): | |
| return (struct.pack("<IIII", self.type, self.flags, self.size, 0) + | |
| self.data) | |
| @staticmethod | |
| def decode(data): | |
| type, flags, size, _ = struct.unpack("<IIII", data[:16]) | |
| return VhostMsg(type, flags, size, data[16:16+size]) | |
| @staticmethod | |
| def type_name(type): | |
| for k, v in VhostMsg.__dict__.items(): | |
| if not k.startswith("__") and v == type: | |
| return k | |
| raise ValueError(f"No type for {type} found.") | |
| @staticmethod | |
| def type_val(name): | |
| try: | |
| return getattr(VhostMsg, name) | |
| except AttributeError: | |
| raise ValueError(f"No such type {name} found.") | |
| class VhostRing: | |
| def __init__(self, index): | |
| self.index = index | |
| self.num = 0 | |
| self.free_descs = collections.deque(range(self.num)) | |
| self.desc_addr = 0 | |
| self.avail_addr = 0 | |
| self.used_addr = 0 | |
| self.flags = 0 | |
| self.log_guest_addr = 0 | |
| self.kick_fd = None | |
| self.call_fd = None | |
| self.avail_idx = 0 | |
| self.last_used_idx = 0 | |
| self.next_data_offset = 0x4000 | |
| def set_num(self, num): | |
| self.num = num | |
| self.free_descs = collections.deque(range(self.num)) | |
| def set_addr(self, desc, avail, used, flags, log_guest_addr=0): | |
| self.desc_addr = desc | |
| self.avail_addr = avail | |
| self.used_addr = used | |
| self.flags = flags | |
| self.log_guest_addr = log_guest_addr | |
| def set_kick_fd(self, fd): | |
| self.kick_fd = fd | |
| def set_call_fd(self, fd): | |
| self.call_fd = fd | |
| def reclaim_used_descriptors(self, memory_region): | |
| """Scan the used ring and reclaim descriptors.""" | |
| while self.last_used_idx != self._read_used_idx(memory_region): | |
| used_entry_offset = self.used_addr + 4 + 4 + (self.last_used_idx % | |
| self.num) * 8 | |
| id_bytes = memory_region[used_entry_offset:used_entry_offset + 4] | |
| desc_id = struct.unpack("<I", id_bytes)[0] | |
| self.free_descs.append(desc_id) | |
| self.last_used_idx += 1 | |
| def write_packets(self, memory_region, packets, guest_phys_base=0x100000): | |
| """Submit packets to the vring, safely reusing descriptors.""" | |
| desc_size = 16 | |
| avail_ring_offset = self.avail_addr + 4 # skip flags + idx | |
| avail_entries = [] | |
| self.reclaim_used_descriptors(memory_region) | |
| for pkt in packets: | |
| if not self.free_descs: | |
| print("No free descriptors available — skipping packet") | |
| break | |
| desc_index = self.free_descs.popleft() | |
| pkt_len = len(pkt) | |
| pkt_phys = guest_phys_base + self.next_data_offset | |
| # Write descriptor | |
| desc = struct.pack("<QQHH", pkt_phys, pkt_len, 0, 0) | |
| memory_region[self.desc_addr + desc_index * desc_size: | |
| self.desc_addr + (desc_index + 1) * desc_size] = desc | |
| # Write packet data | |
| memory_region[self.next_data_offset:self.next_data_offset + pkt_len] = pkt | |
| self.next_data_offset += 0x1000 | |
| # Write to avail ring | |
| avail_entry_off = avail_ring_offset + (self.avail_idx % self.num) * 2 | |
| memory_region[avail_entry_off:avail_entry_off + 2] = struct.pack("<H", desc_index) | |
| self.avail_idx += 1 | |
| avail_entries.append(desc_index) | |
| # Write avail.idx | |
| memory_region[self.avail_addr + 2: self.avail_addr + 4] = struct.pack( | |
| "<H", self.avail_idx) | |
| if avail_entries and self.kick_fd: | |
| os.write(self.kick_fd, struct.pack("Q", 1)) | |
| print(f"Vring {self.index}: Submitted {len(avail_entries)} packets") | |
| def _read_used_idx(self, memory_region): | |
| return struct.unpack("<H", | |
| memory_region[self.used_addr + 2: | |
| self.used_addr + 4])[0] | |
| def read_packets(self, memory_region, guest_phys_base=0x100000): | |
| desc_size = 16 | |
| packets = [] | |
| # Read avail.idx from guest | |
| avail_idx = self._read_used_idx(memory_region) | |
| while self.last_used_idx < avail_idx: | |
| ring_idx = self.last_used_idx % self.num | |
| avail_entry_off = self.avail_addr + 4 + ring_idx * 2 | |
| desc_index = struct.unpack("<H", | |
| memory_region[avail_entry_off: | |
| avail_entry_off + 2])[0] | |
| # Read descriptor | |
| desc_off = self.desc_addr + desc_index * desc_size | |
| desc_data = memory_region[desc_off:desc_off + desc_size] | |
| addr, length, flags, next_desc = struct.unpack("<QQHH", desc_data) | |
| pkt_offset = addr - guest_phys_base | |
| pkt = bytes(memory_region[pkt_offset:pkt_offset + length]) | |
| print(f"Vring {self.index}: Received packet: {pkt.hex()}") | |
| # Add the packet to the list. | |
| packets.append(pkt) | |
| # Write to used ring | |
| used_entry_offset = self.used_addr + 4 + 4 + ring_idx * 8 | |
| memory_region[used_entry_offset: | |
| used_entry_offset + 8] = struct.pack("<II", | |
| desc_index, | |
| length) | |
| self.last_used_idx += 1 | |
| # Write updated used.idx | |
| memory_region[self.used_addr + 2: | |
| self.used_addr + 4] = struct.pack("<H", | |
| self.last_used_idx) | |
| # Notify client via call fd | |
| if self.call_fd: | |
| self.notify() | |
| return packets | |
| def wait(self): | |
| if not self.kick_fd: | |
| raise RuntimeError("Need a kick fd") | |
| rlist, _, _ = select.select([self.kick_fd], [], [], 0.1) | |
| if not rlist: | |
| return False | |
| return True | |
| def notify(self): | |
| if not self.call_fd: | |
| raise RuntimeError("Need a call fd") | |
| os.write(self.call_fd, struct.pack("Q", 1)) | |
| class VhostSocket: | |
| def __init__(self, path="", mode=stat.S_IRUSR | stat.S_IWUSR, server=False): | |
| self.path = path | |
| self.mode = mode | |
| self.is_server = server | |
| self.socket = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) | |
| self.accepted = None | |
| def listen(self): | |
| if self.path == "" or not self.is_server: | |
| raise ValueError("Cannot listen on a blank socket.") | |
| if self.accepted is not None: | |
| self.accepted.close() | |
| self.accepted = None | |
| try: | |
| os.unlink(self.path) | |
| except OSError: | |
| if os.path.exists(self.path): | |
| raise | |
| self.socket.bind(self.path) | |
| self.socket.listen(1) | |
| def accept(self): | |
| self.accepted, _ = self.socket.accept() | |
| def connect(self): | |
| self.socket.connect(self.path) | |
| def active_sock(self): | |
| return self.accepted if self.accepted is not None else self.socket | |
| def recv_msg(self): | |
| fds = array.array("i") # to store received FDs | |
| # Ancillary data buffer (enough for 8 fds) | |
| cmsg_buffer_size = socket.CMSG_LEN(4 * 8) | |
| msg, ancdata, _, _ = self.active_sock().recvmsg( | |
| VhostMsg.VHOST_MSG_HDRLEN, cmsg_buffer_size) | |
| if not msg: | |
| return None, [] | |
| msg_obj = VhostMsg.decode(msg) | |
| # Parse ancillary FDs | |
| for cmsg_level, cmsg_type, cmsg_data in ancdata: | |
| if cmsg_level == socket.SOL_SOCKET and \ | |
| cmsg_type == socket.SCM_RIGHTS: | |
| fds.frombytes(cmsg_data[:len(cmsg_data)]) | |
| # Now receive message body (if needed) | |
| body = b"" | |
| while len(body) < msg_obj.size: | |
| chunk = self.active_sock().recv(msg_obj.size - len(body)) | |
| if not chunk: | |
| break | |
| body += chunk | |
| msg_obj.data = body | |
| return msg_obj, list(fds) | |
| class VhostClient: | |
| def __init__(self, path): | |
| self.socket = VhostSocket(path=path, server=False) | |
| self.socket.connect() | |
| self.memfd = None | |
| self.mem_size = 0x10000 | |
| self.mmap_region = None | |
| def create_memory(self, memfd_name="vhost-user-shmem"): | |
| self.memfd = os.memfd_create(memfd_name, os.MFD_CLOEXEC) | |
| os.ftruncate(self.memfd, self.mem_size) | |
| self.mmap_region = mmap.mmap(self.memfd, self.mem_size, | |
| mmap.MAP_SHARED, | |
| mmap.PROT_READ | mmap.PROT_WRITE) | |
| def send_set_mem_table(self): | |
| guest_phys_addr = 0x100000 | |
| userspace_addr = 0x0 | |
| mmap_offset = 0 | |
| region = struct.pack("<QQQQ", guest_phys_addr, self.mem_size, | |
| userspace_addr, mmap_offset) | |
| header = struct.pack("<IIII", int(VhostMsg.VHOST_USER_SET_MEM_TABLE), | |
| 0, len(region), 0) | |
| fds = array.array("i", [self.memfd]) | |
| self.socket.socket.sendmsg([header, region], | |
| [(socket.SOL_SOCKET, socket.SCM_RIGHTS, | |
| fds.tobytes())]) | |
| def send_set_vring_num(self, index, num): | |
| data = struct.pack("<II", index, num) | |
| header = struct.pack("<IIII", int(VhostMsg.VHOST_USER_SET_VRING_NUM), 0, | |
| len(data), 0) | |
| self.socket.socket.sendall(header + data) | |
| def send_set_vring_addr(self, index, desc_addr, avail_addr, used_addr): | |
| flags = 0 | |
| log_guest_addr = 0 | |
| data = struct.pack("<IIQQQQ", index, flags, desc_addr, avail_addr, | |
| used_addr, log_guest_addr) | |
| header = struct.pack("<IIII", int(VhostMsg.VHOST_USER_SET_VRING_ADDR), | |
| 0, len(data), 0) | |
| self.socket.socket.sendall(header + data) | |
| def send_set_vring_kick(self, index): | |
| kick_fd = os.eventfd(0, os.EFD_NONBLOCK) | |
| data = struct.pack("<I", index) | |
| header = struct.pack("<IIII", int(VhostMsg.VHOST_USER_SET_VRING_KICK), | |
| 0, len(data), 0) | |
| fds = array.array("i", [kick_fd]) | |
| self.socket.socket.sendmsg([header, data], [(socket.SOL_SOCKET, | |
| socket.SCM_RIGHTS, | |
| fds.tobytes())]) | |
| return kick_fd | |
| def send_set_vring_call(self, index): | |
| call_fd = os.eventfd(0, os.EFD_NONBLOCK) | |
| data = struct.pack("<I", index) | |
| header = struct.pack("<IIII", int(VhostMsg.VHOST_USER_SET_VRING_CALL), | |
| 0, len(data), 0) | |
| fds = array.array("i", [call_fd]) | |
| self.socket.socket.sendmsg([header, data], | |
| [(socket.SOL_SOCKET, socket.SCM_RIGHTS, | |
| fds.tobytes())]) | |
| return call_fd | |
| class VhostServer: | |
| def __init__(self, path): | |
| self.socket = VhostSocket(path=path, server=True) | |
| self.vrings = {} | |
| self.memory = [] | |
| def start(self): | |
| self.socket.listen() | |
| self.socket.accept() | |
| def handle_message(self): | |
| msg, fds = self.socket.recv_msg() | |
| if msg is None: | |
| return False | |
| print(f"Server: Received {VhostMsg.type_name(msg.type)}") | |
| if msg.type == VhostMsg.VHOST_USER_GET_FEATURES: | |
| print(f"Server: features requested.") | |
| # write up the GET FEATURES | |
| elif msg.type == VhostMsg.VHOST_USER_SET_MEM_TABLE: | |
| regions = VhostMemoryRegion.parse_regions(msg.data, fds) | |
| for region in regions: | |
| region.mmap_region() | |
| self.memory.extend(regions) | |
| print(f"Server: Mapped {len(regions)} memory region(s)") | |
| elif msg.type == VhostMsg.VHOST_USER_SET_VRING_NUM: | |
| index, num = struct.unpack("<II", msg.data[:8]) | |
| vring = self.vrings.get(index, VhostRing(index)) | |
| vring.set_num(num) | |
| self.vrings[index] = vring | |
| print(f"Server: Set VRING {index} num={num}") | |
| elif msg.type == VhostMsg.VHOST_USER_SET_VRING_ADDR: | |
| index, flags = struct.unpack("<II", msg.data[0:8]) | |
| desc, avail, used, log_guest = struct.unpack("<QQQQ", | |
| msg.data[8:40]) | |
| vring = self.vrings.get(index, VhostRing(index)) | |
| vring.set_addr(desc, avail, used, flags, log_guest) | |
| self.vrings[index] = vring | |
| print(f"Server: Set VRING {index} desc=0x{desc:X}, avail=0x{avail:X}, used=0x{used:X}") | |
| elif msg.type == VhostMsg.VHOST_USER_SET_VRING_KICK: | |
| index = struct.unpack("<I", msg.data[:4])[0] | |
| fd = fds[0] if fds else None | |
| vring = self.vrings.get(index, VhostRing(index)) | |
| vring.set_kick_fd(fd) | |
| self.vrings[index] = vring | |
| print(f"Server: Set KICK FD for VRING {index}") | |
| elif msg.type == VhostMsg.VHOST_USER_SET_VRING_CALL: | |
| index = struct.unpack("<I", msg.data[:4])[0] | |
| fd = fds[0] if fds else None | |
| vring = self.vrings.get(index, VhostRing(index)) | |
| vring.set_call_fd(fd) | |
| self.vrings[index] = vring | |
| print(f"Server: Set CALL FD for VRING {index}") | |
| else: | |
| print(f"Server: Unhandled message type {msg.type}") | |
| return True |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment