diff --git a/dev_tests/src/ratchet.rs b/dev_tests/src/ratchet.rs index 99887e9cd..06c44a080 100644 --- a/dev_tests/src/ratchet.rs +++ b/dev_tests/src/ratchet.rs @@ -39,7 +39,7 @@ fn ratchet_globals() -> Result<()> { ("litebox_platform_linux_userland/", 5), ("litebox_platform_lvbs/", 24), ("litebox_platform_multiplex/", 1), - ("litebox_platform_windows_userland/", 7), + ("litebox_platform_windows_userland/", 8), ("litebox_runner_linux_userland/", 1), ("litebox_runner_lvbs/", 4), ("litebox_runner_snp/", 1), diff --git a/litebox/src/event/wait.rs b/litebox/src/event/wait.rs index af0bdf14d..22d11ae87 100644 --- a/litebox/src/event/wait.rs +++ b/litebox/src/event/wait.rs @@ -374,6 +374,10 @@ impl<'a, Platform: RawSyncPrimitivesProvider + TimeProvider> WaitContext<'a, Pla /// evaluating the wait and interrupt conditions so that wakeups are not /// missed. fn start_wait(&self) { + self.waker + .0 + .platform + .on_interruptible_wait_start(self.waker); self.waker .0 .set_state(ThreadState::WAITING, Ordering::SeqCst); @@ -384,6 +388,7 @@ impl<'a, Platform: RawSyncPrimitivesProvider + TimeProvider> WaitContext<'a, Pla self.waker .0 .set_state(ThreadState::RUNNING_IN_HOST, Ordering::Relaxed); + self.waker.0.platform.on_interruptible_wait_end(); } /// Checks whether the wait should be interrupted. If not, then performs diff --git a/litebox/src/platform/mod.rs b/litebox/src/platform/mod.rs index 3e81b0dfc..95002c9e4 100644 --- a/litebox/src/platform/mod.rs +++ b/litebox/src/platform/mod.rs @@ -95,6 +95,34 @@ pub trait ThreadProvider: RawPointerProvider { /// [`EnterShim`]: crate::shim::EnterShim /// [`EnterShim::interrupt`]: crate::shim::EnterShim::interrupt fn interrupt_thread(&self, thread: &Self::ThreadHandle); + + /// Runs `f` on the current thread after performing any platform-specific + /// thread registration needed for [`current_thread`](Self::current_thread) + /// and related functionality to work. + /// + /// This is intended for test threads that do not go through the normal + /// [`spawn_thread`](Self::spawn_thread) / guest entry path. The platform + /// sets up thread state before calling `f` and tears it down afterward. + /// + /// The default implementation simply calls `f()` with no additional setup. + /// Platforms that require explicit thread registration should override this. + #[cfg(debug_assertions)] + fn run_test_thread(f: impl FnOnce() -> R) -> R { + f() + } +} + +/// Provider for consuming platform-originating signals. +/// +/// Platforms can record signals (e.g., `SIGINT`) and the shim should call +/// [`SignalProvider::take_pending_signals`] to consume them. +pub trait SignalProvider { + /// Atomically take all pending asynchronous signals (e.g., SIGINT and SIGALRM) + /// for the current thread, passing each one to `f`. + /// + /// Platforms that support asynchronous signals should override this method. + #[allow(unused_variables, reason = "no-op by default")] + fn take_pending_signals(&self, f: impl FnMut(crate::shim::Signal)) {} } /// Punch through any functionality for a particular platform that is not explicitly part of the @@ -220,6 +248,25 @@ where /// A provider of raw mutexes pub trait RawMutexProvider { type RawMutex: RawMutex; + + /// Called when a thread enters an interruptible wait. + /// + /// The passed `waker` should live at least until the thread leaves the interruptible + /// wait (i.e., [`on_interruptible_wait_end`](Self::on_interruptible_wait_end) is called). + /// The platform can use the `waker` to wake up the thread while it is in the interruptible wait. + /// + /// This is a no-op by default. + #[allow(unused_variables)] + fn on_interruptible_wait_start(&self, waker: &crate::event::wait::Waker) + where + Self: crate::sync::RawSyncPrimitivesProvider + Sized, + { + } + + /// Called when a thread leaves an interruptible wait. + /// + /// This is a no-op by default. + fn on_interruptible_wait_end(&self) {} } /// A raw mutex/lock API; expected to roughly match (or even be implemented using) a Linux futex. diff --git a/litebox/src/shim.rs b/litebox/src/shim.rs index 432eb6c5e..6ffee69fa 100644 --- a/litebox/src/shim.rs +++ b/litebox/src/shim.rs @@ -133,3 +133,108 @@ impl Exception { /// #PF pub const PAGE_FAULT: Self = Self(14); } + +/// A signal number. +/// +/// Signal numbers are 1-based and must be in the range 1–63. +#[derive(Copy, Clone, Debug, Eq, PartialEq)] +pub struct Signal(u32); + +impl Signal { + /// SIGINT (signal 2) — interrupt from keyboard (Ctrl+C). + pub const SIGINT: Self = Self(2); + /// SIGALRM (signal 14) — timer signal from `alarm`. + pub const SIGALRM: Self = Self(14); + + /// Create a `Signal` from a raw signal number. + /// + /// Returns `None` if `signum` is outside the valid range 1–63. + pub const fn from_raw(signum: u32) -> Option { + if signum >= 1 && signum <= 63 { + Some(Self(signum)) + } else { + None + } + } + + /// Returns the raw signal number. + pub const fn as_raw(self) -> u32 { + self.0 + } +} + +/// A set of [`Signal`]s, stored as a 64-bit bitmask. +/// +/// Bit `(signum - 1)` is set when signal `signum` is present in the set. +/// Because signal numbers are 1-based and capped at 63, all 63 possible +/// signals fit in a single `u64`. +#[derive(Clone, Debug, Eq, PartialEq)] +pub struct SigSet(u64); + +impl SigSet { + /// An empty signal set. + pub const fn empty() -> Self { + Self(0) + } + + /// Returns `true` if the set contains no signals. + pub const fn is_empty(&self) -> bool { + self.0 == 0 + } + + /// Adds `signal` to the set. + pub const fn add(&mut self, signal: Signal) { + self.0 |= 1 << (signal.0 - 1); + } + + /// Returns a new set that is `self` with `signal` added. + #[must_use] + pub const fn with(self, signal: Signal) -> Self { + Self(self.0 | (1 << (signal.0 - 1))) + } + + /// Removes `signal` from the set. + pub const fn remove(&mut self, signal: Signal) { + self.0 &= !(1 << (signal.0 - 1)); + } + + /// Returns `true` if the set contains `signal`. + pub const fn contains(&self, signal: Signal) -> bool { + (self.0 & (1 << (signal.0 - 1))) != 0 + } + + /// Removes and returns the lowest-numbered signal in the set, or `None` + /// if empty. + pub fn pop_lowest(&mut self) -> Option { + if self.0 == 0 { + return None; + } + let bit = self.0.trailing_zeros(); + self.0 &= !(1u64 << bit); + // bit is 0–62, so bit + 1 is 1–63 — always valid. + Some(Signal(bit + 1)) + } + + /// Creates a `SigSet` from a raw `u64` bitmask. + pub const fn from_u64(bits: u64) -> Self { + Self(bits) + } + + /// Returns the underlying `u64` bitmask. + pub const fn as_u64(&self) -> u64 { + self.0 + } +} + +impl Iterator for SigSet { + type Item = Signal; + + fn next(&mut self) -> Option { + self.pop_lowest() + } + + fn size_hint(&self) -> (usize, Option) { + let count = self.0.count_ones() as usize; + (count, Some(count)) + } +} diff --git a/litebox_common_linux/src/signal/mod.rs b/litebox_common_linux/src/signal/mod.rs index 713a24559..1f3794241 100644 --- a/litebox_common_linux/src/signal/mod.rs +++ b/litebox_common_linux/src/signal/mod.rs @@ -102,6 +102,17 @@ impl TryFrom for Signal { } } } +impl TryFrom for litebox::shim::Signal { + type Error = Signal; + + fn try_from(value: Signal) -> Result { + match value { + Signal::SIGINT => Ok(Self::SIGINT), + Signal::SIGALRM => Ok(Self::SIGALRM), + _ => Err(value), + } + } +} /// The default disposition of a signal. pub enum SignalDisposition { diff --git a/litebox_platform_linux_kernel/src/lib.rs b/litebox_platform_linux_kernel/src/lib.rs index 9a4b7d230..3983e6a65 100644 --- a/litebox_platform_linux_kernel/src/lib.rs +++ b/litebox_platform_linux_kernel/src/lib.rs @@ -10,13 +10,13 @@ use core::sync::atomic::AtomicU64; use core::{arch::asm, sync::atomic::AtomicU32}; use litebox::mm::linux::PageRange; +use litebox::platform::RawPointerProvider; use litebox::platform::page_mgmt::FixedAddressBehavior; use litebox::platform::{ DebugLogProvider, IPInterfaceProvider, ImmediatelyWokenUp, PageManagementProvider, Provider, - Punchthrough, PunchthroughProvider, PunchthroughToken, RawMutexProvider, TimeProvider, - UnblockedOrTimedOut, + Punchthrough, PunchthroughProvider, PunchthroughToken, RawMutexProvider, SignalProvider, + TimeProvider, UnblockedOrTimedOut, }; -use litebox::platform::{RawMutex as _, RawPointerProvider}; use litebox_common_linux::PunchthroughSyscall; use litebox_common_linux::errno::Errno; @@ -79,6 +79,7 @@ impl<'a, Host: HostInterface> PunchthroughToken for LinuxPunchthroughToken<'a, H } impl Provider for LinuxKernel {} +impl SignalProvider for LinuxKernel {} // TODO: implement pointer validation to ensure the pointers are in user space. type UserConstPtr = litebox::platform::common_providers::userspace_pointers::UserConstPtr< @@ -180,33 +181,16 @@ impl RawMutex { val: u32, timeout: Option, ) -> Result { - loop { - // No need to wait if the value already changed. - if self - .underlying_atomic() - .load(core::sync::atomic::Ordering::Relaxed) - != val - { - return Err(ImmediatelyWokenUp); + match Host::block_or_maybe_timeout(&self.inner, val, timeout) { + Ok(()) | Err(Errno::EINTR) => Ok(UnblockedOrTimedOut::Unblocked), + Err(Errno::EAGAIN) => { + // If the futex value does not match val, then the call fails + // immediately with the error EAGAIN. + Err(ImmediatelyWokenUp) } - - let ret = Host::block_or_maybe_timeout(&self.inner, val, timeout); - - match ret { - Ok(()) => { - return Ok(UnblockedOrTimedOut::Unblocked); - } - Err(Errno::EAGAIN | Errno::EINTR) => { - // If the futex value does not match val, then the call fails - // immediately with the error EAGAIN. - return Err(ImmediatelyWokenUp); - } - Err(Errno::ETIMEDOUT) => { - return Ok(UnblockedOrTimedOut::TimedOut); - } - Err(e) => { - todo!("Error: {:?}", e); - } + Err(Errno::ETIMEDOUT) => Ok(UnblockedOrTimedOut::TimedOut), + Err(e) => { + todo!("Error: {:?}", e); } } } diff --git a/litebox_platform_linux_userland/src/lib.rs b/litebox_platform_linux_userland/src/lib.rs index 3653f7000..f41a38f78 100644 --- a/litebox_platform_linux_userland/src/lib.rs +++ b/litebox_platform_linux_userland/src/lib.rs @@ -11,6 +11,7 @@ use std::cell::Cell; use std::os::fd::{AsRawFd as _, FromRawFd as _}; use std::sync::atomic::{AtomicI32, AtomicU32, Ordering}; use std::time::Duration; +use std::unimplemented; use litebox::fs::OFlags; use litebox::platform::UnblockedOrTimedOut; @@ -28,6 +29,84 @@ mod syscall_intercept; extern crate alloc; +// --------------------------------------------------------------------------- +// TLS (`.tbss`) access helpers +// +// On x86_64, the ELF TLS model uses `@tpoff`; on x86 it uses `@ntpoff`. +// At guest-host transitions we swap `fs` and `gs`, so after the swap the host TLS base +// is in the normal segment register. Before the swap (e.g. in a signal +// handler that fires while the guest is running), the host TLS base is +// in the *saved* segment register (`gs` on x86_64, `fs` on x86). +// +// The macros below produce string literals so they can be used inside +// `concat!()` within `core::arch::asm!()`. +// --------------------------------------------------------------------------- + +/// TLS relocation suffix: `"@tpoff"` on x86_64, `"@ntpoff"` on x86. +#[cfg(target_arch = "x86_64")] +macro_rules! tls_suffix { + () => { + "@tpoff" + }; +} +#[cfg(target_arch = "x86")] +macro_rules! tls_suffix { + () => { + "@ntpoff" + }; +} + +/// Segment register used for TLS after the fs/gs swap (normal host context). +#[cfg(target_arch = "x86_64")] +macro_rules! tls_seg { + () => { + "fs" + }; +} +#[cfg(target_arch = "x86")] +macro_rules! tls_seg { + () => { + "gs" + }; +} + +/// Segment register where the host TLS base is saved before the swap +/// (signal handler context while the guest is running). +#[cfg(target_arch = "x86_64")] +macro_rules! saved_tls_seg { + () => { + "gs" + }; +} +#[cfg(target_arch = "x86")] +macro_rules! saved_tls_seg { + () => { + "fs" + }; +} + +/// Full TLS memory operand for a `.tbss` variable in normal host context +/// (after the fs/gs swap). +/// +/// Example: `tls!("pending_host_signals")` expands to +/// `"fs:pending_host_signals@tpoff"` on x86_64. +macro_rules! tls { + ($var:literal) => { + concat!(tls_seg!(), ":", $var, tls_suffix!()) + }; +} + +/// Full TLS memory operand for a `.tbss` variable accessed via the *saved* +/// segment register (before the fs/gs swap, e.g. from a signal handler). +/// +/// Example: `saved_tls!("in_guest")` expands to +/// `"gs:in_guest@tpoff"` on x86_64. +macro_rules! saved_tls { + ($var:literal) => { + concat!(saved_tls_seg!(), ":", $var, tls_suffix!()) + }; +} + /// The userland Linux platform. /// /// This implements the main [`litebox::platform::Provider`] trait, i.e., implements all platform @@ -308,6 +387,31 @@ impl LinuxUserland { impl litebox::platform::Provider for LinuxUserland {} +impl litebox::platform::SignalProvider for LinuxUserland { + fn take_pending_signals(&self, mut f: impl FnMut(litebox::shim::Signal)) { + let sigs = take_pending_host_signals(); + for sig in sigs { + f(sig); + } + } +} + +/// Atomically takes the per-thread pending host signal bitmask. +fn take_pending_host_signals() -> litebox::shim::SigSet { + // Atomically swap the per-thread pending signals with zero. + // Only the low 32 bits are used (covers traditional signals 1-31). + let lo: u32; + unsafe { + core::arch::asm!( + "xor {tmp:e}, {tmp:e}", + concat!("xchg DWORD PTR ", tls!("pending_host_signals"), ", {tmp:e}"), + tmp = out(reg) lo, + options(nostack) + ); + } + litebox::shim::SigSet::from_u64(u64::from(lo)) +} + /// Runs a guest thread using the provided shim and the given initial context. /// /// This will run until the thread terminates or returns. @@ -389,6 +493,14 @@ in_guest: .globl interrupt interrupt: .byte 0 + .align 4 +.globl pending_host_signals +pending_host_signals: + .long 0 + .align 8 +.globl wait_waker_addr +wait_waker_addr: + .quad 0 " ); @@ -798,6 +910,14 @@ in_guest: .globl interrupt interrupt: .byte 0 + .align 4 +.globl pending_host_signals +pending_host_signals: + .long 0 + .align 4 +.global wait_waker_addr +wait_waker_addr: + .long 0 " ); @@ -837,6 +957,46 @@ unsafe extern "fastcall" fn switch_to_guest(ctx: &litebox_common_linux::PtRegs) ); } +/// Non-guest threads (e.g., network workers, background tasks) should call this +/// function at the start of their execution so the kernel only delivers +/// `SIGALRM` / `SIGINT` to guest threads, which have the proper signal-handler +/// context to re-enter the shim. +fn block_guest_signals() { + unsafe { + let mut set: libc::sigset_t = std::mem::zeroed(); + libc::sigemptyset(&raw mut set); + libc::sigaddset(&raw mut set, libc::SIGALRM); + libc::sigaddset(&raw mut set, libc::SIGINT); + libc::pthread_sigmask(libc::SIG_BLOCK, &raw const set, std::ptr::null_mut()); + } +} + +/// Spawn a non-guest ("host") thread that automatically blocks guest interrupt +/// signals before running `f`. +/// +/// Every background thread created by a runner (network workers, I/O helpers, +/// etc.) should use this function instead of [`std::thread::spawn`] to ensure +/// that `SIGALRM` and `SIGINT` are only delivered to guest threads. +/// +/// # Example +/// +/// ```ignore +/// let handle = litebox_platform_linux_userland::spawn_host_thread(move || { +/// // This thread will never receive SIGALRM or SIGINT. +/// do_background_work(); +/// }); +/// ``` +pub fn spawn_host_thread(f: F) -> std::thread::JoinHandle +where + F: FnOnce() -> T + Send + 'static, + T: Send + 'static, +{ + std::thread::spawn(move || { + block_guest_signals(); + f() + }) +} + fn thread_start( init_thread: Box< dyn litebox::shim::InitThread, @@ -926,10 +1086,63 @@ impl litebox::platform::ThreadProvider for LinuxUserland { fn interrupt_thread(&self, thread: &Self::ThreadHandle) { thread.interrupt(); } + + #[cfg(debug_assertions)] + fn run_test_thread(f: impl FnOnce() -> R) -> R { + // Sets `gsbase = fsbase` (x86_64) or `fs = gs` (x86) on the current thread + // to mirror the TLS base used in guest context, so that test threads can use the + // same TLS access code as guest threads. + #[cfg(target_arch = "x86_64")] + unsafe { + core::arch::asm!( + "rdfsbase {tmp}", + "wrgsbase {tmp}", + tmp = out(reg) _, + options(nostack, preserves_flags), + ); + } + #[cfg(target_arch = "x86")] + { + unsafe { + core::arch::asm!( + "mov {tmp:x}, gs", + "mov fs, {tmp:x}", + tmp = out(reg) _, + options(nostack, preserves_flags), + ); + } + } + + ThreadHandle::run_with_handle(f) + } } impl litebox::platform::RawMutexProvider for LinuxUserland { type RawMutex = RawMutex; + + fn on_interruptible_wait_start(&self, waker: &litebox::event::wait::Waker) + where + Self: litebox::sync::RawSyncPrimitivesProvider, + { + let waker_ptr = waker as *const litebox::event::wait::Waker; + unsafe { + core::arch::asm!( + concat!("mov ", tls!("wait_waker_addr"), ", {}"), + in(reg) waker_ptr, + options(nostack, preserves_flags), + ); + } + } + + fn on_interruptible_wait_end(&self) { + unsafe { + core::arch::asm!( + concat!("mov ", tls!("wait_waker_addr"), ", {zero}"), + zero = in(reg) 0usize, + options(nostack, preserves_flags), + ); + } + } } pub struct RawMutex { @@ -949,30 +1162,21 @@ impl RawMutex { val: u32, timeout: Option, ) -> Result { - // We immediately wake up (without even hitting syscalls) if we can clearly see that the - // value is different. - if self.inner.load(Ordering::SeqCst) != val { - return Err(ImmediatelyWokenUp); - } - // We wait on the futex, with a timeout if needed - loop { - break match futex_timeout( - &self.inner, - FutexOperation::Wait, - /* expected value */ val, - timeout, - /* ignored */ None, - ) { - Ok(0) => Ok(UnblockedOrTimedOut::Unblocked), - Err(syscalls::Errno::EAGAIN) => Err(ImmediatelyWokenUp), - Err(syscalls::Errno::ETIMEDOUT) => Ok(UnblockedOrTimedOut::TimedOut), - Err(syscalls::Errno::EINTR) => continue, - Err(e) => { - panic!("Unexpected errno={e} for FUTEX_WAIT") - } - _ => unreachable!(), - }; + match futex_timeout( + &self.inner, + FutexOperation::Wait, + /* expected value */ val, + timeout, + /* ignored */ None, + ) { + Ok(0) | Err(syscalls::Errno::EINTR) => Ok(UnblockedOrTimedOut::Unblocked), + Err(syscalls::Errno::EAGAIN) => Err(ImmediatelyWokenUp), + Err(syscalls::Errno::ETIMEDOUT) => Ok(UnblockedOrTimedOut::TimedOut), + Err(e) => { + panic!("Unexpected errno={e} for FUTEX_WAIT") + } + _ => unreachable!(), } } } @@ -1618,14 +1822,8 @@ impl ThreadContext<'_> { // now (by calling into the shim), and it might be set again by the shim // before returning. unsafe { - #[cfg(target_arch = "x86_64")] core::arch::asm!( - "mov BYTE PTR fs:interrupt@tpoff, 0", - options(nostack, preserves_flags) - ); - #[cfg(target_arch = "x86")] - core::arch::asm!( - "mov BYTE PTR gs:interrupt@ntpoff, 0", + concat!("mov BYTE PTR ", tls!("interrupt"), ", 0"), options(nostack, preserves_flags) ); } @@ -1755,6 +1953,25 @@ fn register_exception_handlers() { ); } } + + // Note that non-guest threads should block these signals, so it always fires on a guest thread. + let traditional_signals = &[libc::SIGINT, libc::SIGALRM]; + for &sig in traditional_signals { + unsafe { + let mut sa: libc::sigaction = core::mem::zeroed(); + sa.sa_flags = libc::SA_SIGINFO | libc::SA_ONSTACK; + sa.sa_sigaction = interrupt_signal_handler as *const () as usize; + // Block the interrupt signal while handling signals + libc::sigaddset(&raw mut sa.sa_mask, interrupt_signal); + let mut old_sa = core::mem::zeroed(); + sigaction(sig, Some(&sa), &mut old_sa); + assert_eq!( + old_sa.sa_sigaction, + libc::SIG_DFL, + "signal {sig} handler already installed", + ); + } + } }); } @@ -2156,12 +2373,112 @@ unsafe fn next_signal_handler( } } +/// Async-signal-safe wake of a thread blocked in an interruptible wait. +/// +/// This is the signal-handler counterpart of `WaitStateInner::wake()`: it +/// CAS's the condvar from WAITING to WOKEN and issues a futex wake so the +/// blocked thread returns from `futex_wait`. +/// +/// `waker_addr` is the raw address read from the `wait_waker_addr` TLS +/// variable (0 means no waker is registered). +fn try_wake_wait_waker(waker_addr: usize) { + if waker_addr == 0 { + return; + } + // SAFETY: waker_addr points to a valid Waker whose + // lifetime spans the entire interruptible wait, set by + // RawMutexProvider::on_interruptible_wait_start. + let waker = unsafe { &*(waker_addr as *const litebox::event::wait::Waker) }; + waker.wake(); +} + +/// Records a pending host signal in the `.tbss` bitmask and wakes any condvar +/// the thread is blocked on. +/// +/// # Safety +/// +/// Must be called from a signal handler on a guest thread whose saved host TLS +/// segment register is valid. +unsafe fn record_pending_signal(signal: litebox::shim::Signal) { + let mask: u32 = 1u32 << (signal.as_raw() - 1); + unsafe { + core::arch::asm!( + concat!("lock or DWORD PTR ", saved_tls!("pending_host_signals"), ", {mask:e}"), + mask = in(reg) mask, + options(nostack) + ); + } + let waker_addr: usize; + unsafe { + core::arch::asm!( + concat!("mov {}, ", saved_tls!("wait_waker_addr")), + out(reg) waker_addr, + options(nostack, preserves_flags) + ); + } + try_wake_wait_waker(waker_addr); +} + /// Signal handler for interrupt signals. unsafe fn interrupt_signal_handler( - _signum: libc::c_int, + signum: libc::c_int, _info: &mut libc::siginfo_t, context: &mut libc::ucontext_t, ) { + let raise_signal = |signum: libc::c_int| { + // block the signal on this non-guest thread so the kernel won't + // deliver it here again, then re-raise as process-directed so a + // guest thread picks it up. + // + // This should only be called by test threads (spawned via cargo test). + // Other non-guest threads like network worker threads should have already blocked these signals. + unsafe { + let mut set: libc::sigset_t = core::mem::zeroed(); + libc::sigemptyset(&raw mut set); + libc::sigaddset(&raw mut set, signum); + libc::pthread_sigmask(libc::SIG_BLOCK, &raw const set, std::ptr::null_mut()); + libc::kill(libc::getpid(), signum); + } + }; + + // Record host-originated signals (SIGINT, SIGALRM, etc.) in the + // per-thread pending bitmask so the shim can forward them to the guest. + // TODO: no realtime signal support for now. + if signum > 0 && signum < 32 { + // Only record signals that can be forwarded to the guest as + // litebox::shim::Signal. Unknown signals are silently dropped. + let Ok(signal) = litebox_common_linux::signal::Signal::try_from(signum) else { + return; + }; + let Ok(signal) = litebox::shim::Signal::try_from(signal) else { + return; + }; + + // Check whether the saved host TLS segment is valid (i.e. this is a + // guest thread). If not, re-raise the signal process-wide. + let is_guest_thread; + #[cfg(target_arch = "x86_64")] + { + let gsbase: u64; + unsafe { core::arch::asm!("rdgsbase {}", out(reg) gsbase) }; + is_guest_thread = gsbase != 0; + } + #[cfg(target_arch = "x86")] + { + let fs: u16; + unsafe { core::arch::asm!("mov {:x}, fs", out(reg) fs, options(nostack, nomem)) }; + is_guest_thread = fs != 0; + } + + if is_guest_thread { + // SAFETY: we verified the saved host TLS segment is valid above. + unsafe { record_pending_signal(signal) }; + } else { + raise_signal(signum); + return; + } + } + // The interrupt signal can arrive in different contexts: // 1. The thread is running in the host at the beginning of the syscall // handler. Do nothing--the syscall handler will handle the interrupt. diff --git a/litebox_platform_windows_userland/Cargo.toml b/litebox_platform_windows_userland/Cargo.toml index 6a20afe97..ed637f9a1 100644 --- a/litebox_platform_windows_userland/Cargo.toml +++ b/litebox_platform_windows_userland/Cargo.toml @@ -9,6 +9,7 @@ litebox = { path = "../litebox/", version = "0.1.0" } litebox_common_linux = { path = "../litebox_common_linux", version = "0.1.0" } windows-sys = { version = "0.60.2", features = [ + "Win32_System_Console", "Win32_System_Memory", "Win32_System_Threading", "Win32_System_SystemInformation", diff --git a/litebox_platform_windows_userland/src/lib.rs b/litebox_platform_windows_userland/src/lib.rs index fcfa3959f..71178de39 100644 --- a/litebox_platform_windows_userland/src/lib.rs +++ b/litebox_platform_windows_userland/src/lib.rs @@ -252,6 +252,14 @@ impl WindowsUserland { let _ = AddVectoredExceptionHandler(0, Some(vectored_exception_handler)); } + // Register a console control handler to receive Ctrl+C + unsafe { + windows_sys::Win32::System::Console::SetConsoleCtrlHandler( + Some(ctrl_c_handler), + 1, // TRUE — add the handler + ); + } + Box::leak(Box::new(platform)) } @@ -322,20 +330,26 @@ impl WindowsUserland { impl litebox::platform::Provider for WindowsUserland {} -/// Runs a guest thread using the provided shim and the given initial context. -/// -/// This will run until the thread terminates. +impl litebox::platform::SignalProvider for WindowsUserland { + fn take_pending_signals(&self, mut f: impl FnMut(litebox::shim::Signal)) { + let bits = get_tls_ptr().map_or(0, |p| { + unsafe { &*p } + .pending_host_signals + .swap(0, Ordering::SeqCst) + }); + let sigs = litebox::shim::SigSet::from_u64(u64::from(bits)); + for signal in sigs { + f(signal); + } + } +} + +/// Ensures the module-wide TLS slot index ([`TLS_INDEX`]) has been allocated. /// -/// # Safety -/// The context must be valid guest context. -#[expect( - clippy::missing_panics_doc, - reason = "the caller cannot control whether this will panic" -)] -pub unsafe fn run_thread( - shim: impl litebox::shim::EnterShim, - ctx: &mut litebox_common_linux::PtRegs, -) { +/// This must be called before any code that reads `TLS_INDEX`. Both +/// [`run_thread`] (guest threads) and [`run_test_thread`](WindowsUserland::run_test_thread) +/// (test threads) go through here. +fn ensure_tls_index() { // Allocate a TLS slot for this module if not already done. This is used as // a place to store data across calls to the guest, since all the registers // are used by the guest and will be clobbered. @@ -353,6 +367,19 @@ pub unsafe fn run_thread( ); TLS_INDEX.store(index, Ordering::Relaxed); }); +} + +/// Runs a guest thread using the provided shim and the given initial context. +/// +/// This will run until the thread terminates. +/// +/// # Safety +/// The context must be valid guest context. +pub unsafe fn run_thread( + shim: impl litebox::shim::EnterShim, + ctx: &mut litebox_common_linux::PtRegs, +) { + ensure_tls_index(); run_thread_inner(&shim, ctx); } @@ -360,25 +387,11 @@ fn run_thread_inner( shim: &dyn litebox::shim::EnterShim, ctx: &mut litebox_common_linux::PtRegs, ) { - let tls_index = TLS_INDEX.load(Ordering::Relaxed); - let tls_state = TlsState { - host_sp: Cell::new(core::ptr::null_mut()), - host_bp: Cell::new(core::ptr::null_mut()), - guest_context_top: std::ptr::from_mut(ctx).wrapping_add(1).into(), - scratch: 0.into(), - is_in_guest: false.into(), - interrupt: false.into(), - continue_context: Box::default(), - }; - unsafe { - windows_sys::Win32::System::Threading::TlsSetValue( - tls_index, - core::ptr::from_ref(&tls_state).cast(), - ); - } - let _tls_guard = litebox::utils::defer(|| unsafe { - windows_sys::Win32::System::Threading::TlsSetValue(tls_index, core::ptr::null()); - }); + let tls_state = TlsState::new(); + tls_state + .guest_context_top + .set(std::ptr::from_mut(ctx).wrapping_add(1)); + let mut thread_ctx = ThreadContext { shim, ctx, @@ -400,6 +413,52 @@ struct TlsState { interrupt: Cell, continue_context: Box>, + /// Bitmask of pending host-originated signals for this thread. + pending_host_signals: AtomicU32, + /// Pointer to the `Waker` currently being waited on, or null if not + /// waiting. + /// + /// Set by [`RawMutexProvider::on_interruptible_wait_start`] and cleared by + /// [`RawMutexProvider::on_interruptible_wait_end`]. + waiting_waker: std::sync::atomic::AtomicPtr>, +} + +impl TlsState { + /// Creates a new `TlsState` with all fields zeroed / defaulted. + fn new() -> Self { + Self { + host_sp: Cell::new(core::ptr::null_mut()), + host_bp: Cell::new(core::ptr::null_mut()), + guest_context_top: core::ptr::null_mut::().into(), + scratch: 0.into(), + is_in_guest: false.into(), + interrupt: false.into(), + continue_context: Box::default(), + pending_host_signals: AtomicU32::new(0), + waiting_waker: std::sync::atomic::AtomicPtr::new(std::ptr::null_mut()), + } + } +} + +/// Stores `tls` in the current thread's Windows TLS slot. +/// +/// # Safety +/// +/// The caller must ensure `tls` remains valid for the duration of its use. +unsafe fn install_tls(tls: &TlsState) { + let tls_index = TLS_INDEX.load(Ordering::Relaxed); + unsafe { + windows_sys::Win32::System::Threading::TlsSetValue( + tls_index, + core::ptr::from_ref(tls).cast(), + ); + } +} + +/// Clears the current thread's Windows TLS slot. +fn uninstall_tls() { + let tls_index = TLS_INDEX.load(Ordering::Relaxed); + unsafe { windows_sys::Win32::System::Threading::TlsSetValue(tls_index, core::ptr::null()) }; } fn get_tls_ptr() -> Option<*const TlsState> { @@ -407,9 +466,12 @@ fn get_tls_ptr() -> Option<*const TlsState> { if tls_index == u32::MAX { return None; } - Some(unsafe { - windows_sys::Win32::System::Threading::TlsGetValue(tls_index).cast::() - }) + let ptr = + unsafe { windows_sys::Win32::System::Threading::TlsGetValue(tls_index).cast::() }; + if ptr.is_null() { + return None; + } + Some(ptr) } /// Runs the guest thread until it terminates. @@ -747,6 +809,33 @@ impl litebox::platform::ThreadProvider for WindowsUserland { thread.interrupt(current.as_ref()); }); } + + #[cfg(debug_assertions)] + fn run_test_thread(f: impl FnOnce() -> R) -> R { + // Ensure the module-wide TLS slot is allocated. + ensure_tls_index(); + let tls = TlsState::new(); + ThreadHandle::run_with_handle(&tls, f) + } +} + +/// Console control handler registered via `SetConsoleCtrlHandler`. +/// +/// When the user presses Ctrl+C, this sets the SIGINT bit on every active +/// managed thread and interrupts them so the shim can deliver the signal. +unsafe extern "system" fn ctrl_c_handler(ctrl_type: u32) -> i32 { + if ctrl_type != windows_sys::Win32::System::Console::CTRL_C_EVENT { + return 0; // FALSE — let the next handler deal with it + } + + // Pick one arbitrary thread to deliver the signal to. + let thread = ACTIVE_THREADS.lock().unwrap().first().cloned(); + + if let Some(thread) = thread { + thread.deliver_signal(litebox::shim::Signal::SIGINT); + } + + 1 // TRUE — we handled it } #[derive(Clone)] @@ -764,34 +853,88 @@ thread_local! { static CURRENT_THREAD_HANDLE: RefCell> = const { RefCell::new(None) }; } +/// Global registry of all active managed thread handles. +/// +/// The Ctrl+C handler picks a thread from this list to deliver SIGINT. +/// Threads are registered in [`ThreadHandle::run_with_handle`] and +/// removed when the guard drops. +/// +/// TODO: This global list only works when we support a single process. For +/// multi-process support, each process (or `WindowsUserland` instance) should +/// track its own thread list. +static ACTIVE_THREADS: Mutex> = Mutex::new(alloc::vec::Vec::new()); + impl ThreadHandle { - /// Runs `f`, ensuring that [`CURRENT_THREAD_HANDLE`] is set while in the call to `f`. - fn run_with_handle(tls: &TlsState, f: impl FnOnce() -> R) -> R { + /// Creates a [`ThreadHandle`] referencing the calling OS thread. + fn for_current_thread(tls: &TlsState) -> ThreadHandle { let win_handle = unsafe { std::os::windows::io::BorrowedHandle::borrow_raw( windows_sys::Win32::System::Threading::GetCurrentThread(), ) }; - let handle = ThreadHandle(Arc::new(Mutex::new(Some(ThreadHandleInner { + ThreadHandle(Arc::new(Mutex::new(Some(ThreadHandleInner { handle: win_handle .try_clone_to_owned() .expect("failed to clone current thread handle"), tls: SendConstPtr(tls), - })))); + })))) + } + + /// Runs `f`, ensuring that [`CURRENT_THREAD_HANDLE`] is set while in the call to `f`. + fn run_with_handle(tls: &TlsState, f: impl FnOnce() -> R) -> R { + // Safety: `tls_state` lives for the duration of this call. + unsafe { install_tls(tls) }; + + let handle = Self::for_current_thread(tls); + ACTIVE_THREADS.lock().unwrap().push(handle.clone()); CURRENT_THREAD_HANDLE.with_borrow_mut(|current| { assert!( current.is_none(), - "nested run_with_handle calls are not supported" + "thread is already registered with LiteBox", ); - *current = Some(handle); + *current = Some(handle.clone()); }); - let _guard = litebox::utils::defer(|| { + let _guard = litebox::utils::defer(move || { let current = CURRENT_THREAD_HANDLE.take().unwrap(); + // Remove from the global registry. + ACTIVE_THREADS + .lock() + .unwrap() + .retain(|h| !Arc::ptr_eq(&h.0, ¤t.0)); *current.0.lock().unwrap() = None; + uninstall_tls(); }); f() } + /// Sets a pending signal on this thread, wakes it from any condvar wait, + /// and interrupts it so the shim processes the signal promptly. + fn deliver_signal(&self, signal: litebox::shim::Signal) { + let bit: u32 = 1 << (signal.as_raw() - 1); + + // Set the pending signal bit and wake the condvar in one lock scope. + { + let inner = self.0.lock().unwrap(); + if let Some(inner) = inner.as_ref() { + // Safety: the TLS pointer is valid as long as the thread is + // alive, and we hold the thread handle lock. + let tls = unsafe { &*inner.tls.0 }; + tls.pending_host_signals.fetch_or(bit, Ordering::SeqCst); + + let waker = tls.waiting_waker.load(Ordering::Acquire); + if !waker.is_null() { + // Safety: `waker` points to a valid `Waker` + // whose lifetime spans the interruptible wait, set by + // `RawMutexProvider::on_interruptible_wait_start`. + let waker = unsafe { &*waker }; + waker.wake(); + } + } + } + + self.interrupt(None); + } + /// Interrupt the thread represented by this handle, where `current` is the /// current thread's handle if it is managed by LiteBox. /// @@ -989,6 +1132,23 @@ fn is_in_ntdll_or_this(ip: usize) -> bool { impl litebox::platform::RawMutexProvider for WindowsUserland { type RawMutex = RawMutex; + + fn on_interruptible_wait_start(&self, waker: &litebox::event::wait::Waker) + where + Self: litebox::sync::RawSyncPrimitivesProvider, + { + if let Some(tls) = get_tls_ptr().map(|p| unsafe { &*p }) { + tls.waiting_waker + .store(std::ptr::from_ref(waker).cast_mut(), Ordering::Release); + } + } + + fn on_interruptible_wait_end(&self) { + if let Some(tls) = get_tls_ptr().map(|p| unsafe { &*p }) { + tls.waiting_waker + .store(std::ptr::null_mut(), Ordering::Release); + } + } } // A skeleton of a raw mutex for Windows. @@ -1004,17 +1164,12 @@ impl RawMutex { } } + #[expect(clippy::unnecessary_wraps)] fn block_or_maybe_timeout( &self, val: u32, timeout: Option, ) -> Result { - // We immediately wake up (without even hitting syscalls) if we can clearly see that the - // value is different. - if self.inner.load(Ordering::SeqCst) != val { - return Err(ImmediatelyWokenUp); - } - // Compute timeout in ms let timeout_ms = match timeout { None => Win32_Threading::INFINITE, // no timeout @@ -1039,15 +1194,8 @@ impl RawMutex { // Check why WaitOnAddress failed let err = unsafe { GetLastError() }; match err { - Win32_Foundation::ERROR_TIMEOUT => { - // Timed out - Ok(UnblockedOrTimedOut::TimedOut) - } - e => { - // Other error, possibly spurious wakeup or value changed - // Continue the loop to check the value again - panic!("Unexpected error={e} for WaitOnAddress"); - } + Win32_Foundation::ERROR_TIMEOUT => Ok(UnblockedOrTimedOut::TimedOut), + e => panic!("Unexpected error={e} for WaitOnAddress"), } } } @@ -1720,6 +1868,7 @@ unsafe extern "C-unwind" fn exception_handler( } unsafe extern "C-unwind" fn interrupt_handler(thread_ctx: &mut ThreadContext<'_>) { + thread_ctx.tls.is_in_guest.set(false); thread_ctx.call_shim(|shim, ctx, interrupt| { if interrupt { shim.interrupt(ctx) diff --git a/litebox_runner_linux_userland/src/lib.rs b/litebox_runner_linux_userland/src/lib.rs index 6fa68a5c3..9a17f1247 100644 --- a/litebox_runner_linux_userland/src/lib.rs +++ b/litebox_runner_linux_userland/src/lib.rs @@ -271,7 +271,7 @@ pub fn run(cli_args: CliArgs) -> Result<()> { let net_worker = if cli_args.tun_device_name.is_some() { let shim = shim.clone(); let shutdown_clone = shutdown.clone(); - let child = std::thread::spawn(move || { + let child = litebox_platform_linux_userland::spawn_host_thread(move || { const DEFAULT_TIMEOUT: core::time::Duration = core::time::Duration::from_millis(5); pin_thread_to_cpu(0); diff --git a/litebox_runner_linux_userland/tests/sigint.c b/litebox_runner_linux_userland/tests/sigint.c new file mode 100644 index 000000000..f3b49d943 --- /dev/null +++ b/litebox_runner_linux_userland/tests/sigint.c @@ -0,0 +1,174 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +// Tests: SIGINT delivery (self-sent via raise/kill) + +#define _POSIX_C_SOURCE 200809L +#include +#include +#include +#include +#include + +#define TEST_ASSERT(cond, msg) do { \ + if (!(cond)) { \ + fprintf(stderr, "FAIL: %s (line %d): %s (errno=%d)\n", \ + __func__, __LINE__, msg, errno); \ + return 1; \ + } \ +} while(0) + +static volatile sig_atomic_t sigint_count = 0; + +static void sigint_handler(int sig) { + (void)sig; + sigint_count++; +} + +// Test that SIGINT default action terminates via exit_group when no handler +// is installed. We can't test this directly (it would kill us), so instead +// verify that installing SIG_IGN causes SIGINT to be ignored. +int test_sigint_ignore(void) { + struct sigaction sa; + sa.sa_handler = SIG_IGN; + sigemptyset(&sa.sa_mask); + sa.sa_flags = 0; + + TEST_ASSERT(sigaction(SIGINT, &sa, NULL) == 0, "sigaction(SIG_IGN) failed"); + + // Sending SIGINT should be silently ignored. + TEST_ASSERT(raise(SIGINT) == 0, "raise(SIGINT) failed"); + + // Restore default handler. + sa.sa_handler = SIG_DFL; + sigaction(SIGINT, &sa, NULL); + + printf("sigint_ignore: PASS\n"); + return 0; +} + +// Test that SIGINT is delivered to a custom handler via raise(). +int test_sigint_raise(void) { + struct sigaction sa; + sa.sa_handler = sigint_handler; + sigemptyset(&sa.sa_mask); + sa.sa_flags = 0; + + TEST_ASSERT(sigaction(SIGINT, &sa, NULL) == 0, "sigaction failed"); + + sigint_count = 0; + TEST_ASSERT(raise(SIGINT) == 0, "raise(SIGINT) failed"); + + TEST_ASSERT(sigint_count == 1, "SIGINT handler should have been called exactly once"); + + // Restore default handler. + sa.sa_handler = SIG_DFL; + sigaction(SIGINT, &sa, NULL); + + printf("sigint_raise: PASS\n"); + return 0; +} + +// Test that SIGINT is delivered via kill(getpid(), SIGINT). +int test_sigint_kill(void) { + struct sigaction sa; + sa.sa_handler = sigint_handler; + sigemptyset(&sa.sa_mask); + sa.sa_flags = 0; + + TEST_ASSERT(sigaction(SIGINT, &sa, NULL) == 0, "sigaction failed"); + + sigint_count = 0; + TEST_ASSERT(kill(getpid(), SIGINT) == 0, "kill(getpid(), SIGINT) failed"); + + TEST_ASSERT(sigint_count == 1, "SIGINT handler should have been called exactly once"); + + // Restore default handler. + sa.sa_handler = SIG_DFL; + sigaction(SIGINT, &sa, NULL); + + printf("sigint_kill: PASS\n"); + return 0; +} + +// Test that SIGINT can be blocked and then delivered when unblocked. +int test_sigint_block_unblock(void) { + struct sigaction sa; + sa.sa_handler = sigint_handler; + sigemptyset(&sa.sa_mask); + sa.sa_flags = 0; + + TEST_ASSERT(sigaction(SIGINT, &sa, NULL) == 0, "sigaction failed"); + + sigint_count = 0; + + // Block SIGINT. + sigset_t block_set, old_set; + sigemptyset(&block_set); + sigaddset(&block_set, SIGINT); + TEST_ASSERT(sigprocmask(SIG_BLOCK, &block_set, &old_set) == 0, "sigprocmask(BLOCK) failed"); + + // Send SIGINT while blocked -- should be queued, not delivered. + TEST_ASSERT(raise(SIGINT) == 0, "raise(SIGINT) failed"); + TEST_ASSERT(sigint_count == 0, "SIGINT should be pending, not delivered"); + + // Unblock SIGINT -- should now be delivered. + TEST_ASSERT(sigprocmask(SIG_SETMASK, &old_set, NULL) == 0, "sigprocmask(RESTORE) failed"); + TEST_ASSERT(sigint_count == 1, "SIGINT should have been delivered after unblocking"); + + // Restore default handler. + sa.sa_handler = SIG_DFL; + sigaction(SIGINT, &sa, NULL); + + printf("sigint_block_unblock: PASS\n"); + return 0; +} + +// Test that multiple SIGINTs are coalesced (standard signals are not queued). +int test_sigint_coalesce(void) { + struct sigaction sa; + sa.sa_handler = sigint_handler; + sigemptyset(&sa.sa_mask); + sa.sa_flags = 0; + + TEST_ASSERT(sigaction(SIGINT, &sa, NULL) == 0, "sigaction failed"); + + sigint_count = 0; + + // Block SIGINT. + sigset_t block_set, old_set; + sigemptyset(&block_set); + sigaddset(&block_set, SIGINT); + TEST_ASSERT(sigprocmask(SIG_BLOCK, &block_set, &old_set) == 0, "sigprocmask(BLOCK) failed"); + + // Send multiple SIGINTs while blocked. + for (int i = 0; i < 5; i++) { + raise(SIGINT); + } + TEST_ASSERT(sigint_count == 0, "SIGINT should not be delivered while blocked"); + + // Unblock -- only one delivery expected (standard signal coalescing). + TEST_ASSERT(sigprocmask(SIG_SETMASK, &old_set, NULL) == 0, "sigprocmask(RESTORE) failed"); + TEST_ASSERT(sigint_count == 1, + "only one SIGINT should be delivered (standard signal coalescing)"); + + // Restore default handler. + sa.sa_handler = SIG_DFL; + sigaction(SIGINT, &sa, NULL); + + printf("sigint_coalesce: PASS\n"); + return 0; +} + +int main(void) { + printf("Starting SIGINT tests...\n"); + + if (test_sigint_ignore() != 0) return 1; + if (test_sigint_raise() != 0) return 1; + if (test_sigint_kill() != 0) return 1; + if (test_sigint_block_unblock() != 0) return 1; + if (test_sigint_coalesce() != 0) return 1; + + printf("All SIGINT tests passed!\n"); + return 0; +} diff --git a/litebox_shim_linux/src/syscalls/process.rs b/litebox_shim_linux/src/syscalls/process.rs index 05e015244..aea95fd6a 100644 --- a/litebox_shim_linux/src/syscalls/process.rs +++ b/litebox_shim_linux/src/syscalls/process.rs @@ -1508,6 +1508,8 @@ impl Task { #[cfg(test)] mod tests { + extern crate std; + #[cfg(target_arch = "x86_64")] #[test] fn test_arch_prctl() { @@ -1606,4 +1608,79 @@ mod tests { "prctl get_name returned unexpected comm for too long name" ); } + + /// Installing a custom handler for SIGINT: a background OS thread sends + /// a real SIGINT via `libc::kill`, which should interrupt a blocking sleep + /// with `EINTR`. + /// Target Linux only because it use tgkill syscall to send signal to specific thread. + #[cfg(all(target_os = "linux", debug_assertions))] + #[test] + fn test_sigint_with_custom_handler() { + use litebox_common_linux::signal::{SaFlags, SigAction, SigSet, Signal}; + use litebox_common_linux::{ClockId, TimerFlags, Timespec}; + + let callback_addr = 0x1000usize; // dummy non-null address for the callback + let task = crate::syscalls::tests::init_platform(None); + ::run_test_thread(|| { + let act = SigAction { + sigaction: callback_addr, + flags: SaFlags::RESTORER, + #[cfg(target_pointer_width = "64")] + __pad: 0, + restorer: 0, + mask: SigSet::empty(), + }; + let act_ptr = crate::ConstPtr::from_ptr(&raw const act); + task.sys_rt_sigaction( + Signal::SIGINT, + Some(act_ptr), + None, + core::mem::size_of::(), + ) + .expect("rt_sigaction failed"); + + // Spawn a plain OS thread that sends a real SIGINT to this + // specific thread after a short delay, giving it time to enter nanosleep. + let pid = unsafe { libc::getpid() }; + let tid = unsafe { libc::syscall(libc::SYS_gettid) }; + let handle = std::thread::spawn(move || { + std::thread::sleep(std::time::Duration::from_millis(200)); + // Safety: sending a signal to a thread in our own process is always valid. + let ret = unsafe { libc::syscall(libc::SYS_tgkill, pid, tid, libc::SIGINT) }; + assert_eq!(ret, 0, "tgkill failed"); + }); + + let mut request = Timespec { + tv_sec: 10, + tv_nsec: 0, + }; + let result = task.sys_clock_nanosleep( + ClockId::Monotonic, + TimerFlags::empty(), + litebox_common_linux::TimeParam::Timespec64(crate::MutPtr::from_ptr( + &raw mut request, + )), + litebox_common_linux::TimeParam::None, + ); + assert_eq!( + result, + Err(litebox_common_linux::errno::Errno::EINTR), + "nanosleep should be interrupted by SIGINT from background thread" + ); + + // `process_signals` is called when about to switch back to userspace, so simulate that here. + let mut stack = [0u8; 4096]; + #[cfg(target_arch = "x86_64")] + let mut regs = litebox_common_linux::PtRegs { rsp: stack.as_mut_ptr() as usize + stack.len(), ..Default::default() }; + #[cfg(target_arch = "x86")] + let mut regs = litebox_common_linux::PtRegs { esp: stack.as_mut_ptr() as usize + stack.len(), ..Default::default() }; + task.process_signals(&mut regs); + assert_eq!( + regs.get_ip(), callback_addr, + "after processing signals, execution should be redirected to the custom handler" + ); + + handle.join().expect("background thread panicked"); + }); + } } diff --git a/litebox_shim_linux/src/syscalls/signal/mod.rs b/litebox_shim_linux/src/syscalls/signal/mod.rs index 9adeabaf1..49b7d9e34 100644 --- a/litebox_shim_linux/src/syscalls/signal/mod.rs +++ b/litebox_shim_linux/src/syscalls/signal/mod.rs @@ -36,6 +36,8 @@ use litebox_platform_multiplex::Platform; pub(crate) struct SignalState { /// Pending thread signals. pending: RefCell, + /// Pending process signals (shared across all threads). + shared_pending: Arc>, /// Currently blocked signals. blocked: Cell, /// Signal handlers. @@ -50,6 +52,7 @@ impl SignalState { pub fn new_process() -> Self { Self { pending: RefCell::new(PendingSignals::new()), + shared_pending: Arc::new(Mutex::new(PendingSignals::new())), blocked: Cell::new(SigSet::empty()), handlers: RefCell::new(Arc::new(SignalHandlers::new())), altstack: Cell::new(SigAltStack { @@ -72,6 +75,8 @@ impl SignalState { Self { // Reset pending pending: RefCell::new(PendingSignals::new()), + // Share process-wide pending signals + shared_pending: self.shared_pending.clone(), // Preserve blocked blocked: Cell::new(self.blocked.get()), // Share handlers across tasks @@ -281,7 +286,7 @@ fn siginfo_exception(signal: Signal, fault_address: usize) -> Siginfo { /// Creates a `Siginfo` for a signal sent by a user process via `kill()`, /// `tkill()`, or `tgkill()`. -fn siginfo_kill(signal: Signal) -> Siginfo { +pub(crate) fn siginfo_kill(signal: Signal) -> Siginfo { Siginfo { signo: signal.as_i32(), errno: 0, @@ -527,24 +532,38 @@ impl Task { /// Returns whether there are any pending signals that can be delivered. pub(crate) fn has_pending_signals(&self) -> bool { - let pending = self.signals.pending.borrow().pending & !self.signals.blocked.get(); - !pending.is_empty() + let blocked = self.signals.blocked.get(); + let thread_pending = self.signals.pending.borrow().pending & !blocked; + if !thread_pending.is_empty() { + return true; + } + let shared_pending = self.signals.shared_pending.lock().pending & !blocked; + !shared_pending.is_empty() } /// Deliver any pending signals. pub(crate) fn process_signals(&self, ctx: &mut PtRegs) { loop { - let mut pending = self.signals.pending.borrow_mut(); - let Some(signal) = pending.next(self.signals.blocked.get()) else { - break; + let blocked = self.signals.blocked.get(); + let (signal, siginfo) = { + let mut pending = self.signals.pending.borrow_mut(); + if let Some(signal) = pending.next(blocked) { + (signal, pending.remove(signal)) + } else { + // Then try shared pending. + let mut shared = self.signals.shared_pending.lock(); + if let Some(signal) = shared.next(blocked) { + (signal, shared.remove(signal)) + } else { + break; + } + } }; if self.is_exiting() { // Don't deliver any more signals if exiting. return; } - let siginfo: Siginfo = pending.remove(signal); - drop(pending); let action = self.signals.handlers.borrow().inner.lock()[signal].action; #[expect(clippy::match_same_arms)] match action.sigaction { @@ -586,14 +605,57 @@ impl Task { } } + pub(crate) fn take_pending_signals(&self, sig: litebox::shim::Signal) { + let signal = match sig { + litebox::shim::Signal::SIGALRM => litebox_common_linux::signal::Signal::SIGALRM, + litebox::shim::Signal::SIGINT => litebox_common_linux::signal::Signal::SIGINT, + _ => unimplemented!(), + }; + self.send_shared_signal(signal, siginfo_kill(signal)); + } + + /// Returns whether the given signal is currently being ignored. + fn is_signal_ignored(&self, signal: Signal) -> bool { + // SIGKILL and SIGSTOP can never be ignored. + if signal == Signal::SIGKILL || signal == Signal::SIGSTOP { + return false; + } + // Blocked signals are never ignored, since the signal handler may + // change by the time it is unblocked. + if self.signals.blocked.get().contains(signal) { + return false; + } + let handlers = self.signals.handlers.borrow(); + let inner = handlers.inner.lock(); + match inner[signal].action.sigaction { + SIG_IGN => true, + SIG_DFL => matches!(signal.default_disposition(), SignalDisposition::Ignore), + _ => false, + } + } + /// Only supports sending signals to self for now. - fn send_signal(&self, signal: Signal, siginfo: Siginfo) { + pub(crate) fn send_signal(&self, signal: Signal, siginfo: Siginfo) { + if self.is_signal_ignored(signal) { + return; + } self.signals .pending .borrow_mut() .push(&self.process().limits, signal, siginfo); } + /// Sends a process-directed signal (stored in shared_pending). + pub(crate) fn send_shared_signal(&self, signal: Signal, siginfo: Siginfo) { + if self.is_signal_ignored(signal) { + return; + } + self.signals + .shared_pending + .lock() + .push(&self.process().limits, signal, siginfo); + } + /// Forces a signal to be delivered on next call to `check_for_signals`. fn force_signal(&self, signal: Signal, force_exit: bool) { let siginfo = Siginfo { diff --git a/litebox_shim_linux/src/wait.rs b/litebox_shim_linux/src/wait.rs index 43a35e742..9c0a1a375 100644 --- a/litebox_shim_linux/src/wait.rs +++ b/litebox_shim_linux/src/wait.rs @@ -37,6 +37,10 @@ impl Task { #[must_use] pub(crate) fn prepare_to_run_guest(&self, ctx: &mut litebox_common_linux::PtRegs) -> bool { self.wait_state.0.prepare_to_run_guest(|| { + use litebox::platform::SignalProvider as _; + self.global.platform.take_pending_signals(|signal| { + self.take_pending_signals(signal); + }); self.process_signals(ctx); !self.is_exiting() }) @@ -45,6 +49,10 @@ impl Task { impl litebox::event::wait::CheckForInterrupt for Task { fn check_for_interrupt(&self) -> bool { + use litebox::platform::SignalProvider as _; + self.global.platform.take_pending_signals(|sig| { + self.take_pending_signals(sig); + }); self.is_exiting() || self.has_pending_signals() } }