From 224fdbe22799b797049005ea71d073073201805c Mon Sep 17 00:00:00 2001 From: Alex Zenla Date: Mon, 5 Aug 2024 17:29:09 -0700 Subject: [PATCH] fix(idm): process all idm messages in the same frame and use childwait exit notification for exec (fixes #290) (#302) --- crates/daemon/src/idm.rs | 50 ++++++++++++++++++---------------- crates/krata/src/idm/client.rs | 13 +++++---- crates/xen/xenstore/src/sys.rs | 13 --------- crates/zone/src/background.rs | 10 ++++--- crates/zone/src/childwait.rs | 25 +++++++++-------- crates/zone/src/exec.rs | 33 +++++++++++++++------- 6 files changed, 77 insertions(+), 67 deletions(-) diff --git a/crates/daemon/src/idm.rs b/crates/daemon/src/idm.rs index 6ff16074..37581474 100644 --- a/crates/daemon/src/idm.rs +++ b/crates/daemon/src/idm.rs @@ -136,34 +136,36 @@ impl DaemonIdm { if let Some(data) = data { let buffer = buffers.entry(domid).or_insert_with_key(|_| BytesMut::new()); buffer.extend_from_slice(&data); - if buffer.len() < 6 { - continue; - } + loop { + if buffer.len() < 6 { + break; + } - if buffer[0] != 0xff || buffer[1] != 0xff { - buffer.clear(); - continue; - } + if buffer[0] != 0xff || buffer[1] != 0xff { + buffer.clear(); + break; + } - let size = (buffer[2] as u32 | (buffer[3] as u32) << 8 | (buffer[4] as u32) << 16 | (buffer[5] as u32) << 24) as usize; - let needed = size + 6; - if buffer.len() < needed { - continue; - } - let mut packet = buffer.split_to(needed); - packet.advance(6); - match IdmTransportPacket::decode(packet) { - Ok(packet) => { - let _ = client_or_create(domid, &self.tx_sender, &self.clients, &self.feeds).await?; - let guard = self.feeds.lock().await; - if let Some(feed) = guard.get(&domid) { - let _ = feed.try_send(packet.clone()); - } - let _ = self.snoop_sender.send(DaemonIdmSnoopPacket { from: domid, to: 0, packet }); + let size = (buffer[2] as u32 | (buffer[3] as u32) << 8 | (buffer[4] as u32) << 16 | (buffer[5] as u32) << 24) as usize; + let needed = size + 6; + if buffer.len() < needed { + break; } + let mut packet = buffer.split_to(needed); + packet.advance(6); + match IdmTransportPacket::decode(packet) { + Ok(packet) => { + let _ = client_or_create(domid, &self.tx_sender, &self.clients, &self.feeds).await?; + let guard = self.feeds.lock().await; + if let Some(feed) = guard.get(&domid) { + let _ = feed.try_send(packet.clone()); + } + let _ = self.snoop_sender.send(DaemonIdmSnoopPacket { from: domid, to: 0, packet }); + } - Err(packet) => { - warn!("received invalid packet from domain {}: {}", domid, packet); + Err(packet) => { + warn!("received invalid packet from domain {}: {}", domid, packet); + } } } } else { diff --git a/crates/krata/src/idm/client.rs b/crates/krata/src/idm/client.rs index a9352500..6e34f59e 100644 --- a/crates/krata/src/idm/client.rs +++ b/crates/krata/src/idm/client.rs @@ -9,6 +9,7 @@ use std::{ }; use anyhow::{anyhow, Result}; +use bytes::{BufMut, BytesMut}; use log::{debug, error}; use nix::sys::termios::{cfmakeraw, tcgetattr, tcsetattr, SetArg}; use prost::Message; @@ -96,10 +97,12 @@ impl IdmBackend for IdmFileBackend { async fn send(&mut self, packet: IdmTransportPacket) -> Result<()> { let mut file = self.write.lock().await; - let data = packet.encode_to_vec(); - file.write_all(&[0xff, 0xff]).await?; - file.write_u32_le(data.len() as u32).await?; - file.write_all(&data).await?; + let length = packet.encoded_len(); + let mut buffer = BytesMut::with_capacity(6 + length); + buffer.put_slice(&[0xff, 0xff]); + buffer.put_u32_le(length as u32); + packet.encode(&mut buffer)?; + file.write_all(&buffer).await?; Ok(()) } } @@ -488,7 +491,7 @@ impl IdmClient { error!("unable to send idm packet, packet size exceeded (tried to send {} bytes)", length); continue; } - backend.send(packet).await?; + backend.send(packet.clone()).await?; }, None => { diff --git a/crates/xen/xenstore/src/sys.rs b/crates/xen/xenstore/src/sys.rs index f25e1df4..fb575191 100644 --- a/crates/xen/xenstore/src/sys.rs +++ b/crates/xen/xenstore/src/sys.rs @@ -143,19 +143,6 @@ pub const XSD_ERROR_EPERM: XsdError = XsdError { pub const XSD_WATCH_PATH: u32 = 0; pub const XSD_WATCH_TOKEN: u32 = 1; -#[repr(C)] -pub struct XenDomainInterface { - req: [i8; 1024], - rsp: [i8; 1024], - req_cons: u32, - req_prod: u32, - rsp_cons: u32, - rsp_prod: u32, - server_features: u32, - connection: u32, - error: u32, -} - pub const XS_PAYLOAD_MAX: u32 = 4096; pub const XS_ABS_PATH_MAX: u32 = 3072; pub const XS_REL_PATH_MAX: u32 = 2048; diff --git a/crates/zone/src/background.rs b/crates/zone/src/background.rs index bb44d514..40ba8d91 100644 --- a/crates/zone/src/background.rs +++ b/crates/zone/src/background.rs @@ -39,6 +39,7 @@ impl ZoneBackground { let mut event_subscription = self.idm.subscribe().await?; let mut requests_subscription = self.idm.requests().await?; let mut request_streams_subscription = self.idm.request_streams().await?; + let mut wait_subscription = self.wait.subscribe().await?; loop { select! { x = event_subscription.recv() => match x { @@ -85,9 +86,9 @@ impl ZoneBackground { } }, - event = self.wait.recv() => match event { - Some(event) => self.child_event(event).await?, - None => { + event = wait_subscription.recv() => match event { + Ok(event) => self.child_event(event).await?, + Err(_) => { break; } } @@ -128,9 +129,10 @@ impl ZoneBackground { &mut self, handle: IdmClientStreamResponseHandle, ) -> Result<()> { + let wait = self.wait.clone(); if let Some(RequestType::ExecStream(_)) = &handle.initial.request { tokio::task::spawn(async move { - let exec = ZoneExecTask { handle }; + let exec = ZoneExecTask { wait, handle }; if let Err(error) = exec.run().await { let _ = exec .handle diff --git a/crates/zone/src/childwait.rs b/crates/zone/src/childwait.rs index 331e8c00..309aabab 100644 --- a/crates/zone/src/childwait.rs +++ b/crates/zone/src/childwait.rs @@ -11,7 +11,7 @@ use anyhow::Result; use libc::{c_int, waitpid, WEXITSTATUS, WIFEXITED}; use log::warn; use nix::unistd::Pid; -use tokio::sync::mpsc::{channel, Receiver, Sender}; +use tokio::sync::broadcast::{channel, Receiver, Sender}; const CHILD_WAIT_QUEUE_LEN: usize = 10; @@ -21,18 +21,19 @@ pub struct ChildEvent { pub status: c_int, } +#[derive(Clone)] pub struct ChildWait { - receiver: Receiver, + sender: Sender, signal: Arc, - _task: JoinHandle<()>, + _task: Arc>, } impl ChildWait { pub fn new() -> Result { - let (sender, receiver) = channel(CHILD_WAIT_QUEUE_LEN); + let (sender, _) = channel(CHILD_WAIT_QUEUE_LEN); let signal = Arc::new(AtomicBool::new(false)); let mut processor = ChildWaitTask { - sender, + sender: sender.clone(), signal: signal.clone(), }; let task = thread::spawn(move || { @@ -41,14 +42,14 @@ impl ChildWait { } }); Ok(ChildWait { - receiver, + sender, signal, - _task: task, + _task: Arc::new(task), }) } - pub async fn recv(&mut self) -> Option { - self.receiver.recv().await + pub async fn subscribe(&self) -> Result> { + Ok(self.sender.subscribe()) } } @@ -68,7 +69,7 @@ impl ChildWaitTask { pid: Pid::from_raw(pid), status: WEXITSTATUS(status), }; - let _ = self.sender.try_send(event); + let _ = self.sender.send(event); if self.signal.load(Ordering::Acquire) { return Ok(()); @@ -80,6 +81,8 @@ impl ChildWaitTask { impl Drop for ChildWait { fn drop(&mut self) { - self.signal.store(true, Ordering::Release); + if Arc::strong_count(&self.signal) <= 1 { + self.signal.store(true, Ordering::Release); + } } } diff --git a/crates/zone/src/exec.rs b/crates/zone/src/exec.rs index 35f0f59c..941fc14a 100644 --- a/crates/zone/src/exec.rs +++ b/crates/zone/src/exec.rs @@ -1,6 +1,12 @@ use std::{collections::HashMap, process::Stdio}; use anyhow::{anyhow, Result}; +use tokio::{ + io::{AsyncReadExt, AsyncWriteExt}, + join, + process::Command, +}; + use krata::idm::{ client::IdmClientStreamResponseHandle, internal::{ @@ -9,13 +15,11 @@ use krata::idm::{ }, internal::{response::Response as ResponseType, Request, Response}, }; -use tokio::{ - io::{AsyncReadExt, AsyncWriteExt}, - join, - process::Command, -}; + +use crate::childwait::ChildWait; pub struct ZoneExecTask { + pub wait: ChildWait, pub handle: IdmClientStreamResponseHandle, } @@ -58,6 +62,7 @@ impl ZoneExecTask { start.working_directory.clone() }; + let mut wait_subscription = self.wait.subscribe().await?; let mut child = Command::new(exe) .args(cmd) .envs(env) @@ -69,6 +74,7 @@ impl ZoneExecTask { .spawn() .map_err(|error| anyhow!("failed to spawn: {}", error))?; + let pid = child.id().ok_or_else(|| anyhow!("pid is not provided"))?; let mut stdin = child .stdin .take() @@ -150,12 +156,19 @@ impl ZoneExecTask { } }); - let exit = child.wait().await?; - let code = exit.code().unwrap_or(-1); - - let _ = join!(stdout_task, stderr_task); - stdin_task.abort(); + let data_task = tokio::task::spawn(async move { + let _ = join!(stdout_task, stderr_task); + stdin_task.abort(); + }); + let code = loop { + if let Ok(event) = wait_subscription.recv().await { + if event.pid.as_raw() as u32 == pid { + break event.status; + } + } + }; + data_task.await?; let response = Response { response: Some(ResponseType::ExecStream(ExecStreamResponseUpdate { exited: true,