diff --git a/Cargo.lock b/Cargo.lock index a8330328a..32585a9fe 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -5350,6 +5350,7 @@ dependencies = [ "dlpi 0.2.0 (git+https://github.com/oxidecomputer/dlpi-sys?branch=main)", "erased-serde 0.4.5", "futures", + "iddqd", "ispf", "lazy_static", "libc", diff --git a/Cargo.toml b/Cargo.toml index a3c16a5ed..385d76541 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -132,6 +132,7 @@ hex = "0.4.3" http = "1.1.0" hyper = "1.0" linkme = "0.3.33" +iddqd = "0.3" itertools = "0.13.0" kstat-rs = "0.2.4" lazy_static = "1.4" diff --git a/bin/propolis-standalone/src/config.rs b/bin/propolis-standalone/src/config.rs index 58b485fd1..3aac56a73 100644 --- a/bin/propolis-standalone/src/config.rs +++ b/bin/propolis-standalone/src/config.rs @@ -10,6 +10,7 @@ use std::sync::Arc; use anyhow::Context; use cpuid_utils::CpuidSet; +use propolis::vsock::proxy::VsockPortMapping; use propolis_types::CpuidIdent; use propolis_types::CpuidValues; use propolis_types::CpuidVendor; @@ -170,6 +171,20 @@ impl VionaDeviceParams { } } +#[derive(Deserialize)] +pub struct VsockDevice { + pub guest_cid: u32, + pub port_mappings: Vec, +} + +impl VsockDevice { + pub fn from_opts( + opts: &BTreeMap, + ) -> Result { + opt_deser(opts) + } +} + // Try to turn unmatched flattened options into a config struct fn opt_deser<'de, T: Deserialize<'de>>( value: &BTreeMap, diff --git a/bin/propolis-standalone/src/main.rs b/bin/propolis-standalone/src/main.rs index 42a97319a..070f73a06 100644 --- a/bin/propolis-standalone/src/main.rs +++ b/bin/propolis-standalone/src/main.rs @@ -1308,6 +1308,19 @@ fn setup_instance( guard.inventory.register(&pvpanic); } } + "pci-virtio-vsock" => { + // XXX MTZ: add the vsock device + let config = config::VsockDevice::from_opts(&dev.options)?; + let bdf = bdf.unwrap(); + let vsock = hw::virtio::PciVirtioSock::new( + 512, + config.guest_cid, + log.new(slog::o!("dev" => "vsock")), + config.port_mappings, + ); + guard.inventory.register(&vsock); + chipset_pci_attach(bdf, vsock); + } _ => { slog::error!(log, "unrecognized driver {driver}"; "name" => name); return Err(Error::new( diff --git a/lib/propolis/Cargo.toml b/lib/propolis/Cargo.toml index 0139a804c..195852d9b 100644 --- a/lib/propolis/Cargo.toml +++ b/lib/propolis/Cargo.toml @@ -39,6 +39,7 @@ crucible = { workspace = true, optional = true } oximeter = { workspace = true, optional = true } nexus-client = { workspace = true, optional = true } async-trait.workspace = true +iddqd.workspace = true # falcon libloading = { workspace = true, optional = true } diff --git a/lib/propolis/src/hw/virtio/mod.rs b/lib/propolis/src/hw/virtio/mod.rs index a1a72111a..af3b31f86 100644 --- a/lib/propolis/src/hw/virtio/mod.rs +++ b/lib/propolis/src/hw/virtio/mod.rs @@ -21,6 +21,10 @@ mod queue; #[cfg(feature = "falcon")] pub mod softnpu; pub mod viona; +pub mod vsock; + +#[cfg(test)] +pub mod testutil; use crate::common::RWOp; use crate::hw::pci as pci_hw; @@ -29,6 +33,7 @@ use queue::VirtQueue; pub use block::PciVirtioBlock; pub use viona::PciVirtioViona; +pub use vsock::PciVirtioSock; bitflags! { pub struct LegacyFeatures: u64 { @@ -165,6 +170,7 @@ impl DeviceId { match self { Self::Network => Ok(pci_hw::bits::CLASS_NETWORK), Self::Block | Self::NineP => Ok(pci_hw::bits::CLASS_STORAGE), + Self::Socket => Ok(pci_hw::bits::CLASS_UNCLASSIFIED), _ => Err(self), } } @@ -228,6 +234,7 @@ pub trait VirtioIntr: Send + 'static { fn read(&self) -> VqIntr; } +#[derive(Debug)] pub enum VqChange { /// Underlying virtio device has been reset Reset, diff --git a/lib/propolis/src/hw/virtio/queue.rs b/lib/propolis/src/hw/virtio/queue.rs index 2965c8746..545f83e17 100644 --- a/lib/propolis/src/hw/virtio/queue.rs +++ b/lib/propolis/src/hw/virtio/queue.rs @@ -94,7 +94,7 @@ impl VqAvail { } if let Some(idx) = mem.read::(self.gpa_idx) { let ndesc = Wrapping(*idx) - self.cur_avail_idx; - if ndesc.0 != 0 && ndesc.0 < rsize { + if ndesc.0 != 0 && ndesc.0 <= rsize { let avail_idx = self.cur_avail_idx.0 & (rsize - 1); self.cur_avail_idx += Wrapping(1); diff --git a/lib/propolis/src/hw/virtio/testutil.rs b/lib/propolis/src/hw/virtio/testutil.rs new file mode 100644 index 000000000..42c8821e1 --- /dev/null +++ b/lib/propolis/src/hw/virtio/testutil.rs @@ -0,0 +1,629 @@ +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at https://mozilla.org/MPL/2.0/. + +//! Test utilities for constructing fake virtqueues backed by real guest memory. +//! +//! This module provides [`TestVirtQueue`] for single-queue tests and +//! [`TestVirtQueues`] for multi-queue devices. Both allocate guest memory via +//! a tempfile-backed [`PhysMap`], lay out virtio ring structures, and provide +//! helpers to enqueue descriptor chains — simulating a guest driver writing +//! to the available ring. + +use std::sync::Arc; + +use zerocopy::FromBytes; + +use crate::accessors::MemAccessor; +use crate::common::GuestAddr; +use crate::vmm::mem::PhysMap; +use crate::vmm::MemCtx; + +// Re-export queue types so tests outside this module can access them +// without requiring `queue` to be pub(crate). +pub use super::queue::{Chain, DescFlag, VirtQueue, VirtQueues, VqSize}; + +/// Page size for alignment (4 KiB). +const PAGE_SIZE: u64 = 0x1000; + +/// Size in bytes of a virtio descriptor (addr: u64, len: u32, flags: u16, next: u16). +const DESC_SIZE: u64 = 16; + +/// Size in bytes of a used ring element (id: u32, len: u32). +const USED_ELEM_SIZE: u64 = 8; + +/// Size in bytes of an available ring entry (descriptor index: u16). +const AVAIL_ELEM_SIZE: u64 = 2; + +/// Size in bytes of the ring header (flags: u16, idx: u16). +const RING_HEADER_SIZE: u64 = 4; + +/// Number of pages to allocate for the data area in tests. +const DATA_AREA_PAGES: u64 = 64; + +/// Align `val` up to the next multiple of `align` (must be power of 2). +pub const fn align_up(val: u64, align: u64) -> u64 { + (val + align - 1) & !(align - 1) +} + +/// 16-byte virtio descriptor, matching the on-wire/in-memory layout. +#[repr(C)] +#[derive(Copy, Clone, Default, FromBytes)] +pub struct RawDesc { + pub addr: u64, + pub len: u32, + pub flags: u16, + pub next: u16, +} + +/// 8-byte used ring element. +#[repr(C)] +#[derive(Copy, Clone, Default, FromBytes)] +pub struct RawUsedElem { + pub id: u32, + pub len: u32, +} + +/// Guest physical address layout for a single virtqueue's ring structures. +#[derive(Copy, Clone, Debug)] +pub struct QueueLayout { + pub desc_base: u64, + pub avail_base: u64, + pub used_base: u64, + /// First GPA after this queue's structures. + pub end: u64, +} + +impl QueueLayout { + /// Compute the ring layout for a queue of `size` entries starting at + /// `base`. + /// + /// Layout follows the virtio 1.0 split virtqueue format: + /// - Descriptor table: `size * DESC_SIZE` bytes + /// - Available ring: header (4 bytes) + `size * 2` bytes for entries + /// - Used ring: page-aligned, header (4 bytes) + `size * 8` bytes + pub fn new(base: u64, size: u16) -> Self { + let qsz = size as u64; + let desc_base = base; + let avail_base = desc_base + DESC_SIZE * qsz; + let used_base = align_up( + avail_base + RING_HEADER_SIZE + AVAIL_ELEM_SIZE * qsz, + PAGE_SIZE, + ); + let end = align_up( + used_base + RING_HEADER_SIZE + USED_ELEM_SIZE * qsz, + PAGE_SIZE, + ); + Self { desc_base, avail_base, used_base, end } + } +} + +/// Per-queue writer for injecting descriptors into a virtqueue's rings. +pub struct QueueWriter { + layout: QueueLayout, + size: u16, + /// Next free descriptor index. + next_desc: u16, + /// Start of data area for this queue. + data_start: u64, + /// Next free data area offset (GPA). + data_cursor: u64, + /// Avail ring index we've published up to. + avail_idx: u16, +} + +impl QueueWriter { + /// Create a new QueueWriter for a queue with the given layout. + pub fn new(layout: QueueLayout, size: u16, data_start: u64) -> Self { + Self { + layout, + size, + next_desc: 0, + data_start, + data_cursor: data_start, + avail_idx: 0, + } + } + + /// Reset descriptor and data cursors to allow reusing slots. + pub fn reset_cursors(&mut self) { + self.next_desc = 0; + self.data_cursor = self.data_start; + } + + /// Write a descriptor and return its index. + pub fn write_desc( + &mut self, + mem_acc: &MemAccessor, + addr: u64, + len: u32, + flags: u16, + next: u16, + ) -> u16 { + let idx = self.next_desc; + assert!(idx < self.size, "descriptor table exhausted"); + self.next_desc += 1; + + let desc = RawDesc { addr, len, flags, next }; + let gpa = self.layout.desc_base + u64::from(idx) * DESC_SIZE; + let mem = mem_acc.access().unwrap(); + mem.write(GuestAddr(gpa), &desc); + idx + } + + /// Allocate data space and write bytes into it. Returns the GPA. + pub fn write_data(&mut self, mem_acc: &MemAccessor, data: &[u8]) -> u64 { + let gpa = self.data_cursor; + self.data_cursor += data.len() as u64; + let mem = mem_acc.access().unwrap(); + mem.write_from(GuestAddr(gpa), data, data.len()); + gpa + } + + /// Allocate data space without writing. Returns the GPA. + pub fn alloc_data(&mut self, len: u32) -> u64 { + let gpa = self.data_cursor; + self.data_cursor += u64::from(len); + gpa + } + + /// Add a readable descriptor with the given data. + pub fn add_readable(&mut self, mem_acc: &MemAccessor, data: &[u8]) -> u16 { + let gpa = self.write_data(mem_acc, data); + self.write_desc(mem_acc, gpa, data.len() as u32, 0, 0) + } + + /// Add a writable descriptor of the given size. + pub fn add_writable(&mut self, mem_acc: &MemAccessor, len: u32) -> u16 { + let gpa = self.alloc_data(len); + self.write_desc(mem_acc, gpa, len, DescFlag::WRITE.bits(), 0) + } + + /// Chain two descriptors together via NEXT flag. + pub fn chain(&self, mem_acc: &MemAccessor, from: u16, to: u16) { + let gpa = self.layout.desc_base + u64::from(from) * DESC_SIZE; + let mem = mem_acc.access().unwrap(); + let mut raw: RawDesc = *mem.read(GuestAddr(gpa)).unwrap(); + raw.flags |= DescFlag::NEXT.bits(); + raw.next = to; + mem.write(GuestAddr(gpa), &raw); + } + + /// Publish a descriptor chain head on the available ring. + pub fn publish_avail(&mut self, mem_acc: &MemAccessor, head: u16) { + // Available ring layout: + // flags (u16) | idx (u16) | ring[size] (u16 each) + let slot = self.layout.avail_base + + RING_HEADER_SIZE + + u64::from(self.avail_idx % self.size) * AVAIL_ELEM_SIZE; + self.avail_idx += 1; + let new_idx = self.avail_idx; + let mem = mem_acc.access().unwrap(); + mem.write(GuestAddr(slot), &head); + // Write new index at offset 2 (after flags u16) + mem.write(GuestAddr(self.layout.avail_base + 2), &new_idx); + } + + /// Read the used ring index. + pub fn used_idx(&self, mem_acc: &MemAccessor) -> u16 { + let mem = mem_acc.access().unwrap(); + // Used ring idx is at offset 2 (after flags u16) + *mem.read(GuestAddr(self.layout.used_base + 2)).unwrap() + } + + /// Read a used ring entry by index, returning (desc_id, len). + pub fn read_used_elem( + &self, + mem_acc: &MemAccessor, + used_index: u16, + ) -> RawUsedElem { + let mem = mem_acc.access().unwrap(); + // Used ring layout: + // flags (u16) | idx (u16) | ring[size] (RawUsedElem each) + let entry_gpa = self.layout.used_base + + RING_HEADER_SIZE + + u64::from(used_index % self.size) * USED_ELEM_SIZE; + *mem.read(GuestAddr(entry_gpa)).unwrap() + } + + /// Read raw bytes from the buffer of a descriptor. + pub fn read_desc_data( + &self, + mem_acc: &MemAccessor, + desc_id: u16, + len: usize, + ) -> Vec { + let mem = mem_acc.access().unwrap(); + let desc_gpa = self.layout.desc_base + u64::from(desc_id) * DESC_SIZE; + let raw_desc: RawDesc = *mem.read(GuestAddr(desc_gpa)).unwrap(); + + let mut data = vec![0u8; len]; + mem.read_into( + GuestAddr(raw_desc.addr), + &mut crate::common::GuestData::from(data.as_mut_slice()), + len, + ); + data + } +} + +/// Multi-queue test harness for virtio devices that use multiple queues. +pub struct TestVirtQueues { + /// Must stay alive to keep memory mappings valid. + _phys: PhysMap, + mem_acc: MemAccessor, + queues: VirtQueues, + layouts: Vec, + sizes: Vec, + /// Start of data area (after all queue structures). + data_start: u64, +} + +impl TestVirtQueues { + /// Create a new multi-queue test harness. + /// + /// `sizes` specifies the size of each queue (must be powers of 2). + pub fn new(sizes: &[VqSize]) -> Self { + // Compute layouts for all queues sequentially + let mut layouts = Vec::with_capacity(sizes.len()); + let mut size_vals = Vec::with_capacity(sizes.len()); + let mut offset = 0u64; + for &size in sizes { + let size_u16: u16 = size.into(); + let layout = QueueLayout::new(offset, size_u16); + offset = layout.end; + layouts.push(layout); + size_vals.push(size_u16); + } + + // Data area after all rings + let data_start = offset; + let data_area_size = PAGE_SIZE * DATA_AREA_PAGES; + let total_size = + align_up(data_start + data_area_size, PAGE_SIZE) as usize; + + let mut phys = PhysMap::new_test(total_size); + phys.add_test_mem("test-vqs".to_string(), 0, total_size) + .expect("add test mem"); + let mem_acc = phys.finalize(); + + // Create VirtQueues + let queues = VirtQueues::new(sizes); + + // Initialize each queue + for (i, layout) in layouts.iter().enumerate() { + let vq = queues.get(i as u16).unwrap(); + mem_acc.adopt(&vq.acc_mem, Some(format!("test-vq-{i}"))); + vq.map_virtqueue( + layout.desc_base, + layout.avail_base, + layout.used_base, + ); + vq.live.store(true, std::sync::atomic::Ordering::Release); + vq.enabled.store(true, std::sync::atomic::Ordering::Release); + + // Zero out avail and used ring headers + let mem = mem_acc.access().unwrap(); + mem.write(GuestAddr(layout.avail_base), &0u16); + mem.write(GuestAddr(layout.avail_base + 2), &0u16); + mem.write(GuestAddr(layout.used_base), &0u16); + mem.write(GuestAddr(layout.used_base + 2), &0u16); + } + + Self { + _phys: phys, + mem_acc, + queues, + layouts, + sizes: size_vals, + data_start, + } + } + + /// Get the memory accessor. + pub fn mem_acc(&self) -> &MemAccessor { + &self.mem_acc + } + + /// Get the underlying VirtQueues. + pub fn queues(&self) -> &VirtQueues { + &self.queues + } + + /// Get the VirtQueue at the given index. + pub fn vq(&self, idx: u16) -> &Arc { + self.queues.get(idx).unwrap() + } + + /// Create a QueueWriter for the given queue index. + /// + /// `data_offset` is an offset from the shared data area start, + /// allowing different queues to use different regions. + pub fn writer(&self, queue_idx: usize, data_offset: u64) -> QueueWriter { + let layout = self.layouts[queue_idx]; + let size = self.sizes[queue_idx]; + QueueWriter::new(layout, size, self.data_start + data_offset) + } + + /// Get the layout for a queue. + pub fn layout(&self, queue_idx: usize) -> QueueLayout { + self.layouts[queue_idx] + } +} + +/// A test harness wrapping guest memory and a single virtqueue. +/// +/// For multi-queue tests, use [`TestVirtQueues`] instead. +pub struct TestVirtQueue { + inner: TestVirtQueues, + writer: QueueWriter, +} + +impl TestVirtQueue { + /// Create a new test virtqueue. + /// + /// `queue_size` must be a power of 2. + pub fn new(queue_size: u16) -> Self { + let inner = TestVirtQueues::new(&[VqSize::new(queue_size)]); + let writer = inner.writer(0, 0); + Self { inner, writer } + } + + /// Get the underlying `VirtQueue`. + pub fn vq(&self) -> &Arc { + self.inner.vq(0) + } + + /// Get a `MemCtx` guard for directly reading/writing guest memory. + pub fn mem(&self) -> impl std::ops::Deref + '_ { + self.inner.mem_acc().access().expect("test mem accessible") + } + + /// Add a readable descriptor containing `data`. + /// + /// Returns the descriptor index. + pub fn add_readable(&mut self, data: &[u8]) -> u16 { + self.writer.add_readable(self.inner.mem_acc(), data) + } + + /// Add a writable descriptor of `len` bytes. + /// + /// Returns the descriptor index. + pub fn add_writable(&mut self, len: u32) -> u16 { + self.writer.add_writable(self.inner.mem_acc(), len) + } + + /// Link descriptors into a chain by setting NEXT flags. + /// + /// `descs` should be in order: `[head, ..., tail]`. + pub fn chain_descriptors(&mut self, descs: &[u16]) { + for i in 0..descs.len().saturating_sub(1) { + self.writer.chain(self.inner.mem_acc(), descs[i], descs[i + 1]); + } + } + + /// Publish a descriptor chain head on the available ring. + pub fn publish_avail(&mut self, head: u16) { + self.writer.publish_avail(self.inner.mem_acc(), head); + } + + /// Read all entries from the used ring. + /// + /// Returns `(descriptor_id, bytes_written)` pairs. + pub fn read_used(&self) -> Vec<(u32, u32)> { + let used_idx = self.writer.used_idx(self.inner.mem_acc()); + (0..used_idx) + .map(|i| { + let elem = self.writer.read_used_elem(self.inner.mem_acc(), i); + (elem.id, elem.len) + }) + .collect() + } + + /// Pop a chain from the available ring and return it. + pub fn pop_chain(&self) -> Option<(Chain, u16, u32)> { + let mem = self.inner.mem_acc().access()?; + let mut chain = Chain::with_capacity(64); + let (avail_idx, len) = self.vq().pop_avail(&mut chain, &mem)?; + Some((chain, avail_idx, len)) + } + + /// Push a chain back to the used ring. + pub fn push_used(&self, chain: &mut Chain) { + let mem = self.inner.mem_acc().access().unwrap(); + self.vq().push_used(chain, &mem); + } + + /// Get the GPA of a descriptor's buffer. + pub fn desc_addr(&self, idx: u16) -> u64 { + let mem = self.inner.mem_acc().access().unwrap(); + let desc_gpa = + self.inner.layout(0).desc_base + u64::from(idx) * DESC_SIZE; + let raw: RawDesc = *mem.read(GuestAddr(desc_gpa)).unwrap(); + raw.addr + } + + /// Read raw bytes from guest memory at a given GPA. + pub fn read_guest_mem(&self, addr: u64, len: usize) -> Vec { + let mem = self.inner.mem_acc().access().unwrap(); + let mut buf = vec![0u8; len]; + let mut guest_buf = crate::common::GuestData::from(buf.as_mut_slice()); + mem.read_into(GuestAddr(addr), &mut guest_buf, len); + buf + } +} + +#[cfg(test)] +mod test { + use super::*; + + #[test] + fn smoke_pop_avail_readable() { + let mut tvq = TestVirtQueue::new(16); + + let data = b"hello virtqueue"; + let d0 = tvq.add_readable(data); + tvq.publish_avail(d0); + + let (mut chain, _avail_idx, total_len) = tvq.pop_chain().unwrap(); + assert_eq!(total_len, data.len() as u32); + + let mem = tvq.mem(); + let mut buf = [0u8; 15]; + assert!(chain.read(&mut buf, &mem)); + assert_eq!(&buf, data); + } + + #[test] + fn smoke_pop_avail_writable() { + let mut tvq = TestVirtQueue::new(16); + + let d0 = tvq.add_writable(64); + tvq.publish_avail(d0); + + let (mut chain, _avail_idx, total_len) = tvq.pop_chain().unwrap(); + assert_eq!(total_len, 64); + + let mem = tvq.mem(); + let payload = b"written by device"; + assert!(chain.write(payload, &mem)); + drop(mem); + + tvq.push_used(&mut chain); + + let used = tvq.read_used(); + assert_eq!(used.len(), 1); + assert_eq!(used[0].0, d0 as u32); + assert_eq!(used[0].1, payload.len() as u32); + + let addr = tvq.desc_addr(d0); + let read_back = tvq.read_guest_mem(addr, payload.len()); + assert_eq!(read_back, payload); + } + + #[test] + fn smoke_chained_descriptors() { + let mut tvq = TestVirtQueue::new(16); + + let header_data = [0xAA; 8]; + let body_data = [0xBB; 32]; + let d0 = tvq.add_readable(&header_data); + let d1 = tvq.add_readable(&body_data); + tvq.chain_descriptors(&[d0, d1]); + tvq.publish_avail(d0); + + let (mut chain, _avail_idx, total_len) = tvq.pop_chain().unwrap(); + assert_eq!(total_len, 40); + + let mem = tvq.mem(); + let mut hdr = [0u8; 8]; + assert!(chain.read(&mut hdr, &mem)); + assert_eq!(hdr, header_data); + + let mut body = [0u8; 32]; + assert!(chain.read(&mut body, &mem)); + assert_eq!(body, body_data); + } + + #[test] + fn smoke_mixed_chain() { + let mut tvq = TestVirtQueue::new(16); + + let req_data = [0x01, 0x02, 0x03, 0x04]; + let d0 = tvq.add_readable(&req_data); + let d1 = tvq.add_writable(128); + tvq.chain_descriptors(&[d0, d1]); + tvq.publish_avail(d0); + + let (mut chain, _, total_len) = tvq.pop_chain().unwrap(); + assert_eq!(total_len, 4 + 128); + + let mem = tvq.mem(); + + let mut req = [0u8; 4]; + assert!(chain.read(&mut req, &mem)); + assert_eq!(req, req_data); + + let resp = [0xFF; 16]; + assert!(chain.write(&resp, &mem)); + drop(mem); + + tvq.push_used(&mut chain); + + let addr = tvq.desc_addr(d1); + let read_back = tvq.read_guest_mem(addr, 16); + assert_eq!(read_back, &resp); + } + + #[test] + fn empty_avail_ring_returns_none() { + let tvq = TestVirtQueue::new(16); + assert!(tvq.pop_chain().is_none()); + } + + #[test] + fn multiple_chains() { + let mut tvq = TestVirtQueue::new(16); + + let d0 = tvq.add_readable(b"first"); + tvq.publish_avail(d0); + + let d1 = tvq.add_readable(b"second"); + tvq.publish_avail(d1); + + let (chain0, _, _) = tvq.pop_chain().unwrap(); + let (chain1, _, _) = tvq.pop_chain().unwrap(); + assert!(tvq.pop_chain().is_none()); + + assert_ne!(chain0.remain_read_bytes(), chain1.remain_read_bytes()); + } + + #[test] + fn multi_queue_smoke() { + let tvqs = TestVirtQueues::new(&[ + VqSize::new(64), + VqSize::new(64), + VqSize::new(1), + ]); + + let mut writer0 = tvqs.writer(0, 0); + let mut writer1 = tvqs.writer(1, PAGE_SIZE); + + let d0 = writer0.add_readable(tvqs.mem_acc(), b"queue0"); + writer0.publish_avail(tvqs.mem_acc(), d0); + + let d1 = writer1.add_readable(tvqs.mem_acc(), b"queue1"); + writer1.publish_avail(tvqs.mem_acc(), d1); + + // Pop from each queue + let mem = tvqs.mem_acc().access().unwrap(); + let mut chain0 = Chain::with_capacity(64); + let mut chain1 = Chain::with_capacity(64); + + assert!(tvqs.vq(0).pop_avail(&mut chain0, &mem).is_some()); + assert!(tvqs.vq(1).pop_avail(&mut chain1, &mem).is_some()); + + assert_eq!(chain0.remain_read_bytes(), 6); + assert_eq!(chain1.remain_read_bytes(), 6); + } + + #[test] + fn queue_writer_reset_cursors() { + let tvqs = TestVirtQueues::new(&[VqSize::new(16)]); + let mut writer = tvqs.writer(0, 0); + + // Add some descriptors + let d0 = writer.add_readable(tvqs.mem_acc(), b"first"); + writer.publish_avail(tvqs.mem_acc(), d0); + + // Reset and reuse + writer.reset_cursors(); + + let d1 = writer.add_readable(tvqs.mem_acc(), b"second"); + assert_eq!(d1, 0, "descriptor index should reset to 0"); + writer.publish_avail(tvqs.mem_acc(), d1); + + // Both publishes should have worked + assert_eq!(writer.used_idx(tvqs.mem_acc()), 0); // Nothing consumed yet + } +} diff --git a/lib/propolis/src/hw/virtio/vsock.rs b/lib/propolis/src/hw/virtio/vsock.rs new file mode 100644 index 000000000..3cac84ff9 --- /dev/null +++ b/lib/propolis/src/hw/virtio/vsock.rs @@ -0,0 +1,335 @@ +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at https://mozilla.org/MPL/2.0/. + +use lazy_static::lazy_static; +use slog::Logger; +use std::sync::Arc; + +use crate::accessors::MemAccessor; +use crate::common::*; +use crate::hw::pci; +use crate::hw::virtio; +use crate::hw::virtio::queue::Chain; +use crate::hw::virtio::queue::VirtQueue; +use crate::hw::virtio::queue::VqSize; +use crate::migrate::*; +use crate::util::regmap::RegMap; +use crate::vmm::MemCtx; +use crate::vsock::packet::VsockPacket; +use crate::vsock::packet::VsockPacketError; +use crate::vsock::packet::VsockPacketHeader; +use crate::vsock::proxy::VsockPortMapping; +use crate::vsock::VsockBackend; +use crate::vsock::VsockProxy; + +use super::pci::PciVirtio; +use super::pci::PciVirtioState; +use super::queue::VirtQueues; +use super::VirtioDevice; + +// virtio queue index numbers for virtio socket devices +pub const VSOCK_RX_QUEUE: u16 = 0x0; +pub const VSOCK_TX_QUEUE: u16 = 0x1; +pub const VSOCK_EVENT_QUEUE: u16 = 0x2; + +/// A permit representing a reserved rx queue descriptor chain. +/// +/// This guarantees we have space to send a packet to the guest before reading +/// data from a host socket, preventing data loss if the queue is full. +/// +/// The permit holds a mutable reference to `VsockVq`, ensuring only one permit +/// can exist at a time (enforced at compile time). If dropped without calling +/// `write_rw`, the chain is retained in `VsockVq` for reuse. +pub struct RxPermit<'a> { + available_data_space: usize, + vq: &'a mut VsockVq, +} + +impl RxPermit<'_> { + /// Returns the maximum data payload that can fit in this descriptor chain. + pub fn available_data_space(&self) -> usize { + self.available_data_space + } + + pub fn write(self, header: &VsockPacketHeader, data: &[u8]) { + let mem = self.vq.acc_mem.access().expect("mem access for write"); + let queue = + self.vq.queues.get(VSOCK_RX_QUEUE as usize).expect("rx queue"); + let mut chain = self.vq.rx_chain.take().expect("rx_chain should exist"); + + chain.write(header, &mem); + + if !data.is_empty() { + let mut done = 0; + chain.for_remaining_type(false, |addr, len| { + let to_write = &data[done..]; + if let Some(copied) = mem.write_from(addr, to_write, len) { + let need_more = copied != to_write.len(); + done += copied; + (copied, need_more) + } else { + (0, false) + } + }); + } + + queue.push_used(&mut chain, &mem); + } +} + +pub struct VsockVq { + queues: Vec>, + acc_mem: MemAccessor, + /// Cached rx chain for permit reuse when dropped without write_rw + rx_chain: Option, +} + +impl VsockVq { + pub(crate) fn new( + queues: Vec>, + acc_mem: MemAccessor, + ) -> Self { + Self { queues, acc_mem, rx_chain: None } + } + + /// Try to acquire a permit for sending a packet to the guest. + /// + /// Returns `Some(RxPermit)` if a descriptor chain is available, + /// `None` if the rx queue is full. + pub fn try_rx_permit(&mut self) -> Option> { + // Reuse cached chain or pop a new one + if self.rx_chain.is_none() { + let mem = self.acc_mem.access()?; + let vq = self.queues.get(VSOCK_RX_QUEUE as usize)?; + let mut chain = Chain::with_capacity(10); + vq.pop_avail(&mut chain, &mem)?; + self.rx_chain = Some(chain); + } + + let header_size = std::mem::size_of::(); + let available_data_space = self + .rx_chain + .as_ref() + .unwrap() + .remain_write_bytes() + .saturating_sub(header_size); + + Some(RxPermit { available_data_space, vq: self }) + } + + /// Receive all available packets from the TX queue. + /// + /// Returns a Vec of parsed packets. In the future this may be refactored + /// to return an iterator over GuestRegions to avoid copying packet data. + pub fn recv_packet(&self) -> Option> { + let mem = self.acc_mem.access()?; + let vq = self + .queues + .get(VSOCK_TX_QUEUE as usize) + .expect("vsock has tx queue"); + + let mut chain = Chain::with_capacity(10); + let Some((_idx, _clen)) = vq.pop_avail(&mut chain, &mem) else { + return None; + }; + + let packet = VsockPacket::parse(&mut chain, &mem); + vq.push_used(&mut chain, &mem); + + Some(packet) + } +} + +pub struct PciVirtioSock { + cid: u32, + backend: VsockProxy, + virtio_state: PciVirtioState, + pci_state: pci::DeviceState, +} + +impl PciVirtioSock { + pub fn new( + queue_size: u16, + cid: u32, + log: Logger, + port_mappings: Vec, + ) -> Arc { + let queues = VirtQueues::new(&[ + // VSOCK_RX_QUEUE + VqSize::new(queue_size), + // VSOCK_TX_QUEUE + VqSize::new(queue_size), + // VSOCK_EVENT_QUEUE + VqSize::new(1), + ]); + + // One for rx, tx, event + let msix_count = Some(3); + let (virtio_state, pci_state) = PciVirtioState::new( + virtio::Mode::Transitional, + queues, + msix_count, + virtio::DeviceId::Socket, + VIRTIO_VSOCK_CFG_SIZE, + ); + + let vvq = VsockVq::new( + virtio_state.queues.iter().map(Clone::clone).collect(), + pci_state.acc_mem.child(Some("vsock rx queue".to_string())), + ); + let port_mappings = port_mappings.into_iter().collect(); + + let backend = VsockProxy::new(cid, vvq, log, port_mappings); + + Arc::new(Self { cid, backend, virtio_state, pci_state }) + } + + // fn _send_transport_reset(&self) { + // let vq = &self.virtio_state.queues.get(VSOCK_EVENT_QUEUE).unwrap(); + // let mem = vq.acc_mem.access().unwrap(); + // let mut chain = Chain::with_capacity(1); + + // // Pop a buffer from the event queue + // if let Some((_idx, _clen)) = vq.pop_avail(&mut chain, &mem) { + // // Write the transport reset event + // let event = + // VirtioVsockEvent { id: VIRTIO_VSOCK_EVENT_TRANSPORT_RESET }; + // chain.write(&event, &mem); + + // // Push to used ring (this will also send interrupt to guest) + // vq.push_used(&mut chain, &mem); + // } else { + // eprintln!("no event queue buffer available for transport reset"); + // } + // } +} + +impl VirtioDevice for PciVirtioSock { + fn rw_dev_config(&self, mut rwo: crate::common::RWOp) { + VSOCK_DEV_REGS.process(&mut rwo, |id, rwo| match rwo { + RWOp::Read(ro) => match id { + VsockReg::GuestCid => { + ro.write_u32(self.cid); + // The upper 32 bits are reserved and zeroed. + ro.fill(0); + } + }, + RWOp::Write(_) => {} + }) + } + + fn features(&self) -> u64 { + VIRTIO_VSOCK_F_STREAM + } + + fn set_features(&self, _feat: u64) -> Result<(), ()> { + Ok(()) + } + + fn mode(&self) -> virtio::Mode { + virtio::Mode::Transitional + } + + fn queue_notify(&self, vq: &VirtQueue) { + let _ = self.backend.queue_notify(vq.id); + } +} + +// #[repr(C, packed)] +// #[derive(Copy, Clone, Default, Debug)] +// struct VirtioVsockEvent { +// id: u32, +// } + +impl PciVirtio for PciVirtioSock { + fn virtio_state(&self) -> &PciVirtioState { + &self.virtio_state + } + fn pci_state(&self) -> &pci::DeviceState { + &self.pci_state + } +} + +impl Lifecycle for PciVirtioSock { + fn type_name(&self) -> &'static str { + "pci-virtio-vsock" + } + fn reset(&self) { + self.virtio_state.reset(self); + } + fn migrate(&'_ self) -> Migrator<'_> { + Migrator::NonMigratable + } +} + +#[derive(Copy, Clone, Eq, PartialEq, Debug)] +enum VsockReg { + GuestCid, +} + +lazy_static! { + static ref VSOCK_DEV_REGS: RegMap = { + let layout = [(VsockReg::GuestCid, 8)]; + RegMap::create_packed(VIRTIO_VSOCK_CFG_SIZE, &layout, None) + }; +} + +mod bits { + pub const VIRTIO_VSOCK_CFG_SIZE: usize = 0x8; + + pub const VIRTIO_VSOCK_F_STREAM: u64 = 0; + + #[allow(unused)] + pub const VIRTIO_VSOCK_EVENT_TRANSPORT_RESET: u32 = 0; +} +use bits::*; + +impl VsockPacket { + // TODO: We may want to consider operating on `Vec` to avoid + // double copying the packet contents. For now we are reading all of the + // packet data at once because it's convenient. + fn parse( + chain: &mut Chain, + mem: &MemCtx, + ) -> Result { + let mut packet = VsockPacket::default(); + + // Attempt to read the vsock packet header from the descriptor chain + // before we can process the full packet. + if !chain.read(&mut packet.header, mem) { + return Err(VsockPacketError::ChainHeaderRead); + } + + // If the packet header indicates there is no data in this packet, then + // there's no point in attempting to continue reading from the chain. + if packet.header.len() == 0 { + return Ok(packet); + } + + let len = usize::try_from(packet.header.len()) + .expect("running on a 64bit platform"); + packet.data.resize(len, 0); + + let mut done = 0; + let copied = chain.for_remaining_type(true, |addr, len| { + let mut remain = GuestData::from(&mut packet.data[done..]); + if let Some(copied) = mem.read_into(addr, &mut remain, len) { + let need_more = copied != remain.len(); + done += copied; + (copied, need_more) + } else { + (0, false) + } + }); + + if copied != len { + return Err(VsockPacketError::InsufficientBytes { + expected: len, + remaining: copied, + }); + } + + Ok(packet) + } +} diff --git a/lib/propolis/src/lib.rs b/lib/propolis/src/lib.rs index c608a816c..5f07fa1bd 100644 --- a/lib/propolis/src/lib.rs +++ b/lib/propolis/src/lib.rs @@ -34,6 +34,7 @@ pub mod tasks; pub mod util; pub mod vcpu; pub mod vmm; +pub mod vsock; pub use exits::{VmEntry, VmExit}; pub use vmm::Machine; diff --git a/lib/propolis/src/vsock/buffer.rs b/lib/propolis/src/vsock/buffer.rs new file mode 100644 index 000000000..a5599e738 --- /dev/null +++ b/lib/propolis/src/vsock/buffer.rs @@ -0,0 +1,210 @@ +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at https://mozilla.org/MPL/2.0/. + +use std::num::NonZeroUsize; +use std::num::Wrapping; + +#[derive(Debug, thiserror::Error)] +pub enum VsockBufError { + #[error( + "VsockBuf has {remaining} bytes available but tried to push {pushed}" + )] + InsufficientSpace { pushed: usize, remaining: usize }, +} + +/// A ringbuffer used to store guest -> host data +pub struct VsockBuf { + buf: Box<[u8]>, + head: Wrapping, + tail: Wrapping, +} + +impl std::fmt::Debug for VsockBuf { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("VsockBuf") + .field("capacity", &self.capacity()) + .field("head", &self.head) + .field("tail", &self.tail) + .field("in_use", &self.len()) + .field("free", &self.free()) + .finish() + } +} + +impl VsockBuf { + /// Create a new `VsockBuf` + pub fn new(capacity: NonZeroUsize) -> Self { + let capacity = capacity.get(); + Self { + buf: vec![0; capacity].into_boxed_slice(), + head: Wrapping(0), + tail: Wrapping(0), + } + } + + pub fn capacity(&self) -> usize { + self.buf.len() + } + + pub fn len(&self) -> usize { + (self.head - self.tail).0 + } + + fn free(&self) -> usize { + self.capacity() - self.len() + } + + pub fn is_empty(&self) -> bool { + self.head == self.tail + } + + pub fn push(&mut self, data: Vec) -> Result<(), VsockBufError> { + if data.len() > self.free() { + return Err(VsockBufError::InsufficientSpace { + pushed: data.len(), + remaining: self.free(), + }); + } + + let head_offset = self.head.0 % self.buf.len(); + let available_len = self.buf.len() - head_offset; + + // If the data can fit in the remaining space of the ring buffer copy it + // in one go. + if data.len() <= available_len { + self.buf[head_offset..head_offset + data.len()] + .copy_from_slice(&data); + // Otherwise split it and write the remaining data to the front. + } else { + let (fits, wrapped) = data.split_at(available_len); + self.buf[head_offset..].copy_from_slice(fits); + self.buf[..wrapped.len()].copy_from_slice(wrapped); + } + + self.head += Wrapping(data.len()); + Ok(()) + } + + pub fn write_to( + &mut self, + writer: &mut W, + ) -> std::io::Result { + // If we have no data to write bail early + if self.is_empty() { + return Ok(0); + } + + let tail_offset = self.tail.0 % self.buf.len(); + let head_offset = self.head.0 % self.buf.len(); + + // If the data is contiguous write it in one go + let nwritten = if tail_offset < head_offset { + writer.write(&self.buf[tail_offset..head_offset])? + // Data wraps around so try to write it in batches + } else { + let available_len = self.buf.len() - tail_offset; + let nwritten = writer.write(&self.buf[tail_offset..])?; + + // If we failed to write the entire first segment return early + if nwritten < available_len { + self.tail += Wrapping(nwritten); + return Ok(nwritten); + } + + // If we were successful, attempt to continue writing the wrapped + // around segment + let second_nwritten = writer.write(&self.buf[..head_offset])?; + nwritten + second_nwritten + }; + + self.tail += Wrapping(nwritten); + Ok(nwritten) + } +} + +#[cfg(test)] +mod test { + use std::{io::Cursor, num::NonZeroUsize}; + + use crate::vsock::buffer::VsockBuf; + + #[test] + fn test_capacity_and_len() { + let mut vb = VsockBuf::new(NonZeroUsize::new(10).unwrap()); + assert_eq!(vb.capacity(), 10); + assert!(vb.is_empty()); + + let data = vec![1; 8]; + let data_len = data.len(); + assert!(vb.push(data).is_ok()); + assert!(!vb.is_empty()); + assert_eq!(vb.capacity(), 10); + assert_eq!(vb.len(), data_len); + } + + #[test] + fn test_push_less_than_capacity() { + let mut vb = VsockBuf::new(NonZeroUsize::new(10).unwrap()); + let data = vec![1; 8]; + assert!(vb.push(data).is_ok()); + } + + #[test] + fn test_push_more_than_capacity() { + let mut vb = VsockBuf::new(NonZeroUsize::new(10).unwrap()); + let data = vec![1; 8]; + assert!(vb.push(data).is_ok()); + + let data = vec![1; 8]; + assert!(vb.push(data).is_err()); + } + + #[test] + fn test_write_to() { + let mut vb = VsockBuf::new(NonZeroUsize::new(10).unwrap()); + let data = vec![1; 10]; + assert!(vb.push(data).is_ok()); + + let mut some_socket = [1; 10]; + let mut cursor = Cursor::new(&mut some_socket[..]); + assert!(vb.write_to(&mut cursor).is_ok_and(|n| n == 10)); + } + + #[test] + fn test_partial_write_to() { + let mut vb = VsockBuf::new(NonZeroUsize::new(10).unwrap()); + let data = vec![1; 10]; + assert!(vb.push(data).is_ok()); + + let mut some_socket = [1; 5]; + let mut cursor = Cursor::new(&mut some_socket[..]); + assert!(vb.write_to(&mut cursor).is_ok_and(|n| n == 5)); + assert_eq!(vb.len(), 5, "5 bytes remain"); + + // reset the cursor and read another chunk + cursor.set_position(0); + assert!(vb.write_to(&mut cursor).is_ok_and(|n| n == 5)); + assert!(vb.is_empty()); + } + + #[test] + fn test_wrap_around() { + let mut vb = VsockBuf::new(NonZeroUsize::new(10).unwrap()); + let data = vec![1; 8]; + assert!(vb.push(data).is_ok()); + + let mut some_socket = [1; 4]; + let mut cursor = Cursor::new(&mut some_socket[..]); + assert!(vb.write_to(&mut cursor).is_ok_and(|n| n == 4)); + assert_eq!(some_socket, [1u8; 4]); + + let data = vec![2; 4]; + assert!(vb.push(data).is_ok()); + + let mut some_socket = [1; 8]; + let mut cursor = Cursor::new(&mut some_socket[..]); + assert!(vb.write_to(&mut cursor).is_ok_and(|n| n == 8)); + assert_eq!(some_socket, [1, 1, 1, 1, 2, 2, 2, 2]); + } +} diff --git a/lib/propolis/src/vsock/mod.rs b/lib/propolis/src/vsock/mod.rs new file mode 100644 index 000000000..d6485dfe7 --- /dev/null +++ b/lib/propolis/src/vsock/mod.rs @@ -0,0 +1,29 @@ +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at https://mozilla.org/MPL/2.0/. + +pub mod buffer; +pub mod packet; + +#[cfg(target_os = "illumos")] +pub mod poller; + +#[cfg(not(target_os = "illumos"))] +#[path = "poller_stub.rs"] +pub mod poller; + +pub mod proxy; +pub use proxy::VsockProxy; + +/// Well-known CID for the host +pub(crate) const VSOCK_HOST_CID: u64 = 2; + +#[derive(Debug, thiserror::Error)] +pub enum VsockError { + #[error("failed to send virt queue notification for queue {0}")] + QueueNotify(u16), +} + +pub trait VsockBackend: Send + Sync + 'static { + fn queue_notify(&self, queue_id: u16) -> Result<(), VsockError>; +} diff --git a/lib/propolis/src/vsock/packet.rs b/lib/propolis/src/vsock/packet.rs new file mode 100644 index 000000000..92e566d71 --- /dev/null +++ b/lib/propolis/src/vsock/packet.rs @@ -0,0 +1,259 @@ +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at https://mozilla.org/MPL/2.0/. + +use crate::vsock::VSOCK_HOST_CID; + +pub const VIRTIO_VSOCK_OP_REQUEST: VsockPacketOp = 1; +pub const VIRTIO_VSOCK_OP_RESPONSE: VsockPacketOp = 2; +pub const VIRTIO_VSOCK_OP_RST: VsockPacketOp = 3; +pub const VIRTIO_VSOCK_OP_SHUTDOWN: VsockPacketOp = 4; +pub const VIRTIO_VSOCK_OP_RW: VsockPacketOp = 5; +pub const VIRTIO_VSOCK_OP_CREDIT_UPDATE: VsockPacketOp = 6; +pub const VIRTIO_VSOCK_OP_CREDIT_REQUEST: VsockPacketOp = 7; +type VsockPacketOp = u16; + +pub(crate) const VIRTIO_VSOCK_TYPE_STREAM: VsockSocketType = 1; +type VsockSocketType = u16; + +/// Shutdown flags for VIRTIO_VSOCK_OP_SHUTDOWN +pub const VIRTIO_VSOCK_SHUTDOWN_F_RECEIVE: u32 = 1; +pub const VIRTIO_VSOCK_SHUTDOWN_F_SEND: u32 = 2; + +#[derive(thiserror::Error, Debug)] +pub enum VsockPacketError { + #[error("Failed to read packet header from descriptor chain")] + ChainHeaderRead, + #[error("Packet only contained {remaining} bytes out of {expected} bytes")] + InsufficientBytes { expected: usize, remaining: usize }, +} + +#[repr(C, packed)] +#[derive(Copy, Clone, Default, Debug)] +pub struct VsockPacketHeader { + src_cid: u64, + dst_cid: u64, + src_port: u32, + dst_port: u32, + len: u32, + // Note this is "type" in the spec + socket_type: u16, + op: u16, + flags: u32, + buf_alloc: u32, + fwd_cnt: u32, +} + +impl VsockPacketHeader { + pub fn src_cid(&self) -> u64 { + // The spec states: + // + // The upper 32 bits of src_cid and dst_cid are reserved and zeroed. + u64::from_le(self.src_cid) & u64::from(u32::MAX) + } + + pub fn dst_cid(&self) -> u64 { + // The spec states: + // + // The upper 32 bits of src_cid and dst_cid are reserved and zeroed. + u64::from_le(self.dst_cid) & u64::from(u32::MAX) + } + + pub fn src_port(&self) -> u32 { + u32::from_le(self.src_port) + } + + pub fn dst_port(&self) -> u32 { + u32::from_le(self.dst_port) + } + + pub fn len(&self) -> u32 { + u32::from_le(self.len) + } + + pub fn socket_type(&self) -> u16 { + u16::from_le(self.socket_type) + } + + pub fn op(&self) -> u16 { + u16::from_le(self.op) + } + + pub fn flags(&self) -> u32 { + u32::from_le(self.flags) + } + + pub fn buf_alloc(&self) -> u32 { + u32::from_le(self.buf_alloc) + } + + pub fn fwd_cnt(&self) -> u32 { + u32::from_le(self.fwd_cnt) + } + + pub fn set_src_cid(&mut self, cid: u32) -> &mut Self { + // The spec states: + // + // The upper 32 bits of src_cid and dst_cid are reserved and zeroed. + self.src_cid = cid.to_le() as u64; + self + } + + pub fn set_dst_cid(&mut self, cid: u32) -> &mut Self { + // The spec states: + // + // The upper 32 bits of src_cid and dst_cid are reserved and zeroed. + self.dst_cid = cid.to_le() as u64; + self + } + + pub fn set_src_port(&mut self, port: u32) -> &mut Self { + self.src_port = port.to_le(); + self + } + + pub fn set_dst_port(&mut self, port: u32) -> &mut Self { + self.dst_port = port.to_le(); + self + } + + pub fn set_len(&mut self, len: u32) -> &mut Self { + self.len = len.to_le(); + self + } + + pub fn set_socket_type(&mut self, socket_type: u16) -> &mut Self { + self.socket_type = socket_type.to_le(); + self + } + + pub fn set_op(&mut self, op: u16) -> &mut Self { + self.op = op.to_le(); + self + } + + pub fn set_flags(&mut self, flags: u32) -> &mut Self { + self.flags = flags.to_le(); + self + } + + pub fn set_buf_alloc(&mut self, buf_alloc: u32) -> &mut Self { + self.buf_alloc = buf_alloc.to_le(); + self + } + + pub fn set_fwd_cnt(&mut self, fwd_cnt: u32) -> &mut Self { + self.fwd_cnt = fwd_cnt.to_le(); + self + } +} + +#[derive(Default, Debug)] +pub struct VsockPacket { + pub(crate) header: VsockPacketHeader, + pub(crate) data: Vec, +} + +impl VsockPacket { + pub fn new_reset(guest_cid: u32, src_port: u32, dst_port: u32) -> Self { + let mut header = VsockPacketHeader::default(); + header + .set_src_cid(VSOCK_HOST_CID as u32) + .set_dst_cid(guest_cid) + .set_src_port(src_port) + .set_dst_port(dst_port) + .set_len(0) + .set_socket_type(VIRTIO_VSOCK_TYPE_STREAM) + .set_op(VIRTIO_VSOCK_OP_RST) + .set_buf_alloc(0) + .set_fwd_cnt(0); + + VsockPacket { header, data: Vec::new() } + } + + pub fn new_response( + guest_cid: u32, + src_port: u32, + dst_port: u32, + buf_alloc: u32, + ) -> Self { + let mut header = VsockPacketHeader::default(); + header + .set_src_cid(VSOCK_HOST_CID as u32) + .set_dst_cid(guest_cid) + .set_src_port(src_port) + .set_dst_port(dst_port) + .set_len(0) + .set_socket_type(VIRTIO_VSOCK_TYPE_STREAM) + .set_op(VIRTIO_VSOCK_OP_RESPONSE) + .set_buf_alloc(buf_alloc) + .set_fwd_cnt(0); + + VsockPacket { header, data: Vec::new() } + } + + pub fn new_rw( + guest_cid: u32, + src_port: u32, + dst_port: u32, + buf_alloc: u32, + fwd_cnt: u32, + data: Vec, + ) -> Self { + let mut header = VsockPacketHeader::default(); + header + .set_src_cid(VSOCK_HOST_CID as u32) + .set_dst_cid(guest_cid) + .set_src_port(src_port) + .set_dst_port(dst_port) + .set_len(data.len() as u32) + .set_socket_type(VIRTIO_VSOCK_TYPE_STREAM) + .set_op(VIRTIO_VSOCK_OP_RW) + .set_buf_alloc(buf_alloc) + .set_fwd_cnt(fwd_cnt); + + VsockPacket { header, data } + } + + pub fn new_credit_update( + guest_cid: u32, + src_port: u32, + dst_port: u32, + buf_alloc: u32, + fwd_cnt: u32, + ) -> Self { + let mut header = VsockPacketHeader::default(); + header + .set_src_cid(VSOCK_HOST_CID as u32) + .set_dst_cid(guest_cid) + .set_src_port(src_port) + .set_dst_port(dst_port) + .set_len(0) + .set_socket_type(VIRTIO_VSOCK_TYPE_STREAM) + .set_op(VIRTIO_VSOCK_OP_CREDIT_UPDATE) + .set_buf_alloc(buf_alloc) + .set_fwd_cnt(fwd_cnt); + + VsockPacket { header, data: Vec::new() } + } + + pub fn new_shutdown( + guest_cid: u32, + src_port: u32, + dst_port: u32, + flags: u32, + ) -> Self { + let mut header = VsockPacketHeader::default(); + header + .set_src_cid(VSOCK_HOST_CID as u32) + .set_dst_cid(guest_cid) + .set_src_port(src_port) + .set_dst_port(dst_port) + .set_len(0) + .set_socket_type(VIRTIO_VSOCK_TYPE_STREAM) + .set_op(VIRTIO_VSOCK_OP_SHUTDOWN) + .set_flags(flags); + + VsockPacket { header, data: Vec::new() } + } +} diff --git a/lib/propolis/src/vsock/poller.rs b/lib/propolis/src/vsock/poller.rs new file mode 100644 index 000000000..56376b637 --- /dev/null +++ b/lib/propolis/src/vsock/poller.rs @@ -0,0 +1,1732 @@ +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at https://mozilla.org/MPL/2.0/. + +use std::collections::hash_map::Entry; +use std::collections::HashMap; +use std::collections::VecDeque; +use std::ffi::c_void; +use std::io::ErrorKind; +use std::io::Read; +use std::mem::MaybeUninit; +use std::os::fd::{AsFd, AsRawFd, BorrowedFd, FromRawFd, OwnedFd, RawFd}; +use std::sync::Arc; +use std::thread::JoinHandle; + +use bitflags::bitflags; +use iddqd::IdHashMap; +use slog::{debug, error, info, warn, Logger}; + +use crate::hw::virtio::vsock::VsockVq; +use crate::hw::virtio::vsock::VSOCK_RX_QUEUE; +use crate::hw::virtio::vsock::VSOCK_TX_QUEUE; +use crate::vsock::packet::VsockPacket; +use crate::vsock::packet::VsockPacketHeader; +use crate::vsock::packet::VIRTIO_VSOCK_OP_CREDIT_REQUEST; +use crate::vsock::packet::VIRTIO_VSOCK_OP_CREDIT_UPDATE; +use crate::vsock::packet::VIRTIO_VSOCK_OP_REQUEST; +use crate::vsock::packet::VIRTIO_VSOCK_OP_RST; +use crate::vsock::packet::VIRTIO_VSOCK_OP_RW; +use crate::vsock::packet::VIRTIO_VSOCK_OP_SHUTDOWN; +use crate::vsock::packet::VIRTIO_VSOCK_SHUTDOWN_F_RECEIVE; +use crate::vsock::packet::VIRTIO_VSOCK_SHUTDOWN_F_SEND; +use crate::vsock::packet::VIRTIO_VSOCK_TYPE_STREAM; +use crate::vsock::proxy::ConnKey; +use crate::vsock::proxy::VsockPortMapping; +use crate::vsock::proxy::VsockProxyConn; +use crate::vsock::proxy::CONN_TX_BUF_SIZE; +use crate::vsock::VSOCK_HOST_CID; + +#[repr(usize)] +enum VsockEvent { + TxQueue = 0, + RxQueue, + Shutdown, +} + +pub struct VsockPollerNotify { + port_fd: Arc, +} + +impl VsockPollerNotify { + fn port_fd(&self) -> BorrowedFd<'_> { + self.port_fd.as_fd() + } + + fn port_send(&self, event: VsockEvent) -> std::io::Result<()> { + let ret = unsafe { + libc::port_send(self.port_fd().as_raw_fd(), 0, event as usize as _) + }; + + if ret == 0 { + Ok(()) + } else { + Err(std::io::Error::last_os_error()) + } + } + + pub fn queue_notify(&self, id: u16) -> std::io::Result<()> { + match id { + VSOCK_RX_QUEUE => self.port_send(VsockEvent::RxQueue), + VSOCK_TX_QUEUE => self.port_send(VsockEvent::TxQueue), + _ => Ok(()), + } + } + + pub fn shutdown(&self) -> std::io::Result<()> { + self.port_send(VsockEvent::Shutdown) + } +} + +bitflags! { + #[repr(transparent)] + #[derive(Copy, Clone, Eq, PartialEq, Hash, Debug)] + pub struct PollEvents: i32 { + const IN = libc::POLLIN as i32; + const PRI = libc::POLLPRI as i32; + const OUT = libc::POLLOUT as i32; + const ERR = libc::POLLERR as i32; + const HUP = libc::POLLHUP as i32; + } +} + +/// Set of `PollEvents` that signifies a readable event. +fn fd_readable() -> PollEvents { + PollEvents::IN | PollEvents::HUP | PollEvents::ERR | PollEvents::PRI +} + +/// Set of `PollEvents` that signifies a writable event. +fn fd_writable() -> PollEvents { + PollEvents::OUT | PollEvents::HUP | PollEvents::ERR +} + +#[derive(Debug)] +enum RxEvent { + /// Vsock RST packet + Reset(ConnKey), + /// Vsock RESPONSE packet + NewConnection(ConnKey), + /// Vsock CREDIT_UPDATE packet + CreditUpdate(ConnKey), +} + +pub struct VsockPoller { + log: Logger, + /// The guest context id + guest_cid: u32, + /// Port mappings we are proxying packets to and from + port_mappings: IdHashMap, + /// The event port fd. + port_fd: Arc, + /// The virtqueues associated with the vsock device + queues: VsockVq, + /// The connection map of guest connected streams + connections: HashMap, + /// Queue of vsock packets that need to be sent to the guest + rx: VecDeque, + /// Connections blocked waiting for rx queue descriptors + rx_blocked: Vec, +} + +impl VsockPoller { + /// Create a new `VsockPoller`. + /// + /// This poller is responsible for driving virtio-socket connections between + /// the guest VM and host sockets. + pub fn new( + cid: u32, + queues: VsockVq, + log: Logger, + port_mappings: IdHashMap, + ) -> std::io::Result { + let port_fd = unsafe { + let fd = match libc::port_create() { + -1 => return Err(std::io::Error::last_os_error()), + fd => fd, + }; + + // Set CLOEXEC on the event port fd + if libc::fcntl( + fd, + libc::F_SETFD, + libc::fcntl(fd, libc::F_GETFD) | libc::FD_CLOEXEC, + ) < 0 + { + return Err(std::io::Error::last_os_error()); + }; + + fd + }; + + info!( + &log, + "vsock poller configured with"; + "mappings" => ?port_mappings, + ); + + Ok(Self { + log, + guest_cid: cid, + port_mappings, + port_fd: Arc::new(unsafe { OwnedFd::from_raw_fd(port_fd) }), + queues, + connections: Default::default(), + rx: Default::default(), + rx_blocked: Default::default(), + }) + } + + /// Get a handle to a `VsockPollerNotify`. + pub fn notify_handle(&self) -> VsockPollerNotify { + VsockPollerNotify { port_fd: Arc::clone(&self.port_fd) } + } + + /// Start the event loop. + pub fn run(mut self) -> JoinHandle<()> { + std::thread::Builder::new() + .name("vsock-event-loop".to_string()) + .spawn(move || self.handle_events()) + .expect("failed to spawn vsock event loop") + } + + /// Handle the guest's VIRTIO_VSOCK_OP_REQUEST packet. + fn handle_connection_request(&mut self, key: ConnKey, packet: VsockPacket) { + if self.connections.contains_key(&key) { + // Connection already exists + self.send_conn_rst(key); + return; + } + + let Some(mapping) = self.port_mappings.get(&packet.header.dst_port()) + else { + // Drop the unknown connection so that it times out in the guest. + debug!( + &self.log, + "dropping connect request to unknown mapping"; + "packet" => ?packet, + ); + return; + }; + + match VsockProxyConn::new(mapping.addr()) { + Ok(mut conn) => { + conn.update_peer_credit(&packet.header); + self.connections.insert(key, conn); + self.rx.push_back(RxEvent::NewConnection(key)); + } + Err(e) => { + self.send_conn_rst(key); + error!(self.log, "{e}"); + } + }; + } + + /// Handle the guest's VIRTIO_VSOCK_OP_SHUTDOWN packet. + fn handle_shutdown(&mut self, key: ConnKey, flags: u32) { + if let Entry::Occupied(mut entry) = self.connections.entry(key) { + let conn = entry.get_mut(); + + // Guest won't receive more data + if flags & VIRTIO_VSOCK_SHUTDOWN_F_RECEIVE != 0 { + if let Err(e) = conn.shutdown_guest_read() { + error!( + &self.log, + "cannot transition vsock connection state: {e}"; + "conn" => ?conn, + ); + entry.remove(); + self.send_conn_rst(key); + return; + }; + } + // Guest won't send more data + if flags & VIRTIO_VSOCK_SHUTDOWN_F_SEND != 0 { + if let Err(e) = conn.shutdown_guest_write() { + error!( + &self.log, + "cannot transition vsock connection state: {e}"; + "conn" => ?conn, + ); + entry.remove(); + self.send_conn_rst(key); + return; + }; + } + // XXX how do we register this for future cleanup if there is data + // we have not synced locally yet? We need a cleanup loop... + if conn.should_close() { + if !conn.has_buffered_data() { + self.connections.remove(&key); + // virtio spec states: + // + // Clean disconnect is achieved by one or more + // VIRTIO_VSOCK_OP_SHUTDOWN packets that indicate no + // more data will be sent and received, followed by a + // VIRTIO_VSOCK_OP_RST response from the peer. + self.send_conn_rst(key); + } + } + } + } + + /// Handle the guest's VIRTIO_VSOCK_OP_RW packet. + fn handle_rw_packet(&mut self, key: ConnKey, packet: VsockPacket) { + if let Entry::Occupied(mut entry) = self.connections.entry(key) { + let conn = entry.get_mut(); + + // If we have a valid connection attempt to consume the guest's + // packet. + if let Err(e) = conn.recv_packet(packet) { + error!( + &self.log, + "failed to push vsock packet data into the conn vbuf: {e}"; + "conn" => ?conn, + ); + + entry.remove(); + self.send_conn_rst(key); + return; + } + + if let Some(interests) = conn.poll_interests() { + let fd = conn.get_fd(); + self.associate_fd(key, fd, interests); + } + }; + } + + /// Handle the guest's tx virtqueue. + fn handle_tx_queue_event(&mut self) { + loop { + let packet = match self.queues.recv_packet().transpose() { + Ok(Some(packet)) => packet, + // No more packets on the guests tx queue + Ok(None) => break, + Err(e) => { + warn!(&self.log, "dropping invalid vsock packet: {e}"); + continue; + } + }; + + // If the packet is not destined for the host drop it. + if packet.header.dst_cid() != VSOCK_HOST_CID { + debug!( + &self.log, + "droppping vsock packet not destined for the host"; + "packet" => ?packet, + ); + continue; + } + + let key = ConnKey { + host_port: packet.header.dst_port(), + guest_port: packet.header.src_port(), + }; + + // We only support stream connections + if packet.header.socket_type() != VIRTIO_VSOCK_TYPE_STREAM { + self.send_conn_rst(key); + warn!(&self.log, + "received invalid vsock packet"; + "type" => %packet.header.socket_type(), + ); + continue; + } + + if let Some(conn) = self.connections.get_mut(&key) { + // Regardless of the vsock operation we need to record the peers + // credit info + conn.update_peer_credit(&packet.header); + match packet.header.op() { + VIRTIO_VSOCK_OP_RST => { + self.connections.remove(&key); + } + VIRTIO_VSOCK_OP_SHUTDOWN => { + self.handle_shutdown(key, packet.header.flags()); + } + // Handled above for every packet + VIRTIO_VSOCK_OP_CREDIT_UPDATE => continue, + VIRTIO_VSOCK_OP_CREDIT_REQUEST => { + if self.connections.contains_key(&key) { + self.rx.push_back(RxEvent::CreditUpdate(key)); + } + } + VIRTIO_VSOCK_OP_RW => { + self.handle_rw_packet(key, packet); + } + _ => { + warn!( + &self.log, + "received vsock packet with unknown op code"; + "packet" => ?packet, + ); + } + } + } else { + match packet.header.op() { + VIRTIO_VSOCK_OP_REQUEST => { + self.handle_connection_request(key, packet) + } + VIRTIO_VSOCK_OP_RST => {} + _ => { + warn!( + &self.log, + "received a vsock packet for an unknown connection \ + that was not a REQUEST or RST"; + "packet" => ?packet, + ); + } + } + } + } + } + + /// Process the rx virtqueue (host -> guest). + fn handle_rx_queue_event(&mut self) { + // Now that more descriptors have become available for sending vsock + // packets attempt to drain pending packets + self.process_pending_rx(); + + // Re-register connections that were blocked waiting for rx queue space. + // It would be nice if we had a hint of how many descriptors became + // available but that's not the case today. + for key in std::mem::take(&mut self.rx_blocked).drain(..) { + if let Some(conn) = self.connections.get(&key) { + if let Some(interests) = conn.poll_interests() { + let fd = conn.get_fd(); + self.associate_fd(key, fd, interests); + } + } + } + } + + // Attempt to send any queued rx packets destined for the guest. + fn process_pending_rx(&mut self) { + while let Some(permit) = self.queues.try_rx_permit() { + let Some(rx_event) = self.rx.pop_front() else { + break; + }; + + match rx_event { + RxEvent::Reset(key) => { + let packet = VsockPacket::new_reset( + self.guest_cid, + key.host_port, + key.guest_port, + ); + permit.write(&packet.header, &packet.data); + } + RxEvent::NewConnection(key) => { + let packet = VsockPacket::new_response( + self.guest_cid, + key.host_port, + key.guest_port, + CONN_TX_BUF_SIZE as u32, + ); + permit.write(&packet.header, &packet.data); + + if let Entry::Occupied(mut entry) = + self.connections.entry(key) + { + let conn = entry.get_mut(); + if let Err(e) = conn.set_established() { + error!( + &self.log, + "cannot transition vsock connection state: {e}"; + "conn" => ?conn, + ); + entry.remove(); + self.send_conn_rst(key); + continue; + }; + + if let Some(interests) = conn.poll_interests() { + let fd = conn.get_fd(); + self.associate_fd(key, fd, interests); + } + } + } + RxEvent::CreditUpdate(key) => { + if let Some(conn) = self.connections.get_mut(&key) { + let packet = VsockPacket::new_credit_update( + self.guest_cid, + key.host_port, + key.guest_port, + conn.buf_alloc(), + conn.fwd_cnt(), + ); + permit.write(&packet.header, &packet.data); + conn.mark_credit_sent(); + } + } + } + } + } + + /// Handle a user event. Returns `true` if the event loop should shut down. + fn handle_user_event(&mut self, event: PortEvent) -> bool { + match event.user { + val if val == VsockEvent::TxQueue as usize => { + self.handle_tx_queue_event() + } + val if val == VsockEvent::RxQueue as usize => { + self.handle_rx_queue_event() + } + val if val == VsockEvent::Shutdown as usize => return true, + _ => (), + } + false + } + + /// Handle an fd event by flushing data to the underlying socket from the + /// connections [`VsockBuf`], and by reading data from the socket and + /// sending it to the guest as a `VIRTIO_VSOCK_OP_RW` packet. + fn handle_fd_event(&mut self, event: PortEvent, read_buf: &mut [u8]) { + let key = ConnKey::from_portev_user(event.user); + let events = PollEvents::from_bits_retain(event.events); + + if fd_writable().intersects(events) { + self.handle_writable_fd(key); + } + + if fd_readable().intersects(events) { + self.handle_readable_fd(key, read_buf); + } + } + + /// When an fd is writable, drain buffered guest data to the host socket. + fn handle_writable_fd(&mut self, key: ConnKey) { + let Some(conn) = self.connections.get_mut(&key) else { + return; + }; + + loop { + match conn.flush() { + Ok(0) => break, + Ok(nbytes) => { + conn.update_fwd_cnt(nbytes as u32); + if conn.needs_credit_update() { + self.rx.push_back(RxEvent::CreditUpdate(key)); + } + } + Err(e) if e.kind() == ErrorKind::WouldBlock => break, + Err(e) => { + eprintln!("error writing to socket: {e}"); + break; + } + } + } + + // We have finished draining our buffered data to the host, so check if + // we should remove ourselves from the active connections. + if conn.should_close() && !conn.has_buffered_data() { + self.connections.remove(&key); + self.send_conn_rst(key); + return; + } + + if let Some(interests) = conn.poll_interests() { + let fd = conn.get_fd(); + self.associate_fd(key, fd, interests); + } + } + + /// When an fd is readable, read from host socket and send to guest. + fn handle_readable_fd(&mut self, key: ConnKey, read_buf: &mut [u8]) { + let VsockPoller { queues, connections, guest_cid, rx_blocked, .. } = + self; + + let Some(conn) = connections.get_mut(&key) else { + return; + }; + + // The guest is no longer expecting any data + if !conn.guest_can_read() { + return; + } + + loop { + let Some(permit) = queues.try_rx_permit() else { + rx_blocked.push(key); + break; + }; + + let credit = conn.peer_credit(); + if credit == 0 { + // TODO: when this happens under sufficient load there's the + // possibility we wake up the event loop repeatedly and we + // should defer associating this fd again until there's enough + // credit. This is similar to the `rx_blocked` queue but + // slightly different. + break; + } + + let max_read = std::cmp::min( + permit.available_data_space(), + std::cmp::min(credit as usize, read_buf.len()), + ); + + match conn.socket.read(&mut read_buf[..max_read]) { + Ok(0) => { + // TODO the guest is supposed to send us a RST to finalize + // the shutdown. We need to put this on a quiesce queue so + // that we don't leave a half open connection laying around + // in our connection map. + let packet = VsockPacket::new_shutdown( + *guest_cid, + key.host_port, + key.guest_port, + VIRTIO_VSOCK_SHUTDOWN_F_SEND + | VIRTIO_VSOCK_SHUTDOWN_F_RECEIVE, + ); + permit.write(&packet.header, &packet.data); + return; + } + Ok(nbytes) => { + conn.update_tx_cnt(nbytes as u32); + let mut header = VsockPacketHeader::default(); + header + .set_src_cid(VSOCK_HOST_CID as u32) + .set_dst_cid(*guest_cid) + .set_src_port(key.host_port) + .set_dst_port(key.guest_port) + .set_len(nbytes as u32) + .set_socket_type(VIRTIO_VSOCK_TYPE_STREAM) + .set_op(VIRTIO_VSOCK_OP_RW) + .set_buf_alloc(conn.buf_alloc()) + .set_fwd_cnt(conn.fwd_cnt()); + permit.write(&header, &read_buf[..nbytes]); + } + Err(e) if e.kind() == ErrorKind::WouldBlock => break, + Err(e) => { + error!( + &self.log, + "vsock backend socket read faild: {e}"; + "key" => ?key, + "conn" => ?conn, + ); + + connections.remove(&key); + let packet = VsockPacket::new_reset( + *guest_cid, + key.host_port, + key.guest_port, + ); + permit.write(&packet.header, &packet.data); + return; + } + } + } + + if let Some(interests) = conn.poll_interests() { + let fd = conn.get_fd(); + self.associate_fd(key, fd, interests); + } + } + + /// Associate a connections underlying socket fd with our port fd. + fn associate_fd(&mut self, key: ConnKey, fd: RawFd, interests: PollEvents) { + let ret = unsafe { + libc::port_associate( + self.port_fd.as_raw_fd(), + libc::PORT_SOURCE_FD, + fd as usize, + interests.bits(), + key.to_portev_user() as *mut c_void, + ) + }; + + if ret < 0 { + let err = std::io::Error::last_os_error(); + if let Some(conn) = self.connections.remove(&key) { + error!( + &self.log, + "vsock port_assocaite failed: {err}"; + "key" => ?key, + "conn" => ?conn, + ); + self.send_conn_rst(key); + } + } + } + + /// Enqueue a RST packet for the provided [`ConnKey`] + fn send_conn_rst(&mut self, key: ConnKey) { + self.rx.push_back(RxEvent::Reset(key)); + } + + /// This is the vsock event-loop. It's responsible for handling vsock + /// packets to and from the guest. + fn handle_events(&mut self) { + const MAX_EVENTS: u32 = 32; + + let mut events: [MaybeUninit; MAX_EVENTS as usize] = + [const { MaybeUninit::uninit() }; MAX_EVENTS as usize]; + let mut read_buf = vec![0u8; 1024 * 64]; + + loop { + let mut nget = 1; + + let ret = unsafe { + libc::port_getn( + self.port_fd.as_raw_fd(), + events.as_mut_ptr() as *mut libc::port_event, + MAX_EVENTS, + &mut nget, + std::ptr::null_mut(), + ) + }; + + if ret < 0 { + let err = std::io::Error::last_os_error(); + // SAFETY: The docs state that `raw_os_error` will always return + // a `Some` variant when obtained via `las_os_error`. + match err.raw_os_error().unwrap() { + // A signal was caught so process the loop again + libc::EINTR => continue, + libc::EBADF | libc::EBADFD => { + // This means our event loop is effectively no + // longer servicable and the vsock device is useless. + error!( + &self.log, + "vsock port fd is no longer valid: {err}" + ); + return; + } + _ => { + error!(&self.log, "vsock port_getn returned: {err}"); + continue; + } + } + } + + for i in 0..nget as usize { + let event = PortEvent::from_raw(unsafe { + events[i].assume_init_read() + }); + + match event.source { + EventSource::User => { + let should_shutdown = self.handle_user_event(event); + if should_shutdown { + return; + } + } + EventSource::Fd => { + self.handle_fd_event(event, &mut read_buf); + } + _ => {} + }; + } + + // Process any pending rx events + self.process_pending_rx(); + } + } +} + +/// The source of a port event. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum EventSource { + /// User event i.e. `port_send(3C)` + User, + /// File descriptor event + Fd, + /// Unknown source for the vsock backend + Unknown(u16), +} + +impl EventSource { + fn from_raw(source: u16) -> Self { + match source as i32 { + libc::PORT_SOURCE_USER => EventSource::User, + libc::PORT_SOURCE_FD => EventSource::Fd, + _ => EventSource::Unknown(source), + } + } +} + +/// A port event retrieved from an event port. +/// +/// This represents an event from one of the various event sources (file +/// descriptors, timers, user events, etc.). +#[derive(Debug, Clone)] +struct PortEvent { + /// The events that occurred (source-specific) + events: i32, + /// The source of the event + source: EventSource, + /// The object associated with the event (interpretation depends on source) + #[allow(dead_code)] + object: usize, + /// User-defined data provided during association + user: usize, +} + +impl PortEvent { + fn from_raw(event: libc::port_event) -> Self { + PortEvent { + events: event.portev_events, + source: EventSource::from_raw(event.portev_source), + object: event.portev_object, + user: event.portev_user as usize, + } + } +} + +#[cfg(test)] +mod tests { + use std::io::{Read, Write}; + use std::net::TcpListener; + use std::sync::atomic::{AtomicUsize, Ordering}; + use std::sync::Arc; + use std::time::Duration; + + use iddqd::IdHashMap; + + use crate::hw::virtio::testutil::{QueueWriter, TestVirtQueues, VqSize}; + use crate::hw::virtio::vsock::{VsockVq, VSOCK_RX_QUEUE, VSOCK_TX_QUEUE}; + use crate::vsock::packet::{ + VsockPacketHeader, VIRTIO_VSOCK_OP_CREDIT_UPDATE, + VIRTIO_VSOCK_OP_REQUEST, VIRTIO_VSOCK_OP_RESPONSE, VIRTIO_VSOCK_OP_RST, + VIRTIO_VSOCK_OP_RW, VIRTIO_VSOCK_OP_SHUTDOWN, + VIRTIO_VSOCK_SHUTDOWN_F_RECEIVE, VIRTIO_VSOCK_SHUTDOWN_F_SEND, + VIRTIO_VSOCK_TYPE_STREAM, + }; + use crate::vsock::proxy::{VsockPortMapping, CONN_TX_BUF_SIZE}; + use crate::vsock::VSOCK_HOST_CID; + + use super::VsockPoller; + + fn test_logger() -> slog::Logger { + use slog::Drain; + let decorator = slog_term::TermDecorator::new().stderr().build(); + let drain = slog_term::FullFormat::new(decorator).build().fuse(); + let drain = slog_async::Async::new(drain).build().fuse(); + slog::Logger::root(drain, slog::o!("component" => "vsock-test")) + } + + const QUEUE_SIZE: u16 = 64; + const PAGE_SIZE: u64 = 0x1000; + + /// Bind a TCP listener on an ephemeral port and return it along with an + /// `IdHashMap` that maps `vsock_port` to the listener's + /// actual address. + fn bind_test_backend( + vsock_port: u32, + ) -> (TcpListener, IdHashMap) { + let listener = TcpListener::bind("127.0.0.1:0").unwrap(); + let addr = listener.local_addr().unwrap(); + let mut backends = IdHashMap::new(); + backends.insert_overwrite(VsockPortMapping::new(vsock_port, addr)); + (listener, backends) + } + + /// Test harness for vsock poller tests using shared testutil infrastructure. + struct VsockTestHarness { + tvqs: TestVirtQueues, + rx_writer: QueueWriter, + tx_writer: QueueWriter, + } + + impl VsockTestHarness { + fn new() -> Self { + let tvqs = TestVirtQueues::new(&[ + VqSize::new(QUEUE_SIZE), // RX + VqSize::new(QUEUE_SIZE), // TX + VqSize::new(1), // Event + ]); + + // RX and TX use separate data regions + let rx_writer = tvqs.writer(VSOCK_RX_QUEUE as usize, 0); + let tx_writer = + tvqs.writer(VSOCK_TX_QUEUE as usize, PAGE_SIZE * 16); + + Self { tvqs, rx_writer, tx_writer } + } + + fn make_vsock_vq(&self) -> VsockVq { + let queues: Vec<_> = + self.tvqs.queues().iter().map(|q| q.clone()).collect(); + let acc = self.tvqs.mem_acc().child(Some("vsock-vq".to_string())); + VsockVq::new(queues, acc) + } + + /// Add a writable descriptor to the RX queue and publish it. + fn add_rx_writable(&mut self, len: u32) -> u16 { + let d = self.rx_writer.add_writable(self.tvqs.mem_acc(), len); + self.rx_writer.publish_avail(self.tvqs.mem_acc(), d); + d + } + + /// Add a readable descriptor to the TX queue. + fn add_tx_readable(&mut self, data: &[u8]) -> u16 { + self.tx_writer.add_readable(self.tvqs.mem_acc(), data) + } + + /// Publish a descriptor on the TX queue. + fn publish_tx(&mut self, head: u16) { + self.tx_writer.publish_avail(self.tvqs.mem_acc(), head); + } + + /// Chain two TX descriptors together. + fn chain_tx(&mut self, from: u16, to: u16) { + self.tx_writer.chain(self.tvqs.mem_acc(), from, to); + } + + /// Reset TX writer cursors for reuse. + fn reset_tx_cursors(&mut self) { + self.tx_writer.reset_cursors(); + } + + /// Reset RX writer cursors for reuse. + fn reset_rx_cursors(&mut self) { + self.rx_writer.reset_cursors(); + } + + /// Read a vsock packet header and data from a used ring entry. + fn read_vsock_packet( + &self, + used_index: u16, + ) -> (VsockPacketHeader, Vec) { + let mem_acc = self.tvqs.mem_acc(); + let elem = self.rx_writer.read_used_elem(mem_acc, used_index); + let desc_id = elem.id as u16; + let total_len = elem.len as usize; + + // Read the entire buffer (header + data) + let buf = + self.rx_writer.read_desc_data(mem_acc, desc_id, total_len); + + // Parse header from the first bytes + let hdr_size = std::mem::size_of::(); + let mut hdr = VsockPacketHeader::default(); + unsafe { + std::ptr::copy_nonoverlapping( + buf.as_ptr(), + &mut hdr as *mut VsockPacketHeader as *mut u8, + hdr_size, + ); + } + + // Data is everything after the header + let data = buf[hdr_size..].to_vec(); + + (hdr, data) + } + + fn rx_used_idx(&self) -> u16 { + self.rx_writer.used_idx(self.tvqs.mem_acc()) + } + + fn tx_used_idx(&self) -> u16 { + self.tx_writer.used_idx(self.tvqs.mem_acc()) + } + } + + /// Helper: serialize a VsockPacketHeader to bytes. + fn hdr_as_bytes(hdr: &VsockPacketHeader) -> &[u8] { + unsafe { + std::slice::from_raw_parts( + hdr as *const VsockPacketHeader as *const u8, + std::mem::size_of::(), + ) + } + } + + /// Spin until a condition is met, with a timeout. + fn wait_for_condition(mut f: F, timeout_ms: u64) + where + F: FnMut() -> bool, + { + let start = std::time::Instant::now(); + let timeout = Duration::from_millis(timeout_ms); + while !f() { + if start.elapsed() > timeout { + panic!("timed out waiting for condition"); + } + std::thread::sleep(Duration::from_millis(1)); + } + } + + #[test] + fn request_receives_response() { + let vsock_port = 3000; + let guest_port = 1234; + let guest_cid: u32 = 50; + let (_listener, backends) = bind_test_backend(vsock_port); + + let mut harness = VsockTestHarness::new(); + let vq = harness.make_vsock_vq(); + let log = test_logger(); + let poller = VsockPoller::new(guest_cid, vq, log, backends).unwrap(); + + harness.add_rx_writable(256); + + let notify = poller.notify_handle(); + let handle = poller.run(); + + let mut hdr = VsockPacketHeader::default(); + hdr.set_src_cid(guest_cid) + .set_dst_cid(VSOCK_HOST_CID as u32) + .set_src_port(guest_port) + .set_dst_port(vsock_port) + .set_len(0) + .set_socket_type(VIRTIO_VSOCK_TYPE_STREAM) + .set_op(VIRTIO_VSOCK_OP_REQUEST) + .set_buf_alloc(65536) + .set_fwd_cnt(0); + + let d_tx = harness.add_tx_readable(hdr_as_bytes(&hdr)); + harness.publish_tx(d_tx); + notify.queue_notify(VSOCK_TX_QUEUE).unwrap(); + + wait_for_condition(|| harness.rx_used_idx() >= 1, 5000); + + let (resp_hdr, _) = harness.read_vsock_packet(0); + assert_eq!(resp_hdr.op(), VIRTIO_VSOCK_OP_RESPONSE); + assert_eq!(resp_hdr.src_cid(), VSOCK_HOST_CID); + assert_eq!(resp_hdr.dst_cid(), guest_cid as u64); + assert_eq!(resp_hdr.src_port(), vsock_port); + assert_eq!(resp_hdr.dst_port(), guest_port); + assert_eq!(resp_hdr.socket_type(), VIRTIO_VSOCK_TYPE_STREAM); + + notify.shutdown().unwrap(); + handle.join().unwrap(); + } + + #[test] + fn rw_with_invalid_socket_type_receives_rst() { + let guest_cid: u32 = 50; + + let mut harness = VsockTestHarness::new(); + let vq = harness.make_vsock_vq(); + let log = test_logger(); + let poller = + VsockPoller::new(guest_cid, vq, log, IdHashMap::new()).unwrap(); + + harness.add_rx_writable(256); + + let notify = poller.notify_handle(); + let handle = poller.run(); + + let invalid_socket_type: u16 = 0xBEEF; + let mut hdr = VsockPacketHeader::default(); + hdr.set_src_cid(guest_cid) + .set_dst_cid(VSOCK_HOST_CID as u32) + .set_src_port(5555) + .set_dst_port(8080) + .set_len(0) + .set_socket_type(invalid_socket_type) + .set_op(VIRTIO_VSOCK_OP_RW) + .set_buf_alloc(65536) + .set_fwd_cnt(0); + + let d_tx = harness.add_tx_readable(hdr_as_bytes(&hdr)); + harness.publish_tx(d_tx); + notify.queue_notify(VSOCK_TX_QUEUE).unwrap(); + + wait_for_condition(|| harness.rx_used_idx() >= 1, 5000); + + let (resp_hdr, _) = harness.read_vsock_packet(0); + assert_eq!(resp_hdr.op(), VIRTIO_VSOCK_OP_RST); + assert_eq!(resp_hdr.src_cid(), VSOCK_HOST_CID); + assert_eq!(resp_hdr.dst_cid(), guest_cid as u64); + assert_eq!(resp_hdr.src_port(), 8080); + assert_eq!(resp_hdr.dst_port(), 5555); + + notify.shutdown().unwrap(); + handle.join().unwrap(); + } + + #[test] + fn request_then_rw_delivers_data() { + let vsock_port = 3000; + let guest_port = 1234; + let guest_cid: u32 = 50; + let (listener, backends) = bind_test_backend(vsock_port); + listener.set_nonblocking(false).unwrap(); + + let mut harness = VsockTestHarness::new(); + let vq = harness.make_vsock_vq(); + let log = test_logger(); + let poller = VsockPoller::new(guest_cid, vq, log, backends).unwrap(); + + for _ in 0..4 { + harness.add_rx_writable(4096); + } + + let notify = poller.notify_handle(); + let handle = poller.run(); + + // Send REQUEST + let mut req_hdr = VsockPacketHeader::default(); + req_hdr + .set_src_cid(guest_cid) + .set_dst_cid(VSOCK_HOST_CID as u32) + .set_src_port(guest_port) + .set_dst_port(vsock_port) + .set_len(0) + .set_socket_type(VIRTIO_VSOCK_TYPE_STREAM) + .set_op(VIRTIO_VSOCK_OP_REQUEST) + .set_buf_alloc(65536) + .set_fwd_cnt(0); + + let d_tx = harness.add_tx_readable(hdr_as_bytes(&req_hdr)); + harness.publish_tx(d_tx); + notify.queue_notify(VSOCK_TX_QUEUE).unwrap(); + + // Accept TCP connection and wait for RESPONSE + let mut accepted = listener.accept().unwrap().0; + accepted.set_nonblocking(false).unwrap(); + accepted.set_read_timeout(Some(Duration::from_secs(5))).unwrap(); + wait_for_condition(|| harness.rx_used_idx() >= 1, 5000); + + // Send RW packet with data payload + let payload = b"hello from guest via vsock!"; + let mut rw_hdr = VsockPacketHeader::default(); + rw_hdr + .set_src_cid(guest_cid) + .set_dst_cid(VSOCK_HOST_CID as u32) + .set_src_port(guest_port) + .set_dst_port(vsock_port) + .set_len(payload.len() as u32) + .set_socket_type(VIRTIO_VSOCK_TYPE_STREAM) + .set_op(VIRTIO_VSOCK_OP_RW) + .set_buf_alloc(65536) + .set_fwd_cnt(0); + + let d_hdr = harness.add_tx_readable(hdr_as_bytes(&rw_hdr)); + let d_body = harness.add_tx_readable(payload); + harness.chain_tx(d_hdr, d_body); + harness.publish_tx(d_hdr); + notify.queue_notify(VSOCK_TX_QUEUE).unwrap(); + + // Read from accepted TCP stream and verify + let mut buf = vec![0u8; payload.len()]; + accepted.read_exact(&mut buf).unwrap(); + assert_eq!(&buf, payload); + + notify.shutdown().unwrap(); + handle.join().unwrap(); + } + + #[test] + fn credit_update_sent_after_flushing_half_buffer() { + let vsock_port = 4000; + let guest_port = 2000; + let guest_cid: u32 = 50; + let (listener, backends) = bind_test_backend(vsock_port); + listener.set_nonblocking(false).unwrap(); + + let mut harness = VsockTestHarness::new(); + let vq = harness.make_vsock_vq(); + let log = test_logger(); + let poller = VsockPoller::new(guest_cid, vq, log, backends).unwrap(); + + // Provide plenty of RX descriptors for RESPONSE + credit updates + for _ in 0..16 { + harness.add_rx_writable(4096); + } + + let notify = poller.notify_handle(); + let handle = poller.run(); + + // Establish connection + let mut req_hdr = VsockPacketHeader::default(); + req_hdr + .set_src_cid(guest_cid) + .set_dst_cid(VSOCK_HOST_CID as u32) + .set_src_port(guest_port) + .set_dst_port(vsock_port) + .set_len(0) + .set_socket_type(VIRTIO_VSOCK_TYPE_STREAM) + .set_op(VIRTIO_VSOCK_OP_REQUEST) + .set_buf_alloc(65536) + .set_fwd_cnt(0); + + let d_tx = harness.add_tx_readable(hdr_as_bytes(&req_hdr)); + harness.publish_tx(d_tx); + notify.queue_notify(VSOCK_TX_QUEUE).unwrap(); + + let mut accepted = listener.accept().unwrap().0; + accepted.set_nonblocking(false).unwrap(); + accepted.set_read_timeout(Some(Duration::from_secs(5))).unwrap(); + wait_for_condition(|| harness.rx_used_idx() >= 1, 5000); + + // Send enough data to exceed half the buffer capacity (64KB). + let chunk_size = 8192; + let num_chunks = (CONN_TX_BUF_SIZE / 2) / chunk_size + 1; + let payload = vec![0xAB_u8; chunk_size]; + let total_sent = num_chunks * chunk_size; + let mut tx_consumed = 1u16; // REQUEST was consumed + + for _ in 0..num_chunks { + // Reuse descriptor slots each iteration + harness.reset_tx_cursors(); + + let mut rw_hdr = VsockPacketHeader::default(); + rw_hdr + .set_src_cid(guest_cid) + .set_dst_cid(VSOCK_HOST_CID as u32) + .set_src_port(guest_port) + .set_dst_port(vsock_port) + .set_len(payload.len() as u32) + .set_socket_type(VIRTIO_VSOCK_TYPE_STREAM) + .set_op(VIRTIO_VSOCK_OP_RW) + .set_buf_alloc(65536) + .set_fwd_cnt(0); + + let d_hdr = harness.add_tx_readable(hdr_as_bytes(&rw_hdr)); + let d_body = harness.add_tx_readable(&payload); + harness.chain_tx(d_hdr, d_body); + harness.publish_tx(d_hdr); + notify.queue_notify(VSOCK_TX_QUEUE).unwrap(); + + tx_consumed += 1; + wait_for_condition(|| harness.tx_used_idx() >= tx_consumed, 5000); + } + + // Drain the data from the accepted socket to confirm it arrived + let mut buf = vec![0u8; total_sent]; + accepted.read_exact(&mut buf).unwrap(); + assert!(buf.iter().all(|&b| b == 0xAB)); + + // Look for a CREDIT_UPDATE in the RX used entries + let rx_used = harness.rx_used_idx(); + assert!(rx_used >= 2, "expected at least RESPONSE + CREDIT_UPDATE"); + + let mut found_credit_update = false; + for i in 1..rx_used { + let (hdr, _) = harness.read_vsock_packet(i); + if hdr.op() == VIRTIO_VSOCK_OP_CREDIT_UPDATE { + assert_eq!(hdr.src_cid(), VSOCK_HOST_CID); + assert_eq!(hdr.dst_cid(), guest_cid as u64); + assert_eq!(hdr.src_port(), vsock_port); + assert_eq!(hdr.dst_port(), guest_port); + assert_eq!(hdr.buf_alloc(), CONN_TX_BUF_SIZE as u32); + found_credit_update = true; + break; + } + } + assert!(found_credit_update, "expected a CREDIT_UPDATE on RX queue"); + + notify.shutdown().unwrap(); + handle.join().unwrap(); + } + + #[test] + fn rst_removes_established_connection() { + let vsock_port = 5000; + let guest_port = 3000; + let guest_cid: u32 = 50; + let (listener, backends) = bind_test_backend(vsock_port); + listener.set_nonblocking(false).unwrap(); + + let mut harness = VsockTestHarness::new(); + let vq = harness.make_vsock_vq(); + let log = test_logger(); + let poller = VsockPoller::new(guest_cid, vq, log, backends).unwrap(); + + for _ in 0..4 { + harness.add_rx_writable(4096); + } + + let notify = poller.notify_handle(); + let handle = poller.run(); + + // Send REQUEST + let mut req_hdr = VsockPacketHeader::default(); + req_hdr + .set_src_cid(guest_cid) + .set_dst_cid(VSOCK_HOST_CID as u32) + .set_src_port(guest_port) + .set_dst_port(vsock_port) + .set_len(0) + .set_socket_type(VIRTIO_VSOCK_TYPE_STREAM) + .set_op(VIRTIO_VSOCK_OP_REQUEST) + .set_buf_alloc(65536) + .set_fwd_cnt(0); + + let d_tx = harness.add_tx_readable(hdr_as_bytes(&req_hdr)); + harness.publish_tx(d_tx); + notify.queue_notify(VSOCK_TX_QUEUE).unwrap(); + + let mut accepted = listener.accept().unwrap().0; + accepted.set_nonblocking(false).unwrap(); + accepted.set_read_timeout(Some(Duration::from_secs(5))).unwrap(); + wait_for_condition(|| harness.rx_used_idx() >= 1, 5000); + + // Send RST + let mut rst_hdr = VsockPacketHeader::default(); + rst_hdr + .set_src_cid(guest_cid) + .set_dst_cid(VSOCK_HOST_CID as u32) + .set_src_port(guest_port) + .set_dst_port(vsock_port) + .set_len(0) + .set_socket_type(VIRTIO_VSOCK_TYPE_STREAM) + .set_op(VIRTIO_VSOCK_OP_RST) + .set_buf_alloc(0) + .set_fwd_cnt(0); + + let d_rst = harness.add_tx_readable(hdr_as_bytes(&rst_hdr)); + harness.publish_tx(d_rst); + notify.queue_notify(VSOCK_TX_QUEUE).unwrap(); + + // Wait for the RST to be consumed + wait_for_condition(|| harness.tx_used_idx() >= 2, 5000); + + // Verify the TCP connection was closed by reading from the + // accepted stream. + let mut buf = [0u8; 1]; + let result = accepted.read(&mut buf); + match result { + Ok(0) => {} + Err(_) => {} + Ok(n) => panic!("expected EOF or error, got {n} bytes"), + } + + notify.shutdown().unwrap(); + handle.join().unwrap(); + } + + #[test] + fn end_to_end_guest_to_host() { + let vsock_port = 7000; + let guest_port = 5000; + let guest_cid: u32 = 50; + let (listener, backends) = bind_test_backend(vsock_port); + listener.set_nonblocking(false).unwrap(); + + let mut harness = VsockTestHarness::new(); + let vq = harness.make_vsock_vq(); + let log = test_logger(); + let poller = VsockPoller::new(guest_cid, vq, log, backends).unwrap(); + + // Pre-populate RX queue with writable descriptors for RESPONSE + data + for _ in 0..8 { + harness.add_rx_writable(4096); + } + + let notify = poller.notify_handle(); + let handle = poller.run(); + + // Write REQUEST packet into TX queue + let mut req_hdr = VsockPacketHeader::default(); + req_hdr + .set_src_cid(guest_cid) + .set_dst_cid(VSOCK_HOST_CID as u32) + .set_src_port(guest_port) + .set_dst_port(vsock_port) + .set_len(0) + .set_socket_type(VIRTIO_VSOCK_TYPE_STREAM) + .set_op(VIRTIO_VSOCK_OP_REQUEST) + .set_buf_alloc(65536) + .set_fwd_cnt(0); + + let d_tx = harness.add_tx_readable(hdr_as_bytes(&req_hdr)); + harness.publish_tx(d_tx); + notify.queue_notify(VSOCK_TX_QUEUE).unwrap(); + + // Accept the TCP connection (blocks until poller connects) + let mut accepted = listener.accept().unwrap().0; + accepted.set_nonblocking(false).unwrap(); + accepted.set_read_timeout(Some(Duration::from_secs(5))).unwrap(); + + // Wait for RESPONSE on RX queue + wait_for_condition(|| harness.rx_used_idx() >= 1, 5000); + + // Guest->Host: send RW packet with payload + let payload = b"hello from guest via vsock end-to-end!"; + let mut rw_hdr = VsockPacketHeader::default(); + rw_hdr + .set_src_cid(guest_cid) + .set_dst_cid(VSOCK_HOST_CID as u32) + .set_src_port(guest_port) + .set_dst_port(vsock_port) + .set_len(payload.len() as u32) + .set_socket_type(VIRTIO_VSOCK_TYPE_STREAM) + .set_op(VIRTIO_VSOCK_OP_RW) + .set_buf_alloc(65536) + .set_fwd_cnt(0); + + let d_hdr = harness.add_tx_readable(hdr_as_bytes(&rw_hdr)); + let d_body = harness.add_tx_readable(payload); + harness.chain_tx(d_hdr, d_body); + harness.publish_tx(d_hdr); + notify.queue_notify(VSOCK_TX_QUEUE).unwrap(); + + // Read from accepted TCP stream, and verify guest->host data + let mut buf = vec![0u8; payload.len()]; + accepted.read_exact(&mut buf).unwrap(); + assert_eq!(&buf, payload, "guest->host data mismatch"); + + // Host->Guest: write data into accepted TCP stream + let host_payload = b"reply from host via vsock!"; + accepted.write_all(host_payload).unwrap(); + accepted.flush().unwrap(); + + // Wait for RW packet on RX queue (RESPONSE was 1, now expect 2+) + wait_for_condition(|| harness.rx_used_idx() >= 2, 5000); + + // Read back the RW packet from RX used ring entry 1 + let (resp_hdr, host_buf) = harness.read_vsock_packet(1); + + assert_eq!(resp_hdr.op(), VIRTIO_VSOCK_OP_RW); + assert_eq!(resp_hdr.src_port(), vsock_port); + assert_eq!(resp_hdr.dst_port(), guest_port); + assert_eq!(&host_buf, host_payload, "host->guest data mismatch"); + + notify.shutdown().unwrap(); + handle.join().unwrap(); + } + + #[test] + fn rx_blocked_resumes_when_descriptors_available() { + let vsock_port = 6000; + let guest_port = 4000; + let guest_cid: u32 = 50; + let (listener, backends) = bind_test_backend(vsock_port); + listener.set_nonblocking(false).unwrap(); + + let mut harness = VsockTestHarness::new(); + let vq = harness.make_vsock_vq(); + let log = test_logger(); + let poller = VsockPoller::new(guest_cid, vq, log, backends).unwrap(); + + // Provide only one RX descriptor, just enough for the RESPONSE. + harness.add_rx_writable(4096); + + let notify = poller.notify_handle(); + let handle = poller.run(); + + // Send REQUEST + let mut req_hdr = VsockPacketHeader::default(); + req_hdr + .set_src_cid(guest_cid) + .set_dst_cid(VSOCK_HOST_CID as u32) + .set_src_port(guest_port) + .set_dst_port(vsock_port) + .set_len(0) + .set_socket_type(VIRTIO_VSOCK_TYPE_STREAM) + .set_op(VIRTIO_VSOCK_OP_REQUEST) + .set_buf_alloc(65536) + .set_fwd_cnt(0); + + let d_tx = harness.add_tx_readable(hdr_as_bytes(&req_hdr)); + harness.publish_tx(d_tx); + notify.queue_notify(VSOCK_TX_QUEUE).unwrap(); + + let mut accepted = listener.accept().unwrap().0; + accepted.set_nonblocking(false).unwrap(); + wait_for_condition(|| harness.rx_used_idx() >= 1, 5000); + + // The RESPONSE consumed the only RX descriptor. Write data from + // the host side. + let host_data = b"data from the host side"; + accepted.write_all(host_data).unwrap(); + accepted.flush().unwrap(); + + // Give the poller time to attempt delivery (and get blocked) + std::thread::sleep(Duration::from_millis(100)); + + // Verify no new used entries appeared (still just the RESPONSE) + assert_eq!(harness.rx_used_idx(), 1); + + // Add new RX descriptors and notify + harness.reset_rx_cursors(); + harness.add_rx_writable(4096); + notify.queue_notify(VSOCK_RX_QUEUE).unwrap(); + + // Wait for the data to be delivered + wait_for_condition(|| harness.rx_used_idx() >= 2, 5000); + + let (rw_hdr, payload) = harness.read_vsock_packet(1); + assert_eq!(rw_hdr.op(), VIRTIO_VSOCK_OP_RW); + assert_eq!(rw_hdr.src_port(), vsock_port); + assert_eq!(rw_hdr.dst_port(), guest_port); + assert_eq!(&payload, host_data); + + notify.shutdown().unwrap(); + handle.join().unwrap(); + } + + /// End-to-end test with large data transfers in both directions, + /// exercising rx_blocked, credit updates, and descriptor replenishment + /// across many batches of reused descriptor slots. + #[test] + fn end_to_end_large_data() { + let total_bytes: usize = 10 * 1024 * 1024; + + let vsock_port = 8000; + let guest_port = 6000; + let guest_cid: u32 = 50; + let (listener, backends) = bind_test_backend(vsock_port); + listener.set_nonblocking(false).unwrap(); + + let mut harness = VsockTestHarness::new(); + let vq = harness.make_vsock_vq(); + let log = test_logger(); + let poller = VsockPoller::new(guest_cid, vq, log, backends).unwrap(); + + // Provide initial RX descriptors for RESPONSE + credit updates + for _ in 0..8 { + harness.add_rx_writable(4096); + } + + let notify = poller.notify_handle(); + let handle = poller.run(); + + // Establish connection + // Use a large buf_alloc so host->guest credit doesn't run out + // before we've transferred all the data. + let buf_alloc = total_bytes as u32 * 2; + + let mut req_hdr = VsockPacketHeader::default(); + req_hdr + .set_src_cid(guest_cid) + .set_dst_cid(VSOCK_HOST_CID as u32) + .set_src_port(guest_port) + .set_dst_port(vsock_port) + .set_len(0) + .set_socket_type(VIRTIO_VSOCK_TYPE_STREAM) + .set_op(VIRTIO_VSOCK_OP_REQUEST) + .set_buf_alloc(buf_alloc) + .set_fwd_cnt(0); + + let d_tx = harness.add_tx_readable(hdr_as_bytes(&req_hdr)); + harness.publish_tx(d_tx); + notify.queue_notify(VSOCK_TX_QUEUE).unwrap(); + + let accepted = listener.accept().unwrap().0; + accepted.set_nonblocking(false).unwrap(); + + wait_for_condition(|| harness.rx_used_idx() >= 1, 5000); + + // A reader thread drains the TCP socket while the main thread + // injects RW packets in batches, reusing descriptor slots and + // guest memory between batches. + let guest_data: Vec = + (0..total_bytes).map(|i| (i % 251) as u8).collect(); + + // Track how many bytes the reader has consumed so we can apply + // backpressure and avoid overflowing the poller's VsockBuf. + let bytes_read = Arc::new(AtomicUsize::new(0)); + let tcp_reader = { + let mut stream = accepted.try_clone().unwrap(); + let len = total_bytes; + let progress = Arc::clone(&bytes_read); + std::thread::spawn(move || { + let mut result = Vec::with_capacity(len); + let mut chunk = vec![0u8; 65536]; + let mut total = 0; + while total < len { + let n = stream.read(&mut chunk).unwrap(); + assert!(n > 0, "unexpected EOF after {total}/{len}"); + result.extend_from_slice(&chunk[..n]); + total += n; + progress.store(total, Ordering::Release); + } + result + }) + }; + + let chunk_size = 4096; + let batch_packets = 8; // 8 packets × 2 descs = 16 descs per batch + let mut guest_sent = 0usize; + // TX used_idx starts at 1 (the REQUEST was consumed) + let mut tx_consumed = 1u16; + + while guest_sent < total_bytes { + let remaining = (total_bytes - guest_sent).div_ceil(chunk_size); + let this_batch = std::cmp::min(batch_packets, remaining); + // Backpressure: don't let in-flight data exceed VsockBuf + // capacity. The poller buffers TX data in VsockBuf (128KB) + // and flushes via POLLOUT. If we push faster than the + // flush rate, the buffer overflows and panics. + let after_send = guest_sent + this_batch * chunk_size; + loop { + let read = bytes_read.load(Ordering::Acquire); + if after_send <= read + CONN_TX_BUF_SIZE { + break; + } + std::thread::sleep(Duration::from_millis(1)); + } + + // Reuse the same descriptor slots and data region each batch. + // Safe because we wait for the previous batch to be fully + // consumed before overwriting. + harness.reset_tx_cursors(); + + for i in 0..this_batch { + let offset = guest_sent + i * chunk_size; + let end = std::cmp::min(offset + chunk_size, total_bytes); + let payload = &guest_data[offset..end]; + + let mut rw_hdr = VsockPacketHeader::default(); + rw_hdr + .set_src_cid(guest_cid) + .set_dst_cid(VSOCK_HOST_CID as u32) + .set_src_port(guest_port) + .set_dst_port(vsock_port) + .set_len(payload.len() as u32) + .set_socket_type(VIRTIO_VSOCK_TYPE_STREAM) + .set_op(VIRTIO_VSOCK_OP_RW) + .set_buf_alloc(buf_alloc) + .set_fwd_cnt(0); + + let d_hdr = harness.add_tx_readable(hdr_as_bytes(&rw_hdr)); + let d_body = harness.add_tx_readable(payload); + harness.chain_tx(d_hdr, d_body); + harness.publish_tx(d_hdr); + } + + notify.queue_notify(VSOCK_TX_QUEUE).unwrap(); + + // Wait for the poller to consume this entire batch before + // we overwrite the descriptor slots in the next iteration. + tx_consumed += this_batch as u16; + wait_for_condition(|| harness.tx_used_idx() >= tx_consumed, 10000); + + guest_sent += this_batch * chunk_size; + if guest_sent > total_bytes { + guest_sent = total_bytes; + } + } + + let received = tcp_reader.join().unwrap(); + assert_eq!(received.len(), total_bytes); + assert!(received == guest_data, "guest->host data mismatch"); + + // A writer thread pushes data into the TCP socket while the + // main thread replenishes RX descriptors in batches, reads + // completed used entries, and reuses descriptor slots once + // the entire batch has been consumed. + let host_data: Vec = + (0..total_bytes).map(|i| ((i + 7) % 251) as u8).collect(); + + let tcp_writer = { + let mut stream = accepted.try_clone().unwrap(); + let data = host_data.clone(); + std::thread::spawn(move || { + stream.write_all(&data).unwrap(); + }) + }; + + let mut host_to_guest = Vec::with_capacity(total_bytes); + + // Skip all used entries produced before this phase (RESPONSE + + // any credit updates from Phase 1). + let mut rx_next_used = harness.rx_used_idx(); + let rx_batch = 16u16; + let mut descs_outstanding = 0u16; + + while host_to_guest.len() < total_bytes { + // When all outstanding descriptors have been consumed we can + // safely reuse the descriptor slots and data region. + if descs_outstanding == 0 { + harness.reset_rx_cursors(); + + for _ in 0..rx_batch { + harness.add_rx_writable(4096); + descs_outstanding += 1; + } + notify.queue_notify(VSOCK_RX_QUEUE).unwrap(); + } + + // Wait for at least one new used entry. + wait_for_condition(|| harness.rx_used_idx() > rx_next_used, 10000); + + // Drain all currently available used entries. + let current_used = harness.rx_used_idx(); + while rx_next_used < current_used { + let (hdr, data) = harness.read_vsock_packet(rx_next_used); + rx_next_used += 1; + descs_outstanding -= 1; + + if hdr.op() == VIRTIO_VSOCK_OP_RW { + host_to_guest.extend_from_slice(&data); + } + // Credit updates and other control packets are + // silently consumed — they're expected here. + } + } + + tcp_writer.join().unwrap(); + assert_eq!(host_to_guest.len(), total_bytes); + assert!(host_to_guest == host_data, "host->guest data mismatch"); + + notify.shutdown().unwrap(); + handle.join().unwrap(); + } + + /// Closing the host-side TCP socket should cause the poller to send + /// a VIRTIO_VSOCK_OP_SHUTDOWN packet with VIRTIO_VSOCK_SHUTDOWN_F_SEND + /// to the guest, indicating the host will no longer send data. + #[test] + fn host_socket_eof_sends_shutdown() { + let vsock_port = 9000; + let guest_port = 7000; + let guest_cid: u32 = 50; + let (listener, backends) = bind_test_backend(vsock_port); + listener.set_nonblocking(false).unwrap(); + + let mut harness = VsockTestHarness::new(); + let vq = harness.make_vsock_vq(); + let log = test_logger(); + let poller = VsockPoller::new(guest_cid, vq, log, backends).unwrap(); + + // Provide RX descriptors for RESPONSE + SHUTDOWN + for _ in 0..4 { + harness.add_rx_writable(4096); + } + + let notify = poller.notify_handle(); + let handle = poller.run(); + + // -- Establish connection -- + let mut req_hdr = VsockPacketHeader::default(); + req_hdr + .set_src_cid(guest_cid) + .set_dst_cid(VSOCK_HOST_CID as u32) + .set_src_port(guest_port) + .set_dst_port(vsock_port) + .set_len(0) + .set_socket_type(VIRTIO_VSOCK_TYPE_STREAM) + .set_op(VIRTIO_VSOCK_OP_REQUEST) + .set_buf_alloc(65536) + .set_fwd_cnt(0); + + let d_tx = harness.add_tx_readable(hdr_as_bytes(&req_hdr)); + harness.publish_tx(d_tx); + notify.queue_notify(VSOCK_TX_QUEUE).unwrap(); + + // Accept the connection, wait for RESPONSE + let accepted = listener.accept().unwrap().0; + wait_for_condition(|| harness.rx_used_idx() >= 1, 5000); + + // Close the host-side socket to produce EOF + drop(accepted); + + // The poller should detect EOF on the next POLLIN and send + // a SHUTDOWN packet to the guest. + wait_for_condition(|| harness.rx_used_idx() >= 2, 5000); + + // Read back the packet from RX used ring entry 1 + let (hdr, _data) = harness.read_vsock_packet(1); + + assert_eq!(hdr.op(), VIRTIO_VSOCK_OP_SHUTDOWN); + assert_eq!(hdr.src_cid(), VSOCK_HOST_CID); + assert_eq!(hdr.dst_cid(), guest_cid as u64); + assert_eq!(hdr.src_port(), vsock_port); + assert_eq!(hdr.dst_port(), guest_port); + assert_eq!( + hdr.flags(), + VIRTIO_VSOCK_SHUTDOWN_F_SEND | VIRTIO_VSOCK_SHUTDOWN_F_RECEIVE + ); + + notify.shutdown().unwrap(); + handle.join().unwrap(); + } +} diff --git a/lib/propolis/src/vsock/poller_stub.rs b/lib/propolis/src/vsock/poller_stub.rs new file mode 100644 index 000000000..71bcdd1ee --- /dev/null +++ b/lib/propolis/src/vsock/poller_stub.rs @@ -0,0 +1,54 @@ +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at https://mozilla.org/MPL/2.0/. + +use std::thread::JoinHandle; + +use iddqd::IdHashMap; +use slog::Logger; + +use crate::hw::virtio::vsock::VsockVq; +use crate::vsock::proxy::VsockPortMapping; + +bitflags! { + pub struct PollEvents: i32 { + const IN = libc::POLLIN as i32; + const OUT = libc::POLLOUT as i32; + } +} + +pub struct VsockPollerNotify; + +impl VsockPollerNotify { + pub fn queue_notify(&self, _id: u16) -> std::io::Result<()> { + return Err(std::io::Error::other( + "not available on non-illumos systems", + )); + } +} + +pub struct VsockPoller; + +impl VsockPoller { + pub fn new( + _cid: u32, + _queues: VsockVq, + _log: Logger, + _port_mappings: IdHashMap, + ) -> std::io::Result { + return Err(std::io::Error::other( + "VsockPoller is not available on non-illumos systems", + )); + } + + pub fn notify_handle(&self) -> VsockPollerNotify { + VsockPollerNotify {} + } + + pub fn run(self) -> JoinHandle<()> { + std::thread::Builder::new() + .name("vsock-event-loop".to_string()) + .spawn(move || {}) + .expect("failed to spawn vsock event loop") + } +} diff --git a/lib/propolis/src/vsock/proxy.rs b/lib/propolis/src/vsock/proxy.rs new file mode 100644 index 000000000..969ca21f2 --- /dev/null +++ b/lib/propolis/src/vsock/proxy.rs @@ -0,0 +1,352 @@ +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at https://mozilla.org/MPL/2.0/. + +use std::net::SocketAddr; +use std::net::TcpStream; +use std::num::NonZeroUsize; +use std::num::Wrapping; +use std::os::fd::AsRawFd; +use std::os::fd::RawFd; +use std::thread::JoinHandle; +use std::time::Duration; + +use iddqd::IdHashItem; +use iddqd::IdHashMap; +use serde::Deserialize; +use slog::error; +use slog::Logger; + +use crate::hw::virtio::vsock::VsockVq; +use crate::vsock::buffer::VsockBuf; +use crate::vsock::buffer::VsockBufError; +use crate::vsock::packet::VsockPacket; +use crate::vsock::packet::VsockPacketHeader; +use crate::vsock::poller::PollEvents; +use crate::vsock::poller::VsockPoller; +use crate::vsock::poller::VsockPollerNotify; +use crate::vsock::VsockBackend; +use crate::vsock::VsockError; + +/// Default buffer size for guest->host data. +pub const CONN_TX_BUF_SIZE: usize = 1024 * 128; + +/// Connection lifecycle state for a vsock connection. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum ConnState { + // The guest has sent us a VIRTIO_VSOCK_OP_REQUEST + Init, + /// We have sent VIRTIO_VSOCK_OP_RESPONSE - connection can send/recv data + Established, + /// The connection is in the process of closing - read and write halves are + /// tracked seperately + Closing { + read: bool, + write: bool, + }, +} + +#[derive(Debug, Clone, Copy, Eq, Hash, PartialEq)] +pub struct ConnKey { + /// The port the guest is transmitting to. + pub(crate) host_port: u32, + /// The port the guest is transmitting from. + pub(crate) guest_port: u32, +} + +// This impl allows us to convert to and from a portev_user object (see +// port_associate3C). The conversion to and from a usize allows us to encode +// the key in the pointer value itself rather than allocating memory. +// +// NB: This object is defined as a `*mut c_void` and therefore will not be +// 64bits on all platforms, but we currently only support x86_64 hardware, +// therefore we are leaving a static assertion behind as a future hint to +// ourselves. +impl ConnKey { + /// Pack the host + port into a usize + pub fn to_portev_user(self) -> usize { + static_assertions::assert_eq_size!(u64, usize); + ((self.host_port as usize) << 32) | (self.guest_port as usize) + } + + /// Unpack the host + port from a usize + pub fn from_portev_user(val: usize) -> Self { + Self { host_port: (val >> 32) as u32, guest_port: val as u32 } + } +} + +#[derive(Debug, thiserror::Error)] +pub enum ProxyConnError { + #[error("Failed to connect to vsock backend {backend}: {source}")] + Socket { + backend: SocketAddr, + #[source] + source: std::io::Error, + }, + #[error("Failed to put socket into nonblocking mode: {0}")] + NonBlocking(#[source] std::io::Error), + #[error("Cannot transition connection from {from:?} to {to:?}")] + InvalidStateTransition { from: ConnState, to: ConnState }, +} + +/// An established guest<=>host connection +#[derive(Debug)] +pub struct VsockProxyConn { + pub(crate) socket: TcpStream, + /// Current connection state. + state: ConnState, + /// Ring buffer used to receive packets from the guest tx virt queue. + vbuf: VsockBuf, + /// Bytes we've consumed from vbuf (forwarded to socket). + fwd_cnt: Wrapping, + /// The fwd_cnt value we last sent to the guest in a credit update. + last_fwd_cnt_sent: Wrapping, + /// Bytes we've sent to the guest from the socket. + tx_cnt: Wrapping, + /// Guest's buffer allocation. + peer_buf_alloc: u32, + /// Bytes the guest has consumed from their buffer. + peer_fwd_cnt: Wrapping, +} + +impl VsockProxyConn { + /// Create a new `VsockProxyConn` connected to an underlying host socket. + pub fn new(addr: &SocketAddr) -> Result { + let socket = + TcpStream::connect_timeout(addr, Duration::from_millis(100)) + .map_err(|e| ProxyConnError::Socket { + backend: *addr, + source: e, + })?; + socket.set_nonblocking(true).map_err(ProxyConnError::NonBlocking)?; + + Ok(Self { + socket, + state: ConnState::Init, + vbuf: VsockBuf::new(NonZeroUsize::new(CONN_TX_BUF_SIZE).unwrap()), + fwd_cnt: Wrapping(0), + last_fwd_cnt_sent: Wrapping(0), + tx_cnt: Wrapping(0), + peer_buf_alloc: 0, + peer_fwd_cnt: Wrapping(0), + }) + } + + /// Set of `PollEvents` that this connection is interested in. + pub fn poll_interests(&self) -> Option { + let mut interests = PollEvents::empty(); + interests.set(PollEvents::OUT, self.has_buffered_data()); + interests.set(PollEvents::IN, self.guest_can_read()); + + match interests.is_empty() { + true => None, + false => Some(interests), + } + } + + /// Returns `true` if the connection has data pending in its ring buffer + /// that needs to be flushed to the underlying socket. + pub fn has_buffered_data(&self) -> bool { + !self.vbuf.is_empty() + } + + /// Set the connection to established. + pub fn set_established(&mut self) -> Result<(), ProxyConnError> { + match self.state { + ConnState::Init => self.state = ConnState::Established, + current => { + return Err(ProxyConnError::InvalidStateTransition { + from: current, + to: ConnState::Established, + }) + } + } + + Ok(()) + } + + /// Check if the connection can read from the host socket. + pub fn guest_can_read(&self) -> bool { + matches!( + self.state, + ConnState::Established | ConnState::Closing { read: false, .. } + ) + } + + pub fn shutdown_guest_read(&mut self) -> Result<(), ProxyConnError> { + self.state = match self.state { + ConnState::Established => { + ConnState::Closing { read: true, write: false } + } + ConnState::Closing { write, .. } => { + ConnState::Closing { read: true, write: write } + } + current => { + return Err(ProxyConnError::InvalidStateTransition { + from: current, + to: ConnState::Closing { read: true, write: false }, + }) + } + }; + + Ok(()) + } + + pub fn shutdown_guest_write(&mut self) -> Result<(), ProxyConnError> { + self.state = match self.state { + ConnState::Established => { + ConnState::Closing { read: false, write: true } + } + ConnState::Closing { read, .. } => { + ConnState::Closing { read, write: true } + } + current => { + return Err(ProxyConnError::InvalidStateTransition { + from: current, + to: ConnState::Closing { read: true, write: false }, + }) + } + }; + + Ok(()) + } + + /// Check if the connection should be removed. + pub fn should_close(&self) -> bool { + matches!(self.state, ConnState::Closing { read: true, write: true }) + } + + /// Update peer credit info from a packet header. + pub fn update_peer_credit(&mut self, header: &VsockPacketHeader) { + self.peer_buf_alloc = header.buf_alloc(); + self.peer_fwd_cnt = Wrapping(header.fwd_cnt()); + } + + /// Process a packet received from the guest tx queue. + pub fn recv_packet( + &mut self, + packet: VsockPacket, + ) -> Result<(), VsockBufError> { + self.vbuf.push(packet.data) + } + + pub fn flush(&mut self) -> std::io::Result { + self.vbuf.write_to(&mut self.socket) + } + + /// Calculate how much data we can send to the guest based on their credit. + pub fn peer_credit(&self) -> u32 { + let in_flight = (self.tx_cnt - self.peer_fwd_cnt).0; + self.peer_buf_alloc.saturating_sub(in_flight) + } + + /// Update fwd_cnt after consuming data from vbuf. + pub fn update_fwd_cnt(&mut self, bytes: u32) { + self.fwd_cnt += Wrapping(bytes); + } + + /// Update tx_cnt after sending data to guest. + pub fn update_tx_cnt(&mut self, bytes: u32) { + self.tx_cnt += Wrapping(bytes); + } + + /// Get our current fwd_cnt to report to the guest. + pub fn fwd_cnt(&self) -> u32 { + self.fwd_cnt.0 + } + + /// Get our buffer allocation to report to the guest. + pub fn buf_alloc(&self) -> u32 { + self.vbuf.capacity() as u32 + } + + /// Check if we should send a credit update to the guest. + /// + /// Returns true if we've consumed more than half of our buffer capacity + /// since the last credit update was sent. + pub fn needs_credit_update(&self) -> bool { + let bytes_consumed_since_update = + (self.fwd_cnt - self.last_fwd_cnt_sent).0; + bytes_consumed_since_update > (self.vbuf.capacity() / 2) as u32 + } + + /// Mark that we've sent a credit update with the current fwd_cnt. + pub fn mark_credit_sent(&mut self) { + self.last_fwd_cnt_sent = self.fwd_cnt; + } + + pub fn get_fd(&self) -> RawFd { + self.socket.as_raw_fd() + } +} + +#[derive(Deserialize, Debug, Clone, Copy)] +pub struct VsockPortMapping { + port: u32, + // TODO this could be extended to support Unix sockets as well. + addr: SocketAddr, +} + +impl VsockPortMapping { + pub fn new(port: u32, addr: SocketAddr) -> Self { + Self { port, addr } + } + + pub fn addr(&self) -> &SocketAddr { + &self.addr + } +} + +impl IdHashItem for VsockPortMapping { + type Key<'a> = u32; + + fn key(&self) -> Self::Key<'_> { + self.port + } + + iddqd::id_upcast!(); +} + +/// virtio-socket backend that proxies between a guest and a host UDS. +pub struct VsockProxy { + log: Logger, + poller: VsockPollerNotify, + _evloop_handle: JoinHandle<()>, +} + +impl VsockProxy { + pub fn new( + cid: u32, + queues: VsockVq, + log: Logger, + port_mappings: IdHashMap, + ) -> Self { + let evloop = + VsockPoller::new(cid, queues, log.clone(), port_mappings).unwrap(); + let poller = evloop.notify_handle(); + let jh = evloop.run(); + + Self { log, poller, _evloop_handle: jh } + } + + /// Notification from the vsock device that one of the queues has had an + /// event. + fn queue_notify(&self, vq_id: u16) -> std::io::Result<()> { + self.poller.queue_notify(vq_id) + } +} + +impl VsockBackend for VsockProxy { + fn queue_notify(&self, queue_id: u16) -> Result<(), VsockError> { + self.queue_notify(queue_id) + // Log the raw error in additon to returning the top level + // `VsockError` + .inspect_err(|_e| { + error!(&self.log, + "failed to send virtqueue notification"; + "queue" => %queue_id, + ) + }) + .map_err(|_| VsockError::QueueNotify(queue_id)) + } +}