diff --git a/packages/doeff-linter/docs/rules/DOEFF032.md b/packages/doeff-linter/docs/rules/DOEFF032.md new file mode 100644 index 00000000..07bab71d --- /dev/null +++ b/packages/doeff-linter/docs/rules/DOEFF032.md @@ -0,0 +1,62 @@ +# DOEFF032: Prefer Transfer for Tail Resume + +## Summary + +When a handler is finished and only wants to hand a value to the continuation, prefer: + +```python +yield Transfer(k, value) +``` + +instead of: + +```python +return (yield Resume(k, value)) +``` + +`Resume` keeps the handler alive because it may need to receive the continuation result back for +post-processing. In tail position that extra liveness is unnecessary and can retain large locals in +memory longer than needed. + +## Violation + +```python +@do +def handler(effect: Effect, k: object): + if isinstance(effect, LoadBigPayload): + payload = build_payload(effect) + return (yield Resume(k, payload)) + yield Pass() +``` + +## Preferred + +```python +@do +def handler(effect: Effect, k: object): + if isinstance(effect, LoadBigPayload): + payload = build_payload(effect) + yield Transfer(k, payload) + yield Pass() +``` + +## When Not To Use Transfer + +Keep `Resume` when the handler truly needs the continuation result: + +```python +@do +def handler(effect: Effect, k: object): + if isinstance(effect, Ping): + resumed = yield Resume(k, effect.value) + return resumed * 3 + yield Pass() +``` + +## Suppression + +If the tail `Resume` is intentional, suppress it on that line: + +```python +return (yield Resume(k, payload)) # noqa: DOEFF032 +``` diff --git a/packages/doeff-linter/src/lib.rs b/packages/doeff-linter/src/lib.rs index e9030909..5a019c6d 100644 --- a/packages/doeff-linter/src/lib.rs +++ b/packages/doeff-linter/src/lib.rs @@ -440,6 +440,13 @@ def process(): p: Program = process()"#, violation_line: 2, }, + NoqaTestCase { + rule_id: "DOEFF032", + triggering_code: r#"@do +def handler(effect, k): + return (yield Resume(k, effect.value))"#, + violation_line: 3, + }, ] } @@ -715,4 +722,3 @@ p: Program = process()"#, result.join("\n") } } - diff --git a/packages/doeff-linter/src/main.rs b/packages/doeff-linter/src/main.rs index fb42feab..65922754 100644 --- a/packages/doeff-linter/src/main.rs +++ b/packages/doeff-linter/src/main.rs @@ -540,6 +540,11 @@ fn get_rule_info(rule_id: &str) -> RuleInfo { description: "Avoid creating Program entrypoints by calling @do wrappers that only forward args to a single yielded call and return it.", fix: "Replace `p_x: Program[...] = wrapper(...)` with `p_x: Program[...] = underlying(...)` using the same arguments. If the wrapper is intentional (naming/tracing), add `# noqa: DOEFF031`.", }, + "DOEFF032" => RuleInfo { + name: "Prefer Transfer for Tail Resume", + description: "Tail-position `return (yield Resume(k, value))` keeps the handler frame alive while the resumed continuation runs.", + fix: "Replace tail-position `return (yield Resume(k, value))` with `yield Transfer(k, value)`. Keep `Resume` only when you need post-resume processing.", + }, "NOQA001" => RuleInfo { name: "Malformed noqa Comment", description: "The noqa comment format appears incorrect and may not suppress the intended rule.", diff --git a/packages/doeff-linter/src/rules/doeff032_no_tail_resume_return.rs b/packages/doeff-linter/src/rules/doeff032_no_tail_resume_return.rs new file mode 100644 index 00000000..f7129a25 --- /dev/null +++ b/packages/doeff-linter/src/rules/doeff032_no_tail_resume_return.rs @@ -0,0 +1,245 @@ +//! DOEFF032: Prefer Transfer for tail-position Resume +//! +//! `return (yield Resume(k, value))` keeps the handler generator alive until the resumed +//! continuation returns. In tail position, `yield Transfer(k, value)` is explicit and lets the VM +//! abandon the handler frame immediately. + +use crate::models::{RuleContext, Severity, Violation}; +use crate::rules::base::LintRule; +use rustpython_ast::{Expr, Stmt}; + +pub struct NoTailResumeReturnRule; + +impl NoTailResumeReturnRule { + pub fn new() -> Self { + Self + } + + fn is_resume_call(expr: &Expr) -> bool { + let Expr::Call(call) = expr else { + return false; + }; + + match &*call.func { + Expr::Name(name) => name.id.as_str() == "Resume", + Expr::Attribute(attr) => attr.attr.as_str() == "Resume", + _ => false, + } + } + + fn is_tail_resume_return(stmt: &Stmt) -> Option { + let Stmt::Return(return_stmt) = stmt else { + return None; + }; + let value = return_stmt.value.as_ref()?; + let Expr::Yield(yield_expr) = &**value else { + return None; + }; + let yielded = yield_expr.value.as_ref()?; + Self::is_resume_call(yielded).then(|| return_stmt.range.start().to_usize()) + } + + fn check_stmt(stmt: &Stmt, violations: &mut Vec, file_path: &str) { + if let Some(offset) = Self::is_tail_resume_return(stmt) { + violations.push(Violation::new( + "DOEFF032".to_string(), + "\ +`return (yield Resume(k, value))` keeps the handler frame alive while continuation `k` runs.\n\ +In tail position, prefer `yield Transfer(k, value)` so the handler is abandoned explicitly.\n\ +If you intentionally need post-resume processing, keep `Resume`; otherwise replace the tail \ +`Resume` with `Transfer`." + .to_string(), + offset, + file_path.to_string(), + Severity::Warning, + )); + } + + match stmt { + Stmt::FunctionDef(func) => { + for s in &func.body { + Self::check_stmt(s, violations, file_path); + } + } + Stmt::AsyncFunctionDef(func) => { + for s in &func.body { + Self::check_stmt(s, violations, file_path); + } + } + Stmt::ClassDef(class_def) => { + for s in &class_def.body { + Self::check_stmt(s, violations, file_path); + } + } + Stmt::If(if_stmt) => { + for s in &if_stmt.body { + Self::check_stmt(s, violations, file_path); + } + for s in &if_stmt.orelse { + Self::check_stmt(s, violations, file_path); + } + } + Stmt::While(while_stmt) => { + for s in &while_stmt.body { + Self::check_stmt(s, violations, file_path); + } + for s in &while_stmt.orelse { + Self::check_stmt(s, violations, file_path); + } + } + Stmt::For(for_stmt) => { + for s in &for_stmt.body { + Self::check_stmt(s, violations, file_path); + } + for s in &for_stmt.orelse { + Self::check_stmt(s, violations, file_path); + } + } + Stmt::AsyncFor(for_stmt) => { + for s in &for_stmt.body { + Self::check_stmt(s, violations, file_path); + } + for s in &for_stmt.orelse { + Self::check_stmt(s, violations, file_path); + } + } + Stmt::With(with_stmt) => { + for s in &with_stmt.body { + Self::check_stmt(s, violations, file_path); + } + } + Stmt::AsyncWith(with_stmt) => { + for s in &with_stmt.body { + Self::check_stmt(s, violations, file_path); + } + } + Stmt::Try(try_stmt) => { + for s in &try_stmt.body { + Self::check_stmt(s, violations, file_path); + } + for handler in &try_stmt.handlers { + let rustpython_ast::ExceptHandler::ExceptHandler(handler) = handler; + for s in &handler.body { + Self::check_stmt(s, violations, file_path); + } + } + for s in &try_stmt.orelse { + Self::check_stmt(s, violations, file_path); + } + for s in &try_stmt.finalbody { + Self::check_stmt(s, violations, file_path); + } + } + Stmt::Match(match_stmt) => { + for case in &match_stmt.cases { + for s in &case.body { + Self::check_stmt(s, violations, file_path); + } + } + } + _ => {} + } + } +} + +impl LintRule for NoTailResumeReturnRule { + fn rule_id(&self) -> &str { + "DOEFF032" + } + + fn description(&self) -> &str { + "Prefer Transfer over tail-position Resume" + } + + fn check(&self, context: &RuleContext) -> Vec { + let mut violations = Vec::new(); + Self::check_stmt(context.stmt, &mut violations, context.file_path); + violations + } +} + +#[cfg(test)] +mod tests { + use super::*; + use rustpython_ast::Mod; + use rustpython_parser::{parse, Mode}; + + fn check_code(code: &str) -> Vec { + let ast = parse(code, Mode::Module, "test.py").unwrap(); + let rule = NoTailResumeReturnRule::new(); + let mut violations = Vec::new(); + + if let Mod::Module(module) = &ast { + for stmt in &module.body { + let context = RuleContext { + stmt, + file_path: "test.py", + source: code, + ast: &ast, + }; + violations.extend(rule.check(&context)); + } + } + + violations + } + + #[test] + fn test_tail_resume_return_is_flagged() { + let code = r#" +@do +def handler(effect, k): + return (yield Resume(k, effect.value)) +"#; + let violations = check_code(code); + assert_eq!(violations.len(), 1); + assert!(violations[0].message.contains("Transfer")); + } + + #[test] + fn test_attribute_resume_return_is_flagged() { + let code = r#" +@do +def handler(effect, k): + return (yield doeff_vm.Resume(k, effect.value)) +"#; + let violations = check_code(code); + assert_eq!(violations.len(), 1); + } + + #[test] + fn test_resume_with_post_processing_is_allowed() { + let code = r#" +@do +def handler(effect, k): + resumed = yield Resume(k, effect.value) + return resumed * 3 +"#; + let violations = check_code(code); + assert_eq!(violations.len(), 0); + } + + #[test] + fn test_transfer_is_not_flagged() { + let code = r#" +@do +def handler(effect, k): + yield Transfer(k, effect.value) +"#; + let violations = check_code(code); + assert_eq!(violations.len(), 0); + } + + #[test] + fn test_nested_tail_resume_return_is_flagged() { + let code = r#" +@do +def handler(effect, k): + if effect.ready: + return (yield Resume(k, effect.value)) + yield Pass() +"#; + let violations = check_code(code); + assert_eq!(violations.len(), 1); + } +} diff --git a/packages/doeff-linter/src/rules/mod.rs b/packages/doeff-linter/src/rules/mod.rs index c404329e..7983a18b 100644 --- a/packages/doeff-linter/src/rules/mod.rs +++ b/packages/doeff-linter/src/rules/mod.rs @@ -29,6 +29,7 @@ pub mod doeff023_pipeline_marker; pub mod doeff024_no_recover_ask; pub mod doeff030_ask_result_type_annotation; pub mod doeff031_no_redundant_do_wrapper_entrypoint; +pub mod doeff032_no_tail_resume_return; use base::LintRule; use std::collections::HashMap; @@ -64,6 +65,7 @@ pub fn get_all_rules() -> Vec> { Box::new( doeff031_no_redundant_do_wrapper_entrypoint::NoRedundantDoWrapperEntrypointRule::new(), ), + Box::new(doeff032_no_tail_resume_return::NoTailResumeReturnRule::new()), ] } @@ -103,7 +105,7 @@ mod tests { #[test] fn test_all_rules_loaded() { let rules = get_all_rules(); - assert_eq!(rules.len(), 26); + assert_eq!(rules.len(), 27); let rule_ids: Vec<_> = rules.iter().map(|r| r.rule_id()).collect(); assert!(rule_ids.contains(&"DOEFF001")); @@ -124,6 +126,7 @@ mod tests { assert!(rule_ids.contains(&"DOEFF024")); assert!(rule_ids.contains(&"DOEFF030")); assert!(rule_ids.contains(&"DOEFF031")); + assert!(rule_ids.contains(&"DOEFF032")); } #[test] diff --git a/tests/public_api/test_types_001_handler_protocol.py b/tests/public_api/test_types_001_handler_protocol.py index a975a8d4..303af79d 100644 --- a/tests/public_api/test_types_001_handler_protocol.py +++ b/tests/public_api/test_types_001_handler_protocol.py @@ -122,6 +122,41 @@ def main(): result = run(_prog(main), handlers=default_handlers()) assert result.value == 45 # handler gets 15, returns 15*3 + def test_resume_unwinds_after_remainder_completes(self) -> None: + events: list[str] = [] + + @do + def handler(effect: Effect, k): + if isinstance(effect, _CustomEffect): + events.append(f"before:{effect.value}") + resume_value = yield Resume(k, effect.value) + events.append(f"after:{effect.value}:{resume_value}") + return resume_value + else: + yield Delegate() + + def body(): + first = yield _CustomEffect(1) + events.append(f"body:{first}") + second = yield _CustomEffect(2) + events.append(f"body:{second}") + return "done" + + def main(): + result = yield WithHandler(handler=handler, expr=_prog(body)) + return result + + result = run(_prog(main), handlers=default_handlers()) + assert result.value == "done" + assert events == [ + "before:1", + "body:1", + "before:2", + "body:2", + "after:2:done", + "after:1:done", + ] + class TestHP03BReturnEffect: def test_handler_returning_effect_raises_typeerror(self) -> None: diff --git a/tests/test_try_finally_in_do.py b/tests/test_try_finally_in_do.py index dc0cba64..54c4045f 100644 --- a/tests/test_try_finally_in_do.py +++ b/tests/test_try_finally_in_do.py @@ -3,7 +3,6 @@ from dataclasses import dataclass import doeff - from doeff import ( AcquireSemaphore, CreateSemaphore, diff --git a/tests/test_vm_memory_leak.py b/tests/test_vm_memory_leak.py new file mode 100644 index 00000000..63281ce9 --- /dev/null +++ b/tests/test_vm_memory_leak.py @@ -0,0 +1,167 @@ +from __future__ import annotations + +import gc +import resource +import weakref +from dataclasses import dataclass + +import doeff_vm + +from doeff import Pass, default_handlers, do, run +from doeff.effects.base import EffectBase +from doeff.rust_vm import WithHandler + + +@dataclass(frozen=True) +class BigDataEffect(EffectBase): + iteration: int + + +@dataclass(frozen=True) +class TinyEffect(EffectBase): + iteration: int + + +class CountAlivePayloads(EffectBase): + pass + + +class WeakPayload: + __slots__ = ("__weakref__", "payload") + + def __init__(self, payload: list[int]) -> None: + self.payload = payload + + +@do +def big_data_handler(effect: EffectBase, k): + if isinstance(effect, BigDataEffect): + big_payload = list(range(100_000)) + yield doeff_vm.Transfer(k, big_payload) + if isinstance(effect, TinyEffect): + yield doeff_vm.Transfer(k, effect.iteration) + yield Pass() + + +@do +def _sequential_yield_discard(n: int): + for i in range(n): + _unused = yield BigDataEffect(iteration=i) + del _unused + return n + + +@do +def _sequential_yield_keep_last(n: int): + last = None + for i in range(n): + last = yield BigDataEffect(iteration=i) + return last + + +@do +def _tiny_loop(n: int): + for i in range(n): + _ = yield TinyEffect(iteration=i) + del _ + return n + + +def _rss_mb() -> float: + usage = resource.getrusage(resource.RUSAGE_SELF) + return usage.ru_maxrss / 1024 + + +def _run_program(program): + return run(WithHandler(big_data_handler, program), handlers=default_handlers()) + + +N_ITERATIONS = 200 +MAX_ALLOWED_GROWTH_MB = 50 + + +def test_tail_transfer_releases_large_payloads_during_run() -> None: + # Transfer is the explicit tail-position protocol: the handler is abandoned + # immediately instead of staying suspended on the remainder continuation. + payload_refs: list[weakref.ReferenceType[WeakPayload]] = [] + + @do + def handler(effect: EffectBase, k): + if isinstance(effect, BigDataEffect): + payload = WeakPayload(list(range(100_000))) + payload_refs.append(weakref.ref(payload)) + yield doeff_vm.Transfer(k, payload) + if isinstance(effect, CountAlivePayloads): + alive = sum(ref() is not None for ref in payload_refs) + yield doeff_vm.Transfer(k, alive) + yield Pass() + + @do + def program(): + samples: list[int] = [] + for i in range(40): + payload = yield BigDataEffect(iteration=i) + del payload + if (i + 1) % 10 == 0: + samples.append((yield CountAlivePayloads())) + return samples + + result = run(WithHandler(handler, program()), handlers=default_handlers()) + + assert result.is_ok(), f"Program failed: {result.error}" + assert result.value == [1, 1, 1, 1] + assert max(result.value) <= 1 + + +def test_sequential_discard_bounded_memory() -> None: + gc.collect() + rss_before = _rss_mb() + + result = _run_program(_sequential_yield_discard(N_ITERATIONS)) + + gc.collect() + rss_after = _rss_mb() + delta = rss_after - rss_before + + assert result.is_ok(), f"Program failed: {result.error}" + assert result.value == N_ITERATIONS + assert delta < MAX_ALLOWED_GROWTH_MB, ( + f"Memory leak detected! RSS grew by {delta:.0f} MB " + f"for {N_ITERATIONS} iterations yielding ~800KB each " + f"(expected <{MAX_ALLOWED_GROWTH_MB} MB)." + ) + + +def test_sequential_keep_last_bounded_memory() -> None: + gc.collect() + rss_before = _rss_mb() + + result = _run_program(_sequential_yield_keep_last(N_ITERATIONS)) + + gc.collect() + rss_after = _rss_mb() + delta = rss_after - rss_before + + assert result.is_ok(), f"Program failed: {result.error}" + assert isinstance(result.value, list) + assert len(result.value) == 100_000 + assert delta < MAX_ALLOWED_GROWTH_MB, ( + f"Memory leak detected! RSS grew by {delta:.0f} MB " + f"for {N_ITERATIONS} iterations yielding ~800KB each " + f"(expected <{MAX_ALLOWED_GROWTH_MB} MB)." + ) + + +def test_control_small_effects_low_memory() -> None: + gc.collect() + rss_before = _rss_mb() + + result = _run_program(_tiny_loop(N_ITERATIONS)) + + gc.collect() + rss_after = _rss_mb() + delta = rss_after - rss_before + + assert result.is_ok(), f"Program failed: {result.error}" + assert result.value == N_ITERATIONS + assert delta < 20, f"Even tiny effects leaked {delta:.0f} MB."