diff --git a/doeff/effects/spawn.py b/doeff/effects/spawn.py index 943d572d..32b01923 100644 --- a/doeff/effects/spawn.py +++ b/doeff/effects/spawn.py @@ -15,7 +15,10 @@ import doeff_vm -from doeff.handlers.spawn_handler import spawn_intercept_handler +from doeff.handlers.spawn_handler import ( + spawn_intercept_handler, + sync_spawn_intercept_handler, +) from ._program_types import ProgramLike from ._validators import ensure_dict_str_any, ensure_program_like @@ -267,5 +270,6 @@ def _validate_priority(priority: int) -> int: "promise_id_of", "spawn", "spawn_intercept_handler", + "sync_spawn_intercept_handler", "task_id_of", ] diff --git a/doeff/handlers/await_handlers.py b/doeff/handlers/await_handlers.py index bf278f67..2c86e519 100644 --- a/doeff/handlers/await_handlers.py +++ b/doeff/handlers/await_handlers.py @@ -10,10 +10,12 @@ from doeff.do import do from doeff.effects.base import Effect +from doeff.effects.external_promise import ExternalPromise from doeff.effects.external_promise import CreateExternalPromise -from doeff.effects.wait import Wait +from doeff.effects.wait import wait PythonAsyncioAwaitEffect = doeff_vm.PythonAsyncioAwaitEffect +sync_await_handler = doeff_vm.sync_await_handler _loop_lock = threading.Lock() _loop_thread: threading.Thread | None = None @@ -125,16 +127,33 @@ def _on_done(completed: Any) -> None: future.add_done_callback(_on_done) -@do -def sync_await_handler(effect: Effect, k: Any): - """Handle Await effects via background-loop bridge for sync execution.""" - if isinstance(effect, PythonAsyncioAwaitEffect): - promise = yield CreateExternalPromise() - _submit_awaitable(effect.awaitable, promise) - value = yield Wait(promise.future) - return (yield doeff_vm.Resume(k, value)) +def _external_promise_from_handle(handle: Any) -> ExternalPromise[Any]: + if isinstance(handle, ExternalPromise): + return handle - yield doeff_vm.Pass() + if not isinstance(handle, dict) or handle.get("type") != "ExternalPromise": + raise TypeError( + "Expected ExternalPromise handle dict or ExternalPromise instance, " + f"got {type(handle).__name__}" + ) + + promise_id = handle.get("promise_id") + if not isinstance(promise_id, int): + raise TypeError("ExternalPromise handle missing integer promise_id") + + completion_queue = handle.get("completion_queue") + if completion_queue is None: + raise TypeError("ExternalPromise handle missing completion_queue") + + return ExternalPromise( + _handle=handle, + _completion_queue=completion_queue, + _id=promise_id, + ) + + +def _submit_awaitable_handle(awaitable: Awaitable[Any], handle: Any) -> None: + _submit_awaitable(awaitable, _external_promise_from_handle(handle)) @do @@ -154,7 +173,7 @@ async def _kickoff() -> None: asyncio.get_running_loop().create_task(_run_and_complete()) _ = yield doeff_vm.PythonAsyncSyntaxEscape(action=_kickoff) - value = yield Wait(promise.future) + value = yield wait(promise.future) return (yield doeff_vm.Resume(k, value)) yield doeff_vm.Pass() diff --git a/doeff/handlers/spawn_handler.py b/doeff/handlers/spawn_handler.py index 0c71e279..559dfeae 100644 --- a/doeff/handlers/spawn_handler.py +++ b/doeff/handlers/spawn_handler.py @@ -19,4 +19,14 @@ def spawn_intercept_handler(effect: Effect, k: Any): yield doeff_vm.Pass() -__all__ = ["spawn_intercept_handler"] +@do +def sync_spawn_intercept_handler(effect: Effect, k: Any): + from doeff.effects.spawn import SpawnEffect, coerce_task_handle + + if isinstance(effect, SpawnEffect): + raw = yield doeff_vm.Delegate() + return (yield doeff_vm.Transfer(k, coerce_task_handle(raw))) + yield doeff_vm.Pass() + + +__all__ = ["spawn_intercept_handler", "sync_spawn_intercept_handler"] diff --git a/doeff/rust_vm.py b/doeff/rust_vm.py index 67d134a7..18c5a287 100644 --- a/doeff/rust_vm.py +++ b/doeff/rust_vm.py @@ -493,11 +493,11 @@ def default_handlers() -> list[Any]: """ vm = _vm() from doeff.handlers.await_handlers import sync_await_handler - from doeff.handlers.spawn_handler import spawn_intercept_handler + from doeff.handlers.spawn_handler import sync_spawn_intercept_handler return [ *_core_handler_sentinels(vm), - spawn_intercept_handler, + sync_spawn_intercept_handler, sync_await_handler, ] @@ -646,6 +646,7 @@ def WithIntercept( "result_safe", "lazy_ask", "await_handler", + "sync_await_handler", } @@ -693,6 +694,7 @@ def __getattr__(name: str) -> Any: "WithIntercept", "async_run", "await_handler", + "sync_await_handler", "default_async_handlers", "default_handlers", "lazy_ask", diff --git a/packages/doeff-core-effects/src/handlers/mod.rs b/packages/doeff-core-effects/src/handlers/mod.rs index 3e15ddce..b826005a 100644 --- a/packages/doeff-core-effects/src/handlers/mod.rs +++ b/packages/doeff-core-effects/src/handlers/mod.rs @@ -8,14 +8,14 @@ use std::collections::HashMap; use std::sync::{Arc, Mutex}; use pyo3::prelude::*; -use pyo3::types::{PyDict, PyModule}; +use pyo3::types::PyDict; use crate::continuation::Continuation; use crate::do_ctrl::DoCtrl; use crate::effect::{ dispatch_from_shared, dispatch_into_python, dispatch_ref_as_python, DispatchEffect, - PyAcquireSemaphore, PyAsk, PyCreateSemaphore, PyGet, PyLocal, PyModify, PyPut, - PyPythonAsyncioAwaitEffect, PyReleaseSemaphore, PyResultSafeEffect, PyTell, + PyAcquireSemaphore, PyAsk, PyCreateExternalPromise, PyCreateSemaphore, PyGet, PyLocal, + PyModify, PyPut, PyPythonAsyncioAwaitEffect, PyReleaseSemaphore, PyResultSafeEffect, PyTell, }; use crate::error::VMError; use crate::ir_stream::{IRStream, IRStreamStep, StreamLocation}; @@ -299,22 +299,72 @@ fn parse_result_safe_python_effect(effect: &PyShared) -> Result }) } -fn get_sync_await_runner() -> Result { +fn get_sync_await_submitter() -> Result { Python::attach(|py| { - let module = PyModule::from_code( - py, - c"import asyncio\n\ndef _run_awaitable_sync(awaitable):\n return asyncio.run(awaitable)\n", - c"_doeff_await_bridge", - c"_doeff_await_bridge", - ) - .map_err(|e| e.to_string())?; + let module = py + .import("doeff.handlers.await_handlers") + .map_err(|e| e.to_string())?; let runner = module - .getattr("_run_awaitable_sync") + .getattr("_submit_awaitable_handle") .map_err(|e| e.to_string())?; Ok(PyShared::new(runner.unbind())) }) } +fn create_external_promise_effect() -> Result { + Python::attach(|py| { + let effect = py + .get_type::() + .call0() + .map_err(|e| pyerr_to_exception(py, e))?; + Ok(dispatch_from_shared(PyShared::new(effect.unbind()))) + }) +} + +fn external_promise_future(value: &Value) -> Result, PyException> { + let Value::ExternalPromise(handle) = value else { + return Err(PyException::type_error(format!( + "CreateExternalPromise returned non-external-promise value: {:?}", + value + ))); + }; + + Python::attach(|py| { + let handle_obj = value + .to_pyobject(py) + .map_err(|e| pyerr_to_exception(py, e))?; + let future_type = py + .import("doeff.effects.spawn") + .and_then(|module| module.getattr("Future")) + .map_err(|e| pyerr_to_exception(py, e))?; + let kwargs = PyDict::new(py); + kwargs + .set_item("_handle", handle_obj) + .map_err(|e| pyerr_to_exception(py, e))?; + if let Some(queue) = &handle.completion_queue { + kwargs + .set_item("_completion_queue", queue.bind(py)) + .map_err(|e| pyerr_to_exception(py, e))?; + } + future_type + .call((), Some(&kwargs)) + .map(|future| future.unbind().into_any()) + .map_err(|e| pyerr_to_exception(py, e)) + }) +} + +fn wait_on_external_promise_effect(value: &Value) -> Result { + let future = external_promise_future(value)?; + Python::attach(|py| { + let effect = py + .import("doeff.effects.wait") + .and_then(|module| module.getattr("wait")) + .and_then(|wait_fn| wait_fn.call1((future.bind(py),))) + .map_err(|e| pyerr_to_exception(py, e))?; + Ok(dispatch_from_shared(PyShared::new(effect.unbind()))) + }) +} + // --------------------------------------------------------------------------- // AwaitHandlerFactory + AwaitHandlerProgram // --------------------------------------------------------------------------- @@ -345,23 +395,56 @@ impl IRStreamFactory for AwaitHandlerFactory { } } +#[derive(Debug)] +enum AwaitPhase { + Idle, + AwaitExternalPromise { + continuation: Continuation, + awaitable: PyShared, + }, + AwaitSubmission { + continuation: Continuation, + promise: Value, + }, + AwaitResult { + continuation: Continuation, + }, +} + #[derive(Debug)] struct AwaitHandlerProgram { - pending_k: Option, + phase: AwaitPhase, } impl AwaitHandlerProgram { fn new() -> Self { - AwaitHandlerProgram { pending_k: None } + AwaitHandlerProgram { + phase: AwaitPhase::Idle, + } } fn current_phase_name(&self) -> &'static str { - if self.pending_k.is_some() { - "AwaitBridgeResult" - } else { - "Idle" + match self.phase { + AwaitPhase::Idle => "Idle", + AwaitPhase::AwaitExternalPromise { .. } => "AwaitExternalPromise", + AwaitPhase::AwaitSubmission { .. } => "AwaitSubmission", + AwaitPhase::AwaitResult { .. } => "AwaitResult", } } + + fn yield_perform(effect: Result) -> IRStreamStep { + match effect { + Ok(effect) => IRStreamStep::Yield(DoCtrl::Perform { effect }), + Err(exc) => IRStreamStep::Throw(exc), + } + } + + fn transfer_throw(continuation: Continuation, exception: PyException) -> IRStreamStep { + IRStreamStep::Yield(DoCtrl::TransferThrow { + continuation, + exception, + }) + } } impl IRStreamProgram for AwaitHandlerProgram { @@ -376,20 +459,11 @@ impl IRStreamProgram for AwaitHandlerProgram { if let Some(obj) = dispatch_into_python(effect.clone()) { return match parse_await_python_effect(&obj) { Ok(Some(awaitable)) => { - let runner = match get_sync_await_runner() { - Ok(func) => func, - Err(msg) => { - return IRStreamStep::Throw(PyException::type_error(format!( - "failed to initialize await runner: {msg}" - ))); - } + self.phase = AwaitPhase::AwaitExternalPromise { + continuation: k, + awaitable: PyShared::new(awaitable), }; - self.pending_k = Some(k); - IRStreamStep::NeedsPython(PythonCall::CallFunc { - func: runner, - args: vec![Value::Python(PyShared::new(awaitable))], - kwargs: vec![], - }) + Self::yield_perform(create_external_promise_effect()) } Ok(None) => IRStreamStep::Yield(DoCtrl::Pass { effect: dispatch_from_shared(obj), @@ -415,13 +489,45 @@ impl IRStreamProgram for AwaitHandlerProgram { _store: &mut RustStore, _scope: &mut ScopeStore, ) -> IRStreamStep { - if let Some(continuation) = self.pending_k.take() { - return IRStreamStep::Yield(DoCtrl::Resume { + match std::mem::replace(&mut self.phase, AwaitPhase::Idle) { + AwaitPhase::AwaitExternalPromise { + continuation, + awaitable, + } => { + let submitter = match get_sync_await_submitter() { + Ok(func) => func, + Err(msg) => { + return Self::transfer_throw( + continuation, + PyException::type_error(format!( + "failed to initialize await submitter: {msg}" + )), + ) + } + }; + self.phase = AwaitPhase::AwaitSubmission { + continuation, + promise: value.clone(), + }; + IRStreamStep::NeedsPython(PythonCall::CallFunc { + func: submitter, + args: vec![Value::Python(awaitable), value], + kwargs: vec![], + }) + } + AwaitPhase::AwaitSubmission { + continuation, + promise, + } => { + self.phase = AwaitPhase::AwaitResult { continuation }; + Self::yield_perform(wait_on_external_promise_effect(&promise)) + } + AwaitPhase::AwaitResult { continuation } => IRStreamStep::Yield(DoCtrl::Resume { continuation, value, - }); + }), + AwaitPhase::Idle => IRStreamStep::Return(value), } - IRStreamStep::Return(value) } fn throw( @@ -430,13 +536,12 @@ impl IRStreamProgram for AwaitHandlerProgram { _store: &mut RustStore, _scope: &mut ScopeStore, ) -> IRStreamStep { - if let Some(continuation) = self.pending_k.take() { - return IRStreamStep::Yield(DoCtrl::TransferThrow { - continuation, - exception: exc, - }); + match std::mem::replace(&mut self.phase, AwaitPhase::Idle) { + AwaitPhase::AwaitExternalPromise { continuation, .. } + | AwaitPhase::AwaitSubmission { continuation, .. } + | AwaitPhase::AwaitResult { continuation } => Self::transfer_throw(continuation, exc), + AwaitPhase::Idle => IRStreamStep::Throw(exc), } - IRStreamStep::Throw(exc) } } @@ -1361,7 +1466,10 @@ impl IRStreamProgram for WriterHandlerProgram { return match parse_writer_python_effect(&obj) { Ok(Some(message)) => { store.tell(message); - IRStreamStep::Yield(DoCtrl::Resume { + // Writer is a pure effect-to-store append. Tail-switching back to the + // user continuation avoids stacking extra sibling handler segments in + // hot logging paths used inside scheduler-driven Spawn/Gather loops. + IRStreamStep::Yield(DoCtrl::Transfer { continuation: k, value: Value::Unit, }) @@ -1782,22 +1890,27 @@ mod tests { let mut program = AwaitHandlerProgram::new(); let continuation = make_test_continuation(); let continuation_id = continuation.cont_id; - program.pending_k = Some(continuation); + program.phase = AwaitPhase::AwaitResult { continuation }; let location = IRStream::debug_location(&program).expect("await debug location"); assert_eq!(location.function_name, "AwaitHandler"); - assert_eq!(location.phase.as_deref(), Some("AwaitBridgeResult")); + assert_eq!(location.phase.as_deref(), Some("AwaitResult")); - let step = IRStream::resume(&mut program, Value::Int(12), &mut store, &mut scope); + let step = IRStream::resume( + &mut program, + Value::List(vec![Value::Int(12)]), + &mut store, + &mut scope, + ); match step { - IRStreamStep::Yield(DoCtrl::Resume { + IRStreamStep::Yield(DoCtrl::Transfer { continuation, value, }) => { assert_eq!(continuation.cont_id, continuation_id); assert_eq!(value.as_int(), Some(12)); } - _ => panic!("expected IRStream Yield(Resume)"), + _ => panic!("expected IRStream Yield(Transfer)"), } let location = IRStream::debug_location(&program).expect("await debug location"); diff --git a/packages/doeff-core-effects/src/scheduler/mod.rs b/packages/doeff-core-effects/src/scheduler/mod.rs index 03f670c3..13bd2fe7 100644 --- a/packages/doeff-core-effects/src/scheduler/mod.rs +++ b/packages/doeff-core-effects/src/scheduler/mod.rs @@ -5,7 +5,7 @@ use std::cmp::Ordering as CmpOrdering; use std::collections::{BinaryHeap, HashMap, HashSet, VecDeque}; -use std::sync::atomic::{AtomicU64, Ordering}; +use std::sync::atomic::{AtomicU64, AtomicUsize, Ordering}; use std::sync::{Arc, Mutex, OnceLock, Weak}; use pyo3::prelude::*; @@ -64,6 +64,9 @@ pub enum SchedulerEffect { Gather { items: Vec, }, + Wait { + item: Waitable, + }, Race { items: Vec, }, @@ -174,7 +177,7 @@ struct SemaphoreRuntimeState { holders: HashMap, u64>, } -#[derive(Clone, Debug, PartialEq, Eq)] +#[derive(Clone, Copy, Debug, PartialEq, Eq)] enum WaitMode { All, Any, @@ -199,6 +202,8 @@ struct WaitRequest { continuation: Continuation, items: Vec, mode: WaitMode, + // Shared across per-item registrations so each completion updates the same waiter state. + remaining: Arc, waiting_task: Option, waiting_store: RustStore, } @@ -215,6 +220,28 @@ impl WaitRequest { }, } } + + fn note_completion(&self) -> bool { + loop { + let remaining = self.remaining.load(Ordering::Relaxed); + if remaining == 0 { + return true; + } + + if self + .remaining + .compare_exchange( + remaining, + remaining - 1, + Ordering::Relaxed, + Ordering::Relaxed, + ) + .is_ok() + { + return remaining == 1; + } + } + } } #[derive(Clone, Debug)] @@ -476,6 +503,9 @@ pub struct SchedulerState { pub promises: HashMap, semaphores: HashMap, waiters: HashMap>, + waitables_by_owner: HashMap>, + active_wait_owners: HashSet, + external_waiter_count: usize, pending_gather_fail_fast: HashMap, external_completion_queue: Option, cancel_requested: HashSet, @@ -706,6 +736,21 @@ fn parse_scheduler_python_effect( return Ok(Some(SchedulerEffect::Race { items: waitables })); } + let wait_effect_type = py + .import("doeff.effects.wait") + .and_then(|module| module.getattr("WaitEffect")) + .map_err(|e| e.to_string())?; + if obj + .is_instance(&wait_effect_type) + .map_err(|e| e.to_string())? + { + let future = obj.getattr("future").map_err(|e| e.to_string())?; + let Some(item) = extract_waitable(&future) else { + return Err("WaitEffect.future must be a waitable handle".to_string()); + }; + return Ok(Some(SchedulerEffect::Wait { item })); + } + if obj.extract::>().is_ok() { return Ok(Some(SchedulerEffect::CreatePromise)); } @@ -1013,6 +1058,9 @@ impl SchedulerState { promises: HashMap::new(), semaphores: HashMap::new(), waiters: HashMap::new(), + waitables_by_owner: HashMap::new(), + active_wait_owners: HashSet::new(), + external_waiter_count: 0, pending_gather_fail_fast: HashMap::new(), external_completion_queue: None, cancel_requested: HashSet::new(), @@ -1080,9 +1128,7 @@ impl SchedulerState { } fn has_external_waiters(&self) -> bool { - self.waiters.iter().any(|(item, waiters)| { - matches!(item, Waitable::ExternalPromise(_)) && !waiters.is_empty() - }) + self.external_waiter_count > 0 } fn parse_external_completion_item( @@ -1683,6 +1729,11 @@ impl SchedulerState { let Some(waiters_for_item) = self.waiters.remove(&waitable) else { return highest_ready; }; + if matches!(waitable, Waitable::ExternalPromise(_)) { + self.external_waiter_count = self + .external_waiter_count + .saturating_sub(waiters_for_item.len()); + } for waiter in waiters_for_item { let waiter_id = waiter.continuation.cont_id; @@ -1694,10 +1745,7 @@ impl SchedulerState { continue; } - let ready = match waiter.mode { - WaitMode::All => self.all_done(&waiter.items), - WaitMode::Any => self.any_done(&waiter.items), - }; + let ready = waiter.note_completion(); if ready { if let Some(woken) = self.stage_ready_waiter(waiter) { @@ -1762,6 +1810,8 @@ impl SchedulerState { None => WaitOwner::Root { cont_id }, }; self.pending_gather_fail_fast.remove(&owner); + self.active_wait_owners.remove(&owner); + let waitables = self.waitables_by_owner.remove(&owner).unwrap_or_default(); if let Some(ready_resume) = self.ready_root_resumes.get(&cont_id) { if ready_resume.waiting_task == waiting_task { self.ready_root_resumes.remove(&cont_id); @@ -1812,12 +1862,23 @@ impl SchedulerState { } } - self.waiters.retain(|_key, pending| { - pending.retain(|waiter| { - !(waiter.waiting_task == waiting_task && waiter.continuation.cont_id == cont_id) - }); - !pending.is_empty() - }); + for waitable in waitables { + let mut should_remove_key = false; + if let Some(pending) = self.waiters.get_mut(&waitable) { + let original_len = pending.len(); + pending.retain(|waiter| { + !(waiter.waiting_task == waiting_task && waiter.continuation.cont_id == cont_id) + }); + let removed = original_len.saturating_sub(pending.len()); + if matches!(waitable, Waitable::ExternalPromise(_)) { + self.external_waiter_count = self.external_waiter_count.saturating_sub(removed); + } + should_remove_key = pending.is_empty(); + } + if should_remove_key { + self.waiters.remove(&waitable); + } + } } pub fn task_cont(&self, task_id: TaskId) -> Option { @@ -1828,20 +1889,8 @@ impl SchedulerState { } fn owner_is_waiting(&self, owner: WaitOwner) -> bool { - if self.pending_gather_fail_fast.contains_key(&owner) { - return true; - } - match owner { - WaitOwner::Task { task_id, cont_id } => matches!( - self.tasks.get(&task_id), - Some(TaskState::Pending { cont, .. }) if cont.cont_id == cont_id - ), - WaitOwner::Root { cont_id } => self.waiters.values().any(|requests| { - requests.iter().any(|request| { - request.waiting_task.is_none() && request.continuation.cont_id == cont_id - }) - }), - } + self.pending_gather_fail_fast.contains_key(&owner) + || self.active_wait_owners.contains(&owner) } fn gather_wait_request_for_failed_task(&self, running_task: TaskId) -> Option { @@ -2006,36 +2055,59 @@ impl SchedulerState { None } - pub fn wait_on_all(&mut self, items: &[Waitable], k: Continuation, store: &RustStore) { + fn register_waiter( + &mut self, + items: &[Waitable], + k: Continuation, + store: &RustStore, + mode: WaitMode, + ) { + let pending_items: Vec<_> = items + .iter() + .copied() + .filter(|item| !self.is_done(*item)) + .collect(); + if pending_items.is_empty() { + return; + } + + let owner = match self.current_task { + Some(task_id) => WaitOwner::Task { + task_id, + cont_id: k.cont_id, + }, + None => WaitOwner::Root { cont_id: k.cont_id }, + }; + self.active_wait_owners.insert(owner); + self.waitables_by_owner.insert(owner, pending_items.clone()); + + let remaining = match mode { + WaitMode::All => pending_items.len(), + WaitMode::Any => 1, + }; let waiter = WaitRequest { continuation: k, items: items.to_vec(), - mode: WaitMode::All, + mode, + remaining: Arc::new(AtomicUsize::new(remaining)), waiting_task: self.current_task, waiting_store: store.clone(), }; - for item in items { - if !self.is_done(*item) { - self.waiters.entry(*item).or_default().push(waiter.clone()); + for item in pending_items { + if matches!(item, Waitable::ExternalPromise(_)) { + self.external_waiter_count += 1; } + self.waiters.entry(item).or_default().push(waiter.clone()); } } - pub fn wait_on_any(&mut self, items: &[Waitable], k: Continuation, store: &RustStore) { - let waiter = WaitRequest { - continuation: k, - items: items.to_vec(), - mode: WaitMode::Any, - waiting_task: self.current_task, - waiting_store: store.clone(), - }; + pub fn wait_on_all(&mut self, items: &[Waitable], k: Continuation, store: &RustStore) { + self.register_waiter(items, k, store, WaitMode::All); + } - for item in items { - if !self.is_done(*item) { - self.waiters.entry(*item).or_default().push(waiter.clone()); - } - } + pub fn wait_on_any(&mut self, items: &[Waitable], k: Continuation, store: &RustStore) { + self.register_waiter(items, k, store, WaitMode::Any); } fn is_done(&self, item: Waitable) -> bool { @@ -2047,14 +2119,6 @@ impl SchedulerState { } } - fn all_done(&self, items: &[Waitable]) -> bool { - items.iter().all(|item| self.is_done(*item)) - } - - fn any_done(&self, items: &[Waitable]) -> bool { - items.iter().any(|item| self.is_done(*item)) - } - fn collect_all_result(&mut self, items: &[Waitable]) -> Option> { self.execution_context_task_override = None; let mut results = Vec::with_capacity(items.len()); @@ -2288,6 +2352,78 @@ impl SchedulerState { TransferNextOutcome::None => resume_to_continuation(k, Value::Unit), } } + + fn transfer_ready_owner( + &mut self, + owner: WaitOwner, + store: &mut RustStore, + ) -> Result, PyException> { + match owner { + WaitOwner::Task { task_id, cont_id } => { + if !self.ready_task_ids.contains(&task_id) { + return Ok(None); + } + + let (task_k, resume_outcome, merge_items) = match self.tasks.get_mut(&task_id) { + Some(TaskState::Pending { + cont, + resume_outcome, + pending_log_merge_items, + .. + }) if cont.cont_id == cont_id => { + let continuation = cont.clone(); + let outcome = resume_outcome.take(); + let pending_merge = pending_log_merge_items.take(); + (continuation, outcome, pending_merge) + } + _ => return Ok(None), + }; + + self.ready_task_ids.remove(&task_id); + if let Some(old_id) = self.current_task { + if old_id != task_id { + self.save_task_store(old_id, store)?; + } + } + self.load_task_store(task_id, store)?; + self.current_task = Some(task_id); + if let Some(items) = merge_items.as_ref() { + self.merge_gather_logs(items, store); + } + let step = match resume_outcome { + Some(Err(error)) => throw_to_continuation(task_k, error), + Some(Ok(value)) => resume_to_continuation(task_k, value), + None => transfer_to_continuation(task_k, Value::Unit), + }; + Ok(Some(step)) + } + WaitOwner::Root { cont_id } => { + let Some(ready_root) = self.ready_root_resumes.remove(&cont_id) else { + return Ok(None); + }; + self.clear_waiters_for_owner( + ready_root.waiting_task, + ready_root.continuation.cont_id, + ); + if let Some(waiting_task) = ready_root.waiting_task { + self.load_task_store(waiting_task, store)?; + self.current_task = Some(waiting_task); + } else { + *store = ready_root.waiting_store; + self.current_task = None; + } + if let Some(items) = &ready_root.merge_items { + self.merge_gather_logs(items, store); + } + + let step = match ready_root.outcome { + Ok(value) => resume_to_continuation(ready_root.continuation, value), + Err(error) => throw_to_continuation(ready_root.continuation, error), + }; + Ok(Some(step)) + } + } + } } impl Drop for SchedulerState { @@ -2397,6 +2533,22 @@ impl SchedulerProgram { make_async_external_wait_step().unwrap_or_else(IRStreamStep::Throw) } + fn wait_owner(&self, waiting_task: Option, cont_id: ContId) -> WaitOwner { + match (&self.phase, waiting_task) { + ( + SchedulerPhase::Driving { + owner, + running_task, + }, + Some(task_id), + ) if *running_task == task_id => *owner, + _ => match waiting_task { + Some(task_id) => WaitOwner::Task { task_id, cont_id }, + None => WaitOwner::Root { cont_id }, + }, + } + } + fn continue_simple_transfer( &mut self, k_user: Continuation, @@ -2450,6 +2602,16 @@ impl SchedulerProgram { store: &mut RustStore, ) -> IRStreamStep { let mut state = self.state.lock().expect("Scheduler lock poisoned"); + if !state.owner_is_waiting(owner) { + match state.transfer_ready_owner(owner, store) { + Ok(Some(step)) => { + self.phase = SchedulerPhase::Idle; + return step; + } + Ok(None) => {} + Err(error) => return IRStreamStep::Throw(error), + } + } match state.transfer_next(store) { TransferNextOutcome::Step(step) => { let next_running_task = state.current_task; @@ -2534,6 +2696,16 @@ impl SchedulerProgram { store: &mut RustStore, ) -> IRStreamStep { let mut state = self.state.lock().expect("Scheduler lock poisoned"); + if !state.owner_is_waiting(owner) { + match state.transfer_ready_owner(owner, store) { + Ok(Some(step)) => { + self.phase = SchedulerPhase::Idle; + return step; + } + Ok(None) => {} + Err(error) => return IRStreamStep::Throw(error), + } + } match state.transfer_next(store) { TransferNextOutcome::Step(step) => { let next_running_task = state.current_task; @@ -2592,6 +2764,38 @@ impl SchedulerProgram { } state.wait_on_all(&items, k_user.clone(), store); + let owner = self.wait_owner(waiting_task, k_user.cont_id); + drop(state); + self.continue_wait_transfer( + owner, + "deadlock: Gather blocked with no runnable tasks".to_string(), + store, + ) + } + + fn handle_wait( + &mut self, + k_user: Continuation, + item: Waitable, + store: &mut RustStore, + ) -> IRStreamStep { + let mut state = self.state.lock().expect("Scheduler lock poisoned"); + let items = [item]; + let waiting_task = state.current_task; + if let Some(result) = state.collect_any_result(&items) { + state.clear_waiters_for_owner(waiting_task, k_user.cont_id); + return match result { + Ok(value) => resume_to_continuation(k_user, value), + Err(error) => throw_to_continuation(k_user, error), + }; + } + if let Some(waiting_task) = waiting_task { + if let Err(error) = state.suspend_task_for_wait(waiting_task, k_user.clone()) { + return IRStreamStep::Throw(error); + } + } + + state.wait_on_any(&items, k_user.clone(), store); let owner = match waiting_task { Some(task_id) => WaitOwner::Task { task_id, @@ -2604,7 +2808,7 @@ impl SchedulerProgram { drop(state); self.continue_wait_transfer( owner, - "deadlock: Gather blocked with no runnable tasks".to_string(), + "deadlock: Wait blocked with no runnable tasks".to_string(), store, ) } @@ -2631,15 +2835,7 @@ impl SchedulerProgram { } state.wait_on_any(&items, k_user.clone(), store); - let owner = match waiting_task { - Some(task_id) => WaitOwner::Task { - task_id, - cont_id: k_user.cont_id, - }, - None => WaitOwner::Root { - cont_id: k_user.cont_id, - }, - }; + let owner = self.wait_owner(waiting_task, k_user.cont_id); drop(state); self.continue_wait_transfer( owner, @@ -2814,8 +3010,7 @@ impl SchedulerProgram { if let Err(store_error) = state.save_task_store(running_task, store) { return IRStreamStep::Throw(store_error); } - if let Err(done_error) = - state.mark_task_done(running_task, Err(error.clone())) + if let Err(done_error) = state.mark_task_done(running_task, Err(error.clone())) { return IRStreamStep::Throw(done_error); } @@ -2827,14 +3022,6 @@ impl SchedulerProgram { } } - if task_already_done && !owner_still_waiting { - self.phase = SchedulerPhase::Idle; - return match outcome_for_fallback { - Ok(value) => IRStreamStep::Return(value), - Err(error) => IRStreamStep::Throw(error), - }; - } - if !task_already_done { if let Err(error) = state.save_task_store(running_task, store) { return IRStreamStep::Throw(error); @@ -2846,6 +3033,17 @@ impl SchedulerProgram { owner_still_waiting = state.owner_is_waiting(owner); } + if !owner_still_waiting { + match state.transfer_ready_owner(owner, store) { + Ok(Some(step)) => { + self.phase = SchedulerPhase::Idle; + return step; + } + Ok(None) => {} + Err(error) => return IRStreamStep::Throw(error), + } + } + match state.transfer_next(store) { TransferNextOutcome::Step(step) => { let next_running_task = state.current_task; @@ -2959,6 +3157,8 @@ impl IRStreamProgram for SchedulerProgram { SchedulerEffect::Gather { items } => self.handle_gather(k_user, items, store), + SchedulerEffect::Wait { item } => self.handle_wait(k_user, item, store), + SchedulerEffect::Race { items } => self.handle_race(k_user, items, store), SchedulerEffect::CreatePromise => { @@ -3027,24 +3227,7 @@ impl IRStreamProgram for SchedulerProgram { } } state.wait_on_any(&items, k_user.clone(), store); - let owner = match (&self.phase, waiting_task) { - ( - SchedulerPhase::Driving { - owner, - running_task, - }, - Some(task_id), - ) if *running_task == task_id => *owner, - _ => match waiting_task { - Some(task_id) => WaitOwner::Task { - task_id, - cont_id: k_user.cont_id, - }, - None => WaitOwner::Root { - cont_id: k_user.cont_id, - }, - }, - }; + let owner = self.wait_owner(waiting_task, k_user.cont_id); drop(state); self.continue_wait_transfer( owner, @@ -3528,6 +3711,68 @@ mod tests { .into_any() } + fn make_waitable_promise_object(py: Python<'_>, promise_id: PromiseId) -> Py { + let handle = PyDict::new(py); + handle + .set_item("type", "Promise") + .expect("waitable promise should accept type"); + handle + .set_item("promise_id", promise_id.raw()) + .expect("waitable promise should accept promise_id"); + + let types_mod = py + .import("types") + .expect("failed to import Python types module"); + let namespace = types_mod + .getattr("SimpleNamespace") + .expect("types.SimpleNamespace must exist"); + let kwargs = PyDict::new(py); + kwargs + .set_item("_handle", handle) + .expect("namespace should accept _handle"); + namespace + .call((), Some(&kwargs)) + .expect("failed to construct waitable promise object") + .unbind() + .into_any() + } + + fn make_waitable_external_promise_object(py: Python<'_>, promise_id: PromiseId) -> Py { + let handle = PyDict::new(py); + handle + .set_item("type", "ExternalPromise") + .expect("waitable external promise should accept type"); + handle + .set_item("promise_id", promise_id.raw()) + .expect("waitable external promise should accept promise_id"); + + let types_mod = py + .import("types") + .expect("failed to import Python types module"); + let namespace = types_mod + .getattr("SimpleNamespace") + .expect("types.SimpleNamespace must exist"); + let kwargs = PyDict::new(py); + kwargs + .set_item("_handle", handle) + .expect("namespace should accept _handle"); + namespace + .call((), Some(&kwargs)) + .expect("failed to construct waitable external promise object") + .unbind() + .into_any() + } + + fn make_wait_effect(py: Python<'_>, item: Py) -> DispatchEffect { + let effect = py + .import("doeff.effects.wait") + .and_then(|module| module.getattr("wait")) + .and_then(|wait_fn| wait_fn.call1((item.bind(py),))) + .expect("failed to construct WaitEffect") + .unbind(); + dispatch_from_shared(PyShared::new(effect)) + } + fn make_semaphore_object(py: Python<'_>, semaphore_id: u64, state_id: u64) -> Py { Py::new( py, @@ -3587,6 +3832,44 @@ mod tests { dispatch_from_shared(PyShared::new(effect)) } + fn make_gather_effect(py: Python<'_>, items: Vec>) -> DispatchEffect { + let items = PyTuple::new(py, items) + .expect("failed to build GatherEffect.items tuple") + .unbind() + .into_any(); + let partial_results = py.None(); + let effect = Py::new( + py, + PyClassInitializer::from(PyEffectBase { + tag: DoExprTag::Effect as u8, + }) + .add_subclass(PyGather { + items, + _partial_results: partial_results, + }), + ) + .expect("failed to construct GatherEffect") + .into_any(); + dispatch_from_shared(PyShared::new(effect)) + } + + fn make_race_effect(py: Python<'_>, items: Vec>) -> DispatchEffect { + let futures = PyTuple::new(py, items) + .expect("failed to build RaceEffect.futures tuple") + .unbind() + .into_any(); + let effect = Py::new( + py, + PyClassInitializer::from(PyEffectBase { + tag: DoExprTag::Effect as u8, + }) + .add_subclass(PyRace { futures }), + ) + .expect("failed to construct RaceEffect") + .into_any(); + dispatch_from_shared(PyShared::new(effect)) + } + #[test] fn test_transfer_to_continuation_started_emits_transfer() { let cont = make_test_continuation(); @@ -5238,6 +5521,244 @@ mod tests { }); } + #[test] + fn test_blocked_gather_preserves_outer_driving_owner() { + Python::attach(|py| { + let state = Arc::new(Mutex::new(SchedulerState::new())); + let mut program = SchedulerProgram::new(state.clone()); + let mut store = RustStore::new(); + let mut _scope = ScopeStore::default(); + + let gather_owner_k = make_test_continuation(); + let waiter_k = make_test_continuation(); + let runnable_k = make_test_continuation(); + let driver_k = make_test_continuation(); + + let (promise_id, waiting_task, runnable_task) = { + let mut guard = state.lock().expect("Scheduler lock poisoned"); + let promise_id = guard.alloc_promise_id(); + guard.promises.insert(promise_id, PromiseState::Pending); + + let waiting_task = guard.alloc_task_id(); + guard.tasks.insert( + waiting_task, + TaskState::Pending { + cont: waiter_k.clone(), + store: TaskStore::Shared, + resume_outcome: None, + priority: PRIORITY_NORMAL, + pending_log_merge_items: None, + }, + ); + + let runnable_task = guard.alloc_task_id(); + guard.tasks.insert( + runnable_task, + TaskState::Pending { + cont: runnable_k.clone(), + store: TaskStore::Shared, + resume_outcome: None, + priority: PRIORITY_NORMAL, + pending_log_merge_items: None, + }, + ); + guard.enqueue_ready_task(runnable_task, PRIORITY_NORMAL); + guard.current_task = Some(waiting_task); + + (promise_id, waiting_task, runnable_task) + }; + + program.phase = SchedulerPhase::Driving { + owner: WaitOwner::Root { + cont_id: gather_owner_k.cont_id, + }, + running_task: waiting_task, + }; + + let step = IRStreamProgram::start( + &mut program, + py, + make_gather_effect(py, vec![make_waitable_promise_object(py, promise_id)]), + driver_k, + &mut store, + &mut _scope, + ); + + assert!( + step_targets_cont_id(&step, runnable_k.cont_id), + "blocked gather should transfer into another runnable task, got {:?}", + step + ); + assert!(matches!( + &program.phase, + SchedulerPhase::Driving { + owner: WaitOwner::Root { cont_id }, + running_task, + } if *cont_id == gather_owner_k.cont_id + && *running_task == runnable_task + )); + }); + } + + #[test] + fn test_blocked_race_preserves_outer_driving_owner() { + Python::attach(|py| { + let state = Arc::new(Mutex::new(SchedulerState::new())); + let mut program = SchedulerProgram::new(state.clone()); + let mut store = RustStore::new(); + let mut _scope = ScopeStore::default(); + + let race_owner_k = make_test_continuation(); + let waiter_k = make_test_continuation(); + let runnable_k = make_test_continuation(); + let driver_k = make_test_continuation(); + + let (promise_id, waiting_task, runnable_task) = { + let mut guard = state.lock().expect("Scheduler lock poisoned"); + let promise_id = guard.alloc_promise_id(); + guard.promises.insert(promise_id, PromiseState::Pending); + + let waiting_task = guard.alloc_task_id(); + guard.tasks.insert( + waiting_task, + TaskState::Pending { + cont: waiter_k.clone(), + store: TaskStore::Shared, + resume_outcome: None, + priority: PRIORITY_NORMAL, + pending_log_merge_items: None, + }, + ); + + let runnable_task = guard.alloc_task_id(); + guard.tasks.insert( + runnable_task, + TaskState::Pending { + cont: runnable_k.clone(), + store: TaskStore::Shared, + resume_outcome: None, + priority: PRIORITY_NORMAL, + pending_log_merge_items: None, + }, + ); + guard.enqueue_ready_task(runnable_task, PRIORITY_NORMAL); + guard.current_task = Some(waiting_task); + + (promise_id, waiting_task, runnable_task) + }; + + program.phase = SchedulerPhase::Driving { + owner: WaitOwner::Root { + cont_id: race_owner_k.cont_id, + }, + running_task: waiting_task, + }; + + let step = IRStreamProgram::start( + &mut program, + py, + make_race_effect(py, vec![make_waitable_promise_object(py, promise_id)]), + driver_k, + &mut store, + &mut _scope, + ); + + assert!( + step_targets_cont_id(&step, runnable_k.cont_id), + "blocked race should transfer into another runnable task, got {:?}", + step + ); + assert!(matches!( + &program.phase, + SchedulerPhase::Driving { + owner: WaitOwner::Root { cont_id }, + running_task, + } if *cont_id == race_owner_k.cont_id + && *running_task == runnable_task + )); + }); + } + + #[test] + fn test_blocked_wait_uses_direct_wait_owner() { + Python::attach(|py| { + let state = Arc::new(Mutex::new(SchedulerState::new())); + let mut program = SchedulerProgram::new(state.clone()); + let mut store = RustStore::new(); + let mut _scope = ScopeStore::default(); + + let outer_owner_k = make_test_continuation(); + let waiter_k = make_test_continuation(); + let runnable_k = make_test_continuation(); + let driver_k = make_test_continuation(); + + let (promise_id, waiting_task, runnable_task) = { + let mut guard = state.lock().expect("Scheduler lock poisoned"); + let promise_id = guard.alloc_promise_id(); + guard.promises.insert(promise_id, PromiseState::Pending); + + let waiting_task = guard.alloc_task_id(); + guard.tasks.insert( + waiting_task, + TaskState::Pending { + cont: waiter_k.clone(), + store: TaskStore::Shared, + resume_outcome: None, + priority: PRIORITY_NORMAL, + pending_log_merge_items: None, + }, + ); + + let runnable_task = guard.alloc_task_id(); + guard.tasks.insert( + runnable_task, + TaskState::Pending { + cont: runnable_k.clone(), + store: TaskStore::Shared, + resume_outcome: None, + priority: PRIORITY_NORMAL, + pending_log_merge_items: None, + }, + ); + guard.enqueue_ready_task(runnable_task, PRIORITY_NORMAL); + guard.current_task = Some(waiting_task); + + (promise_id, waiting_task, runnable_task) + }; + + program.phase = SchedulerPhase::Driving { + owner: WaitOwner::Root { + cont_id: outer_owner_k.cont_id, + }, + running_task: waiting_task, + }; + + let step = IRStreamProgram::start( + &mut program, + py, + make_wait_effect(py, make_waitable_external_promise_object(py, promise_id)), + driver_k.clone(), + &mut store, + &mut _scope, + ); + + assert!( + step_targets_cont_id(&step, runnable_k.cont_id), + "blocked wait should transfer into another runnable task, got {:?}", + step + ); + assert!(matches!( + &program.phase, + SchedulerPhase::Driving { + owner: WaitOwner::Task { task_id, cont_id }, + running_task, + } if *task_id == waiting_task + && *cont_id == driver_k.cont_id + && *running_task == runnable_task + )); + }); + } + // ----------------------------------------------------------------------- // ISSUE-VM-003: Gather collects results from multiple tasks/promises // ----------------------------------------------------------------------- @@ -5651,6 +6172,79 @@ mod tests { assert_eq!(result.unwrap().as_int(), Some(99)); } + #[test] + fn test_wait_on_all_counter_tracks_only_pending_items() { + let mut state = SchedulerState::new(); + let done_task = state.alloc_task_id(); + let pending_task = state.alloc_task_id(); + + state.tasks.insert( + done_task, + TaskState::Done { + result: Ok(Value::Int(10)), + store: TaskStore::Shared, + }, + ); + state.tasks.insert( + pending_task, + TaskState::Pending { + cont: make_test_continuation(), + store: TaskStore::Shared, + resume_outcome: None, + priority: PRIORITY_NORMAL, + pending_log_merge_items: None, + }, + ); + + let waiter = make_test_continuation(); + state.wait_on_all( + &[Waitable::Task(done_task), Waitable::Task(pending_task)], + waiter, + &RustStore::new(), + ); + + let waiters = state + .waiters + .get(&Waitable::Task(pending_task)) + .expect("pending task should have a registered waiter"); + assert_eq!(waiters.len(), 1); + assert_eq!(waiters[0].remaining.load(Ordering::Relaxed), 1); + } + + #[test] + fn test_wait_on_any_counter_only_needs_first_completion() { + let mut state = SchedulerState::new(); + let t0 = state.alloc_task_id(); + let t1 = state.alloc_task_id(); + + for task_id in [t0, t1] { + state.tasks.insert( + task_id, + TaskState::Pending { + cont: make_test_continuation(), + store: TaskStore::Shared, + resume_outcome: None, + priority: PRIORITY_NORMAL, + pending_log_merge_items: None, + }, + ); + } + + let waiter = make_test_continuation(); + state.wait_on_any( + &[Waitable::Task(t0), Waitable::Task(t1)], + waiter, + &RustStore::new(), + ); + + let waiters = state + .waiters + .get(&Waitable::Task(t0)) + .expect("pending task should have a registered waiter"); + assert_eq!(waiters.len(), 1); + assert_eq!(waiters[0].remaining.load(Ordering::Relaxed), 1); + } + #[test] fn test_scheduler_store_save_load() { let mut state = SchedulerState::new(); diff --git a/packages/doeff-core-effects/src/sentinels.rs b/packages/doeff-core-effects/src/sentinels.rs index 16afbc2a..0307c9ef 100644 --- a/packages/doeff-core-effects/src/sentinels.rs +++ b/packages/doeff-core-effects/src/sentinels.rs @@ -99,6 +99,13 @@ pub fn register_sentinels(m: &Bound<'_, PyModule>) -> PyResult<()> { "AwaitHandler".to_string(), ))), )?; + m.add( + "sync_await_handler", + PyRustHandlerSentinel::new(Arc::new(RustKleisli::new( + Arc::new(AwaitHandlerFactory), + "sync_await_handler".to_string(), + ))), + )?; m.add_function(wrap_pyfunction!(_notify_semaphore_handle_dropped, m)?)?; m.add_function(wrap_pyfunction!(_debug_scheduler_semaphore_count, m)?)?; m.add_function(wrap_pyfunction!(_debug_semaphore_exists, m)?)?; diff --git a/packages/doeff-vm-core/src/interceptor_state.rs b/packages/doeff-vm-core/src/interceptor_state.rs index 679491c8..8e3528c8 100644 --- a/packages/doeff-vm-core/src/interceptor_state.rs +++ b/packages/doeff-vm-core/src/interceptor_state.rs @@ -35,9 +35,22 @@ impl InterceptorState { ) -> Vec { let mut chain = Vec::new(); let mut seen = HashSet::new(); - Self::walk_segment_chain(current_segment, segments, &mut chain, &mut seen); + let mut visited_segments = HashSet::new(); + Self::walk_segment_chain( + current_segment, + segments, + &mut chain, + &mut seen, + &mut visited_segments, + ); for origin_seg_id in dispatch_origin_segments { - Self::walk_segment_chain(Some(*origin_seg_id), segments, &mut chain, &mut seen); + Self::walk_segment_chain( + Some(*origin_seg_id), + segments, + &mut chain, + &mut seen, + &mut visited_segments, + ); } chain } @@ -47,9 +60,13 @@ impl InterceptorState { segments: &SegmentArena, chain: &mut Vec, seen: &mut HashSet, + visited_segments: &mut HashSet, ) { let mut cursor = start; while let Some(seg_id) = cursor { + if !visited_segments.insert(seg_id) { + break; + } let Some(seg) = segments.get(seg_id) else { break; }; @@ -166,3 +183,76 @@ impl InterceptorState { Ok(body_seg) } } + +#[cfg(test)] +mod tests { + use super::*; + use pyo3::Python; + use std::sync::Arc; + + use crate::do_ctrl::DoCtrl; + use crate::kleisli::{Kleisli, KleisliDebugInfo}; + use crate::value::Value; + + #[derive(Debug)] + struct DummyKleisli; + + impl Kleisli for DummyKleisli { + fn apply(&self, _py: Python<'_>, _args: Vec) -> Result { + unreachable!("test dummy should never be invoked") + } + + fn debug_info(&self) -> KleisliDebugInfo { + KleisliDebugInfo { + name: "DummyKleisli".to_string(), + file: None, + line: None, + } + } + } + + fn dummy_interceptor_segment(marker: Marker, caller: Option) -> Segment { + let mut seg = Segment::new(marker, caller); + seg.kind = SegmentKind::InterceptorBoundary { + interceptor: Arc::new(DummyKleisli), + types: None, + mode: InterceptMode::Include, + metadata: None, + }; + seg + } + + #[test] + fn current_chain_deduplicates_shared_segment_tails() { + let mut arena = SegmentArena::new(); + + let tail_marker = Marker::fresh(); + let tail = arena.alloc(dummy_interceptor_segment(tail_marker, None)); + + let shared_marker = Marker::fresh(); + let shared = arena.alloc(dummy_interceptor_segment(shared_marker, Some(tail))); + + let current_marker = Marker::fresh(); + let current = arena.alloc(dummy_interceptor_segment(current_marker, Some(shared))); + + let origin_a_marker = Marker::fresh(); + let origin_a = arena.alloc(dummy_interceptor_segment(origin_a_marker, Some(shared))); + + let origin_b_marker = Marker::fresh(); + let origin_b = arena.alloc(dummy_interceptor_segment(origin_b_marker, Some(shared))); + + let state = InterceptorState::default(); + let chain = state.current_chain(Some(current), &arena, &[origin_a, origin_b]); + + assert_eq!( + chain, + vec![ + current_marker, + shared_marker, + tail_marker, + origin_a_marker, + origin_b_marker, + ] + ); + } +} diff --git a/packages/doeff-vm-core/src/vm/dispatch.rs b/packages/doeff-vm-core/src/vm/dispatch.rs index 1dad32a5..ab6872c1 100644 --- a/packages/doeff-vm-core/src/vm/dispatch.rs +++ b/packages/doeff-vm-core/src/vm/dispatch.rs @@ -436,30 +436,49 @@ impl VM { &self, scope: &Continuation, ) -> Option { - let mut start_seg_id = scope.segment_id; - if self.segments.get(start_seg_id).is_none() { - return None; - } + self.root_delegate_parent_segment_id( + scope, + "EvalInScope parent chain must be Delegate-created dispatch continuations", + ) + } - // When EvalInScope is reached through Delegate chains, the continuation - // passed to handlers may wrap the original effect-site continuation in - // `parent`. Replay should use the origin scope so wrapper interceptors - // around the effect site remain visible. - let mut cursor = scope.parent.as_deref(); + fn root_delegate_parent_segment_id( + &self, + continuation: &Continuation, + assert_message: &str, + ) -> Option { + let mut start_seg_id = self + .continuation_chain_segment_id(continuation) + .or_else(|| continuation.captured_caller)?; + + // Delegate wraps the original effect-site continuation in `parent`. + // Operations that need the caller-visible handler stack or interceptor + // topology must resolve that wrapper back to the original segment. + let mut cursor = continuation.parent.as_deref(); while let Some(parent) = cursor { - assert!( - parent.dispatch_id.is_some(), - "EvalInScope parent chain must be Delegate-created dispatch continuations" - ); - if self.segments.get(parent.segment_id).is_none() { - break; + assert!(parent.dispatch_id.is_some(), "{}", assert_message); + if let Some(parent_seg_id) = self + .continuation_chain_segment_id(parent) + .or_else(|| parent.captured_caller) + { + start_seg_id = parent_seg_id; } - start_seg_id = parent.segment_id; cursor = parent.parent.as_deref(); } Some(start_seg_id) } + fn continuation_chain_segment_id(&self, continuation: &Continuation) -> Option { + continuation + .captured_caller + .filter(|seg_id| self.segments.get(*seg_id).is_some()) + .or_else(|| { + self.segments + .get(continuation.segment_id) + .map(|_| continuation.segment_id) + }) + } + fn structural_kind_for_marker(&self, marker: Marker) -> SegmentKind { let Some(entry) = self.interceptor_state.get_entry(marker) else { return SegmentKind::Normal; @@ -994,7 +1013,13 @@ impl VM { .and_then(|dispatch_id| { self.dispatch_origin_for_dispatch_id(dispatch_id) .map(|origin| { - self.handlers_in_caller_chain(origin.k_origin.segment_id) + let chain_start = self + .root_delegate_parent_segment_id( + &origin.k_origin, + "Dispatch parent chain must be Delegate-created continuations", + ) + .unwrap_or(origin.k_origin.segment_id); + self.handlers_in_caller_chain(chain_start) .into_iter() .map(|entry| entry.prompt_seg_id) .collect() @@ -1594,7 +1619,13 @@ impl VM { dispatch_id.raw() ))); }; - let handler_chain = self.handlers_in_caller_chain(origin.k_origin.segment_id); + let handler_chain_start = self + .root_delegate_parent_segment_id( + &origin.k_origin, + "Delegate parent chain must be Delegate-created continuations", + ) + .unwrap_or(origin.k_origin.segment_id); + let handler_chain = self.handlers_in_caller_chain(handler_chain_start); let Some(from_idx) = handler_chain .iter() .position(|entry| entry.marker == current_marker) @@ -1731,7 +1762,6 @@ impl VM { self.current_seg_mut().mode = Mode::Throw(original_exception); return StepEvent::Continue; } - self.dispatch_fatal_error_event(VMError::delegate_no_outer_handler(effect)) } @@ -1826,7 +1856,12 @@ impl VM { // handlers. Continuation installation handles deduplication when these // handlers are reapplied from within active dispatch contexts. let handlers = self - .handlers_in_caller_chain(origin.k_origin.segment_id) + .root_delegate_parent_segment_id( + &origin.k_origin, + "GetHandlers parent chain must be Delegate-created continuations", + ) + .map(|seg_id| self.handlers_in_caller_chain(seg_id)) + .unwrap_or_else(|| self.handlers_in_caller_chain(origin.k_origin.segment_id)) .into_iter() .map(|entry| entry.handler) .collect::>(); diff --git a/packages/doeff-vm/doeff_vm/__init__.py b/packages/doeff-vm/doeff_vm/__init__.py index 24898694..0ed93760 100644 --- a/packages/doeff-vm/doeff_vm/__init__.py +++ b/packages/doeff-vm/doeff_vm/__init__.py @@ -181,6 +181,7 @@ def validated_nesting_to_generator(self): scheduler = _ext.scheduler lazy_ask = _ext.lazy_ask await_handler = _ext.await_handler +sync_await_handler = _ext.sync_await_handler CreateContinuation = _ext.CreateContinuation GetContinuation = _ext.GetContinuation GetHandlers = _ext.GetHandlers @@ -356,6 +357,7 @@ def validated_nesting_to_generator(self): "_SchedulerTaskCompleted", "async_run", "await_handler", + "sync_await_handler", "lazy_ask", "reader", "result_safe", diff --git a/packages/doeff-vm/src/pyvm.rs b/packages/doeff-vm/src/pyvm.rs index fbea1ee0..31fedb7c 100644 --- a/packages/doeff-vm/src/pyvm.rs +++ b/packages/doeff-vm/src/pyvm.rs @@ -400,6 +400,12 @@ pub struct PyVM { vm: VM, } +enum SyncDriverLoopOutcome { + Done(Value), + VmError(VMError), + PythonException(PyException), +} + #[pymethods] impl PyVM { #[new] @@ -433,31 +439,20 @@ impl PyVM { ) -> PyResult { self.start_with_expr(py, program)?; - let (result, traceback_data) = loop { - let event = py.detach(|| self.run_rust_steps()); - match event { - StepEvent::Done(value) => match value.to_pyobject(py) { - Ok(v) => break (Ok(v.unbind()), None), - Err(e) => { - let exc = pyerr_to_exception(py, e)?; - break (Err(exc), None); - } - }, - StepEvent::Error(e) => { - let (pyerr, traceback_data) = vmerror_to_pyerr_with_traceback_data(py, e); - let exc = pyerr_to_exception(py, pyerr)?; - break (Err(exc), traceback_data); + let (result, traceback_data) = match py.detach(|| self.run_sync_driver_loop()) { + SyncDriverLoopOutcome::Done(value) => match value.to_pyobject(py) { + Ok(v) => (Ok(v.unbind()), None), + Err(e) => { + let exc = pyerr_to_exception(py, e)?; + (Err(exc), None) } - StepEvent::NeedsPython(call) => { - let outcome = self.execute_python_call(py, call)?; - if let Err(e) = self.vm.receive_python_result(outcome) { - let (pyerr, traceback_data) = vmerror_to_pyerr_with_traceback_data(py, e); - let exc = pyerr_to_exception(py, pyerr)?; - break (Err(exc), traceback_data); - } - } - StepEvent::Continue => unreachable!("handled in run_rust_steps"), + }, + SyncDriverLoopOutcome::VmError(e) => { + let (pyerr, traceback_data) = vmerror_to_pyerr_with_traceback_data(py, e); + let exc = pyerr_to_exception(py, pyerr)?; + (Err(exc), traceback_data) } + SyncDriverLoopOutcome::PythonException(exc) => (Err(exc), None), }; self.vm.end_active_run_session(); @@ -769,6 +764,35 @@ impl PyVM { } } + fn run_sync_driver_loop(&mut self) -> SyncDriverLoopOutcome { + loop { + match self.run_rust_steps() { + StepEvent::Done(value) => return SyncDriverLoopOutcome::Done(value), + StepEvent::Error(error) => return SyncDriverLoopOutcome::VmError(error), + StepEvent::NeedsPython(call) => { + let outcome = Python::attach(|py| self.execute_python_call(py, call)); + match outcome { + Ok(outcome) => { + if let Err(error) = self.vm.receive_python_result(outcome) { + return SyncDriverLoopOutcome::VmError(error); + } + } + Err(pyerr) => { + let exception = Python::attach(|py| pyerr_to_exception(py, pyerr)) + .unwrap_or_else(|err| { + PyException::runtime_error(format!( + "failed to convert Python error during sync driver loop: {err}" + )) + }); + return SyncDriverLoopOutcome::PythonException(exception); + } + } + } + StepEvent::Continue => unreachable!("handled in run_rust_steps"), + } + } + } + fn step_once_error_tuple(&self, py: Python<'_>, e: VMError) -> PyResult> { let (pyerr, traceback_data) = vmerror_to_pyerr_with_traceback_data(py, e); let err_obj = pyerr.value(py).clone().into_any(); diff --git a/tests/core/test_rust_vm_api_strict.py b/tests/core/test_rust_vm_api_strict.py index ef35a183..f29711b3 100644 --- a/tests/core/test_rust_vm_api_strict.py +++ b/tests/core/test_rust_vm_api_strict.py @@ -29,7 +29,7 @@ def test_default_handlers_requires_module_sentinels(monkeypatch: pytest.MonkeyPa def test_default_handlers_are_module_sentinels_only(monkeypatch: pytest.MonkeyPatch) -> None: from doeff.effects.future import sync_await_handler - from doeff.effects.spawn import spawn_intercept_handler + from doeff.effects.spawn import sync_spawn_intercept_handler sentinels = { "state": object(), @@ -51,7 +51,7 @@ def test_default_handlers_are_module_sentinels_only(monkeypatch: pytest.MonkeyPa sentinels["result_safe"], sentinels["scheduler"], sentinels["lazy_ask"], - spawn_intercept_handler, + sync_spawn_intercept_handler, sync_await_handler, ] diff --git a/tests/core/test_sa009_async_handler_spec_gaps.py b/tests/core/test_sa009_async_handler_spec_gaps.py index 30e8456a..4bdfeffd 100644 --- a/tests/core/test_sa009_async_handler_spec_gaps.py +++ b/tests/core/test_sa009_async_handler_spec_gaps.py @@ -144,6 +144,26 @@ async def test_async_run_handler_stack_matches_passed_handlers(self) -> None: vm_handler_stack=_result_value(result), ) + def test_spawned_child_preserves_full_default_handler_stack(self) -> None: + handlers = default_handlers() + + @do + def child(): + return (yield doeff_vm.GetHandlers()) + + @do + def program(): + task = yield Spawn(child(), daemon=False) + values = yield Gather(task) + return values[0] + + result = run(program(), handlers=handlers) + assert _result_is_ok(result), getattr(result, "display", lambda: repr(result))() + _assert_vm_handler_stack_matches_passed_handlers( + passed_handlers=handlers, + vm_handler_stack=_result_value(result), + ) + class TestNoHandlerSwapContract: def test_no_normalize_async_handlers_function(self) -> None: