diff --git a/src/executor/vsock.rs b/src/executor/vsock.rs index 54b2f20c47..c46756f2f6 100644 --- a/src/executor/vsock.rs +++ b/src/executor/vsock.rs @@ -33,6 +33,9 @@ pub(crate) const RAW_SOCKET_BUFFER_SIZE: usize = 256 * 1024; pub(crate) struct RawSocket { pub remote_cid: u32, pub remote_port: u32, + /// The listen port this connection was accepted on. Zero for listener and + /// outbound-connect sockets. + pub listen_port: u32, pub fwd_cnt: u32, pub peer_fwd_cnt: u32, pub peer_buf_alloc: u32, @@ -48,6 +51,7 @@ impl RawSocket { Self { remote_cid: 0, remote_port: 0, + listen_port: 0, fwd_cnt: 0, peer_fwd_cnt: 0, peer_buf_alloc: 0, @@ -78,7 +82,22 @@ async fn vsock_run() { let mut vsock_guard = VSOCK_MAP.lock(); let header_cid: u32 = header.src_cid.to_ne().try_into().unwrap(); - let Some(raw) = vsock_guard.get_mut_socket(port) else { + // For data/shutdown packets, prefer a connected socket that was + // accepted on this port over the listener entry itself. + let header_cid_inner: u32 = header_cid; + let raw_port = header.src_port.to_ne(); + let raw = if matches!(op, Op::Rw | Op::Shutdown | Op::CreditUpdate | Op::Response) { + if let Some(conn) = vsock_guard.get_mut_connected(port, header_cid_inner, raw_port) + { + conn + } else if let Some(s) = vsock_guard.get_mut_socket(port) { + s + } else { + return; + } + } else if let Some(s) = vsock_guard.get_mut_socket(port) { + s + } else { return; }; @@ -206,9 +225,48 @@ impl VsockMap { self.port_map.get_mut(&port) } + /// Look up a connected socket by its original listen port and the remote + /// endpoint. Used to route data packets after a connection has been moved + /// to an ephemeral port by `move_to_ephemeral`. + pub fn get_mut_connected( + &mut self, + listen_port: u32, + remote_cid: u32, + remote_port: u32, + ) -> Option<&mut RawSocket> { + self.port_map.values_mut().find(|raw| { + raw.state == VsockState::Connected + && raw.listen_port == listen_port + && raw.remote_cid == remote_cid + && raw.remote_port == remote_port + }) + } + pub fn remove_socket(&mut self, port: u32) { self.port_map.remove(&port); } + + /// Move the socket at `listen_port` to a fresh ephemeral port, reset the + /// listener entry to `Listen`, and return the ephemeral port. + pub fn move_to_ephemeral(&mut self, listen_port: u32) -> io::Result { + let mut conn = self.port_map.remove(&listen_port).ok_or(Errno::Inval)?; + conn.state = VsockState::Connected; + conn.listen_port = listen_port; + + for ep in u32::MAX / 4..u32::MAX { + if let btree_map::Entry::Vacant(v) = self.port_map.entry(ep) { + v.insert(conn); + self.port_map + .insert(listen_port, RawSocket::new(VsockState::Listen)); + return Ok(ep); + } + } + + // No ephemeral port available; restore the entry to avoid losing it. + self.port_map + .insert(listen_port, RawSocket::new(VsockState::Listen)); + Err(Errno::Badf) + } } pub(crate) fn init() { diff --git a/src/fd/socket/vsock.rs b/src/fd/socket/vsock.rs index a6c0131941..29003ad4b2 100644 --- a/src/fd/socket/vsock.rs +++ b/src/fd/socket/vsock.rs @@ -51,6 +51,9 @@ impl ObjectInterface for NullSocket {} pub struct Socket { port: u32, + /// The port this socket is bound/listening on. Stays fixed across accepts + /// while `port` is updated to the ephemeral connection port after each accept. + listen_port: u32, cid: u32, is_nonblocking: bool, } @@ -59,6 +62,7 @@ impl Socket { pub fn new() -> Self { Self { port: 0, + listen_port: 0, cid: u32::MAX, is_nonblocking: false, } @@ -139,6 +143,7 @@ impl ObjectInterface for Socket { match endpoint { ListenEndpoint::Vsock(ep) => { self.port = ep.port; + self.listen_port = ep.port; if let Some(cid) = ep.cid { self.cid = cid; } else { @@ -234,10 +239,10 @@ impl ObjectInterface for Socket { } async fn accept(&mut self) -> io::Result<(Arc>, Endpoint)> { - let port = self.port; + let port = self.listen_port; let cid = self.cid; - let endpoint = future::poll_fn(|cx| { + let (conn_port, endpoint) = future::poll_fn(|cx| { let mut guard = VSOCK_MAP.lock(); let raw = guard.get_mut_socket(port).ok_or(Errno::Inval)?; @@ -251,44 +256,46 @@ impl ObjectInterface for Socket { } } VsockState::ReceiveRequest => { - let result = { - const HEADER_SIZE: usize = size_of::(); - let mut driver_guard = hardware::get_vsock_driver().unwrap().lock(); - let local_cid = driver_guard.get_cid(); - - driver_guard.send_packet(HEADER_SIZE, |buffer| { - let response = unsafe { &mut *buffer.as_mut_ptr().cast::() }; - - response.src_cid = le64::from_ne(local_cid); - response.dst_cid = le64::from_ne(raw.remote_cid.into()); - response.src_port = le32::from_ne(port); - response.dst_port = le32::from_ne(raw.remote_port); - response.len = le32::from_ne(0); - response.type_ = le16::from_ne(Type::Stream.into()); - if local_cid != u64::from(cid) && cid != u32::MAX { - response.op = le16::from_ne(Op::Rst.into()); - } else { - response.op = le16::from_ne(Op::Response.into()); - } - response.flags = le32::from_ne(0); - response.buf_alloc = le32::from_ne( - crate::executor::vsock::RAW_SOCKET_BUFFER_SIZE as u32, - ); - response.fwd_cnt = le32::from_ne(raw.fwd_cnt); - }); + const HEADER_SIZE: usize = size_of::(); + let mut driver_guard = hardware::get_vsock_driver().unwrap().lock(); + let local_cid = driver_guard.get_cid(); + + driver_guard.send_packet(HEADER_SIZE, |buffer| { + let response = unsafe { &mut *buffer.as_mut_ptr().cast::() }; + + response.src_cid = le64::from_ne(local_cid); + response.dst_cid = le64::from_ne(raw.remote_cid.into()); + response.src_port = le32::from_ne(port); + response.dst_port = le32::from_ne(raw.remote_port); + response.len = le32::from_ne(0); + response.type_ = le16::from_ne(Type::Stream.into()); + if local_cid != u64::from(cid) && cid != u32::MAX { + response.op = le16::from_ne(Op::Rst.into()); + } else { + response.op = le16::from_ne(Op::Response.into()); + } + response.flags = le32::from_ne(0); + response.buf_alloc = + le32::from_ne(crate::executor::vsock::RAW_SOCKET_BUFFER_SIZE as u32); + response.fwd_cnt = le32::from_ne(raw.fwd_cnt); + }); - raw.state = VsockState::Connected; + let endpoint = VsockEndpoint::new(raw.remote_port, raw.remote_cid); - Ok(VsockEndpoint::new(raw.remote_port, raw.remote_cid)) - }; + // Move the accepted connection to an ephemeral port so the + // listener entry can be reset to Listen for the next accept. + let conn_port = guard.move_to_ephemeral(port)?; - Poll::Ready(result) + Poll::Ready(Ok((conn_port, endpoint))) } _ => Poll::Ready(Err(Errno::Badf)), } }) .await?; + // This Socket now tracks the accepted connection, not the listener. + self.port = conn_port; + Ok(( Arc::new(async_lock::RwLock::new(NullSocket::new().into())), Endpoint::Vsock(endpoint), diff --git a/xtask/src/ci/qemu.rs b/xtask/src/ci/qemu.rs index 0e60815491..3f7d95ef25 100644 --- a/xtask/src/ci/qemu.rs +++ b/xtask/src/ci/qemu.rs @@ -176,12 +176,19 @@ impl Qemu { .any(|feature| feature == "client"); test_vsock(has_client)?; } + "vsock_server" => { + test_vsock_server()?; + } _ => {} } if matches!( image_name, - "axum-example" | "http_server" | "http_server_poll" | "http_server_select" | "vsock" + "axum-example" + | "http_server" + | "http_server_poll" + | "http_server_select" + | "vsock" | "vsock_server" ) || self.devices.contains(&Device::CadenceGem) // sifive_u, on which we test CadenceGem, does not support software shutdowns, so we have to terminate the machine ourselves. { @@ -624,6 +631,30 @@ fn test_vsock(has_client: bool) -> Result<()> { Ok(()) } +fn test_vsock_server() -> Result<()> { + const PORT: u32 = 9975; + const CONNECTIONS: usize = 2; + + thread::sleep(Duration::from_secs(10)); + let first_stream = VsockStream::connect_with_cid_port(3, PORT)?; + + let do_ping_pong = |mut stream: VsockStream| -> Result<()> { + stream.write_all(b"ping")?; + let mut buf = [0u8; 64]; + let n = stream.read(&mut buf)?; + let msg = from_utf8(&buf[..n])?; + ensure!(msg == "pong", "expected 'pong', got {msg:?}"); + Ok(()) + }; + + do_ping_pong(first_stream)?; + for _ in 1..CONNECTIONS { + do_ping_pong(VsockStream::connect_with_cid_port(3, PORT)?)?; + } + + Ok(()) +} + fn test_http_server(guest_ip: IpAddr) -> Result<()> { thread::sleep(Duration::from_secs(10)); let url = format!("http://{guest_ip}:9975");