diff --git a/marimo/_pyodide/pyodide_session.py b/marimo/_pyodide/pyodide_session.py index dc2c6b9a324..5cd049ea92f 100644 --- a/marimo/_pyodide/pyodide_session.py +++ b/marimo/_pyodide/pyodide_session.py @@ -483,7 +483,7 @@ def _launch_pyodide_kernel( ) if is_edit_mode: - signal.signal(signal.SIGINT, handlers.construct_interrupt_handler(ctx)) + signal.signal(signal.SIGINT, handlers.construct_interrupt_handler()) async def listen_completion() -> None: while True: diff --git a/marimo/_runtime/context/kernel_context.py b/marimo/_runtime/context/kernel_context.py index 9f99a02402c..44cf94d77e7 100644 --- a/marimo/_runtime/context/kernel_context.py +++ b/marimo/_runtime/context/kernel_context.py @@ -24,6 +24,7 @@ from marimo._ast.app import InternalApp from marimo._messaging.types import KernelStreams + from marimo._runtime.runner.scheduler import Scheduler from marimo._runtime.runtime import Kernel from marimo._runtime.state import State from marimo._runtime.virtual_file import VirtualFileStorageType @@ -41,6 +42,14 @@ class KernelRuntimeContext(RuntimeContext): _app: InternalApp | None = None _id_provider: IDProvider | None = None _execution_context: ExecutionContext | None = None + # Set while a Scheduler's `async with` is open. Lookup goes through + # the currently-installed context — not one captured at install + # time — so embedded-app child contexts route SIGINT correctly. + _active_scheduler: Scheduler | None = None + + @property + def active_scheduler(self) -> Scheduler | None: + return self._active_scheduler @property def graph(self) -> DirectedGraph: diff --git a/marimo/_runtime/executor/evaluator.py b/marimo/_runtime/executor/evaluator.py index 523e384ffa3..cc5deb43438 100644 --- a/marimo/_runtime/executor/evaluator.py +++ b/marimo/_runtime/executor/evaluator.py @@ -3,24 +3,18 @@ from __future__ import annotations -import asyncio -import contextlib -import functools -import signal -import threading from dataclasses import replace -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING from marimo import _loggers from marimo._entrypoints.registry import EntryPointRegistry -from marimo._runtime.control_flow import MarimoInterrupt from marimo._runtime.executor.executor import DefaultExecutor, Executor from marimo._runtime.executor.lifecycles import ExecutionLifecycle, Skip from marimo._runtime.runner.result import RunResult from marimo._types.globals import MutableGlobals if TYPE_CHECKING: - from collections.abc import Callable, Iterator + from collections.abc import Callable from marimo._ast.cell import CellImpl @@ -92,18 +86,6 @@ def evaluate_sync( return self._teardown_chain(cell, glbls, completed, result) - async def evaluate_interruptible( - self, cell: CellImpl, glbls: MutableGlobals - ) -> RunResult: - """Await `evaluate` with SIGINT capture for coroutine cells.""" - if not cell.is_coroutine(): - return await self.evaluate(cell, glbls) - future = asyncio.ensure_future(self.evaluate(cell, glbls)) - if threading.current_thread() is threading.main_thread(): - with _cancel_on_sigint(future): - return await future - return await future - def _setup_chain( self, cell: CellImpl, glbls: MutableGlobals ) -> tuple[list[ExecutionLifecycle], Skip | None, BaseException | None]: @@ -196,48 +178,3 @@ def resolve_executor() -> Executor: e, ) return DefaultExecutor() - - -# Adapted from -# https://github.com/ipython/ipykernel/blob/eddd3e666a82ebec287168b0da7cfa03639a3772/ipykernel/ipkernel.py#L312 -@contextlib.contextmanager -def _cancel_on_sigint(future: asyncio.Future[Any]) -> Iterator[None]: - """Cancel `future` if a SIGINT arrives during evaluation.""" - sigint_future: asyncio.Future[int] = asyncio.Future() - - def cancel_unless_done(f: asyncio.Future[Any], _: Any) -> None: - if f.cancelled() or f.done(): - return - f.cancel() - - sigint_future.add_done_callback( - functools.partial(cancel_unless_done, future) - ) - future.add_done_callback( - functools.partial(cancel_unless_done, sigint_future) - ) - - # Capture the previously-installed SIGINT handler *before* we install - # ours so `handle_sigint` can invoke it for its side effects - # (kernel broadcast, duckdb interrupt). For async cells the actual - # halt comes from cancelling the future, not from a raised - # `MarimoInterrupt` — so we swallow that here. - prior_sigint = signal.getsignal(signal.SIGINT) - - def handle_sigint(signum: int, frame: Any) -> None: - if sigint_future.cancelled() or sigint_future.done(): - return - sigint_future.set_result(1) - if callable(prior_sigint): - try: - prior_sigint(signum, frame) - except MarimoInterrupt: - # The kernel's handler raises MarimoInterrupt for sync - # halt; we cancel the future instead. - pass - - save_sigint = signal.signal(signal.SIGINT, handle_sigint) - try: - yield - finally: - signal.signal(signal.SIGINT, save_sigint) diff --git a/marimo/_runtime/handlers.py b/marimo/_runtime/handlers.py index 5ccd2bc7a5f..16a750dd08f 100644 --- a/marimo/_runtime/handlers.py +++ b/marimo/_runtime/handlers.py @@ -10,6 +10,7 @@ from marimo._messaging.notification_utils import broadcast_notification from marimo._runtime.context import get_context from marimo._runtime.context.kernel_context import KernelRuntimeContext +from marimo._runtime.context.types import safe_get_context from marimo._runtime.control_flow import MarimoInterrupt LOGGER = _loggers.marimo_logger() @@ -20,35 +21,53 @@ from marimo._runtime.runtime import Kernel -def construct_interrupt_handler( - context: KernelRuntimeContext, -) -> Callable[[int, Any], None]: +def construct_interrupt_handler() -> Callable[[int, Any], None]: def interrupt_handler(signum: int, frame: Any) -> None: """Tries to interrupt the kernel.""" del signum del frame + # Resolve the *currently installed* context, not one captured at + # install time — embedded apps swap in their own child context. + ctx = safe_get_context() + if not isinstance(ctx, KernelRuntimeContext): + return + + # `execution_context` is a per-task ContextVar — unreadable from + # this thread while user tasks are suspended in `select()`. The + # scheduler publication is the authoritative "is a run in flight" + # signal; `execution_context` is opportunistic (only used for + # the duckdb hook below). + sched = ctx.active_scheduler + exec_ctx = ctx.execution_context + if sched is None and exec_ctx is None: + return + LOGGER.info("Interrupt request received") - # TODO(akshayka): if kernel is in `run` but not executing, - # it won't be interrupted, which isn't right ... but the - # probability of that happening is low. - if context.execution_context is not None: - broadcast_notification(InterruptedNotification()) - # DuckDB connections are sometimes left in an inconsistent - # state when interrupted by a SIGINT. Manually interrupting - # duckdb through its own API seems to be safer. - if context.execution_context.duckdb_connection is not None: - try: - context.execution_context.duckdb_connection.interrupt() - except Exception as e: - # Coarse try/except; let's not kill the kernel if something - # goes wrong. - LOGGER.warning( - "Failed to interrupt running duckdb connection. This " - "may be a bug in duckdb or marimo. %s", - e, - ) - raise MarimoInterrupt + broadcast_notification(InterruptedNotification()) + + # DuckDB connections are sometimes left in an inconsistent state + # when interrupted by a SIGINT; route through duckdb's own API. + if exec_ctx is not None and exec_ctx.duckdb_connection is not None: + try: + exec_ctx.duckdb_connection.interrupt() + except Exception as e: + LOGGER.warning( + "Failed to interrupt running duckdb connection. This " + "may be a bug in duckdb or marimo. %s", + e, + ) + + if sched is not None and sched.has_active_tasks(): + # Async cell in flight: cancel via the loop. Raising from a + # signal handler escapes into asyncio internals and surfaces + # as an internal-error empty RunResult. + sched.cancel_all() + return + + if sched is not None: + sched.cancel_all() + raise MarimoInterrupt return interrupt_handler diff --git a/marimo/_runtime/runner/cell_runner.py b/marimo/_runtime/runner/cell_runner.py index 5cfec2a87d9..6c49fe7681b 100644 --- a/marimo/_runtime/runner/cell_runner.py +++ b/marimo/_runtime/runner/cell_runner.py @@ -33,6 +33,7 @@ StrictLifecycle, resolve_executor, ) +from marimo._runtime.executor.executor import _strip_frame from marimo._runtime.marimo_pdb import MarimoPdb from marimo._runtime.runner.hook_context import ( CancelledCells, @@ -45,6 +46,7 @@ create_sql_error_from_exception, is_sql_parse_error, ) +from marimo._types.globals import MutableGlobals from marimo._types.ids import CellId_t LOGGER = marimo_logger() @@ -52,6 +54,7 @@ if TYPE_CHECKING: from collections import deque + from marimo._ast.cell import CellImpl from marimo._runtime.runner.hooks import NotebookCellHooks from marimo._runtime.state import State @@ -107,7 +110,7 @@ def __init__( self, roots: set[CellId_t], graph: dataflow.DirectedGraph, - glbls: dict[Any, Any], + glbls: MutableGlobals, debugger: MarimoPdb | None, hooks: NotebookCellHooks, execution_mode: OnCellChangeType = "autorun", @@ -397,6 +400,23 @@ def _run_result_from_exception( output=output, exception=exception ), unwrapped_exception + async def evaluate_interruptible(self, cell: CellImpl) -> RunResult: + """Evaluate `cell`. Coroutine cells run as a scheduler-tracked + task so the SIGINT-handler's `cancel_all` can preempt them — a + plain `await` is not cancellable from another thread.""" + coro = self._evaluator.evaluate(cell, self.glbls) + if not cell.is_coroutine(): + return await coro + try: + async with self._scheduler.start_task(cell.cell_id, coro) as task: + return await task + except asyncio.CancelledError as exc: + # SIGINT cancelled the task at any point — either pre-admit + # (start_task refused entry) or mid-await. Hand back to + # `_finalize_run_result` so it converts to MarimoInterrupt + # rather than escaping to the broad except below. + return RunResult(output=None, exception=exc) + async def run(self, cell_id: CellId_t) -> RunResult: """Run a cell.""" if self.debugger is not None: @@ -409,9 +429,7 @@ async def run(self, cell_id: CellId_t) -> RunResult: # returned RunResult; cell_id-specific classification + side # effects are applied below in `_finalize_run_result`. try: - raw_result = await self._evaluator.evaluate_interruptible( - cell, self.glbls - ) + raw_result = await self.evaluate_interruptible(cell) run_result = self._finalize_run_result(raw_result, cell_id) except BaseException: # Defensive: an unexpected escape from the Evaluator or a bug @@ -459,7 +477,10 @@ def _finalize_run_result( return raw_result if isinstance(exc, asyncio.exceptions.CancelledError): - # Surface cancellation as a MarimoInterrupt for downstream handling. + # Drop the two marimo frames above user code (the evaluator's + # `await execute_cell_async` and the executor's `await eval`) + # before surfacing as MarimoInterrupt. + _strip_frame(exc, 2) tmpio = io.StringIO() traceback.print_exception( type(exc), exc, exc.__traceback__, file=tmpio @@ -603,6 +624,34 @@ def _find_first_blocked_missing_ref( return defining_cell_id return None + async def _run_one( + self, + cell_id: CellId_t, + pre_exec_ctx: Any, + post_exec_ctx: Any, + ) -> None: + cell = self.graph.cells[cell_id] + for pre_hook in self._hooks.pre_execution_hooks: + pre_hook(cell, pre_exec_ctx) + LOGGER.debug("Running cell %s", cell_id) + + if self.execution_context is not None: + try: + with self.execution_context(cell_id) as exc_ctx: + run_result = await self.run(cell_id) + run_result.accumulated_output = exc_ctx.output + for post_hook in self._hooks.post_execution_hooks: + post_hook(cell, post_exec_ctx, run_result) + except KeyboardInterrupt: + LOGGER.error( + "A keyboard interrupt was raised but not handled by " + "the runner." + ) + else: + run_result = await self.run(cell_id) + for post_hook in self._hooks.post_execution_hooks: + post_hook(cell, post_exec_ctx, run_result) + async def run_all(self) -> None: from marimo._runtime.runner.hook_context import ( OnFinishHookContext, @@ -642,8 +691,51 @@ async def run_all(self) -> None: user_config=self.user_config, ) - while self.pending(): - cell_id = self.pop_cell() + # `async with self._scheduler` publishes the scheduler on the + # current context for SIGINT routing — must wrap the prescan + # too, otherwise an interrupt during a large prescan finds no + # scheduler and is dropped. + # + # `try/except KeyboardInterrupt` catches the raise produced by + # the sync-path SIGINT handler (which fires between any two + # bytecodes in the prescan or batch loop). `cancel_all` already + # ran on the scheduler, so `interrupted` is True and the queue + # is halted; suppress the raise so `on_finish_hooks` still fire + # and the kernel control loop doesn't see a `BaseException`. + async with self._scheduler: + try: + await self._dispatch_runnable(pre_exec_ctx, post_exec_ctx) + except KeyboardInterrupt: + LOGGER.info("Runner interrupted via SIGINT") + + finish_ctx = OnFinishHookContext( + graph=self.graph, + cells_to_run=self.cells_to_run, + interrupted=self.interrupted, + cancelled_cells=self.cancelled_cells, + exceptions=self.exceptions, + ) + LOGGER.debug("Running on_finish hooks") + for finish_hook in self._hooks.on_finish_hooks: + finish_hook(finish_ctx) + + async def _dispatch_runnable( + self, + pre_exec_ctx: Any, + post_exec_ctx: Any, + ) -> None: + """Filter the queue then run each runnable cell. + + Prescan iterates a snapshot of `_cells_to_run` and *removes* + filtered cells (cancelled/disabled) from the queue in place. + The live queue therefore always represents "cells still to + run" — a SIGINT mid-prescan leaves on_finish_hooks with the + correct remaining set (no spurious interrupted for filtered + cells, and the cell currently being filtered stays visible). + """ + for cell_id in list(self._scheduler.cells_to_run): + if self._scheduler.interrupted: + break LOGGER.debug("Cell runner processing %s", cell_id) cell = self.graph.cells[cell_id] @@ -680,52 +772,29 @@ async def run_all(self) -> None: LOGGER.debug("%s cancelled", cell_id) cell.set_run_result_status("cancelled") cell.set_runtime_state("idle") + self._scheduler.cells_to_run.remove(cell_id) continue if cell.config.disabled: LOGGER.debug("%s disabled", cell_id) cell.set_run_result_status("disabled") cell.set_runtime_state("idle") + self._scheduler.cells_to_run.remove(cell_id) continue if self.graph.is_disabled(cell_id): LOGGER.debug("%s disabled transitively", cell_id) cell.set_run_result_status("disabled") cell.set_runtime_state("disabled-transitively") + self._scheduler.cells_to_run.remove(cell_id) continue - - LOGGER.debug("Running pre_execution hooks") - for pre_hook in self._hooks.pre_execution_hooks: - pre_hook(cell, pre_exec_ctx) - LOGGER.debug("Running cell %s", cell_id) - if self.execution_context is not None: - try: - # TODO(akshayka): The execution context should be pushed - # down to as close to kernel execution as possible. - with self.execution_context(cell_id) as exc_ctx: - run_result = await self.run(cell_id) - run_result.accumulated_output = exc_ctx.output - LOGGER.debug("Running post_execution hooks in context") - for post_hook in self._hooks.post_execution_hooks: - post_hook(cell, post_exec_ctx, run_result) - except KeyboardInterrupt: - LOGGER.error( - """ - A keyboard interrupt was raised but not handled by the runner. - """ - ) - - else: - run_result = await self.run(cell_id) - LOGGER.debug("Running post_execution hooks out of context") - for post_hook in self._hooks.post_execution_hooks: - post_hook(cell, post_exec_ctx, run_result) - - finish_ctx = OnFinishHookContext( - graph=self.graph, - cells_to_run=self.cells_to_run, - interrupted=self.interrupted, - cancelled_cells=self.cancelled_cells, - exceptions=self.exceptions, - ) - LOGGER.debug("Running on_finish hooks") - for finish_hook in self._hooks.on_finish_hooks: - finish_hook(finish_ctx) + # Runnable: leave in queue for dispatch. + + for batch in self._scheduler.batch(): + for cell_id in batch: + # Re-check: an earlier cell in this run may have + # cancelled its descendants while we were dispatching. + if self.cancelled(cell_id): + cell = self.graph.cells[cell_id] + cell.set_run_result_status("cancelled") + cell.set_runtime_state("idle") + continue + await self._run_one(cell_id, pre_exec_ctx, post_exec_ctx) diff --git a/marimo/_runtime/runner/hooks_post_execution.py b/marimo/_runtime/runner/hooks_post_execution.py index 6efefc29de0..b7607bc569d 100644 --- a/marimo/_runtime/runner/hooks_post_execution.py +++ b/marimo/_runtime/runner/hooks_post_execution.py @@ -98,7 +98,9 @@ def _set_run_result_status( ctx: PostExecutionHookContext, run_result: cell_runner.RunResult, ) -> None: - if isinstance(run_result.exception, MarimoInterruptionError): + if isinstance(run_result.exception, MarimoInterrupt): + # `MarimoInterruptionError` is a broadcast payload (never raised); + # the exception held here is the raised `MarimoInterrupt`. cell.set_run_result_status("interrupted") elif cell.cell_id in ctx.cancelled_cells: cell.set_run_result_status("cancelled") diff --git a/marimo/_runtime/runner/scheduler.py b/marimo/_runtime/runner/scheduler.py index 6205d2670dc..4f347e2ec77 100644 --- a/marimo/_runtime/runner/scheduler.py +++ b/marimo/_runtime/runner/scheduler.py @@ -1,35 +1,79 @@ # Copyright 2026 Marimo. All rights reserved. -"""Scheduler owns the cell queue and cancellation state""" +"""Scheduler: per-run cell queue, cancellation, and async-task tracking. + +Singular-scheduler invariant +---------------------------- +At most one `Runner.run_all()` is on the stack per `KernelRuntimeContext` +at any time. The kernel serializes runs — control requests queue and +state-update cascades only re-enter after `run_all()` returns. Embedded +apps push a child `KernelRuntimeContext` (not a nested scheduler on the +same context). Under this invariant, `KernelRuntimeContext._active_scheduler` +can be a singular field. + +A future non-blocking `AsyncScheduler.submit()` (returning before +dispatch completes) will break this invariant by allowing concurrent +schedulers on one context; that PR will need to promote +`_active_scheduler` to a plural `OrderedDict[int, Scheduler]`. +`__aenter__` below fails loudly if the invariant is ever silently +broken. +""" from __future__ import annotations +import asyncio from collections import deque -from typing import TYPE_CHECKING, Protocol +from contextlib import asynccontextmanager +from typing import TYPE_CHECKING, Any, Protocol +from marimo import _loggers from marimo._runtime import dataflow +from marimo._runtime.context.types import safe_get_context from marimo._runtime.runner.hook_context import CancelledCells if TYPE_CHECKING: - from collections.abc import Iterable, Iterator, Sequence + from collections.abc import ( + AsyncIterator, + Coroutine, + Iterable, + Iterator, + Sequence, + ) + + from typing_extensions import Self from marimo._runtime.dataflow import DirectedGraph + from marimo._runtime.runner.result import RunResult from marimo._types.ids import CellId_t +LOGGER = _loggers.marimo_logger() + class Scheduler(Protocol): - """Cell queue + cancellation. Surface for future scheduler types.""" + """Cell queue + cancellation + async task tracking.""" def pending(self) -> bool: ... def pop_cell(self) -> CellId_t: ... def cancel(self, cell_id: CellId_t) -> None: ... def cancelled(self, cell_id: CellId_t) -> bool: ... def batch( - self, cell_ids: Iterable[CellId_t] - ) -> Iterator[list[CellId_t]]: ... + self, cell_ids: Iterable[CellId_t] | None = ... + ) -> Iterator[Iterable[CellId_t]]: ... + def requeue(self, cell_ids: Iterable[CellId_t]) -> None: ... + + def start_task( + self, + cell_id: CellId_t, + coro: Coroutine[Any, Any, RunResult], + ) -> AsyncIterator[asyncio.Task[RunResult]]: ... + def has_active_tasks(self) -> bool: ... + def cancel_all(self) -> None: ... + + async def __aenter__(self) -> Self: ... + async def __aexit__(self, *exc_info: Any) -> None: ... class SequentialScheduler: - """Single-threaded FIFO queue + cancellation.""" + """Single-threaded FIFO queue + cancellation + async task tracking.""" def __init__( self, @@ -40,6 +84,7 @@ def __init__( self._cancelled = CancelledCells() self._graph = graph self._interrupted = False + self._active: dict[CellId_t, asyncio.Task[Any]] = {} def pending(self) -> bool: return not self._interrupted and len(self._cells_to_run) > 0 @@ -47,15 +92,25 @@ def pending(self) -> bool: def pop_cell(self) -> CellId_t: return self._cells_to_run.popleft() - def batch(self, cell_ids: Iterable[CellId_t]) -> Iterator[list[CellId_t]]: - """Yield batches of cells to execute. - - Sequential default: one cell per batch. + def batch( + self, cell_ids: Iterable[CellId_t] | None = None + ) -> Iterator[Iterable[CellId_t]]: + """Yield batches of cells to execute (1-tuple per batch here). + + If `cell_ids` is given, the queue is replaced first — kept for + callers that drive the scheduler without going through + `requeue` (and for tests). When `None`, consume the existing + `_cells_to_run` as-is. """ + if cell_ids is not None: + self.requeue(cell_ids) + while self._cells_to_run and not self._interrupted: + yield (self._cells_to_run.popleft(),) + + def requeue(self, cell_ids: Iterable[CellId_t]) -> None: + """Replace the pending queue with `cell_ids`.""" self._cells_to_run.clear() self._cells_to_run.extend(cell_ids) - while self._cells_to_run and not self._interrupted: - yield [self._cells_to_run.popleft()] def cancel(self, cell_id: CellId_t) -> None: """Mark a cell and its descendants as cancelled.""" @@ -87,3 +142,92 @@ def cancelled_cells(self) -> CancelledCells: def cells_to_run(self) -> deque[CellId_t]: """The live queue. Mutates as cells are popped.""" return self._cells_to_run + + @asynccontextmanager + async def start_task( + self, + cell_id: CellId_t, + coro: Coroutine[Any, Any, RunResult], + ) -> AsyncIterator[asyncio.Task[RunResult]]: + """Atomically create and register a task for `coro`. + + Closes the SIGINT race where a task is created before being + tracked: `ensure_future` and `_register_task` run as two + adjacent synchronous statements (the narrowest gap in pure + Python), then `_interrupted` is re-checked. A SIGINT delivered + between them flips `_interrupted` via `cancel_all`; the + re-check cancels the freshly-registered task before the loop + ever resumes it. + """ + if self._interrupted: + coro.close() + raise asyncio.CancelledError + task = asyncio.ensure_future(coro) + self._register_task(cell_id, task) + if self._interrupted: + task.cancel() + try: + yield task + finally: + self._unregister_task(cell_id) + + def _register_task( + self, cell_id: CellId_t, task: asyncio.Task[Any] + ) -> None: + self._active[cell_id] = task + + def _unregister_task(self, cell_id: CellId_t) -> None: + self._active.pop(cell_id, None) + + def has_active_tasks(self) -> bool: + return any(not t.done() for t in self._active.values()) + + def cancel_all(self) -> None: + # Set `_interrupted` first so a SIGINT arriving between cells + # (no task registered) still halts the queue. + # + # `call_soon_threadsafe` is required: a plain `task.cancel()` + # from the signal-handler thread queues the cancel but doesn't + # wake the loop's `select()` — the task keeps sleeping until + # its next scheduled wakeup. + self._interrupted = True + for task in list(self._active.values()): + if task.done(): + continue + task.get_loop().call_soon_threadsafe(task.cancel) + + async def __aenter__(self) -> Self: + # Late import to avoid a cycle through the runtime context tree. + from marimo._runtime.context.kernel_context import ( + KernelRuntimeContext, + ) + + ctx = safe_get_context() + if isinstance(ctx, KernelRuntimeContext): + if ctx._active_scheduler is not None: + # See module docstring: a second `async with scheduler` + # on the same context means nested or concurrent runs, + # which the singular `_active_scheduler` design does not + # support. Fail loudly so the regression surfaces in + # tests rather than as a silent SIGINT-routing bug. + raise RuntimeError( + "A scheduler is already active on this context; " + "concurrent runs are not supported. This indicates " + "either a re-entrant Runner.run_all or a future " + "non-blocking scheduler that should be promoting " + "_active_scheduler to plural." + ) + ctx._active_scheduler = self + return self + + async def __aexit__(self, *exc_info: Any) -> None: + from marimo._runtime.context.kernel_context import ( + KernelRuntimeContext, + ) + + ctx = safe_get_context() + if ( + isinstance(ctx, KernelRuntimeContext) + and ctx._active_scheduler is self + ): + ctx._active_scheduler = None diff --git a/marimo/_runtime/runtime.py b/marimo/_runtime/runtime.py index 7b21537a45e..e359d77b205 100644 --- a/marimo/_runtime/runtime.py +++ b/marimo/_runtime/runtime.py @@ -2450,7 +2450,6 @@ def _create_streams( def _install_subprocess_handlers( kernel: Kernel, - ctx: KernelRuntimeContext, user_config: MarimoConfig, interrupt_queue: QueueType[bool] | None, ) -> None: @@ -2460,7 +2459,7 @@ def _install_subprocess_handlers( register_formatters(theme=user_config["display"]["theme"]) - signal.signal(signal.SIGINT, handlers.construct_interrupt_handler(ctx)) + signal.signal(signal.SIGINT, handlers.construct_interrupt_handler()) if sys.platform == "win32": if interrupt_queue is not None: @@ -2544,7 +2543,7 @@ def launch_kernel( # Read theme from kernel.user_config — create_kernel may have # mutated it for run mode (autorun + auto_reload off). _install_subprocess_handlers( - kernel, ctx, kernel.user_config, interrupt_queue + kernel, kernel.user_config, interrupt_queue ) # The control loop is asynchronous so that (a) user code can use diff --git a/tests/_runtime/runner/test_cell_runner.py b/tests/_runtime/runner/test_cell_runner.py index 4ab2f601517..b96449f9f8d 100644 --- a/tests/_runtime/runner/test_cell_runner.py +++ b/tests/_runtime/runner/test_cell_runner.py @@ -1,5 +1,6 @@ # Copyright 2026 Marimo. All rights reserved. import traceback +from typing import Any import pytest @@ -397,11 +398,60 @@ async def fake_evaluate(cell, glbls): # type: ignore[no-untyped-def] del cell, glbls return RunResult(output=None, exception=asyncio.CancelledError()) - monkeypatch.setattr( - runner._evaluator, "evaluate_interruptible", fake_evaluate - ) + monkeypatch.setattr(runner._evaluator, "evaluate", fake_evaluate) with capture_stderr(): await runner.run(er.cell_id) assert runner.interrupted is True + + +async def test_run_all_swallows_sigint_raise_and_fires_on_finish( + execution_kernel: Kernel, + exec_req: ExecReqProvider, + monkeypatch: pytest.MonkeyPatch, +) -> None: + """SIGINT delivered while `run_all` is in the prescan (or between + cells) raises `MarimoInterrupt` (== `KeyboardInterrupt`). `run_all` + must catch it so on_finish_hooks still fire and the + `KeyboardInterrupt` doesn't unwind past the kernel control loop. + """ + k = execution_kernel + await k.run([er := exec_req.get("123")]) + del er + + on_finish_calls: list[Any] = [] + + class _Hooks(NotebookCellHooks): + on_finish_hooks = (lambda ctx: on_finish_calls.append(ctx),) + + runner = Runner( + roots=set(k.graph.cells.keys()), + graph=k.graph, + glbls=k.globals, + debugger=k.debugger, + hooks=_Hooks(), + ) + + # Simulate the sync-path SIGINT handler: cancel the queue and raise. + def _raise_via_prescan(_cell_id: Any) -> Any: + runner._scheduler.cancel_all() + raise KeyboardInterrupt + + monkeypatch.setattr( + runner, "_find_first_blocked_missing_ref", _raise_via_prescan + ) + + # Must not raise; on_finish_hooks must still fire. + await runner.run_all() + + assert runner.interrupted is True + assert len(on_finish_calls) == 1 + assert on_finish_calls[0].interrupted is True + # Codex P2 regression: cells_to_run must still be visible to + # on_finish_hooks. SIGINT mid-prescan must not leave the queue + # drained. + assert list(on_finish_calls[0].cells_to_run), ( + "on_finish_hooks must see the cells that were still queued " + "when SIGINT fired" + ) diff --git a/tests/_runtime/runner/test_hooks.py b/tests/_runtime/runner/test_hooks.py index b18a9c538a7..8915421159f 100644 --- a/tests/_runtime/runner/test_hooks.py +++ b/tests/_runtime/runner/test_hooks.py @@ -72,3 +72,61 @@ def test_set_status_idle_is_last_post_execution_hook(self) -> None: ) assert POST_EXECUTION_HOOKS[-1] is _set_status_idle + + +class TestSetRunResultStatus: + """`MarimoInterrupt` in `run_result.exception` must map to the + `interrupted` status so the UI distinguishes user-stop from error. + """ + + def _hook(self, exception, cancelled_cells=None): # type: ignore[no-untyped-def] + from unittest.mock import MagicMock + + from marimo._runtime.runner.hooks_post_execution import ( + _set_run_result_status, + ) + from marimo._runtime.runner.result import RunResult + + cell = MagicMock() + cell.cell_id = "c0" + ctx = MagicMock() + ctx.cancelled_cells = cancelled_cells or set() + run_result = RunResult(output=None, exception=exception) + _set_run_result_status(cell, ctx, run_result) + return cell.set_run_result_status.call_args + + def test_marimo_interrupt_sets_interrupted(self) -> None: + from marimo._runtime.control_flow import MarimoInterrupt + + call = self._hook(MarimoInterrupt()) + + assert call is not None + assert call.args[0] == "interrupted" + + def test_marimo_interrupt_takes_precedence_over_cancelled(self) -> None: + """Interrupt wins over cancellation: a cell that interrupted + itself also lands in `cancelled_cells` via its descendants pass.""" + from marimo._runtime.control_flow import MarimoInterrupt + + call = self._hook(MarimoInterrupt(), cancelled_cells={"c0"}) + + assert call is not None + assert call.args[0] == "interrupted" + + def test_cancelled_when_not_interrupt(self) -> None: + call = self._hook(ValueError("boom"), cancelled_cells={"c0"}) + + assert call is not None + assert call.args[0] == "cancelled" + + def test_exception_when_not_interrupt_or_cancelled(self) -> None: + call = self._hook(ValueError("boom")) + + assert call is not None + assert call.args[0] == "exception" + + def test_no_exception_is_success(self) -> None: + call = self._hook(None) + + assert call is not None + assert call.args[0] == "success" diff --git a/tests/_runtime/test_executor_evaluator.py b/tests/_runtime/test_executor_evaluator.py index 6e05053b917..10f5ffed519 100644 --- a/tests/_runtime/test_executor_evaluator.py +++ b/tests/_runtime/test_executor_evaluator.py @@ -13,7 +13,8 @@ from __future__ import annotations import asyncio -from typing import Any +from contextlib import asynccontextmanager +from typing import TYPE_CHECKING, Any import pytest @@ -24,8 +25,12 @@ ExecutionLifecycle, Skip, ) +from marimo._runtime.runner.cell_runner import Runner from marimo._runtime.runner.result import RunResult +if TYPE_CHECKING: + from collections.abc import AsyncIterator + class _Recorder: """Lifecycle that records setup/teardown calls into a shared log.""" @@ -523,7 +528,7 @@ def teardown( assert lifecycle.name == "mine" -# --- Surface 4: _cancel_on_sigint + evaluate_interruptible ------------------ +# --- Async cancellation ----------------------------------------------------- def _async_body(src: str) -> Any: @@ -534,154 +539,195 @@ def _async_body(src: str) -> Any: return compile(src, "", "exec", flags=ast.PyCF_ALLOW_TOP_LEVEL_AWAIT) -async def test_cancel_on_sigint_installs_and_restores_handler( - monkeypatch: pytest.MonkeyPatch, -) -> None: - """`_cancel_on_sigint` swaps in its own handler on enter and - restores the previously-installed one on exit.""" - import signal - - from marimo._runtime.executor.evaluator import _cancel_on_sigint - - def prior(signum: int, frame: Any) -> None: - del signum, frame - - signal_calls: list[tuple[int, Any]] = [] +async def test_executor_async_cancellation_propagates_unwrapped() -> None: + """`asyncio.CancelledError` must propagate unwrapped through + `DefaultExecutor.execute_cell_async` — wrapping it as + `MarimoRuntimeException` would mask the cancellation.""" - def fake_signal(signum: int, handler: Any) -> Any: - signal_calls.append((signum, handler)) - return prior + class _AsyncCell: + cell_id = "0" + body = _async_body("import asyncio\nawait asyncio.sleep(100)") + last_expr = compile("None", "", "eval") - monkeypatch.setattr(signal, "signal", fake_signal) - monkeypatch.setattr(signal, "getsignal", lambda _signum: prior) + def is_coroutine(self) -> bool: + return True - fut: asyncio.Future[Any] = asyncio.Future() - with _cancel_on_sigint(fut): - # On enter: a new handler installed (not the prior). - assert signal_calls, "no signal.signal call recorded on enter" - assert signal_calls[0][0] == signal.SIGINT - assert signal_calls[0][1] is not prior + task = asyncio.create_task( + DefaultExecutor().execute_cell_async(_AsyncCell(), {}) # type: ignore[arg-type] + ) + # Yield so the task enters the awaited sleep before we cancel. + await asyncio.sleep(0) + task.cancel() - # On exit: prior handler restored as the last call. - assert signal_calls[-1] == (signal.SIGINT, prior) + with pytest.raises(asyncio.CancelledError): + await task -async def test_cancel_on_sigint_handler_cancels_future_and_chains_prior( - monkeypatch: pytest.MonkeyPatch, -) -> None: - """The installed handler must cancel the wrapped future and invoke - the previously-installed handler for its side effects.""" - import signal +async def test_start_task_cancel_all_propagates() -> None: + """`cancel_all` schedules cancellation via `call_soon_threadsafe` so a + loop blocked in `select()` wakes immediately; a plain `Future.cancel` + leaves the loop sleeping until the task's next scheduled wakeup.""" + from marimo._runtime.runner.scheduler import SequentialScheduler - from marimo._runtime.executor.evaluator import _cancel_on_sigint + sched = SequentialScheduler( + cells_to_run=[], + graph=None, # type: ignore[arg-type] + ) - prior_calls: list[tuple[int, Any]] = [] + async def slow() -> RunResult: + await asyncio.sleep(60) + return RunResult(output=None, exception=None) - def prior(signum: int, frame: Any) -> None: - prior_calls.append((signum, frame)) + async with sched.start_task("c0", slow()) as task: # type: ignore[arg-type] + await asyncio.sleep(0) + assert sched.has_active_tasks() + sched.cancel_all() + with pytest.raises(asyncio.CancelledError): + await task - captured: list[Any] = [] + assert sched.interrupted is True - def fake_signal(signum: int, handler: Any) -> Any: - captured.append(handler) - return prior - monkeypatch.setattr(signal, "signal", fake_signal) - monkeypatch.setattr(signal, "getsignal", lambda _signum: prior) +async def test_start_task_cancels_when_interrupted_pre_entry() -> None: + """`start_task` must refuse to admit a new task once `cancel_all` has + fired — otherwise a SIGINT racing in just before the task is + registered could leave the freshly-created task running detached.""" + from marimo._runtime.runner.scheduler import SequentialScheduler - fut: asyncio.Future[Any] = asyncio.Future() - with _cancel_on_sigint(fut): - marimo_handler = captured[0] - marimo_handler(signal.SIGINT, None) - # Cancellation propagates through done-callbacks asynchronously; - # yield to the loop so they fire. - await asyncio.sleep(0) + sched = SequentialScheduler( + cells_to_run=[], + graph=None, # type: ignore[arg-type] + ) + sched.cancel_all() # flips _interrupted - assert fut.cancelled() - assert prior_calls == [(signal.SIGINT, None)] + async def body() -> RunResult: + await asyncio.sleep(60) + return RunResult(output=None, exception=None) + coro = body() + with pytest.raises(asyncio.CancelledError): + async with sched.start_task("c0", coro): # type: ignore[arg-type] + pass + # `coro` was closed before becoming a task; nothing to leak. + assert not sched.has_active_tasks() -async def test_cancel_on_sigint_swallows_marimo_interrupt_from_prior_handler( - monkeypatch: pytest.MonkeyPatch, -) -> None: - """Prior handler raising `MarimoInterrupt` must not escape — the - kernel's sync-mode raise is irrelevant for async cells, where the - halt comes from cancelling the future.""" - import signal - from marimo._runtime.control_flow import MarimoInterrupt - from marimo._runtime.executor.evaluator import _cancel_on_sigint +async def test_runner_evaluate_interruptible_routes_async_cells_to_scheduler() -> ( + None +): + """`Runner.evaluate_interruptible` must funnel coroutine cells + through `scheduler.start_task` so the SIGINT-handler's `cancel_all` + can preempt them.""" + + class _StubScheduler: + def __init__(self) -> None: + self.started: list[tuple[str, Any]] = [] + + @asynccontextmanager + async def start_task( + self, cell_id: str, coro: Any + ) -> AsyncIterator[asyncio.Task[Any]]: + self.started.append((cell_id, coro)) + task = asyncio.ensure_future(coro) + try: + yield task + finally: + if not task.done(): + task.cancel() - def prior(signum: int, frame: Any) -> None: - raise MarimoInterrupt + class _AsyncCell: + cell_id = "c0" + body = _async_body("x = 1") + last_expr = compile("None", "", "eval") - captured: list[Any] = [] + def is_coroutine(self) -> bool: + return True - def fake_signal(signum: int, handler: Any) -> Any: - captured.append(handler) - return prior + class _RunnerStub: + def __init__(self) -> None: + self.glbls: dict[str, Any] = {} + self._scheduler = _StubScheduler() + self._evaluator = Evaluator( + executor=DefaultExecutor(), lifecycles=[] + ) - monkeypatch.setattr(signal, "signal", fake_signal) - monkeypatch.setattr(signal, "getsignal", lambda _signum: prior) + evaluate_interruptible = ( + Runner.evaluate_interruptible # type: ignore[attr-defined] + ) - fut: asyncio.Future[Any] = asyncio.Future() - with _cancel_on_sigint(fut): - marimo_handler = captured[0] - # No exception escapes — the wrapper catches MarimoInterrupt - # from the prior handler. - marimo_handler(signal.SIGINT, None) - await asyncio.sleep(0) - assert fut.cancelled() + runner = _RunnerStub() + result = await runner.evaluate_interruptible(_AsyncCell()) # type: ignore[arg-type] + assert result.exception is None + assert runner._scheduler.started, ( + "async cell must be routed through scheduler.start_task" + ) -async def test_executor_async_cancellation_propagates_unwrapped() -> None: - """`asyncio.CancelledError` must propagate unwrapped through - `DefaultExecutor.execute_cell_async` — wrapping it as - `MarimoRuntimeException` would mask the cancellation.""" +async def test_runner_evaluate_interruptible_surfaces_cancelled_as_run_result() -> ( + None +): + """When `start_task` refuses to admit a coroutine cell because + `cancel_all` already fired, the resulting `CancelledError` must come + back as `RunResult(exception=CancelledError)`. The broad-except path + in `Runner.run` would otherwise log an internal error and emit an + empty success-like result, masking the interrupt.""" + from marimo._runtime.runner.scheduler import SequentialScheduler + + sched = SequentialScheduler( + cells_to_run=[], + graph=None, # type: ignore[arg-type] + ) + sched.cancel_all() # pre-admit refusal path class _AsyncCell: - cell_id = "0" - body = _async_body("import asyncio\nawait asyncio.sleep(100)") + cell_id = "c0" + body = _async_body("x = 1") last_expr = compile("None", "", "eval") def is_coroutine(self) -> bool: return True - task = asyncio.create_task( - DefaultExecutor().execute_cell_async(_AsyncCell(), {}) # type: ignore[arg-type] - ) - # Yield so the task enters the awaited sleep before we cancel. - await asyncio.sleep(0) - task.cancel() + class _RunnerStub: + def __init__(self) -> None: + self.glbls: dict[str, Any] = {} + self._scheduler = sched + self._evaluator = Evaluator( + executor=DefaultExecutor(), lifecycles=[] + ) - with pytest.raises(asyncio.CancelledError): - await task + evaluate_interruptible = ( + Runner.evaluate_interruptible # type: ignore[attr-defined] + ) + runner = _RunnerStub() + result = await runner.evaluate_interruptible(_AsyncCell()) # type: ignore[arg-type] + assert isinstance(result.exception, asyncio.CancelledError) -async def test_evaluate_interruptible_no_op_for_sync_cell() -> None: - """Sync cells: `evaluate_interruptible` returns the same shape as a - direct `evaluate()` call. The SIGINT-handler wrap is for async only.""" - class _SyncCell: - cell_id = "0" - body = compile("x = 1", "", "exec") - last_expr = compile("x", "", "eval") +async def test_scheduler_async_context_publishes_on_kernel_context() -> None: + """`async with scheduler` sets `_active_scheduler` on entry and + clears it on exit so the SIGINT handler can find the scheduler.""" + from unittest.mock import MagicMock - def is_coroutine(self) -> bool: - return False - - ev = Evaluator(executor=DefaultExecutor(), lifecycles=[]) - - sync_glbls: dict[str, Any] = {} - interruptible_glbls: dict[str, Any] = {} + from marimo._runtime.context.kernel_context import ( + KernelRuntimeContext, + ) + from marimo._runtime.context.types import _THREAD_LOCAL_CONTEXT + from marimo._runtime.runner.scheduler import SequentialScheduler - direct = await ev.evaluate(_SyncCell(), sync_glbls) # type: ignore[arg-type] - interruptible = await ev.evaluate_interruptible( - _SyncCell(), # type: ignore[arg-type] - interruptible_glbls, + sched = SequentialScheduler( + cells_to_run=[], + graph=None, # type: ignore[arg-type] ) - assert direct.output == interruptible.output == 1 - assert direct.exception is None - assert interruptible.exception is None - assert direct.accumulated_output == interruptible.accumulated_output + # `spec=KernelRuntimeContext` makes `isinstance` accept the mock. + ctx = MagicMock(spec=KernelRuntimeContext) + ctx._active_scheduler = None + prior = _THREAD_LOCAL_CONTEXT.runtime_context + _THREAD_LOCAL_CONTEXT.runtime_context = ctx + try: + async with sched: + assert ctx._active_scheduler is sched + assert ctx._active_scheduler is None + finally: + _THREAD_LOCAL_CONTEXT.runtime_context = prior diff --git a/tests/_runtime/test_interrupt_handlers.py b/tests/_runtime/test_interrupt_handlers.py index 4d832755050..2f9436c4dba 100644 --- a/tests/_runtime/test_interrupt_handlers.py +++ b/tests/_runtime/test_interrupt_handlers.py @@ -7,6 +7,7 @@ import pytest from marimo._dependencies.dependencies import DependencyManager +from marimo._runtime.context.kernel_context import KernelRuntimeContext from marimo._runtime.context.types import ExecutionContext from marimo._runtime.handlers import construct_interrupt_handler from marimo._runtime.runtime import MarimoInterrupt @@ -14,6 +15,15 @@ HAS_DUCKDB = DependencyManager.duckdb.has() +def _kernel_context_mock(exec_ctx: ExecutionContext) -> MagicMock: + """Mock that satisfies the handler's `isinstance(KernelRuntimeContext)` + check; `active_scheduler=None` selects the sync raise path.""" + ctx = MagicMock(spec=KernelRuntimeContext) + ctx.execution_context = exec_ctx + ctx.active_scheduler = None + return ctx + + @pytest.mark.skipif(not HAS_DUCKDB, reason="DuckDB not installed") def test_duckdb_interrupt_handler_called_when_connection_present(): """Test that duckdb.interrupt() is called when a connection is present.""" @@ -25,15 +35,13 @@ def test_duckdb_interrupt_handler_called_when_connection_present(): # Create an execution context with a duckdb connection exec_ctx = ExecutionContext(cell_id="cell_id", setting_element_value=False) - # Mock the context to return our execution context - with patch("marimo._runtime.handlers.get_context") as mock_get_context: - mock_context = MagicMock() - mock_context.execution_context = exec_ctx - mock_get_context.return_value = mock_context + with patch( + "marimo._runtime.handlers.safe_get_context" + ) as mock_safe_get_context: + mock_safe_get_context.return_value = _kernel_context_mock(exec_ctx) - # Verify interrupt() is called when connection is set with exec_ctx.with_connection(mock_conn): - interrupt_handler = construct_interrupt_handler(mock_context) + interrupt_handler = construct_interrupt_handler() # Trigger the interrupt handler with pytest.raises(MarimoInterrupt): @@ -51,12 +59,13 @@ def test_duckdb_interrupt_handler_no_error_when_connection_none(): exec_ctx.duckdb_connection = None # Mock the context to return our execution context - with patch("marimo._runtime.handlers.get_context") as mock_get_context: - mock_context = MagicMock() - mock_context.execution_context = exec_ctx - mock_get_context.return_value = mock_context + with patch( + "marimo._runtime.handlers.safe_get_context" + ) as mock_safe_get_context: + mock_context = _kernel_context_mock(exec_ctx) + mock_safe_get_context.return_value = mock_context - interrupt_handler = construct_interrupt_handler(mock_context) + interrupt_handler = construct_interrupt_handler() # Should not raise error from duckdb interrupt (only MarimoInterrupt) with pytest.raises(MarimoInterrupt): @@ -76,14 +85,15 @@ def test_duckdb_interrupt_handler_exception_handling(): exec_ctx = ExecutionContext(cell_id="cell_id", setting_element_value=False) # Mock the context to return our execution context - with patch("marimo._runtime.handlers.get_context") as mock_get_context: - mock_context = MagicMock() - mock_context.execution_context = exec_ctx - mock_get_context.return_value = mock_context + with patch( + "marimo._runtime.handlers.safe_get_context" + ) as mock_safe_get_context: + mock_context = _kernel_context_mock(exec_ctx) + mock_safe_get_context.return_value = mock_context # Make interrupt() raise an exception with exec_ctx.with_connection(mock_conn): - interrupt_handler = construct_interrupt_handler(mock_context) + interrupt_handler = construct_interrupt_handler() # Should raise MarimoInterrupt, not RuntimeError # The RuntimeError should be caught and logged @@ -92,3 +102,57 @@ def test_duckdb_interrupt_handler_exception_handling(): # Verify interrupt was attempted mock_conn.interrupt.assert_called_once() + + +def test_sigint_between_cells_cancels_queue_and_raises() -> None: + """SIGINT landing between two cells (scheduler still running its + queue, no cell installed in `execution_context`) must halt the + queue and raise `MarimoInterrupt`. Regression for the P2 where the + handler returned early on `execution_context is None` before + consulting `active_scheduler`.""" + sched = MagicMock() + sched.has_active_tasks.return_value = False + + ctx = MagicMock(spec=KernelRuntimeContext) + ctx.execution_context = None + ctx.active_scheduler = sched + + with patch("marimo._runtime.handlers.safe_get_context", return_value=ctx): + interrupt_handler = construct_interrupt_handler() + with pytest.raises(MarimoInterrupt): + interrupt_handler(signal.SIGINT, None) + + sched.cancel_all.assert_called_once() + + +def test_sigint_with_active_async_task_cancels_without_raising() -> None: + """When an async cell is in flight (scheduler reports active tasks), + the handler must call `cancel_all` and return — raising from a + signal handler would escape into asyncio internals and surface as + an internal-error empty RunResult.""" + sched = MagicMock() + sched.has_active_tasks.return_value = True + + ctx = MagicMock(spec=KernelRuntimeContext) + ctx.execution_context = None + ctx.active_scheduler = sched + + with patch("marimo._runtime.handlers.safe_get_context", return_value=ctx): + interrupt_handler = construct_interrupt_handler() + # No exception raised. + interrupt_handler(signal.SIGINT, None) + + sched.cancel_all.assert_called_once() + + +def test_sigint_with_no_scheduler_and_no_cell_is_noop() -> None: + """No scheduler installed and no cell in flight — the handler must + return silently without raising or calling broadcast.""" + ctx = MagicMock(spec=KernelRuntimeContext) + ctx.execution_context = None + ctx.active_scheduler = None + + with patch("marimo._runtime.handlers.safe_get_context", return_value=ctx): + interrupt_handler = construct_interrupt_handler() + # No exception raised. + interrupt_handler(signal.SIGINT, None) diff --git a/tests/_runtime/test_scheduler.py b/tests/_runtime/test_scheduler.py index 58a0a517ea6..f6b52ad0563 100644 --- a/tests/_runtime/test_scheduler.py +++ b/tests/_runtime/test_scheduler.py @@ -67,7 +67,9 @@ def fake_closure(graph: object, roots: set[CellId_t]) -> set[CellId_t]: def test_batch_yields_singletons() -> None: sched = SequentialScheduler([], graph=_empty_graph()) cells = [CellId_t("a"), CellId_t("b"), CellId_t("c")] - batches = list(sched.batch(cells)) + # batch() yields iterables, not indexable lists — callers iterate with + # ``for cell_id in batch:`` rather than ``batch[0]``. + batches = [list(b) for b in sched.batch(cells)] assert batches == [["a"], ["b"], ["c"]] @@ -75,7 +77,7 @@ def test_batch_respects_interrupt() -> None: sched = SequentialScheduler([], graph=_empty_graph()) cells = [CellId_t("a"), CellId_t("b"), CellId_t("c")] iterator = sched.batch(cells) - assert next(iterator) == ["a"] + assert list(next(iterator)) == ["a"] sched.interrupted = True # Generator stops once interrupted is set. remaining = list(iterator)